diff --git a/tests/test_cache_security.py b/tests/test_cache_security.py new file mode 100644 index 0000000..9f31587 --- /dev/null +++ b/tests/test_cache_security.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +import os +from pathlib import Path + +import pytest + + +def test_ensure_dir_secure_refuses_symlink(tmp_path: Path): + from enroll.cache import _ensure_dir_secure + + target = tmp_path / "target" + target.mkdir() + link = tmp_path / "link" + link.symlink_to(target, target_is_directory=True) + + with pytest.raises(RuntimeError): + _ensure_dir_secure(link) + + +def test_ensure_dir_secure_ignores_chmod_failures(tmp_path: Path, monkeypatch): + from enroll.cache import _ensure_dir_secure + + d = tmp_path / "d" + + def boom(_path: str, _mode: int): + raise OSError("no") + + monkeypatch.setattr(os, "chmod", boom) + + # Should not raise. + _ensure_dir_secure(d) + assert d.exists() and d.is_dir() diff --git a/tests/test_cli_helpers.py b/tests/test_cli_helpers.py new file mode 100644 index 0000000..264ff85 --- /dev/null +++ b/tests/test_cli_helpers.py @@ -0,0 +1,177 @@ +from __future__ import annotations + +import argparse +import configparser +import types +import textwrap +from pathlib import Path + + +def test_discover_config_path_precedence(tmp_path: Path, monkeypatch): + """_discover_config_path: --config > ENROLL_CONFIG > ./enroll.ini > XDG.""" + from enroll.cli import _discover_config_path + + cfg1 = tmp_path / "one.ini" + cfg1.write_text("[enroll]\n", encoding="utf-8") + + # Explicit --config should win. + assert _discover_config_path(["--config", str(cfg1)]) == cfg1 + + # --no-config disables config loading. + assert _discover_config_path(["--no-config", "--config", str(cfg1)]) is None + + monkeypatch.chdir(tmp_path) + + cfg2 = tmp_path / "two.ini" + cfg2.write_text("[enroll]\n", encoding="utf-8") + monkeypatch.setenv("ENROLL_CONFIG", str(cfg2)) + assert _discover_config_path([]) == cfg2 + + # Local ./enroll.ini fallback. + monkeypatch.delenv("ENROLL_CONFIG", raising=False) + local = tmp_path / "enroll.ini" + local.write_text("[enroll]\n", encoding="utf-8") + assert _discover_config_path([]) == local + + # XDG fallback. + local.unlink() + xdg = tmp_path / "xdg" + cfg3 = xdg / "enroll" / "enroll.ini" + cfg3.parent.mkdir(parents=True) + cfg3.write_text("[enroll]\n", encoding="utf-8") + monkeypatch.setenv("XDG_CONFIG_HOME", str(xdg)) + assert _discover_config_path([]) == cfg3 + + +def test_config_value_parsing_and_list_splitting(): + from enroll.cli import _parse_bool, _split_list_value + + assert _parse_bool("1") is True + assert _parse_bool("yes") is True + assert _parse_bool("false") is False + + assert _parse_bool("maybe") is None + + assert _split_list_value("a,b , c") == ["a", "b", "c"] + # When newlines are present, we split on lines (not commas within a line). + assert _split_list_value("a,b\nc") == ["a,b", "c"] + assert _split_list_value("a\n\n b\n") == ["a", "b"] + assert _split_list_value(" ") == [] + + +def test_section_to_argv_handles_types_and_unknown_keys(capsys): + from enroll.cli import _section_to_argv + + p = argparse.ArgumentParser(add_help=False) + p.add_argument("--dangerous", action="store_true") + p.add_argument("--no-color", dest="color", action="store_false") + p.add_argument("--include-path", dest="include_path", action="append") + p.add_argument("-v", action="count", default=0) + p.add_argument("--out") + + cfg = configparser.ConfigParser() + cfg.read_dict( + { + "harvest": { + "dangerous": "true", + # Keys are matched by argparse dest; store_false actions still use dest. + "color": "false", + "include-path": "a,b,c", + "v": "2", + "out": "/tmp/bundle", + "unknown": "ignored", + } + } + ) + + argv = _section_to_argv(p, cfg, "harvest") + + # Boolean store_true. + assert "--dangerous" in argv + + # Boolean store_false: include the flag only when config wants False. + assert "--no-color" in argv + + # Append: split lists and add one flag per item. + assert argv.count("--include-path") == 3 + assert "a" in argv and "b" in argv and "c" in argv + + # Count: repeats. + assert argv.count("-v") == 2 + + # Scalar. + assert "--out" in argv and "/tmp/bundle" in argv + + err = capsys.readouterr().err + assert "unknown option" in err + + +def test_inject_config_argv_inserts_global_and_subcommand(tmp_path: Path, capsys): + from enroll.cli import _inject_config_argv + + cfg = tmp_path / "enroll.ini" + cfg.write_text( + textwrap.dedent( + """ + [enroll] + dangerous = true + + [harvest] + include-path = /etc/foo + unknown = 1 + """ + ).strip() + + "\n", + encoding="utf-8", + ) + + root = argparse.ArgumentParser(add_help=False) + root.add_argument("--dangerous", action="store_true") + + harvest_p = argparse.ArgumentParser(add_help=False) + harvest_p.add_argument("--include-path", dest="include_path", action="append") + + argv = _inject_config_argv( + ["harvest", "--out", "x"], + cfg_path=cfg, + root_parser=root, + subparsers={"harvest": harvest_p}, + ) + + # Global tokens should appear before the subcommand. + assert argv[0] == "--dangerous" + assert argv[1] == "harvest" + + # Subcommand tokens should appear right after the subcommand. + assert argv[2:4] == ["--include-path", "/etc/foo"] + + # Unknown option should have produced a warning. + assert "unknown option" in capsys.readouterr().err + + +def test_resolve_sops_out_file(tmp_path: Path, monkeypatch): + from enroll import cli + + # Make a predictable cache dir for the default case. + fake_cache = types.SimpleNamespace(dir=tmp_path / "cache") + fake_cache.dir.mkdir(parents=True) + monkeypatch.setattr(cli, "new_harvest_cache_dir", lambda hint=None: fake_cache) + + # If out is a directory, use it directly. + out_dir = tmp_path / "out" + out_dir.mkdir() + # The output filename is fixed; hint is only used when creating a cache dir. + assert ( + cli._resolve_sops_out_file(out=out_dir, hint="bundle.tar.gz") + == out_dir / "harvest.tar.gz.sops" + ) + + # If out is a file path, keep it. + out_file = tmp_path / "x.sops" + assert cli._resolve_sops_out_file(out=out_file, hint="bundle.tar.gz") == out_file + + # None uses the cache dir, and the name is fixed. + assert ( + cli._resolve_sops_out_file(out=None, hint="bundle.tar.gz") + == fake_cache.dir / "harvest.tar.gz.sops" + ) diff --git a/tests/test_diff_notifications.py b/tests/test_diff_notifications.py new file mode 100644 index 0000000..53f6b57 --- /dev/null +++ b/tests/test_diff_notifications.py @@ -0,0 +1,83 @@ +from __future__ import annotations + +import types + +import pytest + + +def test_post_webhook_success(monkeypatch): + from enroll.diff import post_webhook + + class FakeResp: + status = 204 + + def read(self): + return b"OK" + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + monkeypatch.setattr( + "enroll.diff.urllib.request.urlopen", + lambda req, timeout=30: FakeResp(), + ) + + status, body = post_webhook("https://example.com", b"x") + assert status == 204 + assert body == "OK" + + +def test_post_webhook_raises_on_error(monkeypatch): + from enroll.diff import post_webhook + + def boom(_req, timeout=30): + raise OSError("nope") + + monkeypatch.setattr("enroll.diff.urllib.request.urlopen", boom) + + with pytest.raises(RuntimeError): + post_webhook("https://example.com", b"x") + + +def test_send_email_uses_sendmail_when_present(monkeypatch): + from enroll.diff import send_email + + calls = {} + + monkeypatch.setattr("enroll.diff.shutil.which", lambda name: "/usr/sbin/sendmail") + + def fake_run(argv, input=None, check=False, **_kwargs): + calls["argv"] = argv + calls["input"] = input + return types.SimpleNamespace(returncode=0) + + monkeypatch.setattr("enroll.diff.subprocess.run", fake_run) + + send_email( + subject="Subj", + body="Body", + from_addr="a@example.com", + to_addrs=["b@example.com"], + ) + + assert calls["argv"][0].endswith("sendmail") + msg = (calls["input"] or b"").decode("utf-8", errors="replace") + assert "Subject: Subj" in msg + assert "To: b@example.com" in msg + + +def test_send_email_raises_when_no_delivery_method(monkeypatch): + from enroll.diff import send_email + + monkeypatch.setattr("enroll.diff.shutil.which", lambda name: None) + + with pytest.raises(RuntimeError): + send_email( + subject="Subj", + body="Body", + from_addr="a@example.com", + to_addrs=["b@example.com"], + ) diff --git a/tests/test_fsutil_extra.py b/tests/test_fsutil_extra.py new file mode 100644 index 0000000..9b70a67 --- /dev/null +++ b/tests/test_fsutil_extra.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +import os +from pathlib import Path + + +def test_stat_triplet_falls_back_to_numeric_ids(tmp_path: Path, monkeypatch): + """If uid/gid cannot be resolved, stat_triplet should return numeric strings.""" + from enroll.fsutil import stat_triplet + + p = tmp_path / "f" + p.write_text("x", encoding="utf-8") + os.chmod(p, 0o644) + + import grp + import pwd + + def _no_user(_uid): # pragma: no cover - executed via monkeypatch + raise KeyError + + def _no_group(_gid): # pragma: no cover - executed via monkeypatch + raise KeyError + + monkeypatch.setattr(pwd, "getpwuid", _no_user) + monkeypatch.setattr(grp, "getgrgid", _no_group) + + owner, group, mode = stat_triplet(str(p)) + + assert owner.isdigit() + assert group.isdigit() + assert mode == "0644" diff --git a/tests/test_ignore_dir.py b/tests/test_ignore_dir.py new file mode 100644 index 0000000..3066c92 --- /dev/null +++ b/tests/test_ignore_dir.py @@ -0,0 +1,56 @@ +from __future__ import annotations + +from pathlib import Path + + +def test_iter_effective_lines_skips_comments_and_block_comments(): + from enroll.ignore import IgnorePolicy + + policy = IgnorePolicy(deny_globs=[]) + + content = b""" +# comment +; semi +// slash +* c-star + +valid=1 +/* block +ignored=1 +*/ +valid=2 +""" + + lines = [l.strip() for l in policy.iter_effective_lines(content)] + assert lines == [b"valid=1", b"valid=2"] + + +def test_deny_reason_dir_behaviour(tmp_path: Path): + from enroll.ignore import IgnorePolicy + + # Use an absolute pattern matching our temporary path. + deny_glob = str(tmp_path / "deny") + "/*" + pol = IgnorePolicy(deny_globs=[deny_glob], dangerous=False) + + d = tmp_path / "dir" + d.mkdir() + f = tmp_path / "file" + f.write_text("x", encoding="utf-8") + link = tmp_path / "link" + link.symlink_to(d) + + assert pol.deny_reason_dir(str(d)) is None + assert pol.deny_reason_dir(str(link)) == "symlink" + assert pol.deny_reason_dir(str(f)) == "not_directory" + + # Denied by glob. + deny_path = tmp_path / "deny" / "x" + deny_path.mkdir(parents=True) + assert pol.deny_reason_dir(str(deny_path)) == "denied_path" + + # Missing/unreadable. + assert pol.deny_reason_dir(str(tmp_path / "missing")) == "unreadable" + + # Dangerous disables deny_globs. + pol2 = IgnorePolicy(deny_globs=[deny_glob], dangerous=True) + assert pol2.deny_reason_dir(str(deny_path)) is None diff --git a/tests/test_platform_backends.py b/tests/test_platform_backends.py new file mode 100644 index 0000000..6716d53 --- /dev/null +++ b/tests/test_platform_backends.py @@ -0,0 +1,72 @@ +from __future__ import annotations + +from collections import defaultdict + + +def test_dpkg_backend_modified_paths_marks_conffiles_and_packaged(monkeypatch): + from enroll.platform import DpkgBackend + + # Provide fake conffiles md5sums. + monkeypatch.setattr( + "enroll.debian.parse_status_conffiles", + lambda: {"mypkg": {"/etc/mypkg.conf": "aaaa"}}, + ) + monkeypatch.setattr( + "enroll.debian.read_pkg_md5sums", + lambda _pkg: {"etc/other.conf": "bbbb"}, + ) + + # Fake file_md5 values (avoids touching /etc). + def fake_md5(p: str): + if p == "/etc/mypkg.conf": + return "zzzz" # differs from conffile baseline + if p == "/etc/other.conf": + return "cccc" # differs from packaged baseline + if p == "/etc/apt/sources.list": + return "bbbb" + return None + + monkeypatch.setattr("enroll.platform.file_md5", fake_md5) + + b = DpkgBackend() + out = b.modified_paths( + "mypkg", + ["/etc/mypkg.conf", "/etc/other.conf", "/etc/apt/sources.list"], + ) + + assert out["/etc/mypkg.conf"] == "modified_conffile" + assert out["/etc/other.conf"] == "modified_packaged_file" + # pkg config paths (like /etc/apt/...) are excluded. + assert "/etc/apt/sources.list" not in out + + +def test_rpm_backend_modified_paths_caches_queries(monkeypatch): + from enroll.platform import RpmBackend + + calls = defaultdict(int) + + def fake_modified(_pkg=None): + calls["modified"] += 1 + return {"/etc/foo.conf", "/etc/bar.conf"} + + def fake_config(_pkg=None): + calls["config"] += 1 + return {"/etc/foo.conf"} + + monkeypatch.setattr("enroll.rpm.rpm_modified_files", fake_modified) + monkeypatch.setattr("enroll.rpm.rpm_config_files", fake_config) + + b = RpmBackend() + etc = ["/etc/foo.conf", "/etc/bar.conf", "/etc/baz.conf"] + + out1 = b.modified_paths("ignored", etc) + out2 = b.modified_paths("ignored", etc) + + assert out1 == out2 + assert out1["/etc/foo.conf"] == "modified_conffile" + assert out1["/etc/bar.conf"] == "modified_packaged_file" + assert "/etc/baz.conf" not in out1 + + # Caches should mean we only queried rpm once. + assert calls["modified"] == 1 + assert calls["config"] == 1 diff --git a/tests/test_remote.py b/tests/test_remote.py index 387a397..1f9c89b 100644 --- a/tests/test_remote.py +++ b/tests/test_remote.py @@ -49,7 +49,7 @@ def test_safe_extract_tar_rejects_symlinks(tmp_path: Path): _safe_extract_tar(tf, tmp_path) -def test_remote_harvest_happy_path_requests_pty_for_sudo(tmp_path: Path, monkeypatch): +def test_remote_harvest_happy_path(tmp_path: Path, monkeypatch): import sys import enroll.remote as r @@ -65,6 +65,7 @@ def test_remote_harvest_happy_path_requests_pty_for_sudo(tmp_path: Path, monkeyp # Prepare a tiny harvest bundle tar stream from the "remote". tgz = _make_tgz_bytes({"state.json": b'{"ok": true}\n'}) + # Track each SSH exec_command call with whether a PTY was requested. calls: list[tuple[str, bool]] = [] class _Chan: @@ -116,9 +117,8 @@ def test_remote_harvest_happy_path_requests_pty_for_sudo(tmp_path: Path, monkeyp def open_sftp(self): return self._sftp - def exec_command(self, cmd: str, get_pty: bool = False): + def exec_command(self, cmd: str, *, get_pty: bool = False, **_kwargs): calls.append((cmd, bool(get_pty))) - # The tar stream uses exec_command directly. if cmd.startswith("tar -cz -C"): return (None, _Stdout(tgz, rc=0), _Stderr(b"")) @@ -169,15 +169,122 @@ def test_remote_harvest_happy_path_requests_pty_for_sudo(tmp_path: Path, monkeyp assert b"ok" in state_path.read_bytes() # Ensure we attempted remote harvest with sudo and passed include/exclude and dangerous. - joined = "\n".join([c for c, _ in calls]) + joined = "\n".join([c for c, _pty in calls]) assert "sudo" in joined assert "--dangerous" in joined assert "--include-path" in joined assert "--exclude-path" in joined - # Assert PTY is requested for sudo commands (harvest & chown), not for tar streaming. - sudo_cmds = [(c, pty) for c, pty in calls if c.startswith("sudo ")] - assert sudo_cmds, "expected at least one sudo command" - assert all(pty for _, pty in sudo_cmds) - tar_cmds = [(c, pty) for c, pty in calls if c.startswith("tar -cz -C")] - assert tar_cmds and all(not pty for _, pty in tar_cmds) + # Ensure PTY is used for sudo commands (sudoers requiretty) but not for tar. + pty_by_cmd = {c: pty for c, pty in calls} + assert pty_by_cmd.get("id -un") is False + assert any( + c.startswith("sudo") and " harvest " in c and pty is True for c, pty in calls + ) + assert any(c.startswith("sudo chown -R") and pty is True for c, pty in calls) + assert any(c.startswith("tar -cz -C") and pty is False for c, pty in calls) + + +def test_remote_harvest_no_sudo_does_not_request_pty_or_chown( + tmp_path: Path, monkeypatch +): + """When --no-sudo is used we should not request a PTY nor run sudo chown.""" + import sys + + import enroll.remote as r + + monkeypatch.setattr( + r, + "_build_enroll_pyz", + lambda td: (Path(td) / "enroll.pyz").write_bytes(b"PYZ") + or (Path(td) / "enroll.pyz"), + ) + + tgz = _make_tgz_bytes({"state.json": b"{}"}) + calls: list[tuple[str, bool]] = [] + + class _Chan: + def __init__(self, rc: int = 0): + self._rc = rc + + def recv_exit_status(self) -> int: + return self._rc + + class _Stdout: + def __init__(self, payload: bytes = b"", rc: int = 0): + self._bio = io.BytesIO(payload) + self.channel = _Chan(rc) + + def read(self, n: int = -1) -> bytes: + return self._bio.read(n) + + class _Stderr: + def __init__(self, payload: bytes = b""): + self._bio = io.BytesIO(payload) + + def read(self, n: int = -1) -> bytes: + return self._bio.read(n) + + class _SFTP: + def put(self, _local: str, _remote: str) -> None: + return + + def close(self) -> None: + return + + class FakeSSH: + def __init__(self): + self._sftp = _SFTP() + + def load_system_host_keys(self): + return + + def set_missing_host_key_policy(self, _policy): + return + + def connect(self, **_kwargs): + return + + def open_sftp(self): + return self._sftp + + def exec_command(self, cmd: str, *, get_pty: bool = False, **_kwargs): + calls.append((cmd, bool(get_pty))) + if cmd == "mktemp -d": + return (None, _Stdout(b"/tmp/enroll-remote-456\n"), _Stderr()) + if cmd.startswith("chmod 700"): + return (None, _Stdout(b""), _Stderr()) + if cmd.startswith("tar -cz -C"): + return (None, _Stdout(tgz, rc=0), _Stderr()) + if " harvest " in cmd: + return (None, _Stdout(b""), _Stderr()) + if cmd.startswith("rm -rf"): + return (None, _Stdout(b""), _Stderr()) + return (None, _Stdout(b""), _Stderr()) + + def close(self): + return + + import types + + class RejectPolicy: + pass + + monkeypatch.setitem( + sys.modules, + "paramiko", + types.SimpleNamespace(SSHClient=FakeSSH, RejectPolicy=RejectPolicy), + ) + + out_dir = tmp_path / "out" + r.remote_harvest( + local_out_dir=out_dir, + remote_host="example.com", + remote_user="alice", + no_sudo=True, + ) + + joined = "\n".join([c for c, _pty in calls]) + assert "sudo" not in joined + assert "sudo chown" not in joined + assert any(" harvest " in c and pty is False for c, pty in calls) diff --git a/tests/test_version_extra.py b/tests/test_version_extra.py new file mode 100644 index 0000000..a5adc1a --- /dev/null +++ b/tests/test_version_extra.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +import sys +import types + + +def test_get_enroll_version_returns_unknown_when_import_fails(monkeypatch): + from enroll.version import get_enroll_version + + # Ensure both the module cache and the parent package attribute are redirected. + import importlib + + dummy = types.ModuleType("importlib.metadata") + # Missing attributes will cause ImportError when importing names. + monkeypatch.setitem(sys.modules, "importlib.metadata", dummy) + monkeypatch.setattr(importlib, "metadata", dummy, raising=False) + + assert get_enroll_version() == "unknown" + + +def test_get_enroll_version_uses_packages_distributions(monkeypatch): + # Restore the real module for this test. + monkeypatch.delitem(sys.modules, "importlib.metadata", raising=False) + + import importlib.metadata + + from enroll.version import get_enroll_version + + monkeypatch.setattr( + importlib.metadata, + "packages_distributions", + lambda: {"enroll": ["enroll-dist"]}, + ) + monkeypatch.setattr(importlib.metadata, "version", lambda dist: "9.9.9") + + assert get_enroll_version() == "9.9.9"