mirror of
https://github.com/Rudd-O/qubes-network-server.git
synced 2025-03-01 14:22:35 +01:00
380 lines
9.7 KiB
Python
380 lines
9.7 KiB
Python
#!/usr/bin/python3
|
|
|
|
import json
|
|
import logging
|
|
import subprocess
|
|
|
|
from typing import TypedDict, Any, cast, Literal, Union
|
|
|
|
|
|
ADDRESS_FAMILIES = Union[Literal["ip"], Literal["ip6"]]
|
|
|
|
|
|
class Chain(TypedDict):
|
|
name: str
|
|
family: str
|
|
table: str
|
|
handle: int
|
|
type: str
|
|
hook: str
|
|
prio: int
|
|
policy: str
|
|
|
|
|
|
class Table(TypedDict):
|
|
family: str
|
|
name: str
|
|
handle: int
|
|
|
|
|
|
class Metainfo(TypedDict):
|
|
version: str
|
|
release_name: str
|
|
json_schema_version: int
|
|
|
|
|
|
class Rule(TypedDict):
|
|
family: str
|
|
table: str
|
|
chain: str
|
|
handle: int
|
|
expr: list[dict[str, Any]]
|
|
|
|
|
|
class ChainContainer(TypedDict):
|
|
chain: Chain
|
|
|
|
|
|
class MetainfoContainer(TypedDict):
|
|
metainfo: Metainfo
|
|
|
|
|
|
class TableContainer(TypedDict):
|
|
table: Table
|
|
|
|
|
|
class RuleContainer(TypedDict):
|
|
rule: Rule
|
|
|
|
|
|
class NFTablesOutput(TypedDict):
|
|
nftables: list[ChainContainer | MetainfoContainer | TableContainer | RuleContainer]
|
|
|
|
|
|
ADDRESS_FAMILY_IPV6 = "ip6"
|
|
ADDRESS_FAMILY_IPV4 = "ip"
|
|
TABLE_NAME = "qubes"
|
|
FORWARD_CHAIN_NAME = "forward"
|
|
POSTROUTING_CHAIN_NAME = "postrouting"
|
|
ROUTING_MANAGER_CHAIN_NAME = "qubes-routing-manager"
|
|
ROUTING_MANAGER_POSTROUTING_CHAIN_NAME = "qubes-routing-manager-postrouting"
|
|
NFTABLES_CMD = "nft"
|
|
|
|
|
|
def get_table(address_family: ADDRESS_FAMILIES, table: str) -> NFTablesOutput:
|
|
return cast(
|
|
NFTablesOutput,
|
|
json.loads(
|
|
subprocess.check_output(
|
|
[NFTABLES_CMD, "-n", "-j", "list", "table", address_family, table],
|
|
text=True,
|
|
)
|
|
),
|
|
)
|
|
|
|
|
|
def add_chain(address_family: ADDRESS_FAMILIES, table: str, chain: str) -> None:
|
|
subprocess.check_output(
|
|
[
|
|
NFTABLES_CMD,
|
|
"-n",
|
|
"-j",
|
|
"add",
|
|
"chain",
|
|
address_family,
|
|
table,
|
|
chain,
|
|
],
|
|
text=True,
|
|
)
|
|
|
|
|
|
def append_rule_at_end(
|
|
address_family: ADDRESS_FAMILIES, table: str, chain: str, *rest: str
|
|
) -> None:
|
|
subprocess.check_output(
|
|
[
|
|
NFTABLES_CMD,
|
|
"-n",
|
|
"-j",
|
|
"add",
|
|
"rule",
|
|
address_family,
|
|
table,
|
|
chain,
|
|
]
|
|
+ list(rest),
|
|
text=True,
|
|
)
|
|
|
|
|
|
def append_counter_at_end(
|
|
address_family: ADDRESS_FAMILIES, table: str, chain: str, *rest: str
|
|
) -> None:
|
|
subprocess.check_output(
|
|
[
|
|
NFTABLES_CMD,
|
|
"-n",
|
|
"-j",
|
|
"add",
|
|
"rule",
|
|
address_family,
|
|
table,
|
|
chain,
|
|
"counter",
|
|
]
|
|
+ list(rest),
|
|
text=True,
|
|
)
|
|
|
|
|
|
def _append_or_insert_rule(
|
|
where: Literal["add"] | Literal["insert"],
|
|
address_family: ADDRESS_FAMILIES,
|
|
table: str,
|
|
chain: str,
|
|
handle: int,
|
|
*rest: str,
|
|
) -> None:
|
|
subprocess.check_output(
|
|
[
|
|
NFTABLES_CMD,
|
|
"-n",
|
|
"-j",
|
|
where,
|
|
"rule",
|
|
address_family,
|
|
table,
|
|
chain,
|
|
"position",
|
|
str(handle),
|
|
]
|
|
+ list(rest),
|
|
text=True,
|
|
)
|
|
|
|
|
|
def append_rule_after(
|
|
address_family: ADDRESS_FAMILIES, table: str, chain: str, handle: int, *rest: str
|
|
) -> None:
|
|
_append_or_insert_rule("add", address_family, table, chain, handle, *rest)
|
|
|
|
|
|
def insert_rule_before(
|
|
address_family: ADDRESS_FAMILIES, table: str, chain: str, handle: int, *rest: str
|
|
) -> None:
|
|
_append_or_insert_rule("insert", address_family, table, chain, handle, *rest)
|
|
|
|
|
|
def delete_rule(
|
|
address_family: ADDRESS_FAMILIES, table: str, chain: str, handle: int
|
|
) -> None:
|
|
subprocess.check_output(
|
|
[
|
|
NFTABLES_CMD,
|
|
"-n",
|
|
"-j",
|
|
"delete",
|
|
"rule",
|
|
address_family,
|
|
table,
|
|
chain,
|
|
"handle",
|
|
str(handle),
|
|
],
|
|
text=True,
|
|
)
|
|
|
|
|
|
def setup_plain_forwarding_for_address(source: str, enable: bool, family: int) -> None:
|
|
logging.info("Handling forwarding for address %s family %s.", source, family)
|
|
|
|
af = cast(
|
|
ADDRESS_FAMILIES,
|
|
ADDRESS_FAMILY_IPV6 if family == 6 else ADDRESS_FAMILY_IPV4,
|
|
)
|
|
|
|
# table ip qubes {
|
|
# set downstream {
|
|
# type ipv4_addr
|
|
# elements = { 10.137.0.10, 10.250.4.13 }
|
|
# }
|
|
# ...
|
|
existing_table_output = get_table(af, TABLE_NAME)
|
|
existing_table_items = existing_table_output["nftables"]
|
|
|
|
existing_chains = [x["chain"] for x in existing_table_items if "chain" in x] # type: ignore
|
|
existing_rules = [x["rule"] for x in existing_table_items if "rule" in x] # type: ignore
|
|
|
|
try:
|
|
forward_chain = [x for x in existing_chains if x["name"] == FORWARD_CHAIN_NAME][
|
|
0
|
|
]
|
|
postrouting_chain = [
|
|
x for x in existing_chains if x["name"] == POSTROUTING_CHAIN_NAME
|
|
][0]
|
|
except IndexError:
|
|
logging.warn(
|
|
"No forward or postrouting chains in table %s, not setting up forwarding",
|
|
TABLE_NAME,
|
|
)
|
|
return
|
|
|
|
for chain_name in [
|
|
ROUTING_MANAGER_CHAIN_NAME,
|
|
ROUTING_MANAGER_POSTROUTING_CHAIN_NAME,
|
|
]:
|
|
chain: None | Chain = None
|
|
try:
|
|
chain = [x for x in existing_chains if x["name"] == chain_name].pop()
|
|
except IndexError:
|
|
pass
|
|
|
|
if not chain:
|
|
logging.info(
|
|
"Adding %s chain to table %s and counter to chain",
|
|
chain_name,
|
|
TABLE_NAME,
|
|
)
|
|
add_chain(af, TABLE_NAME, chain_name)
|
|
append_counter_at_end(
|
|
af,
|
|
TABLE_NAME,
|
|
chain_name,
|
|
)
|
|
|
|
def is_oifgroup_2(rule):
|
|
return (
|
|
rule["chain"] == forward_chain["name"]
|
|
and len(rule["expr"]) == 3
|
|
and (
|
|
rule["expr"][0].get("match", {}).get("op") == "=="
|
|
and rule["expr"][0]
|
|
.get("match", {})
|
|
.get("left", {})
|
|
.get("meta", {})
|
|
.get("key")
|
|
== "oifgroup"
|
|
and rule["expr"][0].get("match", {}).get("right") == 2
|
|
)
|
|
and (rule["expr"][-1].get("drop", "not none") is None)
|
|
)
|
|
|
|
def is_postrouting_masquerade(rule):
|
|
return (
|
|
rule["chain"] == postrouting_chain["name"]
|
|
and len(rule["expr"]) == 1
|
|
and "masquerade" in rule["expr"][0]
|
|
)
|
|
|
|
for parent_chain, child_chain_name, previous_rule_detector, insertor in [
|
|
(
|
|
forward_chain,
|
|
ROUTING_MANAGER_CHAIN_NAME,
|
|
is_oifgroup_2,
|
|
insert_rule_before,
|
|
),
|
|
(
|
|
postrouting_chain,
|
|
ROUTING_MANAGER_POSTROUTING_CHAIN_NAME,
|
|
is_postrouting_masquerade,
|
|
insert_rule_before,
|
|
),
|
|
]:
|
|
jump_rule: None | Rule = None
|
|
try:
|
|
jump_rule = [
|
|
x
|
|
for x in existing_rules
|
|
if x["chain"] == parent_chain["name"]
|
|
and x["family"] == af
|
|
and len(x["expr"]) == 1
|
|
and x["expr"][0].get("jump", {}).get("target") == child_chain_name
|
|
].pop()
|
|
except IndexError:
|
|
pass
|
|
|
|
if not jump_rule:
|
|
try:
|
|
previous_rule = [
|
|
x for x in existing_rules if previous_rule_detector(x)
|
|
][0]
|
|
except IndexError:
|
|
logging.warn(
|
|
"Cannot find appropriate previous rule in chain %s of table %s, not setting up forwarding",
|
|
parent_chain["name"],
|
|
TABLE_NAME,
|
|
)
|
|
logging.info(
|
|
"Adding rule to jump from chain %s to chain %s in table %s",
|
|
parent_chain["name"],
|
|
child_chain_name,
|
|
TABLE_NAME,
|
|
)
|
|
insertor(
|
|
af,
|
|
TABLE_NAME,
|
|
parent_chain["name"],
|
|
previous_rule["handle"],
|
|
"jump",
|
|
child_chain_name,
|
|
)
|
|
|
|
def detect_ip_rule(rule: Rule, chain_name: str, ip: str, mode: str):
|
|
return (
|
|
rule["chain"] == chain_name
|
|
and len(rule["expr"]) == 2
|
|
and rule["expr"][0].get("match", {}).get("op", {}) == "=="
|
|
and rule["expr"][0]["match"]
|
|
.get("left", {})
|
|
.get("payload", {})
|
|
.get("protocol", "")
|
|
== af
|
|
and rule["expr"][0]["match"]["left"]["payload"].get("field", "") == mode
|
|
and rule["expr"][0].get("match", {}).get("right", []) == ip
|
|
and "accept" in rule["expr"][1]
|
|
)
|
|
|
|
for chain_name, mode in [
|
|
(ROUTING_MANAGER_CHAIN_NAME, "daddr"),
|
|
(ROUTING_MANAGER_POSTROUTING_CHAIN_NAME, "saddr"),
|
|
]:
|
|
address_rules = [
|
|
x for x in existing_rules if detect_ip_rule(x, chain_name, source, mode)
|
|
]
|
|
|
|
if enable and not address_rules:
|
|
logging.info(
|
|
"Adding accept rule on chain %s for %s.",
|
|
chain_name,
|
|
source,
|
|
)
|
|
append_rule_at_end(
|
|
af,
|
|
TABLE_NAME,
|
|
chain_name,
|
|
af,
|
|
mode,
|
|
source,
|
|
"accept",
|
|
)
|
|
elif not enable and address_rules:
|
|
logging.info(
|
|
"Removing %s accept rules from chain %s for %s.",
|
|
len(address_rules),
|
|
chain_name,
|
|
source,
|
|
)
|
|
for rule in reversed(sorted(address_rules, key=lambda r: r["handle"])):
|
|
delete_rule(af, TABLE_NAME, chain_name, rule["handle"])
|