more test coverage

This commit is contained in:
Miguel Jacq 2026-05-31 16:50:57 +10:00
parent b25dd1e314
commit 1544dc0295
Signed by: mig5
GPG key ID: 03906B4110AAD3B8
15 changed files with 3150 additions and 424 deletions

1
.gitignore vendored
View file

@ -8,3 +8,4 @@ dist
*.pdf *.pdf
*.csv *.csv
*.html *.html
coverage.xml

View file

@ -141,3 +141,147 @@ def test_collect_non_system_users(monkeypatch, tmp_path: Path):
assert u.primary_group == "users" assert u.primary_group == "users"
assert u.supplementary_groups == ["admins"] assert u.supplementary_groups == ["admins"]
assert u.ssh_files == ["/home/alice/.ssh/authorized_keys"] assert u.ssh_files == ["/home/alice/.ssh/authorized_keys"]
def test_parse_login_defs_file_not_found(tmp_path: Path):
from enroll.accounts import parse_login_defs
nonexistent = tmp_path / "nonexistent" / "login.defs"
vals = parse_login_defs(str(nonexistent))
assert vals == {}
def test_parse_login_defs_handles_invalid_numbers(tmp_path: Path):
from enroll.accounts import parse_login_defs
p = tmp_path / "login.defs"
p.write_text("UID_MIN not_a_number\nUID_MAX 60000\n", encoding="utf-8")
vals = parse_login_defs(str(p))
assert "UID_MIN" not in vals
assert vals["UID_MAX"] == 60000
def test_parse_group_handles_invalid_gid(tmp_path: Path):
from enroll.accounts import parse_group
p = tmp_path / "group"
p.write_text(
"valid:x:1000:user1\n" "invalid_gid:x:notanint:user2\n",
encoding="utf-8",
)
gid_to_name, name_to_gid, members = parse_group(str(p))
assert 1000 in gid_to_name
assert gid_to_name[1000] == "valid"
assert "invalid_gid" not in name_to_gid
def test_parse_group_line_too_short(tmp_path: Path):
from enroll.accounts import parse_group
p = tmp_path / "group"
p.write_text(
"valid:x:1000:user1\n" "shortline:x:1001\n",
encoding="utf-8",
)
gid_to_name, name_to_gid, members = parse_group(str(p))
assert 1000 in gid_to_name
assert 1001 not in gid_to_name
def test_is_human_user_filters_by_uid_and_shell():
from enroll.accounts import is_human_user
assert is_human_user(1000, "/bin/bash", 1000) is True
assert is_human_user(999, "/bin/bash", 1000) is False
assert is_human_user(1000, "/usr/sbin/nologin", 1000) is False
assert is_human_user(1000, "/usr/bin/nologin", 1000) is False
assert is_human_user(1000, "/bin/false", 1000) is False
assert is_human_user(1000, "", 1000) is True
def test_find_user_ssh_files_no_ssh_dir(tmp_path: Path):
from enroll.accounts import find_user_ssh_files
home = tmp_path / "home" / "user"
home.mkdir(parents=True)
assert find_user_ssh_files(str(home)) == []
def test_find_user_ssh_files_ignores_symlink(tmp_path: Path):
from enroll.accounts import find_user_ssh_files
home = tmp_path / "home" / "user"
sshdir = home / ".ssh"
sshdir.mkdir(parents=True)
target = sshdir / "real_file"
target.write_text("x", encoding="utf-8")
os.symlink(str(target), str(sshdir / "authorized_keys"))
result = find_user_ssh_files(str(home))
assert result == []
def test_find_user_ssh_files_handles_home_not_starting_with_slash():
from enroll.accounts import find_user_ssh_files
assert find_user_ssh_files("relative/path") == []
assert find_user_ssh_files("") == []
def test_collect_non_system_users_skips_nologin_users(tmp_path: Path):
import enroll.accounts as a
orig_parse_login_defs = a.parse_login_defs
orig_parse_passwd = a.parse_passwd
orig_parse_group = a.parse_group
passwd = tmp_path / "passwd"
passwd.write_text(
"root:x:0:0:root:/root:/bin/bash\n"
"alice:x:1000:1000:Alice:/home/alice:/bin/bash\n"
"nobody:x:65534:65534:nobody:/nonexistent:/usr/sbin/nologin\n"
"sysuser:x:100:100:Sys:/home/sys:/bin/bash\n",
encoding="utf-8",
)
group = tmp_path / "group"
group.write_text("users:x:1000:alice\n", encoding="utf-8")
defs = tmp_path / "login.defs"
defs.write_text("UID_MIN 1000\n", encoding="utf-8")
monkeypatch_wrapper = lambda fn, p: lambda path=str(p): fn(path)
a.parse_login_defs = monkeypatch_wrapper(orig_parse_login_defs, defs)
a.parse_passwd = monkeypatch_wrapper(orig_parse_passwd, passwd)
a.parse_group = monkeypatch_wrapper(orig_parse_group, group)
a.find_user_ssh_files = lambda home: []
users = a.collect_non_system_users()
assert [u.name for u in users] == ["alice"]
def test_collect_non_system_users_skips_below_uid_min(tmp_path: Path):
import enroll.accounts as a
orig_parse_login_defs = a.parse_login_defs
orig_parse_passwd = a.parse_passwd
orig_parse_group = a.parse_group
passwd = tmp_path / "passwd"
passwd.write_text(
"root:x:0:0:root:/root:/bin/bash\n"
"sysuser:x:999:999:Sys:/home/sys:/bin/bash\n"
"alice:x:1000:1000:Alice:/home/alice:/bin/bash\n",
encoding="utf-8",
)
group = tmp_path / "group"
group.write_text("users:x:1000:alice\n", encoding="utf-8")
defs = tmp_path / "login.defs"
defs.write_text("UID_MIN 1000\n", encoding="utf-8")
a.parse_login_defs = lambda path=str(defs): orig_parse_login_defs(path)
a.parse_passwd = lambda path=str(passwd): orig_parse_passwd(path)
a.parse_group = lambda path=str(group): orig_parse_group(path)
a.find_user_ssh_files = lambda home: []
users = a.collect_non_system_users()
assert [u.name for u in users] == ["alice"]

View file

@ -96,3 +96,244 @@ def test_parse_status_conffiles_handles_continuations(tmp_path: Path):
assert m["nginx"]["/etc/nginx/nginx.conf"] == "abcdef" assert m["nginx"]["/etc/nginx/nginx.conf"] == "abcdef"
assert m["nginx"]["/etc/nginx/mime.types"] == "123456" assert m["nginx"]["/etc/nginx/mime.types"] == "123456"
assert "other" not in m assert "other" not in m
def test_dpkg_owner_returns_none_on_diversion_only(monkeypatch):
import enroll.debian as d
class P:
def __init__(self, rc: int, out: str):
self.returncode = rc
self.stdout = out
self.stderr = ""
def fake_run(cmd, text, capture_output):
return P(0, "diversion by foo from: /etc/something\n")
monkeypatch.setattr(d.subprocess, "run", fake_run)
assert d.dpkg_owner("/etc/something") is None
def test_dpkg_owner_handles_line_without_colon(monkeypatch):
import enroll.debian as d
class P:
def __init__(self, rc: int, out: str):
self.returncode = rc
self.stdout = out
self.stderr = ""
def fake_run(cmd, text, capture_output):
return P(0, "invalid line without colon\n")
monkeypatch.setattr(d.subprocess, "run", fake_run)
assert d.dpkg_owner("/etc/foo") is None
def test_list_manual_packages_returns_empty_on_error(monkeypatch):
import enroll.debian as d
class P:
def __init__(self, rc: int, out: str):
self.returncode = rc
self.stdout = out
self.stderr = ""
def fake_run(cmd, text, capture_output):
return P(1, "error")
monkeypatch.setattr(d.subprocess, "run", fake_run)
assert d.list_manual_packages() == []
def test_list_installed_packages_handles_exception(monkeypatch):
import enroll.debian as d
def fake_run(*args, **kwargs):
raise Exception("simulated error")
monkeypatch.setattr(d.subprocess, "run", fake_run)
assert d.list_installed_packages() == {}
def test_list_installed_packages_parses_output():
import enroll.debian as d
class P:
def __init__(self, rc: int, out: str):
self.returncode = rc
self.stdout = out
self.stderr = ""
original_run = d.subprocess.run
def fake_run(cmd, text, capture_output, check):
return P(0, "nginx\t1.18.0\tamd64\nvim\t8.2\tamd64\n")
d.subprocess.run = fake_run
try:
result = d.list_installed_packages()
assert "nginx" in result
assert result["nginx"][0]["version"] == "1.18.0"
assert result["nginx"][0]["arch"] == "amd64"
assert "vim" in result
finally:
d.subprocess.run = original_run
def test_list_installed_packages_skips_invalid_lines():
import enroll.debian as d
class P:
def __init__(self, rc: int, out: str):
self.returncode = rc
self.stdout = out
self.stderr = ""
original_run = d.subprocess.run
def fake_run(cmd, text, capture_output, check):
return P(0, "nginx\t1.18.0\tamd64\ninvalid_line\n\t1.0\tamd64\n")
d.subprocess.run = fake_run
try:
result = d.list_installed_packages()
assert "nginx" in result
assert "invalid_line" not in result
finally:
d.subprocess.run = original_run
def test_list_installed_packages_handles_empty_name():
import enroll.debian as d
class P:
def __init__(self, rc: int, out: str):
self.returncode = rc
self.stdout = out
self.stderr = ""
original_run = d.subprocess.run
def fake_run(cmd, text, capture_output, check):
return P(0, "\t1.0\tamd64\nnginx\t1.18.0\tamd64\n")
d.subprocess.run = fake_run
try:
result = d.list_installed_packages()
assert "" not in result
assert "nginx" in result
finally:
d.subprocess.run = original_run
def test_list_installed_packages_sorts_output():
import enroll.debian as d
class P:
def __init__(self, rc: int, out: str):
self.returncode = rc
self.stdout = out
self.stderr = ""
original_run = d.subprocess.run
def fake_run(cmd, text, capture_output, check):
return P(0, "nginx\t1.18.0\tamd64\nnginx\t1.19.0\tarm64\n")
d.subprocess.run = fake_run
try:
result = d.list_installed_packages()
assert len(result["nginx"]) == 2
assert result["nginx"][0]["arch"] == "amd64"
assert result["nginx"][1]["arch"] == "arm64"
finally:
d.subprocess.run = original_run
def test_build_dpkg_etc_index_handles_missing_file(tmp_path: Path):
import enroll.debian as d
info = tmp_path / "info"
info.mkdir()
# Don't create any .list files
owned, owner_map, topdir_to_pkgs, pkg_to_etc = d.build_dpkg_etc_index(str(info))
assert owned == set()
assert owner_map == {}
assert topdir_to_pkgs == {}
assert pkg_to_etc == {}
def test_build_dpkg_etc_index_skips_non_etc_paths(tmp_path: Path):
import enroll.debian as d
info = tmp_path / "info"
info.mkdir()
(info / "foo.list").write_text("/usr/bin/foo\n/etc/bar\n", encoding="utf-8")
owned, owner_map, topdir_to_pkgs, pkg_to_etc = d.build_dpkg_etc_index(str(info))
assert "/usr/bin/foo" not in owned
assert "/etc/bar" in owned
assert "foo" not in topdir_to_pkgs
def test_parse_status_conffiles_handles_empty_status(tmp_path: Path):
import enroll.debian as d
status = tmp_path / "status"
status.write_text("", encoding="utf-8")
m = d.parse_status_conffiles(str(status))
assert m == {}
def test_parse_status_conffiles_handles_package_without_conffiles(tmp_path: Path):
import enroll.debian as d
status = tmp_path / "status"
status.write_text(
"Package: nginx\nVersion: 1\nStatus: install ok installed\n",
encoding="utf-8",
)
m = d.parse_status_conffiles(str(status))
assert m == {}
def test_read_pkg_md5sums_returns_empty_if_file_not_exists(tmp_path: Path):
import enroll.debian as d
result = d.read_pkg_md5sums("nonexistent_package")
assert result == {}
def test_read_pkg_md5sums_parses_md5sums_file(tmp_path: Path, monkeypatch):
import enroll.debian as d
info_dir = tmp_path / "info"
info_dir.mkdir()
md5_file = info_dir / "nginx.md5sums"
md5_file.write_text(
"abcdef1234567890abcdef1234567890 etc/nginx/nginx.conf\n"
"1234567890abcdef1234567890abcdef etc/nginx/sites-enabled/default\n",
encoding="utf-8",
)
def fake_exists(path):
return str(path).endswith("nginx.md5sums")
monkeypatch.setattr(d.os.path, "exists", fake_exists)
original_open = open
def fake_open(path, *args, **kwargs):
if "nginx.md5sums" in str(path):
return original_open(md5_file, *args, **kwargs)
return original_open(path, *args, **kwargs)
monkeypatch.setattr("builtins.open", fake_open, raising=False)
result = d.read_pkg_md5sums("nginx")
assert result["etc/nginx/nginx.conf"] == "abcdef1234567890abcdef1234567890"
assert (
result["etc/nginx/sites-enabled/default"] == "1234567890abcdef1234567890abcdef"
)

View file

@ -87,3 +87,995 @@ def test_bundle_from_input_missing_path(tmp_path: Path):
with pytest.raises(RuntimeError, match="not found"): with pytest.raises(RuntimeError, match="not found"):
d._bundle_from_input(str(tmp_path / "nope"), sops_mode=False) d._bundle_from_input(str(tmp_path / "nope"), sops_mode=False)
import json
import sys
from enroll.diff import (
_bundle_from_input,
_file_index,
_iter_managed_files,
_load_state,
_pkg_version_display,
_pkg_version_key,
_progress_enabled,
_roles,
_service_units,
_sha256,
_users_by_name,
compare_harvests,
)
from enroll.sopsutil import SopsError
def test_progress_enabled_when_tty(monkeypatch):
monkeypatch.setattr(sys.stderr, "isatty", lambda: True)
monkeypatch.delenv("ENROLL_NO_PROGRESS", raising=False)
assert _progress_enabled() is True
def test_progress_enabled_when_not_tty(monkeypatch):
monkeypatch.setattr(sys.stderr, "isatty", lambda: False)
monkeypatch.delenv("ENROLL_NO_PROGRESS", raising=False)
assert _progress_enabled() is False
def test_progress_enabled_with_env_var(monkeypatch):
monkeypatch.setattr(sys.stderr, "isatty", lambda: True)
monkeypatch.setenv("ENROLL_NO_PROGRESS", "1")
assert _progress_enabled() is False
monkeypatch.setenv("ENROLL_NO_PROGRESS", "true")
assert _progress_enabled() is False
monkeypatch.setenv("ENROLL_NO_PROGRESS", "yes")
assert _progress_enabled() is False
def test_sha256(tmp_path: Path):
test_file = tmp_path / "test.txt"
test_file.write_text("hello world", encoding="utf-8")
hash_result = _sha256(test_file)
assert len(hash_result) == 64
def test_sha256_empty_file(tmp_path: Path):
test_file = tmp_path / "empty.txt"
test_file.write_bytes(b"")
hash_result = _sha256(test_file)
assert (
hash_result
== "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
)
def test_bundle_from_input_directory(tmp_path: Path):
result = _bundle_from_input(str(tmp_path), sops_mode=False)
assert result.dir == tmp_path
assert result.tempdir is None
def test_bundle_from_input_state_json_path(tmp_path: Path):
state_file = tmp_path / "state.json"
state_file.write_text("{}", encoding="utf-8")
result = _bundle_from_input(str(state_file), sops_mode=False)
assert result.dir == tmp_path
assert result.tempdir is None
def test_bundle_from_input_not_found():
with pytest.raises(RuntimeError) as exc_info:
_bundle_from_input("/nonexistent/path", sops_mode=False)
assert "not found" in str(exc_info.value).lower()
def test_bundle_from_input_tarball(tmp_path: Path):
bundle_dir = tmp_path / "bundle"
bundle_dir.mkdir()
state_file = bundle_dir / "state.json"
state_file.write_text("{}", encoding="utf-8")
tar_path = tmp_path / "bundle.tar.gz"
with tarfile.open(tar_path, "w:gz") as tf:
tf.add(bundle_dir, arcname="bundle")
result = _bundle_from_input(str(tar_path), sops_mode=False)
assert result.dir.exists()
assert result.tempdir is not None
result.tempdir.cleanup()
def test_bundle_from_input_invalid_type(tmp_path: Path):
test_file = tmp_path / "test.txt"
test_file.write_text("not a bundle", encoding="utf-8")
with pytest.raises(RuntimeError) as exc_info:
_bundle_from_input(str(test_file), sops_mode=False)
assert "not a directory" in str(exc_info.value).lower()
def test_load_state(tmp_path: Path):
state_file = tmp_path / "state.json"
state_file.write_text('{"host": {"hostname": "test"}}', encoding="utf-8")
result = _load_state(tmp_path)
assert result["host"]["hostname"] == "test"
def test_roles_empty_state():
assert _roles({}) == {}
def test_roles_with_roles():
state = {"roles": {"users": {}, "services": []}}
result = _roles(state)
assert "users" in result
def test_service_units_empty():
assert _service_units({}) == {}
def test_service_units_with_services():
state = {
"roles": {
"services": [
{"unit": "nginx.service", "active_state": "active"},
{"unit": "ssh.service", "active_state": "inactive"},
]
}
}
result = _service_units(state)
assert "nginx.service" in result
assert "ssh.service" in result
assert result["nginx.service"]["active_state"] == "active"
def test_users_by_name_empty():
assert _users_by_name({}) == {}
def test_users_by_name_with_users():
state = {
"roles": {
"users": {
"users": [
{"name": "alice", "uid": 1000},
{"name": "bob", "uid": 1001},
]
}
}
}
result = _users_by_name(state)
assert "alice" in result
assert "bob" in result
assert result["alice"]["uid"] == 1000
def test_pkg_version_key_with_version():
entry = {"version": "1.2.3"}
assert _pkg_version_key(entry) == "1.2.3"
def test_pkg_version_key_with_installations():
entry = {
"installations": [
{"arch": "x86_64", "version": "1.2.3"},
{"arch": "aarch64", "version": "1.2.3"},
]
}
result = _pkg_version_key(entry)
assert "x86_64:1.2.3" in result
assert "aarch64:1.2.3" in result
def test_pkg_version_key_with_empty_version():
entry = {"version": None}
assert _pkg_version_key(entry) is None
def test_pkg_version_key_with_invalid_installations():
entry = {"installations": ["not_a_dict", {"arch": "x86_64", "version": "1.0"}]}
result = _pkg_version_key(entry)
assert "x86_64:1.0" in result
def test_pkg_version_display_with_version():
entry = {"version": "1.2.3"}
assert _pkg_version_display(entry) == "1.2.3"
def test_pkg_version_display_with_installations():
entry = {
"installations": [
{"arch": "x86_64", "version": "1.2.3"},
]
}
assert _pkg_version_display(entry) == "1.2.3 (x86_64)"
def test_pkg_version_display_empty():
assert _pkg_version_display({}) is None
def test_iter_managed_files_empty():
state = {"roles": {}}
files = list(_iter_managed_files(state))
assert files == []
def test_iter_managed_files_services():
state = {
"roles": {
"services": [
{
"role_name": "nginx",
"managed_files": [
{"path": "/etc/nginx/nginx.conf", "src_rel": "nginx.conf"}
],
}
]
}
}
files = list(_iter_managed_files(state))
assert len(files) == 1
assert files[0] == (
"nginx",
{"path": "/etc/nginx/nginx.conf", "src_rel": "nginx.conf"},
)
def test_iter_managed_files_packages():
state = {
"roles": {
"packages": [
{
"role_name": "vim",
"managed_files": [{"path": "/usr/bin/vim", "src_rel": "bin/vim"}],
}
]
}
}
files = list(_iter_managed_files(state))
assert len(files) == 1
assert files[0][0] == "vim"
def test_iter_managed_files_users():
state = {
"roles": {
"users": {
"role_name": "users",
"managed_files": [{"path": "/etc/passwd", "src_rel": "passwd"}],
}
}
}
files = list(_iter_managed_files(state))
assert len(files) == 1
assert files[0][0] == "users"
def test_iter_managed_files_apt_config():
state = {
"roles": {
"apt_config": {
"role_name": "apt_config",
"managed_files": [
{"path": "/etc/apt/sources.list", "src_rel": "sources.list"}
],
}
}
}
files = list(_iter_managed_files(state))
assert len(files) == 1
assert files[0][0] == "apt_config"
def test_iter_managed_files_etc_custom():
state = {
"roles": {
"etc_custom": {
"role_name": "etc_custom",
"managed_files": [
{"path": "/etc/custom.conf", "src_rel": "custom.conf"}
],
}
}
}
files = list(_iter_managed_files(state))
assert len(files) == 1
assert files[0][0] == "etc_custom"
def test_iter_managed_files_usr_local_custom():
state = {
"roles": {
"usr_local_custom": {
"role_name": "usr_local_custom",
"managed_files": [
{"path": "/usr/local/bin/script", "src_rel": "bin/script"}
],
}
}
}
files = list(_iter_managed_files(state))
assert len(files) == 1
assert files[0][0] == "usr_local_custom"
def test_iter_managed_files_extra_paths():
state = {
"roles": {
"extra_paths": {
"role_name": "extra_paths",
"managed_files": [{"path": "/opt/app/config", "src_rel": "config"}],
}
}
}
files = list(_iter_managed_files(state))
assert len(files) == 1
assert files[0][0] == "extra_paths"
def test_file_index_empty():
state = {"roles": {}}
index = _file_index(Path("/tmp"), state)
assert index == {}
def test_file_index_with_files(tmp_path: Path):
state = {
"roles": {
"users": {
"managed_files": [
{"path": "/etc/passwd", "src_rel": "passwd", "owner": "root"},
]
}
}
}
index = _file_index(tmp_path, state)
assert "/etc/passwd" in index
assert index["/etc/passwd"].role == "users"
assert index["/etc/passwd"].owner == "root"
def test_file_index_duplicates_first_wins(tmp_path: Path):
state = {
"roles": {
"users": {
"managed_files": [
{"path": "/etc/passwd", "src_rel": "passwd"},
]
},
"etc_custom": {
"managed_files": [
{"path": "/etc/passwd", "src_rel": "custom_passwd"},
]
},
}
}
index = _file_index(tmp_path, state)
assert "/etc/passwd" in index
assert index["/etc/passwd"].src_rel == "passwd"
def test_file_index_skips_missing_path_or_src_rel(tmp_path: Path):
state = {
"roles": {
"users": {
"managed_files": [
{"path": "/etc/passwd"}, # missing src_rel
{"src_rel": "passwd"}, # missing path
]
}
}
}
index = _file_index(tmp_path, state)
assert index == {}
def test_compare_harvests_no_changes(tmp_path: Path):
old_bundle = tmp_path / "old"
old_bundle.mkdir()
(old_bundle / "state.json").write_text(
json.dumps(
{
"inventory": {"packages": {"vim": {"version": "1.0"}}},
"roles": {},
}
),
encoding="utf-8",
)
new_bundle = tmp_path / "new"
new_bundle.mkdir()
(new_bundle / "state.json").write_text(
json.dumps(
{
"inventory": {"packages": {"vim": {"version": "1.0"}}},
"roles": {},
}
),
encoding="utf-8",
)
report, has_changes = compare_harvests(str(old_bundle), str(new_bundle))
assert has_changes is False
assert report["packages"]["added"] == []
assert report["packages"]["removed"] == []
def test_compare_harvests_package_added(tmp_path: Path):
old_bundle = tmp_path / "old"
old_bundle.mkdir()
(old_bundle / "state.json").write_text(
json.dumps({"inventory": {"packages": {}}, "roles": {}}),
encoding="utf-8",
)
new_bundle = tmp_path / "new"
new_bundle.mkdir()
(new_bundle / "state.json").write_text(
json.dumps(
{"inventory": {"packages": {"vim": {"version": "1.0"}}}, "roles": {}}
),
encoding="utf-8",
)
report, has_changes = compare_harvests(str(old_bundle), str(new_bundle))
assert has_changes is True
assert "vim" in report["packages"]["added"]
def test_compare_harvests_package_removed(tmp_path: Path):
old_bundle = tmp_path / "old"
old_bundle.mkdir()
(old_bundle / "state.json").write_text(
json.dumps(
{"inventory": {"packages": {"vim": {"version": "1.0"}}}, "roles": {}}
),
encoding="utf-8",
)
new_bundle = tmp_path / "new"
new_bundle.mkdir()
(new_bundle / "state.json").write_text(
json.dumps({"inventory": {"packages": {}}, "roles": {}}),
encoding="utf-8",
)
report, has_changes = compare_harvests(str(old_bundle), str(new_bundle))
assert has_changes is True
assert "vim" in report["packages"]["removed"]
def test_compare_harvests_package_version_changed(tmp_path: Path):
old_bundle = tmp_path / "old"
old_bundle.mkdir()
(old_bundle / "state.json").write_text(
json.dumps(
{"inventory": {"packages": {"vim": {"version": "1.0"}}}, "roles": {}}
),
encoding="utf-8",
)
new_bundle = tmp_path / "new"
new_bundle.mkdir()
(new_bundle / "state.json").write_text(
json.dumps(
{"inventory": {"packages": {"vim": {"version": "2.0"}}}, "roles": {}}
),
encoding="utf-8",
)
report, has_changes = compare_harvests(str(old_bundle), str(new_bundle))
assert has_changes is True
assert len(report["packages"]["version_changed"]) == 1
def test_compare_harvests_ignore_package_versions(tmp_path: Path):
old_bundle = tmp_path / "old"
old_bundle.mkdir()
(old_bundle / "state.json").write_text(
json.dumps(
{"inventory": {"packages": {"vim": {"version": "1.0"}}}, "roles": {}}
),
encoding="utf-8",
)
new_bundle = tmp_path / "new"
new_bundle.mkdir()
(new_bundle / "state.json").write_text(
json.dumps(
{"inventory": {"packages": {"vim": {"version": "2.0"}}}, "roles": {}}
),
encoding="utf-8",
)
report, has_changes = compare_harvests(
str(old_bundle), str(new_bundle), ignore_package_versions=True
)
assert report["packages"]["version_changed_ignored_count"] == 1
def test_compare_harvests_service_added(tmp_path: Path):
old_bundle = tmp_path / "old"
old_bundle.mkdir()
(old_bundle / "state.json").write_text(
json.dumps({"inventory": {"packages": {}}, "roles": {"services": []}}),
encoding="utf-8",
)
new_bundle = tmp_path / "new"
new_bundle.mkdir()
(new_bundle / "state.json").write_text(
json.dumps(
{
"inventory": {"packages": {}},
"roles": {"services": [{"unit": "nginx.service"}]},
}
),
encoding="utf-8",
)
report, has_changes = compare_harvests(str(old_bundle), str(new_bundle))
assert has_changes is True
assert "nginx.service" in report["services"]["enabled_added"]
def test_compare_harvests_user_added(tmp_path: Path):
old_bundle = tmp_path / "old"
old_bundle.mkdir()
(old_bundle / "state.json").write_text(
json.dumps({"inventory": {"packages": {}}, "roles": {"users": {"users": []}}}),
encoding="utf-8",
)
new_bundle = tmp_path / "new"
new_bundle.mkdir()
(new_bundle / "state.json").write_text(
json.dumps(
{
"inventory": {"packages": {}},
"roles": {"users": {"users": [{"name": "alice", "uid": 1000}]}},
}
),
encoding="utf-8",
)
report, has_changes = compare_harvests(str(old_bundle), str(new_bundle))
assert has_changes is True
assert "alice" in report["users"]["added"]
def test_compare_harvests_with_exclude_paths(tmp_path: Path):
old_bundle = tmp_path / "old"
old_bundle.mkdir()
old_artifacts = old_bundle / "artifacts" / "users"
old_artifacts.mkdir(parents=True)
(old_artifacts / "passwd").write_text("old", encoding="utf-8")
(old_bundle / "state.json").write_text(
json.dumps(
{
"inventory": {"packages": {}},
"roles": {
"users": {
"managed_files": [{"path": "/etc/passwd", "src_rel": "passwd"}]
}
},
}
),
encoding="utf-8",
)
new_bundle = tmp_path / "new"
new_bundle.mkdir()
new_artifacts = new_bundle / "artifacts" / "users"
new_artifacts.mkdir(parents=True)
(new_artifacts / "passwd").write_text("new", encoding="utf-8")
(new_bundle / "state.json").write_text(
json.dumps(
{
"inventory": {"packages": {}},
"roles": {
"users": {
"managed_files": [{"path": "/etc/passwd", "src_rel": "passwd"}]
}
},
}
),
encoding="utf-8",
)
report, has_changes = compare_harvests(
str(old_bundle), str(new_bundle), exclude_paths=["/etc/passwd"]
)
assert "/etc/passwd" not in [f["path"] for f in report["files"]["added"]]
assert "/etc/passwd" not in [f["path"] for f in report["files"]["removed"]]
assert "/etc/passwd" not in [f["path"] for f in report["files"]["changed"]]
from enroll.diff import (
_Spinner,
_enforcement_plan,
has_enforceable_drift,
_role_tag,
_utc_now_iso,
_report_markdown,
)
def test_utc_now_iso():
result = _utc_now_iso()
assert "T" in result
assert "+" in result or "Z" in result
def test_spinner_start_stop(monkeypatch):
# Mock sys.stderr to avoid actual writes
class FakeStderr:
def write(self, s):
pass
def flush(self):
pass
def isatty(self):
return True
monkeypatch.setattr(sys, "stderr", FakeStderr())
spinner = _Spinner("Test")
spinner.start()
spinner.stop(final_line="Done")
# Should not raise
def test_spinner_stop_without_start():
spinner = _Spinner("Test")
spinner.stop(final_line="Done")
# Should not raise
def test_spinner_run_exception(monkeypatch):
class FakeStderr:
def write(self, s):
raise Exception("Write error")
def flush(self):
pass
monkeypatch.setattr(sys, "stderr", FakeStderr())
spinner = _Spinner("Test")
spinner.start()
spinner.stop()
def test_spinner_double_start():
spinner = _Spinner("Test")
spinner.start()
spinner.start() # Should not raise or spawn another thread
spinner.stop()
def test_role_tag_normal():
assert _role_tag("nginx") == "role_nginx"
assert _role_tag("my-app") == "role_my-app"
def test_role_tag_with_special_chars():
assert _role_tag("my.app") == "role_my_app"
assert _role_tag("my app") == "role_my_app"
def test_role_tag_empty():
assert _role_tag("") == "role_other"
assert _role_tag(" ") == "role_other"
def test_has_enforceable_drift_packages_removed():
report = {"packages": {"removed": ["vim"]}}
assert has_enforceable_drift(report) is True
def test_has_enforceable_drift_services_removed():
report = {"services": {"enabled_removed": ["nginx.service"]}}
assert has_enforceable_drift(report) is True
def test_has_enforceable_drift_service_changed():
report = {
"services": {
"changed": [
{
"unit": "nginx.service",
"changes": {"active_state": {"old": "active", "new": "inactive"}},
}
]
}
}
assert has_enforceable_drift(report) is True
def test_has_enforceable_drift_service_package_only_changed():
# Service changed only in packages - should NOT be enforceable
report = {
"services": {
"changed": [
{
"unit": "nginx.service",
"changes": {"packages": {"added": ["nginx-extra"]}},
}
]
}
}
assert has_enforceable_drift(report) is False
def test_has_enforceable_drift_users_removed():
report = {"users": {"removed": ["alice"]}}
assert has_enforceable_drift(report) is True
def test_has_enforceable_drift_users_changed():
report = {
"users": {
"changed": [
{"name": "alice", "changes": {"uid": {"old": 1000, "new": 1001}}}
]
}
}
assert has_enforceable_drift(report) is True
def test_has_enforceable_drift_files_removed():
report = {
"files": {
"removed": [{"path": "/etc/passwd", "role": "users", "reason": "conffile"}]
}
}
assert has_enforceable_drift(report) is True
def test_has_enforceable_drift_files_changed():
report = {
"files": {
"changed": [
{
"path": "/etc/passwd",
"changes": {"content": {"old": "sha1", "new": "sha2"}},
}
]
}
}
assert has_enforceable_drift(report) is True
def test_has_enforceable_drift_no_drift():
report = {
"packages": {"added": ["newpkg"]},
"services": {"enabled_added": ["new.service"]},
"users": {"added": ["bob"]},
"files": {"added": ["/opt/newfile"]},
}
assert has_enforceable_drift(report) is False
def test_enforcement_plan_packages_removed(monkeypatch, tmp_path: Path):
old_state = {
"roles": {
"services": [{"role_name": "nginx", "packages": ["nginx"]}],
"packages": [{"role_name": "vim", "package": "vim"}],
}
}
report = {"packages": {"removed": ["nginx", "vim"]}}
result = _enforcement_plan(report, old_state, tmp_path)
assert "nginx" in result.get("roles", [])
assert "vim" in result.get("roles", [])
assert "role_nginx" in result.get("tags", [])
def test_enforcement_plan_users_changed():
old_state = {
"roles": {"users": {"role_name": "users", "users": [{"name": "alice"}]}}
}
report = {"users": {"changed": [{"name": "alice", "changes": {"uid": {}}}]}}
result = _enforcement_plan(report, old_state, Path("/tmp"))
assert "users" in result.get("roles", [])
def test_enforcement_plan_files_removed(tmp_path: Path):
# Create the artifacts directory structure that _file_index expects
artifacts_dir = tmp_path / "artifacts" / "etc_custom"
artifacts_dir.mkdir(parents=True)
old_state = {
"roles": {
"etc_custom": {
"role_name": "etc_custom",
"managed_files": [
{"path": "/etc/custom.conf", "src_rel": "custom.conf"}
],
}
}
}
report = {
"files": {"removed": [{"path": "/etc/custom.conf", "role": "etc_custom"}]}
}
result = _enforcement_plan(report, old_state, tmp_path)
assert "etc_custom" in result.get("roles", [])
def test_enforcement_plan_no_drift():
old_state = {"roles": {}}
report = {"packages": {"added": ["newpkg"]}}
result = _enforcement_plan(report, old_state, Path("/tmp"))
assert result.get("roles", []) == []
def test_bundle_from_input_tgz(monkeypatch, tmp_path: Path):
bundle_dir = tmp_path / "bundle"
bundle_dir.mkdir()
state_file = bundle_dir / "state.json"
state_file.write_text("{}", encoding="utf-8")
tar_path = tmp_path / "bundle.tgz"
with tarfile.open(tar_path, "w:gz") as tf:
tf.add(bundle_dir, arcname="bundle")
result = _bundle_from_input(str(tar_path), sops_mode=False)
assert result.dir.exists()
assert result.tempdir is not None
result.tempdir.cleanup()
def test_bundle_from_input_sops_mode_no_sops(monkeypatch, tmp_path: Path):
# Create a fake .sops file
sops_file = tmp_path / "harvest.sops"
sops_file.write_bytes(b"encrypted")
def fake_require():
raise SopsError("sops not found")
import enroll.diff as d
monkeypatch.setattr(d, "require_sops_cmd", fake_require)
with pytest.raises(SopsError):
_bundle_from_input(str(sops_file), sops_mode=True)
def test_report_markdown_basic():
report = {
"generated_at": "2024-01-01T00:00:00Z",
"old": {"input": "old.tar.gz", "host": "host1"},
"new": {"input": "new.tar.gz", "host": "host2"},
"packages": {"added": ["vim"], "removed": [], "version_changed": []},
"services": {"enabled_added": [], "enabled_removed": [], "changed": []},
"users": {"added": [], "removed": [], "changed": []},
"files": {"added": [], "removed": [], "changed": []},
}
result = _report_markdown(report)
assert "## Packages" in result
assert "+ vim" in result
def test_report_markdown_with_enforcement_applied():
report = {
"generated_at": "2024-01-01T00:00:00Z",
"old": {"input": "old.tar.gz"},
"new": {"input": "new.tar.gz"},
"packages": {"added": [], "removed": [], "version_changed": []},
"services": {"enabled_added": [], "enabled_removed": [], "changed": []},
"users": {"added": [], "removed": [], "changed": []},
"files": {"added": [], "removed": [], "changed": []},
"enforcement": {
"status": "applied",
"tags": ["role_users"],
"returncode": 0,
"finished_at": "2024-01-01T00:01:00Z",
},
}
result = _report_markdown(report)
assert "Applied old harvest" in result
assert "role_users" in result
def test_report_markdown_with_enforcement_failed():
report = {
"generated_at": "2024-01-01T00:00:00Z",
"old": {"input": "old.tar.gz"},
"new": {"input": "new.tar.gz"},
"packages": {"added": [], "removed": [], "version_changed": []},
"services": {"enabled_added": [], "enabled_removed": [], "changed": []},
"users": {"added": [], "removed": [], "changed": []},
"files": {"added": [], "removed": [], "changed": []},
"enforcement": {
"status": "failed",
"returncode": 1,
},
}
result = _report_markdown(report)
assert "ansible-playbook failed" in result
def test_report_markdown_with_enforcement_skipped():
report = {
"generated_at": "2024-01-01T00:00:00Z",
"old": {"input": "old.tar.gz"},
"new": {"input": "new.tar.gz"},
"packages": {"added": [], "removed": [], "version_changed": []},
"services": {"enabled_added": [], "enabled_removed": [], "changed": []},
"users": {"added": [], "removed": [], "changed": []},
"files": {"added": [], "removed": [], "changed": []},
"enforcement": {
"status": "skipped",
"reason": "no drift",
},
}
result = _report_markdown(report)
assert "Skipped" in result
assert "no drift" in result
def test_report_markdown_with_version_ignored():
report = {
"generated_at": "2024-01-01T00:00:00Z",
"old": {"input": "old.tar.gz"},
"new": {"input": "new.tar.gz"},
"packages": {
"added": [],
"removed": [],
"version_changed": [{"package": "vim", "old": "1.0", "new": "2.0"}],
"version_changed_ignored_count": 1,
},
"services": {"enabled_added": [], "enabled_removed": [], "changed": []},
"users": {"added": [], "removed": [], "changed": []},
"files": {"added": [], "removed": [], "changed": []},
}
result = _report_markdown(report)
assert "ignored 1" in result
def test_report_markdown_with_service_package_changes():
report = {
"generated_at": "2024-01-01T00:00:00Z",
"old": {"input": "old.tar.gz"},
"new": {"input": "new.tar.gz"},
"packages": {"added": [], "removed": [], "version_changed": []},
"services": {
"enabled_added": [],
"enabled_removed": [],
"changed": [
{
"unit": "nginx.service",
"changes": {"packages": {"added": ["nginx-extra"], "removed": []}},
}
],
},
"users": {"added": [], "removed": [], "changed": []},
"files": {"added": [], "removed": [], "changed": []},
}
result = _report_markdown(report)
assert "packages added" in result
def test_report_markdown_empty():
report = {
"generated_at": "2024-01-01T00:00:00Z",
"old": {"input": "old.tar.gz"},
"new": {"input": "new.tar.gz"},
"packages": {},
"services": {},
"users": {},
"files": {},
}
result = _report_markdown(report)
assert "## Packages" in result
assert "## Services" in result

View file

@ -1,4 +1,5 @@
import json import json
import enroll.harvest as harvest
from pathlib import Path from pathlib import Path
import enroll.harvest as h import enroll.harvest as h
@ -367,3 +368,149 @@ def test_shared_cron_snippet_prefers_matching_role_over_lexicographic(
assert all( assert all(
mf["path"] != "/etc/cron.d/ntpsec" for mf in svc_apparmor["managed_files"] mf["path"] != "/etc/cron.d/ntpsec" for mf in svc_apparmor["managed_files"]
) )
def test_files_differ_same_content(tmp_path: Path):
file1 = tmp_path / "file1.txt"
file2 = tmp_path / "file2.txt"
file1.write_text("same content", encoding="utf-8")
file2.write_text("same content", encoding="utf-8")
assert harvest._files_differ(str(file1), str(file2)) is False
def test_files_differ_different_content(tmp_path: Path):
file1 = tmp_path / "file1.txt"
file2 = tmp_path / "file2.txt"
file1.write_text("content1", encoding="utf-8")
file2.write_text("content2", encoding="utf-8")
assert harvest._files_differ(str(file1), str(file2)) is True
def test_files_differ_missing_file(tmp_path: Path):
file1 = tmp_path / "file1.txt"
file1.write_text("content", encoding="utf-8")
file2 = tmp_path / "file2.txt"
assert harvest._files_differ(str(file1), str(file2)) is True
def test_files_differ_both_missing(tmp_path: Path):
file1 = tmp_path / "file1.txt"
file2 = tmp_path / "file2.txt"
assert harvest._files_differ(str(file1), str(file2)) is True
def test_files_differ_binary(tmp_path: Path):
file1 = tmp_path / "file1.bin"
file2 = tmp_path / "file2.bin"
file1.write_bytes(b"\x00\x01\x02\x03")
file2.write_bytes(b"\x00\x01\x02\x03")
assert harvest._files_differ(str(file1), str(file2)) is False
def test_files_differ_binary_different(tmp_path: Path):
file1 = tmp_path / "file1.bin"
file2 = tmp_path / "file2.bin"
file1.write_bytes(b"\x00\x01\x02\x03")
file2.write_bytes(b"\x00\x01\x02\x04")
assert harvest._files_differ(str(file1), str(file2)) is True
def test_files_differ_non_regular_a(tmp_path: Path):
directory = tmp_path / "dir"
directory.mkdir()
file1 = tmp_path / "file1.txt"
file1.write_text("content", encoding="utf-8")
assert harvest._files_differ(str(directory), str(file1)) is True
def test_files_differ_non_regular_b(tmp_path: Path):
file1 = tmp_path / "file1.txt"
file1.write_text("content", encoding="utf-8")
directory = tmp_path / "dir"
directory.mkdir()
assert harvest._files_differ(str(file1), str(directory)) is True
def test_files_differ_size_mismatch(tmp_path: Path):
file1 = tmp_path / "file1.txt"
file1.write_text("short", encoding="utf-8")
file2 = tmp_path / "file2.txt"
file2.write_text("much longer content", encoding="utf-8")
assert harvest._files_differ(str(file1), str(file2)) is True
def test_files_differ_large_files(tmp_path: Path):
file1 = tmp_path / "file1.bin"
file2 = tmp_path / "file2.bin"
file1.write_bytes(b"x" * 3_000_000)
file2.write_bytes(b"x" * 3_000_000)
assert harvest._files_differ(str(file1), str(file2)) is True
def test_is_confish_with_conf(tmp_path: Path):
file1 = tmp_path / "test.conf"
file1.write_text("content", encoding="utf-8")
assert harvest._is_confish(str(file1)) is True
def test_is_confish_with_yaml(tmp_path: Path):
file1 = tmp_path / "test.yaml"
file1.write_text("content", encoding="utf-8")
assert harvest._is_confish(str(file1)) is True
def test_is_confish_with_json(tmp_path: Path):
file1 = tmp_path / "test.json"
file1.write_text("{}", encoding="utf-8")
assert harvest._is_confish(str(file1)) is True
def test_is_confish_with_service(tmp_path: Path):
file1 = tmp_path / "test.service"
file1.write_text("[Unit]", encoding="utf-8")
assert harvest._is_confish(str(file1)) is True
def test_is_confish_with_extensionless(tmp_path: Path):
file1 = tmp_path / "default"
file1.write_text("OPTIONS=", encoding="utf-8")
assert harvest._is_confish(str(file1)) is True
def test_is_confish_not_config(tmp_path: Path):
file1 = tmp_path / "test.log"
file1.write_text("log", encoding="utf-8")
assert harvest._is_confish(str(file1)) is False
def test_is_confish_nonexistent():
assert harvest._is_confish("/nonexistent/file.xyz") is False
def test_topdirs_for_package_with_multiple_paths():
pkg_to_etc_paths = {
"nginx": ["/etc/nginx/nginx.conf", "/etc/nginx/sites-enabled/default"],
}
result = harvest._topdirs_for_package("nginx", pkg_to_etc_paths)
assert result == {"nginx"}
def test_topdirs_for_package_with_multiple_topdirs():
pkg_to_etc_paths = {
"multi": ["/etc/nginx/nginx.conf", "/etc/ssh/sshd_config"],
}
result = harvest._topdirs_for_package("multi", pkg_to_etc_paths)
assert result == {"nginx", "ssh"}
def test_topdirs_for_package_empty():
result = harvest._topdirs_for_package("empty", {})
assert result == set()
def test_topdirs_for_package_no_etc():
pkg_to_etc_paths = {
"other": ["/usr/share/doc/file"],
}
result = harvest._topdirs_for_package("other", pkg_to_etc_paths)
assert result == set()

View file

@ -1,3 +1,8 @@
from __future__ import annotations
import os
from pathlib import Path
from enroll.ignore import IgnorePolicy from enroll.ignore import IgnorePolicy
@ -8,3 +13,238 @@ def test_ignore_policy_denies_common_backup_files():
assert pol.deny_reason("/etc/group-") == "backup_file" assert pol.deny_reason("/etc/group-") == "backup_file"
assert pol.deny_reason("/etc/something~") == "backup_file" assert pol.deny_reason("/etc/something~") == "backup_file"
assert pol.deny_reason("/foobar") == "unreadable" assert pol.deny_reason("/foobar") == "unreadable"
def test_deny_reason_dir_with_denied_path():
pol = IgnorePolicy()
assert pol.deny_reason_dir("/etc/ssl/private/key") == "denied_path"
assert pol.deny_reason_dir("/etc/ssh/ssh_host_key") == "denied_path"
assert pol.deny_reason_dir("/etc/ssh") is None
def test_deny_reason_dir_unreadable(tmp_path: Path):
pol = IgnorePolicy()
nonexistent = tmp_path / "nonexistent"
assert pol.deny_reason_dir(str(nonexistent)) == "unreadable"
def test_deny_reason_dir_symlink(tmp_path: Path):
pol = IgnorePolicy()
real_dir = tmp_path / "real"
real_dir.mkdir()
link = tmp_path / "link"
os.symlink(str(real_dir), str(link))
assert pol.deny_reason_dir(str(link)) == "symlink"
def test_deny_reason_dir_not_directory(tmp_path: Path):
pol = IgnorePolicy()
regular_file = tmp_path / "file.txt"
regular_file.write_text("content", encoding="utf-8")
assert pol.deny_reason_dir(str(regular_file)) == "not_directory"
def test_deny_reason_dir_dangerous_mode(tmp_path: Path):
pol = IgnorePolicy(dangerous=True)
real_dir = tmp_path / "private"
real_dir.mkdir()
assert pol.deny_reason_dir(str(real_dir)) is None
def test_deny_reason_link_basic(tmp_path: Path):
pol = IgnorePolicy()
real_file = tmp_path / "real"
real_file.write_text("content", encoding="utf-8")
link = tmp_path / "link"
os.symlink(str(real_file), str(link))
assert pol.deny_reason_link(str(link)) is None
def test_deny_reason_link_denied_path():
pol = IgnorePolicy()
assert pol.deny_reason_link("/etc/ssh/ssh_host_rsa_key") == "denied_path"
def test_deny_reason_link_unreadable(tmp_path: Path):
pol = IgnorePolicy()
# Create a symlink in a directory that doesn't exist
# This simulates an unreadable path
broken_link = tmp_path / "broken_link"
os.symlink("/nonexistent/target", str(broken_link))
# Broken symlinks are still readable (we can readlink them)
# So they return None (allowed) unless they match deny globs
result = pol.deny_reason_link(str(broken_link))
# Broken symlinks are allowed - we can still read the link target
assert result is None
def test_deny_reason_link_not_symlink(tmp_path: Path):
pol = IgnorePolicy()
regular_file = tmp_path / "file.txt"
regular_file.write_text("content", encoding="utf-8")
assert pol.deny_reason_link(str(regular_file)) == "not_symlink"
def test_deny_reason_link_log_file():
pol = IgnorePolicy()
assert pol.deny_reason_link("/var/log/something.log") == "log_file"
def test_deny_reason_link_backup_file():
pol = IgnorePolicy()
assert pol.deny_reason_link("/etc/passwd-") == "backup_file"
assert pol.deny_reason_link("/etc/something~") == "backup_file"
def test_deny_reason_link_dangerous_mode(tmp_path: Path):
pol = IgnorePolicy(dangerous=True)
real_file = tmp_path / "real"
real_file.write_text("content", encoding="utf-8")
link = tmp_path / "link"
os.symlink(str(real_file), str(link))
assert pol.deny_reason_link(str(link)) is None
def test_iter_effective_lines_with_comments():
pol = IgnorePolicy()
content = b"""
# This is a comment
; This is also a comment
* continuation
def main():
pass
"""
lines = list(pol.iter_effective_lines(content))
assert b"def main():" in lines
assert b"# This is a comment" not in lines
def test_iter_effective_lines_with_block_comments():
pol = IgnorePolicy()
content = b"""
/* This is a block comment
spanning multiple lines */
int x = 5;
"""
lines = list(pol.iter_effective_lines(content))
assert b"int x = 5;" in lines
assert b"/*" not in lines
def test_iter_effective_lines_empty():
pol = IgnorePolicy()
content = b""
lines = list(pol.iter_effective_lines(content))
assert lines == []
def test_deny_reason_binary_not_allowed(tmp_path: Path):
pol = IgnorePolicy()
binary = tmp_path / "random.bin"
binary.write_bytes(b"\x00\x01\x02\x03")
reason = pol.deny_reason(str(binary))
assert reason == "binary_like"
def test_deny_reason_sensitive_content(tmp_path: Path):
pol = IgnorePolicy()
config = tmp_path / "config.txt"
config.write_text("password=secret123", encoding="utf-8")
reason = pol.deny_reason(str(config))
assert reason == "sensitive_content"
def test_deny_reason_sensitive_api_key(tmp_path: Path):
pol = IgnorePolicy()
config = tmp_path / "config.txt"
config.write_text("api_key=abc123", encoding="utf-8")
reason = pol.deny_reason(str(config))
assert reason == "sensitive_content"
def test_deny_reason_private_key(tmp_path: Path):
pol = IgnorePolicy()
key = tmp_path / "key.pem"
key.write_text(
"-----BEGIN RSA PRIVATE KEY-----\nMIIEowIBAAKCAQEA...", encoding="utf-8"
)
reason = pol.deny_reason(str(key))
assert reason == "sensitive_content"
def test_deny_reason_too_large(tmp_path: Path):
pol = IgnorePolicy(max_file_bytes=100)
large = tmp_path / "large.txt"
large.write_bytes(b"x" * 200)
reason = pol.deny_reason(str(large))
assert reason == "too_large"
def test_deny_reason_unreadable(tmp_path: Path):
pol = IgnorePolicy()
nonexistent = tmp_path / "nonexistent"
reason = pol.deny_reason(str(nonexistent))
assert reason == "unreadable"
def test_deny_reason_not_regular_file(tmp_path: Path):
pol = IgnorePolicy()
directory = tmp_path / "dir"
directory.mkdir()
reason = pol.deny_reason(str(directory))
assert reason == "not_regular_file"
def test_deny_reason_symlink_file(tmp_path: Path):
pol = IgnorePolicy()
real_file = tmp_path / "real"
real_file.write_text("content", encoding="utf-8")
link = tmp_path / "link"
os.symlink(str(real_file), str(link))
reason = pol.deny_reason(str(link))
assert reason == "not_regular_file"
def test_deny_reason_logs(tmp_path: Path):
pol = IgnorePolicy()
log = tmp_path / "test.log"
log.write_text("log content", encoding="utf-8")
assert pol.deny_reason(str(log)) == "log_file"
def test_deny_reason_backup_file(tmp_path: Path):
pol = IgnorePolicy()
backup = tmp_path / "file~"
backup.write_text("backup", encoding="utf-8")
assert pol.deny_reason(str(backup)) == "backup_file"
def test_deny_reason_shadow_file():
pol = IgnorePolicy()
assert pol.deny_reason("/etc/shadow") == "denied_path"
assert pol.deny_reason("/etc/gshadow") == "denied_path"
def test_deny_reason_ssl_private():
pol = IgnorePolicy()
assert pol.deny_reason("/etc/ssl/private/key.pem") == "denied_path"
def test_deny_reason_ssh_host_keys():
pol = IgnorePolicy()
assert pol.deny_reason("/etc/ssh/ssh_host_rsa_key") == "denied_path"
assert pol.deny_reason("/etc/ssh/ssh_host_ed25519_key") == "denied_path"
def test_deny_reason_letsencrypt():
pol = IgnorePolicy()
assert (
pol.deny_reason("/etc/letsencrypt/live/example.com/fullchain.pem")
== "denied_path"
)
def test_deny_reason_shadow_backup():
pol = IgnorePolicy()
assert pol.deny_reason("/etc/shadow-") == "backup_file"
assert pol.deny_reason("/etc/passwd-") == "backup_file"

View file

@ -892,3 +892,175 @@ def test_manifest_writes_firewall_runtime_role(tmp_path: Path):
assert ( assert (
out / "roles" / "firewall_runtime" / "files" / "firewall" / "ipset.save" out / "roles" / "firewall_runtime" / "files" / "firewall" / "ipset.save"
).exists() ).exists()
def test_try_yaml_with_yaml_installed():
result = manifest._try_yaml()
# PyYAML should be installed for tests
if result is None:
pytest.skip("PyYAML not installed")
assert hasattr(result, "safe_load")
assert hasattr(result, "dump")
def test_yaml_load_mapping_with_yaml(tmp_path: Path):
text = """
key1: value1
key2:
nested: value
list:
- item1
- item2
"""
result = manifest._yaml_load_mapping(text)
assert result["key1"] == "value1"
assert result["key2"]["nested"] == "value"
assert result["list"] == ["item1", "item2"]
def test_yaml_load_mapping_empty():
result = manifest._yaml_load_mapping("")
assert result == {}
def test_yaml_load_mapping_invalid():
result = manifest._yaml_load_mapping("invalid: yaml: :")
assert result == {}
def test_yaml_load_mapping_not_dict():
result = manifest._yaml_load_mapping("- item1\n- item2")
assert result == {}
def test_yaml_load_mapping_none():
result = manifest._yaml_load_mapping("~")
assert result == {}
def test_yaml_dump_mapping_with_yaml(tmp_path: Path):
obj = {"key1": "value1", "key2": 123}
result = manifest._yaml_dump_mapping(obj)
assert "key1: value1" in result
assert "key2:" in result
def test_yaml_dump_mapping_empty():
result = manifest._yaml_dump_mapping({})
# Empty dict produces '{}'
assert result.strip() == "{}"
def test_yaml_dump_mapping_with_nested(tmp_path: Path):
obj = {"key1": {"nested": "value"}}
result = manifest._yaml_dump_mapping(obj)
assert "nested:" in result
def test_merge_mappings_overwrite_simple():
existing = {"key1": "old", "key2": "keep"}
incoming = {"key1": "new", "key3": "added"}
result = manifest._merge_mappings_overwrite(existing, incoming)
assert result["key1"] == "new"
assert result["key2"] == "keep"
assert result["key3"] == "added"
def test_merge_mappings_overwrite_nested():
existing = {"key1": {"a": 1}}
incoming = {"key1": {"b": 2}}
result = manifest._merge_mappings_overwrite(existing, incoming)
# Nested dicts are replaced, not merged
assert result["key1"] == {"b": 2}
def test_merge_mappings_overwrite_empty():
result = manifest._merge_mappings_overwrite({}, {"key": "value"})
assert result == {"key": "value"}
result = manifest._merge_mappings_overwrite({"key": "value"}, {})
assert result == {"key": "value"}
def test_copy2_replace(tmp_path: Path):
src = tmp_path / "src.txt"
src.write_text("content", encoding="utf-8")
dst = tmp_path / "dst" / "subdir" / "dst.txt"
manifest._copy2_replace(str(src), str(dst))
assert dst.exists()
assert dst.read_text(encoding="utf-8") == "content"
def test_copy2_replace_preserves_metadata(tmp_path: Path):
src = tmp_path / "src.txt"
src.write_text("content", encoding="utf-8")
os.chmod(str(src), 0o644)
dst = tmp_path / "dst.txt"
manifest._copy2_replace(str(src), str(dst))
assert dst.exists()
st = dst.stat()
assert stat.S_IMODE(st.st_mode) == 0o644
def test_copy2_replace_atomic(tmp_path: Path):
src = tmp_path / "src.txt"
src.write_text("content", encoding="utf-8")
dst = tmp_path / "dst.txt"
# Write initial content
dst.write_text("old", encoding="utf-8")
manifest._copy2_replace(str(src), str(dst))
assert dst.read_text(encoding="utf-8") == "content"
def test_render_firewall_runtime_tasks_empty():
state = {"roles": {}}
result = manifest._render_firewall_runtime_tasks(state)
# Function always returns at least a basic playbook structure
assert isinstance(result, str)
assert len(result) > 0
def test_render_firewall_runtime_tasks_with_iptables():
state = {
"roles": {
"firewall_runtime": {
"role_name": "firewall_runtime",
"iptables_v4_save": "artifacts/firewall_runtime/iptables.save",
}
}
}
result = manifest._render_firewall_runtime_tasks(state)
assert len(result) >= 1
def test_render_firewall_runtime_tasks_with_ipset():
state = {
"roles": {
"firewall_runtime": {
"role_name": "firewall_runtime",
"ipset_save": "artifacts/firewall_runtime/ipset.save",
}
}
}
result = manifest._render_firewall_runtime_tasks(state)
assert len(result) >= 1
def test_render_firewall_runtime_tasks_with_ipv6():
state = {
"roles": {
"firewall_runtime": {
"role_name": "firewall_runtime",
"iptables_v6_save": "artifacts/firewall_runtime/ip6tables.save",
}
}
}
result = manifest._render_firewall_runtime_tasks(state)
assert len(result) >= 1

View file

@ -1,416 +0,0 @@
from __future__ import annotations
import json
import os
import stat
import subprocess
import sys
import types
from pathlib import Path
from types import SimpleNamespace
import pytest
from enroll.cache import _safe_component, new_harvest_cache_dir
from enroll.ignore import IgnorePolicy
from enroll.sopsutil import (
SopsError,
_pgp_arg,
decrypt_file_binary_to,
encrypt_file_binary,
)
def test_safe_component_sanitizes_and_bounds_length():
assert _safe_component(" ") == "unknown"
assert _safe_component("a/b c") == "a_b_c"
assert _safe_component("x" * 200) == "x" * 64
def test_new_harvest_cache_dir_uses_xdg_cache_home(tmp_path: Path, monkeypatch):
monkeypatch.setenv("XDG_CACHE_HOME", str(tmp_path / "xdg"))
hc = new_harvest_cache_dir(hint="my host/01")
assert hc.dir.exists()
assert "my_host_01" in hc.dir.name
assert str(hc.dir).startswith(str(tmp_path / "xdg"))
# best-effort: ensure directory is not world-readable on typical FS
try:
mode = stat.S_IMODE(hc.dir.stat().st_mode)
assert mode & 0o077 == 0
except OSError:
pass
def test_ignore_policy_denies_binary_and_sensitive_content(tmp_path: Path):
p_bin = tmp_path / "binfile"
p_bin.write_bytes(b"abc\x00def")
assert IgnorePolicy().deny_reason(str(p_bin)) == "binary_like"
p_secret = tmp_path / "secret.conf"
p_secret.write_text("password=foo\n", encoding="utf-8")
assert IgnorePolicy().deny_reason(str(p_secret)) == "sensitive_content"
# dangerous mode disables heuristic scanning (but still checks file-ness/size)
assert IgnorePolicy(dangerous=True).deny_reason(str(p_secret)) is None
def test_ignore_policy_denies_usr_local_shadow_by_glob():
# This should short-circuit before stat() (path doesn't need to exist).
assert IgnorePolicy().deny_reason("/usr/local/etc/shadow") == "denied_path"
def test_sops_pgp_arg_and_encrypt_decrypt_roundtrip(tmp_path: Path, monkeypatch):
assert _pgp_arg([" ABC ", "DEF"]) == "ABC,DEF"
with pytest.raises(SopsError):
_pgp_arg([])
# Stub out sops and subprocess.
import enroll.sopsutil as s
monkeypatch.setattr(s, "require_sops_cmd", lambda: "sops")
class R:
def __init__(self, rc: int, out: bytes, err: bytes = b""):
self.returncode = rc
self.stdout = out
self.stderr = err
calls = []
def fake_run(cmd, capture_output, check):
calls.append(cmd)
# Return a deterministic payload so we can assert file writes.
if "--encrypt" in cmd:
return R(0, b"ENCRYPTED")
if "--decrypt" in cmd:
return R(0, b"PLAINTEXT")
return R(1, b"", b"bad")
monkeypatch.setattr(s.subprocess, "run", fake_run)
src = tmp_path / "src.bin"
src.write_bytes(b"x")
enc = tmp_path / "out.sops"
dec = tmp_path / "out.bin"
encrypt_file_binary(src, enc, pgp_fingerprints=["ABC"], mode=0o600)
assert enc.read_bytes() == b"ENCRYPTED"
decrypt_file_binary_to(enc, dec, mode=0o644)
assert dec.read_bytes() == b"PLAINTEXT"
# Sanity: we invoked encrypt and decrypt.
assert any("--encrypt" in c for c in calls)
assert any("--decrypt" in c for c in calls)
def test_cache_dir_defaults_to_home_cache(monkeypatch, tmp_path: Path):
# Ensure default path uses ~/.cache when XDG_CACHE_HOME is unset.
from enroll.cache import enroll_cache_dir
monkeypatch.delenv("XDG_CACHE_HOME", raising=False)
monkeypatch.setattr(Path, "home", lambda: tmp_path)
p = enroll_cache_dir()
assert str(p).startswith(str(tmp_path))
assert p.name == "enroll"
def test_harvest_cache_state_json_property(tmp_path: Path):
from enroll.cache import HarvestCache
hc = HarvestCache(tmp_path / "h1")
assert hc.state_json == hc.dir / "state.json"
def test_cache_dir_security_rejects_symlink(tmp_path: Path):
from enroll.cache import _ensure_dir_secure
real = tmp_path / "real"
real.mkdir()
link = tmp_path / "link"
link.symlink_to(real, target_is_directory=True)
with pytest.raises(RuntimeError, match="Refusing to use symlink"):
_ensure_dir_secure(link)
def test_cache_dir_chmod_failures_are_ignored(monkeypatch, tmp_path: Path):
from enroll import cache
# Make the cache base path deterministic and writable.
monkeypatch.setattr(cache, "enroll_cache_dir", lambda: tmp_path)
# Force os.chmod to fail to cover the "except OSError: pass" paths.
monkeypatch.setattr(
os, "chmod", lambda *a, **k: (_ for _ in ()).throw(OSError("nope"))
)
hc = cache.new_harvest_cache_dir()
assert hc.dir.exists()
assert hc.dir.is_dir()
def test_stat_triplet_falls_back_to_numeric_ids(monkeypatch, tmp_path: Path):
from enroll.fsutil import stat_triplet
import pwd
import grp
p = tmp_path / "x"
p.write_text("x", encoding="utf-8")
# Force username/group resolution failures.
monkeypatch.setattr(
pwd, "getpwuid", lambda _uid: (_ for _ in ()).throw(KeyError("no user"))
)
monkeypatch.setattr(
grp, "getgrgid", lambda _gid: (_ for _ in ()).throw(KeyError("no group"))
)
owner, group, mode = stat_triplet(str(p))
assert owner.isdigit()
assert group.isdigit()
assert len(mode) == 4
def test_ignore_policy_iter_effective_lines_removes_block_comments():
from enroll.ignore import IgnorePolicy
pol = IgnorePolicy()
data = b"""keep1
/*
drop me
*/
keep2
"""
assert list(pol.iter_effective_lines(data)) == [b"keep1", b"keep2"]
def test_ignore_policy_deny_reason_dir_variants(tmp_path: Path):
from enroll.ignore import IgnorePolicy
pol = IgnorePolicy()
# denied by glob
assert pol.deny_reason_dir("/etc/shadow") == "denied_path"
# symlink rejected
d = tmp_path / "d"
d.mkdir()
link = tmp_path / "l"
link.symlink_to(d, target_is_directory=True)
assert pol.deny_reason_dir(str(link)) == "symlink"
# not a directory
f = tmp_path / "f"
f.write_text("x", encoding="utf-8")
assert pol.deny_reason_dir(str(f)) == "not_directory"
# ok
assert pol.deny_reason_dir(str(d)) is None
def test_run_jinjaturtle_parses_outputs(monkeypatch, tmp_path: Path):
# Fully unit-test enroll.jinjaturtle.run_jinjaturtle by stubbing subprocess.run.
from enroll.jinjaturtle import run_jinjaturtle
def fake_run(cmd, **kwargs): # noqa: ARG001
# cmd includes "-d <defaults> -t <template>"
d_idx = cmd.index("-d") + 1
t_idx = cmd.index("-t") + 1
defaults = Path(cmd[d_idx])
template = Path(cmd[t_idx])
defaults.write_text("---\nfoo: 1\n", encoding="utf-8")
template.write_text("value={{ foo }}\n", encoding="utf-8")
return SimpleNamespace(returncode=0, stdout="ok", stderr="")
monkeypatch.setattr(subprocess, "run", fake_run)
src = tmp_path / "src.ini"
src.write_text("foo=1\n", encoding="utf-8")
res = run_jinjaturtle("/bin/jinjaturtle", str(src), role_name="role1")
assert "foo: 1" in res.vars_text
assert "value=" in res.template_text
def test_run_jinjaturtle_raises_on_failure(monkeypatch, tmp_path: Path):
from enroll.jinjaturtle import run_jinjaturtle
def fake_run(cmd, **kwargs): # noqa: ARG001
return SimpleNamespace(returncode=2, stdout="out", stderr="bad")
monkeypatch.setattr(subprocess, "run", fake_run)
src = tmp_path / "src.ini"
src.write_text("x", encoding="utf-8")
with pytest.raises(RuntimeError, match="jinjaturtle failed"):
run_jinjaturtle("/bin/jinjaturtle", str(src), role_name="role1")
def test_require_sops_cmd_errors_when_missing(monkeypatch):
from enroll.sopsutil import require_sops_cmd, SopsError
monkeypatch.setattr("enroll.sopsutil.shutil.which", lambda _: None)
with pytest.raises(SopsError, match="not found on PATH"):
require_sops_cmd()
def test_get_enroll_version_reports_unknown_on_metadata_failure(monkeypatch):
import enroll.version as v
fake_meta = types.ModuleType("importlib.metadata")
def boom():
raise RuntimeError("boom")
fake_meta.packages_distributions = boom
fake_meta.version = lambda _dist: boom()
monkeypatch.setitem(sys.modules, "importlib.metadata", fake_meta)
assert v.get_enroll_version() == "unknown"
def test_get_enroll_version_returns_unknown_if_importlib_metadata_unavailable(
monkeypatch,
):
import builtins
import enroll.version as v
real_import = builtins.__import__
def fake_import(
name, globals=None, locals=None, fromlist=(), level=0
): # noqa: A002
if name == "importlib.metadata":
raise ImportError("no metadata")
return real_import(name, globals, locals, fromlist, level)
monkeypatch.setattr(builtins, "__import__", fake_import)
assert v.get_enroll_version() == "unknown"
def test_compare_harvests_and_format_report(tmp_path: Path):
from enroll.diff import compare_harvests, format_report
old = tmp_path / "old"
new = tmp_path / "new"
(old / "artifacts").mkdir(parents=True)
(new / "artifacts").mkdir(parents=True)
def write_state(base: Path, state: dict) -> None:
base.mkdir(parents=True, exist_ok=True)
(base / "state.json").write_text(json.dumps(state, indent=2), encoding="utf-8")
# Old bundle: pkg a@1.0, pkg b@1.0, one service, one user, one managed file.
old_state = {
"schema_version": 3,
"host": {"hostname": "h1"},
"inventory": {"packages": {"a": {"version": "1.0"}, "b": {"version": "1.0"}}},
"roles": {
"services": [
{
"unit": "svc.service",
"role_name": "svc",
"packages": ["a"],
"active_state": "inactive",
"sub_state": "dead",
"unit_file_state": "enabled",
"condition_result": None,
"managed_files": [
{
"path": "/etc/foo.conf",
"src_rel": "etc/foo.conf",
"owner": "root",
"group": "root",
"mode": "0644",
"reason": "modified_conffile",
}
],
}
],
"packages": [],
"users": {
"role_name": "users",
"users": [{"name": "alice", "shell": "/bin/sh"}],
},
"apt_config": {"role_name": "apt_config", "managed_files": []},
"etc_custom": {"role_name": "etc_custom", "managed_files": []},
"usr_local_custom": {"role_name": "usr_local_custom", "managed_files": []},
"extra_paths": {"role_name": "extra_paths", "managed_files": []},
},
}
(old / "artifacts" / "svc" / "etc").mkdir(parents=True, exist_ok=True)
(old / "artifacts" / "svc" / "etc" / "foo.conf").write_text("old", encoding="utf-8")
write_state(old, old_state)
# New bundle: pkg a@2.0, pkg c@1.0, service changed, user changed, file moved role+content.
new_state = {
"schema_version": 3,
"host": {"hostname": "h2"},
"inventory": {"packages": {"a": {"version": "2.0"}, "c": {"version": "1.0"}}},
"roles": {
"services": [
{
"unit": "svc.service",
"role_name": "svc",
"packages": ["a", "c"],
"active_state": "active",
"sub_state": "running",
"unit_file_state": "enabled",
"condition_result": None,
"managed_files": [],
}
],
"packages": [],
"users": {
"role_name": "users",
"users": [{"name": "alice", "shell": "/bin/bash"}, {"name": "bob"}],
},
"apt_config": {"role_name": "apt_config", "managed_files": []},
"etc_custom": {"role_name": "etc_custom", "managed_files": []},
"usr_local_custom": {"role_name": "usr_local_custom", "managed_files": []},
"extra_paths": {
"role_name": "extra_paths",
"managed_files": [
{
"path": "/etc/foo.conf",
"src_rel": "etc/foo.conf",
"owner": "root",
"group": "root",
"mode": "0600",
"reason": "user_include",
},
{
"path": "/etc/added.conf",
"src_rel": "etc/added.conf",
"owner": "root",
"group": "root",
"mode": "0644",
"reason": "user_include",
},
],
},
},
}
(new / "artifacts" / "extra_paths" / "etc").mkdir(parents=True, exist_ok=True)
(new / "artifacts" / "extra_paths" / "etc" / "foo.conf").write_text(
"new", encoding="utf-8"
)
(new / "artifacts" / "extra_paths" / "etc" / "added.conf").write_text(
"x", encoding="utf-8"
)
write_state(new, new_state)
report, changed = compare_harvests(str(old), str(new))
assert changed is True
txt = format_report(report, fmt="text")
assert "Packages" in txt
md = format_report(report, fmt="markdown")
assert "# enroll diff report" in md
js = format_report(report, fmt="json")
parsed = json.loads(js)
assert parsed["packages"]["added"] == ["c"]

View file

@ -184,3 +184,157 @@ def test_expand_includes_respects_max_files(monkeypatch):
paths, notes = pf.expand_includes(include, max_files=2) paths, notes = pf.expand_includes(include, max_files=2)
assert len(paths) == 2 assert len(paths) == 2
assert "/root/c" not in paths assert "/root/c" not in paths
def test_has_glob_chars():
assert pf._has_glob_chars("*.txt") is True
assert pf._has_glob_chars("file?.log") is True
assert pf._has_glob_chars("[abc]") is True
assert pf._has_glob_chars("file.txt") is False
assert pf._has_glob_chars("") is False
def test_compile_path_pattern_regex_valid():
result = pf.compile_path_pattern("re:^/home/.*$")
assert result.kind == "regex"
assert result.regex is not None
assert result.regex.search("/home/user/file.txt") is not None
assert result.regex.search("/var/file.txt") is None
def test_compile_path_pattern_glob_forced():
result = pf.compile_path_pattern("glob:/etc/*.conf")
assert result.kind == "glob"
assert result.value == "/etc/*.conf"
def test_compile_path_pattern_glob_heuristic():
result = pf.compile_path_pattern("/etc/*.conf")
assert result.kind == "glob"
def test_compile_path_pattern_prefix():
result = pf.compile_path_pattern("/etc/nginx")
assert result.kind == "prefix"
assert result.value == "/etc/nginx"
def test_compiled_pattern_matches_prefix():
pat = pf.compile_path_pattern("/etc/nginx")
assert pat.matches("/etc/nginx") is True
assert pat.matches("/etc/nginx/conf.d") is True
assert pat.matches("/etc/ssh") is False
def test_compiled_pattern_matches_glob():
pat = pf.compile_path_pattern("/etc/*.conf")
assert pat.matches("/etc/ssh.conf") is True
assert pat.matches("/etc/ssh/sshd.conf") is False
def test_compiled_pattern_matches_regex():
pat = pf.compile_path_pattern("re:^/home/[^/]+/.bashrc$")
assert pat.matches("/home/alice/.bashrc") is True
assert pat.matches("/home/bob/.bashrc") is True
assert pat.matches("/home/alice/.profile") is False
assert pat.matches("/var/.bashrc") is False
def test_path_filter_is_excluded():
pf_filter = pf.PathFilter(exclude=["/tmp/*", "/var/log"])
assert pf_filter.is_excluded("/tmp/file.txt") is True
assert pf_filter.is_excluded("/var/log/syslog") is True
assert pf_filter.is_excluded("/etc/ssh") is False
def test_path_filter_empty():
pf_filter = pf.PathFilter()
assert pf_filter.is_excluded("/anything") is False
assert pf_filter.iter_include_patterns() == []
def test_expand_includes_prefix_existing(tmp_path: Path):
etc_dir = tmp_path / "etc"
etc_dir.mkdir()
(etc_dir / "file1.txt").write_text("a")
(etc_dir / "file2.txt").write_text("b")
patterns = [pf.compile_path_pattern(str(etc_dir))]
paths, notes = pf.expand_includes(patterns, max_files=10)
assert len(paths) == 2
assert notes == []
def test_expand_includes_prefix_nonexistent():
patterns = [pf.compile_path_pattern("/nonexistent/path")]
paths, notes = pf.expand_includes(patterns, max_files=10)
assert paths == []
assert len(notes) == 1
assert "matched no files" in notes[0]
def test_expand_includes_glob_no_matches():
patterns = [pf.compile_path_pattern("/nonexistent/*.txt")]
paths, notes = pf.expand_includes(patterns, max_files=10)
assert paths == []
assert len(notes) == 1
def test_expand_includes_skips_symlinks(tmp_path: Path):
real_file = tmp_path / "real.txt"
real_file.write_text("x")
link = tmp_path / "link.txt"
os.symlink(str(real_file), str(link))
patterns = [pf.compile_path_pattern(str(tmp_path))]
paths, notes = pf.expand_includes(patterns, max_files=10)
assert len(paths) == 1
assert paths[0].endswith("real.txt")
def test_expand_includes_excludes_pattern(tmp_path: Path):
etc_dir = tmp_path / "etc"
etc_dir.mkdir()
(etc_dir / "include.txt").write_text("a")
(etc_dir / "exclude.txt").write_text("b")
patterns = [pf.compile_path_pattern(str(etc_dir))]
exclude = pf.PathFilter(exclude=["*exclude*"])
paths, notes = pf.expand_includes(patterns, exclude=exclude, max_files=10)
assert len(paths) == 1
assert paths[0].endswith("include.txt")
def test_expand_includes_skips_directories(tmp_path: Path):
subdir = tmp_path / "subdir"
subdir.mkdir()
(tmp_path / "file.txt").write_text("x")
patterns = [pf.compile_path_pattern(str(subdir))]
paths, notes = pf.expand_includes(patterns, max_files=10)
assert paths == []
def test_regex_literal_prefix_simple():
assert pf._regex_literal_prefix("/etc/nginx/") == "/etc/nginx/"
def test_regex_literal_prefix_with_anchor():
assert pf._regex_literal_prefix("^/etc/nginx/") == "/etc/nginx/"
def test_regex_literal_prefix_with_regex_chars():
assert pf._regex_literal_prefix("^/etc/.*\\.conf$") == "/etc/"
def test_path_filter_with_include_patterns():
pf_filter = pf.PathFilter(include=["/etc/*.conf"], exclude=["/etc/secret.conf"])
patterns = pf_filter.iter_include_patterns()
assert len(patterns) == 1
assert patterns[0].kind == "glob"

View file

@ -91,3 +91,176 @@ def test_specific_paths_for_hints_differs_between_backends():
paths = set(r.specific_paths_for_hints({"nginx"})) paths = set(r.specific_paths_for_hints({"nginx"}))
assert "/etc/sysconfig/nginx" in paths assert "/etc/sysconfig/nginx" in paths
assert "/etc/sysconfig/nginx.conf" in paths assert "/etc/sysconfig/nginx.conf" in paths
def test_read_os_release_file_not_found(tmp_path: Path):
result = platform._read_os_release(str(tmp_path / "nonexistent"))
assert result == {}
def test_read_os_release_handles_invalid_line(tmp_path: Path):
p = tmp_path / "os-release"
p.write_text(
"ID=ubuntu\n" "NO_EQUALS_SIGN\n" 'VERSION="22.04"\n',
encoding="utf-8",
)
result = platform._read_os_release(str(p))
assert result["ID"] == "ubuntu"
assert result["VERSION"] == "22.04"
assert "NO_EQUALS_SIGN" not in result
def test_detect_platform_debian(monkeypatch):
def fake_read_os_release(path: str = "/etc/os-release") -> dict:
return {"ID": "debian", "VERSION_ID": "11"}
monkeypatch.setattr(platform, "_read_os_release", fake_read_os_release)
result = platform.detect_platform()
assert result.os_family == "debian"
assert result.pkg_backend == "dpkg"
def test_detect_platform_ubuntu(monkeypatch):
def fake_read_os_release(path: str = "/etc/os-release") -> dict:
return {"ID": "ubuntu", "VERSION_ID": "22.04"}
monkeypatch.setattr(platform, "_read_os_release", fake_read_os_release)
result = platform.detect_platform()
assert result.os_family == "debian"
assert result.pkg_backend == "dpkg"
def test_detect_platform_fedora(monkeypatch):
def fake_read_os_release(path: str = "/etc/os-release") -> dict:
return {"ID": "fedora", "VERSION_ID": "38"}
monkeypatch.setattr(platform, "_read_os_release", fake_read_os_release)
result = platform.detect_platform()
assert result.os_family == "redhat"
assert result.pkg_backend == "rpm"
def test_detect_platform_rocky(monkeypatch):
def fake_read_os_release(path: str = "/etc/os-release") -> dict:
return {"ID": "rocky", "VERSION_ID": "9"}
monkeypatch.setattr(platform, "_read_os_release", fake_read_os_release)
result = platform.detect_platform()
assert result.os_family == "redhat"
assert result.pkg_backend == "rpm"
def test_detect_platform_unknown_fallback_to_dpkg(monkeypatch):
def fake_read_os_release(path: str = "/etc/os-release") -> dict:
return {"ID": "unknown"}
monkeypatch.setattr(platform, "_read_os_release", fake_read_os_release)
monkeypatch.setattr(platform.shutil, "which", lambda x: x == "dpkg")
result = platform.detect_platform()
assert result.os_family == "debian"
assert result.pkg_backend == "dpkg"
def test_detect_platform_unknown_fallback_to_rpm(monkeypatch):
def fake_read_os_release(path: str = "/etc/os-release") -> dict:
return {"ID": "unknown"}
monkeypatch.setattr(platform, "_read_os_release", fake_read_os_release)
monkeypatch.setattr(platform.shutil, "which", lambda x: x == "rpm")
result = platform.detect_platform()
assert result.os_family == "redhat"
assert result.pkg_backend == "rpm"
def test_detect_platform_completely_unknown(monkeypatch):
def fake_read_os_release(path: str = "/etc/os-release") -> dict:
return {"ID": "unknown"}
monkeypatch.setattr(platform, "_read_os_release", fake_read_os_release)
monkeypatch.setattr(platform.shutil, "which", lambda x: False)
result = platform.detect_platform()
assert result.os_family == "unknown"
assert result.pkg_backend == "unknown"
def test_detect_platform_debian_like(monkeypatch):
def fake_read_os_release(path: str = "/etc/os-release") -> dict:
return {"ID": "linuxmint", "ID_LIKE": "debian"}
monkeypatch.setattr(platform, "_read_os_release", fake_read_os_release)
result = platform.detect_platform()
assert result.os_family == "debian"
assert result.pkg_backend == "dpkg"
def test_detect_platform_rhel_like(monkeypatch):
def fake_read_os_release(path: str = "/etc/os-release") -> dict:
return {"ID": "centos", "ID_LIKE": "rhel fedora"}
monkeypatch.setattr(platform, "_read_os_release", fake_read_os_release)
result = platform.detect_platform()
assert result.os_family == "redhat"
assert result.pkg_backend == "rpm"
def test_get_backend_returns_dpkg(monkeypatch):
info = platform.PlatformInfo(os_family="debian", pkg_backend="dpkg", os_release={})
backend = platform.get_backend(info)
assert isinstance(backend, platform.DpkgBackend)
assert backend.name == "dpkg"
def test_get_backend_returns_rpm(monkeypatch):
info = platform.PlatformInfo(os_family="redhat", pkg_backend="rpm", os_release={})
backend = platform.get_backend(info)
assert isinstance(backend, platform.RpmBackend)
assert backend.name == "rpm"
def test_get_backend_unknown_with_rpm(monkeypatch):
info = platform.PlatformInfo(
os_family="unknown", pkg_backend="unknown", os_release={}
)
monkeypatch.setattr(platform.shutil, "which", lambda x: x == "rpm")
backend = platform.get_backend(info)
assert isinstance(backend, platform.RpmBackend)
def test_get_backend_unknown_with_dpkg(monkeypatch):
info = platform.PlatformInfo(
os_family="unknown", pkg_backend="unknown", os_release={}
)
monkeypatch.setattr(platform.shutil, "which", lambda x: x == "dpkg")
backend = platform.get_backend(info)
assert isinstance(backend, platform.DpkgBackend)
def test_dpkg_backend_specific_paths():
backend = platform.DpkgBackend()
paths = backend.specific_paths_for_hints({"nginx"})
assert "/etc/default/nginx" in paths
assert "/etc/init.d/nginx" in paths
assert "/etc/sysctl.d/nginx.conf" in paths
def test_rpm_backend_specific_paths():
backend = platform.RpmBackend()
paths = backend.specific_paths_for_hints({"nginx"})
assert "/etc/sysconfig/nginx" in paths
assert "/etc/sysconfig/nginx.conf" in paths
assert "/etc/sysctl.d/nginx.conf" in paths
def test_is_pkg_config_path_dpkg():
backend = platform.DpkgBackend()
assert backend.is_pkg_config_path("/etc/apt/sources.list") is True
assert backend.is_pkg_config_path("/etc/apt/trusted.gpg") is True
assert backend.is_pkg_config_path("/etc/ssh/sshd_config") is False
def test_is_pkg_config_path_rpm():
backend = platform.RpmBackend()
assert backend.is_pkg_config_path("/etc/dnf/dnf.conf") is True
assert backend.is_pkg_config_path("/etc/yum.conf") is True
assert backend.is_pkg_config_path("/etc/yum.repos.d/custom.repo") is True
assert backend.is_pkg_config_path("/etc/ssh/sshd_config") is False

View file

@ -565,3 +565,452 @@ def test_remote_harvest_sudo_password_retry_uses_sudo_s_and_writes_password(
# Ensure the password was written to stdin for the -S invocation. # Ensure the password was written to stdin for the -S invocation.
assert stdin_by_cmd.get(sudo_s[0]) == ["s3cr3t\n"] assert stdin_by_cmd.get(sudo_s[0]) == ["s3cr3t\n"]
def test_sudo_password_required_detection():
from enroll.remote import _sudo_password_required
assert _sudo_password_required("", "a password is required") is True
assert _sudo_password_required("", "password is required") is True
assert (
_sudo_password_required("", "a terminal is required to read the password")
is True
)
assert (
_sudo_password_required("", "no tty present and no askpass program specified")
is True
)
assert _sudo_password_required("", "must have a tty to run sudo") is True
assert _sudo_password_required("", "sudo: sorry, you must have a tty") is True
assert _sudo_password_required("", "askpass") is True
assert _sudo_password_required("success", "") is False
def test_sudo_not_permitted_detection():
from enroll.remote import _sudo_not_permitted
assert _sudo_not_permitted("", "user is not in the sudoers file") is True
assert _sudo_not_permitted("", "not allowed to execute") is True
assert _sudo_not_permitted("", "may not run sudo") is True
assert _sudo_not_permitted("", "sorry, user") is True
assert _sudo_not_permitted("success", "") is False
def test_sudo_tty_required_detection():
from enroll.remote import _sudo_tty_required
assert _sudo_tty_required("", "must have a tty") is True
assert _sudo_tty_required("", "sorry, you must have a tty") is True
assert _sudo_tty_required("", "sudo: sorry, you must have a tty") is True
assert _sudo_tty_required("", "must have a tty to run sudo") is True
assert _sudo_tty_required("success", "") is False
def test_resolve_become_password_prompts_when_asked(monkeypatch):
from enroll.remote import _resolve_become_password
prompted = []
def fake_getpass(prompt):
prompted.append(prompt)
return "secret"
result = _resolve_become_password(
True, prompt="sudo password: ", getpass_fn=fake_getpass
)
assert result == "secret"
assert len(prompted) == 1
def test_resolve_become_password_returns_none_when_not_asked():
from enroll.remote import _resolve_become_password
result = _resolve_become_password(False)
assert result is None
def test_resolve_ssh_key_passphrase_from_env(monkeypatch):
from enroll.remote import _resolve_ssh_key_passphrase
monkeypatch.setenv("SSH_KEY_PASS", "env_secret")
result = _resolve_ssh_key_passphrase(False, env_var="SSH_KEY_PASS")
assert result == "env_secret"
def test_resolve_ssh_key_passphrase_raises_when_env_not_set(monkeypatch):
from enroll.remote import _resolve_ssh_key_passphrase
monkeypatch.delenv("SSH_KEY_PASS", raising=False)
with pytest.raises(RuntimeError, match="SSH key passphrase environment variable"):
_resolve_ssh_key_passphrase(False, env_var="SSH_KEY_PASS")
def test_resolve_ssh_key_passphrase_prompts_when_asked(monkeypatch):
from enroll.remote import _resolve_ssh_key_passphrase
prompted = []
def fake_getpass(prompt):
prompted.append(prompt)
return "prompt_secret"
result = _resolve_ssh_key_passphrase(
True, prompt="SSH key passphrase: ", getpass_fn=fake_getpass
)
assert result == "prompt_secret"
assert len(prompted) == 1
def test_resolve_ssh_key_passphrase_returns_none_when_not_asked():
from enroll.remote import _resolve_ssh_key_passphrase
result = _resolve_ssh_key_passphrase(False, env_var=None)
assert result is None
def test_safe_extract_tar_rejects_absolute_paths(tmp_path: Path):
from enroll.remote import _safe_extract_tar
import io
import tarfile
bio = io.BytesIO()
with tarfile.open(fileobj=bio, mode="w:gz") as tf:
ti = tarfile.TarInfo(name="/etc/passwd")
ti.size = 1
tf.addfile(ti, io.BytesIO(b"x"))
bio.seek(0)
with tarfile.open(fileobj=bio, mode="r:gz") as tf:
with pytest.raises(RuntimeError, match="Unsafe tar member path"):
_safe_extract_tar(tf, tmp_path)
def test_safe_extract_tar_rejects_hardlinks(tmp_path: Path):
from enroll.remote import _safe_extract_tar
import io
import tarfile
bio = io.BytesIO()
with tarfile.open(fileobj=bio, mode="w:gz") as tf:
ti = tarfile.TarInfo(name="hardlink")
ti.type = tarfile.LNKTYPE
ti.linkname = "/etc/passwd"
tf.addfile(ti)
bio.seek(0)
with tarfile.open(fileobj=bio, mode="r:gz") as tf:
with pytest.raises(RuntimeError, match="Refusing to extract"):
_safe_extract_tar(tf, tmp_path)
def test_safe_extract_tar_rejects_device_nodes(tmp_path: Path):
from enroll.remote import _safe_extract_tar
import io
import tarfile
bio = io.BytesIO()
with tarfile.open(fileobj=bio, mode="w:gz") as tf:
ti = tarfile.TarInfo(name="device")
ti.type = tarfile.CHRTYPE
tf.addfile(ti)
bio.seek(0)
with tarfile.open(fileobj=bio, mode="r:gz") as tf:
with pytest.raises(RuntimeError, match="Refusing to extract"):
_safe_extract_tar(tf, tmp_path)
def test_safe_extract_tar_accepts_dot_entry(tmp_path: Path):
from enroll.remote import _safe_extract_tar
import io
import tarfile
bio = io.BytesIO()
with tarfile.open(fileobj=bio, mode="w:gz") as tf:
ti = tarfile.TarInfo(name=".")
ti.size = 0
tf.addfile(ti, io.BytesIO(b""))
bio.seek(0)
with tarfile.open(fileobj=bio, mode="r:gz") as tf:
_safe_extract_tar(tf, tmp_path)
def test_safe_extract_tar_accepts_valid_files(tmp_path: Path):
from enroll.remote import _safe_extract_tar
import io
import tarfile
bio = io.BytesIO()
with tarfile.open(fileobj=bio, mode="w:gz") as tf:
ti = tarfile.TarInfo(name="foo/bar.txt")
ti.size = 5
tf.addfile(ti, io.BytesIO(b"hello"))
bio.seek(0)
with tarfile.open(fileobj=bio, mode="r:gz") as tf:
_safe_extract_tar(tf, tmp_path)
assert (tmp_path / "foo" / "bar.txt").read_bytes() == b"hello"
def test_remote_harvest_ssh_key_passphrase_retry(monkeypatch, tmp_path: Path):
import sys
import enroll.remote as r
monkeypatch.setattr(
r,
"_build_enroll_pyz",
lambda td: (Path(td) / "enroll.pyz").write_bytes(b"PYZ")
or (Path(td) / "enroll.pyz"),
)
tgz = _make_tgz_bytes({"state.json": b'{"ok": true}\n'})
class _Chan:
def __init__(self, out: bytes = b"", err: bytes = b"", rc: int = 0):
self._out = out
self._err = err
self._out_i = 0
self._err_i = 0
self._rc = rc
self._closed = False
def recv_ready(self) -> bool:
return (not self._closed) and self._out_i < len(self._out)
def recv(self, n: int) -> bytes:
if self._closed:
return b""
chunk = self._out[self._out_i : self._out_i + n]
self._out_i += len(chunk)
return chunk
def recv_stderr_ready(self) -> bool:
return (not self._closed) and self._err_i < len(self._err)
def recv_stderr(self, n: int) -> bytes:
if self._closed:
return b""
chunk = self._err[self._err_i : self._err_i + n]
self._err_i += len(chunk)
return chunk
def exit_status_ready(self) -> bool:
return self._closed or (
self._out_i >= len(self._out) and self._err_i >= len(self._err)
)
def recv_exit_status(self) -> int:
return self._rc
def shutdown_write(self) -> None:
return
def close(self) -> None:
self._closed = True
class _Stdout:
def __init__(self, payload: bytes = b"", rc: int = 0, err: bytes = b""):
self._bio = io.BytesIO(payload)
self.channel = _Chan(out=payload, err=err, rc=rc)
def read(self, n: int = -1) -> bytes:
return self._bio.read(n)
class _Stderr:
def __init__(self, payload: bytes = b""):
self._bio = io.BytesIO(payload)
def read(self, n: int = -1) -> bytes:
return self._bio.read(n)
class _Stdin:
def __init__(self, cmd: str):
self._cmd = cmd
def write(self, s: str) -> None:
pass
def flush(self) -> None:
return
class _SFTP:
def put(self, _local: str, _remote: str) -> None:
return
def close(self) -> None:
return
class FakeSSH:
def __init__(self):
self._sftp = _SFTP()
def load_system_host_keys(self):
return
def set_missing_host_key_policy(self, _policy):
return
def connect(self, **_kwargs):
return
def open_sftp(self):
return self._sftp
def exec_command(self, cmd: str, *, get_pty: bool = False, **_kwargs):
if cmd.startswith("tar -cz -C"):
return (_Stdin(cmd), _Stdout(tgz, rc=0), _Stderr(b""))
if cmd == "mktemp -d":
return (_Stdin(cmd), _Stdout(b"/tmp/enroll-remote-789\n"), _Stderr())
if cmd.startswith("chmod 700"):
return (_Stdin(cmd), _Stdout(b""), _Stderr())
if " harvest " in cmd:
return (_Stdin(cmd), _Stdout(b""), _Stderr())
if cmd.startswith("rm -rf"):
return (_Stdin(cmd), _Stdout(b""), _Stderr())
return (_Stdin(cmd), _Stdout(b""), _Stderr())
def close(self):
return
RejectPolicy4 = type("RejectPolicy", (), {})
class FakeParamiko:
SSHClient = FakeSSH
RejectPolicy = RejectPolicy4 # type: ignore
PasswordRequiredException = Exception # type: ignore
monkeypatch.setitem(sys.modules, "paramiko", FakeParamiko)
prompts = []
def fake_getpass(prompt):
prompts.append(prompt)
return "passphrase"
out_dir = tmp_path / "out"
state_path = r.remote_harvest(
ask_key_passphrase=True,
getpass_fn=fake_getpass,
local_out_dir=out_dir,
remote_host="example.com",
remote_user="alice",
no_sudo=True,
)
assert state_path.exists()
assert len(prompts) == 1
def test_remote_harvest_ssh_key_passphrase_raises_when_not_interactive(
monkeypatch, tmp_path: Path
):
import sys
import enroll.remote as r
monkeypatch.setattr(
r,
"_build_enroll_pyz",
lambda td: (Path(td) / "enroll.pyz").write_bytes(b"PYZ")
or (Path(td) / "enroll.pyz"),
)
class _Chan:
def __init__(self):
self._closed = False
def recv_ready(self) -> bool:
return False
def recv(self, n: int) -> bytes:
return b""
def recv_stderr_ready(self) -> bool:
return False
def recv_stderr(self, n: int) -> bytes:
return b""
def exit_status_ready(self) -> bool:
return True
def recv_exit_status(self) -> int:
return 0
def shutdown_write(self) -> None:
return
def close(self) -> None:
self._closed = True
class _Stdout:
def __init__(self):
self.channel = _Chan()
def read(self, n: int = -1) -> bytes:
return b""
class _Stderr:
def read(self, n: int = -1) -> bytes:
return b""
class _SFTP:
def put(self, _local: str, _remote: str) -> None:
return
def close(self) -> None:
return
class FakeSSH:
def __init__(self):
self._sftp = _SFTP()
def load_system_host_keys(self):
return
def set_missing_host_key_policy(self, _policy):
return
def connect(self, **_kwargs):
raise Exception("PasswordRequired")
def open_sftp(self):
return self._sftp
def exec_command(self, cmd: str, **_kwargs):
return (_Stdout(), _Stdout(), _Stderr())
def close(self):
return
class RejectPolicy:
pass
RejectPolicy3 = RejectPolicy
class FakeParamiko:
SSHClient = FakeSSH
RejectPolicy = RejectPolicy3 # type: ignore
PasswordRequiredException = Exception # type: ignore
monkeypatch.setitem(sys.modules, "paramiko", FakeParamiko)
out_dir = tmp_path / "out"
with pytest.raises(RuntimeError, match="SSH private key is encrypted"):
r.remote_harvest(
ask_key_passphrase=False,
local_out_dir=out_dir,
remote_host="example.com",
stdin=io.StringIO(),
)

View file

@ -1,4 +1,5 @@
from __future__ import annotations from __future__ import annotations
import pytest
import enroll.rpm as rpm import enroll.rpm as rpm
@ -176,3 +177,33 @@ def test_rpm_owner_strips_epoch_prefix_when_present(monkeypatch):
lambda cmd, allow_fail=False, merge_err=False: (0, "1:bash\n"), lambda cmd, allow_fail=False, merge_err=False: (0, "1:bash\n"),
) )
assert rpm.rpm_owner("/bin/bash") == "bash" assert rpm.rpm_owner("/bin/bash") == "bash"
def test_strip_arch_no_suffix():
assert rpm._strip_arch("vim") == "vim"
assert rpm._strip_arch("nginx ") == "nginx"
def test_strip_arch_with_unknown_suffix():
assert rpm._strip_arch("package.unknown") == "package.unknown"
def test_run_command_raises_on_fail():
with pytest.raises(RuntimeError):
rpm._run(["sh", "-c", "echo stderr >&2; exit 1"], allow_fail=False)
def test_rpm_owner_empty_path():
assert rpm.rpm_owner("") is None
def test_rpm_modified_files_empty(monkeypatch):
monkeypatch.setattr(
rpm, "_run", lambda cmd, allow_fail=False, merge_err=False: (0, "")
)
assert rpm.rpm_modified_files("vim") == set()
def test_list_manual_packages_no_commands_available(monkeypatch):
monkeypatch.setattr(rpm.shutil, "which", lambda exe: None)
assert rpm.list_manual_packages() == []

54
tests/test_sopsutil.py Normal file
View file

@ -0,0 +1,54 @@
from __future__ import annotations
import pytest
from enroll.sopsutil import SopsError, _pgp_arg, find_sops_cmd, require_sops_cmd
def test_find_sops_cmd():
result = find_sops_cmd()
if result is None:
pytest.skip("sops not installed")
assert result.endswith("sops")
def test_require_sops_cmd():
exe = require_sops_cmd()
assert exe is not None
assert "sops" in exe
def test_require_sops_cmd_raises_when_not_found(monkeypatch):
import enroll.sopsutil as s
def fake_find():
return None
monkeypatch.setattr(s, "find_sops_cmd", fake_find)
with pytest.raises(SopsError) as exc_info:
require_sops_cmd()
assert "sops" in str(exc_info.value).lower()
assert "not found" in str(exc_info.value).lower()
def test_pgp_arg_with_empty_fingerprints():
with pytest.raises(SopsError) as exc_info:
_pgp_arg([])
assert "No GPG fingerprints" in str(exc_info.value)
def test_pgp_arg_with_whitespace_fingerprints():
result = _pgp_arg([" ", "ABC123", " DEF456 "])
assert result == "ABC123,DEF456"
def test_pgp_arg_with_single_fingerprint():
result = _pgp_arg(["ABC123DEF456"])
assert result == "ABC123DEF456"
def test_pgp_arg_with_multiple_fingerprints():
result = _pgp_arg(["ABC123", "DEF456", "GHI789"])
assert result == "ABC123,DEF456,GHI789"

View file

@ -1,11 +1,10 @@
from __future__ import annotations from __future__ import annotations
import pytest import pytest
import enroll.systemd as s
def test_list_enabled_services_and_timers_filters_templates(monkeypatch): def test_list_enabled_services_and_timers_filters_templates(monkeypatch):
import enroll.systemd as s
def fake_run(cmd: list[str]) -> str: def fake_run(cmd: list[str]) -> str:
if "--type=service" in cmd: if "--type=service" in cmd:
return "\n".join( return "\n".join(
@ -35,8 +34,6 @@ def test_list_enabled_services_and_timers_filters_templates(monkeypatch):
def test_get_unit_info_parses_fields(monkeypatch): def test_get_unit_info_parses_fields(monkeypatch):
import enroll.systemd as s
class P: class P:
def __init__(self, rc: int, out: str, err: str = ""): def __init__(self, rc: int, out: str, err: str = ""):
self.returncode = rc self.returncode = rc
@ -71,8 +68,6 @@ def test_get_unit_info_parses_fields(monkeypatch):
def test_get_unit_info_raises_unit_query_error(monkeypatch): def test_get_unit_info_raises_unit_query_error(monkeypatch):
import enroll.systemd as s
class P: class P:
def __init__(self, rc: int, out: str, err: str): def __init__(self, rc: int, out: str, err: str):
self.returncode = rc self.returncode = rc
@ -90,8 +85,6 @@ def test_get_unit_info_raises_unit_query_error(monkeypatch):
def test_get_timer_info_parses_fields(monkeypatch): def test_get_timer_info_parses_fields(monkeypatch):
import enroll.systemd as s
class P: class P:
def __init__(self, rc: int, out: str, err: str = ""): def __init__(self, rc: int, out: str, err: str = ""):
self.returncode = rc self.returncode = rc
@ -119,3 +112,123 @@ def test_get_timer_info_parses_fields(monkeypatch):
ti = s.get_timer_info("apt-daily.timer") ti = s.get_timer_info("apt-daily.timer")
assert ti.trigger_unit == "apt-daily.service" assert ti.trigger_unit == "apt-daily.service"
assert "/etc/default/apt" in ti.env_files assert "/etc/default/apt" in ti.env_files
def test_list_enabled_services_empty_output(monkeypatch):
def fake_run(cmd: list[str]) -> str:
return ""
monkeypatch.setattr(s, "_run", fake_run)
result = s.list_enabled_services()
assert result == []
def test_list_enabled_timers_empty_output(monkeypatch):
def fake_run(cmd: list[str]) -> str:
return ""
monkeypatch.setattr(s, "_run", fake_run)
result = s.list_enabled_timers()
assert result == []
def test_list_enabled_services_with_only_templates(monkeypatch):
def fake_run(cmd: list[str]) -> str:
return "getty@.service enabled\n"
monkeypatch.setattr(s, "_run", fake_run)
result = s.list_enabled_services()
assert result == []
def test_list_enabled_timers_with_only_templates(monkeypatch):
def fake_run(cmd: list[str]) -> str:
return "foo@.timer enabled\n"
monkeypatch.setattr(s, "_run", fake_run)
result = s.list_enabled_timers()
assert result == []
def test_get_timer_info_raises_on_failure(monkeypatch):
class P:
def __init__(self, rc: int, out: str, err: str):
self.returncode = rc
self.stdout = out
self.stderr = err
def fake_run(cmd, text, capture_output):
return P(1, "", "timer not found")
monkeypatch.setattr(s.subprocess, "run", fake_run)
with pytest.raises(RuntimeError) as exc_info:
s.get_timer_info("nonexistent.timer")
assert "nonexistent.timer" in str(exc_info.value)
def test_get_timer_info_with_empty_fields(monkeypatch):
class P:
def __init__(self, rc: int, out: str, err: str = ""):
self.returncode = rc
self.stdout = out
self.stderr = err
def fake_run(cmd, text, capture_output):
return P(
0,
"\n".join(
[
"FragmentPath=",
"DropInPaths=",
"EnvironmentFiles=",
"Unit=",
"ActiveState=",
"SubState=",
"UnitFileState=",
"ConditionResult=",
]
),
)
monkeypatch.setattr(s.subprocess, "run", fake_run)
ti = s.get_timer_info("empty.timer")
assert ti.fragment_path is None
assert ti.dropin_paths == []
assert ti.env_files == []
assert ti.trigger_unit is None
assert ti.active_state is None
def test_get_unit_info_with_empty_fields(monkeypatch):
class P:
def __init__(self, rc: int, out: str, err: str = ""):
self.returncode = rc
self.stdout = out
self.stderr = err
def fake_run(cmd, check, text, capture_output):
return P(
0,
"\n".join(
[
"FragmentPath=",
"DropInPaths=",
"EnvironmentFiles=",
"ExecStart=",
"ActiveState=",
"SubState=",
"UnitFileState=",
"ConditionResult=",
]
),
)
monkeypatch.setattr(s.subprocess, "run", fake_run)
ui = s.get_unit_info("empty.service")
assert ui.fragment_path is None
assert ui.dropin_paths == []
assert ui.env_files == []
assert ui.exec_paths == []
assert ui.active_state is None

View file

@ -180,3 +180,234 @@ def test_cli_validate_exits_1_on_validation_warning_with_flag(
with pytest.raises(SystemExit) as e: with pytest.raises(SystemExit) as e:
cli.main() cli.main()
assert e.value.code == 1 assert e.value.code == 1
def test_validation_result_ok():
from enroll.validate import ValidationResult
result = ValidationResult(errors=[], warnings=[])
assert result.ok is True
assert result.to_text() == "OK: harvest bundle validated\n"
def test_validation_result_with_errors():
from enroll.validate import ValidationResult
result = ValidationResult(errors=["error1", "error2"], warnings=[])
assert result.ok is False
text = result.to_text()
assert "ERROR: 2 validation error(s)" in text
assert "error1" in text
assert "error2" in text
def test_validation_result_with_warnings():
from enroll.validate import ValidationResult
result = ValidationResult(errors=[], warnings=["warn1"])
assert result.ok is True
text = result.to_text()
assert "WARN: 1 warning(s)" in text
assert "warn1" in text
def test_validation_result_to_dict():
from enroll.validate import ValidationResult
result = ValidationResult(errors=["e1"], warnings=["w1"])
d = result.to_dict()
assert d["ok"] is False
assert d["errors"] == ["e1"]
assert d["warnings"] == ["w1"]
def test_iter_managed_files_singleton_roles():
from enroll.validate import _iter_managed_files
state = {
"roles": {
"users": {"managed_files": [{"path": "/etc/passwd", "src_rel": "passwd"}]},
"packages": [
{
"role_name": "vim",
"managed_files": [{"path": "/usr/bin/vim", "src_rel": "vim"}],
}
],
}
}
files = _iter_managed_files(state)
assert len(files) == 2
assert ("users", {"path": "/etc/passwd", "src_rel": "passwd"}) in files
def test_iter_managed_files_services_role():
from enroll.validate import _iter_managed_files
state = {
"roles": {
"services": [
{
"role_name": "nginx",
"managed_files": [
{"path": "/etc/nginx/nginx.conf", "src_rel": "nginx.conf"}
],
}
]
}
}
files = _iter_managed_files(state)
assert len(files) == 1
assert files[0][0] == "nginx"
def test_iter_managed_files_handles_non_dict_items():
from enroll.validate import _iter_managed_files
state = {
"roles": {
"users": {
"managed_files": [
"not_a_dict",
{"path": "/etc/passwd", "src_rel": "passwd"},
]
},
"services": ["not_a_dict", {"role_name": "nginx", "managed_files": []}],
"packages": ["not_a_dict"],
}
}
files = _iter_managed_files(state)
assert len(files) == 1
def test_iter_managed_files_empty_state():
from enroll.validate import _iter_managed_files
state = {"roles": {}}
files = _iter_managed_files(state)
assert files == []
def test_validate_harvest_missing_state_json(tmp_path: Path):
bundle_dir = tmp_path / "bundle"
bundle_dir.mkdir()
result = validate_harvest(str(bundle_dir))
assert result.ok is False
assert any("missing state.json" in e for e in result.errors)
def test_validate_harvest_invalid_json(tmp_path: Path):
bundle_dir = tmp_path / "bundle"
bundle_dir.mkdir()
state_file = bundle_dir / "state.json"
state_file.write_text("not valid json", encoding="utf-8")
result = validate_harvest(str(bundle_dir))
assert result.ok is False
assert any("failed to parse" in e for e in result.errors)
def test_validate_harvest_schema_error(tmp_path: Path):
bundle_dir = tmp_path / "bundle"
bundle_dir.mkdir()
state_file = bundle_dir / "state.json"
state_file.write_text("{}", encoding="utf-8")
result = validate_harvest(
str(bundle_dir), schema="https://invalid.invalid/schema.json"
)
assert result.ok is False
assert any("failed to load/validate schema" in e for e in result.errors)
def test_validate_harvest_missing_artifact(tmp_path: Path):
bundle_dir = tmp_path / "bundle"
bundle_dir.mkdir()
artifacts_dir = bundle_dir / "artifacts"
artifacts_dir.mkdir()
state = {
"roles": {
"users": {"managed_files": [{"path": "/etc/passwd", "src_rel": "passwd"}]}
}
}
state_file = bundle_dir / "state.json"
state_file.write_text(json.dumps(state), encoding="utf-8")
result = validate_harvest(str(bundle_dir))
assert result.ok is False
assert any("missing artifact" in e for e in result.errors)
def test_validate_harvest_suspicious_src_rel(tmp_path: Path):
bundle_dir = tmp_path / "bundle"
bundle_dir.mkdir()
state = {
"roles": {
"users": {
"managed_files": [
{"path": "/etc/passwd", "src_rel": "../../../etc/passwd"}
]
}
}
}
state_file = bundle_dir / "state.json"
state_file.write_text(json.dumps(state), encoding="utf-8")
result = validate_harvest(str(bundle_dir))
assert result.ok is False
assert any("suspicious src_rel" in e for e in result.errors)
def test_validate_harvest_missing_src_rel(tmp_path: Path):
bundle_dir = tmp_path / "bundle"
bundle_dir.mkdir()
state = {"roles": {"users": {"managed_files": [{"path": "/etc/passwd"}]}}}
state_file = bundle_dir / "state.json"
state_file.write_text(json.dumps(state), encoding="utf-8")
result = validate_harvest(str(bundle_dir))
assert result.ok is False
assert any("missing src_rel" in e for e in result.errors)
def test_validate_harvest_firewall_runtime_missing(tmp_path: Path):
bundle_dir = tmp_path / "bundle"
bundle_dir.mkdir()
artifacts_dir = bundle_dir / "artifacts"
fw_dir = artifacts_dir / "firewall_runtime"
fw_dir.mkdir(parents=True)
state = {
"roles": {
"firewall_runtime": {
"role_name": "firewall_runtime",
"iptables_v4_save": "iptables.save",
}
}
}
state_file = bundle_dir / "state.json"
state_file.write_text(json.dumps(state), encoding="utf-8")
result = validate_harvest(str(bundle_dir))
assert result.ok is False
assert any("missing firewall runtime artifact" in e for e in result.errors)
def test_validate_harvest_firewall_runtime_suspicious(tmp_path: Path):
bundle_dir = tmp_path / "bundle"
bundle_dir.mkdir()
state = {
"roles": {
"firewall_runtime": {
"role_name": "firewall_runtime",
"iptables_v4_save": "../../../etc/passwd",
}
}
}
state_file = bundle_dir / "state.json"
state_file.write_text(json.dumps(state), encoding="utf-8")
result = validate_harvest(str(bundle_dir))
assert result.ok is False
assert any("suspicious src_rel" in e for e in result.errors)
def test_validate_harvest_no_schema_option(tmp_path: Path):
bundle_dir = tmp_path / "bundle"
bundle_dir.mkdir()
state_file = bundle_dir / "state.json"
state_file.write_text("invalid json", encoding="utf-8")
result = validate_harvest(str(bundle_dir), no_schema=True)
assert result.ok is False
assert any("failed to parse" in e for e in result.errors)