2024-02-20 08:57:43 +00:00

372 lines
9.5 KiB
Python

#!/usr/bin/python3
import json
import logging
import subprocess
from typing import TypedDict, Any, cast, Literal
ADDRESS_FAMILIES = Literal["ip"] | Literal["ip6"]
class Chain(TypedDict):
name: str
family: str
table: str
handle: int
type: str
hook: str
prio: int
policy: str
class Table(TypedDict):
family: str
name: str
handle: int
class Metainfo(TypedDict):
version: str
release_name: str
json_schema_version: int
class Rule(TypedDict):
family: str
table: str
chain: str
handle: int
expr: list[dict[str, Any]]
class ChainContainer(TypedDict):
chain: Chain
class MetainfoContainer(TypedDict):
metainfo: Metainfo
class TableContainer(TypedDict):
table: Table
class RuleContainer(TypedDict):
rule: Rule
class NFTablesOutput(TypedDict):
nftables: list[ChainContainer | MetainfoContainer | TableContainer | RuleContainer]
ADDRESS_FAMILY_IPV6 = "ip6"
ADDRESS_FAMILY_IPV4 = "ip"
TABLE_NAME = "qubes"
FORWARD_CHAIN_NAME = "forward"
POSTROUTING_CHAIN_NAME = "postrouting"
ROUTING_MANAGER_CHAIN_NAME = "qubes-routing-manager"
ROUTING_MANAGER_POSTROUTING_CHAIN_NAME = "qubes-routing-manager-postrouting"
NFTABLES_CMD = "nft"
ADD_FORWARD_RULE_AFTER_THIS_RULE = "custom-forward"
def get_table(address_family: ADDRESS_FAMILIES, table: str) -> NFTablesOutput:
return cast(
NFTablesOutput,
json.loads(
subprocess.check_output(
[NFTABLES_CMD, "-n", "-j", "list", "table", address_family, table],
text=True,
)
),
)
def add_chain(address_family: ADDRESS_FAMILIES, table: str, chain: str) -> None:
subprocess.check_output(
[
NFTABLES_CMD,
"-n",
"-j",
"add",
"chain",
address_family,
table,
chain,
],
text=True,
)
def append_rule_at_end(
address_family: ADDRESS_FAMILIES, table: str, chain: str, *rest: str
) -> None:
subprocess.check_output(
[
NFTABLES_CMD,
"-n",
"-j",
"add",
"rule",
address_family,
table,
chain,
]
+ list(rest),
text=True,
)
def append_counter_at_end(
address_family: ADDRESS_FAMILIES, table: str, chain: str, *rest: str
) -> None:
subprocess.check_output(
[
NFTABLES_CMD,
"-n",
"-j",
"add",
"rule",
address_family,
table,
chain,
"counter",
]
+ list(rest),
text=True,
)
def _append_or_insert_rule(
where: Literal["add"] | Literal["insert"],
address_family: ADDRESS_FAMILIES,
table: str,
chain: str,
handle: int,
*rest: str,
) -> None:
subprocess.check_output(
[
NFTABLES_CMD,
"-n",
"-j",
where,
"rule",
address_family,
table,
chain,
"position",
str(handle),
]
+ list(rest),
text=True,
)
def append_rule_after(
address_family: ADDRESS_FAMILIES, table: str, chain: str, handle: int, *rest: str
) -> None:
_append_or_insert_rule("add", address_family, table, chain, handle, *rest)
def insert_rule_before(
address_family: ADDRESS_FAMILIES, table: str, chain: str, handle: int, *rest: str
) -> None:
_append_or_insert_rule("insert", address_family, table, chain, handle, *rest)
def delete_rule(
address_family: ADDRESS_FAMILIES, table: str, chain: str, handle: int
) -> None:
subprocess.check_output(
[
NFTABLES_CMD,
"-n",
"-j",
"delete",
"rule",
address_family,
table,
chain,
"handle",
str(handle),
],
text=True,
)
def setup_plain_forwarding_for_address(source: str, enable: bool, family: int) -> None:
logging.info("Handling forwarding for address %s family %s.", source, family)
af = cast(
ADDRESS_FAMILIES,
ADDRESS_FAMILY_IPV6 if family == 6 else ADDRESS_FAMILY_IPV4,
)
# table ip qubes {
# set downstream {
# type ipv4_addr
# elements = { 10.137.0.10, 10.250.4.13 }
# }
# ...
existing_table_output = get_table(af, TABLE_NAME)
existing_table_items = existing_table_output["nftables"]
existing_chains = [x["chain"] for x in existing_table_items if "chain" in x] # type: ignore
existing_rules = [x["rule"] for x in existing_table_items if "rule" in x] # type: ignore
try:
forward_chain = [x for x in existing_chains if x["name"] == FORWARD_CHAIN_NAME][
0
]
postrouting_chain = [
x for x in existing_chains if x["name"] == POSTROUTING_CHAIN_NAME
][0]
except IndexError:
logging.warn(
"No forward or postrouting chains in table %s, not setting up forwarding",
TABLE_NAME,
)
return
for chain_name in [
ROUTING_MANAGER_CHAIN_NAME,
ROUTING_MANAGER_POSTROUTING_CHAIN_NAME,
]:
chain: None | Chain = None
try:
chain = [x for x in existing_chains if x["name"] == chain_name].pop()
except IndexError:
pass
if not chain:
logging.info(
"Adding %s chain to table %s and counter to chain",
chain_name,
TABLE_NAME,
)
add_chain(af, TABLE_NAME, chain_name)
append_counter_at_end(
af,
TABLE_NAME,
chain_name,
)
def is_forward_jump_to_custom_forward(rule):
return (
rule["chain"] == forward_chain["name"]
and len(rule["expr"]) == 1
and rule["expr"][0].get("jump", {}).get("target")
== ADD_FORWARD_RULE_AFTER_THIS_RULE
)
def is_postrouting_masquerade(rule):
return (
rule["chain"] == postrouting_chain["name"]
and len(rule["expr"]) == 1
and "masquerade" in rule["expr"][0]
)
for parent_chain, child_chain_name, previous_rule_detector, insertor in [
(
forward_chain,
ROUTING_MANAGER_CHAIN_NAME,
is_forward_jump_to_custom_forward,
append_rule_after,
),
(
postrouting_chain,
ROUTING_MANAGER_POSTROUTING_CHAIN_NAME,
is_postrouting_masquerade,
insert_rule_before,
),
]:
jump_rule: None | Rule = None
try:
jump_rule = [
x
for x in existing_rules
if x["chain"] == parent_chain["name"]
and x["family"] == af
and len(x["expr"]) == 1
and x["expr"][0].get("jump", {}).get("target") == child_chain_name
].pop()
except IndexError:
pass
if not jump_rule:
try:
previous_rule = [
x for x in existing_rules if previous_rule_detector(x)
][0]
except IndexError:
logging.warn(
"Cannot find appropriate previous rule in chain %s of table %s, not setting up forwarding",
parent_chain["name"],
TABLE_NAME,
)
logging.info(
"Adding rule to jump from chain %s to chain %s in table %s",
parent_chain["name"],
child_chain_name,
TABLE_NAME,
)
insertor(
af,
TABLE_NAME,
parent_chain["name"],
previous_rule["handle"],
"jump",
child_chain_name,
)
def detect_ip_rule(rule: Rule, chain_name: str, ip: str, mode: str):
return (
rule["chain"] == chain_name
and len(rule["expr"]) == 2
and rule["expr"][0].get("match", {}).get("op", {}) == "=="
and rule["expr"][0]["match"]
.get("left", {})
.get("payload", {})
.get("protocol", "")
== af
and rule["expr"][0]["match"]["left"]["payload"].get("field", "") == mode
and rule["expr"][0].get("match", {}).get("right", []) == ip
and "accept" in rule["expr"][1]
)
for chain_name, mode in [
(ROUTING_MANAGER_CHAIN_NAME, "daddr"),
(ROUTING_MANAGER_POSTROUTING_CHAIN_NAME, "saddr"),
]:
address_rules = [
x for x in existing_rules if detect_ip_rule(x, chain_name, source, mode)
]
if enable and not address_rules:
logging.info(
"Adding accept rule on chain %s for %s.",
chain_name,
source,
)
append_rule_at_end(
af,
TABLE_NAME,
chain_name,
af,
mode,
source,
"accept",
)
elif not enable and address_rules:
logging.info(
"Removing %s accept rules from chain %s for %s.",
len(address_rules),
chain_name,
source,
)
for rule in reversed(sorted(address_rules, key=lambda r: r["handle"])):
delete_rule(af, TABLE_NAME, chain_name, rule["handle"])