From b25dd1e314a85da0085d542ab167f5d38b9e7435 Mon Sep 17 00:00:00 2001 From: Miguel Jacq Date: Thu, 14 May 2026 15:16:36 +1000 Subject: [PATCH 01/36] * Add support for capturing ipset and iptables configuration files * Add support for generating ipset and iptables configuration files from runtime, if the former weren't present (`firewall_runtime` role) * Dependency updates --- CHANGELOG.md | 11 ++ README.md | 4 + debian/changelog | 7 + enroll/explain.py | 22 ++- enroll/harvest.py | 288 +++++++++++++++++++++++++++++++- enroll/manifest.py | 192 +++++++++++++++++++++ enroll/schema/state.schema.json | 78 ++++++++- enroll/validate.py | 31 ++++ poetry.lock | 12 +- pyproject.toml | 2 +- rpm/enroll.spec | 5 +- tests/test_harvest_helpers.py | 118 +++++++++++++ tests/test_manifest.py | 97 +++++++++++ 13 files changed, 856 insertions(+), 11 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4b1428c..ef94a82 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,14 @@ +# 0.6.0 + + * Add support for capturing ipset and iptables configuration files + * Add support for generating ipset and iptables configuration files from runtime, if the former weren't present (`firewall_runtime` role) + * Dependency updates + +# 0.5.0 + + * Add support for templating `sshd_config`, if a compatible version of JinjaTurtle is also present. + * Dependency updates + # 0.4.4 * Update cryptography dependency diff --git a/README.md b/README.md index c2843fd..d2d51ad 100644 --- a/README.md +++ b/README.md @@ -13,6 +13,7 @@ - Defensively excludes likely secrets (path denylist + content sniff + size caps). - Captures non-system users and their SSH public keys and any .bashrc or .bash_aliases or .profile files that deviate from the skel defaults. - Captures miscellaneous `/etc` files it can't attribute to a package and installs them in an `etc_custom` role. +- Captures live ipset and iptables runtime state into a fallback `firewall_runtime` role, when active ipsets/iptables rules are present *and* no corresponding persistent ipset/iptables *files* were found. - Captures symlinks in common applications that rely on them, e.g apache2/nginx 'sites-enabled' - Ditto for /usr/local/bin (for non-binary files) and /usr/local/etc - Avoids trying to start systemd services that were detected as inactive during harvest. @@ -70,6 +71,8 @@ Harvest state about a host and write a harvest bundle. - Changed-from-default config (plus related custom/unowned files under service dirs) - Non-system users + SSH public keys - Misc `/etc` that can't be attributed to a package (`etc_custom` role) +- Static firewall config files such as nftables, UFW, firewalld, `/etc/iptables/rules.v4`, `/etc/iptables/rules.v6`, and `/etc/ipset*` +- Live kernel ipset/iptables state via `ipset save`, `iptables-save`, and `ip6tables-save` as a fallback, but only when the corresponding persistent config was not found (`firewall_runtime` role at manifest time) - Optional user-specified extra files/dirs via `--include-path` (emitted as an `extra_paths` role at manifest time) **Common flags** @@ -531,6 +534,7 @@ Roles collected - packages: 232 package snapshot(s), 41 file(s), 0 excluded - apt_config: 26 file(s), 7 dir(s), 10 excluded - dnf_config: 0 file(s), 0 dir(s), 0 excluded +- firewall_runtime: 2 snapshot(s), 1 ipset(s) - etc_custom: 70 file(s), 20 dir(s), 0 excluded - usr_local_custom: 35 file(s), 1 dir(s), 0 excluded - extra_paths: 0 file(s), 0 dir(s), 0 excluded diff --git a/debian/changelog b/debian/changelog index ee732b6..5292e0e 100644 --- a/debian/changelog +++ b/debian/changelog @@ -1,3 +1,10 @@ +enroll (0.6.0) unstable; urgency=medium + + * Add support for capturing ipset and iptables configuration files + * Add support for generating ipset and iptables configuration files from runtime, if the former weren't present ('firewall_runtime' role) + + -- Miguel Jacq Thu, 14 May 2026 15:00 +1000 + enroll (0.5.0) unstable; urgency=medium * Add ssh config support where JinjaTurtle is used diff --git a/enroll/explain.py b/enroll/explain.py index 835f207..131f2df 100644 --- a/enroll/explain.py +++ b/enroll/explain.py @@ -72,7 +72,7 @@ _MANAGED_FILE_REASONS: Dict[str, ReasonInfo] = { ), "system_firewall": ReasonInfo( "Firewall configuration", - "Firewall rules/configuration (ufw, nftables, iptables, etc.).", + "Firewall rules/configuration (ufw, nftables, iptables, ipset, etc.).", ), "system_sysctl": ReasonInfo( "sysctl configuration", @@ -211,6 +211,10 @@ _OBSERVED_VIA: Dict[str, ReasonInfo] = { "Referenced by package role", "Package was referenced by an enroll packages snapshot/role.", ), + "firewall_runtime": ReasonInfo( + "Referenced by firewall runtime role", + "Package was referenced by captured live ipset/iptables runtime state.", + ), } @@ -359,6 +363,22 @@ def explain_state( } ) + # Runtime firewall snapshot + firewall_obj = roles.get("firewall_runtime") or {} + if isinstance(firewall_obj, dict) and firewall_obj: + captures = [ + key + for key in ("ipset_save", "iptables_v4_save", "iptables_v6_save") + if firewall_obj.get(key) + ] + role_summaries.append( + { + "role": "firewall_runtime", + "summary": f"{len(captures)} snapshot(s), {len(firewall_obj.get('ipset_sets') or [])} ipset(s)", + "notes": firewall_obj.get("notes") or [], + } + ) + # Single snapshots for rname in [ "apt_config", diff --git a/enroll/harvest.py b/enroll/harvest.py index ff62fb7..b64862e 100644 --- a/enroll/harvest.py +++ b/enroll/harvest.py @@ -5,10 +5,12 @@ import json import os import re import shutil +import shlex import stat +import subprocess # nosec import time from dataclasses import dataclass, asdict, field -from typing import Dict, List, Optional, Set +from typing import Dict, List, Optional, Set, Tuple from .systemd import ( list_enabled_services, @@ -148,6 +150,17 @@ class ExtraPathsSnapshot: notes: List[str] = field(default_factory=list) +@dataclass +class FirewallRuntimeSnapshot: + role_name: str + packages: List[str] = field(default_factory=list) + ipset_save: Optional[str] = None + ipset_sets: List[str] = field(default_factory=list) + iptables_v4_save: Optional[str] = None + iptables_v6_save: Optional[str] = None + notes: List[str] = field(default_factory=list) + + ALLOWED_UNOWNED_EXTS = { ".cfg", ".cnf", @@ -653,6 +666,13 @@ _SYSTEM_CAPTURE_GLOBS: List[tuple[str, str]] = [ ("/etc/nftables.d/*", "system_firewall"), ("/etc/iptables/rules.v4", "system_firewall"), ("/etc/iptables/rules.v6", "system_firewall"), + ("/etc/sysconfig/iptables", "system_firewall"), + ("/etc/sysconfig/ip6tables", "system_firewall"), + ("/etc/ipset.conf", "system_firewall"), + ("/etc/ipset/*", "system_firewall"), + ("/etc/ipset.d/*", "system_firewall"), + ("/etc/sysconfig/ipset", "system_firewall"), + ("/etc/default/ipset", "system_firewall"), ("/etc/ufw/*", "system_firewall"), ("/etc/default/ufw", "system_firewall"), ("/etc/firewalld/*", "system_firewall"), @@ -664,6 +684,46 @@ _SYSTEM_CAPTURE_GLOBS: List[tuple[str, str]] = [ ] +# Persistent firewall files that are treated as authoritative for their +# respective runtime state. If any matching file exists, the runtime capture +# for that family is retained only as static managed-file harvest output and +# not duplicated through the generated firewall_runtime role. +_PERSISTENT_IPTABLES_V4_GLOBS = [ + "/etc/iptables/rules.v4", + "/etc/sysconfig/iptables", +] + +_PERSISTENT_IPTABLES_V6_GLOBS = [ + "/etc/iptables/rules.v6", + "/etc/sysconfig/ip6tables", +] + +_PERSISTENT_IPSET_GLOBS = [ + "/etc/ipset.conf", + "/etc/ipset/*", + "/etc/ipset.d/*", + "/etc/sysconfig/ipset", +] + + +def _persistent_firewall_files(globs: List[str]) -> List[str]: + """Return persistent firewall files matching ``globs``. + + This intentionally uses the same file walking helper as the static system + capture path so the runtime fallback decision matches what Enroll can + harvest as managed files. + """ + seen: Set[str] = set() + out: List[str] = [] + for spec in globs: + for path in _iter_matching_files(spec): + if path in seen: + continue + seen.add(path) + out.append(path) + return sorted(out) + + def _iter_matching_files(spec: str, *, cap: int = MAX_FILES_CAP) -> List[str]: """Expand a glob spec and also walk directories to collect files.""" out: List[str] = [] @@ -854,6 +914,200 @@ def _iter_system_capture_paths() -> List[tuple[str, str]]: return uniq +_FIREWALL_CAPTURE_COMMANDS: Dict[str, Tuple[str, ...]] = { + "ipset_save": ("ipset", "save"), + "iptables_v4_save": ("iptables-save",), + "iptables_v6_save": ("ip6tables-save",), +} + + +def _run_capture_command( + command_key: str, *, timeout: int = 10 +) -> tuple[Optional[str], Optional[str]]: + """Return (stdout, error_note) for an allowlisted local state command. + + The command key is resolved through ``_FIREWALL_CAPTURE_COMMANDS`` so this + helper never executes caller-supplied argv. Commands are run with + ``shell=False`` explicitly to avoid shell interpretation. + """ + argv = _FIREWALL_CAPTURE_COMMANDS.get(command_key) + if argv is None: + return None, f"Unknown capture command: {command_key}" + + exe = argv[0] + if shutil.which(exe) is None: + return None, f"{exe} not found on PATH." + + try: + proc = subprocess.run( # nosec + argv, + shell=False, + check=False, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + timeout=timeout, + ) + except Exception as e: # noqa: BLE001 + return None, f"{' '.join(argv)} failed: {e!r}" + + if proc.returncode != 0: + stderr = (proc.stderr or "").strip() + if len(stderr) > 300: + stderr = stderr[:297] + "..." + return ( + None, + f"{' '.join(argv)} exited {proc.returncode}: {stderr or '(no stderr)'}", + ) + + return proc.stdout or "", None + + +def _write_generated_artifact( + bundle_dir: str, role_name: str, src_rel: str, content: str +) -> None: + """Write a generated harvest artifact that did not exist as a file on disk.""" + dst = os.path.join(bundle_dir, "artifacts", role_name, src_rel) + os.makedirs(os.path.dirname(dst), exist_ok=True) + with open(dst, "w", encoding="utf-8") as f: + f.write(content) + + +def _ipset_save_has_state(text: str) -> bool: + for raw in (text or "").splitlines(): + line = raw.strip() + if not line or line.startswith("#"): + continue + if line.startswith(("create ", "add ")): + return True + return False + + +def _parse_ipset_set_names(text: str) -> List[str]: + names: List[str] = [] + seen: Set[str] = set() + for raw in (text or "").splitlines(): + line = raw.strip() + if not line or line.startswith("#"): + continue + try: + toks = shlex.split(line) + except ValueError: + toks = line.split() + if len(toks) >= 2 and toks[0] == "create" and toks[1] not in seen: + seen.add(toks[1]) + names.append(toks[1]) + return names + + +def _iptables_save_has_state(text: str) -> bool: + """Return True when iptables-save output contains non-default state.""" + for raw in (text or "").splitlines(): + line = raw.strip() + if not line or line.startswith("#"): + continue + if line.startswith("*") or line == "COMMIT": + continue + if line.startswith(":"): + parts = line.split() + chain_name = parts[0][1:] if parts else "" + policy = parts[1] if len(parts) >= 2 else "" + # Built-in empty chains usually look like ':INPUT ACCEPT [0:0]'. + # A changed policy, or any custom chain, is meaningful state. + if policy not in ("ACCEPT", "-"): + return True + if policy == "-" and chain_name: + return True + continue + if line.startswith(("-A ", "-I ", "-N ", "-P ", "-R ")): + return True + return False + + +def _collect_firewall_runtime_snapshot( + bundle_dir: str, + *, + persistent_ipset_files: Optional[List[str]] = None, + persistent_iptables_v4_files: Optional[List[str]] = None, + persistent_iptables_v6_files: Optional[List[str]] = None, +) -> FirewallRuntimeSnapshot: + """Capture live kernel firewall state only when no persistent config exists. + + Enroll also harvests persistent firewall files such as + /etc/iptables/rules.v4, /etc/iptables/rules.v6, and /etc/ipset.conf as + managed files. The generated runtime restore role is therefore a fallback: + it captures each firewall family only when that family has no persistent + file to avoid generating two roles that try to manage the same state. + """ + role_name = "firewall_runtime" + packages: Set[str] = set() + notes: List[str] = [] + ipset_save_rel: Optional[str] = None + ipset_sets: List[str] = [] + iptables_v4_rel: Optional[str] = None + iptables_v6_rel: Optional[str] = None + + persistent_ipset_files = persistent_ipset_files or [] + persistent_iptables_v4_files = persistent_iptables_v4_files or [] + persistent_iptables_v6_files = persistent_iptables_v6_files or [] + + if persistent_ipset_files: + notes.append( + "Live ipset runtime capture skipped because persistent ipset " + f"configuration was found: {', '.join(persistent_ipset_files)}" + ) + else: + ipset_out, ipset_err = _run_capture_command("ipset_save") + if ipset_err: + notes.append(ipset_err) + elif ipset_out is not None and _ipset_save_has_state(ipset_out): + ipset_save_rel = "firewall/ipset.save" + _write_generated_artifact(bundle_dir, role_name, ipset_save_rel, ipset_out) + ipset_sets = _parse_ipset_set_names(ipset_out) + packages.add("ipset") + + if persistent_iptables_v4_files: + notes.append( + "Live IPv4 iptables runtime capture skipped because persistent " + f"IPv4 iptables configuration was found: {', '.join(persistent_iptables_v4_files)}" + ) + else: + ipt4_out, ipt4_err = _run_capture_command("iptables_v4_save") + if ipt4_err: + notes.append(ipt4_err) + elif ipt4_out is not None and _iptables_save_has_state(ipt4_out): + iptables_v4_rel = "firewall/iptables.v4" + _write_generated_artifact(bundle_dir, role_name, iptables_v4_rel, ipt4_out) + packages.add("iptables") + + if persistent_iptables_v6_files: + notes.append( + "Live IPv6 iptables runtime capture skipped because persistent " + f"IPv6 iptables configuration was found: {', '.join(persistent_iptables_v6_files)}" + ) + else: + ipt6_out, ipt6_err = _run_capture_command("iptables_v6_save") + if ipt6_err: + notes.append(ipt6_err) + elif ipt6_out is not None and _iptables_save_has_state(ipt6_out): + iptables_v6_rel = "firewall/iptables.v6" + _write_generated_artifact(bundle_dir, role_name, iptables_v6_rel, ipt6_out) + packages.add("iptables") + + # Package names are intentionally added only when matching live state was + # captured. Merely having iptables/ipset installed should not create a role. + + return FirewallRuntimeSnapshot( + role_name=role_name, + packages=sorted(packages), + ipset_save=ipset_save_rel, + ipset_sets=ipset_sets, + iptables_v4_save=iptables_v4_rel, + iptables_v6_save=iptables_v6_rel, + notes=notes, + ) + + def harvest( bundle_dir: str, policy: Optional[IgnorePolicy] = None, @@ -907,6 +1161,29 @@ def harvest( installed_pkgs = backend.installed_packages() or {} installed_names: Set[str] = set(installed_pkgs.keys()) + persistent_ipset_files = _persistent_firewall_files(_PERSISTENT_IPSET_GLOBS) + persistent_iptables_v4_files = _persistent_firewall_files( + _PERSISTENT_IPTABLES_V4_GLOBS + ) + persistent_iptables_v6_files = _persistent_firewall_files( + _PERSISTENT_IPTABLES_V6_GLOBS + ) + + if hasattr(os, "geteuid") and os.geteuid() != 0: + firewall_runtime_snapshot = FirewallRuntimeSnapshot( + role_name="firewall_runtime", + notes=[ + "Live ipset/iptables runtime capture skipped because harvest is not running as root." + ], + ) + else: + firewall_runtime_snapshot = _collect_firewall_runtime_snapshot( + bundle_dir, + persistent_ipset_files=persistent_ipset_files, + persistent_iptables_v4_files=persistent_iptables_v4_files, + persistent_iptables_v6_files=persistent_iptables_v6_files, + ) + def _pick_installed(cands: List[str]) -> Optional[str]: for c in cands: if c in installed_names: @@ -2121,6 +2398,7 @@ def harvest( pkg_names |= manual_set pkg_names |= set(pkg_units.keys()) pkg_names |= {ps.package for ps in pkg_snaps} + pkg_names |= set(firewall_runtime_snapshot.packages or []) packages_inventory: Dict[str, Dict[str, object]] = {} for pkg in sorted(pkg_names): @@ -2136,6 +2414,13 @@ def harvest( observed.append({"kind": "systemd_unit", "ref": unit}) for rn in sorted(set(pkg_role_names.get(pkg, []))): observed.append({"kind": "package_role", "ref": rn}) + if pkg in set(firewall_runtime_snapshot.packages or []): + observed.append( + {"kind": "firewall_runtime", "ref": firewall_runtime_snapshot.role_name} + ) + pkg_roles_map.setdefault(pkg, set()).add( + firewall_runtime_snapshot.role_name + ) roles = sorted(pkg_roles_map.get(pkg, set())) @@ -2219,6 +2504,7 @@ def harvest( "packages": [asdict(p) for p in pkg_snaps], "apt_config": asdict(apt_config_snapshot), "dnf_config": asdict(dnf_config_snapshot), + "firewall_runtime": asdict(firewall_runtime_snapshot), "etc_custom": asdict(etc_custom_snapshot), "usr_local_custom": asdict(usr_local_custom_snapshot), "extra_paths": asdict(extra_paths_snapshot), diff --git a/enroll/manifest.py b/enroll/manifest.py index 0186621..99adbb7 100644 --- a/enroll/manifest.py +++ b/enroll/manifest.py @@ -582,6 +582,97 @@ def _render_install_packages_tasks(role: str, var_prefix: str) -> str: """ +def _render_firewall_runtime_tasks(var_prefix: str) -> str: + """Render tasks for live ipset/iptables snapshots.""" + return f"""- name: Ensure firewall runtime snapshot directory exists + ansible.builtin.file: + path: /etc/enroll/firewall + state: directory + owner: root + group: root + mode: "0750" + +- name: Deploy captured ipset snapshot + vars: + _enroll_ff: + files: + - "{{{{ inventory_dir }}}}/host_vars/{{{{ inventory_hostname }}}}/{{{{ role_name }}}}/.files/{{{{ {var_prefix}_ipset_save }}}}" + - "{{{{ role_path }}}}/files/{{{{ {var_prefix}_ipset_save }}}}" + ansible.builtin.copy: + src: "{{{{ lookup('ansible.builtin.first_found', _enroll_ff) }}}}" + dest: /etc/enroll/firewall/ipset.save + owner: root + group: root + mode: "0600" + when: ({var_prefix}_ipset_save | default('') | length) > 0 + +- name: Flush captured ipsets before restoring members + ansible.builtin.command: + cmd: "ipset flush {{{{ item }}}}" + loop: "{{{{ {var_prefix}_ipset_sets | default([]) }}}}" + register: _enroll_ipset_flush + failed_when: false + changed_when: false + when: + - ({var_prefix}_ipset_save | default('') | length) > 0 + - {var_prefix}_sync_ipsets_exact | default(true) | bool + +- name: Restore captured ipsets + ansible.builtin.shell: "ipset restore -exist < /etc/enroll/firewall/ipset.save" + args: + executable: /bin/sh + register: _enroll_ipset_restore + changed_when: _enroll_ipset_restore.rc == 0 + when: ({var_prefix}_ipset_save | default('') | length) > 0 + +- name: Deploy captured IPv4 iptables snapshot + vars: + _enroll_ff: + files: + - "{{{{ inventory_dir }}}}/host_vars/{{{{ inventory_hostname }}}}/{{{{ role_name }}}}/.files/{{{{ {var_prefix}_iptables_v4_save }}}}" + - "{{{{ role_path }}}}/files/{{{{ {var_prefix}_iptables_v4_save }}}}" + ansible.builtin.copy: + src: "{{{{ lookup('ansible.builtin.first_found', _enroll_ff) }}}}" + dest: /etc/enroll/firewall/iptables.v4 + owner: root + group: root + mode: "0600" + when: ({var_prefix}_iptables_v4_save | default('') | length) > 0 + +- name: Restore captured IPv4 iptables rules + ansible.builtin.command: + cmd: iptables-restore /etc/enroll/firewall/iptables.v4 + register: _enroll_iptables_v4_restore + changed_when: _enroll_iptables_v4_restore.rc == 0 + when: + - ({var_prefix}_iptables_v4_save | default('') | length) > 0 + - {var_prefix}_restore_iptables | default(true) | bool + +- name: Deploy captured IPv6 iptables snapshot + vars: + _enroll_ff: + files: + - "{{{{ inventory_dir }}}}/host_vars/{{{{ inventory_hostname }}}}/{{{{ role_name }}}}/.files/{{{{ {var_prefix}_iptables_v6_save }}}}" + - "{{{{ role_path }}}}/files/{{{{ {var_prefix}_iptables_v6_save }}}}" + ansible.builtin.copy: + src: "{{{{ lookup('ansible.builtin.first_found', _enroll_ff) }}}}" + dest: /etc/enroll/firewall/iptables.v6 + owner: root + group: root + mode: "0600" + when: ({var_prefix}_iptables_v6_save | default('') | length) > 0 + +- name: Restore captured IPv6 iptables rules + ansible.builtin.command: + cmd: ip6tables-restore /etc/enroll/firewall/iptables.v6 + register: _enroll_iptables_v6_restore + changed_when: _enroll_iptables_v6_restore.rc == 0 + when: + - ({var_prefix}_iptables_v6_save | default('') | length) > 0 + - {var_prefix}_restore_iptables | default(true) | bool +""" + + def _prepare_bundle_dir( bundle: str, *, @@ -746,6 +837,7 @@ def _manifest_from_bundle_dir( users_snapshot: Dict[str, Any] = roles.get("users", {}) apt_config_snapshot: Dict[str, Any] = roles.get("apt_config", {}) dnf_config_snapshot: Dict[str, Any] = roles.get("dnf_config", {}) + firewall_runtime_snapshot: Dict[str, Any] = roles.get("firewall_runtime", {}) etc_custom_snapshot: Dict[str, Any] = roles.get("etc_custom", {}) usr_local_custom_snapshot: Dict[str, Any] = roles.get("usr_local_custom", {}) extra_paths_snapshot: Dict[str, Any] = roles.get("extra_paths", {}) @@ -782,6 +874,7 @@ def _manifest_from_bundle_dir( manifested_users_roles: List[str] = [] manifested_apt_config_roles: List[str] = [] manifested_dnf_config_roles: List[str] = [] + manifested_firewall_runtime_roles: List[str] = [] manifested_etc_custom_roles: List[str] = [] manifested_usr_local_custom_roles: List[str] = [] manifested_extra_paths_roles: List[str] = [] @@ -1332,6 +1425,104 @@ DNF/YUM configuration harvested from the system (repos, config files, and RPM GP manifested_dnf_config_roles.append(role) + # ------------------------- + # firewall_runtime role (live ipset/iptables kernel state) + # ------------------------- + if firewall_runtime_snapshot and ( + firewall_runtime_snapshot.get("ipset_save") + or firewall_runtime_snapshot.get("iptables_v4_save") + or firewall_runtime_snapshot.get("iptables_v6_save") + ): + role = firewall_runtime_snapshot.get("role_name", "firewall_runtime") + role_dir = os.path.join(roles_root, role) + _write_role_scaffold(role_dir) + + var_prefix = role + packages = firewall_runtime_snapshot.get("packages", []) or [] + ipset_save = firewall_runtime_snapshot.get("ipset_save") or "" + ipset_sets = firewall_runtime_snapshot.get("ipset_sets", []) or [] + iptables_v4_save = firewall_runtime_snapshot.get("iptables_v4_save") or "" + iptables_v6_save = firewall_runtime_snapshot.get("iptables_v6_save") or "" + notes = firewall_runtime_snapshot.get("notes", []) or [] + + # Generated firewall snapshots are host-specific in site mode. + if site_mode: + _copy_artifacts( + bundle_dir, + role, + _host_role_files_dir(out_dir, fqdn or "", role), + ) + else: + _copy_artifacts(bundle_dir, role, os.path.join(role_dir, "files")) + + vars_map: Dict[str, Any] = { + f"{var_prefix}_packages": packages, + f"{var_prefix}_ipset_save": ipset_save, + f"{var_prefix}_ipset_sets": ipset_sets, + f"{var_prefix}_iptables_v4_save": iptables_v4_save, + f"{var_prefix}_iptables_v6_save": iptables_v6_save, + f"{var_prefix}_sync_ipsets_exact": True, + f"{var_prefix}_restore_iptables": True, + } + + if site_mode: + _write_role_defaults( + role_dir, + { + f"{var_prefix}_packages": [], + f"{var_prefix}_ipset_save": "", + f"{var_prefix}_ipset_sets": [], + f"{var_prefix}_iptables_v4_save": "", + f"{var_prefix}_iptables_v6_save": "", + f"{var_prefix}_sync_ipsets_exact": True, + f"{var_prefix}_restore_iptables": True, + }, + ) + _write_hostvars(out_dir, fqdn or "", role, vars_map) + else: + _write_role_defaults(role_dir, vars_map) + + tasks = ( + "---\n" + + _render_install_packages_tasks(role, var_prefix) + + _render_firewall_runtime_tasks(var_prefix) + ) + with open( + os.path.join(role_dir, "tasks", "main.yml"), "w", encoding="utf-8" + ) as f: + f.write(tasks.rstrip() + "\n") + + with open( + os.path.join(role_dir, "meta", "main.yml"), "w", encoding="utf-8" + ) as f: + f.write("---\ndependencies: []\n") + + readme = f"""# {role} + +Generated from live firewall runtime state captured during harvest. + +This role restores live ipset and iptables state only for firewall families where Enroll did not find corresponding persistent configuration on the source host. Static firewall configuration files, such as `/etc/iptables/rules.v4`, `/etc/iptables/rules.v6`, UFW, nftables, firewalld, or `/etc/ipset*`, are harvested separately as managed files and treated as authoritative for their respective family. + +## Captured snapshots +- ipset: {ipset_save or "(none)"} +- iptables IPv4: {iptables_v4_save or "(none)"} +- iptables IPv6: {iptables_v6_save or "(none)"} + +## Captured ipsets +{os.linesep.join("- " + x for x in ipset_sets) or "- (none)"} + +## Notes +{os.linesep.join("- " + n for n in notes) or "- (none)"} + +## Safety notes +- `firewall_runtime_sync_ipsets_exact` defaults to `true`; it flushes captured set members before replaying the saved members so stale entries are removed. This applies only when no persistent ipset config was found. +- `firewall_runtime_restore_iptables` defaults to `true`; `iptables-restore`/`ip6tables-restore` replace only captured families. A family is captured only when no corresponding persistent iptables config was found. +""" + with open(os.path.join(role_dir, "README.md"), "w", encoding="utf-8") as f: + f.write(readme) + + manifested_firewall_runtime_roles.append(role) + # ------------------------- # etc_custom role (unowned /etc not already attributed) # ------------------------- @@ -2012,6 +2203,7 @@ Generated for package `{pkg}`. + manifested_extra_paths_roles + manifested_users_roles + tail_roles + + manifested_firewall_runtime_roles ) if site_mode: diff --git a/enroll/schema/state.schema.json b/enroll/schema/state.schema.json index 083f90f..d0bde52 100644 --- a/enroll/schema/state.schema.json +++ b/enroll/schema/state.schema.json @@ -60,7 +60,7 @@ "enum": [ "user_excluded", "unreadable", - "backup_file", + "backup_file", "log_file", "denied_path", "too_large", @@ -315,6 +315,23 @@ "ref" ], "type": "object" + }, + { + "additionalProperties": false, + "properties": { + "kind": { + "const": "firewall_runtime" + }, + "ref": { + "minLength": 1, + "type": "string" + } + }, + "required": [ + "kind", + "ref" + ], + "type": "object" } ] }, @@ -579,6 +596,62 @@ } ], "unevaluatedProperties": false + }, + "FirewallRuntimeSnapshot": { + "additionalProperties": false, + "properties": { + "role_name": { + "const": "firewall_runtime" + }, + "packages": { + "items": { + "minLength": 1, + "type": "string" + }, + "type": "array" + }, + "ipset_save": { + "type": [ + "string", + "null" + ] + }, + "ipset_sets": { + "items": { + "minLength": 1, + "type": "string" + }, + "type": "array" + }, + "iptables_v4_save": { + "type": [ + "string", + "null" + ] + }, + "iptables_v6_save": { + "type": [ + "string", + "null" + ] + }, + "notes": { + "items": { + "type": "string" + }, + "type": "array" + } + }, + "required": [ + "role_name", + "packages", + "ipset_save", + "ipset_sets", + "iptables_v4_save", + "iptables_v6_save", + "notes" + ], + "type": "object" } }, "$id": "https://enroll.sh/schema/state.schema.json", @@ -686,6 +759,9 @@ }, "usr_local_custom": { "$ref": "#/$defs/UsrLocalCustomSnapshot" + }, + "firewall_runtime": { + "$ref": "#/$defs/FirewallRuntimeSnapshot" } }, "required": [ diff --git a/enroll/validate.py b/enroll/validate.py index 5a8fa88..f3291e9 100644 --- a/enroll/validate.py +++ b/enroll/validate.py @@ -197,6 +197,37 @@ def validate_harvest( f"artifact is not a file for role {role_name}: artifacts/{role_name}/{src_rel}" ) + # Runtime firewall snapshots are generated artifacts rather than managed files. + fw = (state.get("roles") or {}).get("firewall_runtime") or {} + if isinstance(fw, dict): + for key in ("ipset_save", "iptables_v4_save", "iptables_v6_save"): + src_rel = str(fw.get(key) or "") + if not src_rel: + continue + if src_rel.startswith("/") or ".." in src_rel.split("/"): + errors.append( + f"firewall_runtime {key} has suspicious src_rel: {src_rel!r}" + ) + continue + referenced.add( + (str(fw.get("role_name") or "firewall_runtime"), src_rel) + ) + p = ( + artifacts_dir + / str(fw.get("role_name") or "firewall_runtime") + / src_rel + ) + if not p.exists(): + errors.append( + "missing firewall runtime artifact: " + f"artifacts/{fw.get('role_name') or 'firewall_runtime'}/{src_rel}" + ) + elif not p.is_file(): + errors.append( + "firewall runtime artifact is not a file: " + f"artifacts/{fw.get('role_name') or 'firewall_runtime'}/{src_rel}" + ) + # Warn if there are extra files in artifacts not referenced. if artifacts_dir.exists() and artifacts_dir.is_dir(): for fp in artifacts_dir.rglob("*"): diff --git a/poetry.lock b/poetry.lock index c94436e..b338a10 100644 --- a/poetry.lock +++ b/poetry.lock @@ -562,13 +562,13 @@ test = ["pytest (>=6)"] [[package]] name = "idna" -version = "3.14" +version = "3.15" description = "Internationalized Domain Names in Applications (IDNA)" optional = false python-versions = ">=3.8" files = [ - {file = "idna-3.14-py3-none-any.whl", hash = "sha256:e677eaf072e290f7b725f9acf0b3a2bd55f9fd6f7c70abe5f0e34823d0accf69"}, - {file = "idna-3.14.tar.gz", hash = "sha256:466d810d7a2cc1022bea9b037c39728d51ae7dad40d480fc9b7d7ecf98ba8ee3"}, + {file = "idna-3.15-py3-none-any.whl", hash = "sha256:048adeaf8c2d788c40fee287673ccaa74c24ffd8dcf09ffa555a2fbb59f10ac8"}, + {file = "idna-3.15.tar.gz", hash = "sha256:ca962446ea538f7092a95e057da437618e886f4d349216d2b1e294abfdb65fdc"}, ] [package.extras] @@ -897,13 +897,13 @@ typing-extensions = {version = ">=4.4.0", markers = "python_version < \"3.13\""} [[package]] name = "requests" -version = "2.34.0" +version = "2.34.1" description = "Python HTTP for Humans." optional = false python-versions = ">=3.10" files = [ - {file = "requests-2.34.0-py3-none-any.whl", hash = "sha256:917520a21b767485ce7c588f4ebb917c436b24a31231b44228715eaeb5a52c60"}, - {file = "requests-2.34.0.tar.gz", hash = "sha256:7d62fe92f50eb82c529b0916bb445afa1531a566fc8f35ffdc64446e771b856a"}, + {file = "requests-2.34.1-py3-none-any.whl", hash = "sha256:bf38a3ff993960d3dd819c08862c40b3c703306eb7c744fcd9f4ddbb95b548f0"}, + {file = "requests-2.34.1.tar.gz", hash = "sha256:0fc5669f2b69704449fe1552360bd2a73a54512dfd03e65529157f1513322beb"}, ] [package.dependencies] diff --git a/pyproject.toml b/pyproject.toml index 4afda15..a7a83d0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "enroll" -version = "0.5.0" +version = "0.6.0" description = "Enroll a server's running state retrospectively into Ansible" authors = ["Miguel Jacq "] license = "GPL-3.0-or-later" diff --git a/rpm/enroll.spec b/rpm/enroll.spec index 2980f32..0e83c84 100644 --- a/rpm/enroll.spec +++ b/rpm/enroll.spec @@ -1,4 +1,4 @@ -%global upstream_version 0.5.0 +%global upstream_version 0.6.0 Name: enroll Version: %{upstream_version} @@ -43,6 +43,9 @@ Enroll a server's running state retrospectively into Ansible. %{_bindir}/enroll %changelog +* Thu May 14 2026 Miguel Jacq - %{version}-%{release} +- Add support for capturing ipset and iptables configuration files +- Add support for generating ipset and iptables configuration files from runtime, if the former weren't present ('firewall_runtime' role) * Tue May 12 2026 Miguel Jacq - %{version}-%{release} - Add ssh config support where JinjaTurtle is used * Tue Feb 16 2026 Miguel Jacq - %{version}-%{release} diff --git a/tests/test_harvest_helpers.py b/tests/test_harvest_helpers.py index 531a62c..a0d2c91 100644 --- a/tests/test_harvest_helpers.py +++ b/tests/test_harvest_helpers.py @@ -168,3 +168,121 @@ def test_iter_system_capture_paths_dedupes_first_reason(monkeypatch): ) out = h._iter_system_capture_paths() assert out == [("/dup", "r1")] + + +def test_ipset_and_iptables_state_helpers(tmp_path: Path): + ipset_save = """create blocklist hash:ip family inet hashsize 1024 maxelem 65536 +add blocklist 203.0.113.10 +create nets hash:net family inet +""" + assert h._ipset_save_has_state(ipset_save) + assert h._parse_ipset_set_names(ipset_save) == ["blocklist", "nets"] + assert not h._ipset_save_has_state("# empty\n") + + empty_iptables = """*filter +:INPUT ACCEPT [0:0] +:FORWARD ACCEPT [0:0] +:OUTPUT ACCEPT [0:0] +COMMIT +""" + assert not h._iptables_save_has_state(empty_iptables) + + native_rule = empty_iptables.replace( + "COMMIT", "-A INPUT -p tcp --dport 22 -j ACCEPT\nCOMMIT" + ) + assert h._iptables_save_has_state(native_rule) + + changed_policy = empty_iptables.replace(":INPUT ACCEPT", ":INPUT DROP") + assert h._iptables_save_has_state(changed_policy) + + +def test_collect_firewall_runtime_snapshot_writes_generated_artifacts( + monkeypatch, tmp_path: Path +): + outputs = { + "ipset_save": ( + "create blocklist hash:ip family inet\nadd blocklist 203.0.113.10\n", + None, + ), + "iptables_v4_save": ( + "*filter\n:INPUT DROP [0:0]\n-A INPUT -m set --match-set blocklist src -j DROP\nCOMMIT\n", + None, + ), + "iptables_v6_save": ("*filter\n:INPUT ACCEPT [0:0]\nCOMMIT\n", None), + } + + def fake_run(command_key, *, timeout=10): + return outputs[command_key] + + monkeypatch.setattr(h, "_run_capture_command", fake_run) + + snap = h._collect_firewall_runtime_snapshot(str(tmp_path)) + assert snap.role_name == "firewall_runtime" + assert snap.packages == ["ipset", "iptables"] + assert snap.ipset_save == "firewall/ipset.save" + assert snap.ipset_sets == ["blocklist"] + assert snap.iptables_v4_save == "firewall/iptables.v4" + assert snap.iptables_v6_save is None + + assert ( + (tmp_path / "artifacts" / "firewall_runtime" / "firewall" / "ipset.save") + .read_text(encoding="utf-8") + .startswith("create blocklist") + ) + assert ( + (tmp_path / "artifacts" / "firewall_runtime" / "firewall" / "iptables.v4") + .read_text(encoding="utf-8") + .startswith("*filter") + ) + + +def test_collect_firewall_runtime_snapshot_is_per_family_fallback( + monkeypatch, tmp_path: Path +): + calls = [] + outputs = { + "ipset_save": ( + "create blocklist hash:ip family inet\nadd blocklist 203.0.113.10\n", + None, + ), + "iptables_v4_save": ( + "*filter\n:INPUT DROP [0:0]\n-A INPUT -p tcp --dport 22 -j ACCEPT\nCOMMIT\n", + None, + ), + "iptables_v6_save": ( + "*filter\n:INPUT DROP [0:0]\n-A INPUT -p tcp --dport 22 -j ACCEPT\nCOMMIT\n", + None, + ), + } + + def fake_run(command_key, *, timeout=10): + calls.append(command_key) + return outputs[command_key] + + monkeypatch.setattr(h, "_run_capture_command", fake_run) + + snap = h._collect_firewall_runtime_snapshot( + str(tmp_path), + persistent_ipset_files=["/etc/ipset.conf"], + persistent_iptables_v4_files=["/etc/iptables/rules.v4"], + persistent_iptables_v6_files=[], + ) + + assert "ipset_save" not in calls + assert "iptables_v4_save" not in calls + assert "iptables_v6_save" in calls + assert snap.ipset_save is None + assert snap.iptables_v4_save is None + assert snap.iptables_v6_save == "firewall/iptables.v6" + assert snap.packages == ["iptables"] + assert any("persistent ipset configuration" in note for note in snap.notes) + assert any("persistent IPv4 iptables configuration" in note for note in snap.notes) + assert not ( + tmp_path / "artifacts" / "firewall_runtime" / "firewall" / "ipset.save" + ).exists() + assert not ( + tmp_path / "artifacts" / "firewall_runtime" / "firewall" / "iptables.v4" + ).exists() + assert ( + tmp_path / "artifacts" / "firewall_runtime" / "firewall" / "iptables.v6" + ).exists() diff --git a/tests/test_manifest.py b/tests/test_manifest.py index 073fd6d..658d77f 100644 --- a/tests/test_manifest.py +++ b/tests/test_manifest.py @@ -795,3 +795,100 @@ def test_manifest_applies_jinjaturtle_to_jinjifyable_managed_file( assert not ( out_dir / "roles" / "apt_config" / "files" / "etc" / "apt" / "foo.ini" ).exists() + + +def test_manifest_writes_firewall_runtime_role(tmp_path: Path): + bundle = tmp_path / "bundle" + out = tmp_path / "ansible" + (bundle / "artifacts" / "firewall_runtime" / "firewall").mkdir( + parents=True, exist_ok=True + ) + (bundle / "artifacts" / "firewall_runtime" / "firewall" / "ipset.save").write_text( + "create blocklist hash:ip family inet\nadd blocklist 203.0.113.10\n", + encoding="utf-8", + ) + (bundle / "artifacts" / "firewall_runtime" / "firewall" / "iptables.v4").write_text( + "*filter\n:INPUT DROP [0:0]\n-A INPUT -m set --match-set blocklist src -j DROP\nCOMMIT\n", + encoding="utf-8", + ) + + state = { + "schema_version": 3, + "host": {"hostname": "test", "os": "debian", "pkg_backend": "dpkg"}, + "inventory": {"packages": {}}, + "roles": { + "users": { + "role_name": "users", + "users": [], + "managed_files": [], + "excluded": [], + "notes": [], + }, + "services": [], + "packages": [], + "apt_config": { + "role_name": "apt_config", + "managed_files": [], + "excluded": [], + "notes": [], + }, + "dnf_config": { + "role_name": "dnf_config", + "managed_files": [], + "excluded": [], + "notes": [], + }, + "firewall_runtime": { + "role_name": "firewall_runtime", + "packages": ["ipset", "iptables"], + "ipset_save": "firewall/ipset.save", + "ipset_sets": ["blocklist"], + "iptables_v4_save": "firewall/iptables.v4", + "iptables_v6_save": None, + "notes": [], + }, + "etc_custom": { + "role_name": "etc_custom", + "managed_files": [], + "excluded": [], + "notes": [], + }, + "usr_local_custom": { + "role_name": "usr_local_custom", + "managed_files": [], + "excluded": [], + "notes": [], + }, + "extra_paths": { + "role_name": "extra_paths", + "include_patterns": [], + "exclude_patterns": [], + "managed_files": [], + "excluded": [], + "notes": [], + }, + }, + } + (bundle / "state.json").write_text(json.dumps(state, indent=2), encoding="utf-8") + + manifest.manifest(str(bundle), str(out)) + + tasks = (out / "roles" / "firewall_runtime" / "tasks" / "main.yml").read_text( + encoding="utf-8" + ) + assert "ipset restore -exist" in tasks + assert "iptables-restore /etc/enroll/firewall/iptables.v4" in tasks + assert "ipset flush {{ item }}" in tasks + + defaults = (out / "roles" / "firewall_runtime" / "defaults" / "main.yml").read_text( + encoding="utf-8" + ) + assert "firewall_runtime_ipset_sets:" in defaults + assert "- blocklist" in defaults + assert "firewall_runtime_restore_iptables: true" in defaults + + pb = (out / "playbook.yml").read_text(encoding="utf-8") + assert "role: firewall_runtime" in pb + assert ( + out / "roles" / "firewall_runtime" / "files" / "firewall" / "ipset.save" + ).exists() From 1544dc0295c19aa870c023daac4b6f0631abdcfe Mon Sep 17 00:00:00 2001 From: Miguel Jacq Date: Sun, 31 May 2026 16:50:57 +1000 Subject: [PATCH 02/36] more test coverage --- .gitignore | 1 + tests/test_accounts.py | 144 ++++++ tests/test_debian.py | 241 +++++++++ tests/test_diff_bundle.py | 992 ++++++++++++++++++++++++++++++++++++ tests/test_harvest.py | 147 ++++++ tests/test_ignore.py | 240 +++++++++ tests/test_manifest.py | 172 +++++++ tests/test_misc_coverage.py | 416 --------------- tests/test_pathfilter.py | 154 ++++++ tests/test_platform.py | 173 +++++++ tests/test_remote.py | 449 ++++++++++++++++ tests/test_rpm.py | 31 ++ tests/test_sopsutil.py | 54 ++ tests/test_systemd.py | 129 ++++- tests/test_validate.py | 231 +++++++++ 15 files changed, 3150 insertions(+), 424 deletions(-) delete mode 100644 tests/test_misc_coverage.py create mode 100644 tests/test_sopsutil.py diff --git a/.gitignore b/.gitignore index 07c956d..4ef962d 100644 --- a/.gitignore +++ b/.gitignore @@ -8,3 +8,4 @@ dist *.pdf *.csv *.html +coverage.xml diff --git a/tests/test_accounts.py b/tests/test_accounts.py index d5cc267..9e60b57 100644 --- a/tests/test_accounts.py +++ b/tests/test_accounts.py @@ -141,3 +141,147 @@ def test_collect_non_system_users(monkeypatch, tmp_path: Path): assert u.primary_group == "users" assert u.supplementary_groups == ["admins"] assert u.ssh_files == ["/home/alice/.ssh/authorized_keys"] + + +def test_parse_login_defs_file_not_found(tmp_path: Path): + from enroll.accounts import parse_login_defs + + nonexistent = tmp_path / "nonexistent" / "login.defs" + vals = parse_login_defs(str(nonexistent)) + assert vals == {} + + +def test_parse_login_defs_handles_invalid_numbers(tmp_path: Path): + from enroll.accounts import parse_login_defs + + p = tmp_path / "login.defs" + p.write_text("UID_MIN not_a_number\nUID_MAX 60000\n", encoding="utf-8") + vals = parse_login_defs(str(p)) + assert "UID_MIN" not in vals + assert vals["UID_MAX"] == 60000 + + +def test_parse_group_handles_invalid_gid(tmp_path: Path): + from enroll.accounts import parse_group + + p = tmp_path / "group" + p.write_text( + "valid:x:1000:user1\n" "invalid_gid:x:notanint:user2\n", + encoding="utf-8", + ) + gid_to_name, name_to_gid, members = parse_group(str(p)) + assert 1000 in gid_to_name + assert gid_to_name[1000] == "valid" + assert "invalid_gid" not in name_to_gid + + +def test_parse_group_line_too_short(tmp_path: Path): + from enroll.accounts import parse_group + + p = tmp_path / "group" + p.write_text( + "valid:x:1000:user1\n" "shortline:x:1001\n", + encoding="utf-8", + ) + gid_to_name, name_to_gid, members = parse_group(str(p)) + assert 1000 in gid_to_name + assert 1001 not in gid_to_name + + +def test_is_human_user_filters_by_uid_and_shell(): + from enroll.accounts import is_human_user + + assert is_human_user(1000, "/bin/bash", 1000) is True + assert is_human_user(999, "/bin/bash", 1000) is False + assert is_human_user(1000, "/usr/sbin/nologin", 1000) is False + assert is_human_user(1000, "/usr/bin/nologin", 1000) is False + assert is_human_user(1000, "/bin/false", 1000) is False + assert is_human_user(1000, "", 1000) is True + + +def test_find_user_ssh_files_no_ssh_dir(tmp_path: Path): + from enroll.accounts import find_user_ssh_files + + home = tmp_path / "home" / "user" + home.mkdir(parents=True) + assert find_user_ssh_files(str(home)) == [] + + +def test_find_user_ssh_files_ignores_symlink(tmp_path: Path): + from enroll.accounts import find_user_ssh_files + + home = tmp_path / "home" / "user" + sshdir = home / ".ssh" + sshdir.mkdir(parents=True) + target = sshdir / "real_file" + target.write_text("x", encoding="utf-8") + os.symlink(str(target), str(sshdir / "authorized_keys")) + + result = find_user_ssh_files(str(home)) + assert result == [] + + +def test_find_user_ssh_files_handles_home_not_starting_with_slash(): + from enroll.accounts import find_user_ssh_files + + assert find_user_ssh_files("relative/path") == [] + assert find_user_ssh_files("") == [] + + +def test_collect_non_system_users_skips_nologin_users(tmp_path: Path): + import enroll.accounts as a + + orig_parse_login_defs = a.parse_login_defs + orig_parse_passwd = a.parse_passwd + orig_parse_group = a.parse_group + + passwd = tmp_path / "passwd" + passwd.write_text( + "root:x:0:0:root:/root:/bin/bash\n" + "alice:x:1000:1000:Alice:/home/alice:/bin/bash\n" + "nobody:x:65534:65534:nobody:/nonexistent:/usr/sbin/nologin\n" + "sysuser:x:100:100:Sys:/home/sys:/bin/bash\n", + encoding="utf-8", + ) + group = tmp_path / "group" + group.write_text("users:x:1000:alice\n", encoding="utf-8") + defs = tmp_path / "login.defs" + defs.write_text("UID_MIN 1000\n", encoding="utf-8") + + monkeypatch_wrapper = lambda fn, p: lambda path=str(p): fn(path) + + a.parse_login_defs = monkeypatch_wrapper(orig_parse_login_defs, defs) + a.parse_passwd = monkeypatch_wrapper(orig_parse_passwd, passwd) + a.parse_group = monkeypatch_wrapper(orig_parse_group, group) + a.find_user_ssh_files = lambda home: [] + + users = a.collect_non_system_users() + assert [u.name for u in users] == ["alice"] + + +def test_collect_non_system_users_skips_below_uid_min(tmp_path: Path): + import enroll.accounts as a + + orig_parse_login_defs = a.parse_login_defs + orig_parse_passwd = a.parse_passwd + orig_parse_group = a.parse_group + + passwd = tmp_path / "passwd" + passwd.write_text( + "root:x:0:0:root:/root:/bin/bash\n" + "sysuser:x:999:999:Sys:/home/sys:/bin/bash\n" + "alice:x:1000:1000:Alice:/home/alice:/bin/bash\n", + encoding="utf-8", + ) + group = tmp_path / "group" + group.write_text("users:x:1000:alice\n", encoding="utf-8") + defs = tmp_path / "login.defs" + defs.write_text("UID_MIN 1000\n", encoding="utf-8") + + a.parse_login_defs = lambda path=str(defs): orig_parse_login_defs(path) + a.parse_passwd = lambda path=str(passwd): orig_parse_passwd(path) + a.parse_group = lambda path=str(group): orig_parse_group(path) + a.find_user_ssh_files = lambda home: [] + + users = a.collect_non_system_users() + assert [u.name for u in users] == ["alice"] diff --git a/tests/test_debian.py b/tests/test_debian.py index abad361..818ee8a 100644 --- a/tests/test_debian.py +++ b/tests/test_debian.py @@ -96,3 +96,244 @@ def test_parse_status_conffiles_handles_continuations(tmp_path: Path): assert m["nginx"]["/etc/nginx/nginx.conf"] == "abcdef" assert m["nginx"]["/etc/nginx/mime.types"] == "123456" assert "other" not in m + + +def test_dpkg_owner_returns_none_on_diversion_only(monkeypatch): + import enroll.debian as d + + class P: + def __init__(self, rc: int, out: str): + self.returncode = rc + self.stdout = out + self.stderr = "" + + def fake_run(cmd, text, capture_output): + return P(0, "diversion by foo from: /etc/something\n") + + monkeypatch.setattr(d.subprocess, "run", fake_run) + assert d.dpkg_owner("/etc/something") is None + + +def test_dpkg_owner_handles_line_without_colon(monkeypatch): + import enroll.debian as d + + class P: + def __init__(self, rc: int, out: str): + self.returncode = rc + self.stdout = out + self.stderr = "" + + def fake_run(cmd, text, capture_output): + return P(0, "invalid line without colon\n") + + monkeypatch.setattr(d.subprocess, "run", fake_run) + assert d.dpkg_owner("/etc/foo") is None + + +def test_list_manual_packages_returns_empty_on_error(monkeypatch): + import enroll.debian as d + + class P: + def __init__(self, rc: int, out: str): + self.returncode = rc + self.stdout = out + self.stderr = "" + + def fake_run(cmd, text, capture_output): + return P(1, "error") + + monkeypatch.setattr(d.subprocess, "run", fake_run) + assert d.list_manual_packages() == [] + + +def test_list_installed_packages_handles_exception(monkeypatch): + import enroll.debian as d + + def fake_run(*args, **kwargs): + raise Exception("simulated error") + + monkeypatch.setattr(d.subprocess, "run", fake_run) + assert d.list_installed_packages() == {} + + +def test_list_installed_packages_parses_output(): + import enroll.debian as d + + class P: + def __init__(self, rc: int, out: str): + self.returncode = rc + self.stdout = out + self.stderr = "" + + original_run = d.subprocess.run + + def fake_run(cmd, text, capture_output, check): + return P(0, "nginx\t1.18.0\tamd64\nvim\t8.2\tamd64\n") + + d.subprocess.run = fake_run + try: + result = d.list_installed_packages() + assert "nginx" in result + assert result["nginx"][0]["version"] == "1.18.0" + assert result["nginx"][0]["arch"] == "amd64" + assert "vim" in result + finally: + d.subprocess.run = original_run + + +def test_list_installed_packages_skips_invalid_lines(): + import enroll.debian as d + + class P: + def __init__(self, rc: int, out: str): + self.returncode = rc + self.stdout = out + self.stderr = "" + + original_run = d.subprocess.run + + def fake_run(cmd, text, capture_output, check): + return P(0, "nginx\t1.18.0\tamd64\ninvalid_line\n\t1.0\tamd64\n") + + d.subprocess.run = fake_run + try: + result = d.list_installed_packages() + assert "nginx" in result + assert "invalid_line" not in result + finally: + d.subprocess.run = original_run + + +def test_list_installed_packages_handles_empty_name(): + import enroll.debian as d + + class P: + def __init__(self, rc: int, out: str): + self.returncode = rc + self.stdout = out + self.stderr = "" + + original_run = d.subprocess.run + + def fake_run(cmd, text, capture_output, check): + return P(0, "\t1.0\tamd64\nnginx\t1.18.0\tamd64\n") + + d.subprocess.run = fake_run + try: + result = d.list_installed_packages() + assert "" not in result + assert "nginx" in result + finally: + d.subprocess.run = original_run + + +def test_list_installed_packages_sorts_output(): + import enroll.debian as d + + class P: + def __init__(self, rc: int, out: str): + self.returncode = rc + self.stdout = out + self.stderr = "" + + original_run = d.subprocess.run + + def fake_run(cmd, text, capture_output, check): + return P(0, "nginx\t1.18.0\tamd64\nnginx\t1.19.0\tarm64\n") + + d.subprocess.run = fake_run + try: + result = d.list_installed_packages() + assert len(result["nginx"]) == 2 + assert result["nginx"][0]["arch"] == "amd64" + assert result["nginx"][1]["arch"] == "arm64" + finally: + d.subprocess.run = original_run + + +def test_build_dpkg_etc_index_handles_missing_file(tmp_path: Path): + import enroll.debian as d + + info = tmp_path / "info" + info.mkdir() + # Don't create any .list files + + owned, owner_map, topdir_to_pkgs, pkg_to_etc = d.build_dpkg_etc_index(str(info)) + assert owned == set() + assert owner_map == {} + assert topdir_to_pkgs == {} + assert pkg_to_etc == {} + + +def test_build_dpkg_etc_index_skips_non_etc_paths(tmp_path: Path): + import enroll.debian as d + + info = tmp_path / "info" + info.mkdir() + (info / "foo.list").write_text("/usr/bin/foo\n/etc/bar\n", encoding="utf-8") + + owned, owner_map, topdir_to_pkgs, pkg_to_etc = d.build_dpkg_etc_index(str(info)) + assert "/usr/bin/foo" not in owned + assert "/etc/bar" in owned + assert "foo" not in topdir_to_pkgs + + +def test_parse_status_conffiles_handles_empty_status(tmp_path: Path): + import enroll.debian as d + + status = tmp_path / "status" + status.write_text("", encoding="utf-8") + m = d.parse_status_conffiles(str(status)) + assert m == {} + + +def test_parse_status_conffiles_handles_package_without_conffiles(tmp_path: Path): + import enroll.debian as d + + status = tmp_path / "status" + status.write_text( + "Package: nginx\nVersion: 1\nStatus: install ok installed\n", + encoding="utf-8", + ) + m = d.parse_status_conffiles(str(status)) + assert m == {} + + +def test_read_pkg_md5sums_returns_empty_if_file_not_exists(tmp_path: Path): + import enroll.debian as d + + result = d.read_pkg_md5sums("nonexistent_package") + assert result == {} + + +def test_read_pkg_md5sums_parses_md5sums_file(tmp_path: Path, monkeypatch): + import enroll.debian as d + + info_dir = tmp_path / "info" + info_dir.mkdir() + md5_file = info_dir / "nginx.md5sums" + md5_file.write_text( + "abcdef1234567890abcdef1234567890 etc/nginx/nginx.conf\n" + "1234567890abcdef1234567890abcdef etc/nginx/sites-enabled/default\n", + encoding="utf-8", + ) + + def fake_exists(path): + return str(path).endswith("nginx.md5sums") + + monkeypatch.setattr(d.os.path, "exists", fake_exists) + + original_open = open + + def fake_open(path, *args, **kwargs): + if "nginx.md5sums" in str(path): + return original_open(md5_file, *args, **kwargs) + return original_open(path, *args, **kwargs) + + monkeypatch.setattr("builtins.open", fake_open, raising=False) + + result = d.read_pkg_md5sums("nginx") + assert result["etc/nginx/nginx.conf"] == "abcdef1234567890abcdef1234567890" + assert ( + result["etc/nginx/sites-enabled/default"] == "1234567890abcdef1234567890abcdef" + ) diff --git a/tests/test_diff_bundle.py b/tests/test_diff_bundle.py index 66ef094..ae12187 100644 --- a/tests/test_diff_bundle.py +++ b/tests/test_diff_bundle.py @@ -87,3 +87,995 @@ def test_bundle_from_input_missing_path(tmp_path: Path): with pytest.raises(RuntimeError, match="not found"): d._bundle_from_input(str(tmp_path / "nope"), sops_mode=False) + + +import json +import sys + + +from enroll.diff import ( + _bundle_from_input, + _file_index, + _iter_managed_files, + _load_state, + _pkg_version_display, + _pkg_version_key, + _progress_enabled, + _roles, + _service_units, + _sha256, + _users_by_name, + compare_harvests, +) +from enroll.sopsutil import SopsError + + +def test_progress_enabled_when_tty(monkeypatch): + monkeypatch.setattr(sys.stderr, "isatty", lambda: True) + monkeypatch.delenv("ENROLL_NO_PROGRESS", raising=False) + assert _progress_enabled() is True + + +def test_progress_enabled_when_not_tty(monkeypatch): + monkeypatch.setattr(sys.stderr, "isatty", lambda: False) + monkeypatch.delenv("ENROLL_NO_PROGRESS", raising=False) + assert _progress_enabled() is False + + +def test_progress_enabled_with_env_var(monkeypatch): + monkeypatch.setattr(sys.stderr, "isatty", lambda: True) + monkeypatch.setenv("ENROLL_NO_PROGRESS", "1") + assert _progress_enabled() is False + + monkeypatch.setenv("ENROLL_NO_PROGRESS", "true") + assert _progress_enabled() is False + + monkeypatch.setenv("ENROLL_NO_PROGRESS", "yes") + assert _progress_enabled() is False + + +def test_sha256(tmp_path: Path): + test_file = tmp_path / "test.txt" + test_file.write_text("hello world", encoding="utf-8") + hash_result = _sha256(test_file) + assert len(hash_result) == 64 + + +def test_sha256_empty_file(tmp_path: Path): + test_file = tmp_path / "empty.txt" + test_file.write_bytes(b"") + hash_result = _sha256(test_file) + assert ( + hash_result + == "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" + ) + + +def test_bundle_from_input_directory(tmp_path: Path): + result = _bundle_from_input(str(tmp_path), sops_mode=False) + assert result.dir == tmp_path + assert result.tempdir is None + + +def test_bundle_from_input_state_json_path(tmp_path: Path): + state_file = tmp_path / "state.json" + state_file.write_text("{}", encoding="utf-8") + result = _bundle_from_input(str(state_file), sops_mode=False) + assert result.dir == tmp_path + assert result.tempdir is None + + +def test_bundle_from_input_not_found(): + with pytest.raises(RuntimeError) as exc_info: + _bundle_from_input("/nonexistent/path", sops_mode=False) + assert "not found" in str(exc_info.value).lower() + + +def test_bundle_from_input_tarball(tmp_path: Path): + bundle_dir = tmp_path / "bundle" + bundle_dir.mkdir() + state_file = bundle_dir / "state.json" + state_file.write_text("{}", encoding="utf-8") + + tar_path = tmp_path / "bundle.tar.gz" + with tarfile.open(tar_path, "w:gz") as tf: + tf.add(bundle_dir, arcname="bundle") + + result = _bundle_from_input(str(tar_path), sops_mode=False) + assert result.dir.exists() + assert result.tempdir is not None + result.tempdir.cleanup() + + +def test_bundle_from_input_invalid_type(tmp_path: Path): + test_file = tmp_path / "test.txt" + test_file.write_text("not a bundle", encoding="utf-8") + + with pytest.raises(RuntimeError) as exc_info: + _bundle_from_input(str(test_file), sops_mode=False) + assert "not a directory" in str(exc_info.value).lower() + + +def test_load_state(tmp_path: Path): + state_file = tmp_path / "state.json" + state_file.write_text('{"host": {"hostname": "test"}}', encoding="utf-8") + result = _load_state(tmp_path) + assert result["host"]["hostname"] == "test" + + +def test_roles_empty_state(): + assert _roles({}) == {} + + +def test_roles_with_roles(): + state = {"roles": {"users": {}, "services": []}} + result = _roles(state) + assert "users" in result + + +def test_service_units_empty(): + assert _service_units({}) == {} + + +def test_service_units_with_services(): + state = { + "roles": { + "services": [ + {"unit": "nginx.service", "active_state": "active"}, + {"unit": "ssh.service", "active_state": "inactive"}, + ] + } + } + result = _service_units(state) + assert "nginx.service" in result + assert "ssh.service" in result + assert result["nginx.service"]["active_state"] == "active" + + +def test_users_by_name_empty(): + assert _users_by_name({}) == {} + + +def test_users_by_name_with_users(): + state = { + "roles": { + "users": { + "users": [ + {"name": "alice", "uid": 1000}, + {"name": "bob", "uid": 1001}, + ] + } + } + } + result = _users_by_name(state) + assert "alice" in result + assert "bob" in result + assert result["alice"]["uid"] == 1000 + + +def test_pkg_version_key_with_version(): + entry = {"version": "1.2.3"} + assert _pkg_version_key(entry) == "1.2.3" + + +def test_pkg_version_key_with_installations(): + entry = { + "installations": [ + {"arch": "x86_64", "version": "1.2.3"}, + {"arch": "aarch64", "version": "1.2.3"}, + ] + } + result = _pkg_version_key(entry) + assert "x86_64:1.2.3" in result + assert "aarch64:1.2.3" in result + + +def test_pkg_version_key_with_empty_version(): + entry = {"version": None} + assert _pkg_version_key(entry) is None + + +def test_pkg_version_key_with_invalid_installations(): + entry = {"installations": ["not_a_dict", {"arch": "x86_64", "version": "1.0"}]} + result = _pkg_version_key(entry) + assert "x86_64:1.0" in result + + +def test_pkg_version_display_with_version(): + entry = {"version": "1.2.3"} + assert _pkg_version_display(entry) == "1.2.3" + + +def test_pkg_version_display_with_installations(): + entry = { + "installations": [ + {"arch": "x86_64", "version": "1.2.3"}, + ] + } + assert _pkg_version_display(entry) == "1.2.3 (x86_64)" + + +def test_pkg_version_display_empty(): + assert _pkg_version_display({}) is None + + +def test_iter_managed_files_empty(): + state = {"roles": {}} + files = list(_iter_managed_files(state)) + assert files == [] + + +def test_iter_managed_files_services(): + state = { + "roles": { + "services": [ + { + "role_name": "nginx", + "managed_files": [ + {"path": "/etc/nginx/nginx.conf", "src_rel": "nginx.conf"} + ], + } + ] + } + } + files = list(_iter_managed_files(state)) + assert len(files) == 1 + assert files[0] == ( + "nginx", + {"path": "/etc/nginx/nginx.conf", "src_rel": "nginx.conf"}, + ) + + +def test_iter_managed_files_packages(): + state = { + "roles": { + "packages": [ + { + "role_name": "vim", + "managed_files": [{"path": "/usr/bin/vim", "src_rel": "bin/vim"}], + } + ] + } + } + files = list(_iter_managed_files(state)) + assert len(files) == 1 + assert files[0][0] == "vim" + + +def test_iter_managed_files_users(): + state = { + "roles": { + "users": { + "role_name": "users", + "managed_files": [{"path": "/etc/passwd", "src_rel": "passwd"}], + } + } + } + files = list(_iter_managed_files(state)) + assert len(files) == 1 + assert files[0][0] == "users" + + +def test_iter_managed_files_apt_config(): + state = { + "roles": { + "apt_config": { + "role_name": "apt_config", + "managed_files": [ + {"path": "/etc/apt/sources.list", "src_rel": "sources.list"} + ], + } + } + } + files = list(_iter_managed_files(state)) + assert len(files) == 1 + assert files[0][0] == "apt_config" + + +def test_iter_managed_files_etc_custom(): + state = { + "roles": { + "etc_custom": { + "role_name": "etc_custom", + "managed_files": [ + {"path": "/etc/custom.conf", "src_rel": "custom.conf"} + ], + } + } + } + files = list(_iter_managed_files(state)) + assert len(files) == 1 + assert files[0][0] == "etc_custom" + + +def test_iter_managed_files_usr_local_custom(): + state = { + "roles": { + "usr_local_custom": { + "role_name": "usr_local_custom", + "managed_files": [ + {"path": "/usr/local/bin/script", "src_rel": "bin/script"} + ], + } + } + } + files = list(_iter_managed_files(state)) + assert len(files) == 1 + assert files[0][0] == "usr_local_custom" + + +def test_iter_managed_files_extra_paths(): + state = { + "roles": { + "extra_paths": { + "role_name": "extra_paths", + "managed_files": [{"path": "/opt/app/config", "src_rel": "config"}], + } + } + } + files = list(_iter_managed_files(state)) + assert len(files) == 1 + assert files[0][0] == "extra_paths" + + +def test_file_index_empty(): + state = {"roles": {}} + index = _file_index(Path("/tmp"), state) + assert index == {} + + +def test_file_index_with_files(tmp_path: Path): + state = { + "roles": { + "users": { + "managed_files": [ + {"path": "/etc/passwd", "src_rel": "passwd", "owner": "root"}, + ] + } + } + } + index = _file_index(tmp_path, state) + assert "/etc/passwd" in index + assert index["/etc/passwd"].role == "users" + assert index["/etc/passwd"].owner == "root" + + +def test_file_index_duplicates_first_wins(tmp_path: Path): + state = { + "roles": { + "users": { + "managed_files": [ + {"path": "/etc/passwd", "src_rel": "passwd"}, + ] + }, + "etc_custom": { + "managed_files": [ + {"path": "/etc/passwd", "src_rel": "custom_passwd"}, + ] + }, + } + } + index = _file_index(tmp_path, state) + assert "/etc/passwd" in index + assert index["/etc/passwd"].src_rel == "passwd" + + +def test_file_index_skips_missing_path_or_src_rel(tmp_path: Path): + state = { + "roles": { + "users": { + "managed_files": [ + {"path": "/etc/passwd"}, # missing src_rel + {"src_rel": "passwd"}, # missing path + ] + } + } + } + index = _file_index(tmp_path, state) + assert index == {} + + +def test_compare_harvests_no_changes(tmp_path: Path): + old_bundle = tmp_path / "old" + old_bundle.mkdir() + (old_bundle / "state.json").write_text( + json.dumps( + { + "inventory": {"packages": {"vim": {"version": "1.0"}}}, + "roles": {}, + } + ), + encoding="utf-8", + ) + + new_bundle = tmp_path / "new" + new_bundle.mkdir() + (new_bundle / "state.json").write_text( + json.dumps( + { + "inventory": {"packages": {"vim": {"version": "1.0"}}}, + "roles": {}, + } + ), + encoding="utf-8", + ) + + report, has_changes = compare_harvests(str(old_bundle), str(new_bundle)) + assert has_changes is False + assert report["packages"]["added"] == [] + assert report["packages"]["removed"] == [] + + +def test_compare_harvests_package_added(tmp_path: Path): + old_bundle = tmp_path / "old" + old_bundle.mkdir() + (old_bundle / "state.json").write_text( + json.dumps({"inventory": {"packages": {}}, "roles": {}}), + encoding="utf-8", + ) + + new_bundle = tmp_path / "new" + new_bundle.mkdir() + (new_bundle / "state.json").write_text( + json.dumps( + {"inventory": {"packages": {"vim": {"version": "1.0"}}}, "roles": {}} + ), + encoding="utf-8", + ) + + report, has_changes = compare_harvests(str(old_bundle), str(new_bundle)) + assert has_changes is True + assert "vim" in report["packages"]["added"] + + +def test_compare_harvests_package_removed(tmp_path: Path): + old_bundle = tmp_path / "old" + old_bundle.mkdir() + (old_bundle / "state.json").write_text( + json.dumps( + {"inventory": {"packages": {"vim": {"version": "1.0"}}}, "roles": {}} + ), + encoding="utf-8", + ) + + new_bundle = tmp_path / "new" + new_bundle.mkdir() + (new_bundle / "state.json").write_text( + json.dumps({"inventory": {"packages": {}}, "roles": {}}), + encoding="utf-8", + ) + + report, has_changes = compare_harvests(str(old_bundle), str(new_bundle)) + assert has_changes is True + assert "vim" in report["packages"]["removed"] + + +def test_compare_harvests_package_version_changed(tmp_path: Path): + old_bundle = tmp_path / "old" + old_bundle.mkdir() + (old_bundle / "state.json").write_text( + json.dumps( + {"inventory": {"packages": {"vim": {"version": "1.0"}}}, "roles": {}} + ), + encoding="utf-8", + ) + + new_bundle = tmp_path / "new" + new_bundle.mkdir() + (new_bundle / "state.json").write_text( + json.dumps( + {"inventory": {"packages": {"vim": {"version": "2.0"}}}, "roles": {}} + ), + encoding="utf-8", + ) + + report, has_changes = compare_harvests(str(old_bundle), str(new_bundle)) + assert has_changes is True + assert len(report["packages"]["version_changed"]) == 1 + + +def test_compare_harvests_ignore_package_versions(tmp_path: Path): + old_bundle = tmp_path / "old" + old_bundle.mkdir() + (old_bundle / "state.json").write_text( + json.dumps( + {"inventory": {"packages": {"vim": {"version": "1.0"}}}, "roles": {}} + ), + encoding="utf-8", + ) + + new_bundle = tmp_path / "new" + new_bundle.mkdir() + (new_bundle / "state.json").write_text( + json.dumps( + {"inventory": {"packages": {"vim": {"version": "2.0"}}}, "roles": {}} + ), + encoding="utf-8", + ) + + report, has_changes = compare_harvests( + str(old_bundle), str(new_bundle), ignore_package_versions=True + ) + assert report["packages"]["version_changed_ignored_count"] == 1 + + +def test_compare_harvests_service_added(tmp_path: Path): + old_bundle = tmp_path / "old" + old_bundle.mkdir() + (old_bundle / "state.json").write_text( + json.dumps({"inventory": {"packages": {}}, "roles": {"services": []}}), + encoding="utf-8", + ) + + new_bundle = tmp_path / "new" + new_bundle.mkdir() + (new_bundle / "state.json").write_text( + json.dumps( + { + "inventory": {"packages": {}}, + "roles": {"services": [{"unit": "nginx.service"}]}, + } + ), + encoding="utf-8", + ) + + report, has_changes = compare_harvests(str(old_bundle), str(new_bundle)) + assert has_changes is True + assert "nginx.service" in report["services"]["enabled_added"] + + +def test_compare_harvests_user_added(tmp_path: Path): + old_bundle = tmp_path / "old" + old_bundle.mkdir() + (old_bundle / "state.json").write_text( + json.dumps({"inventory": {"packages": {}}, "roles": {"users": {"users": []}}}), + encoding="utf-8", + ) + + new_bundle = tmp_path / "new" + new_bundle.mkdir() + (new_bundle / "state.json").write_text( + json.dumps( + { + "inventory": {"packages": {}}, + "roles": {"users": {"users": [{"name": "alice", "uid": 1000}]}}, + } + ), + encoding="utf-8", + ) + + report, has_changes = compare_harvests(str(old_bundle), str(new_bundle)) + assert has_changes is True + assert "alice" in report["users"]["added"] + + +def test_compare_harvests_with_exclude_paths(tmp_path: Path): + old_bundle = tmp_path / "old" + old_bundle.mkdir() + old_artifacts = old_bundle / "artifacts" / "users" + old_artifacts.mkdir(parents=True) + (old_artifacts / "passwd").write_text("old", encoding="utf-8") + (old_bundle / "state.json").write_text( + json.dumps( + { + "inventory": {"packages": {}}, + "roles": { + "users": { + "managed_files": [{"path": "/etc/passwd", "src_rel": "passwd"}] + } + }, + } + ), + encoding="utf-8", + ) + + new_bundle = tmp_path / "new" + new_bundle.mkdir() + new_artifacts = new_bundle / "artifacts" / "users" + new_artifacts.mkdir(parents=True) + (new_artifacts / "passwd").write_text("new", encoding="utf-8") + (new_bundle / "state.json").write_text( + json.dumps( + { + "inventory": {"packages": {}}, + "roles": { + "users": { + "managed_files": [{"path": "/etc/passwd", "src_rel": "passwd"}] + } + }, + } + ), + encoding="utf-8", + ) + + report, has_changes = compare_harvests( + str(old_bundle), str(new_bundle), exclude_paths=["/etc/passwd"] + ) + assert "/etc/passwd" not in [f["path"] for f in report["files"]["added"]] + assert "/etc/passwd" not in [f["path"] for f in report["files"]["removed"]] + assert "/etc/passwd" not in [f["path"] for f in report["files"]["changed"]] + + +from enroll.diff import ( + _Spinner, + _enforcement_plan, + has_enforceable_drift, + _role_tag, + _utc_now_iso, + _report_markdown, +) + + +def test_utc_now_iso(): + result = _utc_now_iso() + assert "T" in result + assert "+" in result or "Z" in result + + +def test_spinner_start_stop(monkeypatch): + # Mock sys.stderr to avoid actual writes + class FakeStderr: + def write(self, s): + pass + + def flush(self): + pass + + def isatty(self): + return True + + monkeypatch.setattr(sys, "stderr", FakeStderr()) + + spinner = _Spinner("Test") + spinner.start() + spinner.stop(final_line="Done") + # Should not raise + + +def test_spinner_stop_without_start(): + spinner = _Spinner("Test") + spinner.stop(final_line="Done") + # Should not raise + + +def test_spinner_run_exception(monkeypatch): + class FakeStderr: + def write(self, s): + raise Exception("Write error") + + def flush(self): + pass + + monkeypatch.setattr(sys, "stderr", FakeStderr()) + + spinner = _Spinner("Test") + spinner.start() + spinner.stop() + + +def test_spinner_double_start(): + spinner = _Spinner("Test") + spinner.start() + spinner.start() # Should not raise or spawn another thread + spinner.stop() + + +def test_role_tag_normal(): + assert _role_tag("nginx") == "role_nginx" + assert _role_tag("my-app") == "role_my-app" + + +def test_role_tag_with_special_chars(): + assert _role_tag("my.app") == "role_my_app" + assert _role_tag("my app") == "role_my_app" + + +def test_role_tag_empty(): + assert _role_tag("") == "role_other" + assert _role_tag(" ") == "role_other" + + +def test_has_enforceable_drift_packages_removed(): + report = {"packages": {"removed": ["vim"]}} + assert has_enforceable_drift(report) is True + + +def test_has_enforceable_drift_services_removed(): + report = {"services": {"enabled_removed": ["nginx.service"]}} + assert has_enforceable_drift(report) is True + + +def test_has_enforceable_drift_service_changed(): + report = { + "services": { + "changed": [ + { + "unit": "nginx.service", + "changes": {"active_state": {"old": "active", "new": "inactive"}}, + } + ] + } + } + assert has_enforceable_drift(report) is True + + +def test_has_enforceable_drift_service_package_only_changed(): + # Service changed only in packages - should NOT be enforceable + report = { + "services": { + "changed": [ + { + "unit": "nginx.service", + "changes": {"packages": {"added": ["nginx-extra"]}}, + } + ] + } + } + assert has_enforceable_drift(report) is False + + +def test_has_enforceable_drift_users_removed(): + report = {"users": {"removed": ["alice"]}} + assert has_enforceable_drift(report) is True + + +def test_has_enforceable_drift_users_changed(): + report = { + "users": { + "changed": [ + {"name": "alice", "changes": {"uid": {"old": 1000, "new": 1001}}} + ] + } + } + assert has_enforceable_drift(report) is True + + +def test_has_enforceable_drift_files_removed(): + report = { + "files": { + "removed": [{"path": "/etc/passwd", "role": "users", "reason": "conffile"}] + } + } + assert has_enforceable_drift(report) is True + + +def test_has_enforceable_drift_files_changed(): + report = { + "files": { + "changed": [ + { + "path": "/etc/passwd", + "changes": {"content": {"old": "sha1", "new": "sha2"}}, + } + ] + } + } + assert has_enforceable_drift(report) is True + + +def test_has_enforceable_drift_no_drift(): + report = { + "packages": {"added": ["newpkg"]}, + "services": {"enabled_added": ["new.service"]}, + "users": {"added": ["bob"]}, + "files": {"added": ["/opt/newfile"]}, + } + assert has_enforceable_drift(report) is False + + +def test_enforcement_plan_packages_removed(monkeypatch, tmp_path: Path): + old_state = { + "roles": { + "services": [{"role_name": "nginx", "packages": ["nginx"]}], + "packages": [{"role_name": "vim", "package": "vim"}], + } + } + report = {"packages": {"removed": ["nginx", "vim"]}} + + result = _enforcement_plan(report, old_state, tmp_path) + assert "nginx" in result.get("roles", []) + assert "vim" in result.get("roles", []) + assert "role_nginx" in result.get("tags", []) + + +def test_enforcement_plan_users_changed(): + old_state = { + "roles": {"users": {"role_name": "users", "users": [{"name": "alice"}]}} + } + report = {"users": {"changed": [{"name": "alice", "changes": {"uid": {}}}]}} + + result = _enforcement_plan(report, old_state, Path("/tmp")) + assert "users" in result.get("roles", []) + + +def test_enforcement_plan_files_removed(tmp_path: Path): + # Create the artifacts directory structure that _file_index expects + artifacts_dir = tmp_path / "artifacts" / "etc_custom" + artifacts_dir.mkdir(parents=True) + + old_state = { + "roles": { + "etc_custom": { + "role_name": "etc_custom", + "managed_files": [ + {"path": "/etc/custom.conf", "src_rel": "custom.conf"} + ], + } + } + } + report = { + "files": {"removed": [{"path": "/etc/custom.conf", "role": "etc_custom"}]} + } + + result = _enforcement_plan(report, old_state, tmp_path) + assert "etc_custom" in result.get("roles", []) + + +def test_enforcement_plan_no_drift(): + old_state = {"roles": {}} + report = {"packages": {"added": ["newpkg"]}} + + result = _enforcement_plan(report, old_state, Path("/tmp")) + assert result.get("roles", []) == [] + + +def test_bundle_from_input_tgz(monkeypatch, tmp_path: Path): + bundle_dir = tmp_path / "bundle" + bundle_dir.mkdir() + state_file = bundle_dir / "state.json" + state_file.write_text("{}", encoding="utf-8") + + tar_path = tmp_path / "bundle.tgz" + with tarfile.open(tar_path, "w:gz") as tf: + tf.add(bundle_dir, arcname="bundle") + + result = _bundle_from_input(str(tar_path), sops_mode=False) + assert result.dir.exists() + assert result.tempdir is not None + result.tempdir.cleanup() + + +def test_bundle_from_input_sops_mode_no_sops(monkeypatch, tmp_path: Path): + # Create a fake .sops file + sops_file = tmp_path / "harvest.sops" + sops_file.write_bytes(b"encrypted") + + def fake_require(): + raise SopsError("sops not found") + + import enroll.diff as d + + monkeypatch.setattr(d, "require_sops_cmd", fake_require) + + with pytest.raises(SopsError): + _bundle_from_input(str(sops_file), sops_mode=True) + + +def test_report_markdown_basic(): + report = { + "generated_at": "2024-01-01T00:00:00Z", + "old": {"input": "old.tar.gz", "host": "host1"}, + "new": {"input": "new.tar.gz", "host": "host2"}, + "packages": {"added": ["vim"], "removed": [], "version_changed": []}, + "services": {"enabled_added": [], "enabled_removed": [], "changed": []}, + "users": {"added": [], "removed": [], "changed": []}, + "files": {"added": [], "removed": [], "changed": []}, + } + result = _report_markdown(report) + assert "## Packages" in result + assert "+ vim" in result + + +def test_report_markdown_with_enforcement_applied(): + report = { + "generated_at": "2024-01-01T00:00:00Z", + "old": {"input": "old.tar.gz"}, + "new": {"input": "new.tar.gz"}, + "packages": {"added": [], "removed": [], "version_changed": []}, + "services": {"enabled_added": [], "enabled_removed": [], "changed": []}, + "users": {"added": [], "removed": [], "changed": []}, + "files": {"added": [], "removed": [], "changed": []}, + "enforcement": { + "status": "applied", + "tags": ["role_users"], + "returncode": 0, + "finished_at": "2024-01-01T00:01:00Z", + }, + } + result = _report_markdown(report) + assert "Applied old harvest" in result + assert "role_users" in result + + +def test_report_markdown_with_enforcement_failed(): + report = { + "generated_at": "2024-01-01T00:00:00Z", + "old": {"input": "old.tar.gz"}, + "new": {"input": "new.tar.gz"}, + "packages": {"added": [], "removed": [], "version_changed": []}, + "services": {"enabled_added": [], "enabled_removed": [], "changed": []}, + "users": {"added": [], "removed": [], "changed": []}, + "files": {"added": [], "removed": [], "changed": []}, + "enforcement": { + "status": "failed", + "returncode": 1, + }, + } + result = _report_markdown(report) + assert "ansible-playbook failed" in result + + +def test_report_markdown_with_enforcement_skipped(): + report = { + "generated_at": "2024-01-01T00:00:00Z", + "old": {"input": "old.tar.gz"}, + "new": {"input": "new.tar.gz"}, + "packages": {"added": [], "removed": [], "version_changed": []}, + "services": {"enabled_added": [], "enabled_removed": [], "changed": []}, + "users": {"added": [], "removed": [], "changed": []}, + "files": {"added": [], "removed": [], "changed": []}, + "enforcement": { + "status": "skipped", + "reason": "no drift", + }, + } + result = _report_markdown(report) + assert "Skipped" in result + assert "no drift" in result + + +def test_report_markdown_with_version_ignored(): + report = { + "generated_at": "2024-01-01T00:00:00Z", + "old": {"input": "old.tar.gz"}, + "new": {"input": "new.tar.gz"}, + "packages": { + "added": [], + "removed": [], + "version_changed": [{"package": "vim", "old": "1.0", "new": "2.0"}], + "version_changed_ignored_count": 1, + }, + "services": {"enabled_added": [], "enabled_removed": [], "changed": []}, + "users": {"added": [], "removed": [], "changed": []}, + "files": {"added": [], "removed": [], "changed": []}, + } + result = _report_markdown(report) + assert "ignored 1" in result + + +def test_report_markdown_with_service_package_changes(): + report = { + "generated_at": "2024-01-01T00:00:00Z", + "old": {"input": "old.tar.gz"}, + "new": {"input": "new.tar.gz"}, + "packages": {"added": [], "removed": [], "version_changed": []}, + "services": { + "enabled_added": [], + "enabled_removed": [], + "changed": [ + { + "unit": "nginx.service", + "changes": {"packages": {"added": ["nginx-extra"], "removed": []}}, + } + ], + }, + "users": {"added": [], "removed": [], "changed": []}, + "files": {"added": [], "removed": [], "changed": []}, + } + result = _report_markdown(report) + assert "packages added" in result + + +def test_report_markdown_empty(): + report = { + "generated_at": "2024-01-01T00:00:00Z", + "old": {"input": "old.tar.gz"}, + "new": {"input": "new.tar.gz"}, + "packages": {}, + "services": {}, + "users": {}, + "files": {}, + } + result = _report_markdown(report) + assert "## Packages" in result + assert "## Services" in result diff --git a/tests/test_harvest.py b/tests/test_harvest.py index 1b884aa..33b5302 100644 --- a/tests/test_harvest.py +++ b/tests/test_harvest.py @@ -1,4 +1,5 @@ import json +import enroll.harvest as harvest from pathlib import Path import enroll.harvest as h @@ -367,3 +368,149 @@ def test_shared_cron_snippet_prefers_matching_role_over_lexicographic( assert all( mf["path"] != "/etc/cron.d/ntpsec" for mf in svc_apparmor["managed_files"] ) + + +def test_files_differ_same_content(tmp_path: Path): + file1 = tmp_path / "file1.txt" + file2 = tmp_path / "file2.txt" + file1.write_text("same content", encoding="utf-8") + file2.write_text("same content", encoding="utf-8") + assert harvest._files_differ(str(file1), str(file2)) is False + + +def test_files_differ_different_content(tmp_path: Path): + file1 = tmp_path / "file1.txt" + file2 = tmp_path / "file2.txt" + file1.write_text("content1", encoding="utf-8") + file2.write_text("content2", encoding="utf-8") + assert harvest._files_differ(str(file1), str(file2)) is True + + +def test_files_differ_missing_file(tmp_path: Path): + file1 = tmp_path / "file1.txt" + file1.write_text("content", encoding="utf-8") + file2 = tmp_path / "file2.txt" + assert harvest._files_differ(str(file1), str(file2)) is True + + +def test_files_differ_both_missing(tmp_path: Path): + file1 = tmp_path / "file1.txt" + file2 = tmp_path / "file2.txt" + assert harvest._files_differ(str(file1), str(file2)) is True + + +def test_files_differ_binary(tmp_path: Path): + file1 = tmp_path / "file1.bin" + file2 = tmp_path / "file2.bin" + file1.write_bytes(b"\x00\x01\x02\x03") + file2.write_bytes(b"\x00\x01\x02\x03") + assert harvest._files_differ(str(file1), str(file2)) is False + + +def test_files_differ_binary_different(tmp_path: Path): + file1 = tmp_path / "file1.bin" + file2 = tmp_path / "file2.bin" + file1.write_bytes(b"\x00\x01\x02\x03") + file2.write_bytes(b"\x00\x01\x02\x04") + assert harvest._files_differ(str(file1), str(file2)) is True + + +def test_files_differ_non_regular_a(tmp_path: Path): + directory = tmp_path / "dir" + directory.mkdir() + file1 = tmp_path / "file1.txt" + file1.write_text("content", encoding="utf-8") + assert harvest._files_differ(str(directory), str(file1)) is True + + +def test_files_differ_non_regular_b(tmp_path: Path): + file1 = tmp_path / "file1.txt" + file1.write_text("content", encoding="utf-8") + directory = tmp_path / "dir" + directory.mkdir() + assert harvest._files_differ(str(file1), str(directory)) is True + + +def test_files_differ_size_mismatch(tmp_path: Path): + file1 = tmp_path / "file1.txt" + file1.write_text("short", encoding="utf-8") + file2 = tmp_path / "file2.txt" + file2.write_text("much longer content", encoding="utf-8") + assert harvest._files_differ(str(file1), str(file2)) is True + + +def test_files_differ_large_files(tmp_path: Path): + file1 = tmp_path / "file1.bin" + file2 = tmp_path / "file2.bin" + file1.write_bytes(b"x" * 3_000_000) + file2.write_bytes(b"x" * 3_000_000) + assert harvest._files_differ(str(file1), str(file2)) is True + + +def test_is_confish_with_conf(tmp_path: Path): + file1 = tmp_path / "test.conf" + file1.write_text("content", encoding="utf-8") + assert harvest._is_confish(str(file1)) is True + + +def test_is_confish_with_yaml(tmp_path: Path): + file1 = tmp_path / "test.yaml" + file1.write_text("content", encoding="utf-8") + assert harvest._is_confish(str(file1)) is True + + +def test_is_confish_with_json(tmp_path: Path): + file1 = tmp_path / "test.json" + file1.write_text("{}", encoding="utf-8") + assert harvest._is_confish(str(file1)) is True + + +def test_is_confish_with_service(tmp_path: Path): + file1 = tmp_path / "test.service" + file1.write_text("[Unit]", encoding="utf-8") + assert harvest._is_confish(str(file1)) is True + + +def test_is_confish_with_extensionless(tmp_path: Path): + file1 = tmp_path / "default" + file1.write_text("OPTIONS=", encoding="utf-8") + assert harvest._is_confish(str(file1)) is True + + +def test_is_confish_not_config(tmp_path: Path): + file1 = tmp_path / "test.log" + file1.write_text("log", encoding="utf-8") + assert harvest._is_confish(str(file1)) is False + + +def test_is_confish_nonexistent(): + assert harvest._is_confish("/nonexistent/file.xyz") is False + + +def test_topdirs_for_package_with_multiple_paths(): + pkg_to_etc_paths = { + "nginx": ["/etc/nginx/nginx.conf", "/etc/nginx/sites-enabled/default"], + } + result = harvest._topdirs_for_package("nginx", pkg_to_etc_paths) + assert result == {"nginx"} + + +def test_topdirs_for_package_with_multiple_topdirs(): + pkg_to_etc_paths = { + "multi": ["/etc/nginx/nginx.conf", "/etc/ssh/sshd_config"], + } + result = harvest._topdirs_for_package("multi", pkg_to_etc_paths) + assert result == {"nginx", "ssh"} + + +def test_topdirs_for_package_empty(): + result = harvest._topdirs_for_package("empty", {}) + assert result == set() + + +def test_topdirs_for_package_no_etc(): + pkg_to_etc_paths = { + "other": ["/usr/share/doc/file"], + } + result = harvest._topdirs_for_package("other", pkg_to_etc_paths) + assert result == set() diff --git a/tests/test_ignore.py b/tests/test_ignore.py index 1eaae01..2ba9a90 100644 --- a/tests/test_ignore.py +++ b/tests/test_ignore.py @@ -1,3 +1,8 @@ +from __future__ import annotations + +import os +from pathlib import Path + from enroll.ignore import IgnorePolicy @@ -8,3 +13,238 @@ def test_ignore_policy_denies_common_backup_files(): assert pol.deny_reason("/etc/group-") == "backup_file" assert pol.deny_reason("/etc/something~") == "backup_file" assert pol.deny_reason("/foobar") == "unreadable" + + +def test_deny_reason_dir_with_denied_path(): + pol = IgnorePolicy() + assert pol.deny_reason_dir("/etc/ssl/private/key") == "denied_path" + assert pol.deny_reason_dir("/etc/ssh/ssh_host_key") == "denied_path" + assert pol.deny_reason_dir("/etc/ssh") is None + + +def test_deny_reason_dir_unreadable(tmp_path: Path): + pol = IgnorePolicy() + nonexistent = tmp_path / "nonexistent" + assert pol.deny_reason_dir(str(nonexistent)) == "unreadable" + + +def test_deny_reason_dir_symlink(tmp_path: Path): + pol = IgnorePolicy() + real_dir = tmp_path / "real" + real_dir.mkdir() + link = tmp_path / "link" + os.symlink(str(real_dir), str(link)) + assert pol.deny_reason_dir(str(link)) == "symlink" + + +def test_deny_reason_dir_not_directory(tmp_path: Path): + pol = IgnorePolicy() + regular_file = tmp_path / "file.txt" + regular_file.write_text("content", encoding="utf-8") + assert pol.deny_reason_dir(str(regular_file)) == "not_directory" + + +def test_deny_reason_dir_dangerous_mode(tmp_path: Path): + pol = IgnorePolicy(dangerous=True) + real_dir = tmp_path / "private" + real_dir.mkdir() + assert pol.deny_reason_dir(str(real_dir)) is None + + +def test_deny_reason_link_basic(tmp_path: Path): + pol = IgnorePolicy() + real_file = tmp_path / "real" + real_file.write_text("content", encoding="utf-8") + link = tmp_path / "link" + os.symlink(str(real_file), str(link)) + assert pol.deny_reason_link(str(link)) is None + + +def test_deny_reason_link_denied_path(): + pol = IgnorePolicy() + assert pol.deny_reason_link("/etc/ssh/ssh_host_rsa_key") == "denied_path" + + +def test_deny_reason_link_unreadable(tmp_path: Path): + pol = IgnorePolicy() + # Create a symlink in a directory that doesn't exist + # This simulates an unreadable path + broken_link = tmp_path / "broken_link" + os.symlink("/nonexistent/target", str(broken_link)) + # Broken symlinks are still readable (we can readlink them) + # So they return None (allowed) unless they match deny globs + result = pol.deny_reason_link(str(broken_link)) + # Broken symlinks are allowed - we can still read the link target + assert result is None + + +def test_deny_reason_link_not_symlink(tmp_path: Path): + pol = IgnorePolicy() + regular_file = tmp_path / "file.txt" + regular_file.write_text("content", encoding="utf-8") + assert pol.deny_reason_link(str(regular_file)) == "not_symlink" + + +def test_deny_reason_link_log_file(): + pol = IgnorePolicy() + assert pol.deny_reason_link("/var/log/something.log") == "log_file" + + +def test_deny_reason_link_backup_file(): + pol = IgnorePolicy() + assert pol.deny_reason_link("/etc/passwd-") == "backup_file" + assert pol.deny_reason_link("/etc/something~") == "backup_file" + + +def test_deny_reason_link_dangerous_mode(tmp_path: Path): + pol = IgnorePolicy(dangerous=True) + real_file = tmp_path / "real" + real_file.write_text("content", encoding="utf-8") + link = tmp_path / "link" + os.symlink(str(real_file), str(link)) + assert pol.deny_reason_link(str(link)) is None + + +def test_iter_effective_lines_with_comments(): + pol = IgnorePolicy() + content = b""" +# This is a comment +; This is also a comment +* continuation +def main(): + pass +""" + lines = list(pol.iter_effective_lines(content)) + assert b"def main():" in lines + assert b"# This is a comment" not in lines + + +def test_iter_effective_lines_with_block_comments(): + pol = IgnorePolicy() + content = b""" +/* This is a block comment + spanning multiple lines */ +int x = 5; +""" + lines = list(pol.iter_effective_lines(content)) + assert b"int x = 5;" in lines + assert b"/*" not in lines + + +def test_iter_effective_lines_empty(): + pol = IgnorePolicy() + content = b"" + lines = list(pol.iter_effective_lines(content)) + assert lines == [] + + +def test_deny_reason_binary_not_allowed(tmp_path: Path): + pol = IgnorePolicy() + binary = tmp_path / "random.bin" + binary.write_bytes(b"\x00\x01\x02\x03") + reason = pol.deny_reason(str(binary)) + assert reason == "binary_like" + + +def test_deny_reason_sensitive_content(tmp_path: Path): + pol = IgnorePolicy() + config = tmp_path / "config.txt" + config.write_text("password=secret123", encoding="utf-8") + reason = pol.deny_reason(str(config)) + assert reason == "sensitive_content" + + +def test_deny_reason_sensitive_api_key(tmp_path: Path): + pol = IgnorePolicy() + config = tmp_path / "config.txt" + config.write_text("api_key=abc123", encoding="utf-8") + reason = pol.deny_reason(str(config)) + assert reason == "sensitive_content" + + +def test_deny_reason_private_key(tmp_path: Path): + pol = IgnorePolicy() + key = tmp_path / "key.pem" + key.write_text( + "-----BEGIN RSA PRIVATE KEY-----\nMIIEowIBAAKCAQEA...", encoding="utf-8" + ) + reason = pol.deny_reason(str(key)) + assert reason == "sensitive_content" + + +def test_deny_reason_too_large(tmp_path: Path): + pol = IgnorePolicy(max_file_bytes=100) + large = tmp_path / "large.txt" + large.write_bytes(b"x" * 200) + reason = pol.deny_reason(str(large)) + assert reason == "too_large" + + +def test_deny_reason_unreadable(tmp_path: Path): + pol = IgnorePolicy() + nonexistent = tmp_path / "nonexistent" + reason = pol.deny_reason(str(nonexistent)) + assert reason == "unreadable" + + +def test_deny_reason_not_regular_file(tmp_path: Path): + pol = IgnorePolicy() + directory = tmp_path / "dir" + directory.mkdir() + reason = pol.deny_reason(str(directory)) + assert reason == "not_regular_file" + + +def test_deny_reason_symlink_file(tmp_path: Path): + pol = IgnorePolicy() + real_file = tmp_path / "real" + real_file.write_text("content", encoding="utf-8") + link = tmp_path / "link" + os.symlink(str(real_file), str(link)) + reason = pol.deny_reason(str(link)) + assert reason == "not_regular_file" + + +def test_deny_reason_logs(tmp_path: Path): + pol = IgnorePolicy() + log = tmp_path / "test.log" + log.write_text("log content", encoding="utf-8") + assert pol.deny_reason(str(log)) == "log_file" + + +def test_deny_reason_backup_file(tmp_path: Path): + pol = IgnorePolicy() + backup = tmp_path / "file~" + backup.write_text("backup", encoding="utf-8") + assert pol.deny_reason(str(backup)) == "backup_file" + + +def test_deny_reason_shadow_file(): + pol = IgnorePolicy() + assert pol.deny_reason("/etc/shadow") == "denied_path" + assert pol.deny_reason("/etc/gshadow") == "denied_path" + + +def test_deny_reason_ssl_private(): + pol = IgnorePolicy() + assert pol.deny_reason("/etc/ssl/private/key.pem") == "denied_path" + + +def test_deny_reason_ssh_host_keys(): + pol = IgnorePolicy() + assert pol.deny_reason("/etc/ssh/ssh_host_rsa_key") == "denied_path" + assert pol.deny_reason("/etc/ssh/ssh_host_ed25519_key") == "denied_path" + + +def test_deny_reason_letsencrypt(): + pol = IgnorePolicy() + assert ( + pol.deny_reason("/etc/letsencrypt/live/example.com/fullchain.pem") + == "denied_path" + ) + + +def test_deny_reason_shadow_backup(): + pol = IgnorePolicy() + assert pol.deny_reason("/etc/shadow-") == "backup_file" + assert pol.deny_reason("/etc/passwd-") == "backup_file" diff --git a/tests/test_manifest.py b/tests/test_manifest.py index 658d77f..1b78bcf 100644 --- a/tests/test_manifest.py +++ b/tests/test_manifest.py @@ -892,3 +892,175 @@ def test_manifest_writes_firewall_runtime_role(tmp_path: Path): assert ( out / "roles" / "firewall_runtime" / "files" / "firewall" / "ipset.save" ).exists() + + +def test_try_yaml_with_yaml_installed(): + result = manifest._try_yaml() + # PyYAML should be installed for tests + if result is None: + pytest.skip("PyYAML not installed") + assert hasattr(result, "safe_load") + assert hasattr(result, "dump") + + +def test_yaml_load_mapping_with_yaml(tmp_path: Path): + text = """ +key1: value1 +key2: + nested: value +list: + - item1 + - item2 +""" + result = manifest._yaml_load_mapping(text) + assert result["key1"] == "value1" + assert result["key2"]["nested"] == "value" + assert result["list"] == ["item1", "item2"] + + +def test_yaml_load_mapping_empty(): + result = manifest._yaml_load_mapping("") + assert result == {} + + +def test_yaml_load_mapping_invalid(): + result = manifest._yaml_load_mapping("invalid: yaml: :") + assert result == {} + + +def test_yaml_load_mapping_not_dict(): + result = manifest._yaml_load_mapping("- item1\n- item2") + assert result == {} + + +def test_yaml_load_mapping_none(): + result = manifest._yaml_load_mapping("~") + assert result == {} + + +def test_yaml_dump_mapping_with_yaml(tmp_path: Path): + obj = {"key1": "value1", "key2": 123} + result = manifest._yaml_dump_mapping(obj) + assert "key1: value1" in result + assert "key2:" in result + + +def test_yaml_dump_mapping_empty(): + result = manifest._yaml_dump_mapping({}) + # Empty dict produces '{}' + assert result.strip() == "{}" + + +def test_yaml_dump_mapping_with_nested(tmp_path: Path): + obj = {"key1": {"nested": "value"}} + result = manifest._yaml_dump_mapping(obj) + assert "nested:" in result + + +def test_merge_mappings_overwrite_simple(): + existing = {"key1": "old", "key2": "keep"} + incoming = {"key1": "new", "key3": "added"} + result = manifest._merge_mappings_overwrite(existing, incoming) + assert result["key1"] == "new" + assert result["key2"] == "keep" + assert result["key3"] == "added" + + +def test_merge_mappings_overwrite_nested(): + existing = {"key1": {"a": 1}} + incoming = {"key1": {"b": 2}} + result = manifest._merge_mappings_overwrite(existing, incoming) + # Nested dicts are replaced, not merged + assert result["key1"] == {"b": 2} + + +def test_merge_mappings_overwrite_empty(): + result = manifest._merge_mappings_overwrite({}, {"key": "value"}) + assert result == {"key": "value"} + + result = manifest._merge_mappings_overwrite({"key": "value"}, {}) + assert result == {"key": "value"} + + +def test_copy2_replace(tmp_path: Path): + src = tmp_path / "src.txt" + src.write_text("content", encoding="utf-8") + dst = tmp_path / "dst" / "subdir" / "dst.txt" + + manifest._copy2_replace(str(src), str(dst)) + + assert dst.exists() + assert dst.read_text(encoding="utf-8") == "content" + + +def test_copy2_replace_preserves_metadata(tmp_path: Path): + src = tmp_path / "src.txt" + src.write_text("content", encoding="utf-8") + os.chmod(str(src), 0o644) + dst = tmp_path / "dst.txt" + + manifest._copy2_replace(str(src), str(dst)) + + assert dst.exists() + st = dst.stat() + assert stat.S_IMODE(st.st_mode) == 0o644 + + +def test_copy2_replace_atomic(tmp_path: Path): + src = tmp_path / "src.txt" + src.write_text("content", encoding="utf-8") + dst = tmp_path / "dst.txt" + + # Write initial content + dst.write_text("old", encoding="utf-8") + + manifest._copy2_replace(str(src), str(dst)) + + assert dst.read_text(encoding="utf-8") == "content" + + +def test_render_firewall_runtime_tasks_empty(): + state = {"roles": {}} + result = manifest._render_firewall_runtime_tasks(state) + # Function always returns at least a basic playbook structure + assert isinstance(result, str) + assert len(result) > 0 + + +def test_render_firewall_runtime_tasks_with_iptables(): + state = { + "roles": { + "firewall_runtime": { + "role_name": "firewall_runtime", + "iptables_v4_save": "artifacts/firewall_runtime/iptables.save", + } + } + } + result = manifest._render_firewall_runtime_tasks(state) + assert len(result) >= 1 + + +def test_render_firewall_runtime_tasks_with_ipset(): + state = { + "roles": { + "firewall_runtime": { + "role_name": "firewall_runtime", + "ipset_save": "artifacts/firewall_runtime/ipset.save", + } + } + } + result = manifest._render_firewall_runtime_tasks(state) + assert len(result) >= 1 + + +def test_render_firewall_runtime_tasks_with_ipv6(): + state = { + "roles": { + "firewall_runtime": { + "role_name": "firewall_runtime", + "iptables_v6_save": "artifacts/firewall_runtime/ip6tables.save", + } + } + } + result = manifest._render_firewall_runtime_tasks(state) + assert len(result) >= 1 diff --git a/tests/test_misc_coverage.py b/tests/test_misc_coverage.py deleted file mode 100644 index 1ff6e98..0000000 --- a/tests/test_misc_coverage.py +++ /dev/null @@ -1,416 +0,0 @@ -from __future__ import annotations - -import json -import os -import stat -import subprocess -import sys -import types -from pathlib import Path -from types import SimpleNamespace - -import pytest - -from enroll.cache import _safe_component, new_harvest_cache_dir -from enroll.ignore import IgnorePolicy -from enroll.sopsutil import ( - SopsError, - _pgp_arg, - decrypt_file_binary_to, - encrypt_file_binary, -) - - -def test_safe_component_sanitizes_and_bounds_length(): - assert _safe_component(" ") == "unknown" - assert _safe_component("a/b c") == "a_b_c" - assert _safe_component("x" * 200) == "x" * 64 - - -def test_new_harvest_cache_dir_uses_xdg_cache_home(tmp_path: Path, monkeypatch): - monkeypatch.setenv("XDG_CACHE_HOME", str(tmp_path / "xdg")) - hc = new_harvest_cache_dir(hint="my host/01") - assert hc.dir.exists() - assert "my_host_01" in hc.dir.name - assert str(hc.dir).startswith(str(tmp_path / "xdg")) - # best-effort: ensure directory is not world-readable on typical FS - try: - mode = stat.S_IMODE(hc.dir.stat().st_mode) - assert mode & 0o077 == 0 - except OSError: - pass - - -def test_ignore_policy_denies_binary_and_sensitive_content(tmp_path: Path): - p_bin = tmp_path / "binfile" - p_bin.write_bytes(b"abc\x00def") - assert IgnorePolicy().deny_reason(str(p_bin)) == "binary_like" - - p_secret = tmp_path / "secret.conf" - p_secret.write_text("password=foo\n", encoding="utf-8") - assert IgnorePolicy().deny_reason(str(p_secret)) == "sensitive_content" - - # dangerous mode disables heuristic scanning (but still checks file-ness/size) - assert IgnorePolicy(dangerous=True).deny_reason(str(p_secret)) is None - - -def test_ignore_policy_denies_usr_local_shadow_by_glob(): - # This should short-circuit before stat() (path doesn't need to exist). - assert IgnorePolicy().deny_reason("/usr/local/etc/shadow") == "denied_path" - - -def test_sops_pgp_arg_and_encrypt_decrypt_roundtrip(tmp_path: Path, monkeypatch): - assert _pgp_arg([" ABC ", "DEF"]) == "ABC,DEF" - with pytest.raises(SopsError): - _pgp_arg([]) - - # Stub out sops and subprocess. - import enroll.sopsutil as s - - monkeypatch.setattr(s, "require_sops_cmd", lambda: "sops") - - class R: - def __init__(self, rc: int, out: bytes, err: bytes = b""): - self.returncode = rc - self.stdout = out - self.stderr = err - - calls = [] - - def fake_run(cmd, capture_output, check): - calls.append(cmd) - # Return a deterministic payload so we can assert file writes. - if "--encrypt" in cmd: - return R(0, b"ENCRYPTED") - if "--decrypt" in cmd: - return R(0, b"PLAINTEXT") - return R(1, b"", b"bad") - - monkeypatch.setattr(s.subprocess, "run", fake_run) - - src = tmp_path / "src.bin" - src.write_bytes(b"x") - enc = tmp_path / "out.sops" - dec = tmp_path / "out.bin" - - encrypt_file_binary(src, enc, pgp_fingerprints=["ABC"], mode=0o600) - assert enc.read_bytes() == b"ENCRYPTED" - - decrypt_file_binary_to(enc, dec, mode=0o644) - assert dec.read_bytes() == b"PLAINTEXT" - - # Sanity: we invoked encrypt and decrypt. - assert any("--encrypt" in c for c in calls) - assert any("--decrypt" in c for c in calls) - - -def test_cache_dir_defaults_to_home_cache(monkeypatch, tmp_path: Path): - # Ensure default path uses ~/.cache when XDG_CACHE_HOME is unset. - from enroll.cache import enroll_cache_dir - - monkeypatch.delenv("XDG_CACHE_HOME", raising=False) - monkeypatch.setattr(Path, "home", lambda: tmp_path) - - p = enroll_cache_dir() - assert str(p).startswith(str(tmp_path)) - assert p.name == "enroll" - - -def test_harvest_cache_state_json_property(tmp_path: Path): - from enroll.cache import HarvestCache - - hc = HarvestCache(tmp_path / "h1") - assert hc.state_json == hc.dir / "state.json" - - -def test_cache_dir_security_rejects_symlink(tmp_path: Path): - from enroll.cache import _ensure_dir_secure - - real = tmp_path / "real" - real.mkdir() - link = tmp_path / "link" - link.symlink_to(real, target_is_directory=True) - - with pytest.raises(RuntimeError, match="Refusing to use symlink"): - _ensure_dir_secure(link) - - -def test_cache_dir_chmod_failures_are_ignored(monkeypatch, tmp_path: Path): - from enroll import cache - - # Make the cache base path deterministic and writable. - monkeypatch.setattr(cache, "enroll_cache_dir", lambda: tmp_path) - - # Force os.chmod to fail to cover the "except OSError: pass" paths. - monkeypatch.setattr( - os, "chmod", lambda *a, **k: (_ for _ in ()).throw(OSError("nope")) - ) - - hc = cache.new_harvest_cache_dir() - assert hc.dir.exists() - assert hc.dir.is_dir() - - -def test_stat_triplet_falls_back_to_numeric_ids(monkeypatch, tmp_path: Path): - from enroll.fsutil import stat_triplet - import pwd - import grp - - p = tmp_path / "x" - p.write_text("x", encoding="utf-8") - - # Force username/group resolution failures. - monkeypatch.setattr( - pwd, "getpwuid", lambda _uid: (_ for _ in ()).throw(KeyError("no user")) - ) - monkeypatch.setattr( - grp, "getgrgid", lambda _gid: (_ for _ in ()).throw(KeyError("no group")) - ) - - owner, group, mode = stat_triplet(str(p)) - assert owner.isdigit() - assert group.isdigit() - assert len(mode) == 4 - - -def test_ignore_policy_iter_effective_lines_removes_block_comments(): - from enroll.ignore import IgnorePolicy - - pol = IgnorePolicy() - data = b"""keep1 -/* -drop me -*/ -keep2 -""" - assert list(pol.iter_effective_lines(data)) == [b"keep1", b"keep2"] - - -def test_ignore_policy_deny_reason_dir_variants(tmp_path: Path): - from enroll.ignore import IgnorePolicy - - pol = IgnorePolicy() - - # denied by glob - assert pol.deny_reason_dir("/etc/shadow") == "denied_path" - - # symlink rejected - d = tmp_path / "d" - d.mkdir() - link = tmp_path / "l" - link.symlink_to(d, target_is_directory=True) - assert pol.deny_reason_dir(str(link)) == "symlink" - - # not a directory - f = tmp_path / "f" - f.write_text("x", encoding="utf-8") - assert pol.deny_reason_dir(str(f)) == "not_directory" - - # ok - assert pol.deny_reason_dir(str(d)) is None - - -def test_run_jinjaturtle_parses_outputs(monkeypatch, tmp_path: Path): - # Fully unit-test enroll.jinjaturtle.run_jinjaturtle by stubbing subprocess.run. - from enroll.jinjaturtle import run_jinjaturtle - - def fake_run(cmd, **kwargs): # noqa: ARG001 - # cmd includes "-d -t