From 921801caa632c894ac4228efb390061b64fd668b Mon Sep 17 00:00:00 2001 From: Miguel Jacq Date: Sun, 28 Dec 2025 15:32:40 +1100 Subject: [PATCH] 0.1.6 --- CHANGELOG.md | 5 + debian/changelog | 7 + enroll/cli.py | 2 +- enroll/debian.py | 4 +- enroll/harvest.py | 718 ++++++++++++++++---------------------- enroll/pathfilter.py | 2 +- pyproject.toml | 2 +- rpm/enroll.spec | 5 +- tests/test___main__.py | 18 + tests/test_accounts.py | 143 ++++++++ tests/test_debian.py | 154 ++++++++ tests/test_diff_bundle.py | 89 +++++ tests/test_pathfilter.py | 80 +++++ tests/test_remote.py | 175 ++++++++++ tests/test_systemd.py | 121 +++++++ 15 files changed, 1102 insertions(+), 423 deletions(-) create mode 100644 tests/test___main__.py create mode 100644 tests/test_accounts.py create mode 100644 tests/test_debian.py create mode 100644 tests/test_diff_bundle.py create mode 100644 tests/test_pathfilter.py create mode 100644 tests/test_remote.py create mode 100644 tests/test_systemd.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 79e45cd..2a4c39d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,8 @@ +# 0.1.6 + + * DRY up some code logic + * More test coverage + # 0.1.5 * Consolidate logrotate and cron files into their main service/package roles if they exist. diff --git a/debian/changelog b/debian/changelog index 5f3be58..a15c38a 100644 --- a/debian/changelog +++ b/debian/changelog @@ -1,3 +1,10 @@ +enroll (0.1.6) unstable; urgency=medium + + * DRY up some code logic + * More test coverage + + -- Miguel Jacq Sun, 28 Dec 2025 15:30:00 +1100 + enroll (0.1.5) unstable; urgency=medium * Consolidate logrotate and cron files into their main service/package roles if they exist. diff --git a/enroll/cli.py b/enroll/cli.py index e5f729d..ae9aba0 100644 --- a/enroll/cli.py +++ b/enroll/cli.py @@ -482,7 +482,7 @@ def main() -> None: metavar="GPG_FINGERPRINT", help=( "Encrypt the harvest as a SOPS-encrypted tarball, and bundle+encrypt the manifest output in --out " - "(same behavior as `harvest --sops` and `manifest --sops`)." + "(same behaviour as `harvest --sops` and `manifest --sops`)." ), ) s.add_argument( diff --git a/enroll/debian.py b/enroll/debian.py index 58569e5..0ddc1f3 100644 --- a/enroll/debian.py +++ b/enroll/debian.py @@ -154,7 +154,9 @@ def parse_status_conffiles( if ":" in line: k, v = line.split(":", 1) key = k - cur[key] = v.lstrip() + # Preserve leading spaces in continuation lines, but strip + # the trailing newline from the initial key line value. + cur[key] = v.lstrip().rstrip("\n") if cur: flush() diff --git a/enroll/harvest.py b/enroll/harvest.py index c1a1986..56e5aed 100644 --- a/enroll/harvest.py +++ b/enroll/harvest.py @@ -112,9 +112,9 @@ class ExtraPathsSnapshot: ALLOWED_UNOWNED_EXTS = { + ".cfg", ".cnf", ".conf", - ".cfg", ".ini", ".json", ".link", @@ -136,7 +136,9 @@ ALLOWED_UNOWNED_EXTS = { MAX_FILES_CAP = 4000 MAX_UNOWNED_FILES_PER_ROLE = 500 -# Directories that are shared across many packages; never attribute unowned files in these trees to a single package. +# Directories that are shared across many packages. +# Never attribute all unowned files in these trees +# to one single package. SHARED_ETC_TOPDIRS = { "apparmor.d", "apt", @@ -195,6 +197,82 @@ def _copy_into_bundle( shutil.copy2(abs_path, dst) +def _capture_file( + *, + bundle_dir: str, + role_name: str, + abs_path: str, + reason: str, + policy: IgnorePolicy, + path_filter: PathFilter, + managed_out: List[ManagedFile], + excluded_out: List[ExcludedFile], + seen_role: Optional[Set[str]] = None, + seen_global: Optional[Set[str]] = None, + metadata: Optional[tuple[str, str, str]] = None, +) -> bool: + """Try to capture a single file into the bundle. + + Returns True if the file was copied (managed), False otherwise. + + * seen_role: de-dupe within a role (prevents duplicate tasks/records) + * seen_global: de-dupe across roles/stages (prevents multiple roles copying same path) + * metadata: optional (owner, group, mode) tuple to avoid re-statting + """ + + if seen_global is not None and abs_path in seen_global: + return False + if seen_role is not None and abs_path in seen_role: + return False + + def _mark_seen() -> None: + if seen_role is not None: + seen_role.add(abs_path) + if seen_global is not None: + seen_global.add(abs_path) + + if path_filter.is_excluded(abs_path): + excluded_out.append(ExcludedFile(path=abs_path, reason="user_excluded")) + _mark_seen() + return False + + deny = policy.deny_reason(abs_path) + if deny: + excluded_out.append(ExcludedFile(path=abs_path, reason=deny)) + _mark_seen() + return False + + try: + owner, group, mode = ( + metadata if metadata is not None else stat_triplet(abs_path) + ) + except OSError: + excluded_out.append(ExcludedFile(path=abs_path, reason="unreadable")) + _mark_seen() + return False + + src_rel = abs_path.lstrip("/") + try: + _copy_into_bundle(bundle_dir, role_name, abs_path, src_rel) + except OSError: + excluded_out.append(ExcludedFile(path=abs_path, reason="unreadable")) + _mark_seen() + return False + + managed_out.append( + ManagedFile( + path=abs_path, + src_rel=src_rel, + owner=owner, + group=group, + mode=mode, + reason=reason, + ) + ) + _mark_seen() + return True + + def _is_confish(path: str) -> bool: base = os.path.basename(path) _, ext = os.path.splitext(base) @@ -227,7 +305,6 @@ def _maybe_add_specific_paths(hints: Set[str]) -> List[str]: f"/etc/default/{h}", f"/etc/init.d/{h}", f"/etc/sysctl.d/{h}.conf", - f"/etc/logrotate.d/{h}", ] ) return paths @@ -492,7 +569,7 @@ def harvest( policy = IgnorePolicy(dangerous=dangerous) elif dangerous: # If callers explicitly provided a policy but also requested - # dangerous behavior, honour the CLI intent. + # dangerous behaviour, honour the CLI intent. policy.dangerous = True os.makedirs(bundle_dir, exist_ok=True) @@ -513,12 +590,21 @@ def harvest( # Service roles # ------------------------- service_snaps: List[ServiceSnapshot] = [] + # Track alias strings (service names, package names, stems) that should map + # back to the service role for shared snippet attribution (cron.d/logrotate.d). + service_role_aliases: Dict[str, Set[str]] = {} + # De-dupe per-role captures (avoids duplicate tasks in manifest generation). + seen_by_role: Dict[str, Set[str]] = {} for unit in list_enabled_services(): role = _role_name_from_unit(unit) try: ui = get_unit_info(unit) except UnitQueryError as e: + # Even when we can't query the unit, keep a minimal alias mapping so + # shared snippets can still be attributed to this role by name. + service_role_aliases.setdefault(role, _hint_names(unit, set()) | {role}) + seen_by_role.setdefault(role, set()) service_snaps.append( ServiceSnapshot( unit=unit, @@ -567,6 +653,10 @@ def harvest( hints = _hint_names(unit, pkgs) _add_pkgs_from_etc_topdirs(hints, topdir_to_pkgs, pkgs) + # Keep a stable set of aliases for this service role. Include current + # packages as well, so that package-named snippets (e.g. cron.d or + # logrotate.d entries) can still be attributed back to this service. + service_role_aliases[role] = set(hints) | set(pkgs) | {role} for sp in _maybe_add_specific_paths(hints): if not os.path.exists(sp): @@ -610,7 +700,7 @@ def harvest( # key material under service directories (e.g. /etc/openvpn/*.crt). # # To avoid exploding output for shared trees (e.g. /etc/systemd), keep - # the older "config-ish only" behavior for known shared topdirs. + # the older "config-ish only" behaviour for known shared topdirs. any_roots: List[str] = [] confish_roots: List[str] = [] for h in hints: @@ -646,34 +736,20 @@ def harvest( "No packages or /etc candidates detected (unexpected for enabled service)." ) + # De-dupe within this role while capturing. This also avoids emitting + # duplicate Ansible tasks for the same destination path. + role_seen = seen_by_role.setdefault(role, set()) for path, reason in sorted(candidates.items()): - if path_filter.is_excluded(path): - excluded.append(ExcludedFile(path=path, reason="user_excluded")) - continue - deny = policy.deny_reason(path) - if deny: - excluded.append(ExcludedFile(path=path, reason=deny)) - continue - try: - owner, group, mode = stat_triplet(path) - except OSError: - excluded.append(ExcludedFile(path=path, reason="unreadable")) - continue - src_rel = path.lstrip("/") - try: - _copy_into_bundle(bundle_dir, role, path, src_rel) - except OSError: - excluded.append(ExcludedFile(path=path, reason="unreadable")) - continue - managed.append( - ManagedFile( - path=path, - src_rel=src_rel, - owner=owner, - group=group, - mode=mode, - reason=reason, - ) + _capture_file( + bundle_dir=bundle_dir, + role_name=role, + abs_path=path, + reason=reason, + policy=policy, + path_filter=path_filter, + managed_out=managed, + excluded_out=excluded, + seen_role=role_seen, ) service_snaps.append( @@ -735,36 +811,18 @@ def harvest( snap = service_snap_by_unit.get(ti.trigger_unit) if snap is not None: + role_seen = seen_by_role.setdefault(snap.role_name, set()) for path in timer_paths: - if path_filter.is_excluded(path): - snap.excluded.append( - ExcludedFile(path=path, reason="user_excluded") - ) - continue - deny = policy.deny_reason(path) - if deny: - snap.excluded.append(ExcludedFile(path=path, reason=deny)) - continue - try: - owner, group, mode = stat_triplet(path) - except OSError: - snap.excluded.append(ExcludedFile(path=path, reason="unreadable")) - continue - src_rel = path.lstrip("/") - try: - _copy_into_bundle(bundle_dir, snap.role_name, path, src_rel) - except OSError: - snap.excluded.append(ExcludedFile(path=path, reason="unreadable")) - continue - snap.managed_files.append( - ManagedFile( - path=path, - src_rel=src_rel, - owner=owner, - group=group, - mode=mode, - reason="related_timer", - ) + _capture_file( + bundle_dir=bundle_dir, + role_name=snap.role_name, + abs_path=path, + reason="related_timer", + policy=policy, + path_filter=path_filter, + managed_out=snap.managed_files, + excluded_out=snap.excluded, + seen_role=role_seen, ) continue @@ -852,7 +910,6 @@ def harvest( roots.extend([f"/etc/{td}", f"/etc/{td}.d"]) roots.extend([f"/etc/default/{td}"]) roots.extend([f"/etc/init.d/{td}"]) - roots.extend([f"/etc/logrotate.d/{td}"]) roots.extend([f"/etc/sysctl.d/{td}.conf"]) # Capture any custom/unowned files under /etc/ for this @@ -871,34 +928,18 @@ def harvest( if r not in owned_etc and _is_confish(r): candidates.setdefault(r, "custom_specific_path") + role_seen = seen_by_role.setdefault(role, set()) for path, reason in sorted(candidates.items()): - if path_filter.is_excluded(path): - excluded.append(ExcludedFile(path=path, reason="user_excluded")) - continue - deny = policy.deny_reason(path) - if deny: - excluded.append(ExcludedFile(path=path, reason=deny)) - continue - try: - owner, group, mode = stat_triplet(path) - except OSError: - excluded.append(ExcludedFile(path=path, reason="unreadable")) - continue - src_rel = path.lstrip("/") - try: - _copy_into_bundle(bundle_dir, role, path, src_rel) - except OSError: - excluded.append(ExcludedFile(path=path, reason="unreadable")) - continue - managed.append( - ManagedFile( - path=path, - src_rel=src_rel, - owner=owner, - group=group, - mode=mode, - reason=reason, - ) + _capture_file( + bundle_dir=bundle_dir, + role_name=role, + abs_path=path, + reason=reason, + policy=policy, + path_filter=path_filter, + managed_out=managed, + excluded_out=excluded, + seen_role=role_seen, ) if not pkg_to_etc_paths.get(pkg, []) and not managed: @@ -929,6 +970,7 @@ def harvest( users_notes.append(f"Failed to enumerate users: {e!r}") users_role_name = "users" + users_role_seen = seen_by_role.setdefault(users_role_name, set()) for u in user_records: users_list.append( @@ -946,38 +988,21 @@ def harvest( # Copy only safe SSH public material: authorized_keys + *.pub for sf in u.ssh_files: - if path_filter.is_excluded(sf): - users_excluded.append(ExcludedFile(path=sf, reason="user_excluded")) - continue - deny = policy.deny_reason(sf) - if deny: - users_excluded.append(ExcludedFile(path=sf, reason=deny)) - continue - try: - owner, group, mode = stat_triplet(sf) - except OSError: - users_excluded.append(ExcludedFile(path=sf, reason="unreadable")) - continue - src_rel = sf.lstrip("/") - try: - _copy_into_bundle(bundle_dir, users_role_name, sf, src_rel) - except OSError: - users_excluded.append(ExcludedFile(path=sf, reason="unreadable")) - continue reason = ( "authorized_keys" if sf.endswith("/authorized_keys") else "ssh_public_key" ) - users_managed.append( - ManagedFile( - path=sf, - src_rel=src_rel, - owner=owner, - group=group, - mode=mode, - reason=reason, - ) + _capture_file( + bundle_dir=bundle_dir, + role_name=users_role_name, + abs_path=sf, + reason=reason, + policy=policy, + path_filter=path_filter, + managed_out=users_managed, + excluded_out=users_excluded, + seen_role=users_role_seen, ) users_snapshot = UsersSnapshot( @@ -995,39 +1020,19 @@ def harvest( apt_excluded: List[ExcludedFile] = [] apt_managed: List[ManagedFile] = [] apt_role_name = "apt_config" + apt_role_seen = seen_by_role.setdefault(apt_role_name, set()) for path, reason in _iter_apt_capture_paths(): - if path_filter.is_excluded(path): - apt_excluded.append(ExcludedFile(path=path, reason="user_excluded")) - continue - - deny = policy.deny_reason(path) - if deny: - apt_excluded.append(ExcludedFile(path=path, reason=deny)) - continue - - try: - owner, group, mode = stat_triplet(path) - except OSError: - apt_excluded.append(ExcludedFile(path=path, reason="unreadable")) - continue - - src_rel = path.lstrip("/") - try: - _copy_into_bundle(bundle_dir, apt_role_name, path, src_rel) - except OSError: - apt_excluded.append(ExcludedFile(path=path, reason="unreadable")) - continue - - apt_managed.append( - ManagedFile( - path=path, - src_rel=src_rel, - owner=owner, - group=group, - mode=mode, - reason=reason, - ) + _capture_file( + bundle_dir=bundle_dir, + role_name=apt_role_name, + abs_path=path, + reason=reason, + policy=policy, + path_filter=path_filter, + managed_out=apt_managed, + excluded_out=apt_excluded, + seen_role=apt_role_seen, ) apt_config_snapshot = AptConfigSnapshot( @@ -1062,11 +1067,58 @@ def harvest( svc_by_role: Dict[str, ServiceSnapshot] = {s.role_name: s for s in service_snaps} pkg_by_role: Dict[str, PackageSnapshot] = {p.role_name: p for p in pkg_snaps} - def _target_role_for_shared_snippet(path: str) -> Optional[tuple[str, str]]: - """If `path` is a shared snippet, return (role_name, reason) to attach to.""" - base = os.path.basename(path) + # Package name -> role_name for manually-installed package roles. + pkg_name_to_role: Dict[str, str] = {p.package: p.role_name for p in pkg_snaps} - # Try full filename and stem (before first dot). + # Package name -> list of service role names that reference it. + pkg_to_service_roles: Dict[str, List[str]] = {} + for s in service_snaps: + for pkg in s.packages: + pkg_to_service_roles.setdefault(pkg, []).append(s.role_name) + + # Alias -> role mapping used as a fallback when dpkg ownership is missing. + # Prefer service roles over package roles when both would match. + alias_ranked: Dict[str, tuple[int, str]] = {} + + def _add_alias(alias: str, role_name: str, *, priority: int) -> None: + key = _safe_name(alias) + if not key: + return + cur = alias_ranked.get(key) + if ( + cur is None + or priority < cur[0] + or (priority == cur[0] and role_name < cur[1]) + ): + alias_ranked[key] = (priority, role_name) + + for role_name, aliases in service_role_aliases.items(): + for a in aliases: + _add_alias(a, role_name, priority=0) + + for p in pkg_snaps: + _add_alias(p.package, p.role_name, priority=1) + + def _target_role_for_shared_snippet(path: str) -> Optional[tuple[str, str]]: + """If `path` is a shared snippet, return (role_name, reason) to attach to. + + This is used primarily for /etc/logrotate.d/* and /etc/cron.d/* where + files are "owned" by many packages but people tend to reason about them + per service. + + Resolution order: + 1) dpkg owner -> service role (if any service references the package) + 2) dpkg owner -> package role (manual package role exists) + 3) basename/stem alias match -> preferred role + """ + if path.startswith("/etc/logrotate.d/"): + tag = "logrotate_snippet" + elif path.startswith("/etc/cron.d/"): + tag = "cron_snippet" + else: + return None + + base = os.path.basename(path) candidates: List[str] = [base] if "." in base: candidates.append(base.split(".", 1)[0]) @@ -1078,122 +1130,62 @@ def harvest( seen.add(c) uniq.append(c) - if path.startswith("/etc/logrotate.d/"): - for c in uniq: - rn = _safe_name(c) - if rn in svc_by_role or rn in pkg_by_role: - return (rn, "logrotate_snippet") - return None + pkg = dpkg_owner(path) + if pkg: + svc_roles = pkg_to_service_roles.get(pkg) + if svc_roles: + # Deterministic tie-break: lowest role name. + return (sorted(set(svc_roles))[0], tag) + pkg_role = pkg_name_to_role.get(pkg) + if pkg_role: + return (pkg_role, tag) - if path.startswith("/etc/cron.d/"): - for c in uniq: - rn = _safe_name(c) - if rn in svc_by_role or rn in pkg_by_role: - return (rn, "cron_snippet") - return None + for c in uniq: + key = _safe_name(c) + hit = alias_ranked.get(key) + if hit is not None: + return (hit[1], tag) return None + def _lists_for_role(role_name: str) -> tuple[List[ManagedFile], List[ExcludedFile]]: + if role_name in svc_by_role: + snap = svc_by_role[role_name] + return (snap.managed_files, snap.excluded) + if role_name in pkg_by_role: + snap = pkg_by_role[role_name] + return (snap.managed_files, snap.excluded) + # Fallback (shouldn't normally happen): attribute to etc_custom. + return (etc_managed, etc_excluded) + # Capture essential system config/state (even if package-owned). + etc_role_seen = seen_by_role.setdefault(etc_role_name, set()) for path, reason in _iter_system_capture_paths(): if path in already: continue target = _target_role_for_shared_snippet(path) - - if path_filter.is_excluded(path): - if target: - rn, _ = target - if rn in svc_by_role: - svc_by_role[rn].excluded.append( - ExcludedFile(path=path, reason="user_excluded") - ) - elif rn in pkg_by_role: - pkg_by_role[rn].excluded.append( - ExcludedFile(path=path, reason="user_excluded") - ) - else: - etc_excluded.append(ExcludedFile(path=path, reason="user_excluded")) - already.add(path) - continue - - deny = policy.deny_reason(path) - if deny: - if target: - rn, _ = target - if rn in svc_by_role: - svc_by_role[rn].excluded.append( - ExcludedFile(path=path, reason=deny) - ) - elif rn in pkg_by_role: - pkg_by_role[rn].excluded.append( - ExcludedFile(path=path, reason=deny) - ) - else: - etc_excluded.append(ExcludedFile(path=path, reason=deny)) - already.add(path) - continue - - try: - owner, group, mode = stat_triplet(path) - except OSError: - if target: - rn, _ = target - if rn in svc_by_role: - svc_by_role[rn].excluded.append( - ExcludedFile(path=path, reason="unreadable") - ) - elif rn in pkg_by_role: - pkg_by_role[rn].excluded.append( - ExcludedFile(path=path, reason="unreadable") - ) - else: - etc_excluded.append(ExcludedFile(path=path, reason="unreadable")) - already.add(path) - continue - - src_rel = path.lstrip("/") - role_for_copy = etc_role_name - reason_for_role = reason - if target: + if target is not None: role_for_copy, reason_for_role = target - - try: - _copy_into_bundle(bundle_dir, role_for_copy, path, src_rel) - except OSError: - if target: - rn, _ = target - if rn in svc_by_role: - svc_by_role[rn].excluded.append( - ExcludedFile(path=path, reason="unreadable") - ) - elif rn in pkg_by_role: - pkg_by_role[rn].excluded.append( - ExcludedFile(path=path, reason="unreadable") - ) - else: - etc_excluded.append(ExcludedFile(path=path, reason="unreadable")) - already.add(path) - continue - - mf = ManagedFile( - path=path, - src_rel=src_rel, - owner=owner, - group=group, - mode=mode, - reason=reason_for_role, - ) - if target: - rn, _ = target - if rn in svc_by_role: - svc_by_role[rn].managed_files.append(mf) - elif rn in pkg_by_role: - pkg_by_role[rn].managed_files.append(mf) + managed_out, excluded_out = _lists_for_role(role_for_copy) + role_seen = seen_by_role.setdefault(role_for_copy, set()) else: - etc_managed.append(mf) + role_for_copy, reason_for_role = (etc_role_name, reason) + managed_out, excluded_out = (etc_managed, etc_excluded) + role_seen = etc_role_seen - already.add(path) + _capture_file( + bundle_dir=bundle_dir, + role_name=role_for_copy, + abs_path=path, + reason=reason_for_role, + policy=policy, + path_filter=path_filter, + managed_out=managed_out, + excluded_out=excluded_out, + seen_role=role_seen, + seen_global=already, + ) # Walk /etc for remaining unowned config-ish files scanned = 0 @@ -1212,99 +1204,28 @@ def harvest( continue target = _target_role_for_shared_snippet(path) - - if path_filter.is_excluded(path): - if target: - rn, _ = target - if rn in svc_by_role: - svc_by_role[rn].excluded.append( - ExcludedFile(path=path, reason="user_excluded") - ) - elif rn in pkg_by_role: - pkg_by_role[rn].excluded.append( - ExcludedFile(path=path, reason="user_excluded") - ) - else: - etc_excluded.append(ExcludedFile(path=path, reason="user_excluded")) - already.add(path) - continue - - deny = policy.deny_reason(path) - if deny: - if target: - rn, _ = target - if rn in svc_by_role: - svc_by_role[rn].excluded.append( - ExcludedFile(path=path, reason=deny) - ) - elif rn in pkg_by_role: - pkg_by_role[rn].excluded.append( - ExcludedFile(path=path, reason=deny) - ) - else: - etc_excluded.append(ExcludedFile(path=path, reason=deny)) - already.add(path) - continue - - try: - owner, group, mode = stat_triplet(path) - except OSError: - if target: - rn, _ = target - if rn in svc_by_role: - svc_by_role[rn].excluded.append( - ExcludedFile(path=path, reason="unreadable") - ) - elif rn in pkg_by_role: - pkg_by_role[rn].excluded.append( - ExcludedFile(path=path, reason="unreadable") - ) - else: - etc_excluded.append(ExcludedFile(path=path, reason="unreadable")) - already.add(path) - continue - - src_rel = path.lstrip("/") - role_for_copy = etc_role_name - reason_for_role = "custom_unowned" - if target: + if target is not None: role_for_copy, reason_for_role = target - - try: - _copy_into_bundle(bundle_dir, role_for_copy, path, src_rel) - except OSError: - if target: - rn, _ = target - if rn in svc_by_role: - svc_by_role[rn].excluded.append( - ExcludedFile(path=path, reason="unreadable") - ) - elif rn in pkg_by_role: - pkg_by_role[rn].excluded.append( - ExcludedFile(path=path, reason="unreadable") - ) - else: - etc_excluded.append(ExcludedFile(path=path, reason="unreadable")) - already.add(path) - continue - - mf = ManagedFile( - path=path, - src_rel=src_rel, - owner=owner, - group=group, - mode=mode, - reason=reason_for_role, - ) - if target: - rn, _ = target - if rn in svc_by_role: - svc_by_role[rn].managed_files.append(mf) - elif rn in pkg_by_role: - pkg_by_role[rn].managed_files.append(mf) + managed_out, excluded_out = _lists_for_role(role_for_copy) + role_seen = seen_by_role.setdefault(role_for_copy, set()) else: - etc_managed.append(mf) - scanned += 1 + role_for_copy, reason_for_role = (etc_role_name, "custom_unowned") + managed_out, excluded_out = (etc_managed, etc_excluded) + role_seen = etc_role_seen + + if _capture_file( + bundle_dir=bundle_dir, + role_name=role_for_copy, + abs_path=path, + reason=reason_for_role, + policy=policy, + path_filter=path_filter, + managed_out=managed_out, + excluded_out=excluded_out, + seen_role=role_seen, + seen_global=already, + ): + scanned += 1 if scanned >= MAX_FILES_CAP: etc_notes.append( f"Reached file cap ({MAX_FILES_CAP}) while scanning /etc for unowned files." @@ -1339,6 +1260,7 @@ def harvest( scanned = 0 if not os.path.isdir(root): return + role_seen = seen_by_role.setdefault(ul_role_name, set()) for dirpath, _, filenames in os.walk(root): for fn in filenames: path = os.path.join(dirpath, fn) @@ -1346,54 +1268,34 @@ def harvest( continue if not os.path.isfile(path) or os.path.islink(path): continue + try: + owner, group, mode = stat_triplet(path) + except OSError: + ul_excluded.append(ExcludedFile(path=path, reason="unreadable")) + continue + if require_executable: - try: - owner, group, mode = stat_triplet(path) - except OSError: - ul_excluded.append(ExcludedFile(path=path, reason="unreadable")) - continue try: if (int(mode, 8) & 0o111) == 0: continue except ValueError: # If mode parsing fails, be conservative and skip. continue - else: - try: - owner, group, mode = stat_triplet(path) - except OSError: - ul_excluded.append(ExcludedFile(path=path, reason="unreadable")) - continue - if path_filter.is_excluded(path): - ul_excluded.append(ExcludedFile(path=path, reason="user_excluded")) - continue - - deny = policy.deny_reason(path) - if deny: - ul_excluded.append(ExcludedFile(path=path, reason=deny)) - continue - - src_rel = path.lstrip("/") - try: - _copy_into_bundle(bundle_dir, ul_role_name, path, src_rel) - except OSError: - ul_excluded.append(ExcludedFile(path=path, reason="unreadable")) - continue - - ul_managed.append( - ManagedFile( - path=path, - src_rel=src_rel, - owner=owner, - group=group, - mode=mode, - reason=reason, - ) - ) - - already_all.add(path) - scanned += 1 + if _capture_file( + bundle_dir=bundle_dir, + role_name=ul_role_name, + abs_path=path, + reason=reason, + policy=policy, + path_filter=path_filter, + managed_out=ul_managed, + excluded_out=ul_excluded, + seen_role=role_seen, + metadata=(owner, group, mode), + ): + already_all.add(path) + scanned += 1 if scanned >= cap: ul_notes.append(f"Reached file cap ({cap}) while scanning {root}.") return @@ -1428,6 +1330,7 @@ def harvest( extra_excluded: List[ExcludedFile] = [] extra_managed: List[ManagedFile] = [] extra_role_name = "extra_paths" + extra_role_seen = seen_by_role.setdefault(extra_role_name, set()) include_specs = list(include_paths or []) exclude_specs = list(exclude_paths or []) @@ -1453,39 +1356,18 @@ def harvest( if path in already_all: continue - if path_filter.is_excluded(path): - extra_excluded.append(ExcludedFile(path=path, reason="user_excluded")) - continue - - deny = policy.deny_reason(path) - if deny: - extra_excluded.append(ExcludedFile(path=path, reason=deny)) - continue - - try: - owner, group, mode = stat_triplet(path) - except OSError: - extra_excluded.append(ExcludedFile(path=path, reason="unreadable")) - continue - - src_rel = path.lstrip("/") - try: - _copy_into_bundle(bundle_dir, extra_role_name, path, src_rel) - except OSError: - extra_excluded.append(ExcludedFile(path=path, reason="unreadable")) - continue - - extra_managed.append( - ManagedFile( - path=path, - src_rel=src_rel, - owner=owner, - group=group, - mode=mode, - reason="user_include", - ) - ) - already_all.add(path) + if _capture_file( + bundle_dir=bundle_dir, + role_name=extra_role_name, + abs_path=path, + reason="user_include", + policy=policy, + path_filter=path_filter, + managed_out=extra_managed, + excluded_out=extra_excluded, + seen_role=extra_role_seen, + ): + already_all.add(path) extra_paths_snapshot = ExtraPathsSnapshot( role_name=extra_role_name, diff --git a/enroll/pathfilter.py b/enroll/pathfilter.py index 6541ca9..680d390 100644 --- a/enroll/pathfilter.py +++ b/enroll/pathfilter.py @@ -141,7 +141,7 @@ class PathFilter: - Regex: prefix with 're:' or 'regex:' - Force glob: prefix with 'glob:' - A plain path without wildcards matches that path and everything under it - (directory-prefix behavior). + (directory-prefix behaviour). Examples: --exclude-path /usr/local/bin/docker-* diff --git a/pyproject.toml b/pyproject.toml index 3aa01d0..c7356bc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "enroll" -version = "0.1.5" +version = "0.1.6" description = "Enroll a server's running state retrospectively into Ansible" authors = ["Miguel Jacq "] license = "GPL-3.0-or-later" diff --git a/rpm/enroll.spec b/rpm/enroll.spec index ed0a3c9..637dee1 100644 --- a/rpm/enroll.spec +++ b/rpm/enroll.spec @@ -1,4 +1,4 @@ -%global upstream_version 0.1.5 +%global upstream_version 0.1.6 Name: enroll Version: %{upstream_version} @@ -44,6 +44,9 @@ Enroll a server's running state retrospectively into Ansible. %changelog * Sun Dec 28 2025 Miguel Jacq - %{version}-%{release} +- DRY up some code logic +- More test coverage +* Sun Dec 28 2025 Miguel Jacq - %{version}-%{release} - Consolidate logrotate and cron files into their main service/package roles if they exist. - Standardise on MAX_FILES_CAP in one place - Manage apt stuff in its own role, not in etc_custom diff --git a/tests/test___main__.py b/tests/test___main__.py new file mode 100644 index 0000000..2e83ac1 --- /dev/null +++ b/tests/test___main__.py @@ -0,0 +1,18 @@ +from __future__ import annotations + +import runpy + + +def test_module_main_invokes_cli_main(monkeypatch): + import enroll.cli + + called = {"ok": False} + + def fake_main() -> None: + called["ok"] = True + + monkeypatch.setattr(enroll.cli, "main", fake_main) + + # Execute enroll.__main__ as if `python -m enroll`. + runpy.run_module("enroll.__main__", run_name="__main__") + assert called["ok"] is True diff --git a/tests/test_accounts.py b/tests/test_accounts.py new file mode 100644 index 0000000..d5cc267 --- /dev/null +++ b/tests/test_accounts.py @@ -0,0 +1,143 @@ +from __future__ import annotations + +import os +from pathlib import Path + + +def test_parse_login_defs_parses_known_keys(tmp_path: Path): + from enroll.accounts import parse_login_defs + + p = tmp_path / "login.defs" + p.write_text( + """ + # comment + UID_MIN 1000 + UID_MAX 60000 + SYS_UID_MIN 100 + SYS_UID_MAX 999 + UID_MIN not_an_int + OTHER 123 + """, + encoding="utf-8", + ) + + vals = parse_login_defs(str(p)) + assert vals["UID_MIN"] == 1000 + assert vals["UID_MAX"] == 60000 + assert vals["SYS_UID_MIN"] == 100 + assert vals["SYS_UID_MAX"] == 999 + assert "OTHER" not in vals + + +def test_parse_passwd_and_group_and_ssh_files(tmp_path: Path): + from enroll.accounts import find_user_ssh_files, parse_group, parse_passwd + + passwd = tmp_path / "passwd" + passwd.write_text( + "\n".join( + [ + "root:x:0:0:root:/root:/bin/bash", + "# comment", + "alice:x:1000:1000:Alice:/home/alice:/bin/bash", + "bob:x:1001:1000:Bob:/home/bob:/usr/sbin/nologin", + "badline", + "cathy:x:notint:1000:Cathy:/home/cathy:/bin/bash", + "", + ] + ), + encoding="utf-8", + ) + + group = tmp_path / "group" + group.write_text( + "\n".join( + [ + "root:x:0:", + "users:x:1000:alice,bob", + "admins:x:1002:alice", + "badgroup:x:notint:alice", + "", + ] + ), + encoding="utf-8", + ) + + rows = parse_passwd(str(passwd)) + assert ("alice", 1000, 1000, "Alice", "/home/alice", "/bin/bash") in rows + assert all(r[0] != "cathy" for r in rows) # skipped invalid UID + + gid_to_name, name_to_gid, members = parse_group(str(group)) + assert gid_to_name[1000] == "users" + assert name_to_gid["admins"] == 1002 + assert "alice" in members["admins"] + + # ssh discovery: only authorized_keys, no symlinks + home = tmp_path / "home" / "alice" + sshdir = home / ".ssh" + sshdir.mkdir(parents=True) + ak = sshdir / "authorized_keys" + ak.write_text("ssh-ed25519 AAA...", encoding="utf-8") + # a symlink should be ignored + (sshdir / "authorized_keys2").write_text("x", encoding="utf-8") + os.symlink(str(sshdir / "authorized_keys2"), str(sshdir / "authorized_keys_link")) + assert find_user_ssh_files(str(home)) == [str(ak)] + + +def test_collect_non_system_users(monkeypatch, tmp_path: Path): + import enroll.accounts as a + + orig_parse_login_defs = a.parse_login_defs + orig_parse_passwd = a.parse_passwd + orig_parse_group = a.parse_group + + # Provide controlled passwd/group/login.defs inputs via monkeypatch. + passwd = tmp_path / "passwd" + passwd.write_text( + "\n".join( + [ + "root:x:0:0:root:/root:/bin/bash", + "nobody:x:65534:65534:nobody:/nonexistent:/usr/sbin/nologin", + "alice:x:1000:1000:Alice:/home/alice:/bin/bash", + "sysuser:x:200:200:Sys:/home/sys:/bin/bash", + "bob:x:1001:1000:Bob:/home/bob:/bin/false", + "", + ] + ), + encoding="utf-8", + ) + group = tmp_path / "group" + group.write_text( + "\n".join( + [ + "users:x:1000:alice,bob", + "admins:x:1002:alice", + "", + ] + ), + encoding="utf-8", + ) + + defs = tmp_path / "login.defs" + defs.write_text("UID_MIN 1000\n", encoding="utf-8") + + monkeypatch.setattr( + a, "parse_login_defs", lambda path=str(defs): orig_parse_login_defs(path) + ) + monkeypatch.setattr( + a, "parse_passwd", lambda path=str(passwd): orig_parse_passwd(path) + ) + monkeypatch.setattr( + a, "parse_group", lambda path=str(group): orig_parse_group(path) + ) + + # Use a stable fake ssh discovery. + monkeypatch.setattr( + a, "find_user_ssh_files", lambda home: [f"{home}/.ssh/authorized_keys"] + ) + + users = a.collect_non_system_users() + assert [u.name for u in users] == ["alice"] + u = users[0] + assert u.primary_group == "users" + assert u.supplementary_groups == ["admins"] + assert u.ssh_files == ["/home/alice/.ssh/authorized_keys"] diff --git a/tests/test_debian.py b/tests/test_debian.py new file mode 100644 index 0000000..333afc1 --- /dev/null +++ b/tests/test_debian.py @@ -0,0 +1,154 @@ +from __future__ import annotations + +import hashlib +from pathlib import Path + + +def test_dpkg_owner_parses_output(monkeypatch): + import enroll.debian as d + + class P: + def __init__(self, rc: int, out: str): + self.returncode = rc + self.stdout = out + self.stderr = "" + + def fake_run(cmd, text, capture_output): + assert cmd[:2] == ["dpkg", "-S"] + return P( + 0, + """ + diversion by foo from: /etc/something + nginx-common:amd64: /etc/nginx/nginx.conf + nginx-common, nginx: /etc/nginx/sites-enabled/default + """, + ) + + monkeypatch.setattr(d.subprocess, "run", fake_run) + assert d.dpkg_owner("/etc/nginx/nginx.conf") == "nginx-common" + + def fake_run_none(cmd, text, capture_output): + return P(1, "") + + monkeypatch.setattr(d.subprocess, "run", fake_run_none) + assert d.dpkg_owner("/missing") is None + + +def test_list_manual_packages_parses_and_sorts(monkeypatch): + import enroll.debian as d + + class P: + def __init__(self, rc: int, out: str): + self.returncode = rc + self.stdout = out + self.stderr = "" + + def fake_run(cmd, text, capture_output): + assert cmd == ["apt-mark", "showmanual"] + return P(0, "\n# comment\nnginx\nvim\nnginx\n") + + monkeypatch.setattr(d.subprocess, "run", fake_run) + assert d.list_manual_packages() == ["nginx", "vim"] + + +def test_build_dpkg_etc_index(tmp_path: Path): + import enroll.debian as d + + info = tmp_path / "info" + info.mkdir() + (info / "nginx.list").write_text( + "/etc/nginx/nginx.conf\n/etc/nginx/sites-enabled/default\n/usr/bin/nginx\n", + encoding="utf-8", + ) + (info / "vim:amd64.list").write_text( + "/etc/vim/vimrc\n/usr/bin/vim\n", + encoding="utf-8", + ) + + owned, owner_map, topdir_to_pkgs, pkg_to_etc = d.build_dpkg_etc_index(str(info)) + assert "/etc/nginx/nginx.conf" in owned + assert owner_map["/etc/nginx/nginx.conf"] == "nginx" + assert "nginx" in topdir_to_pkgs + assert topdir_to_pkgs["nginx"] == {"nginx"} + assert pkg_to_etc["vim"] == ["/etc/vim/vimrc"] + + +def test_parse_status_conffiles_handles_continuations(tmp_path: Path): + import enroll.debian as d + + status = tmp_path / "status" + status.write_text( + "\n".join( + [ + "Package: nginx", + "Version: 1", + "Conffiles:", + " /etc/nginx/nginx.conf abcdef", + " /etc/nginx/mime.types 123456", + "", + "Package: other", + "Version: 2", + "", + ] + ), + encoding="utf-8", + ) + m = d.parse_status_conffiles(str(status)) + assert m["nginx"]["/etc/nginx/nginx.conf"] == "abcdef" + assert m["nginx"]["/etc/nginx/mime.types"] == "123456" + assert "other" not in m + + +def test_read_pkg_md5sums_and_file_md5(tmp_path: Path, monkeypatch): + import enroll.debian as d + + # Patch /var/lib/dpkg/info/.md5sums lookup to a tmp file. + md5_file = tmp_path / "pkg.md5sums" + md5_file.write_text("0123456789abcdef etc/foo.conf\n", encoding="utf-8") + + def fake_exists(path: str) -> bool: + return path.endswith("/var/lib/dpkg/info/p1.md5sums") + + real_open = open + + def fake_open(path: str, *args, **kwargs): + if path.endswith("/var/lib/dpkg/info/p1.md5sums"): + return real_open(md5_file, *args, **kwargs) + return real_open(path, *args, **kwargs) + + monkeypatch.setattr(d.os.path, "exists", fake_exists) + monkeypatch.setattr("builtins.open", fake_open) + + m = d.read_pkg_md5sums("p1") + assert m == {"etc/foo.conf": "0123456789abcdef"} + + content = b"hello world\n" + p = tmp_path / "x" + p.write_bytes(content) + assert d.file_md5(str(p)) == hashlib.md5(content).hexdigest() + + +def test_stat_triplet_fallbacks(tmp_path: Path, monkeypatch): + import enroll.debian as d + import sys + + p = tmp_path / "f" + p.write_text("x", encoding="utf-8") + + class FakePwdMod: + @staticmethod + def getpwuid(_): # pragma: no cover + raise KeyError + + class FakeGrpMod: + @staticmethod + def getgrgid(_): # pragma: no cover + raise KeyError + + # stat_triplet imports pwd/grp inside the function, so patch sys.modules. + monkeypatch.setitem(sys.modules, "pwd", FakePwdMod) + monkeypatch.setitem(sys.modules, "grp", FakeGrpMod) + owner, group, mode = d.stat_triplet(str(p)) + assert owner.isdigit() + assert group.isdigit() + assert mode.isdigit() and len(mode) == 4 diff --git a/tests/test_diff_bundle.py b/tests/test_diff_bundle.py new file mode 100644 index 0000000..66ef094 --- /dev/null +++ b/tests/test_diff_bundle.py @@ -0,0 +1,89 @@ +from __future__ import annotations + +import os +import tarfile +from pathlib import Path + +import pytest + + +def _make_bundle_dir(tmp_path: Path) -> Path: + b = tmp_path / "bundle" + (b / "artifacts").mkdir(parents=True) + (b / "state.json").write_text("{}\n", encoding="utf-8") + return b + + +def _tar_gz_of_dir(src: Path, out: Path) -> None: + with tarfile.open(out, mode="w:gz") as tf: + # tar -C src . semantics + for p in src.rglob("*"): + rel = p.relative_to(src) + tf.add(p, arcname=str(rel)) + + +def test_bundle_from_directory_and_statejson_path(tmp_path: Path): + import enroll.diff as d + + b = _make_bundle_dir(tmp_path) + + br1 = d._bundle_from_input(str(b), sops_mode=False) + assert br1.dir == b + assert br1.state_path.exists() + + br2 = d._bundle_from_input(str(b / "state.json"), sops_mode=False) + assert br2.dir == b + + +def test_bundle_from_tarball_extracts(tmp_path: Path): + import enroll.diff as d + + b = _make_bundle_dir(tmp_path) + tgz = tmp_path / "bundle.tgz" + _tar_gz_of_dir(b, tgz) + + br = d._bundle_from_input(str(tgz), sops_mode=False) + try: + assert br.dir.is_dir() + assert (br.dir / "state.json").exists() + finally: + if br.tempdir: + br.tempdir.cleanup() + + +def test_bundle_from_sops_like_file(monkeypatch, tmp_path: Path): + import enroll.diff as d + + b = _make_bundle_dir(tmp_path) + tgz = tmp_path / "bundle.tar.gz" + _tar_gz_of_dir(b, tgz) + + # Pretend the tarball is an encrypted bundle by giving it a .sops name. + sops_path = tmp_path / "bundle.tar.gz.sops" + sops_path.write_bytes(tgz.read_bytes()) + + # Stub out sops machinery: "decrypt" just copies through. + monkeypatch.setattr(d, "require_sops_cmd", lambda: "sops") + + def fake_decrypt(src: Path, dest: Path, mode: int): + dest.write_bytes(Path(src).read_bytes()) + try: + os.chmod(dest, mode) + except OSError: + pass + + monkeypatch.setattr(d, "decrypt_file_binary_to", fake_decrypt) + + br = d._bundle_from_input(str(sops_path), sops_mode=False) + try: + assert (br.dir / "state.json").exists() + finally: + if br.tempdir: + br.tempdir.cleanup() + + +def test_bundle_from_input_missing_path(tmp_path: Path): + import enroll.diff as d + + with pytest.raises(RuntimeError, match="not found"): + d._bundle_from_input(str(tmp_path / "nope"), sops_mode=False) diff --git a/tests/test_pathfilter.py b/tests/test_pathfilter.py new file mode 100644 index 0000000..406b7e7 --- /dev/null +++ b/tests/test_pathfilter.py @@ -0,0 +1,80 @@ +from __future__ import annotations + +import os +from pathlib import Path + + +def test_compile_and_match_prefix_glob_and_regex(tmp_path: Path): + from enroll.pathfilter import PathFilter, compile_path_pattern + + # prefix semantics: matches the exact path and subtree + p = compile_path_pattern("/etc/nginx") + assert p.kind == "prefix" + assert p.matches("/etc/nginx") + assert p.matches("/etc/nginx/nginx.conf") + assert not p.matches("/etc/nginx2/nginx.conf") + + # glob semantics + g = compile_path_pattern("/etc/**/*.conf") + assert g.kind == "glob" + assert g.matches("/etc/nginx/nginx.conf") + assert not g.matches("/var/etc/nginx.conf") + + # explicit glob + g2 = compile_path_pattern("glob:/home/*/.bashrc") + assert g2.kind == "glob" + assert g2.matches("/home/alice/.bashrc") + + # regex semantics (search, not match) + r = compile_path_pattern(r"re:/home/[^/]+/\.ssh/authorized_keys$") + assert r.kind == "regex" + assert r.matches("/home/alice/.ssh/authorized_keys") + assert not r.matches("/home/alice/.ssh/authorized_keys2") + + # invalid regex: never matches + bad = compile_path_pattern("re:[") + assert bad.kind == "regex" + assert not bad.matches("/etc/passwd") + + # exclude wins + pf = PathFilter(exclude=["/etc/nginx"], include=["/etc/nginx/nginx.conf"]) + assert pf.is_excluded("/etc/nginx/nginx.conf") + + +def test_expand_includes_respects_exclude_symlinks_and_caps(tmp_path: Path): + from enroll.pathfilter import PathFilter, compile_path_pattern, expand_includes + + root = tmp_path / "root" + (root / "a").mkdir(parents=True) + (root / "a" / "one.txt").write_text("1", encoding="utf-8") + (root / "a" / "two.txt").write_text("2", encoding="utf-8") + (root / "b").mkdir() + (root / "b" / "secret.txt").write_text("s", encoding="utf-8") + + # symlink file should be ignored + os.symlink(str(root / "a" / "one.txt"), str(root / "a" / "link.txt")) + + exclude = PathFilter(exclude=[str(root / "b")]) + + pats = [ + compile_path_pattern(str(root / "a")), + compile_path_pattern("glob:" + str(root / "**" / "*.txt")), + ] + + paths, notes = expand_includes(pats, exclude=exclude, max_files=2) + # cap should limit to 2 files + assert len(paths) == 2 + assert any("cap" in n.lower() for n in notes) + # excluded dir should not contribute + assert all("/b/" not in p for p in paths) + # symlink ignored + assert all(not p.endswith("link.txt") for p in paths) + + +def test_expand_includes_notes_on_no_matches(tmp_path: Path): + from enroll.pathfilter import compile_path_pattern, expand_includes + + pats = [compile_path_pattern(str(tmp_path / "does_not_exist"))] + paths, notes = expand_includes(pats, max_files=10) + assert paths == [] + assert any("matched no files" in n.lower() for n in notes) diff --git a/tests/test_remote.py b/tests/test_remote.py new file mode 100644 index 0000000..576c0b1 --- /dev/null +++ b/tests/test_remote.py @@ -0,0 +1,175 @@ +from __future__ import annotations + +import io +import tarfile +from pathlib import Path + +import pytest + + +def _make_tgz_bytes(files: dict[str, bytes]) -> bytes: + bio = io.BytesIO() + with tarfile.open(fileobj=bio, mode="w:gz") as tf: + for name, content in files.items(): + ti = tarfile.TarInfo(name=name) + ti.size = len(content) + tf.addfile(ti, io.BytesIO(content)) + return bio.getvalue() + + +def test_safe_extract_tar_rejects_path_traversal(tmp_path: Path): + from enroll.remote import _safe_extract_tar + + # Build an unsafe tar with ../ traversal + bio = io.BytesIO() + with tarfile.open(fileobj=bio, mode="w:gz") as tf: + ti = tarfile.TarInfo(name="../evil") + ti.size = 1 + tf.addfile(ti, io.BytesIO(b"x")) + + bio.seek(0) + with tarfile.open(fileobj=bio, mode="r:gz") as tf: + with pytest.raises(RuntimeError, match="Unsafe tar member path"): + _safe_extract_tar(tf, tmp_path) + + +def test_safe_extract_tar_rejects_symlinks(tmp_path: Path): + from enroll.remote import _safe_extract_tar + + bio = io.BytesIO() + with tarfile.open(fileobj=bio, mode="w:gz") as tf: + ti = tarfile.TarInfo(name="link") + ti.type = tarfile.SYMTYPE + ti.linkname = "/etc/passwd" + tf.addfile(ti) + + bio.seek(0) + with tarfile.open(fileobj=bio, mode="r:gz") as tf: + with pytest.raises(RuntimeError, match="Refusing to extract"): + _safe_extract_tar(tf, tmp_path) + + +def test_remote_harvest_happy_path(tmp_path: Path, monkeypatch): + import sys + + import enroll.remote as r + + # Avoid building a real zipapp; just create a file. + def fake_build(_td: Path) -> Path: + p = _td / "enroll.pyz" + p.write_bytes(b"PYZ") + return p + + monkeypatch.setattr(r, "_build_enroll_pyz", fake_build) + + # Prepare a tiny harvest bundle tar stream from the "remote". + tgz = _make_tgz_bytes({"state.json": b'{"ok": true}\n'}) + + calls: list[str] = [] + + class _Chan: + def __init__(self, rc: int = 0): + self._rc = rc + + def recv_exit_status(self) -> int: + return self._rc + + class _Stdout: + def __init__(self, payload: bytes = b"", rc: int = 0): + self._bio = io.BytesIO(payload) + self.channel = _Chan(rc) + + def read(self, n: int = -1) -> bytes: + return self._bio.read(n) + + class _Stderr: + def __init__(self, payload: bytes = b""): + self._bio = io.BytesIO(payload) + + def read(self, n: int = -1) -> bytes: + return self._bio.read(n) + + class _SFTP: + def __init__(self): + self.put_calls: list[tuple[str, str]] = [] + + def put(self, local: str, remote: str) -> None: + self.put_calls.append((local, remote)) + + def close(self) -> None: + return + + class FakeSSH: + def __init__(self): + self._sftp = _SFTP() + + def load_system_host_keys(self): + return + + def set_missing_host_key_policy(self, _policy): + return + + def connect(self, **kwargs): + # Accept any connect parameters. + return + + def open_sftp(self): + return self._sftp + + def exec_command(self, cmd: str): + calls.append(cmd) + # The tar stream uses exec_command directly. + if cmd.startswith("tar -cz -C"): + return (None, _Stdout(tgz, rc=0), _Stderr(b"")) + + # _ssh_run path: id -un, mktemp -d, chmod, sudo harvest, sudo chown, rm -rf + if cmd == "id -un": + return (None, _Stdout(b"alice\n"), _Stderr()) + if cmd == "mktemp -d": + return (None, _Stdout(b"/tmp/enroll-remote-123\n"), _Stderr()) + if cmd.startswith("chmod 700"): + return (None, _Stdout(b""), _Stderr()) + if " harvest " in cmd: + return (None, _Stdout(b""), _Stderr()) + if cmd.startswith("sudo chown -R"): + return (None, _Stdout(b""), _Stderr()) + if cmd.startswith("rm -rf"): + return (None, _Stdout(b""), _Stderr()) + + return (None, _Stdout(b""), _Stderr(b"unknown")) + + def close(self): + return + + import types + + class RejectPolicy: + pass + + FakeParamiko = types.SimpleNamespace(SSHClient=FakeSSH, RejectPolicy=RejectPolicy) + + # Provide a fake paramiko module. + monkeypatch.setitem(sys.modules, "paramiko", FakeParamiko) + + out_dir = tmp_path / "out" + state_path = r.remote_harvest( + local_out_dir=out_dir, + remote_host="example.com", + remote_port=2222, + remote_user=None, + include_paths=["/etc/nginx/nginx.conf"], + exclude_paths=["/etc/shadow"], + dangerous=True, + no_sudo=False, + ) + + assert state_path == out_dir / "state.json" + assert state_path.exists() + assert b"ok" in state_path.read_bytes() + + # Ensure we attempted remote harvest with sudo and passed include/exclude and dangerous. + joined = "\n".join(calls) + assert "sudo" in joined + assert "--dangerous" in joined + assert "--include-path" in joined + assert "--exclude-path" in joined diff --git a/tests/test_systemd.py b/tests/test_systemd.py new file mode 100644 index 0000000..f351159 --- /dev/null +++ b/tests/test_systemd.py @@ -0,0 +1,121 @@ +from __future__ import annotations + +import pytest + + +def test_list_enabled_services_and_timers_filters_templates(monkeypatch): + import enroll.systemd as s + + def fake_run(cmd: list[str]) -> str: + if "--type=service" in cmd: + return "\n".join( + [ + "nginx.service enabled", + "getty@.service enabled", # template + "foo@bar.service enabled", # instance units are included + "ssh.service enabled", + ] + ) + if "--type=timer" in cmd: + return "\n".join( + [ + "apt-daily.timer enabled", + "foo@.timer enabled", # template + ] + ) + raise AssertionError("unexpected") + + monkeypatch.setattr(s, "_run", fake_run) + assert s.list_enabled_services() == [ + "foo@bar.service", + "nginx.service", + "ssh.service", + ] + assert s.list_enabled_timers() == ["apt-daily.timer"] + + +def test_get_unit_info_parses_fields(monkeypatch): + import enroll.systemd as s + + class P: + def __init__(self, rc: int, out: str, err: str = ""): + self.returncode = rc + self.stdout = out + self.stderr = err + + def fake_run(cmd, check, text, capture_output): + assert cmd[0:2] == ["systemctl", "show"] + return P( + 0, + "\n".join( + [ + "FragmentPath=/lib/systemd/system/nginx.service", + "DropInPaths=/etc/systemd/system/nginx.service.d/override.conf /etc/systemd/system/nginx.service.d/extra.conf", + "EnvironmentFiles=-/etc/default/nginx /etc/nginx/env", + "ExecStart={ path=/usr/sbin/nginx ; argv[]=/usr/sbin/nginx -g daemon off; }", + "ActiveState=active", + "SubState=running", + "UnitFileState=enabled", + "ConditionResult=yes", + ] + ), + ) + + monkeypatch.setattr(s.subprocess, "run", fake_run) + ui = s.get_unit_info("nginx.service") + assert ui.fragment_path == "/lib/systemd/system/nginx.service" + assert "/etc/default/nginx" in ui.env_files + assert "/etc/nginx/env" in ui.env_files + assert "/usr/sbin/nginx" in ui.exec_paths + assert ui.active_state == "active" + + +def test_get_unit_info_raises_unit_query_error(monkeypatch): + import enroll.systemd as s + + class P: + def __init__(self, rc: int, out: str, err: str): + self.returncode = rc + self.stdout = out + self.stderr = err + + def fake_run(cmd, check, text, capture_output): + return P(1, "", "no such unit") + + monkeypatch.setattr(s.subprocess, "run", fake_run) + with pytest.raises(s.UnitQueryError) as ei: + s.get_unit_info("missing.service") + assert "missing.service" in str(ei.value) + assert ei.value.unit == "missing.service" + + +def test_get_timer_info_parses_fields(monkeypatch): + import enroll.systemd as s + + class P: + def __init__(self, rc: int, out: str, err: str = ""): + self.returncode = rc + self.stdout = out + self.stderr = err + + def fake_run(cmd, text, capture_output): + return P( + 0, + "\n".join( + [ + "FragmentPath=/lib/systemd/system/apt-daily.timer", + "DropInPaths=", + "EnvironmentFiles=-/etc/default/apt", + "Unit=apt-daily.service", + "ActiveState=active", + "SubState=waiting", + "UnitFileState=enabled", + "ConditionResult=yes", + ] + ), + ) + + monkeypatch.setattr(s.subprocess, "run", fake_run) + ti = s.get_timer_info("apt-daily.timer") + assert ti.trigger_unit == "apt-daily.service" + assert "/etc/default/apt" in ti.env_files