diff --git a/tests/test_harvest.py b/tests/test_harvest.py index 35e2c5c..93dfd90 100644 --- a/tests/test_harvest.py +++ b/tests/test_harvest.py @@ -1,10 +1,28 @@ import json -import enroll.harvest as harvest +import os +import pytest + from pathlib import Path -import enroll.harvest as h +import enroll.harvest as harvest from enroll.platform import PlatformInfo from enroll.systemd import UnitInfo +from enroll.pathfilter import PathFilter +from enroll.harvest import ( + _is_confish, + _hint_names, + _topdirs_for_package, + _iter_matching_files, + _parse_apt_signed_by, + _capture_link, + _capture_file, + ManagedFile, + ManagedLink, + ExcludedFile, + IgnorePolicy, +) + +from unittest.mock import MagicMock class AllowAllPolicy: @@ -155,17 +173,17 @@ def test_harvest_dedup_manual_packages_and_builds_etc_custom( else: yield (root, [], []) - monkeypatch.setattr(h.os.path, "isfile", fake_isfile) - monkeypatch.setattr(h.os.path, "isdir", fake_isdir) - monkeypatch.setattr(h.os.path, "islink", fake_islink) - monkeypatch.setattr(h.os.path, "exists", fake_exists) - monkeypatch.setattr(h.os, "walk", fake_walk) + monkeypatch.setattr(harvest.os.path, "isfile", fake_isfile) + monkeypatch.setattr(harvest.os.path, "isdir", fake_isdir) + monkeypatch.setattr(harvest.os.path, "islink", fake_islink) + monkeypatch.setattr(harvest.os.path, "exists", fake_exists) + monkeypatch.setattr(harvest.os, "walk", fake_walk) # Avoid real system access - monkeypatch.setattr(h, "list_enabled_services", lambda: ["openvpn.service"]) - monkeypatch.setattr(h, "list_enabled_timers", lambda: []) + monkeypatch.setattr(harvest, "list_enabled_services", lambda: ["openvpn.service"]) + monkeypatch.setattr(harvest, "list_enabled_timers", lambda: []) monkeypatch.setattr( - h, + harvest, "get_unit_info", lambda unit: UnitInfo( name=unit, @@ -200,11 +218,11 @@ def test_harvest_dedup_manual_packages_and_builds_etc_custom( ) monkeypatch.setattr( - h, "detect_platform", lambda: PlatformInfo("debian", "dpkg", {}) + harvest, "detect_platform", lambda: PlatformInfo("debian", "dpkg", {}) ) - monkeypatch.setattr(h, "get_backend", lambda info=None: backend) + monkeypatch.setattr(harvest, "get_backend", lambda info=None: backend) - monkeypatch.setattr(h, "collect_non_system_users", lambda: []) + monkeypatch.setattr(harvest, "collect_non_system_users", lambda: []) def fake_stat_triplet(p: str): if p == "/usr/local/bin/myscript": @@ -212,7 +230,7 @@ def test_harvest_dedup_manual_packages_and_builds_etc_custom( # /usr/local/bin/readme.txt remains non-executable return ("root", "root", "0644") - monkeypatch.setattr(h, "stat_triplet", fake_stat_triplet) + monkeypatch.setattr(harvest, "stat_triplet", fake_stat_triplet) # Avoid needing source files on disk by implementing our own bundle copier def fake_copy(bundle_dir: str, role_name: str, abs_path: str, src_rel: str): @@ -220,9 +238,9 @@ def test_harvest_dedup_manual_packages_and_builds_etc_custom( dst.parent.mkdir(parents=True, exist_ok=True) dst.write_bytes(files.get(abs_path, b"")) - monkeypatch.setattr(h, "_copy_into_bundle", fake_copy) + monkeypatch.setattr(harvest, "_copy_into_bundle", fake_copy) - state_path = h.harvest(str(bundle), policy=AllowAllPolicy()) + state_path = harvest.harvest(str(bundle), policy=AllowAllPolicy()) st = json.loads(Path(state_path).read_text(encoding="utf-8")) inv = st["inventory"]["packages"] @@ -275,21 +293,25 @@ def test_shared_cron_snippet_prefers_matching_role_over_lexicographic( files = {"/etc/cron.d/ntpsec": b"# cron\n"} dirs = {"/etc", "/etc/cron.d"} - monkeypatch.setattr(h.os.path, "isfile", lambda p: p in files) - monkeypatch.setattr(h.os.path, "islink", lambda p: False) - monkeypatch.setattr(h.os.path, "isdir", lambda p: p in dirs) - monkeypatch.setattr(h.os.path, "exists", lambda p: p in files or p in dirs) - monkeypatch.setattr(h.os, "walk", lambda root: [("/etc/cron.d", [], ["ntpsec"])]) + monkeypatch.setattr(harvest.os.path, "isfile", lambda p: p in files) + monkeypatch.setattr(harvest.os.path, "islink", lambda p: False) + monkeypatch.setattr(harvest.os.path, "isdir", lambda p: p in dirs) + monkeypatch.setattr(harvest.os.path, "exists", lambda p: p in files or p in dirs) + monkeypatch.setattr( + harvest.os, "walk", lambda root: [("/etc/cron.d", [], ["ntpsec"])] + ) # Only include the cron snippet in the system capture set. monkeypatch.setattr( - h, "_iter_system_capture_paths", lambda: [("/etc/cron.d/ntpsec", "system_cron")] + harvest, + "_iter_system_capture_paths", + lambda: [("/etc/cron.d/ntpsec", "system_cron")], ) monkeypatch.setattr( - h, "list_enabled_services", lambda: ["apparmor.service", "ntpsec.service"] + harvest, "list_enabled_services", lambda: ["apparmor.service", "ntpsec.service"] ) - monkeypatch.setattr(h, "list_enabled_timers", lambda: []) + monkeypatch.setattr(harvest, "list_enabled_timers", lambda: []) def fake_unit_info(unit: str) -> UnitInfo: if unit == "apparmor.service": @@ -316,7 +338,7 @@ def test_shared_cron_snippet_prefers_matching_role_over_lexicographic( condition_result=None, ) - monkeypatch.setattr(h, "get_unit_info", fake_unit_info) + monkeypatch.setattr(harvest, "get_unit_info", fake_unit_info) # Make apparmor *also* claim the ntpsec package (simulates overly-broad # package inference). The snippet routing should still prefer role 'ntpsec'. @@ -341,21 +363,21 @@ def test_shared_cron_snippet_prefers_matching_role_over_lexicographic( ) monkeypatch.setattr( - h, "detect_platform", lambda: PlatformInfo("debian", "dpkg", {}) + harvest, "detect_platform", lambda: PlatformInfo("debian", "dpkg", {}) ) - monkeypatch.setattr(h, "get_backend", lambda info=None: backend) + monkeypatch.setattr(harvest, "get_backend", lambda info=None: backend) - monkeypatch.setattr(h, "stat_triplet", lambda p: ("root", "root", "0644")) - monkeypatch.setattr(h, "collect_non_system_users", lambda: []) + monkeypatch.setattr(harvest, "stat_triplet", lambda p: ("root", "root", "0644")) + monkeypatch.setattr(harvest, "collect_non_system_users", lambda: []) def fake_copy(bundle_dir: str, role_name: str, abs_path: str, src_rel: str): dst = Path(bundle_dir) / "artifacts" / role_name / src_rel dst.parent.mkdir(parents=True, exist_ok=True) dst.write_bytes(files[abs_path]) - monkeypatch.setattr(h, "_copy_into_bundle", fake_copy) + monkeypatch.setattr(harvest, "_copy_into_bundle", fake_copy) - state_path = h.harvest(str(bundle), policy=AllowAllPolicy()) + state_path = harvest.harvest(str(bundle), policy=AllowAllPolicy()) st = json.loads(Path(state_path).read_text(encoding="utf-8")) # Cron snippet should end up attached to the ntpsec role, not apparmor. @@ -607,3 +629,408 @@ def test_is_confish_not_config(tmp_path: Path): def test_is_confish_nonexistent(): """Test _is_confish returns False for nonexistent files.""" assert harvest._is_confish("/nonexistent/file.xyz") is False + + +"""Additional coverage tests for harvest.py""" + + +class TestIsConfish: + """Tests for _is_confish function""" + + def test_is_confish_true_extensions(self, tmp_path): + """Test files with config extensions are detected.""" + for ext in [".conf", ".cfg", ".ini", ".yaml", ".json", ".cnf"]: + f = tmp_path / f"test{ext}" + f.write_text("test", encoding="utf-8") + assert _is_confish(str(f)) is True + + def test_is_confish_false(self, tmp_path): + """Test non-config files are not detected.""" + for name in ["data.txt", "script.sh"]: + f = tmp_path / name + f.write_text("test", encoding="utf-8") + assert _is_confish(str(f)) is False + + +class TestHintNames: + """Tests for _hint_names function""" + + def test_hint_names_simple(self): + """Test simple hint name extraction.""" + result = _hint_names("nginx", {"nginx"}) + assert "nginx" in result + + def test_hint_names_multiple(self): + """Test multiple hint names.""" + result = _hint_names("nginx", {"apache"}) + assert "nginx" in result + assert "apache" in result + + def test_hint_names_empty(self): + """Test empty hint names.""" + result = _hint_names("", set()) + assert result == set() + + def test_hint_names_with_service(self): + """Test hint names with .service suffix.""" + result = _hint_names("nginx.service", set()) + assert "nginx" in result + + def test_hint_names_with_template(self): + """Test hint names with template unit.""" + result = _hint_names("nginx@.service", set()) + assert "nginx" in result + + +class TestTopdirsForPackage: + """Tests for _topdirs_for_package function""" + + def test_topdirs_single_level(self): + """Test topdirs with single level paths.""" + pkg_to_etc = {"nginx": ["/etc/nginx/nginx.conf"]} + result = _topdirs_for_package("nginx", pkg_to_etc) + assert result == {"nginx"} + + def test_topdirs_multiple_paths(self): + """Test topdirs with multiple paths.""" + pkg_to_etc = {"nginx": ["/etc/nginx/nginx.conf", "/etc/nginx/sites-enabled"]} + result = _topdirs_for_package("nginx", pkg_to_etc) + assert result == {"nginx"} + + def test_topdirs_empty(self): + """Test topdirs with empty package.""" + result = _topdirs_for_package("nonexistent", {}) + assert result == set() + + +class TestIterMatchingFiles: + """Tests for _iter_matching_files function""" + + def test_iter_matching_files_glob(self, tmp_path): + """Test glob pattern matching.""" + (tmp_path / "a.txt").write_text("a", encoding="utf-8") + (tmp_path / "b.txt").write_text("b", encoding="utf-8") + (tmp_path / "c.py").write_text("c", encoding="utf-8") + + os.chdir(tmp_path) + result = _iter_matching_files("*.txt") + assert len(result) == 2 + assert any("a.txt" in p for p in result) + assert any("b.txt" in p for p in result) + + def test_iter_matching_files_directory_walk(self, tmp_path): + """Test directory walking.""" + subdir = tmp_path / "sub" + subdir.mkdir() + (tmp_path / "a.txt").write_text("a", encoding="utf-8") + (subdir / "b.txt").write_text("b", encoding="utf-8") + + os.chdir(tmp_path) + result = _iter_matching_files(str(tmp_path)) + assert len(result) == 2 + + def test_iter_matching_files_cap(self, tmp_path): + """Test file cap limit.""" + for i in range(100): + (tmp_path / f"file{i}.txt").write_text(str(i), encoding="utf-8") + + os.chdir(tmp_path) + result = _iter_matching_files("*.txt", cap=10) + assert len(result) == 10 + + +class TestParseAptSignedBy: + """Tests for _parse_apt_signed_by function""" + + def test_parse_apt_signed_by_bracket(self, tmp_path): + """Test parsing signed-by from bracket notation.""" + sources_list = tmp_path / "sources.list" + sources_list.write_text( + "deb [signed-by=/usr/share/keyrings/nginx.gpg] http://nginx.net stable main\n", + encoding="utf-8", + ) + result = _parse_apt_signed_by([str(sources_list)]) + assert "/usr/share/keyrings/nginx.gpg" in result + + def test_parse_apt_signed_by_header(self, tmp_path): + """Test parsing signed-by from header.""" + sources_file = tmp_path / "sources.list" + sources_file.write_text( + "Signed-By: /usr/share/keyrings/foo.gpg\n", encoding="utf-8" + ) + result = _parse_apt_signed_by([str(sources_file)]) + assert "/usr/share/keyrings/foo.gpg" in result + + def test_parse_apt_signed_by_multiple(self, tmp_path): + """Test parsing multiple signed-by paths.""" + sources_file = tmp_path / "sources.list" + sources_file.write_text( + "Signed-By: /usr/share/keyrings/a.gpg, /usr/share/keyrings/b.gpg\n", + encoding="utf-8", + ) + result = _parse_apt_signed_by([str(sources_file)]) + assert "/usr/share/keyrings/a.gpg" in result + assert "/usr/share/keyrings/b.gpg" in result + + def test_parse_apt_signed_by_oserror(self, tmp_path): + """Test handling of unreadable files.""" + result = _parse_apt_signed_by(["/nonexistent/file"]) + assert result == set() + + +class TestCaptureLink: + """Tests for _capture_link function""" + + def test_capture_link_basic(self, tmp_path): + """Test basic link capture.""" + target = tmp_path / "target.txt" + target.write_text("content", encoding="utf-8") + link = tmp_path / "link.txt" + link.symlink_to(target) + + policy = MagicMock(spec=IgnorePolicy) + policy.deny_reason_link = None + policy.deny_reason = MagicMock(return_value=None) + policy.deny_reason_link = None # No special link denial + + managed: list[ManagedLink] = [] + excluded: list[ExcludedFile] = [] + path_filter = PathFilter([], []) + + result = _capture_link( + role_name="test_role", + abs_path=str(link), + reason="test", + policy=policy, + path_filter=path_filter, + managed_out=managed, + excluded_out=excluded, + seen_role=set(), + seen_global=set(), + ) + assert result is True + assert len(managed) == 1 + assert managed[0].path == str(link) + + def test_capture_link_deny(self, tmp_path): + """Test link capture with deny policy.""" + target = tmp_path / "target.txt" + target.write_text("content", encoding="utf-8") + link = tmp_path / "link.txt" + link.symlink_to(target) + + policy = MagicMock(spec=IgnorePolicy) + policy.deny_reason_link = None + policy.deny_reason = MagicMock(return_value="policy_deny") + + managed: list[ManagedLink] = [] + excluded: list[ExcludedFile] = [] + path_filter = PathFilter([], []) + + result = _capture_link( + role_name="test_role", + abs_path=str(link), + reason="test", + policy=policy, + path_filter=path_filter, + managed_out=managed, + excluded_out=excluded, + seen_role=set(), + seen_global=set(), + ) + assert result is False + assert len(excluded) == 1 + + def test_capture_link_not_symlink(self, tmp_path): + """Test that regular files are rejected.""" + f = tmp_path / "file.txt" + f.write_text("content", encoding="utf-8") + + policy = MagicMock(spec=IgnorePolicy) + policy.deny_reason_link = None + policy.deny_reason = MagicMock(return_value=None) + + managed: list[ManagedLink] = [] + excluded: list[ExcludedFile] = [] + path_filter = PathFilter([], []) + + result = _capture_link( + role_name="test_role", + abs_path=str(f), + reason="test", + policy=policy, + path_filter=path_filter, + managed_out=managed, + excluded_out=excluded, + seen_role=set(), + seen_global=set(), + ) + assert result is False + assert len(excluded) == 1 + + def test_capture_link_seen_role(self, tmp_path): + """Test link capture with seen_role deduplication.""" + target = tmp_path / "target.txt" + target.write_text("content", encoding="utf-8") + link = tmp_path / "link.txt" + link.symlink_to(target) + + policy = MagicMock(spec=IgnorePolicy) + policy.deny_reason_link = None + policy.deny_reason = MagicMock(return_value=None) + + managed: list[ManagedLink] = [] + excluded: list[ExcludedFile] = [] + path_filter = PathFilter([], []) + seen_role = {str(link)} + + result = _capture_link( + role_name="test_role", + abs_path=str(link), + reason="test", + policy=policy, + path_filter=path_filter, + managed_out=managed, + excluded_out=excluded, + seen_role=seen_role, + seen_global=None, + ) + assert result is False + assert len(managed) == 0 + + def test_capture_link_seen_global(self, tmp_path): + """Test link capture with seen_global deduplication.""" + target = tmp_path / "target.txt" + target.write_text("content", encoding="utf-8") + link = tmp_path / "link.txt" + link.symlink_to(target) + + policy = MagicMock(spec=IgnorePolicy) + policy.deny_reason_link = None + policy.deny_reason = MagicMock(return_value=None) + + managed: list[ManagedLink] = [] + excluded: list[ExcludedFile] = [] + path_filter = PathFilter([], []) + seen_global = {str(link)} + + result = _capture_link( + role_name="test_role", + abs_path=str(link), + reason="test", + policy=policy, + path_filter=path_filter, + managed_out=managed, + excluded_out=excluded, + seen_role=None, + seen_global=seen_global, + ) + assert result is False + assert len(managed) == 0 + + +class TestCaptureFile: + """Tests for _capture_file function""" + + def test_capture_file_basic(self, tmp_path): + """Test basic file capture.""" + bundle = tmp_path / "bundle" + bundle.mkdir() + (bundle / "artifacts").mkdir() + + source = tmp_path / "source.txt" + source.write_text("content", encoding="utf-8") + + policy = MagicMock(spec=IgnorePolicy) + policy.deny_reason_link = None + policy.deny_reason = MagicMock(return_value=None) + + managed: list[ManagedFile] = [] + excluded: list[ExcludedFile] = [] + path_filter = PathFilter([], []) + + result = _capture_file( + bundle_dir=str(bundle), + role_name="test_role", + abs_path=str(source), + reason="test", + policy=policy, + path_filter=path_filter, + managed_out=managed, + excluded_out=excluded, + seen_role=set(), + seen_global=set(), + metadata=None, + ) + assert result is True + assert len(managed) == 1 + + def test_capture_file_seen_role(self, tmp_path): + """Test file capture with seen_role deduplication.""" + bundle = tmp_path / "bundle" + bundle.mkdir() + + source = tmp_path / "source.txt" + source.write_text("content", encoding="utf-8") + + policy = MagicMock(spec=IgnorePolicy) + policy.deny_reason_link = None + policy.deny_reason = MagicMock(return_value=None) + + managed: list[ManagedFile] = [] + excluded: list[ExcludedFile] = [] + path_filter = PathFilter([], []) + seen_role = {str(source)} + + result = _capture_file( + bundle_dir=str(bundle), + role_name="test_role", + abs_path=str(source), + reason="test", + policy=policy, + path_filter=path_filter, + managed_out=managed, + excluded_out=excluded, + seen_role=seen_role, + seen_global=None, + metadata=None, + ) + assert result is False + assert len(managed) == 0 + + def test_capture_file_seen_global(self, tmp_path): + """Test file capture with seen_global deduplication.""" + bundle = tmp_path / "bundle" + bundle.mkdir() + + source = tmp_path / "source.txt" + source.write_text("content", encoding="utf-8") + + policy = MagicMock(spec=IgnorePolicy) + policy.deny_reason_link = None + policy.deny_reason = MagicMock(return_value=None) + + managed: list[ManagedFile] = [] + excluded: list[ExcludedFile] = [] + path_filter = PathFilter([], []) + seen_global = {str(source)} + + result = _capture_file( + bundle_dir=str(bundle), + role_name="test_role", + abs_path=str(source), + reason="test", + policy=policy, + path_filter=path_filter, + managed_out=managed, + excluded_out=excluded, + seen_role=None, + seen_global=seen_global, + metadata=None, + ) + assert result is False + assert len(managed) == 0 + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])