More test coverage (71%)
Some checks failed
Lint / test (push) Waiting to run
Trivy / test (push) Waiting to run
CI / test (push) Has been cancelled

This commit is contained in:
Miguel Jacq 2026-01-03 12:34:39 +11:00
parent 9a2516d858
commit f82fd894ca
Signed by: mig5
GPG key ID: 59B3F0C24135C6A9
8 changed files with 605 additions and 10 deletions

View file

@ -49,7 +49,7 @@ def test_safe_extract_tar_rejects_symlinks(tmp_path: Path):
_safe_extract_tar(tf, tmp_path)
def test_remote_harvest_happy_path_requests_pty_for_sudo(tmp_path: Path, monkeypatch):
def test_remote_harvest_happy_path(tmp_path: Path, monkeypatch):
import sys
import enroll.remote as r
@ -65,6 +65,7 @@ def test_remote_harvest_happy_path_requests_pty_for_sudo(tmp_path: Path, monkeyp
# Prepare a tiny harvest bundle tar stream from the "remote".
tgz = _make_tgz_bytes({"state.json": b'{"ok": true}\n'})
# Track each SSH exec_command call with whether a PTY was requested.
calls: list[tuple[str, bool]] = []
class _Chan:
@ -116,9 +117,8 @@ def test_remote_harvest_happy_path_requests_pty_for_sudo(tmp_path: Path, monkeyp
def open_sftp(self):
return self._sftp
def exec_command(self, cmd: str, get_pty: bool = False):
def exec_command(self, cmd: str, *, get_pty: bool = False, **_kwargs):
calls.append((cmd, bool(get_pty)))
# The tar stream uses exec_command directly.
if cmd.startswith("tar -cz -C"):
return (None, _Stdout(tgz, rc=0), _Stderr(b""))
@ -169,15 +169,122 @@ def test_remote_harvest_happy_path_requests_pty_for_sudo(tmp_path: Path, monkeyp
assert b"ok" in state_path.read_bytes()
# Ensure we attempted remote harvest with sudo and passed include/exclude and dangerous.
joined = "\n".join([c for c, _ in calls])
joined = "\n".join([c for c, _pty in calls])
assert "sudo" in joined
assert "--dangerous" in joined
assert "--include-path" in joined
assert "--exclude-path" in joined
# Assert PTY is requested for sudo commands (harvest & chown), not for tar streaming.
sudo_cmds = [(c, pty) for c, pty in calls if c.startswith("sudo ")]
assert sudo_cmds, "expected at least one sudo command"
assert all(pty for _, pty in sudo_cmds)
tar_cmds = [(c, pty) for c, pty in calls if c.startswith("tar -cz -C")]
assert tar_cmds and all(not pty for _, pty in tar_cmds)
# Ensure PTY is used for sudo commands (sudoers requiretty) but not for tar.
pty_by_cmd = {c: pty for c, pty in calls}
assert pty_by_cmd.get("id -un") is False
assert any(
c.startswith("sudo") and " harvest " in c and pty is True for c, pty in calls
)
assert any(c.startswith("sudo chown -R") and pty is True for c, pty in calls)
assert any(c.startswith("tar -cz -C") and pty is False for c, pty in calls)
def test_remote_harvest_no_sudo_does_not_request_pty_or_chown(
tmp_path: Path, monkeypatch
):
"""When --no-sudo is used we should not request a PTY nor run sudo chown."""
import sys
import enroll.remote as r
monkeypatch.setattr(
r,
"_build_enroll_pyz",
lambda td: (Path(td) / "enroll.pyz").write_bytes(b"PYZ")
or (Path(td) / "enroll.pyz"),
)
tgz = _make_tgz_bytes({"state.json": b"{}"})
calls: list[tuple[str, bool]] = []
class _Chan:
def __init__(self, rc: int = 0):
self._rc = rc
def recv_exit_status(self) -> int:
return self._rc
class _Stdout:
def __init__(self, payload: bytes = b"", rc: int = 0):
self._bio = io.BytesIO(payload)
self.channel = _Chan(rc)
def read(self, n: int = -1) -> bytes:
return self._bio.read(n)
class _Stderr:
def __init__(self, payload: bytes = b""):
self._bio = io.BytesIO(payload)
def read(self, n: int = -1) -> bytes:
return self._bio.read(n)
class _SFTP:
def put(self, _local: str, _remote: str) -> None:
return
def close(self) -> None:
return
class FakeSSH:
def __init__(self):
self._sftp = _SFTP()
def load_system_host_keys(self):
return
def set_missing_host_key_policy(self, _policy):
return
def connect(self, **_kwargs):
return
def open_sftp(self):
return self._sftp
def exec_command(self, cmd: str, *, get_pty: bool = False, **_kwargs):
calls.append((cmd, bool(get_pty)))
if cmd == "mktemp -d":
return (None, _Stdout(b"/tmp/enroll-remote-456\n"), _Stderr())
if cmd.startswith("chmod 700"):
return (None, _Stdout(b""), _Stderr())
if cmd.startswith("tar -cz -C"):
return (None, _Stdout(tgz, rc=0), _Stderr())
if " harvest " in cmd:
return (None, _Stdout(b""), _Stderr())
if cmd.startswith("rm -rf"):
return (None, _Stdout(b""), _Stderr())
return (None, _Stdout(b""), _Stderr())
def close(self):
return
import types
class RejectPolicy:
pass
monkeypatch.setitem(
sys.modules,
"paramiko",
types.SimpleNamespace(SSHClient=FakeSSH, RejectPolicy=RejectPolicy),
)
out_dir = tmp_path / "out"
r.remote_harvest(
local_out_dir=out_dir,
remote_host="example.com",
remote_user="alice",
no_sudo=True,
)
joined = "\n".join([c for c, _pty in calls])
assert "sudo" not in joined
assert "sudo chown" not in joined
assert any(" harvest " in c and pty is False for c, pty in calls)