more test coverage
This commit is contained in:
parent
b25dd1e314
commit
1544dc0295
15 changed files with 3150 additions and 424 deletions
|
|
@ -565,3 +565,452 @@ def test_remote_harvest_sudo_password_retry_uses_sudo_s_and_writes_password(
|
|||
|
||||
# Ensure the password was written to stdin for the -S invocation.
|
||||
assert stdin_by_cmd.get(sudo_s[0]) == ["s3cr3t\n"]
|
||||
|
||||
|
||||
def test_sudo_password_required_detection():
|
||||
from enroll.remote import _sudo_password_required
|
||||
|
||||
assert _sudo_password_required("", "a password is required") is True
|
||||
assert _sudo_password_required("", "password is required") is True
|
||||
assert (
|
||||
_sudo_password_required("", "a terminal is required to read the password")
|
||||
is True
|
||||
)
|
||||
assert (
|
||||
_sudo_password_required("", "no tty present and no askpass program specified")
|
||||
is True
|
||||
)
|
||||
assert _sudo_password_required("", "must have a tty to run sudo") is True
|
||||
assert _sudo_password_required("", "sudo: sorry, you must have a tty") is True
|
||||
assert _sudo_password_required("", "askpass") is True
|
||||
assert _sudo_password_required("success", "") is False
|
||||
|
||||
|
||||
def test_sudo_not_permitted_detection():
|
||||
from enroll.remote import _sudo_not_permitted
|
||||
|
||||
assert _sudo_not_permitted("", "user is not in the sudoers file") is True
|
||||
assert _sudo_not_permitted("", "not allowed to execute") is True
|
||||
assert _sudo_not_permitted("", "may not run sudo") is True
|
||||
assert _sudo_not_permitted("", "sorry, user") is True
|
||||
assert _sudo_not_permitted("success", "") is False
|
||||
|
||||
|
||||
def test_sudo_tty_required_detection():
|
||||
from enroll.remote import _sudo_tty_required
|
||||
|
||||
assert _sudo_tty_required("", "must have a tty") is True
|
||||
assert _sudo_tty_required("", "sorry, you must have a tty") is True
|
||||
assert _sudo_tty_required("", "sudo: sorry, you must have a tty") is True
|
||||
assert _sudo_tty_required("", "must have a tty to run sudo") is True
|
||||
assert _sudo_tty_required("success", "") is False
|
||||
|
||||
|
||||
def test_resolve_become_password_prompts_when_asked(monkeypatch):
|
||||
from enroll.remote import _resolve_become_password
|
||||
|
||||
prompted = []
|
||||
|
||||
def fake_getpass(prompt):
|
||||
prompted.append(prompt)
|
||||
return "secret"
|
||||
|
||||
result = _resolve_become_password(
|
||||
True, prompt="sudo password: ", getpass_fn=fake_getpass
|
||||
)
|
||||
assert result == "secret"
|
||||
assert len(prompted) == 1
|
||||
|
||||
|
||||
def test_resolve_become_password_returns_none_when_not_asked():
|
||||
from enroll.remote import _resolve_become_password
|
||||
|
||||
result = _resolve_become_password(False)
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_resolve_ssh_key_passphrase_from_env(monkeypatch):
|
||||
from enroll.remote import _resolve_ssh_key_passphrase
|
||||
|
||||
monkeypatch.setenv("SSH_KEY_PASS", "env_secret")
|
||||
|
||||
result = _resolve_ssh_key_passphrase(False, env_var="SSH_KEY_PASS")
|
||||
assert result == "env_secret"
|
||||
|
||||
|
||||
def test_resolve_ssh_key_passphrase_raises_when_env_not_set(monkeypatch):
|
||||
from enroll.remote import _resolve_ssh_key_passphrase
|
||||
|
||||
monkeypatch.delenv("SSH_KEY_PASS", raising=False)
|
||||
|
||||
with pytest.raises(RuntimeError, match="SSH key passphrase environment variable"):
|
||||
_resolve_ssh_key_passphrase(False, env_var="SSH_KEY_PASS")
|
||||
|
||||
|
||||
def test_resolve_ssh_key_passphrase_prompts_when_asked(monkeypatch):
|
||||
from enroll.remote import _resolve_ssh_key_passphrase
|
||||
|
||||
prompted = []
|
||||
|
||||
def fake_getpass(prompt):
|
||||
prompted.append(prompt)
|
||||
return "prompt_secret"
|
||||
|
||||
result = _resolve_ssh_key_passphrase(
|
||||
True, prompt="SSH key passphrase: ", getpass_fn=fake_getpass
|
||||
)
|
||||
assert result == "prompt_secret"
|
||||
assert len(prompted) == 1
|
||||
|
||||
|
||||
def test_resolve_ssh_key_passphrase_returns_none_when_not_asked():
|
||||
from enroll.remote import _resolve_ssh_key_passphrase
|
||||
|
||||
result = _resolve_ssh_key_passphrase(False, env_var=None)
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_safe_extract_tar_rejects_absolute_paths(tmp_path: Path):
|
||||
from enroll.remote import _safe_extract_tar
|
||||
|
||||
import io
|
||||
import tarfile
|
||||
|
||||
bio = io.BytesIO()
|
||||
with tarfile.open(fileobj=bio, mode="w:gz") as tf:
|
||||
ti = tarfile.TarInfo(name="/etc/passwd")
|
||||
ti.size = 1
|
||||
tf.addfile(ti, io.BytesIO(b"x"))
|
||||
|
||||
bio.seek(0)
|
||||
with tarfile.open(fileobj=bio, mode="r:gz") as tf:
|
||||
with pytest.raises(RuntimeError, match="Unsafe tar member path"):
|
||||
_safe_extract_tar(tf, tmp_path)
|
||||
|
||||
|
||||
def test_safe_extract_tar_rejects_hardlinks(tmp_path: Path):
|
||||
from enroll.remote import _safe_extract_tar
|
||||
|
||||
import io
|
||||
import tarfile
|
||||
|
||||
bio = io.BytesIO()
|
||||
with tarfile.open(fileobj=bio, mode="w:gz") as tf:
|
||||
ti = tarfile.TarInfo(name="hardlink")
|
||||
ti.type = tarfile.LNKTYPE
|
||||
ti.linkname = "/etc/passwd"
|
||||
tf.addfile(ti)
|
||||
|
||||
bio.seek(0)
|
||||
with tarfile.open(fileobj=bio, mode="r:gz") as tf:
|
||||
with pytest.raises(RuntimeError, match="Refusing to extract"):
|
||||
_safe_extract_tar(tf, tmp_path)
|
||||
|
||||
|
||||
def test_safe_extract_tar_rejects_device_nodes(tmp_path: Path):
|
||||
from enroll.remote import _safe_extract_tar
|
||||
|
||||
import io
|
||||
import tarfile
|
||||
|
||||
bio = io.BytesIO()
|
||||
with tarfile.open(fileobj=bio, mode="w:gz") as tf:
|
||||
ti = tarfile.TarInfo(name="device")
|
||||
ti.type = tarfile.CHRTYPE
|
||||
tf.addfile(ti)
|
||||
|
||||
bio.seek(0)
|
||||
with tarfile.open(fileobj=bio, mode="r:gz") as tf:
|
||||
with pytest.raises(RuntimeError, match="Refusing to extract"):
|
||||
_safe_extract_tar(tf, tmp_path)
|
||||
|
||||
|
||||
def test_safe_extract_tar_accepts_dot_entry(tmp_path: Path):
|
||||
from enroll.remote import _safe_extract_tar
|
||||
|
||||
import io
|
||||
import tarfile
|
||||
|
||||
bio = io.BytesIO()
|
||||
with tarfile.open(fileobj=bio, mode="w:gz") as tf:
|
||||
ti = tarfile.TarInfo(name=".")
|
||||
ti.size = 0
|
||||
tf.addfile(ti, io.BytesIO(b""))
|
||||
|
||||
bio.seek(0)
|
||||
with tarfile.open(fileobj=bio, mode="r:gz") as tf:
|
||||
_safe_extract_tar(tf, tmp_path)
|
||||
|
||||
|
||||
def test_safe_extract_tar_accepts_valid_files(tmp_path: Path):
|
||||
from enroll.remote import _safe_extract_tar
|
||||
|
||||
import io
|
||||
import tarfile
|
||||
|
||||
bio = io.BytesIO()
|
||||
with tarfile.open(fileobj=bio, mode="w:gz") as tf:
|
||||
ti = tarfile.TarInfo(name="foo/bar.txt")
|
||||
ti.size = 5
|
||||
tf.addfile(ti, io.BytesIO(b"hello"))
|
||||
|
||||
bio.seek(0)
|
||||
with tarfile.open(fileobj=bio, mode="r:gz") as tf:
|
||||
_safe_extract_tar(tf, tmp_path)
|
||||
|
||||
assert (tmp_path / "foo" / "bar.txt").read_bytes() == b"hello"
|
||||
|
||||
|
||||
def test_remote_harvest_ssh_key_passphrase_retry(monkeypatch, tmp_path: Path):
|
||||
import sys
|
||||
|
||||
import enroll.remote as r
|
||||
|
||||
monkeypatch.setattr(
|
||||
r,
|
||||
"_build_enroll_pyz",
|
||||
lambda td: (Path(td) / "enroll.pyz").write_bytes(b"PYZ")
|
||||
or (Path(td) / "enroll.pyz"),
|
||||
)
|
||||
|
||||
tgz = _make_tgz_bytes({"state.json": b'{"ok": true}\n'})
|
||||
|
||||
class _Chan:
|
||||
def __init__(self, out: bytes = b"", err: bytes = b"", rc: int = 0):
|
||||
self._out = out
|
||||
self._err = err
|
||||
self._out_i = 0
|
||||
self._err_i = 0
|
||||
self._rc = rc
|
||||
self._closed = False
|
||||
|
||||
def recv_ready(self) -> bool:
|
||||
return (not self._closed) and self._out_i < len(self._out)
|
||||
|
||||
def recv(self, n: int) -> bytes:
|
||||
if self._closed:
|
||||
return b""
|
||||
chunk = self._out[self._out_i : self._out_i + n]
|
||||
self._out_i += len(chunk)
|
||||
return chunk
|
||||
|
||||
def recv_stderr_ready(self) -> bool:
|
||||
return (not self._closed) and self._err_i < len(self._err)
|
||||
|
||||
def recv_stderr(self, n: int) -> bytes:
|
||||
if self._closed:
|
||||
return b""
|
||||
chunk = self._err[self._err_i : self._err_i + n]
|
||||
self._err_i += len(chunk)
|
||||
return chunk
|
||||
|
||||
def exit_status_ready(self) -> bool:
|
||||
return self._closed or (
|
||||
self._out_i >= len(self._out) and self._err_i >= len(self._err)
|
||||
)
|
||||
|
||||
def recv_exit_status(self) -> int:
|
||||
return self._rc
|
||||
|
||||
def shutdown_write(self) -> None:
|
||||
return
|
||||
|
||||
def close(self) -> None:
|
||||
self._closed = True
|
||||
|
||||
class _Stdout:
|
||||
def __init__(self, payload: bytes = b"", rc: int = 0, err: bytes = b""):
|
||||
self._bio = io.BytesIO(payload)
|
||||
self.channel = _Chan(out=payload, err=err, rc=rc)
|
||||
|
||||
def read(self, n: int = -1) -> bytes:
|
||||
return self._bio.read(n)
|
||||
|
||||
class _Stderr:
|
||||
def __init__(self, payload: bytes = b""):
|
||||
self._bio = io.BytesIO(payload)
|
||||
|
||||
def read(self, n: int = -1) -> bytes:
|
||||
return self._bio.read(n)
|
||||
|
||||
class _Stdin:
|
||||
def __init__(self, cmd: str):
|
||||
self._cmd = cmd
|
||||
|
||||
def write(self, s: str) -> None:
|
||||
pass
|
||||
|
||||
def flush(self) -> None:
|
||||
return
|
||||
|
||||
class _SFTP:
|
||||
def put(self, _local: str, _remote: str) -> None:
|
||||
return
|
||||
|
||||
def close(self) -> None:
|
||||
return
|
||||
|
||||
class FakeSSH:
|
||||
def __init__(self):
|
||||
self._sftp = _SFTP()
|
||||
|
||||
def load_system_host_keys(self):
|
||||
return
|
||||
|
||||
def set_missing_host_key_policy(self, _policy):
|
||||
return
|
||||
|
||||
def connect(self, **_kwargs):
|
||||
return
|
||||
|
||||
def open_sftp(self):
|
||||
return self._sftp
|
||||
|
||||
def exec_command(self, cmd: str, *, get_pty: bool = False, **_kwargs):
|
||||
if cmd.startswith("tar -cz -C"):
|
||||
return (_Stdin(cmd), _Stdout(tgz, rc=0), _Stderr(b""))
|
||||
if cmd == "mktemp -d":
|
||||
return (_Stdin(cmd), _Stdout(b"/tmp/enroll-remote-789\n"), _Stderr())
|
||||
if cmd.startswith("chmod 700"):
|
||||
return (_Stdin(cmd), _Stdout(b""), _Stderr())
|
||||
if " harvest " in cmd:
|
||||
return (_Stdin(cmd), _Stdout(b""), _Stderr())
|
||||
if cmd.startswith("rm -rf"):
|
||||
return (_Stdin(cmd), _Stdout(b""), _Stderr())
|
||||
return (_Stdin(cmd), _Stdout(b""), _Stderr())
|
||||
|
||||
def close(self):
|
||||
return
|
||||
|
||||
RejectPolicy4 = type("RejectPolicy", (), {})
|
||||
|
||||
class FakeParamiko:
|
||||
SSHClient = FakeSSH
|
||||
RejectPolicy = RejectPolicy4 # type: ignore
|
||||
PasswordRequiredException = Exception # type: ignore
|
||||
|
||||
monkeypatch.setitem(sys.modules, "paramiko", FakeParamiko)
|
||||
|
||||
prompts = []
|
||||
|
||||
def fake_getpass(prompt):
|
||||
prompts.append(prompt)
|
||||
return "passphrase"
|
||||
|
||||
out_dir = tmp_path / "out"
|
||||
state_path = r.remote_harvest(
|
||||
ask_key_passphrase=True,
|
||||
getpass_fn=fake_getpass,
|
||||
local_out_dir=out_dir,
|
||||
remote_host="example.com",
|
||||
remote_user="alice",
|
||||
no_sudo=True,
|
||||
)
|
||||
|
||||
assert state_path.exists()
|
||||
assert len(prompts) == 1
|
||||
|
||||
|
||||
def test_remote_harvest_ssh_key_passphrase_raises_when_not_interactive(
|
||||
monkeypatch, tmp_path: Path
|
||||
):
|
||||
import sys
|
||||
|
||||
import enroll.remote as r
|
||||
|
||||
monkeypatch.setattr(
|
||||
r,
|
||||
"_build_enroll_pyz",
|
||||
lambda td: (Path(td) / "enroll.pyz").write_bytes(b"PYZ")
|
||||
or (Path(td) / "enroll.pyz"),
|
||||
)
|
||||
|
||||
class _Chan:
|
||||
def __init__(self):
|
||||
self._closed = False
|
||||
|
||||
def recv_ready(self) -> bool:
|
||||
return False
|
||||
|
||||
def recv(self, n: int) -> bytes:
|
||||
return b""
|
||||
|
||||
def recv_stderr_ready(self) -> bool:
|
||||
return False
|
||||
|
||||
def recv_stderr(self, n: int) -> bytes:
|
||||
return b""
|
||||
|
||||
def exit_status_ready(self) -> bool:
|
||||
return True
|
||||
|
||||
def recv_exit_status(self) -> int:
|
||||
return 0
|
||||
|
||||
def shutdown_write(self) -> None:
|
||||
return
|
||||
|
||||
def close(self) -> None:
|
||||
self._closed = True
|
||||
|
||||
class _Stdout:
|
||||
def __init__(self):
|
||||
self.channel = _Chan()
|
||||
|
||||
def read(self, n: int = -1) -> bytes:
|
||||
return b""
|
||||
|
||||
class _Stderr:
|
||||
def read(self, n: int = -1) -> bytes:
|
||||
return b""
|
||||
|
||||
class _SFTP:
|
||||
def put(self, _local: str, _remote: str) -> None:
|
||||
return
|
||||
|
||||
def close(self) -> None:
|
||||
return
|
||||
|
||||
class FakeSSH:
|
||||
def __init__(self):
|
||||
self._sftp = _SFTP()
|
||||
|
||||
def load_system_host_keys(self):
|
||||
return
|
||||
|
||||
def set_missing_host_key_policy(self, _policy):
|
||||
return
|
||||
|
||||
def connect(self, **_kwargs):
|
||||
raise Exception("PasswordRequired")
|
||||
|
||||
def open_sftp(self):
|
||||
return self._sftp
|
||||
|
||||
def exec_command(self, cmd: str, **_kwargs):
|
||||
return (_Stdout(), _Stdout(), _Stderr())
|
||||
|
||||
def close(self):
|
||||
return
|
||||
|
||||
class RejectPolicy:
|
||||
pass
|
||||
|
||||
RejectPolicy3 = RejectPolicy
|
||||
|
||||
class FakeParamiko:
|
||||
SSHClient = FakeSSH
|
||||
RejectPolicy = RejectPolicy3 # type: ignore
|
||||
PasswordRequiredException = Exception # type: ignore
|
||||
|
||||
monkeypatch.setitem(sys.modules, "paramiko", FakeParamiko)
|
||||
|
||||
out_dir = tmp_path / "out"
|
||||
|
||||
with pytest.raises(RuntimeError, match="SSH private key is encrypted"):
|
||||
r.remote_harvest(
|
||||
ask_key_passphrase=False,
|
||||
local_out_dir=out_dir,
|
||||
remote_host="example.com",
|
||||
stdin=io.StringIO(),
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue