diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md new file mode 100644 index 0000000..a411acc --- /dev/null +++ b/CONTRIBUTORS.md @@ -0,0 +1,5 @@ +## Contributors + +mig5 would like to thank the following people for their contributions to Enroll. + + * [slhck](https://slhck.info/) diff --git a/enroll/cli.py b/enroll/cli.py index bb4d3f1..55fdd0b 100644 --- a/enroll/cli.py +++ b/enroll/cli.py @@ -13,7 +13,7 @@ from .cache import new_harvest_cache_dir from .diff import compare_harvests, format_report, post_webhook, send_email from .harvest import harvest from .manifest import manifest -from .remote import remote_harvest +from .remote import remote_harvest, RemoteSudoPasswordRequired from .sopsutil import SopsError, encrypt_file_binary from .version import get_enroll_version @@ -352,6 +352,17 @@ def _add_remote_args(p: argparse.ArgumentParser) -> None: help="SSH username for --remote-host (default: local $USER).", ) + # Align terminology with Ansible: "become" == sudo. + p.add_argument( + "--ask-become-pass", + "-K", + action="store_true", + help=( + "Prompt for the remote sudo (become) password when using --remote-host " + "(similar to ansible --ask-become-pass)." + ), + ) + def main() -> None: ap = argparse.ArgumentParser(prog="enroll") @@ -623,6 +634,7 @@ def main() -> None: except OSError: pass remote_harvest( + ask_become_pass=args.ask_become_pass, local_out_dir=tmp_bundle, remote_host=args.remote_host, remote_port=int(args.remote_port), @@ -643,6 +655,7 @@ def main() -> None: else new_harvest_cache_dir(hint=args.remote_host).dir ) state = remote_harvest( + ask_become_pass=args.ask_become_pass, local_out_dir=out_dir, remote_host=args.remote_host, remote_port=int(args.remote_port), @@ -769,6 +782,7 @@ def main() -> None: except OSError: pass remote_harvest( + ask_become_pass=args.ask_become_pass, local_out_dir=tmp_bundle, remote_host=args.remote_host, remote_port=int(args.remote_port), @@ -798,6 +812,7 @@ def main() -> None: else new_harvest_cache_dir(hint=args.remote_host).dir ) remote_harvest( + ask_become_pass=args.ask_become_pass, local_out_dir=harvest_dir, remote_host=args.remote_host, remote_port=int(args.remote_port), @@ -912,5 +927,11 @@ def main() -> None: if getattr(args, "exit_code", False) and has_changes: raise SystemExit(2) + except RemoteSudoPasswordRequired: + raise SystemExit( + "error: remote sudo requires a password. Re-run with --ask-become-pass." + ) from None + except RuntimeError as e: + raise SystemExit(f"error: {e}") from None except SopsError as e: - raise SystemExit(f"error: {e}") + raise SystemExit(f"error: {e}") from None diff --git a/enroll/remote.py b/enroll/remote.py index b86cd08..93cee74 100644 --- a/enroll/remote.py +++ b/enroll/remote.py @@ -1,14 +1,117 @@ from __future__ import annotations +import getpass import os import shlex import shutil +import sys +import time import tarfile import tempfile import zipapp from pathlib import Path from pathlib import PurePosixPath -from typing import Optional +from typing import Optional, Callable, TextIO + + +class RemoteSudoPasswordRequired(RuntimeError): + """Raised when sudo requires a password but none was provided.""" + + +def _sudo_password_required(out: str, err: str) -> bool: + """Return True if sudo output indicates it needs a password/TTY.""" + blob = (out + "\n" + err).lower() + patterns = ( + "a password is required", + "password is required", + "a terminal is required to read the password", + "no tty present and no askpass program specified", + "must have a tty to run sudo", + "sudo: sorry, you must have a tty", + "askpass", + ) + return any(p in blob for p in patterns) + + +def _sudo_not_permitted(out: str, err: str) -> bool: + """Return True if sudo output indicates the user cannot sudo at all.""" + blob = (out + "\n" + err).lower() + patterns = ( + "is not in the sudoers file", + "not allowed to execute", + "may not run sudo", + "sorry, user", + ) + return any(p in blob for p in patterns) + + +def _sudo_tty_required(out: str, err: str) -> bool: + """Return True if sudo output indicates it requires a TTY (sudoers requiretty).""" + blob = (out + "\n" + err).lower() + patterns = ( + "must have a tty", + "sorry, you must have a tty", + "sudo: sorry, you must have a tty", + "must have a tty to run sudo", + ) + return any(p in blob for p in patterns) + + +def _resolve_become_password( + ask_become_pass: bool, + *, + prompt: str = "sudo password: ", + getpass_fn: Callable[[str], str] = getpass.getpass, +) -> Optional[str]: + if ask_become_pass: + return getpass_fn(prompt) + return None + + +def remote_harvest( + *, + ask_become_pass: bool = False, + no_sudo: bool = False, + prompt: str = "sudo password: ", + getpass_fn: Optional[Callable[[str], str]] = None, + stdin: Optional[TextIO] = None, + **kwargs, +): + """Call _remote_harvest, with a safe sudo password fallback. + + Behavior: + - Run without a password unless --ask-become-pass is set. + - If the remote sudo policy requires a password and none was provided, + prompt and retry when running interactively. + """ + + # Resolve defaults at call time (easier to test/monkeypatch, and avoids capturing + # sys.stdin / getpass.getpass at import time). + if getpass_fn is None: + getpass_fn = getpass.getpass + if stdin is None: + stdin = sys.stdin + + sudo_password = _resolve_become_password( + ask_become_pass and not no_sudo, + prompt=prompt, + getpass_fn=getpass_fn, + ) + + try: + return _remote_harvest(sudo_password=sudo_password, no_sudo=no_sudo, **kwargs) + except RemoteSudoPasswordRequired: + if sudo_password is not None: + raise + + # Fallback prompt if interactive + if stdin is not None and getattr(stdin, "isatty", lambda: False)(): + pw = getpass_fn(prompt) + return _remote_harvest(sudo_password=pw, no_sudo=no_sudo, **kwargs) + + raise RemoteSudoPasswordRequired( + "Remote sudo requires a password. Re-run with --ask-become-pass." + ) def _safe_extract_tar(tar: tarfile.TarFile, dest: Path) -> None: @@ -79,7 +182,14 @@ def _build_enroll_pyz(tmpdir: Path) -> Path: return pyz_path -def _ssh_run(ssh, cmd: str, *, get_pty: bool = False) -> tuple[int, str, str]: +def _ssh_run( + ssh, + cmd: str, + *, + get_pty: bool = False, + stdin_text: Optional[str] = None, + close_stdin: bool = False, +) -> tuple[int, str, str]: """Run a command over a Paramiko SSHClient. Paramiko's exec_command runs commands without a TTY by default. @@ -90,14 +200,133 @@ def _ssh_run(ssh, cmd: str, *, get_pty: bool = False) -> tuple[int, str, str]: We do not request a PTY for commands that stream binary data (e.g. tar/gzip output), as a PTY can corrupt the byte stream. """ - _stdin, stdout, stderr = ssh.exec_command(cmd, get_pty=get_pty) - out = stdout.read().decode("utf-8", errors="replace") - err = stderr.read().decode("utf-8", errors="replace") - rc = stdout.channel.recv_exit_status() + stdin, stdout, stderr = ssh.exec_command(cmd, get_pty=get_pty) + # All three file-like objects share the same underlying Channel. + chan = stdout.channel + + if stdin_text is not None and stdin is not None: + try: + stdin.write(stdin_text) + stdin.flush() + except Exception: + # If the remote side closed stdin early, ignore. + pass # nosec + finally: + if close_stdin: + # For sudo -S, a wrong password causes sudo to re-prompt and wait + # forever for more input. We try hard to deliver EOF so sudo can + # fail fast. + try: + chan.shutdown_write() # sends EOF to the remote process + except Exception: + pass # nosec + try: + stdin.close() + except Exception: + pass # nosec + + # Read incrementally to avoid blocking forever on stdout.read()/stderr.read() + # if the remote process is waiting for more input (e.g. sudo password retry). + out_chunks: list[bytes] = [] + err_chunks: list[bytes] = [] + # Keep a small tail of stderr to detect sudo retry messages without + # repeatedly joining potentially large buffers. + err_tail = b"" + + while True: + progressed = False + if chan.recv_ready(): + out_chunks.append(chan.recv(1024 * 64)) + progressed = True + if chan.recv_stderr_ready(): + chunk = chan.recv_stderr(1024 * 64) + err_chunks.append(chunk) + err_tail = (err_tail + chunk)[-4096:] + progressed = True + + # If we just attempted sudo -S with a single password line and sudo is + # asking again, detect it and stop waiting. + if close_stdin and stdin_text is not None: + blob = err_tail.lower() + if b"sorry, try again" in blob or b"incorrect password" in blob: + try: + chan.close() + except Exception: + pass # nosec + break + + # Exit once the process has exited and we have drained the buffers. + if ( + chan.exit_status_ready() + and not chan.recv_ready() + and not chan.recv_stderr_ready() + ): + break + + if not progressed: + time.sleep(0.05) + + out = b"".join(out_chunks).decode("utf-8", errors="replace") + err = b"".join(err_chunks).decode("utf-8", errors="replace") + rc = chan.recv_exit_status() if chan.exit_status_ready() else 1 return rc, out, err -def remote_harvest( +def _ssh_run_sudo( + ssh, + cmd: str, + *, + sudo_password: Optional[str] = None, + get_pty: bool = True, +) -> tuple[int, str, str]: + """Run cmd via sudo with a safe non-interactive-first strategy. + + Strategy: + 1) Try `sudo -n`. + 2) If sudo reports a password is required and we have one, retry with + `sudo -S` and feed it via stdin. + 3) If sudo reports a password is required and we *don't* have one, raise + RemoteSudoPasswordRequired. + + We avoid requesting a PTY unless the remote sudo policy requires it. + This makes sudo -S behavior more reliable (wrong passwords fail fast + instead of blocking on a PTY). + """ + cmd_n = f"sudo -n -p '' -- {cmd}" + + # First try: never prompt, and prefer no PTY. + rc, out, err = _ssh_run(ssh, cmd_n, get_pty=False) + need_pty = False + + # Some sudoers configurations require a TTY even for passwordless sudo. + if get_pty and rc != 0 and _sudo_tty_required(out, err): + need_pty = True + rc, out, err = _ssh_run(ssh, cmd_n, get_pty=True) + + if rc == 0: + return rc, out, err + + if _sudo_not_permitted(out, err): + return rc, out, err + + if _sudo_password_required(out, err): + if sudo_password is None: + raise RemoteSudoPasswordRequired( + "Remote sudo requires a password, but none was provided." + ) + cmd_s = f"sudo -S -p '' -- {cmd}" + return _ssh_run( + ssh, + cmd_s, + get_pty=need_pty, + stdin_text=str(sudo_password) + "\n", + close_stdin=True, + ) + + return rc, out, err + + +def _remote_harvest( *, local_out_dir: Path, remote_host: str, @@ -106,6 +335,7 @@ def remote_harvest( remote_python: str = "python3", dangerous: bool = False, no_sudo: bool = False, + sudo_password: Optional[str] = None, include_paths: Optional[list[str]] = None, exclude_paths: Optional[list[str]] = None, ) -> Path: @@ -190,10 +420,15 @@ def remote_harvest( argv.extend(["--exclude-path", str(p)]) _cmd = " ".join(map(shlex.quote, argv)) - cmd = f"sudo {_cmd}" if not no_sudo else _cmd - - # PTY for sudo commands (helps sudoers requiretty). - rc, out, err = _ssh_run(ssh, cmd, get_pty=(not no_sudo)) + if not no_sudo: + # Prefer non-interactive sudo first; retry with -S only when needed. + rc, out, err = _ssh_run_sudo( + ssh, _cmd, sudo_password=sudo_password, get_pty=True + ) + cmd = f"sudo {_cmd}" + else: + cmd = _cmd + rc, out, err = _ssh_run(ssh, cmd, get_pty=False) if rc != 0: raise RuntimeError( "Remote harvest failed.\n" @@ -210,12 +445,17 @@ def remote_harvest( "Unable to determine remote username for chown. " "Pass --remote-user explicitly or use --no-sudo." ) - cmd = f"sudo chown -R {resolved_user} {rbundle}" - rc, out, err = _ssh_run(ssh, cmd, get_pty=True) + chown_cmd = f"chown -R {resolved_user} {rbundle}" + rc, out, err = _ssh_run_sudo( + ssh, + chown_cmd, + sudo_password=sudo_password, + get_pty=True, + ) if rc != 0: raise RuntimeError( "chown of harvest failed.\n" - f"Command: {cmd}\n" + f"Command: sudo {chown_cmd}\n" f"Exit code: {rc}\n" f"Stdout: {out.strip()}\n" f"Stderr: {err.strip()}" diff --git a/tests/test_cli.py b/tests/test_cli.py index 4477b24..5fc9a66 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,5 +1,7 @@ import sys +import pytest + import enroll.cli as cli @@ -258,6 +260,113 @@ def test_cli_single_shot_remote_without_harvest_prints_state_path( assert ("manifest", str(cache_dir), str(ansible_dir), "example.test") in calls +def test_cli_harvest_remote_ask_become_pass_prompts_and_passes_password( + monkeypatch, tmp_path +): + from enroll.cache import HarvestCache + import enroll.remote as r + + cache_dir = tmp_path / "cache" + cache_dir.mkdir() + + called = {} + + def fake_cache_dir(*, hint=None): + return HarvestCache(dir=cache_dir) + + def fake__remote_harvest(*, sudo_password=None, **kwargs): + called["sudo_password"] = sudo_password + return cache_dir / "state.json" + + monkeypatch.setattr(cli, "new_harvest_cache_dir", fake_cache_dir) + monkeypatch.setattr(r, "_remote_harvest", fake__remote_harvest) + monkeypatch.setattr(r.getpass, "getpass", lambda _prompt="": "pw123") + + monkeypatch.setattr( + sys, + "argv", + [ + "enroll", + "harvest", + "--remote-host", + "example.test", + "--ask-become-pass", + ], + ) + + cli.main() + assert called["sudo_password"] == "pw123" + + +def test_cli_harvest_remote_password_required_fallback_prompts_and_retries( + monkeypatch, tmp_path +): + from enroll.cache import HarvestCache + import enroll.remote as r + + cache_dir = tmp_path / "cache" + cache_dir.mkdir() + + def fake_cache_dir(*, hint=None): + return HarvestCache(dir=cache_dir) + + calls = [] + + def fake__remote_harvest(*, sudo_password=None, **kwargs): + calls.append(sudo_password) + if sudo_password is None: + raise r.RemoteSudoPasswordRequired("pw required") + return cache_dir / "state.json" + + class _TTYStdin: + def isatty(self): + return True + + monkeypatch.setattr(cli, "new_harvest_cache_dir", fake_cache_dir) + monkeypatch.setattr(r, "_remote_harvest", fake__remote_harvest) + monkeypatch.setattr(r.getpass, "getpass", lambda _prompt="": "pw456") + monkeypatch.setattr(sys, "stdin", _TTYStdin()) + + monkeypatch.setattr( + sys, "argv", ["enroll", "harvest", "--remote-host", "example.test"] + ) + + cli.main() + assert calls == [None, "pw456"] + + +def test_cli_harvest_remote_password_required_noninteractive_errors( + monkeypatch, tmp_path +): + from enroll.cache import HarvestCache + import enroll.remote as r + + cache_dir = tmp_path / "cache" + cache_dir.mkdir() + + def fake_cache_dir(*, hint=None): + return HarvestCache(dir=cache_dir) + + def fake__remote_harvest(*, sudo_password=None, **kwargs): + raise r.RemoteSudoPasswordRequired("pw required") + + class _NoTTYStdin: + def isatty(self): + return False + + monkeypatch.setattr(cli, "new_harvest_cache_dir", fake_cache_dir) + monkeypatch.setattr(r, "_remote_harvest", fake__remote_harvest) + monkeypatch.setattr(sys, "stdin", _NoTTYStdin()) + + monkeypatch.setattr( + sys, "argv", ["enroll", "harvest", "--remote-host", "example.test"] + ) + + with pytest.raises(SystemExit) as e: + cli.main() + assert "--ask-become-pass" in str(e.value) + + def test_cli_manifest_common_args(monkeypatch, tmp_path): """Ensure --fqdn and jinjaturtle mode flags are forwarded correctly.""" diff --git a/tests/test_remote.py b/tests/test_remote.py index 1f9c89b..6b4ab01 100644 --- a/tests/test_remote.py +++ b/tests/test_remote.py @@ -69,16 +69,53 @@ def test_remote_harvest_happy_path(tmp_path: Path, monkeypatch): calls: list[tuple[str, bool]] = [] class _Chan: - def __init__(self, rc: int = 0): + def __init__(self, out: bytes = b"", err: bytes = b"", rc: int = 0): + self._out = out + self._err = err + self._out_i = 0 + self._err_i = 0 self._rc = rc + self._closed = False + + def recv_ready(self) -> bool: + return (not self._closed) and self._out_i < len(self._out) + + def recv(self, n: int) -> bytes: + if self._closed: + return b"" + chunk = self._out[self._out_i : self._out_i + n] + self._out_i += len(chunk) + return chunk + + def recv_stderr_ready(self) -> bool: + return (not self._closed) and self._err_i < len(self._err) + + def recv_stderr(self, n: int) -> bytes: + if self._closed: + return b"" + chunk = self._err[self._err_i : self._err_i + n] + self._err_i += len(chunk) + return chunk + + def exit_status_ready(self) -> bool: + return self._closed or ( + self._out_i >= len(self._out) and self._err_i >= len(self._err) + ) def recv_exit_status(self) -> int: return self._rc + def shutdown_write(self) -> None: + return + + def close(self) -> None: + self._closed = True + class _Stdout: - def __init__(self, payload: bytes = b"", rc: int = 0): + def __init__(self, payload: bytes = b"", rc: int = 0, err: bytes = b""): self._bio = io.BytesIO(payload) - self.channel = _Chan(rc) + # _ssh_run reads stdout/stderr via the underlying channel. + self.channel = _Chan(out=payload, err=err, rc=rc) def read(self, n: int = -1) -> bytes: return self._bio.read(n) @@ -130,10 +167,20 @@ def test_remote_harvest_happy_path(tmp_path: Path, monkeypatch): return (None, _Stdout(b"/tmp/enroll-remote-123\n"), _Stderr()) if cmd.startswith("chmod 700"): return (None, _Stdout(b""), _Stderr()) + if cmd.startswith("sudo -n") and " harvest " in cmd: + if not get_pty: + msg = b"sudo: sorry, you must have a tty to run sudo\n" + return (None, _Stdout(b"", rc=1, err=msg), _Stderr(msg)) + return (None, _Stdout(b"", rc=0), _Stderr(b"")) + if cmd.startswith("sudo -S") and " harvest " in cmd: + return (None, _Stdout(b""), _Stderr()) if " harvest " in cmd: return (None, _Stdout(b""), _Stderr()) - if cmd.startswith("sudo chown -R"): - return (None, _Stdout(b""), _Stderr()) + if cmd.startswith("sudo -n") and " chown -R" in cmd: + if not get_pty: + msg = b"sudo: sorry, you must have a tty to run sudo\n" + return (None, _Stdout(b"", rc=1, err=msg), _Stderr(msg)) + return (None, _Stdout(b"", rc=0), _Stderr(b"")) if cmd.startswith("rm -rf"): return (None, _Stdout(b""), _Stderr()) @@ -154,6 +201,7 @@ def test_remote_harvest_happy_path(tmp_path: Path, monkeypatch): out_dir = tmp_path / "out" state_path = r.remote_harvest( + ask_become_pass=False, local_out_dir=out_dir, remote_host="example.com", remote_port=2222, @@ -175,13 +223,21 @@ def test_remote_harvest_happy_path(tmp_path: Path, monkeypatch): assert "--include-path" in joined assert "--exclude-path" in joined - # Ensure PTY is used for sudo commands (sudoers requiretty) but not for tar. - pty_by_cmd = {c: pty for c, pty in calls} - assert pty_by_cmd.get("id -un") is False - assert any( - c.startswith("sudo") and " harvest " in c and pty is True for c, pty in calls - ) - assert any(c.startswith("sudo chown -R") and pty is True for c, pty in calls) + # Ensure we fall back to PTY only when sudo reports it is required. + assert any(c == "id -un" and pty is False for c, pty in calls) + + sudo_harvest = [ + (c, pty) for c, pty in calls if c.startswith("sudo -n") and " harvest " in c + ] + assert any(pty is False for _c, pty in sudo_harvest) + assert any(pty is True for _c, pty in sudo_harvest) + + sudo_chown = [ + (c, pty) for c, pty in calls if c.startswith("sudo -n") and " chown -R" in c + ] + assert any(pty is False for _c, pty in sudo_chown) + assert any(pty is True for _c, pty in sudo_chown) + assert any(c.startswith("tar -cz -C") and pty is False for c, pty in calls) @@ -204,16 +260,53 @@ def test_remote_harvest_no_sudo_does_not_request_pty_or_chown( calls: list[tuple[str, bool]] = [] class _Chan: - def __init__(self, rc: int = 0): + def __init__(self, out: bytes = b"", err: bytes = b"", rc: int = 0): + self._out = out + self._err = err + self._out_i = 0 + self._err_i = 0 self._rc = rc + self._closed = False + + def recv_ready(self) -> bool: + return (not self._closed) and self._out_i < len(self._out) + + def recv(self, n: int) -> bytes: + if self._closed: + return b"" + chunk = self._out[self._out_i : self._out_i + n] + self._out_i += len(chunk) + return chunk + + def recv_stderr_ready(self) -> bool: + return (not self._closed) and self._err_i < len(self._err) + + def recv_stderr(self, n: int) -> bytes: + if self._closed: + return b"" + chunk = self._err[self._err_i : self._err_i + n] + self._err_i += len(chunk) + return chunk + + def exit_status_ready(self) -> bool: + return self._closed or ( + self._out_i >= len(self._out) and self._err_i >= len(self._err) + ) def recv_exit_status(self) -> int: return self._rc + def shutdown_write(self) -> None: + return + + def close(self) -> None: + self._closed = True + class _Stdout: - def __init__(self, payload: bytes = b"", rc: int = 0): + def __init__(self, payload: bytes = b"", rc: int = 0, err: bytes = b""): self._bio = io.BytesIO(payload) - self.channel = _Chan(rc) + # _ssh_run reads stdout/stderr via the underlying channel. + self.channel = _Chan(out=payload, err=err, rc=rc) def read(self, n: int = -1) -> bytes: return self._bio.read(n) @@ -278,6 +371,7 @@ def test_remote_harvest_no_sudo_does_not_request_pty_or_chown( out_dir = tmp_path / "out" r.remote_harvest( + ask_become_pass=False, local_out_dir=out_dir, remote_host="example.com", remote_user="alice", @@ -288,3 +382,186 @@ def test_remote_harvest_no_sudo_does_not_request_pty_or_chown( assert "sudo" not in joined assert "sudo chown" not in joined assert any(" harvest " in c and pty is False for c, pty in calls) + + +def test_remote_harvest_sudo_password_retry_uses_sudo_s_and_writes_password( + tmp_path: Path, monkeypatch +): + """If sudo requires a password, we should fall back from -n to -S and feed stdin.""" + import sys + import types + + import enroll.remote as r + + # Avoid building a real zipapp; just create a file. + monkeypatch.setattr( + r, + "_build_enroll_pyz", + lambda td: (Path(td) / "enroll.pyz").write_bytes(b"PYZ") + or (Path(td) / "enroll.pyz"), + ) + + tgz = _make_tgz_bytes({"state.json": b'{"ok": true}\n'}) + calls: list[tuple[str, bool]] = [] + stdin_by_cmd: dict[str, list[str]] = {} + + class _Chan: + def __init__(self, out: bytes = b"", err: bytes = b"", rc: int = 0): + self._out = out + self._err = err + self._out_i = 0 + self._err_i = 0 + self._rc = rc + self._closed = False + + def recv_ready(self) -> bool: + return (not self._closed) and self._out_i < len(self._out) + + def recv(self, n: int) -> bytes: + if self._closed: + return b"" + chunk = self._out[self._out_i : self._out_i + n] + self._out_i += len(chunk) + return chunk + + def recv_stderr_ready(self) -> bool: + return (not self._closed) and self._err_i < len(self._err) + + def recv_stderr(self, n: int) -> bytes: + if self._closed: + return b"" + chunk = self._err[self._err_i : self._err_i + n] + self._err_i += len(chunk) + return chunk + + def exit_status_ready(self) -> bool: + return self._closed or ( + self._out_i >= len(self._out) and self._err_i >= len(self._err) + ) + + def recv_exit_status(self) -> int: + return self._rc + + def shutdown_write(self) -> None: + return + + def close(self) -> None: + self._closed = True + + class _Stdout: + def __init__(self, payload: bytes = b"", rc: int = 0, err: bytes = b""): + self._bio = io.BytesIO(payload) + # _ssh_run reads stdout/stderr via the underlying channel. + self.channel = _Chan(out=payload, err=err, rc=rc) + + def read(self, n: int = -1) -> bytes: + return self._bio.read(n) + + class _Stderr: + def __init__(self, payload: bytes = b""): + self._bio = io.BytesIO(payload) + + def read(self, n: int = -1) -> bytes: + return self._bio.read(n) + + class _Stdin: + def __init__(self, cmd: str): + self._cmd = cmd + stdin_by_cmd.setdefault(cmd, []) + + def write(self, s: str) -> None: + stdin_by_cmd[self._cmd].append(s) + + def flush(self) -> None: + return + + class _SFTP: + def put(self, _local: str, _remote: str) -> None: + return + + def close(self) -> None: + return + + class FakeSSH: + def __init__(self): + self._sftp = _SFTP() + + def load_system_host_keys(self): + return + + def set_missing_host_key_policy(self, _policy): + return + + def connect(self, **_kwargs): + return + + def open_sftp(self): + return self._sftp + + def exec_command(self, cmd: str, *, get_pty: bool = False, **_kwargs): + calls.append((cmd, bool(get_pty))) + + # Tar stream + if cmd.startswith("tar -cz -C"): + return (_Stdin(cmd), _Stdout(tgz, rc=0), _Stderr(b"")) + + if cmd == "mktemp -d": + return (_Stdin(cmd), _Stdout(b"/tmp/enroll-remote-789\n"), _Stderr()) + if cmd.startswith("chmod 700"): + return (_Stdin(cmd), _Stdout(b""), _Stderr()) + + # First attempt: sudo -n fails, prompting is not allowed. + if cmd.startswith("sudo -n") and " harvest " in cmd: + return ( + _Stdin(cmd), + _Stdout(b"", rc=1, err=b"sudo: a password is required\n"), + _Stderr(b"sudo: a password is required\n"), + ) + + # Retry: sudo -S succeeds and should have been fed the password via stdin. + if cmd.startswith("sudo -S") and " harvest " in cmd: + return (_Stdin(cmd), _Stdout(b"", rc=0), _Stderr(b"")) + + # chown succeeds passwordlessly (e.g., sudo timestamp is warm). + if cmd.startswith("sudo -n") and " chown -R" in cmd: + return (_Stdin(cmd), _Stdout(b"", rc=0), _Stderr(b"")) + + if cmd.startswith("rm -rf"): + return (_Stdin(cmd), _Stdout(b"", rc=0), _Stderr(b"")) + + # Fallback for unexpected commands. + return (_Stdin(cmd), _Stdout(b"", rc=0), _Stderr(b"")) + + def close(self): + return + + class RejectPolicy: + pass + + monkeypatch.setitem( + sys.modules, + "paramiko", + types.SimpleNamespace(SSHClient=FakeSSH, RejectPolicy=RejectPolicy), + ) + + out_dir = tmp_path / "out" + state_path = r.remote_harvest( + ask_become_pass=True, + getpass_fn=lambda _prompt="": "s3cr3t", + local_out_dir=out_dir, + remote_host="example.com", + remote_user="alice", + no_sudo=False, + ) + + assert state_path.exists() + assert b"ok" in state_path.read_bytes() + + # Ensure we attempted with sudo -n first, then sudo -S. + sudo_n = [c for c, _pty in calls if c.startswith("sudo -n") and " harvest " in c] + sudo_s = [c for c, _pty in calls if c.startswith("sudo -S") and " harvest " in c] + assert len(sudo_n) == 1 + assert len(sudo_s) == 1 + + # Ensure the password was written to stdin for the -S invocation. + assert stdin_by_cmd.get(sudo_s[0]) == ["s3cr3t\n"]