diff --git a/enroll/remote.py b/enroll/remote.py index 0a71d4c..ecb1c27 100644 --- a/enroll/remote.py +++ b/enroll/remote.py @@ -578,11 +578,14 @@ def _remote_harvest( sftp = ssh.open_sftp() rtmp: Optional[str] = None + remote_root_tmp: Optional[str] = None try: rc, out, err = _ssh_run(ssh, "mktemp -d") if rc != 0: raise RuntimeError(f"Remote mktemp failed: {err.strip()}") rtmp = out.strip() + if not rtmp: + raise RuntimeError("Remote mktemp returned an empty path") # Be explicit: restrict the remote staging area to the current user. rc, out, err = _ssh_run(ssh, f"chmod 700 -- {shlex.quote(rtmp)}") @@ -590,10 +593,35 @@ def _remote_harvest( raise RuntimeError(f"Remote chmod failed: {err.strip()}") rapp = f"{rtmp}/enroll.pyz" - rbundle = f"{rtmp}/bundle" - sftp.put(str(pyz), rapp) + if not no_sudo: + # The remote zipapp is staged as the SSH user, but the harvest + # itself runs as root. Root must not write its bundle under the + # SSH user's mktemp directory: the root-output safety checks + # deliberately reject user-owned parents to avoid symlink/race + # issues. Create a separate sudo-owned tempdir for the bundle. + rc, out, err = _ssh_run_sudo( + ssh, "mktemp -d", sudo_password=sudo_password, get_pty=True + ) + if rc != 0: + raise RuntimeError(f"Remote sudo mktemp failed: {err.strip()}") + remote_root_tmp = out.strip() + if not remote_root_tmp: + raise RuntimeError("Remote sudo mktemp returned an empty path") + + rc, out, err = _ssh_run_sudo( + ssh, + f"chmod 700 -- {shlex.quote(remote_root_tmp)}", + sudo_password=sudo_password, + get_pty=True, + ) + if rc != 0: + raise RuntimeError(f"Remote sudo chmod failed: {err.strip()}") + rbundle = f"{remote_root_tmp}/bundle" + else: + rbundle = f"{rtmp}/bundle" + # Run remote harvest. argv: list[str] = [ remote_python, @@ -635,9 +663,10 @@ def _remote_harvest( "Unable to determine remote username for chown. " "Pass --remote-user explicitly or use --no-sudo." ) + chown_target = remote_root_tmp or rbundle chown_cmd = ( "chown -R -- " - f"{shlex.quote(resolved_user)} {shlex.quote(rbundle)}" + f"{shlex.quote(resolved_user)} {shlex.quote(chown_target)}" ) rc, out, err = _ssh_run_sudo( ssh, @@ -678,7 +707,19 @@ def _remote_harvest( _safe_extract_tar(tf, local_out_dir) finally: - # Cleanup remote tmpdir even on failure. + # Cleanup remote tmpdirs even on failure. The sudo-owned harvest + # tempdir may still be root-owned if harvest/chown failed, so remove + # it via sudo and avoid masking the original error if cleanup fails. + if remote_root_tmp: + try: + _ssh_run_sudo( + ssh, + f"rm -rf -- {shlex.quote(remote_root_tmp)}", + sudo_password=sudo_password, + get_pty=True, + ) + except Exception: + pass # nosec - best-effort remote cleanup if rtmp: _ssh_run(ssh, f"rm -rf -- {shlex.quote(rtmp)}") try: diff --git a/tests/test_harvest_collectors.py b/tests/test_harvest_collectors.py index 80f43e4..94b5259 100644 --- a/tests/test_harvest_collectors.py +++ b/tests/test_harvest_collectors.py @@ -1,8 +1,11 @@ from __future__ import annotations +from pathlib import Path + from enroll.harvest_collectors.context import HarvestContext +from enroll.harvest_collectors.paths import ExtraPathsCollector, UsrLocalCustomCollector from enroll.harvest_collectors.runtime import RuntimeStateCollector -from enroll.harvest_types import FirewallRuntimeSnapshot, SysctlSnapshot +from enroll.harvest_types import FirewallRuntimeSnapshot, ManagedFile, SysctlSnapshot from enroll.ignore import IgnorePolicy from enroll.pathfilter import PathFilter @@ -11,11 +14,11 @@ class _Backend: name = "dpkg" -def _context(tmp_path): +def _context(tmp_path: Path, *, include=(), exclude=(), policy=None) -> HarvestContext: return HarvestContext( - bundle_dir=str(tmp_path), - policy=IgnorePolicy(), - path_filter=PathFilter(include=(), exclude=()), + bundle_dir=str(tmp_path / "bundle"), + policy=policy or IgnorePolicy(), + path_filter=PathFilter(include=include, exclude=exclude), platform={}, backend=_Backend(), installed_pkgs={}, @@ -245,3 +248,149 @@ def test_container_images_collector_notes_unexpected_inspect_shape( assert result.images == [] assert "Unexpected docker image inspect JSON shape" in result.notes[0] + + +def test_extra_paths_collector_records_dirs_files_notes_and_excludes( + monkeypatch, tmp_path +): + from enroll.harvest_collectors import paths + + root = tmp_path / "include" + sub = root / "sub" + skip = root / "skip" + sub.mkdir(parents=True) + skip.mkdir() + keep_file = sub / "keep.conf" + keep_file.write_text("ok", encoding="utf-8") + skip_file = skip / "skip.conf" + skip_file.write_text("no", encoding="utf-8") + + class Policy(IgnorePolicy): + def deny_reason_dir(self, path: str): + return "denied_dir" if path == str(sub) else None + + def fake_stat_triplet(path: str): + return ("root", "root", "0755") + + def fake_capture_file(**kwargs): + kwargs["managed_out"].append( + ManagedFile( + path=kwargs["abs_path"], + src_rel=kwargs["abs_path"].lstrip("/"), + owner="root", + group="root", + mode="0644", + reason=kwargs["reason"], + ) + ) + return True + + monkeypatch.setattr(paths.h, "stat_triplet", fake_stat_triplet) + monkeypatch.setattr(paths, "capture_file", lambda *a, **kw: fake_capture_file(**kw)) + + ctx = _context( + tmp_path, + include=[str(root)], + exclude=[str(skip)], + policy=Policy(), + ) + result = ExtraPathsCollector( + ctx, + seen_by_role={}, + already_all=set(), + include_paths=[str(root)], + exclude_paths=[str(skip)], + ).collect() + + managed_dirs = {d.path for d in result.managed_dirs} + assert str(root) in managed_dirs + assert str(sub) not in managed_dirs # denied by policy + assert str(skip) not in managed_dirs # pruned by exclude filter + assert [m.path for m in result.managed_files] == [str(keep_file)] + assert "User include patterns:" in result.notes + assert f"- {root}" in result.notes + assert f"- {skip}" in result.notes + + +def test_extra_paths_collector_skips_already_captured_files(monkeypatch, tmp_path): + from enroll.harvest_collectors import paths + + root = tmp_path / "include" + root.mkdir() + file_path = root / "keep.conf" + file_path.write_text("ok", encoding="utf-8") + calls: list[str] = [] + + monkeypatch.setattr(paths.h, "stat_triplet", lambda p: ("root", "root", "0755")) + monkeypatch.setattr( + paths, "capture_file", lambda *a, **kw: calls.append(kw["abs_path"]) or True + ) + + ctx = _context(tmp_path, include=[str(root)]) + result = ExtraPathsCollector( + ctx, + seen_by_role={}, + already_all={str(file_path)}, + include_paths=[str(root)], + ).collect() + + assert result.managed_files == [] + assert calls == [] + + +def test_usr_local_custom_collector_scans_executable_bin_and_notes_cap( + monkeypatch, tmp_path +): + from enroll.harvest_collectors import paths + + captured: list[str] = [] + + def fake_isdir(path: str) -> bool: + return path in {"/usr/local/etc", "/usr/local/bin"} + + def fake_walk(root: str): + if root == "/usr/local/etc": + yield root, [], ["app.conf"] + elif root == "/usr/local/bin": + yield root, [], ["tool", "not-exec"] + + def fake_isfile(path: str) -> bool: + return path in { + "/usr/local/etc/app.conf", + "/usr/local/bin/tool", + "/usr/local/bin/not-exec", + } + + def fake_stat_triplet(path: str): + mode = "0755" if path == "/usr/local/bin/tool" else "0644" + return ("root", "root", mode) + + def fake_capture_file(**kwargs): + captured.append(kwargs["abs_path"]) + kwargs["managed_out"].append( + ManagedFile( + path=kwargs["abs_path"], + src_rel=kwargs["abs_path"].lstrip("/"), + owner="root", + group="root", + mode="0644", + reason=kwargs["reason"], + ) + ) + return True + + monkeypatch.setattr(paths.os.path, "isdir", fake_isdir) + monkeypatch.setattr(paths.os, "walk", fake_walk) + monkeypatch.setattr(paths.os.path, "isfile", fake_isfile) + monkeypatch.setattr(paths.os.path, "islink", lambda p: False) + monkeypatch.setattr(paths.h, "stat_triplet", fake_stat_triplet) + monkeypatch.setattr(paths, "capture_file", lambda *a, **kw: fake_capture_file(**kw)) + + ctx = _context(tmp_path) + result = UsrLocalCustomCollector(ctx, seen_by_role={}, already_all=set()).collect() + + assert captured == ["/usr/local/etc/app.conf", "/usr/local/bin/tool"] + assert [m.reason for m in result.managed_files] == [ + "usr_local_etc_custom", + "usr_local_bin_script", + ] diff --git a/tests/test_harvest_collectors_package_manager.py b/tests/test_harvest_collectors_package_manager.py new file mode 100644 index 0000000..ca805dd --- /dev/null +++ b/tests/test_harvest_collectors_package_manager.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +from pathlib import Path + +from enroll.harvest_collectors.context import HarvestContext +from enroll.harvest_collectors.package_manager import PackageManagerConfigCollector +from enroll.harvest_types import ManagedFile +from enroll.ignore import IgnorePolicy +from enroll.pathfilter import PathFilter + + +class _Backend: + def __init__(self, name: str): + self.name = name + + +def _context(tmp_path: Path, backend_name: str) -> HarvestContext: + return HarvestContext( + bundle_dir=str(tmp_path / "bundle"), + policy=IgnorePolicy(), + path_filter=PathFilter(include=(), exclude=()), + platform={}, + backend=_Backend(backend_name), + installed_pkgs={}, + installed_names=set(), + owned_etc=set(), + etc_owner_map={}, + topdir_to_pkgs={}, + pkg_to_etc_paths={}, + captured_global=set(), + ) + + +def _fake_capture(**kwargs): + kwargs["managed_out"].append( + ManagedFile( + path=kwargs["abs_path"], + src_rel=kwargs["abs_path"].lstrip("/"), + owner="root", + group="root", + mode="0644", + reason=kwargs["reason"], + ) + ) + return True + + +def test_package_manager_config_collector_captures_apt_branch(monkeypatch, tmp_path): + from enroll.harvest_collectors import package_manager as pm + + monkeypatch.setattr( + pm, "iter_apt_capture_paths", lambda: [("/etc/apt/a.conf", "apt")] + ) + monkeypatch.setattr(pm, "capture_file", lambda *a, **kw: _fake_capture(**kw)) + + result = PackageManagerConfigCollector(_context(tmp_path, "dpkg"), {}).collect() + + assert [m.path for m in result.apt_config_snapshot.managed_files] == [ + "/etc/apt/a.conf" + ] + assert result.dnf_config_snapshot.managed_files == [] + + +def test_package_manager_config_collector_captures_dnf_branch(monkeypatch, tmp_path): + from enroll.harvest_collectors import package_manager as pm + + monkeypatch.setattr( + pm, "iter_dnf_capture_paths", lambda: [("/etc/dnf/d.conf", "dnf")] + ) + monkeypatch.setattr(pm, "capture_file", lambda *a, **kw: _fake_capture(**kw)) + + result = PackageManagerConfigCollector(_context(tmp_path, "rpm"), {}).collect() + + assert result.apt_config_snapshot.managed_files == [] + assert [m.path for m in result.dnf_config_snapshot.managed_files] == [ + "/etc/dnf/d.conf" + ] + + +def test_package_manager_config_collector_unknown_backend_returns_empty(tmp_path): + result = PackageManagerConfigCollector(_context(tmp_path, "apk"), {}).collect() + + assert result.apt_config_snapshot.managed_files == [] + assert result.dnf_config_snapshot.managed_files == [] diff --git a/tests/test_harvest_safety.py b/tests/test_harvest_safety.py index 5084788..6c5b338 100644 --- a/tests/test_harvest_safety.py +++ b/tests/test_harvest_safety.py @@ -1,6 +1,7 @@ from __future__ import annotations import os +import stat from pathlib import Path import pytest @@ -13,6 +14,8 @@ from enroll.manifest_safety import prepare_manifest_output_dir from enroll.harvest_safety import OutputSafetyError, prepare_new_private_dir from enroll.pathfilter import PathFilter +import enroll.harvest_safety as hs + class _RacePolicy(IgnorePolicy): def inspect_file(self, path: str): @@ -192,3 +195,90 @@ def test_safe_output_parent_does_not_descend_into_raced_symlink( hs.ensure_safe_output_parent(link / "subdir" / "report.txt", label="report") assert not (target / "subdir").exists() + + +def _stat_result(mode: int, *, uid: int = 0) -> os.stat_result: + return os.stat_result((mode, 1, 1, 1, uid, 0, 0, 0, 0, 0)) + + +def test_effective_uid_handles_missing_geteuid(monkeypatch): + monkeypatch.setattr(hs, "_OS_GETEUID", None) + assert hs._effective_uid() is None + + +def test_effective_uid_handles_geteuid_error(monkeypatch): + def boom(): + raise OSError("no euid") + + monkeypatch.setattr(hs, "_OS_GETEUID", boom) + assert hs._effective_uid() is None + + +def test_trusted_root_parent_skips_checks_when_not_root(monkeypatch): + monkeypatch.setattr(hs, "_effective_uid", lambda: 1000) + hs._assert_trusted_root_parent( + Path("not-a-dir"), _stat_result(stat.S_IFREG | 0o644, uid=1234), label="x" + ) + + +def test_trusted_root_parent_rejects_non_directory(monkeypatch): + monkeypatch.setattr(hs, "_effective_uid", lambda: 0) + with pytest.raises(OutputSafetyError, match="parent is not a directory"): + hs._assert_trusted_root_parent( + Path("file"), _stat_result(stat.S_IFREG | 0o644), label="x" + ) + + +def test_trusted_root_parent_rejects_group_or_world_writable(monkeypatch): + monkeypatch.setattr(hs, "_effective_uid", lambda: 0) + with pytest.raises(OutputSafetyError, match="writable by group/other"): + hs._assert_trusted_root_parent( + Path("open-dir"), _stat_result(stat.S_IFDIR | 0o777), label="x" + ) + + +def test_trusted_root_parent_allows_root_owned_sticky_shared_dir(monkeypatch): + monkeypatch.setattr(hs, "_effective_uid", lambda: 0) + hs._assert_trusted_root_parent( + Path("tmp"), _stat_result(stat.S_IFDIR | stat.S_ISVTX | 0o777), label="x" + ) + + +def test_assert_no_existing_symlink_components_without_root_trust_still_rejects_symlink( + tmp_path: Path, +): + real = tmp_path / "real" + real.mkdir() + link = tmp_path / "link" + link.symlink_to(real, target_is_directory=True) + + with pytest.raises(OutputSafetyError, match="parent path contains a symlink"): + hs._assert_no_existing_symlink_components( + link / "leaf", label="x", require_trusted_root_parents=False + ) + + +def test_ensure_private_empty_dir_rejects_bad_existing_paths(tmp_path: Path): + file_path = tmp_path / "file" + file_path.write_text("x", encoding="utf-8") + with pytest.raises(OutputSafetyError, match="not a directory"): + hs.ensure_private_empty_dir(file_path, label="cache") + + nonempty = tmp_path / "nonempty" + nonempty.mkdir() + (nonempty / "child").write_text("x", encoding="utf-8") + with pytest.raises(OutputSafetyError, match="not empty"): + hs.ensure_private_empty_dir(nonempty, label="cache") + + real = tmp_path / "real" + real.mkdir() + link = tmp_path / "link" + link.symlink_to(real, target_is_directory=True) + with pytest.raises(OutputSafetyError, match="symlink"): + hs.ensure_private_empty_dir(link, label="cache") + + +def test_ensure_private_empty_dir_creates_private_dir(tmp_path: Path): + out = hs.ensure_private_empty_dir(tmp_path / "new-cache", label="cache") + assert out.is_dir() + assert (out.stat().st_mode & 0o777) == 0o700 diff --git a/tests/test_manifest_safety.py b/tests/test_manifest_safety.py new file mode 100644 index 0000000..deb764d --- /dev/null +++ b/tests/test_manifest_safety.py @@ -0,0 +1,126 @@ +from __future__ import annotations + +import os +from pathlib import Path + +import pytest + +from enroll.manifest_safety import ( + ArtifactSafetyError, + ManifestOutputError, + copy_safe_artifact_file, + iter_safe_artifact_files, + prepare_manifest_output_dir, + safe_artifact_file, + validate_site_fqdn, +) + + +def test_validate_site_fqdn_accepts_and_normalises_simple_values(): + assert validate_site_fqdn(None) is None + assert validate_site_fqdn(" ") is None + assert validate_site_fqdn("host_1.example") == "host_1.example" + + +@pytest.mark.parametrize( + "value", ["../host", "host/name", "host\\name", "host\nname", "-bad", ".", ".."] +) +def test_validate_site_fqdn_rejects_path_or_inventory_injection(value: str): + with pytest.raises(ManifestOutputError): + validate_site_fqdn(value) + + +def test_prepare_manifest_output_dir_allows_existing_clean_tree_in_site_mode( + tmp_path: Path, +): + out = tmp_path / "site" + out.mkdir() + (out / ".git").mkdir() + (out / ".git" / "ignored-link").symlink_to(tmp_path, target_is_directory=True) + + assert prepare_manifest_output_dir(out, allow_existing=True) == out + + +def test_prepare_manifest_output_dir_rejects_existing_tree_symlink(tmp_path: Path): + out = tmp_path / "site" + out.mkdir() + (out / "bad-link").symlink_to(tmp_path, target_is_directory=True) + + with pytest.raises(ManifestOutputError, match="contains a symlink"): + prepare_manifest_output_dir(out, allow_existing=True) + + +def test_safe_artifact_file_accepts_regular_file_and_copy(tmp_path: Path): + bundle = tmp_path / "bundle" + artifact = bundle / "artifacts" / "role" / "etc" / "app.conf" + artifact.parent.mkdir(parents=True) + artifact.write_text("managed=true\n", encoding="utf-8") + + assert safe_artifact_file(bundle, "role", "etc/app.conf") == artifact + + dst = tmp_path / "copy.conf" + copy_safe_artifact_file(artifact, dst) + assert dst.read_text(encoding="utf-8") == "managed=true\n" + + +def test_safe_artifact_file_rejects_unsafe_role_and_src(tmp_path: Path): + bundle = tmp_path / "bundle" + with pytest.raises(ArtifactSafetyError, match="must be relative"): + safe_artifact_file(bundle, "/role", "file") + with pytest.raises(ArtifactSafetyError, match="unsafe path component"): + safe_artifact_file(bundle, "role", "../file") + with pytest.raises(ArtifactSafetyError, match="NUL"): + safe_artifact_file(bundle, "role", "bad\x00file") + + +def test_safe_artifact_file_rejects_artifacts_symlink(tmp_path: Path): + bundle = tmp_path / "bundle" + bundle.mkdir() + (bundle / "artifacts").symlink_to(tmp_path, target_is_directory=True) + + with pytest.raises(ArtifactSafetyError, match="artifacts directory is a symlink"): + safe_artifact_file(bundle, "role", "file") + + +def test_safe_artifact_file_rejects_bad_artifact_kinds(tmp_path: Path): + bundle = tmp_path / "bundle" + role_dir = bundle / "artifacts" / "role" + role_dir.mkdir(parents=True) + + target = role_dir / "target" + target.write_text("x", encoding="utf-8") + (role_dir / "link").symlink_to(target) + with pytest.raises(ArtifactSafetyError, match="symlink"): + safe_artifact_file(bundle, "role", "link") + + (role_dir / "dir-artifact").mkdir() + with pytest.raises(ArtifactSafetyError, match="not a regular file"): + safe_artifact_file(bundle, "role", "dir-artifact") + + hardlink = role_dir / "hardlink" + os.link(target, hardlink) + with pytest.raises(ArtifactSafetyError, match="hardlinked"): + safe_artifact_file(bundle, "role", "target") + + +def test_iter_safe_artifact_files_handles_missing_and_bad_role_dirs(tmp_path: Path): + bundle = tmp_path / "bundle" + assert list(iter_safe_artifact_files(bundle, "missing")) == [] + + role_file = bundle / "artifacts" / "role" + role_file.parent.mkdir(parents=True) + role_file.write_text("not a dir", encoding="utf-8") + with pytest.raises(ArtifactSafetyError, match="not a directory"): + list(iter_safe_artifact_files(bundle, "role")) + + +def test_iter_safe_artifact_files_rejects_symlink_subdir(tmp_path: Path): + bundle = tmp_path / "bundle" + role_dir = bundle / "artifacts" / "role" + role_dir.mkdir(parents=True) + real = tmp_path / "real" + real.mkdir() + (role_dir / "linkdir").symlink_to(real, target_is_directory=True) + + with pytest.raises(ArtifactSafetyError, match="directory is a symlink"): + list(iter_safe_artifact_files(bundle, "role")) diff --git a/tests/test_package_hints.py b/tests/test_package_hints.py new file mode 100644 index 0000000..76dd835 --- /dev/null +++ b/tests/test_package_hints.py @@ -0,0 +1,70 @@ +from __future__ import annotations + +from enroll.package_hints import ( + add_pkgs_from_etc_topdirs, + hint_names, + maybe_add_specific_paths, + package_section_from_installations, + role_id, + role_name_from_pkg, + role_name_from_unit, + safe_name, +) + + +class _Backend: + def __init__(self, *, fail: bool = False): + self.fail = fail + + def specific_paths_for_hints(self, hints): + if self.fail: + raise RuntimeError("backend unavailable") + return [f"/backend/{h}" for h in sorted(hints)] + + +def test_role_name_helpers_sanitise_reserved_and_odd_names(): + assert safe_name("pkg.name+with-dash") == "pkg_name_with_dash" + assert role_id("123 Camel-Case!!") == "r_123_camel_case" + assert role_name_from_unit("class.service") == "class" + assert role_name_from_pkg("flatpak") == "package_flatpak" + + +def test_package_section_from_installations_filters_empty_and_unspecified(): + assert package_section_from_installations([]) is None + assert ( + package_section_from_installations([{"section": "none"}, {"group": ""}]) is None + ) + assert ( + package_section_from_installations([{"section": "z-utils"}, {"group": "admin"}]) + == "admin" + ) + + +def test_hint_names_expands_templates_packages_and_dot_prefixes(): + assert hint_names("postgresql@14-main.service", {"postgresql.14"}) == { + "postgresql@14-main", + "postgresql", + "postgresql.14", + } + + +def test_add_pkgs_from_etc_topdirs_skips_shared_dirs(): + pkgs: set[str] = set() + add_pkgs_from_etc_topdirs( + {"ssh", "nginx"}, + {"ssh": {"openssh-server"}, "nginx": {"nginx"}, "nginx.d": {"nginx-extra"}}, + pkgs, + ) + assert pkgs == {"nginx", "nginx-extra"} + + +def test_maybe_add_specific_paths_uses_backend_and_fallback(): + assert maybe_add_specific_paths({"a", "b"}, _Backend()) == [ + "/backend/a", + "/backend/b", + ] + assert maybe_add_specific_paths({"svc"}, _Backend(fail=True)) == [ + "/etc/default/svc", + "/etc/init.d/svc", + "/etc/sysctl.d/svc.conf", + ] diff --git a/tests/test_remote.py b/tests/test_remote.py index 1c0bfd0..fe54d33 100644 --- a/tests/test_remote.py +++ b/tests/test_remote.py @@ -166,6 +166,13 @@ def test_remote_harvest_happy_path(tmp_path: Path, monkeypatch): return (None, _Stdout(b"alice\n"), _Stderr()) if cmd == "mktemp -d": return (None, _Stdout(b"/tmp/enroll-remote-123\n"), _Stderr()) + if cmd.startswith("sudo -n") and " mktemp -d" in cmd: + return (None, _Stdout(b"/tmp/enroll-root-123\n"), _Stderr()) + if ( + cmd.startswith("sudo -n") + and " chmod 700 -- /tmp/enroll-root-123" in cmd + ): + return (None, _Stdout(b""), _Stderr()) if cmd.startswith("chmod 700"): return (None, _Stdout(b""), _Stderr()) if cmd.startswith("sudo -n") and " harvest " in cmd: @@ -182,6 +189,8 @@ def test_remote_harvest_happy_path(tmp_path: Path, monkeypatch): msg = b"sudo: sorry, you must have a tty to run sudo\n" return (None, _Stdout(b"", rc=1, err=msg), _Stderr(msg)) return (None, _Stdout(b"", rc=0), _Stderr(b"")) + if cmd.startswith("sudo -n") and " rm -rf -- /tmp/enroll-root-123" in cmd: + return (None, _Stdout(b""), _Stderr()) if cmd.startswith("rm -rf"): return (None, _Stdout(b""), _Stderr()) @@ -223,6 +232,11 @@ def test_remote_harvest_happy_path(tmp_path: Path, monkeypatch): assert "--dangerous" in joined assert "--include-path" in joined assert "--exclude-path" in joined + assert "sudo -n -p '' -- mktemp -d" in joined + assert "--out /tmp/enroll-root-123/bundle" in joined + assert "--out /tmp/enroll-remote-123/bundle" not in joined + assert "chown -R -- alice /tmp/enroll-root-123" in joined + assert "tar -cz -C /tmp/enroll-root-123/bundle ." in joined # Ensure we fall back to PTY only when sudo reports it is required. assert any(c == "id -un" and pty is False for c, pty in calls) @@ -508,6 +522,13 @@ def test_remote_harvest_sudo_password_retry_uses_sudo_s_and_writes_password( if cmd == "mktemp -d": return (_Stdin(cmd), _Stdout(b"/tmp/enroll-remote-789\n"), _Stderr()) + if cmd.startswith("sudo -n") and " mktemp -d" in cmd: + return (_Stdin(cmd), _Stdout(b"/tmp/enroll-root-789\n"), _Stderr()) + if ( + cmd.startswith("sudo -n") + and " chmod 700 -- /tmp/enroll-root-789" in cmd + ): + return (_Stdin(cmd), _Stdout(b""), _Stderr()) if cmd.startswith("chmod 700"): return (_Stdin(cmd), _Stdout(b""), _Stderr()) @@ -527,6 +548,8 @@ def test_remote_harvest_sudo_password_retry_uses_sudo_s_and_writes_password( if cmd.startswith("sudo -n") and " chown -R" in cmd: return (_Stdin(cmd), _Stdout(b"", rc=0), _Stderr(b"")) + if cmd.startswith("sudo -n") and " rm -rf -- /tmp/enroll-root-789" in cmd: + return (_Stdin(cmd), _Stdout(b"", rc=0), _Stderr(b"")) if cmd.startswith("rm -rf"): return (_Stdin(cmd), _Stdout(b"", rc=0), _Stderr(b"")) @@ -563,6 +586,10 @@ def test_remote_harvest_sudo_password_retry_uses_sudo_s_and_writes_password( sudo_s = [c for c, _pty in calls if c.startswith("sudo -S") and " harvest " in c] assert len(sudo_n) == 1 assert len(sudo_s) == 1 + joined = "\n".join([c for c, _pty in calls]) + assert "sudo -n -p '' -- mktemp -d" in joined + assert "--out /tmp/enroll-root-789/bundle" in joined + assert "--out /tmp/enroll-remote-789/bundle" not in joined # Ensure the password was written to stdin for the -S invocation. assert stdin_by_cmd.get(sudo_s[0]) == ["s3cr3t\n"]