Compare commits
2 commits
b25dd1e314
...
bf735c8328
| Author | SHA1 | Date | |
|---|---|---|---|
| bf735c8328 | |||
| 1544dc0295 |
17 changed files with 4004 additions and 454 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
|
@ -8,3 +8,4 @@ dist
|
|||
*.pdf
|
||||
*.csv
|
||||
*.html
|
||||
coverage.xml
|
||||
|
|
|
|||
|
|
@ -141,3 +141,174 @@ def test_collect_non_system_users(monkeypatch, tmp_path: Path):
|
|||
assert u.primary_group == "users"
|
||||
assert u.supplementary_groups == ["admins"]
|
||||
assert u.ssh_files == ["/home/alice/.ssh/authorized_keys"]
|
||||
|
||||
|
||||
def test_parse_login_defs_file_not_found(tmp_path: Path):
|
||||
from enroll.accounts import parse_login_defs
|
||||
|
||||
nonexistent = tmp_path / "nonexistent" / "login.defs"
|
||||
vals = parse_login_defs(str(nonexistent))
|
||||
assert vals == {}
|
||||
|
||||
|
||||
def test_parse_login_defs_handles_invalid_numbers(tmp_path: Path):
|
||||
from enroll.accounts import parse_login_defs
|
||||
|
||||
p = tmp_path / "login.defs"
|
||||
p.write_text("UID_MIN not_a_number\nUID_MAX 60000\n", encoding="utf-8")
|
||||
vals = parse_login_defs(str(p))
|
||||
assert "UID_MIN" not in vals
|
||||
assert vals["UID_MAX"] == 60000
|
||||
|
||||
|
||||
def test_parse_group_handles_invalid_gid(tmp_path: Path):
|
||||
from enroll.accounts import parse_group
|
||||
|
||||
p = tmp_path / "group"
|
||||
p.write_text(
|
||||
"valid:x:1000:user1\n" "invalid_gid:x:notanint:user2\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
gid_to_name, name_to_gid, members = parse_group(str(p))
|
||||
assert 1000 in gid_to_name
|
||||
assert gid_to_name[1000] == "valid"
|
||||
assert "invalid_gid" not in name_to_gid
|
||||
|
||||
|
||||
def test_parse_group_line_too_short(tmp_path: Path):
|
||||
from enroll.accounts import parse_group
|
||||
|
||||
p = tmp_path / "group"
|
||||
p.write_text(
|
||||
"valid:x:1000:user1\n" "shortline:x:1001\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
gid_to_name, name_to_gid, members = parse_group(str(p))
|
||||
assert 1000 in gid_to_name
|
||||
assert 1001 not in gid_to_name
|
||||
|
||||
|
||||
def test_is_human_user_filters_by_uid_and_shell():
|
||||
from enroll.accounts import is_human_user
|
||||
|
||||
assert is_human_user(1000, "/bin/bash", 1000) is True
|
||||
assert is_human_user(999, "/bin/bash", 1000) is False
|
||||
assert is_human_user(1000, "/usr/sbin/nologin", 1000) is False
|
||||
assert is_human_user(1000, "/usr/bin/nologin", 1000) is False
|
||||
assert is_human_user(1000, "/bin/false", 1000) is False
|
||||
assert is_human_user(1000, "", 1000) is True
|
||||
|
||||
|
||||
def test_find_user_ssh_files_no_ssh_dir(tmp_path: Path):
|
||||
from enroll.accounts import find_user_ssh_files
|
||||
|
||||
home = tmp_path / "home" / "user"
|
||||
home.mkdir(parents=True)
|
||||
assert find_user_ssh_files(str(home)) == []
|
||||
|
||||
|
||||
def test_find_user_ssh_files_ignores_symlink(tmp_path: Path):
|
||||
from enroll.accounts import find_user_ssh_files
|
||||
|
||||
home = tmp_path / "home" / "user"
|
||||
sshdir = home / ".ssh"
|
||||
sshdir.mkdir(parents=True)
|
||||
target = sshdir / "real_file"
|
||||
target.write_text("x", encoding="utf-8")
|
||||
os.symlink(str(target), str(sshdir / "authorized_keys"))
|
||||
|
||||
result = find_user_ssh_files(str(home))
|
||||
assert result == []
|
||||
|
||||
|
||||
def test_find_user_ssh_files_handles_home_not_starting_with_slash():
|
||||
from enroll.accounts import find_user_ssh_files
|
||||
|
||||
assert find_user_ssh_files("relative/path") == []
|
||||
assert find_user_ssh_files("") == []
|
||||
|
||||
|
||||
def test_collect_non_system_users_skips_nologin_users(tmp_path: Path):
|
||||
import enroll.accounts as a
|
||||
|
||||
orig_parse_login_defs = a.parse_login_defs
|
||||
orig_parse_passwd = a.parse_passwd
|
||||
orig_parse_group = a.parse_group
|
||||
|
||||
passwd = tmp_path / "passwd"
|
||||
passwd.write_text(
|
||||
"root:x:0:0:root:/root:/bin/bash\n"
|
||||
"alice:x:1000:1000:Alice:/home/alice:/bin/bash\n"
|
||||
"nobody:x:65534:65534:nobody:/nonexistent:/usr/sbin/nologin\n"
|
||||
"sysuser:x:100:100:Sys:/home/sys:/bin/bash\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
group = tmp_path / "group"
|
||||
group.write_text("users:x:1000:alice\n", encoding="utf-8")
|
||||
defs = tmp_path / "login.defs"
|
||||
defs.write_text("UID_MIN 1000\n", encoding="utf-8")
|
||||
|
||||
monkeypatch_wrapper = lambda fn, p: lambda path=str(p): fn(path)
|
||||
|
||||
a.parse_login_defs = monkeypatch_wrapper(orig_parse_login_defs, defs)
|
||||
a.parse_passwd = monkeypatch_wrapper(orig_parse_passwd, passwd)
|
||||
a.parse_group = monkeypatch_wrapper(orig_parse_group, group)
|
||||
a.find_user_ssh_files = lambda home: []
|
||||
|
||||
users = a.collect_non_system_users()
|
||||
assert [u.name for u in users] == ["alice"]
|
||||
|
||||
|
||||
def test_collect_non_system_users_skips_below_uid_min(tmp_path: Path):
|
||||
import enroll.accounts as a
|
||||
|
||||
orig_parse_login_defs = a.parse_login_defs
|
||||
orig_parse_passwd = a.parse_passwd
|
||||
orig_parse_group = a.parse_group
|
||||
|
||||
passwd = tmp_path / "passwd"
|
||||
passwd.write_text(
|
||||
"root:x:0:0:root:/root:/bin/bash\n"
|
||||
"sysuser:x:999:999:Sys:/home/sys:/bin/bash\n"
|
||||
"alice:x:1000:1000:Alice:/home/alice:/bin/bash\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
group = tmp_path / "group"
|
||||
group.write_text("users:x:1000:alice\n", encoding="utf-8")
|
||||
defs = tmp_path / "login.defs"
|
||||
defs.write_text("UID_MIN 1000\n", encoding="utf-8")
|
||||
|
||||
a.parse_login_defs = lambda path=str(defs): orig_parse_login_defs(path)
|
||||
a.parse_passwd = lambda path=str(passwd): orig_parse_passwd(path)
|
||||
a.parse_group = lambda path=str(group): orig_parse_group(path)
|
||||
a.find_user_ssh_files = lambda home: []
|
||||
|
||||
users = a.collect_non_system_users()
|
||||
assert [u.name for u in users] == ["alice"]
|
||||
|
||||
|
||||
def test_parse_group_handles_empty_lines(tmp_path: Path):
|
||||
from enroll.accounts import parse_group
|
||||
|
||||
p = tmp_path / "group"
|
||||
p.write_text(
|
||||
"valid:x:1000:user1\n" "\n" "another:x:1001:user2\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
gid_to_name, name_to_gid, members = parse_group(str(p))
|
||||
assert 1000 in gid_to_name
|
||||
assert 1001 in gid_to_name
|
||||
|
||||
|
||||
def test_parse_group_handles_short_lines(tmp_path: Path):
|
||||
from enroll.accounts import parse_group
|
||||
|
||||
p = tmp_path / "group"
|
||||
p.write_text(
|
||||
"valid:x:1000:user1\n" "short:x:1001\n" "another:x:1002:user2\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
gid_to_name, name_to_gid, members = parse_group(str(p))
|
||||
assert 1000 in gid_to_name
|
||||
assert 1001 not in gid_to_name # skipped due to short line
|
||||
assert 1002 in gid_to_name
|
||||
|
|
|
|||
|
|
@ -31,3 +31,67 @@ def test_ensure_dir_secure_ignores_chmod_failures(tmp_path: Path, monkeypatch):
|
|||
# Should not raise.
|
||||
_ensure_dir_secure(d)
|
||||
assert d.exists() and d.is_dir()
|
||||
|
||||
|
||||
def test_safe_component_returns_unknown_for_empty_string():
|
||||
from enroll.cache import _safe_component
|
||||
|
||||
assert _safe_component("") == "unknown"
|
||||
assert _safe_component(" ") == "unknown"
|
||||
|
||||
|
||||
def test_safe_component_truncates_long_strings():
|
||||
from enroll.cache import _safe_component
|
||||
|
||||
long_str = "a" * 100
|
||||
result = _safe_component(long_str)
|
||||
assert len(result) <= 64
|
||||
|
||||
|
||||
def test_safe_component_replaces_special_chars():
|
||||
from enroll.cache import _safe_component
|
||||
|
||||
result = _safe_component("hello world!")
|
||||
assert result == "hello_world_"
|
||||
|
||||
|
||||
def test_enroll_cache_dir_uses_xdg_cache_home(monkeypatch):
|
||||
from enroll.cache import enroll_cache_dir
|
||||
|
||||
monkeypatch.setenv("XDG_CACHE_HOME", "/custom/cache")
|
||||
result = enroll_cache_dir()
|
||||
assert str(result) == "/custom/cache/enroll"
|
||||
|
||||
|
||||
def test_harvest_cache_state_json_property():
|
||||
from enroll.cache import HarvestCache
|
||||
|
||||
cache_dir = HarvestCache(dir=Path("/tmp/test"))
|
||||
assert cache_dir.state_json == Path("/tmp/test/state.json")
|
||||
|
||||
|
||||
def test_new_harvest_cache_dir_chmod_fails(tmp_path: Path, monkeypatch):
|
||||
from enroll.cache import new_harvest_cache_dir
|
||||
|
||||
def fake_enroll_cache_dir():
|
||||
return tmp_path / "enroll"
|
||||
|
||||
def fake_chmod(path, mode):
|
||||
raise OSError("no")
|
||||
|
||||
monkeypatch.setattr("enroll.cache.enroll_cache_dir", fake_enroll_cache_dir)
|
||||
monkeypatch.setattr(os, "chmod", fake_chmod)
|
||||
|
||||
# Should not raise even though chmod fails
|
||||
cache = new_harvest_cache_dir(hint="test")
|
||||
assert cache.dir.exists()
|
||||
assert isinstance(cache.dir, Path)
|
||||
|
||||
|
||||
def test_enroll_cache_dir_uses_default_when_xdg_not_set(monkeypatch):
|
||||
from enroll.cache import enroll_cache_dir
|
||||
|
||||
# Remove XDG_CACHE_HOME if it exists
|
||||
monkeypatch.delenv("XDG_CACHE_HOME", raising=False)
|
||||
result = enroll_cache_dir()
|
||||
assert str(result).endswith("/.local/cache/enroll")
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
import pytest
|
||||
|
||||
|
||||
def test_dpkg_owner_parses_output(monkeypatch):
|
||||
|
|
@ -96,3 +97,441 @@ def test_parse_status_conffiles_handles_continuations(tmp_path: Path):
|
|||
assert m["nginx"]["/etc/nginx/nginx.conf"] == "abcdef"
|
||||
assert m["nginx"]["/etc/nginx/mime.types"] == "123456"
|
||||
assert "other" not in m
|
||||
|
||||
|
||||
def test_dpkg_owner_returns_none_on_diversion_only(monkeypatch):
|
||||
import enroll.debian as d
|
||||
|
||||
class P:
|
||||
def __init__(self, rc: int, out: str):
|
||||
self.returncode = rc
|
||||
self.stdout = out
|
||||
self.stderr = ""
|
||||
|
||||
def fake_run(cmd, text, capture_output):
|
||||
return P(0, "diversion by foo from: /etc/something\n")
|
||||
|
||||
monkeypatch.setattr(d.subprocess, "run", fake_run)
|
||||
assert d.dpkg_owner("/etc/something") is None
|
||||
|
||||
|
||||
def test_dpkg_owner_handles_line_without_colon(monkeypatch):
|
||||
import enroll.debian as d
|
||||
|
||||
class P:
|
||||
def __init__(self, rc: int, out: str):
|
||||
self.returncode = rc
|
||||
self.stdout = out
|
||||
self.stderr = ""
|
||||
|
||||
def fake_run(cmd, text, capture_output):
|
||||
return P(0, "invalid line without colon\n")
|
||||
|
||||
monkeypatch.setattr(d.subprocess, "run", fake_run)
|
||||
assert d.dpkg_owner("/etc/foo") is None
|
||||
|
||||
|
||||
def test_list_manual_packages_returns_empty_on_error(monkeypatch):
|
||||
import enroll.debian as d
|
||||
|
||||
class P:
|
||||
def __init__(self, rc: int, out: str):
|
||||
self.returncode = rc
|
||||
self.stdout = out
|
||||
self.stderr = ""
|
||||
|
||||
def fake_run(cmd, text, capture_output):
|
||||
return P(1, "error")
|
||||
|
||||
monkeypatch.setattr(d.subprocess, "run", fake_run)
|
||||
assert d.list_manual_packages() == []
|
||||
|
||||
|
||||
def test_list_installed_packages_handles_exception(monkeypatch):
|
||||
import enroll.debian as d
|
||||
|
||||
def fake_run(*args, **kwargs):
|
||||
raise Exception("simulated error")
|
||||
|
||||
monkeypatch.setattr(d.subprocess, "run", fake_run)
|
||||
assert d.list_installed_packages() == {}
|
||||
|
||||
|
||||
def test_list_installed_packages_parses_output():
|
||||
import enroll.debian as d
|
||||
|
||||
class P:
|
||||
def __init__(self, rc: int, out: str):
|
||||
self.returncode = rc
|
||||
self.stdout = out
|
||||
self.stderr = ""
|
||||
|
||||
original_run = d.subprocess.run
|
||||
|
||||
def fake_run(cmd, text, capture_output, check):
|
||||
return P(0, "nginx\t1.18.0\tamd64\nvim\t8.2\tamd64\n")
|
||||
|
||||
d.subprocess.run = fake_run
|
||||
try:
|
||||
result = d.list_installed_packages()
|
||||
assert "nginx" in result
|
||||
assert result["nginx"][0]["version"] == "1.18.0"
|
||||
assert result["nginx"][0]["arch"] == "amd64"
|
||||
assert "vim" in result
|
||||
finally:
|
||||
d.subprocess.run = original_run
|
||||
|
||||
|
||||
def test_list_installed_packages_skips_invalid_lines():
|
||||
import enroll.debian as d
|
||||
|
||||
class P:
|
||||
def __init__(self, rc: int, out: str):
|
||||
self.returncode = rc
|
||||
self.stdout = out
|
||||
self.stderr = ""
|
||||
|
||||
original_run = d.subprocess.run
|
||||
|
||||
def fake_run(cmd, text, capture_output, check):
|
||||
return P(0, "nginx\t1.18.0\tamd64\ninvalid_line\n\t1.0\tamd64\n")
|
||||
|
||||
d.subprocess.run = fake_run
|
||||
try:
|
||||
result = d.list_installed_packages()
|
||||
assert "nginx" in result
|
||||
assert "invalid_line" not in result
|
||||
finally:
|
||||
d.subprocess.run = original_run
|
||||
|
||||
|
||||
def test_list_installed_packages_handles_empty_name():
|
||||
import enroll.debian as d
|
||||
|
||||
class P:
|
||||
def __init__(self, rc: int, out: str):
|
||||
self.returncode = rc
|
||||
self.stdout = out
|
||||
self.stderr = ""
|
||||
|
||||
original_run = d.subprocess.run
|
||||
|
||||
def fake_run(cmd, text, capture_output, check):
|
||||
return P(0, "\t1.0\tamd64\nnginx\t1.18.0\tamd64\n")
|
||||
|
||||
d.subprocess.run = fake_run
|
||||
try:
|
||||
result = d.list_installed_packages()
|
||||
assert "" not in result
|
||||
assert "nginx" in result
|
||||
finally:
|
||||
d.subprocess.run = original_run
|
||||
|
||||
|
||||
def test_list_installed_packages_sorts_output():
|
||||
import enroll.debian as d
|
||||
|
||||
class P:
|
||||
def __init__(self, rc: int, out: str):
|
||||
self.returncode = rc
|
||||
self.stdout = out
|
||||
self.stderr = ""
|
||||
|
||||
original_run = d.subprocess.run
|
||||
|
||||
def fake_run(cmd, text, capture_output, check):
|
||||
return P(0, "nginx\t1.18.0\tamd64\nnginx\t1.19.0\tarm64\n")
|
||||
|
||||
d.subprocess.run = fake_run
|
||||
try:
|
||||
result = d.list_installed_packages()
|
||||
assert len(result["nginx"]) == 2
|
||||
assert result["nginx"][0]["arch"] == "amd64"
|
||||
assert result["nginx"][1]["arch"] == "arm64"
|
||||
finally:
|
||||
d.subprocess.run = original_run
|
||||
|
||||
|
||||
def test_build_dpkg_etc_index_handles_missing_file(tmp_path: Path):
|
||||
import enroll.debian as d
|
||||
|
||||
info = tmp_path / "info"
|
||||
info.mkdir()
|
||||
# Don't create any .list files
|
||||
|
||||
owned, owner_map, topdir_to_pkgs, pkg_to_etc = d.build_dpkg_etc_index(str(info))
|
||||
assert owned == set()
|
||||
assert owner_map == {}
|
||||
assert topdir_to_pkgs == {}
|
||||
assert pkg_to_etc == {}
|
||||
|
||||
|
||||
def test_build_dpkg_etc_index_skips_non_etc_paths(tmp_path: Path):
|
||||
import enroll.debian as d
|
||||
|
||||
info = tmp_path / "info"
|
||||
info.mkdir()
|
||||
(info / "foo.list").write_text("/usr/bin/foo\n/etc/bar\n", encoding="utf-8")
|
||||
|
||||
owned, owner_map, topdir_to_pkgs, pkg_to_etc = d.build_dpkg_etc_index(str(info))
|
||||
assert "/usr/bin/foo" not in owned
|
||||
assert "/etc/bar" in owned
|
||||
assert "foo" not in topdir_to_pkgs
|
||||
|
||||
|
||||
def test_parse_status_conffiles_handles_empty_status(tmp_path: Path):
|
||||
import enroll.debian as d
|
||||
|
||||
status = tmp_path / "status"
|
||||
status.write_text("", encoding="utf-8")
|
||||
m = d.parse_status_conffiles(str(status))
|
||||
assert m == {}
|
||||
|
||||
|
||||
def test_parse_status_conffiles_handles_package_without_conffiles(tmp_path: Path):
|
||||
import enroll.debian as d
|
||||
|
||||
status = tmp_path / "status"
|
||||
status.write_text(
|
||||
"Package: nginx\nVersion: 1\nStatus: install ok installed\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
m = d.parse_status_conffiles(str(status))
|
||||
assert m == {}
|
||||
|
||||
|
||||
def test_read_pkg_md5sums_returns_empty_if_file_not_exists(tmp_path: Path):
|
||||
import enroll.debian as d
|
||||
|
||||
result = d.read_pkg_md5sums("nonexistent_package")
|
||||
assert result == {}
|
||||
|
||||
|
||||
def test_read_pkg_md5sums_parses_md5sums_file(tmp_path: Path, monkeypatch):
|
||||
import enroll.debian as d
|
||||
|
||||
info_dir = tmp_path / "info"
|
||||
info_dir.mkdir()
|
||||
md5_file = info_dir / "nginx.md5sums"
|
||||
md5_file.write_text(
|
||||
"abcdef1234567890abcdef1234567890 etc/nginx/nginx.conf\n"
|
||||
"1234567890abcdef1234567890abcdef etc/nginx/sites-enabled/default\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
def fake_exists(path):
|
||||
return str(path).endswith("nginx.md5sums")
|
||||
|
||||
monkeypatch.setattr(d.os.path, "exists", fake_exists)
|
||||
|
||||
original_open = open
|
||||
|
||||
def fake_open(path, *args, **kwargs):
|
||||
if "nginx.md5sums" in str(path):
|
||||
return original_open(md5_file, *args, **kwargs)
|
||||
return original_open(path, *args, **kwargs)
|
||||
|
||||
monkeypatch.setattr("builtins.open", fake_open, raising=False)
|
||||
|
||||
result = d.read_pkg_md5sums("nginx")
|
||||
assert result["etc/nginx/nginx.conf"] == "abcdef1234567890abcdef1234567890"
|
||||
assert (
|
||||
result["etc/nginx/sites-enabled/default"] == "1234567890abcdef1234567890abcdef"
|
||||
)
|
||||
|
||||
|
||||
def test_dpkg_owner_raises_on_command_failure(monkeypatch):
|
||||
"""Test _run raises RuntimeError on non-zero exit."""
|
||||
import enroll.debian as d
|
||||
|
||||
class P:
|
||||
returncode = 1
|
||||
stdout = ""
|
||||
stderr = "command failed"
|
||||
|
||||
def fake_run(cmd, text, capture_output, check=False):
|
||||
return P()
|
||||
|
||||
monkeypatch.setattr(d.subprocess, "run", fake_run)
|
||||
|
||||
with pytest.raises(RuntimeError) as exc_info:
|
||||
d._run(["fake", "command"])
|
||||
|
||||
assert "Command failed" in str(exc_info.value)
|
||||
assert "fake" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_build_dpkg_etc_index_skips_invalid_line_formats(tmp_path: Path):
|
||||
"""Test that lines with less than 3 parts are skipped."""
|
||||
import enroll.debian as d
|
||||
|
||||
info = tmp_path / "info"
|
||||
info.mkdir()
|
||||
# Create a .list file with invalid format (missing tab-separated fields)
|
||||
(info / "foo.list").write_text(
|
||||
"/etc/foo/bar\n" # This is a path, not a tab-separated line
|
||||
"/etc/foo/baz\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
# Should handle gracefully
|
||||
owned, owner_map, topdir_to_pkgs, pkg_to_etc = d.build_dpkg_etc_index(str(info))
|
||||
# The path lines should be processed normally
|
||||
assert "/etc/foo/bar" in owned or "/etc/foo/baz" in owned
|
||||
|
||||
|
||||
def test_build_dpkg_etc_index_handles_file_not_found(tmp_path: Path):
|
||||
"""Test that FileNotFoundError is handled gracefully."""
|
||||
import enroll.debian as d
|
||||
|
||||
info = tmp_path / "info"
|
||||
info.mkdir()
|
||||
# Create a .list file that references a non-existent path
|
||||
(info / "foo.list").write_text(
|
||||
"/nonexistent/path\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
# Should not raise
|
||||
owned, owner_map, topdir_to_pkgs, pkg_to_etc = d.build_dpkg_etc_index(str(info))
|
||||
# The non-existent path should be skipped
|
||||
assert "/nonexistent/path" not in owned
|
||||
|
||||
|
||||
def test_parse_status_conffiles_skips_empty_lines(tmp_path: Path):
|
||||
"""Test that empty lines in conffiles are skipped."""
|
||||
import enroll.debian as d
|
||||
|
||||
status = tmp_path / "status"
|
||||
status.write_text(
|
||||
"Package: nginx\n"
|
||||
"Version: 1\n"
|
||||
"Conffiles:\n"
|
||||
" /etc/nginx/nginx.conf abcdef\n"
|
||||
" /etc/nginx/mime.types 123456\n"
|
||||
"\n", # Empty line to trigger flush
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
m = d.parse_status_conffiles(str(status))
|
||||
assert "/etc/nginx/nginx.conf" in m["nginx"]
|
||||
assert "/etc/nginx/mime.types" in m["nginx"]
|
||||
|
||||
|
||||
def test_read_pkg_md5sums_skips_invalid_md5_lines(tmp_path: Path, monkeypatch):
|
||||
"""Test that lines without proper MD5 format are skipped."""
|
||||
import enroll.debian as d
|
||||
|
||||
info_dir = tmp_path / "info"
|
||||
info_dir.mkdir()
|
||||
md5_file = info_dir / "foo.md5sums"
|
||||
md5_file.write_text(
|
||||
"abcdef1234567890abcdef1234567890 etc/foo/bar\n"
|
||||
"invalid line without proper format\n"
|
||||
"1234567890abcdef1234567890abcdef etc/foo/baz\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
def fake_exists(path):
|
||||
return str(path).endswith("foo.md5sums")
|
||||
|
||||
monkeypatch.setattr(d.os.path, "exists", fake_exists)
|
||||
|
||||
original_open = open
|
||||
|
||||
def fake_open(path, *args, **kwargs):
|
||||
if "foo.md5sums" in str(path):
|
||||
return original_open(md5_file, *args, **kwargs)
|
||||
return original_open(path, *args, **kwargs)
|
||||
|
||||
monkeypatch.setattr("builtins.open", fake_open, raising=False)
|
||||
|
||||
result = d.read_pkg_md5sums("foo")
|
||||
assert "etc/foo/bar" in result
|
||||
assert "etc/foo/baz" in result
|
||||
|
||||
|
||||
def test_build_dpkg_etc_index_skips_lines_without_tabs(tmp_path: Path):
|
||||
"""Test that lines without tab separators are skipped (parts < 3)."""
|
||||
import enroll.debian as d
|
||||
|
||||
info = tmp_path / "info"
|
||||
info.mkdir()
|
||||
# Create file with lines that don't have tab separators
|
||||
(info / "foo.list").write_text(
|
||||
"notabseparator\n" # No tab - should be skipped
|
||||
"/etc/foo/bar\n", # This is a path line, processed differently
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
owned, owner_map, topdir_to_pkgs, pkg_to_etc = d.build_dpkg_etc_index(str(info))
|
||||
# Path lines are still processed
|
||||
assert "/etc/foo/bar" in owned
|
||||
|
||||
|
||||
def test_read_pkg_md5sums_skips_empty_lines(tmp_path: Path, monkeypatch):
|
||||
"""Test that empty lines in md5sums are skipped."""
|
||||
import enroll.debian as d
|
||||
|
||||
info_dir = tmp_path / "info"
|
||||
info_dir.mkdir()
|
||||
md5_file = info_dir / "bar.md5sums"
|
||||
md5_file.write_text(
|
||||
"abcdef1234567890abcdef1234567890 etc/bar/file1\n"
|
||||
"\n" # Empty line
|
||||
"1234567890abcdef1234567890abcdef etc/bar/file2\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
def fake_exists(path):
|
||||
return str(path).endswith("bar.md5sums")
|
||||
|
||||
monkeypatch.setattr(d.os.path, "exists", fake_exists)
|
||||
|
||||
original_open = open
|
||||
|
||||
def fake_open(path, *args, **kwargs):
|
||||
if "bar.md5sums" in str(path):
|
||||
return original_open(md5_file, *args, **kwargs)
|
||||
return original_open(path, *args, **kwargs)
|
||||
|
||||
monkeypatch.setattr("builtins.open", fake_open, raising=False)
|
||||
|
||||
result = d.read_pkg_md5sums("bar")
|
||||
assert "etc/bar/file1" in result
|
||||
assert "etc/bar/file2" in result
|
||||
|
||||
|
||||
def test_read_pkg_md5sums_skips_lines_not_starting_with_path(
|
||||
tmp_path: Path, monkeypatch
|
||||
):
|
||||
"""Test that lines not starting with / are skipped."""
|
||||
import enroll.debian as d
|
||||
|
||||
info_dir = tmp_path / "info"
|
||||
info_dir.mkdir()
|
||||
md5_file = info_dir / "baz.md5sums"
|
||||
md5_file.write_text(
|
||||
"abcdef1234567890abcdef1234567890 etc/baz/file1\n"
|
||||
"invalid line\n" # Doesn't start with /
|
||||
"1234567890abcdef1234567890abcdef etc/baz/file2\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
def fake_exists(path):
|
||||
return str(path).endswith("baz.md5sums")
|
||||
|
||||
monkeypatch.setattr(d.os.path, "exists", fake_exists)
|
||||
|
||||
original_open = open
|
||||
|
||||
def fake_open(path, *args, **kwargs):
|
||||
if "baz.md5sums" in str(path):
|
||||
return original_open(md5_file, *args, **kwargs)
|
||||
return original_open(path, *args, **kwargs)
|
||||
|
||||
monkeypatch.setattr("builtins.open", fake_open, raising=False)
|
||||
|
||||
result = d.read_pkg_md5sums("baz")
|
||||
assert "etc/baz/file1" in result
|
||||
assert "etc/baz/file2" in result
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -1,4 +1,5 @@
|
|||
import json
|
||||
import enroll.harvest as harvest
|
||||
from pathlib import Path
|
||||
|
||||
import enroll.harvest as h
|
||||
|
|
@ -367,3 +368,149 @@ def test_shared_cron_snippet_prefers_matching_role_over_lexicographic(
|
|||
assert all(
|
||||
mf["path"] != "/etc/cron.d/ntpsec" for mf in svc_apparmor["managed_files"]
|
||||
)
|
||||
|
||||
|
||||
def test_files_differ_same_content(tmp_path: Path):
|
||||
file1 = tmp_path / "file1.txt"
|
||||
file2 = tmp_path / "file2.txt"
|
||||
file1.write_text("same content", encoding="utf-8")
|
||||
file2.write_text("same content", encoding="utf-8")
|
||||
assert harvest._files_differ(str(file1), str(file2)) is False
|
||||
|
||||
|
||||
def test_files_differ_different_content(tmp_path: Path):
|
||||
file1 = tmp_path / "file1.txt"
|
||||
file2 = tmp_path / "file2.txt"
|
||||
file1.write_text("content1", encoding="utf-8")
|
||||
file2.write_text("content2", encoding="utf-8")
|
||||
assert harvest._files_differ(str(file1), str(file2)) is True
|
||||
|
||||
|
||||
def test_files_differ_missing_file(tmp_path: Path):
|
||||
file1 = tmp_path / "file1.txt"
|
||||
file1.write_text("content", encoding="utf-8")
|
||||
file2 = tmp_path / "file2.txt"
|
||||
assert harvest._files_differ(str(file1), str(file2)) is True
|
||||
|
||||
|
||||
def test_files_differ_both_missing(tmp_path: Path):
|
||||
file1 = tmp_path / "file1.txt"
|
||||
file2 = tmp_path / "file2.txt"
|
||||
assert harvest._files_differ(str(file1), str(file2)) is True
|
||||
|
||||
|
||||
def test_files_differ_binary(tmp_path: Path):
|
||||
file1 = tmp_path / "file1.bin"
|
||||
file2 = tmp_path / "file2.bin"
|
||||
file1.write_bytes(b"\x00\x01\x02\x03")
|
||||
file2.write_bytes(b"\x00\x01\x02\x03")
|
||||
assert harvest._files_differ(str(file1), str(file2)) is False
|
||||
|
||||
|
||||
def test_files_differ_binary_different(tmp_path: Path):
|
||||
file1 = tmp_path / "file1.bin"
|
||||
file2 = tmp_path / "file2.bin"
|
||||
file1.write_bytes(b"\x00\x01\x02\x03")
|
||||
file2.write_bytes(b"\x00\x01\x02\x04")
|
||||
assert harvest._files_differ(str(file1), str(file2)) is True
|
||||
|
||||
|
||||
def test_files_differ_non_regular_a(tmp_path: Path):
|
||||
directory = tmp_path / "dir"
|
||||
directory.mkdir()
|
||||
file1 = tmp_path / "file1.txt"
|
||||
file1.write_text("content", encoding="utf-8")
|
||||
assert harvest._files_differ(str(directory), str(file1)) is True
|
||||
|
||||
|
||||
def test_files_differ_non_regular_b(tmp_path: Path):
|
||||
file1 = tmp_path / "file1.txt"
|
||||
file1.write_text("content", encoding="utf-8")
|
||||
directory = tmp_path / "dir"
|
||||
directory.mkdir()
|
||||
assert harvest._files_differ(str(file1), str(directory)) is True
|
||||
|
||||
|
||||
def test_files_differ_size_mismatch(tmp_path: Path):
|
||||
file1 = tmp_path / "file1.txt"
|
||||
file1.write_text("short", encoding="utf-8")
|
||||
file2 = tmp_path / "file2.txt"
|
||||
file2.write_text("much longer content", encoding="utf-8")
|
||||
assert harvest._files_differ(str(file1), str(file2)) is True
|
||||
|
||||
|
||||
def test_files_differ_large_files(tmp_path: Path):
|
||||
file1 = tmp_path / "file1.bin"
|
||||
file2 = tmp_path / "file2.bin"
|
||||
file1.write_bytes(b"x" * 3_000_000)
|
||||
file2.write_bytes(b"x" * 3_000_000)
|
||||
assert harvest._files_differ(str(file1), str(file2)) is True
|
||||
|
||||
|
||||
def test_is_confish_with_conf(tmp_path: Path):
|
||||
file1 = tmp_path / "test.conf"
|
||||
file1.write_text("content", encoding="utf-8")
|
||||
assert harvest._is_confish(str(file1)) is True
|
||||
|
||||
|
||||
def test_is_confish_with_yaml(tmp_path: Path):
|
||||
file1 = tmp_path / "test.yaml"
|
||||
file1.write_text("content", encoding="utf-8")
|
||||
assert harvest._is_confish(str(file1)) is True
|
||||
|
||||
|
||||
def test_is_confish_with_json(tmp_path: Path):
|
||||
file1 = tmp_path / "test.json"
|
||||
file1.write_text("{}", encoding="utf-8")
|
||||
assert harvest._is_confish(str(file1)) is True
|
||||
|
||||
|
||||
def test_is_confish_with_service(tmp_path: Path):
|
||||
file1 = tmp_path / "test.service"
|
||||
file1.write_text("[Unit]", encoding="utf-8")
|
||||
assert harvest._is_confish(str(file1)) is True
|
||||
|
||||
|
||||
def test_is_confish_with_extensionless(tmp_path: Path):
|
||||
file1 = tmp_path / "default"
|
||||
file1.write_text("OPTIONS=", encoding="utf-8")
|
||||
assert harvest._is_confish(str(file1)) is True
|
||||
|
||||
|
||||
def test_is_confish_not_config(tmp_path: Path):
|
||||
file1 = tmp_path / "test.log"
|
||||
file1.write_text("log", encoding="utf-8")
|
||||
assert harvest._is_confish(str(file1)) is False
|
||||
|
||||
|
||||
def test_is_confish_nonexistent():
|
||||
assert harvest._is_confish("/nonexistent/file.xyz") is False
|
||||
|
||||
|
||||
def test_topdirs_for_package_with_multiple_paths():
|
||||
pkg_to_etc_paths = {
|
||||
"nginx": ["/etc/nginx/nginx.conf", "/etc/nginx/sites-enabled/default"],
|
||||
}
|
||||
result = harvest._topdirs_for_package("nginx", pkg_to_etc_paths)
|
||||
assert result == {"nginx"}
|
||||
|
||||
|
||||
def test_topdirs_for_package_with_multiple_topdirs():
|
||||
pkg_to_etc_paths = {
|
||||
"multi": ["/etc/nginx/nginx.conf", "/etc/ssh/sshd_config"],
|
||||
}
|
||||
result = harvest._topdirs_for_package("multi", pkg_to_etc_paths)
|
||||
assert result == {"nginx", "ssh"}
|
||||
|
||||
|
||||
def test_topdirs_for_package_empty():
|
||||
result = harvest._topdirs_for_package("empty", {})
|
||||
assert result == set()
|
||||
|
||||
|
||||
def test_topdirs_for_package_no_etc():
|
||||
pkg_to_etc_paths = {
|
||||
"other": ["/usr/share/doc/file"],
|
||||
}
|
||||
result = harvest._topdirs_for_package("other", pkg_to_etc_paths)
|
||||
assert result == set()
|
||||
|
|
|
|||
|
|
@ -1,3 +1,8 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from enroll.ignore import IgnorePolicy
|
||||
|
||||
|
||||
|
|
@ -8,3 +13,238 @@ def test_ignore_policy_denies_common_backup_files():
|
|||
assert pol.deny_reason("/etc/group-") == "backup_file"
|
||||
assert pol.deny_reason("/etc/something~") == "backup_file"
|
||||
assert pol.deny_reason("/foobar") == "unreadable"
|
||||
|
||||
|
||||
def test_deny_reason_dir_with_denied_path():
|
||||
pol = IgnorePolicy()
|
||||
assert pol.deny_reason_dir("/etc/ssl/private/key") == "denied_path"
|
||||
assert pol.deny_reason_dir("/etc/ssh/ssh_host_key") == "denied_path"
|
||||
assert pol.deny_reason_dir("/etc/ssh") is None
|
||||
|
||||
|
||||
def test_deny_reason_dir_unreadable(tmp_path: Path):
|
||||
pol = IgnorePolicy()
|
||||
nonexistent = tmp_path / "nonexistent"
|
||||
assert pol.deny_reason_dir(str(nonexistent)) == "unreadable"
|
||||
|
||||
|
||||
def test_deny_reason_dir_symlink(tmp_path: Path):
|
||||
pol = IgnorePolicy()
|
||||
real_dir = tmp_path / "real"
|
||||
real_dir.mkdir()
|
||||
link = tmp_path / "link"
|
||||
os.symlink(str(real_dir), str(link))
|
||||
assert pol.deny_reason_dir(str(link)) == "symlink"
|
||||
|
||||
|
||||
def test_deny_reason_dir_not_directory(tmp_path: Path):
|
||||
pol = IgnorePolicy()
|
||||
regular_file = tmp_path / "file.txt"
|
||||
regular_file.write_text("content", encoding="utf-8")
|
||||
assert pol.deny_reason_dir(str(regular_file)) == "not_directory"
|
||||
|
||||
|
||||
def test_deny_reason_dir_dangerous_mode(tmp_path: Path):
|
||||
pol = IgnorePolicy(dangerous=True)
|
||||
real_dir = tmp_path / "private"
|
||||
real_dir.mkdir()
|
||||
assert pol.deny_reason_dir(str(real_dir)) is None
|
||||
|
||||
|
||||
def test_deny_reason_link_basic(tmp_path: Path):
|
||||
pol = IgnorePolicy()
|
||||
real_file = tmp_path / "real"
|
||||
real_file.write_text("content", encoding="utf-8")
|
||||
link = tmp_path / "link"
|
||||
os.symlink(str(real_file), str(link))
|
||||
assert pol.deny_reason_link(str(link)) is None
|
||||
|
||||
|
||||
def test_deny_reason_link_denied_path():
|
||||
pol = IgnorePolicy()
|
||||
assert pol.deny_reason_link("/etc/ssh/ssh_host_rsa_key") == "denied_path"
|
||||
|
||||
|
||||
def test_deny_reason_link_unreadable(tmp_path: Path):
|
||||
pol = IgnorePolicy()
|
||||
# Create a symlink in a directory that doesn't exist
|
||||
# This simulates an unreadable path
|
||||
broken_link = tmp_path / "broken_link"
|
||||
os.symlink("/nonexistent/target", str(broken_link))
|
||||
# Broken symlinks are still readable (we can readlink them)
|
||||
# So they return None (allowed) unless they match deny globs
|
||||
result = pol.deny_reason_link(str(broken_link))
|
||||
# Broken symlinks are allowed - we can still read the link target
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_deny_reason_link_not_symlink(tmp_path: Path):
|
||||
pol = IgnorePolicy()
|
||||
regular_file = tmp_path / "file.txt"
|
||||
regular_file.write_text("content", encoding="utf-8")
|
||||
assert pol.deny_reason_link(str(regular_file)) == "not_symlink"
|
||||
|
||||
|
||||
def test_deny_reason_link_log_file():
|
||||
pol = IgnorePolicy()
|
||||
assert pol.deny_reason_link("/var/log/something.log") == "log_file"
|
||||
|
||||
|
||||
def test_deny_reason_link_backup_file():
|
||||
pol = IgnorePolicy()
|
||||
assert pol.deny_reason_link("/etc/passwd-") == "backup_file"
|
||||
assert pol.deny_reason_link("/etc/something~") == "backup_file"
|
||||
|
||||
|
||||
def test_deny_reason_link_dangerous_mode(tmp_path: Path):
|
||||
pol = IgnorePolicy(dangerous=True)
|
||||
real_file = tmp_path / "real"
|
||||
real_file.write_text("content", encoding="utf-8")
|
||||
link = tmp_path / "link"
|
||||
os.symlink(str(real_file), str(link))
|
||||
assert pol.deny_reason_link(str(link)) is None
|
||||
|
||||
|
||||
def test_iter_effective_lines_with_comments():
|
||||
pol = IgnorePolicy()
|
||||
content = b"""
|
||||
# This is a comment
|
||||
; This is also a comment
|
||||
* continuation
|
||||
def main():
|
||||
pass
|
||||
"""
|
||||
lines = list(pol.iter_effective_lines(content))
|
||||
assert b"def main():" in lines
|
||||
assert b"# This is a comment" not in lines
|
||||
|
||||
|
||||
def test_iter_effective_lines_with_block_comments():
|
||||
pol = IgnorePolicy()
|
||||
content = b"""
|
||||
/* This is a block comment
|
||||
spanning multiple lines */
|
||||
int x = 5;
|
||||
"""
|
||||
lines = list(pol.iter_effective_lines(content))
|
||||
assert b"int x = 5;" in lines
|
||||
assert b"/*" not in lines
|
||||
|
||||
|
||||
def test_iter_effective_lines_empty():
|
||||
pol = IgnorePolicy()
|
||||
content = b""
|
||||
lines = list(pol.iter_effective_lines(content))
|
||||
assert lines == []
|
||||
|
||||
|
||||
def test_deny_reason_binary_not_allowed(tmp_path: Path):
|
||||
pol = IgnorePolicy()
|
||||
binary = tmp_path / "random.bin"
|
||||
binary.write_bytes(b"\x00\x01\x02\x03")
|
||||
reason = pol.deny_reason(str(binary))
|
||||
assert reason == "binary_like"
|
||||
|
||||
|
||||
def test_deny_reason_sensitive_content(tmp_path: Path):
|
||||
pol = IgnorePolicy()
|
||||
config = tmp_path / "config.txt"
|
||||
config.write_text("password=secret123", encoding="utf-8")
|
||||
reason = pol.deny_reason(str(config))
|
||||
assert reason == "sensitive_content"
|
||||
|
||||
|
||||
def test_deny_reason_sensitive_api_key(tmp_path: Path):
|
||||
pol = IgnorePolicy()
|
||||
config = tmp_path / "config.txt"
|
||||
config.write_text("api_key=abc123", encoding="utf-8")
|
||||
reason = pol.deny_reason(str(config))
|
||||
assert reason == "sensitive_content"
|
||||
|
||||
|
||||
def test_deny_reason_private_key(tmp_path: Path):
|
||||
pol = IgnorePolicy()
|
||||
key = tmp_path / "key.pem"
|
||||
key.write_text(
|
||||
"-----BEGIN RSA PRIVATE KEY-----\nMIIEowIBAAKCAQEA...", encoding="utf-8"
|
||||
)
|
||||
reason = pol.deny_reason(str(key))
|
||||
assert reason == "sensitive_content"
|
||||
|
||||
|
||||
def test_deny_reason_too_large(tmp_path: Path):
|
||||
pol = IgnorePolicy(max_file_bytes=100)
|
||||
large = tmp_path / "large.txt"
|
||||
large.write_bytes(b"x" * 200)
|
||||
reason = pol.deny_reason(str(large))
|
||||
assert reason == "too_large"
|
||||
|
||||
|
||||
def test_deny_reason_unreadable(tmp_path: Path):
|
||||
pol = IgnorePolicy()
|
||||
nonexistent = tmp_path / "nonexistent"
|
||||
reason = pol.deny_reason(str(nonexistent))
|
||||
assert reason == "unreadable"
|
||||
|
||||
|
||||
def test_deny_reason_not_regular_file(tmp_path: Path):
|
||||
pol = IgnorePolicy()
|
||||
directory = tmp_path / "dir"
|
||||
directory.mkdir()
|
||||
reason = pol.deny_reason(str(directory))
|
||||
assert reason == "not_regular_file"
|
||||
|
||||
|
||||
def test_deny_reason_symlink_file(tmp_path: Path):
|
||||
pol = IgnorePolicy()
|
||||
real_file = tmp_path / "real"
|
||||
real_file.write_text("content", encoding="utf-8")
|
||||
link = tmp_path / "link"
|
||||
os.symlink(str(real_file), str(link))
|
||||
reason = pol.deny_reason(str(link))
|
||||
assert reason == "not_regular_file"
|
||||
|
||||
|
||||
def test_deny_reason_logs(tmp_path: Path):
|
||||
pol = IgnorePolicy()
|
||||
log = tmp_path / "test.log"
|
||||
log.write_text("log content", encoding="utf-8")
|
||||
assert pol.deny_reason(str(log)) == "log_file"
|
||||
|
||||
|
||||
def test_deny_reason_backup_file(tmp_path: Path):
|
||||
pol = IgnorePolicy()
|
||||
backup = tmp_path / "file~"
|
||||
backup.write_text("backup", encoding="utf-8")
|
||||
assert pol.deny_reason(str(backup)) == "backup_file"
|
||||
|
||||
|
||||
def test_deny_reason_shadow_file():
|
||||
pol = IgnorePolicy()
|
||||
assert pol.deny_reason("/etc/shadow") == "denied_path"
|
||||
assert pol.deny_reason("/etc/gshadow") == "denied_path"
|
||||
|
||||
|
||||
def test_deny_reason_ssl_private():
|
||||
pol = IgnorePolicy()
|
||||
assert pol.deny_reason("/etc/ssl/private/key.pem") == "denied_path"
|
||||
|
||||
|
||||
def test_deny_reason_ssh_host_keys():
|
||||
pol = IgnorePolicy()
|
||||
assert pol.deny_reason("/etc/ssh/ssh_host_rsa_key") == "denied_path"
|
||||
assert pol.deny_reason("/etc/ssh/ssh_host_ed25519_key") == "denied_path"
|
||||
|
||||
|
||||
def test_deny_reason_letsencrypt():
|
||||
pol = IgnorePolicy()
|
||||
assert (
|
||||
pol.deny_reason("/etc/letsencrypt/live/example.com/fullchain.pem")
|
||||
== "denied_path"
|
||||
)
|
||||
|
||||
|
||||
def test_deny_reason_shadow_backup():
|
||||
pol = IgnorePolicy()
|
||||
assert pol.deny_reason("/etc/shadow-") == "backup_file"
|
||||
assert pol.deny_reason("/etc/passwd-") == "backup_file"
|
||||
|
|
|
|||
|
|
@ -892,3 +892,175 @@ def test_manifest_writes_firewall_runtime_role(tmp_path: Path):
|
|||
assert (
|
||||
out / "roles" / "firewall_runtime" / "files" / "firewall" / "ipset.save"
|
||||
).exists()
|
||||
|
||||
|
||||
def test_try_yaml_with_yaml_installed():
|
||||
result = manifest._try_yaml()
|
||||
# PyYAML should be installed for tests
|
||||
if result is None:
|
||||
pytest.skip("PyYAML not installed")
|
||||
assert hasattr(result, "safe_load")
|
||||
assert hasattr(result, "dump")
|
||||
|
||||
|
||||
def test_yaml_load_mapping_with_yaml(tmp_path: Path):
|
||||
text = """
|
||||
key1: value1
|
||||
key2:
|
||||
nested: value
|
||||
list:
|
||||
- item1
|
||||
- item2
|
||||
"""
|
||||
result = manifest._yaml_load_mapping(text)
|
||||
assert result["key1"] == "value1"
|
||||
assert result["key2"]["nested"] == "value"
|
||||
assert result["list"] == ["item1", "item2"]
|
||||
|
||||
|
||||
def test_yaml_load_mapping_empty():
|
||||
result = manifest._yaml_load_mapping("")
|
||||
assert result == {}
|
||||
|
||||
|
||||
def test_yaml_load_mapping_invalid():
|
||||
result = manifest._yaml_load_mapping("invalid: yaml: :")
|
||||
assert result == {}
|
||||
|
||||
|
||||
def test_yaml_load_mapping_not_dict():
|
||||
result = manifest._yaml_load_mapping("- item1\n- item2")
|
||||
assert result == {}
|
||||
|
||||
|
||||
def test_yaml_load_mapping_none():
|
||||
result = manifest._yaml_load_mapping("~")
|
||||
assert result == {}
|
||||
|
||||
|
||||
def test_yaml_dump_mapping_with_yaml(tmp_path: Path):
|
||||
obj = {"key1": "value1", "key2": 123}
|
||||
result = manifest._yaml_dump_mapping(obj)
|
||||
assert "key1: value1" in result
|
||||
assert "key2:" in result
|
||||
|
||||
|
||||
def test_yaml_dump_mapping_empty():
|
||||
result = manifest._yaml_dump_mapping({})
|
||||
# Empty dict produces '{}'
|
||||
assert result.strip() == "{}"
|
||||
|
||||
|
||||
def test_yaml_dump_mapping_with_nested(tmp_path: Path):
|
||||
obj = {"key1": {"nested": "value"}}
|
||||
result = manifest._yaml_dump_mapping(obj)
|
||||
assert "nested:" in result
|
||||
|
||||
|
||||
def test_merge_mappings_overwrite_simple():
|
||||
existing = {"key1": "old", "key2": "keep"}
|
||||
incoming = {"key1": "new", "key3": "added"}
|
||||
result = manifest._merge_mappings_overwrite(existing, incoming)
|
||||
assert result["key1"] == "new"
|
||||
assert result["key2"] == "keep"
|
||||
assert result["key3"] == "added"
|
||||
|
||||
|
||||
def test_merge_mappings_overwrite_nested():
|
||||
existing = {"key1": {"a": 1}}
|
||||
incoming = {"key1": {"b": 2}}
|
||||
result = manifest._merge_mappings_overwrite(existing, incoming)
|
||||
# Nested dicts are replaced, not merged
|
||||
assert result["key1"] == {"b": 2}
|
||||
|
||||
|
||||
def test_merge_mappings_overwrite_empty():
|
||||
result = manifest._merge_mappings_overwrite({}, {"key": "value"})
|
||||
assert result == {"key": "value"}
|
||||
|
||||
result = manifest._merge_mappings_overwrite({"key": "value"}, {})
|
||||
assert result == {"key": "value"}
|
||||
|
||||
|
||||
def test_copy2_replace(tmp_path: Path):
|
||||
src = tmp_path / "src.txt"
|
||||
src.write_text("content", encoding="utf-8")
|
||||
dst = tmp_path / "dst" / "subdir" / "dst.txt"
|
||||
|
||||
manifest._copy2_replace(str(src), str(dst))
|
||||
|
||||
assert dst.exists()
|
||||
assert dst.read_text(encoding="utf-8") == "content"
|
||||
|
||||
|
||||
def test_copy2_replace_preserves_metadata(tmp_path: Path):
|
||||
src = tmp_path / "src.txt"
|
||||
src.write_text("content", encoding="utf-8")
|
||||
os.chmod(str(src), 0o644)
|
||||
dst = tmp_path / "dst.txt"
|
||||
|
||||
manifest._copy2_replace(str(src), str(dst))
|
||||
|
||||
assert dst.exists()
|
||||
st = dst.stat()
|
||||
assert stat.S_IMODE(st.st_mode) == 0o644
|
||||
|
||||
|
||||
def test_copy2_replace_atomic(tmp_path: Path):
|
||||
src = tmp_path / "src.txt"
|
||||
src.write_text("content", encoding="utf-8")
|
||||
dst = tmp_path / "dst.txt"
|
||||
|
||||
# Write initial content
|
||||
dst.write_text("old", encoding="utf-8")
|
||||
|
||||
manifest._copy2_replace(str(src), str(dst))
|
||||
|
||||
assert dst.read_text(encoding="utf-8") == "content"
|
||||
|
||||
|
||||
def test_render_firewall_runtime_tasks_empty():
|
||||
state = {"roles": {}}
|
||||
result = manifest._render_firewall_runtime_tasks(state)
|
||||
# Function always returns at least a basic playbook structure
|
||||
assert isinstance(result, str)
|
||||
assert len(result) > 0
|
||||
|
||||
|
||||
def test_render_firewall_runtime_tasks_with_iptables():
|
||||
state = {
|
||||
"roles": {
|
||||
"firewall_runtime": {
|
||||
"role_name": "firewall_runtime",
|
||||
"iptables_v4_save": "artifacts/firewall_runtime/iptables.save",
|
||||
}
|
||||
}
|
||||
}
|
||||
result = manifest._render_firewall_runtime_tasks(state)
|
||||
assert len(result) >= 1
|
||||
|
||||
|
||||
def test_render_firewall_runtime_tasks_with_ipset():
|
||||
state = {
|
||||
"roles": {
|
||||
"firewall_runtime": {
|
||||
"role_name": "firewall_runtime",
|
||||
"ipset_save": "artifacts/firewall_runtime/ipset.save",
|
||||
}
|
||||
}
|
||||
}
|
||||
result = manifest._render_firewall_runtime_tasks(state)
|
||||
assert len(result) >= 1
|
||||
|
||||
|
||||
def test_render_firewall_runtime_tasks_with_ipv6():
|
||||
state = {
|
||||
"roles": {
|
||||
"firewall_runtime": {
|
||||
"role_name": "firewall_runtime",
|
||||
"iptables_v6_save": "artifacts/firewall_runtime/ip6tables.save",
|
||||
}
|
||||
}
|
||||
}
|
||||
result = manifest._render_firewall_runtime_tasks(state)
|
||||
assert len(result) >= 1
|
||||
|
|
|
|||
|
|
@ -1,416 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import stat
|
||||
import subprocess
|
||||
import sys
|
||||
import types
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from enroll.cache import _safe_component, new_harvest_cache_dir
|
||||
from enroll.ignore import IgnorePolicy
|
||||
from enroll.sopsutil import (
|
||||
SopsError,
|
||||
_pgp_arg,
|
||||
decrypt_file_binary_to,
|
||||
encrypt_file_binary,
|
||||
)
|
||||
|
||||
|
||||
def test_safe_component_sanitizes_and_bounds_length():
|
||||
assert _safe_component(" ") == "unknown"
|
||||
assert _safe_component("a/b c") == "a_b_c"
|
||||
assert _safe_component("x" * 200) == "x" * 64
|
||||
|
||||
|
||||
def test_new_harvest_cache_dir_uses_xdg_cache_home(tmp_path: Path, monkeypatch):
|
||||
monkeypatch.setenv("XDG_CACHE_HOME", str(tmp_path / "xdg"))
|
||||
hc = new_harvest_cache_dir(hint="my host/01")
|
||||
assert hc.dir.exists()
|
||||
assert "my_host_01" in hc.dir.name
|
||||
assert str(hc.dir).startswith(str(tmp_path / "xdg"))
|
||||
# best-effort: ensure directory is not world-readable on typical FS
|
||||
try:
|
||||
mode = stat.S_IMODE(hc.dir.stat().st_mode)
|
||||
assert mode & 0o077 == 0
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
def test_ignore_policy_denies_binary_and_sensitive_content(tmp_path: Path):
|
||||
p_bin = tmp_path / "binfile"
|
||||
p_bin.write_bytes(b"abc\x00def")
|
||||
assert IgnorePolicy().deny_reason(str(p_bin)) == "binary_like"
|
||||
|
||||
p_secret = tmp_path / "secret.conf"
|
||||
p_secret.write_text("password=foo\n", encoding="utf-8")
|
||||
assert IgnorePolicy().deny_reason(str(p_secret)) == "sensitive_content"
|
||||
|
||||
# dangerous mode disables heuristic scanning (but still checks file-ness/size)
|
||||
assert IgnorePolicy(dangerous=True).deny_reason(str(p_secret)) is None
|
||||
|
||||
|
||||
def test_ignore_policy_denies_usr_local_shadow_by_glob():
|
||||
# This should short-circuit before stat() (path doesn't need to exist).
|
||||
assert IgnorePolicy().deny_reason("/usr/local/etc/shadow") == "denied_path"
|
||||
|
||||
|
||||
def test_sops_pgp_arg_and_encrypt_decrypt_roundtrip(tmp_path: Path, monkeypatch):
|
||||
assert _pgp_arg([" ABC ", "DEF"]) == "ABC,DEF"
|
||||
with pytest.raises(SopsError):
|
||||
_pgp_arg([])
|
||||
|
||||
# Stub out sops and subprocess.
|
||||
import enroll.sopsutil as s
|
||||
|
||||
monkeypatch.setattr(s, "require_sops_cmd", lambda: "sops")
|
||||
|
||||
class R:
|
||||
def __init__(self, rc: int, out: bytes, err: bytes = b""):
|
||||
self.returncode = rc
|
||||
self.stdout = out
|
||||
self.stderr = err
|
||||
|
||||
calls = []
|
||||
|
||||
def fake_run(cmd, capture_output, check):
|
||||
calls.append(cmd)
|
||||
# Return a deterministic payload so we can assert file writes.
|
||||
if "--encrypt" in cmd:
|
||||
return R(0, b"ENCRYPTED")
|
||||
if "--decrypt" in cmd:
|
||||
return R(0, b"PLAINTEXT")
|
||||
return R(1, b"", b"bad")
|
||||
|
||||
monkeypatch.setattr(s.subprocess, "run", fake_run)
|
||||
|
||||
src = tmp_path / "src.bin"
|
||||
src.write_bytes(b"x")
|
||||
enc = tmp_path / "out.sops"
|
||||
dec = tmp_path / "out.bin"
|
||||
|
||||
encrypt_file_binary(src, enc, pgp_fingerprints=["ABC"], mode=0o600)
|
||||
assert enc.read_bytes() == b"ENCRYPTED"
|
||||
|
||||
decrypt_file_binary_to(enc, dec, mode=0o644)
|
||||
assert dec.read_bytes() == b"PLAINTEXT"
|
||||
|
||||
# Sanity: we invoked encrypt and decrypt.
|
||||
assert any("--encrypt" in c for c in calls)
|
||||
assert any("--decrypt" in c for c in calls)
|
||||
|
||||
|
||||
def test_cache_dir_defaults_to_home_cache(monkeypatch, tmp_path: Path):
|
||||
# Ensure default path uses ~/.cache when XDG_CACHE_HOME is unset.
|
||||
from enroll.cache import enroll_cache_dir
|
||||
|
||||
monkeypatch.delenv("XDG_CACHE_HOME", raising=False)
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
|
||||
p = enroll_cache_dir()
|
||||
assert str(p).startswith(str(tmp_path))
|
||||
assert p.name == "enroll"
|
||||
|
||||
|
||||
def test_harvest_cache_state_json_property(tmp_path: Path):
|
||||
from enroll.cache import HarvestCache
|
||||
|
||||
hc = HarvestCache(tmp_path / "h1")
|
||||
assert hc.state_json == hc.dir / "state.json"
|
||||
|
||||
|
||||
def test_cache_dir_security_rejects_symlink(tmp_path: Path):
|
||||
from enroll.cache import _ensure_dir_secure
|
||||
|
||||
real = tmp_path / "real"
|
||||
real.mkdir()
|
||||
link = tmp_path / "link"
|
||||
link.symlink_to(real, target_is_directory=True)
|
||||
|
||||
with pytest.raises(RuntimeError, match="Refusing to use symlink"):
|
||||
_ensure_dir_secure(link)
|
||||
|
||||
|
||||
def test_cache_dir_chmod_failures_are_ignored(monkeypatch, tmp_path: Path):
|
||||
from enroll import cache
|
||||
|
||||
# Make the cache base path deterministic and writable.
|
||||
monkeypatch.setattr(cache, "enroll_cache_dir", lambda: tmp_path)
|
||||
|
||||
# Force os.chmod to fail to cover the "except OSError: pass" paths.
|
||||
monkeypatch.setattr(
|
||||
os, "chmod", lambda *a, **k: (_ for _ in ()).throw(OSError("nope"))
|
||||
)
|
||||
|
||||
hc = cache.new_harvest_cache_dir()
|
||||
assert hc.dir.exists()
|
||||
assert hc.dir.is_dir()
|
||||
|
||||
|
||||
def test_stat_triplet_falls_back_to_numeric_ids(monkeypatch, tmp_path: Path):
|
||||
from enroll.fsutil import stat_triplet
|
||||
import pwd
|
||||
import grp
|
||||
|
||||
p = tmp_path / "x"
|
||||
p.write_text("x", encoding="utf-8")
|
||||
|
||||
# Force username/group resolution failures.
|
||||
monkeypatch.setattr(
|
||||
pwd, "getpwuid", lambda _uid: (_ for _ in ()).throw(KeyError("no user"))
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
grp, "getgrgid", lambda _gid: (_ for _ in ()).throw(KeyError("no group"))
|
||||
)
|
||||
|
||||
owner, group, mode = stat_triplet(str(p))
|
||||
assert owner.isdigit()
|
||||
assert group.isdigit()
|
||||
assert len(mode) == 4
|
||||
|
||||
|
||||
def test_ignore_policy_iter_effective_lines_removes_block_comments():
|
||||
from enroll.ignore import IgnorePolicy
|
||||
|
||||
pol = IgnorePolicy()
|
||||
data = b"""keep1
|
||||
/*
|
||||
drop me
|
||||
*/
|
||||
keep2
|
||||
"""
|
||||
assert list(pol.iter_effective_lines(data)) == [b"keep1", b"keep2"]
|
||||
|
||||
|
||||
def test_ignore_policy_deny_reason_dir_variants(tmp_path: Path):
|
||||
from enroll.ignore import IgnorePolicy
|
||||
|
||||
pol = IgnorePolicy()
|
||||
|
||||
# denied by glob
|
||||
assert pol.deny_reason_dir("/etc/shadow") == "denied_path"
|
||||
|
||||
# symlink rejected
|
||||
d = tmp_path / "d"
|
||||
d.mkdir()
|
||||
link = tmp_path / "l"
|
||||
link.symlink_to(d, target_is_directory=True)
|
||||
assert pol.deny_reason_dir(str(link)) == "symlink"
|
||||
|
||||
# not a directory
|
||||
f = tmp_path / "f"
|
||||
f.write_text("x", encoding="utf-8")
|
||||
assert pol.deny_reason_dir(str(f)) == "not_directory"
|
||||
|
||||
# ok
|
||||
assert pol.deny_reason_dir(str(d)) is None
|
||||
|
||||
|
||||
def test_run_jinjaturtle_parses_outputs(monkeypatch, tmp_path: Path):
|
||||
# Fully unit-test enroll.jinjaturtle.run_jinjaturtle by stubbing subprocess.run.
|
||||
from enroll.jinjaturtle import run_jinjaturtle
|
||||
|
||||
def fake_run(cmd, **kwargs): # noqa: ARG001
|
||||
# cmd includes "-d <defaults> -t <template>"
|
||||
d_idx = cmd.index("-d") + 1
|
||||
t_idx = cmd.index("-t") + 1
|
||||
defaults = Path(cmd[d_idx])
|
||||
template = Path(cmd[t_idx])
|
||||
defaults.write_text("---\nfoo: 1\n", encoding="utf-8")
|
||||
template.write_text("value={{ foo }}\n", encoding="utf-8")
|
||||
return SimpleNamespace(returncode=0, stdout="ok", stderr="")
|
||||
|
||||
monkeypatch.setattr(subprocess, "run", fake_run)
|
||||
|
||||
src = tmp_path / "src.ini"
|
||||
src.write_text("foo=1\n", encoding="utf-8")
|
||||
|
||||
res = run_jinjaturtle("/bin/jinjaturtle", str(src), role_name="role1")
|
||||
assert "foo: 1" in res.vars_text
|
||||
assert "value=" in res.template_text
|
||||
|
||||
|
||||
def test_run_jinjaturtle_raises_on_failure(monkeypatch, tmp_path: Path):
|
||||
from enroll.jinjaturtle import run_jinjaturtle
|
||||
|
||||
def fake_run(cmd, **kwargs): # noqa: ARG001
|
||||
return SimpleNamespace(returncode=2, stdout="out", stderr="bad")
|
||||
|
||||
monkeypatch.setattr(subprocess, "run", fake_run)
|
||||
|
||||
src = tmp_path / "src.ini"
|
||||
src.write_text("x", encoding="utf-8")
|
||||
with pytest.raises(RuntimeError, match="jinjaturtle failed"):
|
||||
run_jinjaturtle("/bin/jinjaturtle", str(src), role_name="role1")
|
||||
|
||||
|
||||
def test_require_sops_cmd_errors_when_missing(monkeypatch):
|
||||
from enroll.sopsutil import require_sops_cmd, SopsError
|
||||
|
||||
monkeypatch.setattr("enroll.sopsutil.shutil.which", lambda _: None)
|
||||
with pytest.raises(SopsError, match="not found on PATH"):
|
||||
require_sops_cmd()
|
||||
|
||||
|
||||
def test_get_enroll_version_reports_unknown_on_metadata_failure(monkeypatch):
|
||||
import enroll.version as v
|
||||
|
||||
fake_meta = types.ModuleType("importlib.metadata")
|
||||
|
||||
def boom():
|
||||
raise RuntimeError("boom")
|
||||
|
||||
fake_meta.packages_distributions = boom
|
||||
fake_meta.version = lambda _dist: boom()
|
||||
|
||||
monkeypatch.setitem(sys.modules, "importlib.metadata", fake_meta)
|
||||
assert v.get_enroll_version() == "unknown"
|
||||
|
||||
|
||||
def test_get_enroll_version_returns_unknown_if_importlib_metadata_unavailable(
|
||||
monkeypatch,
|
||||
):
|
||||
import builtins
|
||||
import enroll.version as v
|
||||
|
||||
real_import = builtins.__import__
|
||||
|
||||
def fake_import(
|
||||
name, globals=None, locals=None, fromlist=(), level=0
|
||||
): # noqa: A002
|
||||
if name == "importlib.metadata":
|
||||
raise ImportError("no metadata")
|
||||
return real_import(name, globals, locals, fromlist, level)
|
||||
|
||||
monkeypatch.setattr(builtins, "__import__", fake_import)
|
||||
assert v.get_enroll_version() == "unknown"
|
||||
|
||||
|
||||
def test_compare_harvests_and_format_report(tmp_path: Path):
|
||||
from enroll.diff import compare_harvests, format_report
|
||||
|
||||
old = tmp_path / "old"
|
||||
new = tmp_path / "new"
|
||||
(old / "artifacts").mkdir(parents=True)
|
||||
(new / "artifacts").mkdir(parents=True)
|
||||
|
||||
def write_state(base: Path, state: dict) -> None:
|
||||
base.mkdir(parents=True, exist_ok=True)
|
||||
(base / "state.json").write_text(json.dumps(state, indent=2), encoding="utf-8")
|
||||
|
||||
# Old bundle: pkg a@1.0, pkg b@1.0, one service, one user, one managed file.
|
||||
old_state = {
|
||||
"schema_version": 3,
|
||||
"host": {"hostname": "h1"},
|
||||
"inventory": {"packages": {"a": {"version": "1.0"}, "b": {"version": "1.0"}}},
|
||||
"roles": {
|
||||
"services": [
|
||||
{
|
||||
"unit": "svc.service",
|
||||
"role_name": "svc",
|
||||
"packages": ["a"],
|
||||
"active_state": "inactive",
|
||||
"sub_state": "dead",
|
||||
"unit_file_state": "enabled",
|
||||
"condition_result": None,
|
||||
"managed_files": [
|
||||
{
|
||||
"path": "/etc/foo.conf",
|
||||
"src_rel": "etc/foo.conf",
|
||||
"owner": "root",
|
||||
"group": "root",
|
||||
"mode": "0644",
|
||||
"reason": "modified_conffile",
|
||||
}
|
||||
],
|
||||
}
|
||||
],
|
||||
"packages": [],
|
||||
"users": {
|
||||
"role_name": "users",
|
||||
"users": [{"name": "alice", "shell": "/bin/sh"}],
|
||||
},
|
||||
"apt_config": {"role_name": "apt_config", "managed_files": []},
|
||||
"etc_custom": {"role_name": "etc_custom", "managed_files": []},
|
||||
"usr_local_custom": {"role_name": "usr_local_custom", "managed_files": []},
|
||||
"extra_paths": {"role_name": "extra_paths", "managed_files": []},
|
||||
},
|
||||
}
|
||||
(old / "artifacts" / "svc" / "etc").mkdir(parents=True, exist_ok=True)
|
||||
(old / "artifacts" / "svc" / "etc" / "foo.conf").write_text("old", encoding="utf-8")
|
||||
write_state(old, old_state)
|
||||
|
||||
# New bundle: pkg a@2.0, pkg c@1.0, service changed, user changed, file moved role+content.
|
||||
new_state = {
|
||||
"schema_version": 3,
|
||||
"host": {"hostname": "h2"},
|
||||
"inventory": {"packages": {"a": {"version": "2.0"}, "c": {"version": "1.0"}}},
|
||||
"roles": {
|
||||
"services": [
|
||||
{
|
||||
"unit": "svc.service",
|
||||
"role_name": "svc",
|
||||
"packages": ["a", "c"],
|
||||
"active_state": "active",
|
||||
"sub_state": "running",
|
||||
"unit_file_state": "enabled",
|
||||
"condition_result": None,
|
||||
"managed_files": [],
|
||||
}
|
||||
],
|
||||
"packages": [],
|
||||
"users": {
|
||||
"role_name": "users",
|
||||
"users": [{"name": "alice", "shell": "/bin/bash"}, {"name": "bob"}],
|
||||
},
|
||||
"apt_config": {"role_name": "apt_config", "managed_files": []},
|
||||
"etc_custom": {"role_name": "etc_custom", "managed_files": []},
|
||||
"usr_local_custom": {"role_name": "usr_local_custom", "managed_files": []},
|
||||
"extra_paths": {
|
||||
"role_name": "extra_paths",
|
||||
"managed_files": [
|
||||
{
|
||||
"path": "/etc/foo.conf",
|
||||
"src_rel": "etc/foo.conf",
|
||||
"owner": "root",
|
||||
"group": "root",
|
||||
"mode": "0600",
|
||||
"reason": "user_include",
|
||||
},
|
||||
{
|
||||
"path": "/etc/added.conf",
|
||||
"src_rel": "etc/added.conf",
|
||||
"owner": "root",
|
||||
"group": "root",
|
||||
"mode": "0644",
|
||||
"reason": "user_include",
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
}
|
||||
(new / "artifacts" / "extra_paths" / "etc").mkdir(parents=True, exist_ok=True)
|
||||
(new / "artifacts" / "extra_paths" / "etc" / "foo.conf").write_text(
|
||||
"new", encoding="utf-8"
|
||||
)
|
||||
(new / "artifacts" / "extra_paths" / "etc" / "added.conf").write_text(
|
||||
"x", encoding="utf-8"
|
||||
)
|
||||
write_state(new, new_state)
|
||||
|
||||
report, changed = compare_harvests(str(old), str(new))
|
||||
assert changed is True
|
||||
|
||||
txt = format_report(report, fmt="text")
|
||||
assert "Packages" in txt
|
||||
|
||||
md = format_report(report, fmt="markdown")
|
||||
assert "# enroll diff report" in md
|
||||
|
||||
js = format_report(report, fmt="json")
|
||||
parsed = json.loads(js)
|
||||
assert parsed["packages"]["added"] == ["c"]
|
||||
|
|
@ -184,3 +184,157 @@ def test_expand_includes_respects_max_files(monkeypatch):
|
|||
paths, notes = pf.expand_includes(include, max_files=2)
|
||||
assert len(paths) == 2
|
||||
assert "/root/c" not in paths
|
||||
|
||||
|
||||
def test_has_glob_chars():
|
||||
assert pf._has_glob_chars("*.txt") is True
|
||||
assert pf._has_glob_chars("file?.log") is True
|
||||
assert pf._has_glob_chars("[abc]") is True
|
||||
assert pf._has_glob_chars("file.txt") is False
|
||||
assert pf._has_glob_chars("") is False
|
||||
|
||||
|
||||
def test_compile_path_pattern_regex_valid():
|
||||
result = pf.compile_path_pattern("re:^/home/.*$")
|
||||
assert result.kind == "regex"
|
||||
assert result.regex is not None
|
||||
assert result.regex.search("/home/user/file.txt") is not None
|
||||
assert result.regex.search("/var/file.txt") is None
|
||||
|
||||
|
||||
def test_compile_path_pattern_glob_forced():
|
||||
result = pf.compile_path_pattern("glob:/etc/*.conf")
|
||||
assert result.kind == "glob"
|
||||
assert result.value == "/etc/*.conf"
|
||||
|
||||
|
||||
def test_compile_path_pattern_glob_heuristic():
|
||||
result = pf.compile_path_pattern("/etc/*.conf")
|
||||
assert result.kind == "glob"
|
||||
|
||||
|
||||
def test_compile_path_pattern_prefix():
|
||||
result = pf.compile_path_pattern("/etc/nginx")
|
||||
assert result.kind == "prefix"
|
||||
assert result.value == "/etc/nginx"
|
||||
|
||||
|
||||
def test_compiled_pattern_matches_prefix():
|
||||
pat = pf.compile_path_pattern("/etc/nginx")
|
||||
assert pat.matches("/etc/nginx") is True
|
||||
assert pat.matches("/etc/nginx/conf.d") is True
|
||||
assert pat.matches("/etc/ssh") is False
|
||||
|
||||
|
||||
def test_compiled_pattern_matches_glob():
|
||||
pat = pf.compile_path_pattern("/etc/*.conf")
|
||||
assert pat.matches("/etc/ssh.conf") is True
|
||||
assert pat.matches("/etc/ssh/sshd.conf") is False
|
||||
|
||||
|
||||
def test_compiled_pattern_matches_regex():
|
||||
pat = pf.compile_path_pattern("re:^/home/[^/]+/.bashrc$")
|
||||
assert pat.matches("/home/alice/.bashrc") is True
|
||||
assert pat.matches("/home/bob/.bashrc") is True
|
||||
assert pat.matches("/home/alice/.profile") is False
|
||||
assert pat.matches("/var/.bashrc") is False
|
||||
|
||||
|
||||
def test_path_filter_is_excluded():
|
||||
pf_filter = pf.PathFilter(exclude=["/tmp/*", "/var/log"])
|
||||
assert pf_filter.is_excluded("/tmp/file.txt") is True
|
||||
assert pf_filter.is_excluded("/var/log/syslog") is True
|
||||
assert pf_filter.is_excluded("/etc/ssh") is False
|
||||
|
||||
|
||||
def test_path_filter_empty():
|
||||
pf_filter = pf.PathFilter()
|
||||
assert pf_filter.is_excluded("/anything") is False
|
||||
assert pf_filter.iter_include_patterns() == []
|
||||
|
||||
|
||||
def test_expand_includes_prefix_existing(tmp_path: Path):
|
||||
etc_dir = tmp_path / "etc"
|
||||
etc_dir.mkdir()
|
||||
(etc_dir / "file1.txt").write_text("a")
|
||||
(etc_dir / "file2.txt").write_text("b")
|
||||
|
||||
patterns = [pf.compile_path_pattern(str(etc_dir))]
|
||||
paths, notes = pf.expand_includes(patterns, max_files=10)
|
||||
|
||||
assert len(paths) == 2
|
||||
assert notes == []
|
||||
|
||||
|
||||
def test_expand_includes_prefix_nonexistent():
|
||||
patterns = [pf.compile_path_pattern("/nonexistent/path")]
|
||||
paths, notes = pf.expand_includes(patterns, max_files=10)
|
||||
|
||||
assert paths == []
|
||||
assert len(notes) == 1
|
||||
assert "matched no files" in notes[0]
|
||||
|
||||
|
||||
def test_expand_includes_glob_no_matches():
|
||||
patterns = [pf.compile_path_pattern("/nonexistent/*.txt")]
|
||||
paths, notes = pf.expand_includes(patterns, max_files=10)
|
||||
|
||||
assert paths == []
|
||||
assert len(notes) == 1
|
||||
|
||||
|
||||
def test_expand_includes_skips_symlinks(tmp_path: Path):
|
||||
real_file = tmp_path / "real.txt"
|
||||
real_file.write_text("x")
|
||||
link = tmp_path / "link.txt"
|
||||
os.symlink(str(real_file), str(link))
|
||||
|
||||
patterns = [pf.compile_path_pattern(str(tmp_path))]
|
||||
paths, notes = pf.expand_includes(patterns, max_files=10)
|
||||
|
||||
assert len(paths) == 1
|
||||
assert paths[0].endswith("real.txt")
|
||||
|
||||
|
||||
def test_expand_includes_excludes_pattern(tmp_path: Path):
|
||||
etc_dir = tmp_path / "etc"
|
||||
etc_dir.mkdir()
|
||||
(etc_dir / "include.txt").write_text("a")
|
||||
(etc_dir / "exclude.txt").write_text("b")
|
||||
|
||||
patterns = [pf.compile_path_pattern(str(etc_dir))]
|
||||
exclude = pf.PathFilter(exclude=["*exclude*"])
|
||||
paths, notes = pf.expand_includes(patterns, exclude=exclude, max_files=10)
|
||||
|
||||
assert len(paths) == 1
|
||||
assert paths[0].endswith("include.txt")
|
||||
|
||||
|
||||
def test_expand_includes_skips_directories(tmp_path: Path):
|
||||
subdir = tmp_path / "subdir"
|
||||
subdir.mkdir()
|
||||
(tmp_path / "file.txt").write_text("x")
|
||||
|
||||
patterns = [pf.compile_path_pattern(str(subdir))]
|
||||
paths, notes = pf.expand_includes(patterns, max_files=10)
|
||||
|
||||
assert paths == []
|
||||
|
||||
|
||||
def test_regex_literal_prefix_simple():
|
||||
assert pf._regex_literal_prefix("/etc/nginx/") == "/etc/nginx/"
|
||||
|
||||
|
||||
def test_regex_literal_prefix_with_anchor():
|
||||
assert pf._regex_literal_prefix("^/etc/nginx/") == "/etc/nginx/"
|
||||
|
||||
|
||||
def test_regex_literal_prefix_with_regex_chars():
|
||||
assert pf._regex_literal_prefix("^/etc/.*\\.conf$") == "/etc/"
|
||||
|
||||
|
||||
def test_path_filter_with_include_patterns():
|
||||
pf_filter = pf.PathFilter(include=["/etc/*.conf"], exclude=["/etc/secret.conf"])
|
||||
patterns = pf_filter.iter_include_patterns()
|
||||
assert len(patterns) == 1
|
||||
assert patterns[0].kind == "glob"
|
||||
|
|
|
|||
|
|
@ -91,3 +91,176 @@ def test_specific_paths_for_hints_differs_between_backends():
|
|||
paths = set(r.specific_paths_for_hints({"nginx"}))
|
||||
assert "/etc/sysconfig/nginx" in paths
|
||||
assert "/etc/sysconfig/nginx.conf" in paths
|
||||
|
||||
|
||||
def test_read_os_release_file_not_found(tmp_path: Path):
|
||||
result = platform._read_os_release(str(tmp_path / "nonexistent"))
|
||||
assert result == {}
|
||||
|
||||
|
||||
def test_read_os_release_handles_invalid_line(tmp_path: Path):
|
||||
p = tmp_path / "os-release"
|
||||
p.write_text(
|
||||
"ID=ubuntu\n" "NO_EQUALS_SIGN\n" 'VERSION="22.04"\n',
|
||||
encoding="utf-8",
|
||||
)
|
||||
result = platform._read_os_release(str(p))
|
||||
assert result["ID"] == "ubuntu"
|
||||
assert result["VERSION"] == "22.04"
|
||||
assert "NO_EQUALS_SIGN" not in result
|
||||
|
||||
|
||||
def test_detect_platform_debian(monkeypatch):
|
||||
def fake_read_os_release(path: str = "/etc/os-release") -> dict:
|
||||
return {"ID": "debian", "VERSION_ID": "11"}
|
||||
|
||||
monkeypatch.setattr(platform, "_read_os_release", fake_read_os_release)
|
||||
result = platform.detect_platform()
|
||||
assert result.os_family == "debian"
|
||||
assert result.pkg_backend == "dpkg"
|
||||
|
||||
|
||||
def test_detect_platform_ubuntu(monkeypatch):
|
||||
def fake_read_os_release(path: str = "/etc/os-release") -> dict:
|
||||
return {"ID": "ubuntu", "VERSION_ID": "22.04"}
|
||||
|
||||
monkeypatch.setattr(platform, "_read_os_release", fake_read_os_release)
|
||||
result = platform.detect_platform()
|
||||
assert result.os_family == "debian"
|
||||
assert result.pkg_backend == "dpkg"
|
||||
|
||||
|
||||
def test_detect_platform_fedora(monkeypatch):
|
||||
def fake_read_os_release(path: str = "/etc/os-release") -> dict:
|
||||
return {"ID": "fedora", "VERSION_ID": "38"}
|
||||
|
||||
monkeypatch.setattr(platform, "_read_os_release", fake_read_os_release)
|
||||
result = platform.detect_platform()
|
||||
assert result.os_family == "redhat"
|
||||
assert result.pkg_backend == "rpm"
|
||||
|
||||
|
||||
def test_detect_platform_rocky(monkeypatch):
|
||||
def fake_read_os_release(path: str = "/etc/os-release") -> dict:
|
||||
return {"ID": "rocky", "VERSION_ID": "9"}
|
||||
|
||||
monkeypatch.setattr(platform, "_read_os_release", fake_read_os_release)
|
||||
result = platform.detect_platform()
|
||||
assert result.os_family == "redhat"
|
||||
assert result.pkg_backend == "rpm"
|
||||
|
||||
|
||||
def test_detect_platform_unknown_fallback_to_dpkg(monkeypatch):
|
||||
def fake_read_os_release(path: str = "/etc/os-release") -> dict:
|
||||
return {"ID": "unknown"}
|
||||
|
||||
monkeypatch.setattr(platform, "_read_os_release", fake_read_os_release)
|
||||
monkeypatch.setattr(platform.shutil, "which", lambda x: x == "dpkg")
|
||||
result = platform.detect_platform()
|
||||
assert result.os_family == "debian"
|
||||
assert result.pkg_backend == "dpkg"
|
||||
|
||||
|
||||
def test_detect_platform_unknown_fallback_to_rpm(monkeypatch):
|
||||
def fake_read_os_release(path: str = "/etc/os-release") -> dict:
|
||||
return {"ID": "unknown"}
|
||||
|
||||
monkeypatch.setattr(platform, "_read_os_release", fake_read_os_release)
|
||||
monkeypatch.setattr(platform.shutil, "which", lambda x: x == "rpm")
|
||||
result = platform.detect_platform()
|
||||
assert result.os_family == "redhat"
|
||||
assert result.pkg_backend == "rpm"
|
||||
|
||||
|
||||
def test_detect_platform_completely_unknown(monkeypatch):
|
||||
def fake_read_os_release(path: str = "/etc/os-release") -> dict:
|
||||
return {"ID": "unknown"}
|
||||
|
||||
monkeypatch.setattr(platform, "_read_os_release", fake_read_os_release)
|
||||
monkeypatch.setattr(platform.shutil, "which", lambda x: False)
|
||||
result = platform.detect_platform()
|
||||
assert result.os_family == "unknown"
|
||||
assert result.pkg_backend == "unknown"
|
||||
|
||||
|
||||
def test_detect_platform_debian_like(monkeypatch):
|
||||
def fake_read_os_release(path: str = "/etc/os-release") -> dict:
|
||||
return {"ID": "linuxmint", "ID_LIKE": "debian"}
|
||||
|
||||
monkeypatch.setattr(platform, "_read_os_release", fake_read_os_release)
|
||||
result = platform.detect_platform()
|
||||
assert result.os_family == "debian"
|
||||
assert result.pkg_backend == "dpkg"
|
||||
|
||||
|
||||
def test_detect_platform_rhel_like(monkeypatch):
|
||||
def fake_read_os_release(path: str = "/etc/os-release") -> dict:
|
||||
return {"ID": "centos", "ID_LIKE": "rhel fedora"}
|
||||
|
||||
monkeypatch.setattr(platform, "_read_os_release", fake_read_os_release)
|
||||
result = platform.detect_platform()
|
||||
assert result.os_family == "redhat"
|
||||
assert result.pkg_backend == "rpm"
|
||||
|
||||
|
||||
def test_get_backend_returns_dpkg(monkeypatch):
|
||||
info = platform.PlatformInfo(os_family="debian", pkg_backend="dpkg", os_release={})
|
||||
backend = platform.get_backend(info)
|
||||
assert isinstance(backend, platform.DpkgBackend)
|
||||
assert backend.name == "dpkg"
|
||||
|
||||
|
||||
def test_get_backend_returns_rpm(monkeypatch):
|
||||
info = platform.PlatformInfo(os_family="redhat", pkg_backend="rpm", os_release={})
|
||||
backend = platform.get_backend(info)
|
||||
assert isinstance(backend, platform.RpmBackend)
|
||||
assert backend.name == "rpm"
|
||||
|
||||
|
||||
def test_get_backend_unknown_with_rpm(monkeypatch):
|
||||
info = platform.PlatformInfo(
|
||||
os_family="unknown", pkg_backend="unknown", os_release={}
|
||||
)
|
||||
monkeypatch.setattr(platform.shutil, "which", lambda x: x == "rpm")
|
||||
backend = platform.get_backend(info)
|
||||
assert isinstance(backend, platform.RpmBackend)
|
||||
|
||||
|
||||
def test_get_backend_unknown_with_dpkg(monkeypatch):
|
||||
info = platform.PlatformInfo(
|
||||
os_family="unknown", pkg_backend="unknown", os_release={}
|
||||
)
|
||||
monkeypatch.setattr(platform.shutil, "which", lambda x: x == "dpkg")
|
||||
backend = platform.get_backend(info)
|
||||
assert isinstance(backend, platform.DpkgBackend)
|
||||
|
||||
|
||||
def test_dpkg_backend_specific_paths():
|
||||
backend = platform.DpkgBackend()
|
||||
paths = backend.specific_paths_for_hints({"nginx"})
|
||||
assert "/etc/default/nginx" in paths
|
||||
assert "/etc/init.d/nginx" in paths
|
||||
assert "/etc/sysctl.d/nginx.conf" in paths
|
||||
|
||||
|
||||
def test_rpm_backend_specific_paths():
|
||||
backend = platform.RpmBackend()
|
||||
paths = backend.specific_paths_for_hints({"nginx"})
|
||||
assert "/etc/sysconfig/nginx" in paths
|
||||
assert "/etc/sysconfig/nginx.conf" in paths
|
||||
assert "/etc/sysctl.d/nginx.conf" in paths
|
||||
|
||||
|
||||
def test_is_pkg_config_path_dpkg():
|
||||
backend = platform.DpkgBackend()
|
||||
assert backend.is_pkg_config_path("/etc/apt/sources.list") is True
|
||||
assert backend.is_pkg_config_path("/etc/apt/trusted.gpg") is True
|
||||
assert backend.is_pkg_config_path("/etc/ssh/sshd_config") is False
|
||||
|
||||
|
||||
def test_is_pkg_config_path_rpm():
|
||||
backend = platform.RpmBackend()
|
||||
assert backend.is_pkg_config_path("/etc/dnf/dnf.conf") is True
|
||||
assert backend.is_pkg_config_path("/etc/yum.conf") is True
|
||||
assert backend.is_pkg_config_path("/etc/yum.repos.d/custom.repo") is True
|
||||
assert backend.is_pkg_config_path("/etc/ssh/sshd_config") is False
|
||||
|
|
|
|||
|
|
@ -565,3 +565,452 @@ def test_remote_harvest_sudo_password_retry_uses_sudo_s_and_writes_password(
|
|||
|
||||
# Ensure the password was written to stdin for the -S invocation.
|
||||
assert stdin_by_cmd.get(sudo_s[0]) == ["s3cr3t\n"]
|
||||
|
||||
|
||||
def test_sudo_password_required_detection():
|
||||
from enroll.remote import _sudo_password_required
|
||||
|
||||
assert _sudo_password_required("", "a password is required") is True
|
||||
assert _sudo_password_required("", "password is required") is True
|
||||
assert (
|
||||
_sudo_password_required("", "a terminal is required to read the password")
|
||||
is True
|
||||
)
|
||||
assert (
|
||||
_sudo_password_required("", "no tty present and no askpass program specified")
|
||||
is True
|
||||
)
|
||||
assert _sudo_password_required("", "must have a tty to run sudo") is True
|
||||
assert _sudo_password_required("", "sudo: sorry, you must have a tty") is True
|
||||
assert _sudo_password_required("", "askpass") is True
|
||||
assert _sudo_password_required("success", "") is False
|
||||
|
||||
|
||||
def test_sudo_not_permitted_detection():
|
||||
from enroll.remote import _sudo_not_permitted
|
||||
|
||||
assert _sudo_not_permitted("", "user is not in the sudoers file") is True
|
||||
assert _sudo_not_permitted("", "not allowed to execute") is True
|
||||
assert _sudo_not_permitted("", "may not run sudo") is True
|
||||
assert _sudo_not_permitted("", "sorry, user") is True
|
||||
assert _sudo_not_permitted("success", "") is False
|
||||
|
||||
|
||||
def test_sudo_tty_required_detection():
|
||||
from enroll.remote import _sudo_tty_required
|
||||
|
||||
assert _sudo_tty_required("", "must have a tty") is True
|
||||
assert _sudo_tty_required("", "sorry, you must have a tty") is True
|
||||
assert _sudo_tty_required("", "sudo: sorry, you must have a tty") is True
|
||||
assert _sudo_tty_required("", "must have a tty to run sudo") is True
|
||||
assert _sudo_tty_required("success", "") is False
|
||||
|
||||
|
||||
def test_resolve_become_password_prompts_when_asked(monkeypatch):
|
||||
from enroll.remote import _resolve_become_password
|
||||
|
||||
prompted = []
|
||||
|
||||
def fake_getpass(prompt):
|
||||
prompted.append(prompt)
|
||||
return "secret"
|
||||
|
||||
result = _resolve_become_password(
|
||||
True, prompt="sudo password: ", getpass_fn=fake_getpass
|
||||
)
|
||||
assert result == "secret"
|
||||
assert len(prompted) == 1
|
||||
|
||||
|
||||
def test_resolve_become_password_returns_none_when_not_asked():
|
||||
from enroll.remote import _resolve_become_password
|
||||
|
||||
result = _resolve_become_password(False)
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_resolve_ssh_key_passphrase_from_env(monkeypatch):
|
||||
from enroll.remote import _resolve_ssh_key_passphrase
|
||||
|
||||
monkeypatch.setenv("SSH_KEY_PASS", "env_secret")
|
||||
|
||||
result = _resolve_ssh_key_passphrase(False, env_var="SSH_KEY_PASS")
|
||||
assert result == "env_secret"
|
||||
|
||||
|
||||
def test_resolve_ssh_key_passphrase_raises_when_env_not_set(monkeypatch):
|
||||
from enroll.remote import _resolve_ssh_key_passphrase
|
||||
|
||||
monkeypatch.delenv("SSH_KEY_PASS", raising=False)
|
||||
|
||||
with pytest.raises(RuntimeError, match="SSH key passphrase environment variable"):
|
||||
_resolve_ssh_key_passphrase(False, env_var="SSH_KEY_PASS")
|
||||
|
||||
|
||||
def test_resolve_ssh_key_passphrase_prompts_when_asked(monkeypatch):
|
||||
from enroll.remote import _resolve_ssh_key_passphrase
|
||||
|
||||
prompted = []
|
||||
|
||||
def fake_getpass(prompt):
|
||||
prompted.append(prompt)
|
||||
return "prompt_secret"
|
||||
|
||||
result = _resolve_ssh_key_passphrase(
|
||||
True, prompt="SSH key passphrase: ", getpass_fn=fake_getpass
|
||||
)
|
||||
assert result == "prompt_secret"
|
||||
assert len(prompted) == 1
|
||||
|
||||
|
||||
def test_resolve_ssh_key_passphrase_returns_none_when_not_asked():
|
||||
from enroll.remote import _resolve_ssh_key_passphrase
|
||||
|
||||
result = _resolve_ssh_key_passphrase(False, env_var=None)
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_safe_extract_tar_rejects_absolute_paths(tmp_path: Path):
|
||||
from enroll.remote import _safe_extract_tar
|
||||
|
||||
import io
|
||||
import tarfile
|
||||
|
||||
bio = io.BytesIO()
|
||||
with tarfile.open(fileobj=bio, mode="w:gz") as tf:
|
||||
ti = tarfile.TarInfo(name="/etc/passwd")
|
||||
ti.size = 1
|
||||
tf.addfile(ti, io.BytesIO(b"x"))
|
||||
|
||||
bio.seek(0)
|
||||
with tarfile.open(fileobj=bio, mode="r:gz") as tf:
|
||||
with pytest.raises(RuntimeError, match="Unsafe tar member path"):
|
||||
_safe_extract_tar(tf, tmp_path)
|
||||
|
||||
|
||||
def test_safe_extract_tar_rejects_hardlinks(tmp_path: Path):
|
||||
from enroll.remote import _safe_extract_tar
|
||||
|
||||
import io
|
||||
import tarfile
|
||||
|
||||
bio = io.BytesIO()
|
||||
with tarfile.open(fileobj=bio, mode="w:gz") as tf:
|
||||
ti = tarfile.TarInfo(name="hardlink")
|
||||
ti.type = tarfile.LNKTYPE
|
||||
ti.linkname = "/etc/passwd"
|
||||
tf.addfile(ti)
|
||||
|
||||
bio.seek(0)
|
||||
with tarfile.open(fileobj=bio, mode="r:gz") as tf:
|
||||
with pytest.raises(RuntimeError, match="Refusing to extract"):
|
||||
_safe_extract_tar(tf, tmp_path)
|
||||
|
||||
|
||||
def test_safe_extract_tar_rejects_device_nodes(tmp_path: Path):
|
||||
from enroll.remote import _safe_extract_tar
|
||||
|
||||
import io
|
||||
import tarfile
|
||||
|
||||
bio = io.BytesIO()
|
||||
with tarfile.open(fileobj=bio, mode="w:gz") as tf:
|
||||
ti = tarfile.TarInfo(name="device")
|
||||
ti.type = tarfile.CHRTYPE
|
||||
tf.addfile(ti)
|
||||
|
||||
bio.seek(0)
|
||||
with tarfile.open(fileobj=bio, mode="r:gz") as tf:
|
||||
with pytest.raises(RuntimeError, match="Refusing to extract"):
|
||||
_safe_extract_tar(tf, tmp_path)
|
||||
|
||||
|
||||
def test_safe_extract_tar_accepts_dot_entry(tmp_path: Path):
|
||||
from enroll.remote import _safe_extract_tar
|
||||
|
||||
import io
|
||||
import tarfile
|
||||
|
||||
bio = io.BytesIO()
|
||||
with tarfile.open(fileobj=bio, mode="w:gz") as tf:
|
||||
ti = tarfile.TarInfo(name=".")
|
||||
ti.size = 0
|
||||
tf.addfile(ti, io.BytesIO(b""))
|
||||
|
||||
bio.seek(0)
|
||||
with tarfile.open(fileobj=bio, mode="r:gz") as tf:
|
||||
_safe_extract_tar(tf, tmp_path)
|
||||
|
||||
|
||||
def test_safe_extract_tar_accepts_valid_files(tmp_path: Path):
|
||||
from enroll.remote import _safe_extract_tar
|
||||
|
||||
import io
|
||||
import tarfile
|
||||
|
||||
bio = io.BytesIO()
|
||||
with tarfile.open(fileobj=bio, mode="w:gz") as tf:
|
||||
ti = tarfile.TarInfo(name="foo/bar.txt")
|
||||
ti.size = 5
|
||||
tf.addfile(ti, io.BytesIO(b"hello"))
|
||||
|
||||
bio.seek(0)
|
||||
with tarfile.open(fileobj=bio, mode="r:gz") as tf:
|
||||
_safe_extract_tar(tf, tmp_path)
|
||||
|
||||
assert (tmp_path / "foo" / "bar.txt").read_bytes() == b"hello"
|
||||
|
||||
|
||||
def test_remote_harvest_ssh_key_passphrase_retry(monkeypatch, tmp_path: Path):
|
||||
import sys
|
||||
|
||||
import enroll.remote as r
|
||||
|
||||
monkeypatch.setattr(
|
||||
r,
|
||||
"_build_enroll_pyz",
|
||||
lambda td: (Path(td) / "enroll.pyz").write_bytes(b"PYZ")
|
||||
or (Path(td) / "enroll.pyz"),
|
||||
)
|
||||
|
||||
tgz = _make_tgz_bytes({"state.json": b'{"ok": true}\n'})
|
||||
|
||||
class _Chan:
|
||||
def __init__(self, out: bytes = b"", err: bytes = b"", rc: int = 0):
|
||||
self._out = out
|
||||
self._err = err
|
||||
self._out_i = 0
|
||||
self._err_i = 0
|
||||
self._rc = rc
|
||||
self._closed = False
|
||||
|
||||
def recv_ready(self) -> bool:
|
||||
return (not self._closed) and self._out_i < len(self._out)
|
||||
|
||||
def recv(self, n: int) -> bytes:
|
||||
if self._closed:
|
||||
return b""
|
||||
chunk = self._out[self._out_i : self._out_i + n]
|
||||
self._out_i += len(chunk)
|
||||
return chunk
|
||||
|
||||
def recv_stderr_ready(self) -> bool:
|
||||
return (not self._closed) and self._err_i < len(self._err)
|
||||
|
||||
def recv_stderr(self, n: int) -> bytes:
|
||||
if self._closed:
|
||||
return b""
|
||||
chunk = self._err[self._err_i : self._err_i + n]
|
||||
self._err_i += len(chunk)
|
||||
return chunk
|
||||
|
||||
def exit_status_ready(self) -> bool:
|
||||
return self._closed or (
|
||||
self._out_i >= len(self._out) and self._err_i >= len(self._err)
|
||||
)
|
||||
|
||||
def recv_exit_status(self) -> int:
|
||||
return self._rc
|
||||
|
||||
def shutdown_write(self) -> None:
|
||||
return
|
||||
|
||||
def close(self) -> None:
|
||||
self._closed = True
|
||||
|
||||
class _Stdout:
|
||||
def __init__(self, payload: bytes = b"", rc: int = 0, err: bytes = b""):
|
||||
self._bio = io.BytesIO(payload)
|
||||
self.channel = _Chan(out=payload, err=err, rc=rc)
|
||||
|
||||
def read(self, n: int = -1) -> bytes:
|
||||
return self._bio.read(n)
|
||||
|
||||
class _Stderr:
|
||||
def __init__(self, payload: bytes = b""):
|
||||
self._bio = io.BytesIO(payload)
|
||||
|
||||
def read(self, n: int = -1) -> bytes:
|
||||
return self._bio.read(n)
|
||||
|
||||
class _Stdin:
|
||||
def __init__(self, cmd: str):
|
||||
self._cmd = cmd
|
||||
|
||||
def write(self, s: str) -> None:
|
||||
pass
|
||||
|
||||
def flush(self) -> None:
|
||||
return
|
||||
|
||||
class _SFTP:
|
||||
def put(self, _local: str, _remote: str) -> None:
|
||||
return
|
||||
|
||||
def close(self) -> None:
|
||||
return
|
||||
|
||||
class FakeSSH:
|
||||
def __init__(self):
|
||||
self._sftp = _SFTP()
|
||||
|
||||
def load_system_host_keys(self):
|
||||
return
|
||||
|
||||
def set_missing_host_key_policy(self, _policy):
|
||||
return
|
||||
|
||||
def connect(self, **_kwargs):
|
||||
return
|
||||
|
||||
def open_sftp(self):
|
||||
return self._sftp
|
||||
|
||||
def exec_command(self, cmd: str, *, get_pty: bool = False, **_kwargs):
|
||||
if cmd.startswith("tar -cz -C"):
|
||||
return (_Stdin(cmd), _Stdout(tgz, rc=0), _Stderr(b""))
|
||||
if cmd == "mktemp -d":
|
||||
return (_Stdin(cmd), _Stdout(b"/tmp/enroll-remote-789\n"), _Stderr())
|
||||
if cmd.startswith("chmod 700"):
|
||||
return (_Stdin(cmd), _Stdout(b""), _Stderr())
|
||||
if " harvest " in cmd:
|
||||
return (_Stdin(cmd), _Stdout(b""), _Stderr())
|
||||
if cmd.startswith("rm -rf"):
|
||||
return (_Stdin(cmd), _Stdout(b""), _Stderr())
|
||||
return (_Stdin(cmd), _Stdout(b""), _Stderr())
|
||||
|
||||
def close(self):
|
||||
return
|
||||
|
||||
RejectPolicy4 = type("RejectPolicy", (), {})
|
||||
|
||||
class FakeParamiko:
|
||||
SSHClient = FakeSSH
|
||||
RejectPolicy = RejectPolicy4 # type: ignore
|
||||
PasswordRequiredException = Exception # type: ignore
|
||||
|
||||
monkeypatch.setitem(sys.modules, "paramiko", FakeParamiko)
|
||||
|
||||
prompts = []
|
||||
|
||||
def fake_getpass(prompt):
|
||||
prompts.append(prompt)
|
||||
return "passphrase"
|
||||
|
||||
out_dir = tmp_path / "out"
|
||||
state_path = r.remote_harvest(
|
||||
ask_key_passphrase=True,
|
||||
getpass_fn=fake_getpass,
|
||||
local_out_dir=out_dir,
|
||||
remote_host="example.com",
|
||||
remote_user="alice",
|
||||
no_sudo=True,
|
||||
)
|
||||
|
||||
assert state_path.exists()
|
||||
assert len(prompts) == 1
|
||||
|
||||
|
||||
def test_remote_harvest_ssh_key_passphrase_raises_when_not_interactive(
|
||||
monkeypatch, tmp_path: Path
|
||||
):
|
||||
import sys
|
||||
|
||||
import enroll.remote as r
|
||||
|
||||
monkeypatch.setattr(
|
||||
r,
|
||||
"_build_enroll_pyz",
|
||||
lambda td: (Path(td) / "enroll.pyz").write_bytes(b"PYZ")
|
||||
or (Path(td) / "enroll.pyz"),
|
||||
)
|
||||
|
||||
class _Chan:
|
||||
def __init__(self):
|
||||
self._closed = False
|
||||
|
||||
def recv_ready(self) -> bool:
|
||||
return False
|
||||
|
||||
def recv(self, n: int) -> bytes:
|
||||
return b""
|
||||
|
||||
def recv_stderr_ready(self) -> bool:
|
||||
return False
|
||||
|
||||
def recv_stderr(self, n: int) -> bytes:
|
||||
return b""
|
||||
|
||||
def exit_status_ready(self) -> bool:
|
||||
return True
|
||||
|
||||
def recv_exit_status(self) -> int:
|
||||
return 0
|
||||
|
||||
def shutdown_write(self) -> None:
|
||||
return
|
||||
|
||||
def close(self) -> None:
|
||||
self._closed = True
|
||||
|
||||
class _Stdout:
|
||||
def __init__(self):
|
||||
self.channel = _Chan()
|
||||
|
||||
def read(self, n: int = -1) -> bytes:
|
||||
return b""
|
||||
|
||||
class _Stderr:
|
||||
def read(self, n: int = -1) -> bytes:
|
||||
return b""
|
||||
|
||||
class _SFTP:
|
||||
def put(self, _local: str, _remote: str) -> None:
|
||||
return
|
||||
|
||||
def close(self) -> None:
|
||||
return
|
||||
|
||||
class FakeSSH:
|
||||
def __init__(self):
|
||||
self._sftp = _SFTP()
|
||||
|
||||
def load_system_host_keys(self):
|
||||
return
|
||||
|
||||
def set_missing_host_key_policy(self, _policy):
|
||||
return
|
||||
|
||||
def connect(self, **_kwargs):
|
||||
raise Exception("PasswordRequired")
|
||||
|
||||
def open_sftp(self):
|
||||
return self._sftp
|
||||
|
||||
def exec_command(self, cmd: str, **_kwargs):
|
||||
return (_Stdout(), _Stdout(), _Stderr())
|
||||
|
||||
def close(self):
|
||||
return
|
||||
|
||||
class RejectPolicy:
|
||||
pass
|
||||
|
||||
RejectPolicy3 = RejectPolicy
|
||||
|
||||
class FakeParamiko:
|
||||
SSHClient = FakeSSH
|
||||
RejectPolicy = RejectPolicy3 # type: ignore
|
||||
PasswordRequiredException = Exception # type: ignore
|
||||
|
||||
monkeypatch.setitem(sys.modules, "paramiko", FakeParamiko)
|
||||
|
||||
out_dir = tmp_path / "out"
|
||||
|
||||
with pytest.raises(RuntimeError, match="SSH private key is encrypted"):
|
||||
r.remote_harvest(
|
||||
ask_key_passphrase=False,
|
||||
local_out_dir=out_dir,
|
||||
remote_host="example.com",
|
||||
stdin=io.StringIO(),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from __future__ import annotations
|
||||
import pytest
|
||||
|
||||
import enroll.rpm as rpm
|
||||
|
||||
|
|
@ -176,3 +177,33 @@ def test_rpm_owner_strips_epoch_prefix_when_present(monkeypatch):
|
|||
lambda cmd, allow_fail=False, merge_err=False: (0, "1:bash\n"),
|
||||
)
|
||||
assert rpm.rpm_owner("/bin/bash") == "bash"
|
||||
|
||||
|
||||
def test_strip_arch_no_suffix():
|
||||
assert rpm._strip_arch("vim") == "vim"
|
||||
assert rpm._strip_arch("nginx ") == "nginx"
|
||||
|
||||
|
||||
def test_strip_arch_with_unknown_suffix():
|
||||
assert rpm._strip_arch("package.unknown") == "package.unknown"
|
||||
|
||||
|
||||
def test_run_command_raises_on_fail():
|
||||
with pytest.raises(RuntimeError):
|
||||
rpm._run(["sh", "-c", "echo stderr >&2; exit 1"], allow_fail=False)
|
||||
|
||||
|
||||
def test_rpm_owner_empty_path():
|
||||
assert rpm.rpm_owner("") is None
|
||||
|
||||
|
||||
def test_rpm_modified_files_empty(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
rpm, "_run", lambda cmd, allow_fail=False, merge_err=False: (0, "")
|
||||
)
|
||||
assert rpm.rpm_modified_files("vim") == set()
|
||||
|
||||
|
||||
def test_list_manual_packages_no_commands_available(monkeypatch):
|
||||
monkeypatch.setattr(rpm.shutil, "which", lambda exe: None)
|
||||
assert rpm.list_manual_packages() == []
|
||||
|
|
|
|||
234
tests/test_sopsutil.py
Normal file
234
tests/test_sopsutil.py
Normal file
|
|
@ -0,0 +1,234 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from pathlib import Path
|
||||
from enroll.sopsutil import SopsError, _pgp_arg, find_sops_cmd, require_sops_cmd
|
||||
|
||||
|
||||
def test_find_sops_cmd():
|
||||
result = find_sops_cmd()
|
||||
if result is None:
|
||||
pytest.skip("sops not installed")
|
||||
assert result.endswith("sops")
|
||||
|
||||
|
||||
def test_require_sops_cmd():
|
||||
exe = require_sops_cmd()
|
||||
assert exe is not None
|
||||
assert "sops" in exe
|
||||
|
||||
|
||||
def test_require_sops_cmd_raises_when_not_found(monkeypatch):
|
||||
import enroll.sopsutil as s
|
||||
|
||||
def fake_find():
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(s, "find_sops_cmd", fake_find)
|
||||
|
||||
with pytest.raises(SopsError) as exc_info:
|
||||
require_sops_cmd()
|
||||
|
||||
assert "sops" in str(exc_info.value).lower()
|
||||
assert "not found" in str(exc_info.value).lower()
|
||||
|
||||
|
||||
def test_pgp_arg_with_empty_fingerprints():
|
||||
with pytest.raises(SopsError) as exc_info:
|
||||
_pgp_arg([])
|
||||
assert "No GPG fingerprints" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_pgp_arg_with_whitespace_fingerprints():
|
||||
result = _pgp_arg([" ", "ABC123", " DEF456 "])
|
||||
assert result == "ABC123,DEF456"
|
||||
|
||||
|
||||
def test_pgp_arg_with_single_fingerprint():
|
||||
result = _pgp_arg(["ABC123DEF456"])
|
||||
assert result == "ABC123DEF456"
|
||||
|
||||
|
||||
def test_pgp_arg_with_multiple_fingerprints():
|
||||
result = _pgp_arg(["ABC123", "DEF456", "GHI789"])
|
||||
assert result == "ABC123,DEF456,GHI789"
|
||||
|
||||
|
||||
def test_encrypt_file_binary_success(monkeypatch, tmp_path: Path):
|
||||
"""Test successful encryption path."""
|
||||
# Create source file
|
||||
src = tmp_path / "secret.txt"
|
||||
src.write_text("secret data", encoding="utf-8")
|
||||
dst = tmp_path / "encrypted.sops"
|
||||
|
||||
# Mock subprocess.run to succeed
|
||||
class Result:
|
||||
returncode = 0
|
||||
stdout = b"encrypted data"
|
||||
stderr = b""
|
||||
|
||||
def fake_run(cmd, capture_output, check):
|
||||
return Result()
|
||||
|
||||
# Mock require_sops_cmd to return a fake path
|
||||
def fake_require():
|
||||
return "/fake/sops"
|
||||
|
||||
monkeypatch.setattr("enroll.sopsutil.subprocess.run", fake_run)
|
||||
monkeypatch.setattr("enroll.sopsutil.require_sops_cmd", fake_require)
|
||||
|
||||
from enroll.sopsutil import encrypt_file_binary
|
||||
|
||||
encrypt_file_binary(src, dst, pgp_fingerprints=["ABC123"])
|
||||
|
||||
assert dst.exists()
|
||||
assert dst.read_bytes() == b"encrypted data"
|
||||
|
||||
|
||||
def test_encrypt_file_binary_fails(monkeypatch, tmp_path: Path):
|
||||
"""Test encryption failure path."""
|
||||
src = tmp_path / "secret.txt"
|
||||
src.write_text("secret data", encoding="utf-8")
|
||||
dst = tmp_path / "encrypted.sops"
|
||||
|
||||
class Result:
|
||||
returncode = 1
|
||||
stdout = b""
|
||||
stderr = b"sops: gpg error"
|
||||
|
||||
def fake_run(cmd, capture_output, check):
|
||||
return Result()
|
||||
|
||||
def fake_require():
|
||||
return "/fake/sops"
|
||||
|
||||
monkeypatch.setattr("enroll.sopsutil.subprocess.run", fake_run)
|
||||
monkeypatch.setattr("enroll.sopsutil.require_sops_cmd", fake_require)
|
||||
|
||||
from enroll.sopsutil import encrypt_file_binary, SopsError
|
||||
|
||||
with pytest.raises(SopsError) as exc_info:
|
||||
encrypt_file_binary(src, dst, pgp_fingerprints=["ABC123"])
|
||||
|
||||
assert "encryption failed" in str(exc_info.value).lower()
|
||||
|
||||
|
||||
def test_encrypt_file_binary_chmod_fails(monkeypatch, tmp_path: Path):
|
||||
"""Test when chmod fails but file is still written."""
|
||||
src = tmp_path / "secret.txt"
|
||||
src.write_text("secret data", encoding="utf-8")
|
||||
dst = tmp_path / "encrypted.sops"
|
||||
|
||||
class Result:
|
||||
returncode = 0
|
||||
stdout = b"encrypted data"
|
||||
stderr = b""
|
||||
|
||||
def fake_run(cmd, capture_output, check):
|
||||
return Result()
|
||||
|
||||
def fake_require():
|
||||
return "/fake/sops"
|
||||
|
||||
def fake_chmod(path, mode):
|
||||
raise OSError("Permission denied")
|
||||
|
||||
monkeypatch.setattr("enroll.sopsutil.subprocess.run", fake_run)
|
||||
monkeypatch.setattr("enroll.sopsutil.require_sops_cmd", fake_require)
|
||||
monkeypatch.setattr("enroll.sopsutil.os.chmod", fake_chmod)
|
||||
|
||||
from enroll.sopsutil import encrypt_file_binary
|
||||
|
||||
# Should not raise even though chmod fails
|
||||
encrypt_file_binary(src, dst, pgp_fingerprints=["ABC123"])
|
||||
|
||||
assert dst.exists()
|
||||
|
||||
|
||||
def test_decrypt_file_binary_to_success(monkeypatch, tmp_path: Path):
|
||||
"""Test successful decryption path."""
|
||||
src = tmp_path / "encrypted.sops"
|
||||
src.write_bytes(b"encrypted data")
|
||||
dst = tmp_path / "decrypted.txt"
|
||||
|
||||
class Result:
|
||||
returncode = 0
|
||||
stdout = b"decrypted data"
|
||||
stderr = b""
|
||||
|
||||
def fake_run(cmd, capture_output, check):
|
||||
return Result()
|
||||
|
||||
def fake_require():
|
||||
return "/fake/sops"
|
||||
|
||||
monkeypatch.setattr("enroll.sopsutil.subprocess.run", fake_run)
|
||||
monkeypatch.setattr("enroll.sopsutil.require_sops_cmd", fake_require)
|
||||
|
||||
from enroll.sopsutil import decrypt_file_binary_to
|
||||
|
||||
decrypt_file_binary_to(src, dst)
|
||||
|
||||
assert dst.exists()
|
||||
assert dst.read_bytes() == b"decrypted data"
|
||||
|
||||
|
||||
def test_decrypt_file_binary_to_fails(monkeypatch, tmp_path: Path):
|
||||
"""Test decryption failure path."""
|
||||
src = tmp_path / "encrypted.sops"
|
||||
src.write_bytes(b"encrypted data")
|
||||
dst = tmp_path / "decrypted.txt"
|
||||
|
||||
class Result:
|
||||
returncode = 1
|
||||
stdout = b""
|
||||
stderr = b"sops: decryption failed"
|
||||
|
||||
def fake_run(cmd, capture_output, check):
|
||||
return Result()
|
||||
|
||||
def fake_require():
|
||||
return "/fake/sops"
|
||||
|
||||
monkeypatch.setattr("enroll.sopsutil.subprocess.run", fake_run)
|
||||
monkeypatch.setattr("enroll.sopsutil.require_sops_cmd", fake_require)
|
||||
|
||||
from enroll.sopsutil import decrypt_file_binary_to, SopsError
|
||||
|
||||
with pytest.raises(SopsError) as exc_info:
|
||||
decrypt_file_binary_to(src, dst)
|
||||
|
||||
assert "decryption failed" in str(exc_info.value).lower()
|
||||
|
||||
|
||||
def test_decrypt_file_binary_to_chmod_fails(monkeypatch, tmp_path: Path):
|
||||
"""Test when chmod fails during decryption but file is still written."""
|
||||
src = tmp_path / "encrypted.sops"
|
||||
src.write_bytes(b"encrypted data")
|
||||
dst = tmp_path / "decrypted.txt"
|
||||
|
||||
class Result:
|
||||
returncode = 0
|
||||
stdout = b"decrypted data"
|
||||
stderr = b""
|
||||
|
||||
def fake_run(cmd, capture_output, check):
|
||||
return Result()
|
||||
|
||||
def fake_require():
|
||||
return "/fake/sops"
|
||||
|
||||
def fake_chmod(path, mode):
|
||||
raise OSError("Permission denied")
|
||||
|
||||
monkeypatch.setattr("enroll.sopsutil.subprocess.run", fake_run)
|
||||
monkeypatch.setattr("enroll.sopsutil.require_sops_cmd", fake_require)
|
||||
monkeypatch.setattr("enroll.sopsutil.os.chmod", fake_chmod)
|
||||
|
||||
from enroll.sopsutil import decrypt_file_binary_to
|
||||
|
||||
# Should not raise even though chmod fails
|
||||
decrypt_file_binary_to(src, dst)
|
||||
|
||||
assert dst.exists()
|
||||
|
|
@ -1,11 +1,10 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
import enroll.systemd as s
|
||||
|
||||
|
||||
def test_list_enabled_services_and_timers_filters_templates(monkeypatch):
|
||||
import enroll.systemd as s
|
||||
|
||||
def fake_run(cmd: list[str]) -> str:
|
||||
if "--type=service" in cmd:
|
||||
return "\n".join(
|
||||
|
|
@ -35,8 +34,6 @@ def test_list_enabled_services_and_timers_filters_templates(monkeypatch):
|
|||
|
||||
|
||||
def test_get_unit_info_parses_fields(monkeypatch):
|
||||
import enroll.systemd as s
|
||||
|
||||
class P:
|
||||
def __init__(self, rc: int, out: str, err: str = ""):
|
||||
self.returncode = rc
|
||||
|
|
@ -71,8 +68,6 @@ def test_get_unit_info_parses_fields(monkeypatch):
|
|||
|
||||
|
||||
def test_get_unit_info_raises_unit_query_error(monkeypatch):
|
||||
import enroll.systemd as s
|
||||
|
||||
class P:
|
||||
def __init__(self, rc: int, out: str, err: str):
|
||||
self.returncode = rc
|
||||
|
|
@ -90,8 +85,6 @@ def test_get_unit_info_raises_unit_query_error(monkeypatch):
|
|||
|
||||
|
||||
def test_get_timer_info_parses_fields(monkeypatch):
|
||||
import enroll.systemd as s
|
||||
|
||||
class P:
|
||||
def __init__(self, rc: int, out: str, err: str = ""):
|
||||
self.returncode = rc
|
||||
|
|
@ -119,3 +112,210 @@ def test_get_timer_info_parses_fields(monkeypatch):
|
|||
ti = s.get_timer_info("apt-daily.timer")
|
||||
assert ti.trigger_unit == "apt-daily.service"
|
||||
assert "/etc/default/apt" in ti.env_files
|
||||
|
||||
|
||||
def test_list_enabled_services_empty_output(monkeypatch):
|
||||
def fake_run(cmd: list[str]) -> str:
|
||||
return ""
|
||||
|
||||
monkeypatch.setattr(s, "_run", fake_run)
|
||||
result = s.list_enabled_services()
|
||||
assert result == []
|
||||
|
||||
|
||||
def test_list_enabled_timers_empty_output(monkeypatch):
|
||||
def fake_run(cmd: list[str]) -> str:
|
||||
return ""
|
||||
|
||||
monkeypatch.setattr(s, "_run", fake_run)
|
||||
result = s.list_enabled_timers()
|
||||
assert result == []
|
||||
|
||||
|
||||
def test_list_enabled_services_with_only_templates(monkeypatch):
|
||||
def fake_run(cmd: list[str]) -> str:
|
||||
return "getty@.service enabled\n"
|
||||
|
||||
monkeypatch.setattr(s, "_run", fake_run)
|
||||
result = s.list_enabled_services()
|
||||
assert result == []
|
||||
|
||||
|
||||
def test_list_enabled_timers_with_only_templates(monkeypatch):
|
||||
def fake_run(cmd: list[str]) -> str:
|
||||
return "foo@.timer enabled\n"
|
||||
|
||||
monkeypatch.setattr(s, "_run", fake_run)
|
||||
result = s.list_enabled_timers()
|
||||
assert result == []
|
||||
|
||||
|
||||
def test_get_timer_info_raises_on_failure(monkeypatch):
|
||||
class P:
|
||||
def __init__(self, rc: int, out: str, err: str):
|
||||
self.returncode = rc
|
||||
self.stdout = out
|
||||
self.stderr = err
|
||||
|
||||
def fake_run(cmd, text, capture_output):
|
||||
return P(1, "", "timer not found")
|
||||
|
||||
monkeypatch.setattr(s.subprocess, "run", fake_run)
|
||||
|
||||
with pytest.raises(RuntimeError) as exc_info:
|
||||
s.get_timer_info("nonexistent.timer")
|
||||
|
||||
assert "nonexistent.timer" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_get_timer_info_with_empty_fields(monkeypatch):
|
||||
class P:
|
||||
def __init__(self, rc: int, out: str, err: str = ""):
|
||||
self.returncode = rc
|
||||
self.stdout = out
|
||||
self.stderr = err
|
||||
|
||||
def fake_run(cmd, text, capture_output):
|
||||
return P(
|
||||
0,
|
||||
"\n".join(
|
||||
[
|
||||
"FragmentPath=",
|
||||
"DropInPaths=",
|
||||
"EnvironmentFiles=",
|
||||
"Unit=",
|
||||
"ActiveState=",
|
||||
"SubState=",
|
||||
"UnitFileState=",
|
||||
"ConditionResult=",
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
monkeypatch.setattr(s.subprocess, "run", fake_run)
|
||||
ti = s.get_timer_info("empty.timer")
|
||||
assert ti.fragment_path is None
|
||||
assert ti.dropin_paths == []
|
||||
assert ti.env_files == []
|
||||
assert ti.trigger_unit is None
|
||||
assert ti.active_state is None
|
||||
|
||||
|
||||
def test_get_unit_info_with_empty_fields(monkeypatch):
|
||||
class P:
|
||||
def __init__(self, rc: int, out: str, err: str = ""):
|
||||
self.returncode = rc
|
||||
self.stdout = out
|
||||
self.stderr = err
|
||||
|
||||
def fake_run(cmd, check, text, capture_output):
|
||||
return P(
|
||||
0,
|
||||
"\n".join(
|
||||
[
|
||||
"FragmentPath=",
|
||||
"DropInPaths=",
|
||||
"EnvironmentFiles=",
|
||||
"ExecStart=",
|
||||
"ActiveState=",
|
||||
"SubState=",
|
||||
"UnitFileState=",
|
||||
"ConditionResult=",
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
monkeypatch.setattr(s.subprocess, "run", fake_run)
|
||||
ui = s.get_unit_info("empty.service")
|
||||
assert ui.fragment_path is None
|
||||
assert ui.dropin_paths == []
|
||||
assert ui.env_files == []
|
||||
assert ui.exec_paths == []
|
||||
assert ui.active_state is None
|
||||
|
||||
|
||||
def test_run_command_raises_on_error(monkeypatch):
|
||||
"""Test _run raises RuntimeError on non-zero exit."""
|
||||
|
||||
class P:
|
||||
returncode = 1
|
||||
stdout = ""
|
||||
stderr = "command failed"
|
||||
|
||||
def fake_run(cmd, check, text, capture_output):
|
||||
return P()
|
||||
|
||||
monkeypatch.setattr(s.subprocess, "run", fake_run)
|
||||
|
||||
with pytest.raises(RuntimeError) as exc_info:
|
||||
s._run(["fake", "command"])
|
||||
|
||||
assert "Command failed" in str(exc_info.value)
|
||||
assert "fake" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_list_enabled_services_filters_non_service_units(monkeypatch):
|
||||
"""Test that non-.service units are filtered out."""
|
||||
|
||||
def fake_run(cmd: list[str]) -> str:
|
||||
return "\n".join(
|
||||
[
|
||||
"nginx.service enabled",
|
||||
"network.target enabled", # not a service
|
||||
"multi-user.target enabled", # not a service
|
||||
]
|
||||
)
|
||||
|
||||
monkeypatch.setattr(s, "_run", fake_run)
|
||||
result = s.list_enabled_services()
|
||||
assert result == ["nginx.service"]
|
||||
|
||||
|
||||
def test_list_enabled_timers_filters_non_timer_units(monkeypatch):
|
||||
"""Test that non-.timer units are filtered out."""
|
||||
|
||||
def fake_run(cmd: list[str]) -> str:
|
||||
return "\n".join(
|
||||
[
|
||||
"apt-daily.timer enabled",
|
||||
"some.service enabled", # not a timer
|
||||
]
|
||||
)
|
||||
|
||||
monkeypatch.setattr(s, "_run", fake_run)
|
||||
result = s.list_enabled_timers()
|
||||
assert result == ["apt-daily.timer"]
|
||||
|
||||
|
||||
def test_list_enabled_services_filters_empty_lines(monkeypatch):
|
||||
"""Test that empty lines are skipped."""
|
||||
|
||||
def fake_run(cmd: list[str]) -> str:
|
||||
return "\n".join(
|
||||
[
|
||||
"nginx.service enabled",
|
||||
"", # empty line
|
||||
"ssh.service enabled",
|
||||
]
|
||||
)
|
||||
|
||||
monkeypatch.setattr(s, "_run", fake_run)
|
||||
result = s.list_enabled_services()
|
||||
assert result == ["nginx.service", "ssh.service"]
|
||||
|
||||
|
||||
def test_list_enabled_timers_filters_empty_lines(monkeypatch):
|
||||
"""Test that empty lines are skipped."""
|
||||
|
||||
def fake_run(cmd: list[str]) -> str:
|
||||
return "\n".join(
|
||||
[
|
||||
"apt-daily.timer enabled",
|
||||
"", # empty line
|
||||
"daily.timer enabled",
|
||||
]
|
||||
)
|
||||
|
||||
monkeypatch.setattr(s, "_run", fake_run)
|
||||
result = s.list_enabled_timers()
|
||||
assert result == ["apt-daily.timer", "daily.timer"]
|
||||
|
|
|
|||
|
|
@ -180,3 +180,234 @@ def test_cli_validate_exits_1_on_validation_warning_with_flag(
|
|||
with pytest.raises(SystemExit) as e:
|
||||
cli.main()
|
||||
assert e.value.code == 1
|
||||
|
||||
|
||||
def test_validation_result_ok():
|
||||
from enroll.validate import ValidationResult
|
||||
|
||||
result = ValidationResult(errors=[], warnings=[])
|
||||
assert result.ok is True
|
||||
assert result.to_text() == "OK: harvest bundle validated\n"
|
||||
|
||||
|
||||
def test_validation_result_with_errors():
|
||||
from enroll.validate import ValidationResult
|
||||
|
||||
result = ValidationResult(errors=["error1", "error2"], warnings=[])
|
||||
assert result.ok is False
|
||||
text = result.to_text()
|
||||
assert "ERROR: 2 validation error(s)" in text
|
||||
assert "error1" in text
|
||||
assert "error2" in text
|
||||
|
||||
|
||||
def test_validation_result_with_warnings():
|
||||
from enroll.validate import ValidationResult
|
||||
|
||||
result = ValidationResult(errors=[], warnings=["warn1"])
|
||||
assert result.ok is True
|
||||
text = result.to_text()
|
||||
assert "WARN: 1 warning(s)" in text
|
||||
assert "warn1" in text
|
||||
|
||||
|
||||
def test_validation_result_to_dict():
|
||||
from enroll.validate import ValidationResult
|
||||
|
||||
result = ValidationResult(errors=["e1"], warnings=["w1"])
|
||||
d = result.to_dict()
|
||||
assert d["ok"] is False
|
||||
assert d["errors"] == ["e1"]
|
||||
assert d["warnings"] == ["w1"]
|
||||
|
||||
|
||||
def test_iter_managed_files_singleton_roles():
|
||||
from enroll.validate import _iter_managed_files
|
||||
|
||||
state = {
|
||||
"roles": {
|
||||
"users": {"managed_files": [{"path": "/etc/passwd", "src_rel": "passwd"}]},
|
||||
"packages": [
|
||||
{
|
||||
"role_name": "vim",
|
||||
"managed_files": [{"path": "/usr/bin/vim", "src_rel": "vim"}],
|
||||
}
|
||||
],
|
||||
}
|
||||
}
|
||||
files = _iter_managed_files(state)
|
||||
assert len(files) == 2
|
||||
assert ("users", {"path": "/etc/passwd", "src_rel": "passwd"}) in files
|
||||
|
||||
|
||||
def test_iter_managed_files_services_role():
|
||||
from enroll.validate import _iter_managed_files
|
||||
|
||||
state = {
|
||||
"roles": {
|
||||
"services": [
|
||||
{
|
||||
"role_name": "nginx",
|
||||
"managed_files": [
|
||||
{"path": "/etc/nginx/nginx.conf", "src_rel": "nginx.conf"}
|
||||
],
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
files = _iter_managed_files(state)
|
||||
assert len(files) == 1
|
||||
assert files[0][0] == "nginx"
|
||||
|
||||
|
||||
def test_iter_managed_files_handles_non_dict_items():
|
||||
from enroll.validate import _iter_managed_files
|
||||
|
||||
state = {
|
||||
"roles": {
|
||||
"users": {
|
||||
"managed_files": [
|
||||
"not_a_dict",
|
||||
{"path": "/etc/passwd", "src_rel": "passwd"},
|
||||
]
|
||||
},
|
||||
"services": ["not_a_dict", {"role_name": "nginx", "managed_files": []}],
|
||||
"packages": ["not_a_dict"],
|
||||
}
|
||||
}
|
||||
files = _iter_managed_files(state)
|
||||
assert len(files) == 1
|
||||
|
||||
|
||||
def test_iter_managed_files_empty_state():
|
||||
from enroll.validate import _iter_managed_files
|
||||
|
||||
state = {"roles": {}}
|
||||
files = _iter_managed_files(state)
|
||||
assert files == []
|
||||
|
||||
|
||||
def test_validate_harvest_missing_state_json(tmp_path: Path):
|
||||
bundle_dir = tmp_path / "bundle"
|
||||
bundle_dir.mkdir()
|
||||
result = validate_harvest(str(bundle_dir))
|
||||
assert result.ok is False
|
||||
assert any("missing state.json" in e for e in result.errors)
|
||||
|
||||
|
||||
def test_validate_harvest_invalid_json(tmp_path: Path):
|
||||
bundle_dir = tmp_path / "bundle"
|
||||
bundle_dir.mkdir()
|
||||
state_file = bundle_dir / "state.json"
|
||||
state_file.write_text("not valid json", encoding="utf-8")
|
||||
result = validate_harvest(str(bundle_dir))
|
||||
assert result.ok is False
|
||||
assert any("failed to parse" in e for e in result.errors)
|
||||
|
||||
|
||||
def test_validate_harvest_schema_error(tmp_path: Path):
|
||||
bundle_dir = tmp_path / "bundle"
|
||||
bundle_dir.mkdir()
|
||||
state_file = bundle_dir / "state.json"
|
||||
state_file.write_text("{}", encoding="utf-8")
|
||||
result = validate_harvest(
|
||||
str(bundle_dir), schema="https://invalid.invalid/schema.json"
|
||||
)
|
||||
assert result.ok is False
|
||||
assert any("failed to load/validate schema" in e for e in result.errors)
|
||||
|
||||
|
||||
def test_validate_harvest_missing_artifact(tmp_path: Path):
|
||||
bundle_dir = tmp_path / "bundle"
|
||||
bundle_dir.mkdir()
|
||||
artifacts_dir = bundle_dir / "artifacts"
|
||||
artifacts_dir.mkdir()
|
||||
state = {
|
||||
"roles": {
|
||||
"users": {"managed_files": [{"path": "/etc/passwd", "src_rel": "passwd"}]}
|
||||
}
|
||||
}
|
||||
state_file = bundle_dir / "state.json"
|
||||
state_file.write_text(json.dumps(state), encoding="utf-8")
|
||||
result = validate_harvest(str(bundle_dir))
|
||||
assert result.ok is False
|
||||
assert any("missing artifact" in e for e in result.errors)
|
||||
|
||||
|
||||
def test_validate_harvest_suspicious_src_rel(tmp_path: Path):
|
||||
bundle_dir = tmp_path / "bundle"
|
||||
bundle_dir.mkdir()
|
||||
state = {
|
||||
"roles": {
|
||||
"users": {
|
||||
"managed_files": [
|
||||
{"path": "/etc/passwd", "src_rel": "../../../etc/passwd"}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
state_file = bundle_dir / "state.json"
|
||||
state_file.write_text(json.dumps(state), encoding="utf-8")
|
||||
result = validate_harvest(str(bundle_dir))
|
||||
assert result.ok is False
|
||||
assert any("suspicious src_rel" in e for e in result.errors)
|
||||
|
||||
|
||||
def test_validate_harvest_missing_src_rel(tmp_path: Path):
|
||||
bundle_dir = tmp_path / "bundle"
|
||||
bundle_dir.mkdir()
|
||||
state = {"roles": {"users": {"managed_files": [{"path": "/etc/passwd"}]}}}
|
||||
state_file = bundle_dir / "state.json"
|
||||
state_file.write_text(json.dumps(state), encoding="utf-8")
|
||||
result = validate_harvest(str(bundle_dir))
|
||||
assert result.ok is False
|
||||
assert any("missing src_rel" in e for e in result.errors)
|
||||
|
||||
|
||||
def test_validate_harvest_firewall_runtime_missing(tmp_path: Path):
|
||||
bundle_dir = tmp_path / "bundle"
|
||||
bundle_dir.mkdir()
|
||||
artifacts_dir = bundle_dir / "artifacts"
|
||||
fw_dir = artifacts_dir / "firewall_runtime"
|
||||
fw_dir.mkdir(parents=True)
|
||||
state = {
|
||||
"roles": {
|
||||
"firewall_runtime": {
|
||||
"role_name": "firewall_runtime",
|
||||
"iptables_v4_save": "iptables.save",
|
||||
}
|
||||
}
|
||||
}
|
||||
state_file = bundle_dir / "state.json"
|
||||
state_file.write_text(json.dumps(state), encoding="utf-8")
|
||||
result = validate_harvest(str(bundle_dir))
|
||||
assert result.ok is False
|
||||
assert any("missing firewall runtime artifact" in e for e in result.errors)
|
||||
|
||||
|
||||
def test_validate_harvest_firewall_runtime_suspicious(tmp_path: Path):
|
||||
bundle_dir = tmp_path / "bundle"
|
||||
bundle_dir.mkdir()
|
||||
state = {
|
||||
"roles": {
|
||||
"firewall_runtime": {
|
||||
"role_name": "firewall_runtime",
|
||||
"iptables_v4_save": "../../../etc/passwd",
|
||||
}
|
||||
}
|
||||
}
|
||||
state_file = bundle_dir / "state.json"
|
||||
state_file.write_text(json.dumps(state), encoding="utf-8")
|
||||
result = validate_harvest(str(bundle_dir))
|
||||
assert result.ok is False
|
||||
assert any("suspicious src_rel" in e for e in result.errors)
|
||||
|
||||
|
||||
def test_validate_harvest_no_schema_option(tmp_path: Path):
|
||||
bundle_dir = tmp_path / "bundle"
|
||||
bundle_dir.mkdir()
|
||||
state_file = bundle_dir / "state.json"
|
||||
state_file.write_text("invalid json", encoding="utf-8")
|
||||
result = validate_harvest(str(bundle_dir), no_schema=True)
|
||||
assert result.ok is False
|
||||
assert any("failed to parse" in e for e in result.errors)
|
||||
|
|
|
|||
|
|
@ -1,36 +1,12 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import types
|
||||
# The version module is hard to test fully because it uses importlib.metadata
|
||||
# which is difficult to mock. We'll test what we can.
|
||||
|
||||
|
||||
def test_get_enroll_version_returns_unknown_when_import_fails(monkeypatch):
|
||||
def test_get_enroll_version_returns_string():
|
||||
from enroll.version import get_enroll_version
|
||||
|
||||
# Ensure both the module cache and the parent package attribute are redirected.
|
||||
import importlib
|
||||
|
||||
dummy = types.ModuleType("importlib.metadata")
|
||||
# Missing attributes will cause ImportError when importing names.
|
||||
monkeypatch.setitem(sys.modules, "importlib.metadata", dummy)
|
||||
monkeypatch.setattr(importlib, "metadata", dummy, raising=False)
|
||||
|
||||
assert get_enroll_version() == "unknown"
|
||||
|
||||
|
||||
def test_get_enroll_version_uses_packages_distributions(monkeypatch):
|
||||
# Restore the real module for this test.
|
||||
monkeypatch.delitem(sys.modules, "importlib.metadata", raising=False)
|
||||
|
||||
import importlib.metadata
|
||||
|
||||
from enroll.version import get_enroll_version
|
||||
|
||||
monkeypatch.setattr(
|
||||
importlib.metadata,
|
||||
"packages_distributions",
|
||||
lambda: {"enroll": ["enroll-dist"]},
|
||||
)
|
||||
monkeypatch.setattr(importlib.metadata, "version", lambda dist: "9.9.9")
|
||||
|
||||
assert get_enroll_version() == "9.9.9"
|
||||
result = get_enroll_version()
|
||||
assert isinstance(result, str)
|
||||
assert len(result) > 0
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue