0.1.6
All checks were successful
CI / test (push) Successful in 5m24s
Lint / test (push) Successful in 30s
Trivy / test (push) Successful in 16s

This commit is contained in:
Miguel Jacq 2025-12-28 15:32:40 +11:00
parent 3fc5aec5fc
commit 921801caa6
Signed by: mig5
GPG key ID: 59B3F0C24135C6A9
15 changed files with 1102 additions and 423 deletions

18
tests/test___main__.py Normal file
View file

@ -0,0 +1,18 @@
from __future__ import annotations
import runpy
def test_module_main_invokes_cli_main(monkeypatch):
import enroll.cli
called = {"ok": False}
def fake_main() -> None:
called["ok"] = True
monkeypatch.setattr(enroll.cli, "main", fake_main)
# Execute enroll.__main__ as if `python -m enroll`.
runpy.run_module("enroll.__main__", run_name="__main__")
assert called["ok"] is True

143
tests/test_accounts.py Normal file
View file

@ -0,0 +1,143 @@
from __future__ import annotations
import os
from pathlib import Path
def test_parse_login_defs_parses_known_keys(tmp_path: Path):
from enroll.accounts import parse_login_defs
p = tmp_path / "login.defs"
p.write_text(
"""
# comment
UID_MIN 1000
UID_MAX 60000
SYS_UID_MIN 100
SYS_UID_MAX 999
UID_MIN not_an_int
OTHER 123
""",
encoding="utf-8",
)
vals = parse_login_defs(str(p))
assert vals["UID_MIN"] == 1000
assert vals["UID_MAX"] == 60000
assert vals["SYS_UID_MIN"] == 100
assert vals["SYS_UID_MAX"] == 999
assert "OTHER" not in vals
def test_parse_passwd_and_group_and_ssh_files(tmp_path: Path):
from enroll.accounts import find_user_ssh_files, parse_group, parse_passwd
passwd = tmp_path / "passwd"
passwd.write_text(
"\n".join(
[
"root:x:0:0:root:/root:/bin/bash",
"# comment",
"alice:x:1000:1000:Alice:/home/alice:/bin/bash",
"bob:x:1001:1000:Bob:/home/bob:/usr/sbin/nologin",
"badline",
"cathy:x:notint:1000:Cathy:/home/cathy:/bin/bash",
"",
]
),
encoding="utf-8",
)
group = tmp_path / "group"
group.write_text(
"\n".join(
[
"root:x:0:",
"users:x:1000:alice,bob",
"admins:x:1002:alice",
"badgroup:x:notint:alice",
"",
]
),
encoding="utf-8",
)
rows = parse_passwd(str(passwd))
assert ("alice", 1000, 1000, "Alice", "/home/alice", "/bin/bash") in rows
assert all(r[0] != "cathy" for r in rows) # skipped invalid UID
gid_to_name, name_to_gid, members = parse_group(str(group))
assert gid_to_name[1000] == "users"
assert name_to_gid["admins"] == 1002
assert "alice" in members["admins"]
# ssh discovery: only authorized_keys, no symlinks
home = tmp_path / "home" / "alice"
sshdir = home / ".ssh"
sshdir.mkdir(parents=True)
ak = sshdir / "authorized_keys"
ak.write_text("ssh-ed25519 AAA...", encoding="utf-8")
# a symlink should be ignored
(sshdir / "authorized_keys2").write_text("x", encoding="utf-8")
os.symlink(str(sshdir / "authorized_keys2"), str(sshdir / "authorized_keys_link"))
assert find_user_ssh_files(str(home)) == [str(ak)]
def test_collect_non_system_users(monkeypatch, tmp_path: Path):
import enroll.accounts as a
orig_parse_login_defs = a.parse_login_defs
orig_parse_passwd = a.parse_passwd
orig_parse_group = a.parse_group
# Provide controlled passwd/group/login.defs inputs via monkeypatch.
passwd = tmp_path / "passwd"
passwd.write_text(
"\n".join(
[
"root:x:0:0:root:/root:/bin/bash",
"nobody:x:65534:65534:nobody:/nonexistent:/usr/sbin/nologin",
"alice:x:1000:1000:Alice:/home/alice:/bin/bash",
"sysuser:x:200:200:Sys:/home/sys:/bin/bash",
"bob:x:1001:1000:Bob:/home/bob:/bin/false",
"",
]
),
encoding="utf-8",
)
group = tmp_path / "group"
group.write_text(
"\n".join(
[
"users:x:1000:alice,bob",
"admins:x:1002:alice",
"",
]
),
encoding="utf-8",
)
defs = tmp_path / "login.defs"
defs.write_text("UID_MIN 1000\n", encoding="utf-8")
monkeypatch.setattr(
a, "parse_login_defs", lambda path=str(defs): orig_parse_login_defs(path)
)
monkeypatch.setattr(
a, "parse_passwd", lambda path=str(passwd): orig_parse_passwd(path)
)
monkeypatch.setattr(
a, "parse_group", lambda path=str(group): orig_parse_group(path)
)
# Use a stable fake ssh discovery.
monkeypatch.setattr(
a, "find_user_ssh_files", lambda home: [f"{home}/.ssh/authorized_keys"]
)
users = a.collect_non_system_users()
assert [u.name for u in users] == ["alice"]
u = users[0]
assert u.primary_group == "users"
assert u.supplementary_groups == ["admins"]
assert u.ssh_files == ["/home/alice/.ssh/authorized_keys"]

154
tests/test_debian.py Normal file
View file

@ -0,0 +1,154 @@
from __future__ import annotations
import hashlib
from pathlib import Path
def test_dpkg_owner_parses_output(monkeypatch):
import enroll.debian as d
class P:
def __init__(self, rc: int, out: str):
self.returncode = rc
self.stdout = out
self.stderr = ""
def fake_run(cmd, text, capture_output):
assert cmd[:2] == ["dpkg", "-S"]
return P(
0,
"""
diversion by foo from: /etc/something
nginx-common:amd64: /etc/nginx/nginx.conf
nginx-common, nginx: /etc/nginx/sites-enabled/default
""",
)
monkeypatch.setattr(d.subprocess, "run", fake_run)
assert d.dpkg_owner("/etc/nginx/nginx.conf") == "nginx-common"
def fake_run_none(cmd, text, capture_output):
return P(1, "")
monkeypatch.setattr(d.subprocess, "run", fake_run_none)
assert d.dpkg_owner("/missing") is None
def test_list_manual_packages_parses_and_sorts(monkeypatch):
import enroll.debian as d
class P:
def __init__(self, rc: int, out: str):
self.returncode = rc
self.stdout = out
self.stderr = ""
def fake_run(cmd, text, capture_output):
assert cmd == ["apt-mark", "showmanual"]
return P(0, "\n# comment\nnginx\nvim\nnginx\n")
monkeypatch.setattr(d.subprocess, "run", fake_run)
assert d.list_manual_packages() == ["nginx", "vim"]
def test_build_dpkg_etc_index(tmp_path: Path):
import enroll.debian as d
info = tmp_path / "info"
info.mkdir()
(info / "nginx.list").write_text(
"/etc/nginx/nginx.conf\n/etc/nginx/sites-enabled/default\n/usr/bin/nginx\n",
encoding="utf-8",
)
(info / "vim:amd64.list").write_text(
"/etc/vim/vimrc\n/usr/bin/vim\n",
encoding="utf-8",
)
owned, owner_map, topdir_to_pkgs, pkg_to_etc = d.build_dpkg_etc_index(str(info))
assert "/etc/nginx/nginx.conf" in owned
assert owner_map["/etc/nginx/nginx.conf"] == "nginx"
assert "nginx" in topdir_to_pkgs
assert topdir_to_pkgs["nginx"] == {"nginx"}
assert pkg_to_etc["vim"] == ["/etc/vim/vimrc"]
def test_parse_status_conffiles_handles_continuations(tmp_path: Path):
import enroll.debian as d
status = tmp_path / "status"
status.write_text(
"\n".join(
[
"Package: nginx",
"Version: 1",
"Conffiles:",
" /etc/nginx/nginx.conf abcdef",
" /etc/nginx/mime.types 123456",
"",
"Package: other",
"Version: 2",
"",
]
),
encoding="utf-8",
)
m = d.parse_status_conffiles(str(status))
assert m["nginx"]["/etc/nginx/nginx.conf"] == "abcdef"
assert m["nginx"]["/etc/nginx/mime.types"] == "123456"
assert "other" not in m
def test_read_pkg_md5sums_and_file_md5(tmp_path: Path, monkeypatch):
import enroll.debian as d
# Patch /var/lib/dpkg/info/<pkg>.md5sums lookup to a tmp file.
md5_file = tmp_path / "pkg.md5sums"
md5_file.write_text("0123456789abcdef etc/foo.conf\n", encoding="utf-8")
def fake_exists(path: str) -> bool:
return path.endswith("/var/lib/dpkg/info/p1.md5sums")
real_open = open
def fake_open(path: str, *args, **kwargs):
if path.endswith("/var/lib/dpkg/info/p1.md5sums"):
return real_open(md5_file, *args, **kwargs)
return real_open(path, *args, **kwargs)
monkeypatch.setattr(d.os.path, "exists", fake_exists)
monkeypatch.setattr("builtins.open", fake_open)
m = d.read_pkg_md5sums("p1")
assert m == {"etc/foo.conf": "0123456789abcdef"}
content = b"hello world\n"
p = tmp_path / "x"
p.write_bytes(content)
assert d.file_md5(str(p)) == hashlib.md5(content).hexdigest()
def test_stat_triplet_fallbacks(tmp_path: Path, monkeypatch):
import enroll.debian as d
import sys
p = tmp_path / "f"
p.write_text("x", encoding="utf-8")
class FakePwdMod:
@staticmethod
def getpwuid(_): # pragma: no cover
raise KeyError
class FakeGrpMod:
@staticmethod
def getgrgid(_): # pragma: no cover
raise KeyError
# stat_triplet imports pwd/grp inside the function, so patch sys.modules.
monkeypatch.setitem(sys.modules, "pwd", FakePwdMod)
monkeypatch.setitem(sys.modules, "grp", FakeGrpMod)
owner, group, mode = d.stat_triplet(str(p))
assert owner.isdigit()
assert group.isdigit()
assert mode.isdigit() and len(mode) == 4

89
tests/test_diff_bundle.py Normal file
View file

@ -0,0 +1,89 @@
from __future__ import annotations
import os
import tarfile
from pathlib import Path
import pytest
def _make_bundle_dir(tmp_path: Path) -> Path:
b = tmp_path / "bundle"
(b / "artifacts").mkdir(parents=True)
(b / "state.json").write_text("{}\n", encoding="utf-8")
return b
def _tar_gz_of_dir(src: Path, out: Path) -> None:
with tarfile.open(out, mode="w:gz") as tf:
# tar -C src . semantics
for p in src.rglob("*"):
rel = p.relative_to(src)
tf.add(p, arcname=str(rel))
def test_bundle_from_directory_and_statejson_path(tmp_path: Path):
import enroll.diff as d
b = _make_bundle_dir(tmp_path)
br1 = d._bundle_from_input(str(b), sops_mode=False)
assert br1.dir == b
assert br1.state_path.exists()
br2 = d._bundle_from_input(str(b / "state.json"), sops_mode=False)
assert br2.dir == b
def test_bundle_from_tarball_extracts(tmp_path: Path):
import enroll.diff as d
b = _make_bundle_dir(tmp_path)
tgz = tmp_path / "bundle.tgz"
_tar_gz_of_dir(b, tgz)
br = d._bundle_from_input(str(tgz), sops_mode=False)
try:
assert br.dir.is_dir()
assert (br.dir / "state.json").exists()
finally:
if br.tempdir:
br.tempdir.cleanup()
def test_bundle_from_sops_like_file(monkeypatch, tmp_path: Path):
import enroll.diff as d
b = _make_bundle_dir(tmp_path)
tgz = tmp_path / "bundle.tar.gz"
_tar_gz_of_dir(b, tgz)
# Pretend the tarball is an encrypted bundle by giving it a .sops name.
sops_path = tmp_path / "bundle.tar.gz.sops"
sops_path.write_bytes(tgz.read_bytes())
# Stub out sops machinery: "decrypt" just copies through.
monkeypatch.setattr(d, "require_sops_cmd", lambda: "sops")
def fake_decrypt(src: Path, dest: Path, mode: int):
dest.write_bytes(Path(src).read_bytes())
try:
os.chmod(dest, mode)
except OSError:
pass
monkeypatch.setattr(d, "decrypt_file_binary_to", fake_decrypt)
br = d._bundle_from_input(str(sops_path), sops_mode=False)
try:
assert (br.dir / "state.json").exists()
finally:
if br.tempdir:
br.tempdir.cleanup()
def test_bundle_from_input_missing_path(tmp_path: Path):
import enroll.diff as d
with pytest.raises(RuntimeError, match="not found"):
d._bundle_from_input(str(tmp_path / "nope"), sops_mode=False)

80
tests/test_pathfilter.py Normal file
View file

@ -0,0 +1,80 @@
from __future__ import annotations
import os
from pathlib import Path
def test_compile_and_match_prefix_glob_and_regex(tmp_path: Path):
from enroll.pathfilter import PathFilter, compile_path_pattern
# prefix semantics: matches the exact path and subtree
p = compile_path_pattern("/etc/nginx")
assert p.kind == "prefix"
assert p.matches("/etc/nginx")
assert p.matches("/etc/nginx/nginx.conf")
assert not p.matches("/etc/nginx2/nginx.conf")
# glob semantics
g = compile_path_pattern("/etc/**/*.conf")
assert g.kind == "glob"
assert g.matches("/etc/nginx/nginx.conf")
assert not g.matches("/var/etc/nginx.conf")
# explicit glob
g2 = compile_path_pattern("glob:/home/*/.bashrc")
assert g2.kind == "glob"
assert g2.matches("/home/alice/.bashrc")
# regex semantics (search, not match)
r = compile_path_pattern(r"re:/home/[^/]+/\.ssh/authorized_keys$")
assert r.kind == "regex"
assert r.matches("/home/alice/.ssh/authorized_keys")
assert not r.matches("/home/alice/.ssh/authorized_keys2")
# invalid regex: never matches
bad = compile_path_pattern("re:[")
assert bad.kind == "regex"
assert not bad.matches("/etc/passwd")
# exclude wins
pf = PathFilter(exclude=["/etc/nginx"], include=["/etc/nginx/nginx.conf"])
assert pf.is_excluded("/etc/nginx/nginx.conf")
def test_expand_includes_respects_exclude_symlinks_and_caps(tmp_path: Path):
from enroll.pathfilter import PathFilter, compile_path_pattern, expand_includes
root = tmp_path / "root"
(root / "a").mkdir(parents=True)
(root / "a" / "one.txt").write_text("1", encoding="utf-8")
(root / "a" / "two.txt").write_text("2", encoding="utf-8")
(root / "b").mkdir()
(root / "b" / "secret.txt").write_text("s", encoding="utf-8")
# symlink file should be ignored
os.symlink(str(root / "a" / "one.txt"), str(root / "a" / "link.txt"))
exclude = PathFilter(exclude=[str(root / "b")])
pats = [
compile_path_pattern(str(root / "a")),
compile_path_pattern("glob:" + str(root / "**" / "*.txt")),
]
paths, notes = expand_includes(pats, exclude=exclude, max_files=2)
# cap should limit to 2 files
assert len(paths) == 2
assert any("cap" in n.lower() for n in notes)
# excluded dir should not contribute
assert all("/b/" not in p for p in paths)
# symlink ignored
assert all(not p.endswith("link.txt") for p in paths)
def test_expand_includes_notes_on_no_matches(tmp_path: Path):
from enroll.pathfilter import compile_path_pattern, expand_includes
pats = [compile_path_pattern(str(tmp_path / "does_not_exist"))]
paths, notes = expand_includes(pats, max_files=10)
assert paths == []
assert any("matched no files" in n.lower() for n in notes)

175
tests/test_remote.py Normal file
View file

@ -0,0 +1,175 @@
from __future__ import annotations
import io
import tarfile
from pathlib import Path
import pytest
def _make_tgz_bytes(files: dict[str, bytes]) -> bytes:
bio = io.BytesIO()
with tarfile.open(fileobj=bio, mode="w:gz") as tf:
for name, content in files.items():
ti = tarfile.TarInfo(name=name)
ti.size = len(content)
tf.addfile(ti, io.BytesIO(content))
return bio.getvalue()
def test_safe_extract_tar_rejects_path_traversal(tmp_path: Path):
from enroll.remote import _safe_extract_tar
# Build an unsafe tar with ../ traversal
bio = io.BytesIO()
with tarfile.open(fileobj=bio, mode="w:gz") as tf:
ti = tarfile.TarInfo(name="../evil")
ti.size = 1
tf.addfile(ti, io.BytesIO(b"x"))
bio.seek(0)
with tarfile.open(fileobj=bio, mode="r:gz") as tf:
with pytest.raises(RuntimeError, match="Unsafe tar member path"):
_safe_extract_tar(tf, tmp_path)
def test_safe_extract_tar_rejects_symlinks(tmp_path: Path):
from enroll.remote import _safe_extract_tar
bio = io.BytesIO()
with tarfile.open(fileobj=bio, mode="w:gz") as tf:
ti = tarfile.TarInfo(name="link")
ti.type = tarfile.SYMTYPE
ti.linkname = "/etc/passwd"
tf.addfile(ti)
bio.seek(0)
with tarfile.open(fileobj=bio, mode="r:gz") as tf:
with pytest.raises(RuntimeError, match="Refusing to extract"):
_safe_extract_tar(tf, tmp_path)
def test_remote_harvest_happy_path(tmp_path: Path, monkeypatch):
import sys
import enroll.remote as r
# Avoid building a real zipapp; just create a file.
def fake_build(_td: Path) -> Path:
p = _td / "enroll.pyz"
p.write_bytes(b"PYZ")
return p
monkeypatch.setattr(r, "_build_enroll_pyz", fake_build)
# Prepare a tiny harvest bundle tar stream from the "remote".
tgz = _make_tgz_bytes({"state.json": b'{"ok": true}\n'})
calls: list[str] = []
class _Chan:
def __init__(self, rc: int = 0):
self._rc = rc
def recv_exit_status(self) -> int:
return self._rc
class _Stdout:
def __init__(self, payload: bytes = b"", rc: int = 0):
self._bio = io.BytesIO(payload)
self.channel = _Chan(rc)
def read(self, n: int = -1) -> bytes:
return self._bio.read(n)
class _Stderr:
def __init__(self, payload: bytes = b""):
self._bio = io.BytesIO(payload)
def read(self, n: int = -1) -> bytes:
return self._bio.read(n)
class _SFTP:
def __init__(self):
self.put_calls: list[tuple[str, str]] = []
def put(self, local: str, remote: str) -> None:
self.put_calls.append((local, remote))
def close(self) -> None:
return
class FakeSSH:
def __init__(self):
self._sftp = _SFTP()
def load_system_host_keys(self):
return
def set_missing_host_key_policy(self, _policy):
return
def connect(self, **kwargs):
# Accept any connect parameters.
return
def open_sftp(self):
return self._sftp
def exec_command(self, cmd: str):
calls.append(cmd)
# The tar stream uses exec_command directly.
if cmd.startswith("tar -cz -C"):
return (None, _Stdout(tgz, rc=0), _Stderr(b""))
# _ssh_run path: id -un, mktemp -d, chmod, sudo harvest, sudo chown, rm -rf
if cmd == "id -un":
return (None, _Stdout(b"alice\n"), _Stderr())
if cmd == "mktemp -d":
return (None, _Stdout(b"/tmp/enroll-remote-123\n"), _Stderr())
if cmd.startswith("chmod 700"):
return (None, _Stdout(b""), _Stderr())
if " harvest " in cmd:
return (None, _Stdout(b""), _Stderr())
if cmd.startswith("sudo chown -R"):
return (None, _Stdout(b""), _Stderr())
if cmd.startswith("rm -rf"):
return (None, _Stdout(b""), _Stderr())
return (None, _Stdout(b""), _Stderr(b"unknown"))
def close(self):
return
import types
class RejectPolicy:
pass
FakeParamiko = types.SimpleNamespace(SSHClient=FakeSSH, RejectPolicy=RejectPolicy)
# Provide a fake paramiko module.
monkeypatch.setitem(sys.modules, "paramiko", FakeParamiko)
out_dir = tmp_path / "out"
state_path = r.remote_harvest(
local_out_dir=out_dir,
remote_host="example.com",
remote_port=2222,
remote_user=None,
include_paths=["/etc/nginx/nginx.conf"],
exclude_paths=["/etc/shadow"],
dangerous=True,
no_sudo=False,
)
assert state_path == out_dir / "state.json"
assert state_path.exists()
assert b"ok" in state_path.read_bytes()
# Ensure we attempted remote harvest with sudo and passed include/exclude and dangerous.
joined = "\n".join(calls)
assert "sudo" in joined
assert "--dangerous" in joined
assert "--include-path" in joined
assert "--exclude-path" in joined

121
tests/test_systemd.py Normal file
View file

@ -0,0 +1,121 @@
from __future__ import annotations
import pytest
def test_list_enabled_services_and_timers_filters_templates(monkeypatch):
import enroll.systemd as s
def fake_run(cmd: list[str]) -> str:
if "--type=service" in cmd:
return "\n".join(
[
"nginx.service enabled",
"getty@.service enabled", # template
"foo@bar.service enabled", # instance units are included
"ssh.service enabled",
]
)
if "--type=timer" in cmd:
return "\n".join(
[
"apt-daily.timer enabled",
"foo@.timer enabled", # template
]
)
raise AssertionError("unexpected")
monkeypatch.setattr(s, "_run", fake_run)
assert s.list_enabled_services() == [
"foo@bar.service",
"nginx.service",
"ssh.service",
]
assert s.list_enabled_timers() == ["apt-daily.timer"]
def test_get_unit_info_parses_fields(monkeypatch):
import enroll.systemd as s
class P:
def __init__(self, rc: int, out: str, err: str = ""):
self.returncode = rc
self.stdout = out
self.stderr = err
def fake_run(cmd, check, text, capture_output):
assert cmd[0:2] == ["systemctl", "show"]
return P(
0,
"\n".join(
[
"FragmentPath=/lib/systemd/system/nginx.service",
"DropInPaths=/etc/systemd/system/nginx.service.d/override.conf /etc/systemd/system/nginx.service.d/extra.conf",
"EnvironmentFiles=-/etc/default/nginx /etc/nginx/env",
"ExecStart={ path=/usr/sbin/nginx ; argv[]=/usr/sbin/nginx -g daemon off; }",
"ActiveState=active",
"SubState=running",
"UnitFileState=enabled",
"ConditionResult=yes",
]
),
)
monkeypatch.setattr(s.subprocess, "run", fake_run)
ui = s.get_unit_info("nginx.service")
assert ui.fragment_path == "/lib/systemd/system/nginx.service"
assert "/etc/default/nginx" in ui.env_files
assert "/etc/nginx/env" in ui.env_files
assert "/usr/sbin/nginx" in ui.exec_paths
assert ui.active_state == "active"
def test_get_unit_info_raises_unit_query_error(monkeypatch):
import enroll.systemd as s
class P:
def __init__(self, rc: int, out: str, err: str):
self.returncode = rc
self.stdout = out
self.stderr = err
def fake_run(cmd, check, text, capture_output):
return P(1, "", "no such unit")
monkeypatch.setattr(s.subprocess, "run", fake_run)
with pytest.raises(s.UnitQueryError) as ei:
s.get_unit_info("missing.service")
assert "missing.service" in str(ei.value)
assert ei.value.unit == "missing.service"
def test_get_timer_info_parses_fields(monkeypatch):
import enroll.systemd as s
class P:
def __init__(self, rc: int, out: str, err: str = ""):
self.returncode = rc
self.stdout = out
self.stderr = err
def fake_run(cmd, text, capture_output):
return P(
0,
"\n".join(
[
"FragmentPath=/lib/systemd/system/apt-daily.timer",
"DropInPaths=",
"EnvironmentFiles=-/etc/default/apt",
"Unit=apt-daily.service",
"ActiveState=active",
"SubState=waiting",
"UnitFileState=enabled",
"ConditionResult=yes",
]
),
)
monkeypatch.setattr(s.subprocess, "run", fake_run)
ti = s.get_timer_info("apt-daily.timer")
assert ti.trigger_unit == "apt-daily.service"
assert "/etc/default/apt" in ti.env_files