More test coverage (71%)
Some checks failed
Lint / test (push) Waiting to run
Trivy / test (push) Waiting to run
CI / test (push) Has been cancelled

This commit is contained in:
Miguel Jacq 2026-01-03 12:34:39 +11:00
parent 9a2516d858
commit f82fd894ca
Signed by: mig5
GPG key ID: 59B3F0C24135C6A9
8 changed files with 605 additions and 10 deletions

View file

@ -0,0 +1,33 @@
from __future__ import annotations
import os
from pathlib import Path
import pytest
def test_ensure_dir_secure_refuses_symlink(tmp_path: Path):
from enroll.cache import _ensure_dir_secure
target = tmp_path / "target"
target.mkdir()
link = tmp_path / "link"
link.symlink_to(target, target_is_directory=True)
with pytest.raises(RuntimeError):
_ensure_dir_secure(link)
def test_ensure_dir_secure_ignores_chmod_failures(tmp_path: Path, monkeypatch):
from enroll.cache import _ensure_dir_secure
d = tmp_path / "d"
def boom(_path: str, _mode: int):
raise OSError("no")
monkeypatch.setattr(os, "chmod", boom)
# Should not raise.
_ensure_dir_secure(d)
assert d.exists() and d.is_dir()

177
tests/test_cli_helpers.py Normal file
View file

@ -0,0 +1,177 @@
from __future__ import annotations
import argparse
import configparser
import types
import textwrap
from pathlib import Path
def test_discover_config_path_precedence(tmp_path: Path, monkeypatch):
"""_discover_config_path: --config > ENROLL_CONFIG > ./enroll.ini > XDG."""
from enroll.cli import _discover_config_path
cfg1 = tmp_path / "one.ini"
cfg1.write_text("[enroll]\n", encoding="utf-8")
# Explicit --config should win.
assert _discover_config_path(["--config", str(cfg1)]) == cfg1
# --no-config disables config loading.
assert _discover_config_path(["--no-config", "--config", str(cfg1)]) is None
monkeypatch.chdir(tmp_path)
cfg2 = tmp_path / "two.ini"
cfg2.write_text("[enroll]\n", encoding="utf-8")
monkeypatch.setenv("ENROLL_CONFIG", str(cfg2))
assert _discover_config_path([]) == cfg2
# Local ./enroll.ini fallback.
monkeypatch.delenv("ENROLL_CONFIG", raising=False)
local = tmp_path / "enroll.ini"
local.write_text("[enroll]\n", encoding="utf-8")
assert _discover_config_path([]) == local
# XDG fallback.
local.unlink()
xdg = tmp_path / "xdg"
cfg3 = xdg / "enroll" / "enroll.ini"
cfg3.parent.mkdir(parents=True)
cfg3.write_text("[enroll]\n", encoding="utf-8")
monkeypatch.setenv("XDG_CONFIG_HOME", str(xdg))
assert _discover_config_path([]) == cfg3
def test_config_value_parsing_and_list_splitting():
from enroll.cli import _parse_bool, _split_list_value
assert _parse_bool("1") is True
assert _parse_bool("yes") is True
assert _parse_bool("false") is False
assert _parse_bool("maybe") is None
assert _split_list_value("a,b , c") == ["a", "b", "c"]
# When newlines are present, we split on lines (not commas within a line).
assert _split_list_value("a,b\nc") == ["a,b", "c"]
assert _split_list_value("a\n\n b\n") == ["a", "b"]
assert _split_list_value(" ") == []
def test_section_to_argv_handles_types_and_unknown_keys(capsys):
from enroll.cli import _section_to_argv
p = argparse.ArgumentParser(add_help=False)
p.add_argument("--dangerous", action="store_true")
p.add_argument("--no-color", dest="color", action="store_false")
p.add_argument("--include-path", dest="include_path", action="append")
p.add_argument("-v", action="count", default=0)
p.add_argument("--out")
cfg = configparser.ConfigParser()
cfg.read_dict(
{
"harvest": {
"dangerous": "true",
# Keys are matched by argparse dest; store_false actions still use dest.
"color": "false",
"include-path": "a,b,c",
"v": "2",
"out": "/tmp/bundle",
"unknown": "ignored",
}
}
)
argv = _section_to_argv(p, cfg, "harvest")
# Boolean store_true.
assert "--dangerous" in argv
# Boolean store_false: include the flag only when config wants False.
assert "--no-color" in argv
# Append: split lists and add one flag per item.
assert argv.count("--include-path") == 3
assert "a" in argv and "b" in argv and "c" in argv
# Count: repeats.
assert argv.count("-v") == 2
# Scalar.
assert "--out" in argv and "/tmp/bundle" in argv
err = capsys.readouterr().err
assert "unknown option" in err
def test_inject_config_argv_inserts_global_and_subcommand(tmp_path: Path, capsys):
from enroll.cli import _inject_config_argv
cfg = tmp_path / "enroll.ini"
cfg.write_text(
textwrap.dedent(
"""
[enroll]
dangerous = true
[harvest]
include-path = /etc/foo
unknown = 1
"""
).strip()
+ "\n",
encoding="utf-8",
)
root = argparse.ArgumentParser(add_help=False)
root.add_argument("--dangerous", action="store_true")
harvest_p = argparse.ArgumentParser(add_help=False)
harvest_p.add_argument("--include-path", dest="include_path", action="append")
argv = _inject_config_argv(
["harvest", "--out", "x"],
cfg_path=cfg,
root_parser=root,
subparsers={"harvest": harvest_p},
)
# Global tokens should appear before the subcommand.
assert argv[0] == "--dangerous"
assert argv[1] == "harvest"
# Subcommand tokens should appear right after the subcommand.
assert argv[2:4] == ["--include-path", "/etc/foo"]
# Unknown option should have produced a warning.
assert "unknown option" in capsys.readouterr().err
def test_resolve_sops_out_file(tmp_path: Path, monkeypatch):
from enroll import cli
# Make a predictable cache dir for the default case.
fake_cache = types.SimpleNamespace(dir=tmp_path / "cache")
fake_cache.dir.mkdir(parents=True)
monkeypatch.setattr(cli, "new_harvest_cache_dir", lambda hint=None: fake_cache)
# If out is a directory, use it directly.
out_dir = tmp_path / "out"
out_dir.mkdir()
# The output filename is fixed; hint is only used when creating a cache dir.
assert (
cli._resolve_sops_out_file(out=out_dir, hint="bundle.tar.gz")
== out_dir / "harvest.tar.gz.sops"
)
# If out is a file path, keep it.
out_file = tmp_path / "x.sops"
assert cli._resolve_sops_out_file(out=out_file, hint="bundle.tar.gz") == out_file
# None uses the cache dir, and the name is fixed.
assert (
cli._resolve_sops_out_file(out=None, hint="bundle.tar.gz")
== fake_cache.dir / "harvest.tar.gz.sops"
)

View file

@ -0,0 +1,83 @@
from __future__ import annotations
import types
import pytest
def test_post_webhook_success(monkeypatch):
from enroll.diff import post_webhook
class FakeResp:
status = 204
def read(self):
return b"OK"
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
monkeypatch.setattr(
"enroll.diff.urllib.request.urlopen",
lambda req, timeout=30: FakeResp(),
)
status, body = post_webhook("https://example.com", b"x")
assert status == 204
assert body == "OK"
def test_post_webhook_raises_on_error(monkeypatch):
from enroll.diff import post_webhook
def boom(_req, timeout=30):
raise OSError("nope")
monkeypatch.setattr("enroll.diff.urllib.request.urlopen", boom)
with pytest.raises(RuntimeError):
post_webhook("https://example.com", b"x")
def test_send_email_uses_sendmail_when_present(monkeypatch):
from enroll.diff import send_email
calls = {}
monkeypatch.setattr("enroll.diff.shutil.which", lambda name: "/usr/sbin/sendmail")
def fake_run(argv, input=None, check=False, **_kwargs):
calls["argv"] = argv
calls["input"] = input
return types.SimpleNamespace(returncode=0)
monkeypatch.setattr("enroll.diff.subprocess.run", fake_run)
send_email(
subject="Subj",
body="Body",
from_addr="a@example.com",
to_addrs=["b@example.com"],
)
assert calls["argv"][0].endswith("sendmail")
msg = (calls["input"] or b"").decode("utf-8", errors="replace")
assert "Subject: Subj" in msg
assert "To: b@example.com" in msg
def test_send_email_raises_when_no_delivery_method(monkeypatch):
from enroll.diff import send_email
monkeypatch.setattr("enroll.diff.shutil.which", lambda name: None)
with pytest.raises(RuntimeError):
send_email(
subject="Subj",
body="Body",
from_addr="a@example.com",
to_addrs=["b@example.com"],
)

View file

@ -0,0 +1,31 @@
from __future__ import annotations
import os
from pathlib import Path
def test_stat_triplet_falls_back_to_numeric_ids(tmp_path: Path, monkeypatch):
"""If uid/gid cannot be resolved, stat_triplet should return numeric strings."""
from enroll.fsutil import stat_triplet
p = tmp_path / "f"
p.write_text("x", encoding="utf-8")
os.chmod(p, 0o644)
import grp
import pwd
def _no_user(_uid): # pragma: no cover - executed via monkeypatch
raise KeyError
def _no_group(_gid): # pragma: no cover - executed via monkeypatch
raise KeyError
monkeypatch.setattr(pwd, "getpwuid", _no_user)
monkeypatch.setattr(grp, "getgrgid", _no_group)
owner, group, mode = stat_triplet(str(p))
assert owner.isdigit()
assert group.isdigit()
assert mode == "0644"

56
tests/test_ignore_dir.py Normal file
View file

@ -0,0 +1,56 @@
from __future__ import annotations
from pathlib import Path
def test_iter_effective_lines_skips_comments_and_block_comments():
from enroll.ignore import IgnorePolicy
policy = IgnorePolicy(deny_globs=[])
content = b"""
# comment
; semi
// slash
* c-star
valid=1
/* block
ignored=1
*/
valid=2
"""
lines = [l.strip() for l in policy.iter_effective_lines(content)]
assert lines == [b"valid=1", b"valid=2"]
def test_deny_reason_dir_behaviour(tmp_path: Path):
from enroll.ignore import IgnorePolicy
# Use an absolute pattern matching our temporary path.
deny_glob = str(tmp_path / "deny") + "/*"
pol = IgnorePolicy(deny_globs=[deny_glob], dangerous=False)
d = tmp_path / "dir"
d.mkdir()
f = tmp_path / "file"
f.write_text("x", encoding="utf-8")
link = tmp_path / "link"
link.symlink_to(d)
assert pol.deny_reason_dir(str(d)) is None
assert pol.deny_reason_dir(str(link)) == "symlink"
assert pol.deny_reason_dir(str(f)) == "not_directory"
# Denied by glob.
deny_path = tmp_path / "deny" / "x"
deny_path.mkdir(parents=True)
assert pol.deny_reason_dir(str(deny_path)) == "denied_path"
# Missing/unreadable.
assert pol.deny_reason_dir(str(tmp_path / "missing")) == "unreadable"
# Dangerous disables deny_globs.
pol2 = IgnorePolicy(deny_globs=[deny_glob], dangerous=True)
assert pol2.deny_reason_dir(str(deny_path)) is None

View file

@ -0,0 +1,72 @@
from __future__ import annotations
from collections import defaultdict
def test_dpkg_backend_modified_paths_marks_conffiles_and_packaged(monkeypatch):
from enroll.platform import DpkgBackend
# Provide fake conffiles md5sums.
monkeypatch.setattr(
"enroll.debian.parse_status_conffiles",
lambda: {"mypkg": {"/etc/mypkg.conf": "aaaa"}},
)
monkeypatch.setattr(
"enroll.debian.read_pkg_md5sums",
lambda _pkg: {"etc/other.conf": "bbbb"},
)
# Fake file_md5 values (avoids touching /etc).
def fake_md5(p: str):
if p == "/etc/mypkg.conf":
return "zzzz" # differs from conffile baseline
if p == "/etc/other.conf":
return "cccc" # differs from packaged baseline
if p == "/etc/apt/sources.list":
return "bbbb"
return None
monkeypatch.setattr("enroll.platform.file_md5", fake_md5)
b = DpkgBackend()
out = b.modified_paths(
"mypkg",
["/etc/mypkg.conf", "/etc/other.conf", "/etc/apt/sources.list"],
)
assert out["/etc/mypkg.conf"] == "modified_conffile"
assert out["/etc/other.conf"] == "modified_packaged_file"
# pkg config paths (like /etc/apt/...) are excluded.
assert "/etc/apt/sources.list" not in out
def test_rpm_backend_modified_paths_caches_queries(monkeypatch):
from enroll.platform import RpmBackend
calls = defaultdict(int)
def fake_modified(_pkg=None):
calls["modified"] += 1
return {"/etc/foo.conf", "/etc/bar.conf"}
def fake_config(_pkg=None):
calls["config"] += 1
return {"/etc/foo.conf"}
monkeypatch.setattr("enroll.rpm.rpm_modified_files", fake_modified)
monkeypatch.setattr("enroll.rpm.rpm_config_files", fake_config)
b = RpmBackend()
etc = ["/etc/foo.conf", "/etc/bar.conf", "/etc/baz.conf"]
out1 = b.modified_paths("ignored", etc)
out2 = b.modified_paths("ignored", etc)
assert out1 == out2
assert out1["/etc/foo.conf"] == "modified_conffile"
assert out1["/etc/bar.conf"] == "modified_packaged_file"
assert "/etc/baz.conf" not in out1
# Caches should mean we only queried rpm once.
assert calls["modified"] == 1
assert calls["config"] == 1

View file

@ -49,7 +49,7 @@ def test_safe_extract_tar_rejects_symlinks(tmp_path: Path):
_safe_extract_tar(tf, tmp_path) _safe_extract_tar(tf, tmp_path)
def test_remote_harvest_happy_path_requests_pty_for_sudo(tmp_path: Path, monkeypatch): def test_remote_harvest_happy_path(tmp_path: Path, monkeypatch):
import sys import sys
import enroll.remote as r import enroll.remote as r
@ -65,6 +65,7 @@ def test_remote_harvest_happy_path_requests_pty_for_sudo(tmp_path: Path, monkeyp
# Prepare a tiny harvest bundle tar stream from the "remote". # Prepare a tiny harvest bundle tar stream from the "remote".
tgz = _make_tgz_bytes({"state.json": b'{"ok": true}\n'}) tgz = _make_tgz_bytes({"state.json": b'{"ok": true}\n'})
# Track each SSH exec_command call with whether a PTY was requested.
calls: list[tuple[str, bool]] = [] calls: list[tuple[str, bool]] = []
class _Chan: class _Chan:
@ -116,9 +117,8 @@ def test_remote_harvest_happy_path_requests_pty_for_sudo(tmp_path: Path, monkeyp
def open_sftp(self): def open_sftp(self):
return self._sftp return self._sftp
def exec_command(self, cmd: str, get_pty: bool = False): def exec_command(self, cmd: str, *, get_pty: bool = False, **_kwargs):
calls.append((cmd, bool(get_pty))) calls.append((cmd, bool(get_pty)))
# The tar stream uses exec_command directly. # The tar stream uses exec_command directly.
if cmd.startswith("tar -cz -C"): if cmd.startswith("tar -cz -C"):
return (None, _Stdout(tgz, rc=0), _Stderr(b"")) return (None, _Stdout(tgz, rc=0), _Stderr(b""))
@ -169,15 +169,122 @@ def test_remote_harvest_happy_path_requests_pty_for_sudo(tmp_path: Path, monkeyp
assert b"ok" in state_path.read_bytes() assert b"ok" in state_path.read_bytes()
# Ensure we attempted remote harvest with sudo and passed include/exclude and dangerous. # Ensure we attempted remote harvest with sudo and passed include/exclude and dangerous.
joined = "\n".join([c for c, _ in calls]) joined = "\n".join([c for c, _pty in calls])
assert "sudo" in joined assert "sudo" in joined
assert "--dangerous" in joined assert "--dangerous" in joined
assert "--include-path" in joined assert "--include-path" in joined
assert "--exclude-path" in joined assert "--exclude-path" in joined
# Assert PTY is requested for sudo commands (harvest & chown), not for tar streaming. # Ensure PTY is used for sudo commands (sudoers requiretty) but not for tar.
sudo_cmds = [(c, pty) for c, pty in calls if c.startswith("sudo ")] pty_by_cmd = {c: pty for c, pty in calls}
assert sudo_cmds, "expected at least one sudo command" assert pty_by_cmd.get("id -un") is False
assert all(pty for _, pty in sudo_cmds) assert any(
tar_cmds = [(c, pty) for c, pty in calls if c.startswith("tar -cz -C")] c.startswith("sudo") and " harvest " in c and pty is True for c, pty in calls
assert tar_cmds and all(not pty for _, pty in tar_cmds) )
assert any(c.startswith("sudo chown -R") and pty is True for c, pty in calls)
assert any(c.startswith("tar -cz -C") and pty is False for c, pty in calls)
def test_remote_harvest_no_sudo_does_not_request_pty_or_chown(
tmp_path: Path, monkeypatch
):
"""When --no-sudo is used we should not request a PTY nor run sudo chown."""
import sys
import enroll.remote as r
monkeypatch.setattr(
r,
"_build_enroll_pyz",
lambda td: (Path(td) / "enroll.pyz").write_bytes(b"PYZ")
or (Path(td) / "enroll.pyz"),
)
tgz = _make_tgz_bytes({"state.json": b"{}"})
calls: list[tuple[str, bool]] = []
class _Chan:
def __init__(self, rc: int = 0):
self._rc = rc
def recv_exit_status(self) -> int:
return self._rc
class _Stdout:
def __init__(self, payload: bytes = b"", rc: int = 0):
self._bio = io.BytesIO(payload)
self.channel = _Chan(rc)
def read(self, n: int = -1) -> bytes:
return self._bio.read(n)
class _Stderr:
def __init__(self, payload: bytes = b""):
self._bio = io.BytesIO(payload)
def read(self, n: int = -1) -> bytes:
return self._bio.read(n)
class _SFTP:
def put(self, _local: str, _remote: str) -> None:
return
def close(self) -> None:
return
class FakeSSH:
def __init__(self):
self._sftp = _SFTP()
def load_system_host_keys(self):
return
def set_missing_host_key_policy(self, _policy):
return
def connect(self, **_kwargs):
return
def open_sftp(self):
return self._sftp
def exec_command(self, cmd: str, *, get_pty: bool = False, **_kwargs):
calls.append((cmd, bool(get_pty)))
if cmd == "mktemp -d":
return (None, _Stdout(b"/tmp/enroll-remote-456\n"), _Stderr())
if cmd.startswith("chmod 700"):
return (None, _Stdout(b""), _Stderr())
if cmd.startswith("tar -cz -C"):
return (None, _Stdout(tgz, rc=0), _Stderr())
if " harvest " in cmd:
return (None, _Stdout(b""), _Stderr())
if cmd.startswith("rm -rf"):
return (None, _Stdout(b""), _Stderr())
return (None, _Stdout(b""), _Stderr())
def close(self):
return
import types
class RejectPolicy:
pass
monkeypatch.setitem(
sys.modules,
"paramiko",
types.SimpleNamespace(SSHClient=FakeSSH, RejectPolicy=RejectPolicy),
)
out_dir = tmp_path / "out"
r.remote_harvest(
local_out_dir=out_dir,
remote_host="example.com",
remote_user="alice",
no_sudo=True,
)
joined = "\n".join([c for c, _pty in calls])
assert "sudo" not in joined
assert "sudo chown" not in joined
assert any(" harvest " in c and pty is False for c, pty in calls)

View file

@ -0,0 +1,36 @@
from __future__ import annotations
import sys
import types
def test_get_enroll_version_returns_unknown_when_import_fails(monkeypatch):
from enroll.version import get_enroll_version
# Ensure both the module cache and the parent package attribute are redirected.
import importlib
dummy = types.ModuleType("importlib.metadata")
# Missing attributes will cause ImportError when importing names.
monkeypatch.setitem(sys.modules, "importlib.metadata", dummy)
monkeypatch.setattr(importlib, "metadata", dummy, raising=False)
assert get_enroll_version() == "unknown"
def test_get_enroll_version_uses_packages_distributions(monkeypatch):
# Restore the real module for this test.
monkeypatch.delitem(sys.modules, "importlib.metadata", raising=False)
import importlib.metadata
from enroll.version import get_enroll_version
monkeypatch.setattr(
importlib.metadata,
"packages_distributions",
lambda: {"enroll": ["enroll-dist"]},
)
monkeypatch.setattr(importlib.metadata, "version", lambda dist: "9.9.9")
assert get_enroll_version() == "9.9.9"