more test coverage

This commit is contained in:
Miguel Jacq 2026-05-31 16:50:57 +10:00
parent b25dd1e314
commit 1544dc0295
Signed by: mig5
GPG key ID: 03906B4110AAD3B8
15 changed files with 3150 additions and 424 deletions

View file

@ -565,3 +565,452 @@ def test_remote_harvest_sudo_password_retry_uses_sudo_s_and_writes_password(
# Ensure the password was written to stdin for the -S invocation.
assert stdin_by_cmd.get(sudo_s[0]) == ["s3cr3t\n"]
def test_sudo_password_required_detection():
from enroll.remote import _sudo_password_required
assert _sudo_password_required("", "a password is required") is True
assert _sudo_password_required("", "password is required") is True
assert (
_sudo_password_required("", "a terminal is required to read the password")
is True
)
assert (
_sudo_password_required("", "no tty present and no askpass program specified")
is True
)
assert _sudo_password_required("", "must have a tty to run sudo") is True
assert _sudo_password_required("", "sudo: sorry, you must have a tty") is True
assert _sudo_password_required("", "askpass") is True
assert _sudo_password_required("success", "") is False
def test_sudo_not_permitted_detection():
from enroll.remote import _sudo_not_permitted
assert _sudo_not_permitted("", "user is not in the sudoers file") is True
assert _sudo_not_permitted("", "not allowed to execute") is True
assert _sudo_not_permitted("", "may not run sudo") is True
assert _sudo_not_permitted("", "sorry, user") is True
assert _sudo_not_permitted("success", "") is False
def test_sudo_tty_required_detection():
from enroll.remote import _sudo_tty_required
assert _sudo_tty_required("", "must have a tty") is True
assert _sudo_tty_required("", "sorry, you must have a tty") is True
assert _sudo_tty_required("", "sudo: sorry, you must have a tty") is True
assert _sudo_tty_required("", "must have a tty to run sudo") is True
assert _sudo_tty_required("success", "") is False
def test_resolve_become_password_prompts_when_asked(monkeypatch):
from enroll.remote import _resolve_become_password
prompted = []
def fake_getpass(prompt):
prompted.append(prompt)
return "secret"
result = _resolve_become_password(
True, prompt="sudo password: ", getpass_fn=fake_getpass
)
assert result == "secret"
assert len(prompted) == 1
def test_resolve_become_password_returns_none_when_not_asked():
from enroll.remote import _resolve_become_password
result = _resolve_become_password(False)
assert result is None
def test_resolve_ssh_key_passphrase_from_env(monkeypatch):
from enroll.remote import _resolve_ssh_key_passphrase
monkeypatch.setenv("SSH_KEY_PASS", "env_secret")
result = _resolve_ssh_key_passphrase(False, env_var="SSH_KEY_PASS")
assert result == "env_secret"
def test_resolve_ssh_key_passphrase_raises_when_env_not_set(monkeypatch):
from enroll.remote import _resolve_ssh_key_passphrase
monkeypatch.delenv("SSH_KEY_PASS", raising=False)
with pytest.raises(RuntimeError, match="SSH key passphrase environment variable"):
_resolve_ssh_key_passphrase(False, env_var="SSH_KEY_PASS")
def test_resolve_ssh_key_passphrase_prompts_when_asked(monkeypatch):
from enroll.remote import _resolve_ssh_key_passphrase
prompted = []
def fake_getpass(prompt):
prompted.append(prompt)
return "prompt_secret"
result = _resolve_ssh_key_passphrase(
True, prompt="SSH key passphrase: ", getpass_fn=fake_getpass
)
assert result == "prompt_secret"
assert len(prompted) == 1
def test_resolve_ssh_key_passphrase_returns_none_when_not_asked():
from enroll.remote import _resolve_ssh_key_passphrase
result = _resolve_ssh_key_passphrase(False, env_var=None)
assert result is None
def test_safe_extract_tar_rejects_absolute_paths(tmp_path: Path):
from enroll.remote import _safe_extract_tar
import io
import tarfile
bio = io.BytesIO()
with tarfile.open(fileobj=bio, mode="w:gz") as tf:
ti = tarfile.TarInfo(name="/etc/passwd")
ti.size = 1
tf.addfile(ti, io.BytesIO(b"x"))
bio.seek(0)
with tarfile.open(fileobj=bio, mode="r:gz") as tf:
with pytest.raises(RuntimeError, match="Unsafe tar member path"):
_safe_extract_tar(tf, tmp_path)
def test_safe_extract_tar_rejects_hardlinks(tmp_path: Path):
from enroll.remote import _safe_extract_tar
import io
import tarfile
bio = io.BytesIO()
with tarfile.open(fileobj=bio, mode="w:gz") as tf:
ti = tarfile.TarInfo(name="hardlink")
ti.type = tarfile.LNKTYPE
ti.linkname = "/etc/passwd"
tf.addfile(ti)
bio.seek(0)
with tarfile.open(fileobj=bio, mode="r:gz") as tf:
with pytest.raises(RuntimeError, match="Refusing to extract"):
_safe_extract_tar(tf, tmp_path)
def test_safe_extract_tar_rejects_device_nodes(tmp_path: Path):
from enroll.remote import _safe_extract_tar
import io
import tarfile
bio = io.BytesIO()
with tarfile.open(fileobj=bio, mode="w:gz") as tf:
ti = tarfile.TarInfo(name="device")
ti.type = tarfile.CHRTYPE
tf.addfile(ti)
bio.seek(0)
with tarfile.open(fileobj=bio, mode="r:gz") as tf:
with pytest.raises(RuntimeError, match="Refusing to extract"):
_safe_extract_tar(tf, tmp_path)
def test_safe_extract_tar_accepts_dot_entry(tmp_path: Path):
from enroll.remote import _safe_extract_tar
import io
import tarfile
bio = io.BytesIO()
with tarfile.open(fileobj=bio, mode="w:gz") as tf:
ti = tarfile.TarInfo(name=".")
ti.size = 0
tf.addfile(ti, io.BytesIO(b""))
bio.seek(0)
with tarfile.open(fileobj=bio, mode="r:gz") as tf:
_safe_extract_tar(tf, tmp_path)
def test_safe_extract_tar_accepts_valid_files(tmp_path: Path):
from enroll.remote import _safe_extract_tar
import io
import tarfile
bio = io.BytesIO()
with tarfile.open(fileobj=bio, mode="w:gz") as tf:
ti = tarfile.TarInfo(name="foo/bar.txt")
ti.size = 5
tf.addfile(ti, io.BytesIO(b"hello"))
bio.seek(0)
with tarfile.open(fileobj=bio, mode="r:gz") as tf:
_safe_extract_tar(tf, tmp_path)
assert (tmp_path / "foo" / "bar.txt").read_bytes() == b"hello"
def test_remote_harvest_ssh_key_passphrase_retry(monkeypatch, tmp_path: Path):
import sys
import enroll.remote as r
monkeypatch.setattr(
r,
"_build_enroll_pyz",
lambda td: (Path(td) / "enroll.pyz").write_bytes(b"PYZ")
or (Path(td) / "enroll.pyz"),
)
tgz = _make_tgz_bytes({"state.json": b'{"ok": true}\n'})
class _Chan:
def __init__(self, out: bytes = b"", err: bytes = b"", rc: int = 0):
self._out = out
self._err = err
self._out_i = 0
self._err_i = 0
self._rc = rc
self._closed = False
def recv_ready(self) -> bool:
return (not self._closed) and self._out_i < len(self._out)
def recv(self, n: int) -> bytes:
if self._closed:
return b""
chunk = self._out[self._out_i : self._out_i + n]
self._out_i += len(chunk)
return chunk
def recv_stderr_ready(self) -> bool:
return (not self._closed) and self._err_i < len(self._err)
def recv_stderr(self, n: int) -> bytes:
if self._closed:
return b""
chunk = self._err[self._err_i : self._err_i + n]
self._err_i += len(chunk)
return chunk
def exit_status_ready(self) -> bool:
return self._closed or (
self._out_i >= len(self._out) and self._err_i >= len(self._err)
)
def recv_exit_status(self) -> int:
return self._rc
def shutdown_write(self) -> None:
return
def close(self) -> None:
self._closed = True
class _Stdout:
def __init__(self, payload: bytes = b"", rc: int = 0, err: bytes = b""):
self._bio = io.BytesIO(payload)
self.channel = _Chan(out=payload, err=err, rc=rc)
def read(self, n: int = -1) -> bytes:
return self._bio.read(n)
class _Stderr:
def __init__(self, payload: bytes = b""):
self._bio = io.BytesIO(payload)
def read(self, n: int = -1) -> bytes:
return self._bio.read(n)
class _Stdin:
def __init__(self, cmd: str):
self._cmd = cmd
def write(self, s: str) -> None:
pass
def flush(self) -> None:
return
class _SFTP:
def put(self, _local: str, _remote: str) -> None:
return
def close(self) -> None:
return
class FakeSSH:
def __init__(self):
self._sftp = _SFTP()
def load_system_host_keys(self):
return
def set_missing_host_key_policy(self, _policy):
return
def connect(self, **_kwargs):
return
def open_sftp(self):
return self._sftp
def exec_command(self, cmd: str, *, get_pty: bool = False, **_kwargs):
if cmd.startswith("tar -cz -C"):
return (_Stdin(cmd), _Stdout(tgz, rc=0), _Stderr(b""))
if cmd == "mktemp -d":
return (_Stdin(cmd), _Stdout(b"/tmp/enroll-remote-789\n"), _Stderr())
if cmd.startswith("chmod 700"):
return (_Stdin(cmd), _Stdout(b""), _Stderr())
if " harvest " in cmd:
return (_Stdin(cmd), _Stdout(b""), _Stderr())
if cmd.startswith("rm -rf"):
return (_Stdin(cmd), _Stdout(b""), _Stderr())
return (_Stdin(cmd), _Stdout(b""), _Stderr())
def close(self):
return
RejectPolicy4 = type("RejectPolicy", (), {})
class FakeParamiko:
SSHClient = FakeSSH
RejectPolicy = RejectPolicy4 # type: ignore
PasswordRequiredException = Exception # type: ignore
monkeypatch.setitem(sys.modules, "paramiko", FakeParamiko)
prompts = []
def fake_getpass(prompt):
prompts.append(prompt)
return "passphrase"
out_dir = tmp_path / "out"
state_path = r.remote_harvest(
ask_key_passphrase=True,
getpass_fn=fake_getpass,
local_out_dir=out_dir,
remote_host="example.com",
remote_user="alice",
no_sudo=True,
)
assert state_path.exists()
assert len(prompts) == 1
def test_remote_harvest_ssh_key_passphrase_raises_when_not_interactive(
monkeypatch, tmp_path: Path
):
import sys
import enroll.remote as r
monkeypatch.setattr(
r,
"_build_enroll_pyz",
lambda td: (Path(td) / "enroll.pyz").write_bytes(b"PYZ")
or (Path(td) / "enroll.pyz"),
)
class _Chan:
def __init__(self):
self._closed = False
def recv_ready(self) -> bool:
return False
def recv(self, n: int) -> bytes:
return b""
def recv_stderr_ready(self) -> bool:
return False
def recv_stderr(self, n: int) -> bytes:
return b""
def exit_status_ready(self) -> bool:
return True
def recv_exit_status(self) -> int:
return 0
def shutdown_write(self) -> None:
return
def close(self) -> None:
self._closed = True
class _Stdout:
def __init__(self):
self.channel = _Chan()
def read(self, n: int = -1) -> bytes:
return b""
class _Stderr:
def read(self, n: int = -1) -> bytes:
return b""
class _SFTP:
def put(self, _local: str, _remote: str) -> None:
return
def close(self) -> None:
return
class FakeSSH:
def __init__(self):
self._sftp = _SFTP()
def load_system_host_keys(self):
return
def set_missing_host_key_policy(self, _policy):
return
def connect(self, **_kwargs):
raise Exception("PasswordRequired")
def open_sftp(self):
return self._sftp
def exec_command(self, cmd: str, **_kwargs):
return (_Stdout(), _Stdout(), _Stderr())
def close(self):
return
class RejectPolicy:
pass
RejectPolicy3 = RejectPolicy
class FakeParamiko:
SSHClient = FakeSSH
RejectPolicy = RejectPolicy3 # type: ignore
PasswordRequiredException = Exception # type: ignore
monkeypatch.setitem(sys.modules, "paramiko", FakeParamiko)
out_dir = tmp_path / "out"
with pytest.raises(RuntimeError, match="SSH private key is encrypted"):
r.remote_harvest(
ask_key_passphrase=False,
local_out_dir=out_dir,
remote_host="example.com",
stdin=io.StringIO(),
)