diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md deleted file mode 100644 index a411acc..0000000 --- a/CONTRIBUTORS.md +++ /dev/null @@ -1,5 +0,0 @@ -## Contributors - -mig5 would like to thank the following people for their contributions to Enroll. - - * [slhck](https://slhck.info/) diff --git a/enroll/cli.py b/enroll/cli.py index 55fdd0b..bb4d3f1 100644 --- a/enroll/cli.py +++ b/enroll/cli.py @@ -13,7 +13,7 @@ from .cache import new_harvest_cache_dir from .diff import compare_harvests, format_report, post_webhook, send_email from .harvest import harvest from .manifest import manifest -from .remote import remote_harvest, RemoteSudoPasswordRequired +from .remote import remote_harvest from .sopsutil import SopsError, encrypt_file_binary from .version import get_enroll_version @@ -352,17 +352,6 @@ def _add_remote_args(p: argparse.ArgumentParser) -> None: help="SSH username for --remote-host (default: local $USER).", ) - # Align terminology with Ansible: "become" == sudo. - p.add_argument( - "--ask-become-pass", - "-K", - action="store_true", - help=( - "Prompt for the remote sudo (become) password when using --remote-host " - "(similar to ansible --ask-become-pass)." - ), - ) - def main() -> None: ap = argparse.ArgumentParser(prog="enroll") @@ -634,7 +623,6 @@ def main() -> None: except OSError: pass remote_harvest( - ask_become_pass=args.ask_become_pass, local_out_dir=tmp_bundle, remote_host=args.remote_host, remote_port=int(args.remote_port), @@ -655,7 +643,6 @@ def main() -> None: else new_harvest_cache_dir(hint=args.remote_host).dir ) state = remote_harvest( - ask_become_pass=args.ask_become_pass, local_out_dir=out_dir, remote_host=args.remote_host, remote_port=int(args.remote_port), @@ -782,7 +769,6 @@ def main() -> None: except OSError: pass remote_harvest( - ask_become_pass=args.ask_become_pass, local_out_dir=tmp_bundle, remote_host=args.remote_host, remote_port=int(args.remote_port), @@ -812,7 +798,6 @@ def main() -> None: else new_harvest_cache_dir(hint=args.remote_host).dir ) remote_harvest( - ask_become_pass=args.ask_become_pass, local_out_dir=harvest_dir, remote_host=args.remote_host, remote_port=int(args.remote_port), @@ -927,11 +912,5 @@ def main() -> None: if getattr(args, "exit_code", False) and has_changes: raise SystemExit(2) - except RemoteSudoPasswordRequired: - raise SystemExit( - "error: remote sudo requires a password. Re-run with --ask-become-pass." - ) from None - except RuntimeError as e: - raise SystemExit(f"error: {e}") from None except SopsError as e: - raise SystemExit(f"error: {e}") from None + raise SystemExit(f"error: {e}") diff --git a/enroll/remote.py b/enroll/remote.py index 93cee74..b86cd08 100644 --- a/enroll/remote.py +++ b/enroll/remote.py @@ -1,117 +1,14 @@ from __future__ import annotations -import getpass import os import shlex import shutil -import sys -import time import tarfile import tempfile import zipapp from pathlib import Path from pathlib import PurePosixPath -from typing import Optional, Callable, TextIO - - -class RemoteSudoPasswordRequired(RuntimeError): - """Raised when sudo requires a password but none was provided.""" - - -def _sudo_password_required(out: str, err: str) -> bool: - """Return True if sudo output indicates it needs a password/TTY.""" - blob = (out + "\n" + err).lower() - patterns = ( - "a password is required", - "password is required", - "a terminal is required to read the password", - "no tty present and no askpass program specified", - "must have a tty to run sudo", - "sudo: sorry, you must have a tty", - "askpass", - ) - return any(p in blob for p in patterns) - - -def _sudo_not_permitted(out: str, err: str) -> bool: - """Return True if sudo output indicates the user cannot sudo at all.""" - blob = (out + "\n" + err).lower() - patterns = ( - "is not in the sudoers file", - "not allowed to execute", - "may not run sudo", - "sorry, user", - ) - return any(p in blob for p in patterns) - - -def _sudo_tty_required(out: str, err: str) -> bool: - """Return True if sudo output indicates it requires a TTY (sudoers requiretty).""" - blob = (out + "\n" + err).lower() - patterns = ( - "must have a tty", - "sorry, you must have a tty", - "sudo: sorry, you must have a tty", - "must have a tty to run sudo", - ) - return any(p in blob for p in patterns) - - -def _resolve_become_password( - ask_become_pass: bool, - *, - prompt: str = "sudo password: ", - getpass_fn: Callable[[str], str] = getpass.getpass, -) -> Optional[str]: - if ask_become_pass: - return getpass_fn(prompt) - return None - - -def remote_harvest( - *, - ask_become_pass: bool = False, - no_sudo: bool = False, - prompt: str = "sudo password: ", - getpass_fn: Optional[Callable[[str], str]] = None, - stdin: Optional[TextIO] = None, - **kwargs, -): - """Call _remote_harvest, with a safe sudo password fallback. - - Behavior: - - Run without a password unless --ask-become-pass is set. - - If the remote sudo policy requires a password and none was provided, - prompt and retry when running interactively. - """ - - # Resolve defaults at call time (easier to test/monkeypatch, and avoids capturing - # sys.stdin / getpass.getpass at import time). - if getpass_fn is None: - getpass_fn = getpass.getpass - if stdin is None: - stdin = sys.stdin - - sudo_password = _resolve_become_password( - ask_become_pass and not no_sudo, - prompt=prompt, - getpass_fn=getpass_fn, - ) - - try: - return _remote_harvest(sudo_password=sudo_password, no_sudo=no_sudo, **kwargs) - except RemoteSudoPasswordRequired: - if sudo_password is not None: - raise - - # Fallback prompt if interactive - if stdin is not None and getattr(stdin, "isatty", lambda: False)(): - pw = getpass_fn(prompt) - return _remote_harvest(sudo_password=pw, no_sudo=no_sudo, **kwargs) - - raise RemoteSudoPasswordRequired( - "Remote sudo requires a password. Re-run with --ask-become-pass." - ) +from typing import Optional def _safe_extract_tar(tar: tarfile.TarFile, dest: Path) -> None: @@ -182,14 +79,7 @@ def _build_enroll_pyz(tmpdir: Path) -> Path: return pyz_path -def _ssh_run( - ssh, - cmd: str, - *, - get_pty: bool = False, - stdin_text: Optional[str] = None, - close_stdin: bool = False, -) -> tuple[int, str, str]: +def _ssh_run(ssh, cmd: str, *, get_pty: bool = False) -> tuple[int, str, str]: """Run a command over a Paramiko SSHClient. Paramiko's exec_command runs commands without a TTY by default. @@ -200,133 +90,14 @@ def _ssh_run( We do not request a PTY for commands that stream binary data (e.g. tar/gzip output), as a PTY can corrupt the byte stream. """ - stdin, stdout, stderr = ssh.exec_command(cmd, get_pty=get_pty) - # All three file-like objects share the same underlying Channel. - chan = stdout.channel - - if stdin_text is not None and stdin is not None: - try: - stdin.write(stdin_text) - stdin.flush() - except Exception: - # If the remote side closed stdin early, ignore. - pass # nosec - finally: - if close_stdin: - # For sudo -S, a wrong password causes sudo to re-prompt and wait - # forever for more input. We try hard to deliver EOF so sudo can - # fail fast. - try: - chan.shutdown_write() # sends EOF to the remote process - except Exception: - pass # nosec - try: - stdin.close() - except Exception: - pass # nosec - - # Read incrementally to avoid blocking forever on stdout.read()/stderr.read() - # if the remote process is waiting for more input (e.g. sudo password retry). - out_chunks: list[bytes] = [] - err_chunks: list[bytes] = [] - # Keep a small tail of stderr to detect sudo retry messages without - # repeatedly joining potentially large buffers. - err_tail = b"" - - while True: - progressed = False - if chan.recv_ready(): - out_chunks.append(chan.recv(1024 * 64)) - progressed = True - if chan.recv_stderr_ready(): - chunk = chan.recv_stderr(1024 * 64) - err_chunks.append(chunk) - err_tail = (err_tail + chunk)[-4096:] - progressed = True - - # If we just attempted sudo -S with a single password line and sudo is - # asking again, detect it and stop waiting. - if close_stdin and stdin_text is not None: - blob = err_tail.lower() - if b"sorry, try again" in blob or b"incorrect password" in blob: - try: - chan.close() - except Exception: - pass # nosec - break - - # Exit once the process has exited and we have drained the buffers. - if ( - chan.exit_status_ready() - and not chan.recv_ready() - and not chan.recv_stderr_ready() - ): - break - - if not progressed: - time.sleep(0.05) - - out = b"".join(out_chunks).decode("utf-8", errors="replace") - err = b"".join(err_chunks).decode("utf-8", errors="replace") - rc = chan.recv_exit_status() if chan.exit_status_ready() else 1 + _stdin, stdout, stderr = ssh.exec_command(cmd, get_pty=get_pty) + out = stdout.read().decode("utf-8", errors="replace") + err = stderr.read().decode("utf-8", errors="replace") + rc = stdout.channel.recv_exit_status() return rc, out, err -def _ssh_run_sudo( - ssh, - cmd: str, - *, - sudo_password: Optional[str] = None, - get_pty: bool = True, -) -> tuple[int, str, str]: - """Run cmd via sudo with a safe non-interactive-first strategy. - - Strategy: - 1) Try `sudo -n`. - 2) If sudo reports a password is required and we have one, retry with - `sudo -S` and feed it via stdin. - 3) If sudo reports a password is required and we *don't* have one, raise - RemoteSudoPasswordRequired. - - We avoid requesting a PTY unless the remote sudo policy requires it. - This makes sudo -S behavior more reliable (wrong passwords fail fast - instead of blocking on a PTY). - """ - cmd_n = f"sudo -n -p '' -- {cmd}" - - # First try: never prompt, and prefer no PTY. - rc, out, err = _ssh_run(ssh, cmd_n, get_pty=False) - need_pty = False - - # Some sudoers configurations require a TTY even for passwordless sudo. - if get_pty and rc != 0 and _sudo_tty_required(out, err): - need_pty = True - rc, out, err = _ssh_run(ssh, cmd_n, get_pty=True) - - if rc == 0: - return rc, out, err - - if _sudo_not_permitted(out, err): - return rc, out, err - - if _sudo_password_required(out, err): - if sudo_password is None: - raise RemoteSudoPasswordRequired( - "Remote sudo requires a password, but none was provided." - ) - cmd_s = f"sudo -S -p '' -- {cmd}" - return _ssh_run( - ssh, - cmd_s, - get_pty=need_pty, - stdin_text=str(sudo_password) + "\n", - close_stdin=True, - ) - - return rc, out, err - - -def _remote_harvest( +def remote_harvest( *, local_out_dir: Path, remote_host: str, @@ -335,7 +106,6 @@ def _remote_harvest( remote_python: str = "python3", dangerous: bool = False, no_sudo: bool = False, - sudo_password: Optional[str] = None, include_paths: Optional[list[str]] = None, exclude_paths: Optional[list[str]] = None, ) -> Path: @@ -420,15 +190,10 @@ def _remote_harvest( argv.extend(["--exclude-path", str(p)]) _cmd = " ".join(map(shlex.quote, argv)) - if not no_sudo: - # Prefer non-interactive sudo first; retry with -S only when needed. - rc, out, err = _ssh_run_sudo( - ssh, _cmd, sudo_password=sudo_password, get_pty=True - ) - cmd = f"sudo {_cmd}" - else: - cmd = _cmd - rc, out, err = _ssh_run(ssh, cmd, get_pty=False) + cmd = f"sudo {_cmd}" if not no_sudo else _cmd + + # PTY for sudo commands (helps sudoers requiretty). + rc, out, err = _ssh_run(ssh, cmd, get_pty=(not no_sudo)) if rc != 0: raise RuntimeError( "Remote harvest failed.\n" @@ -445,17 +210,12 @@ def _remote_harvest( "Unable to determine remote username for chown. " "Pass --remote-user explicitly or use --no-sudo." ) - chown_cmd = f"chown -R {resolved_user} {rbundle}" - rc, out, err = _ssh_run_sudo( - ssh, - chown_cmd, - sudo_password=sudo_password, - get_pty=True, - ) + cmd = f"sudo chown -R {resolved_user} {rbundle}" + rc, out, err = _ssh_run(ssh, cmd, get_pty=True) if rc != 0: raise RuntimeError( "chown of harvest failed.\n" - f"Command: sudo {chown_cmd}\n" + f"Command: {cmd}\n" f"Exit code: {rc}\n" f"Stdout: {out.strip()}\n" f"Stderr: {err.strip()}" diff --git a/tests/test_cli.py b/tests/test_cli.py index 5fc9a66..4477b24 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,7 +1,5 @@ import sys -import pytest - import enroll.cli as cli @@ -260,113 +258,6 @@ def test_cli_single_shot_remote_without_harvest_prints_state_path( assert ("manifest", str(cache_dir), str(ansible_dir), "example.test") in calls -def test_cli_harvest_remote_ask_become_pass_prompts_and_passes_password( - monkeypatch, tmp_path -): - from enroll.cache import HarvestCache - import enroll.remote as r - - cache_dir = tmp_path / "cache" - cache_dir.mkdir() - - called = {} - - def fake_cache_dir(*, hint=None): - return HarvestCache(dir=cache_dir) - - def fake__remote_harvest(*, sudo_password=None, **kwargs): - called["sudo_password"] = sudo_password - return cache_dir / "state.json" - - monkeypatch.setattr(cli, "new_harvest_cache_dir", fake_cache_dir) - monkeypatch.setattr(r, "_remote_harvest", fake__remote_harvest) - monkeypatch.setattr(r.getpass, "getpass", lambda _prompt="": "pw123") - - monkeypatch.setattr( - sys, - "argv", - [ - "enroll", - "harvest", - "--remote-host", - "example.test", - "--ask-become-pass", - ], - ) - - cli.main() - assert called["sudo_password"] == "pw123" - - -def test_cli_harvest_remote_password_required_fallback_prompts_and_retries( - monkeypatch, tmp_path -): - from enroll.cache import HarvestCache - import enroll.remote as r - - cache_dir = tmp_path / "cache" - cache_dir.mkdir() - - def fake_cache_dir(*, hint=None): - return HarvestCache(dir=cache_dir) - - calls = [] - - def fake__remote_harvest(*, sudo_password=None, **kwargs): - calls.append(sudo_password) - if sudo_password is None: - raise r.RemoteSudoPasswordRequired("pw required") - return cache_dir / "state.json" - - class _TTYStdin: - def isatty(self): - return True - - monkeypatch.setattr(cli, "new_harvest_cache_dir", fake_cache_dir) - monkeypatch.setattr(r, "_remote_harvest", fake__remote_harvest) - monkeypatch.setattr(r.getpass, "getpass", lambda _prompt="": "pw456") - monkeypatch.setattr(sys, "stdin", _TTYStdin()) - - monkeypatch.setattr( - sys, "argv", ["enroll", "harvest", "--remote-host", "example.test"] - ) - - cli.main() - assert calls == [None, "pw456"] - - -def test_cli_harvest_remote_password_required_noninteractive_errors( - monkeypatch, tmp_path -): - from enroll.cache import HarvestCache - import enroll.remote as r - - cache_dir = tmp_path / "cache" - cache_dir.mkdir() - - def fake_cache_dir(*, hint=None): - return HarvestCache(dir=cache_dir) - - def fake__remote_harvest(*, sudo_password=None, **kwargs): - raise r.RemoteSudoPasswordRequired("pw required") - - class _NoTTYStdin: - def isatty(self): - return False - - monkeypatch.setattr(cli, "new_harvest_cache_dir", fake_cache_dir) - monkeypatch.setattr(r, "_remote_harvest", fake__remote_harvest) - monkeypatch.setattr(sys, "stdin", _NoTTYStdin()) - - monkeypatch.setattr( - sys, "argv", ["enroll", "harvest", "--remote-host", "example.test"] - ) - - with pytest.raises(SystemExit) as e: - cli.main() - assert "--ask-become-pass" in str(e.value) - - def test_cli_manifest_common_args(monkeypatch, tmp_path): """Ensure --fqdn and jinjaturtle mode flags are forwarded correctly.""" diff --git a/tests/test_remote.py b/tests/test_remote.py index 6b4ab01..1f9c89b 100644 --- a/tests/test_remote.py +++ b/tests/test_remote.py @@ -69,53 +69,16 @@ def test_remote_harvest_happy_path(tmp_path: Path, monkeypatch): calls: list[tuple[str, bool]] = [] class _Chan: - def __init__(self, out: bytes = b"", err: bytes = b"", rc: int = 0): - self._out = out - self._err = err - self._out_i = 0 - self._err_i = 0 + def __init__(self, rc: int = 0): self._rc = rc - self._closed = False - - def recv_ready(self) -> bool: - return (not self._closed) and self._out_i < len(self._out) - - def recv(self, n: int) -> bytes: - if self._closed: - return b"" - chunk = self._out[self._out_i : self._out_i + n] - self._out_i += len(chunk) - return chunk - - def recv_stderr_ready(self) -> bool: - return (not self._closed) and self._err_i < len(self._err) - - def recv_stderr(self, n: int) -> bytes: - if self._closed: - return b"" - chunk = self._err[self._err_i : self._err_i + n] - self._err_i += len(chunk) - return chunk - - def exit_status_ready(self) -> bool: - return self._closed or ( - self._out_i >= len(self._out) and self._err_i >= len(self._err) - ) def recv_exit_status(self) -> int: return self._rc - def shutdown_write(self) -> None: - return - - def close(self) -> None: - self._closed = True - class _Stdout: - def __init__(self, payload: bytes = b"", rc: int = 0, err: bytes = b""): + def __init__(self, payload: bytes = b"", rc: int = 0): self._bio = io.BytesIO(payload) - # _ssh_run reads stdout/stderr via the underlying channel. - self.channel = _Chan(out=payload, err=err, rc=rc) + self.channel = _Chan(rc) def read(self, n: int = -1) -> bytes: return self._bio.read(n) @@ -167,20 +130,10 @@ def test_remote_harvest_happy_path(tmp_path: Path, monkeypatch): return (None, _Stdout(b"/tmp/enroll-remote-123\n"), _Stderr()) if cmd.startswith("chmod 700"): return (None, _Stdout(b""), _Stderr()) - if cmd.startswith("sudo -n") and " harvest " in cmd: - if not get_pty: - msg = b"sudo: sorry, you must have a tty to run sudo\n" - return (None, _Stdout(b"", rc=1, err=msg), _Stderr(msg)) - return (None, _Stdout(b"", rc=0), _Stderr(b"")) - if cmd.startswith("sudo -S") and " harvest " in cmd: - return (None, _Stdout(b""), _Stderr()) if " harvest " in cmd: return (None, _Stdout(b""), _Stderr()) - if cmd.startswith("sudo -n") and " chown -R" in cmd: - if not get_pty: - msg = b"sudo: sorry, you must have a tty to run sudo\n" - return (None, _Stdout(b"", rc=1, err=msg), _Stderr(msg)) - return (None, _Stdout(b"", rc=0), _Stderr(b"")) + if cmd.startswith("sudo chown -R"): + return (None, _Stdout(b""), _Stderr()) if cmd.startswith("rm -rf"): return (None, _Stdout(b""), _Stderr()) @@ -201,7 +154,6 @@ def test_remote_harvest_happy_path(tmp_path: Path, monkeypatch): out_dir = tmp_path / "out" state_path = r.remote_harvest( - ask_become_pass=False, local_out_dir=out_dir, remote_host="example.com", remote_port=2222, @@ -223,21 +175,13 @@ def test_remote_harvest_happy_path(tmp_path: Path, monkeypatch): assert "--include-path" in joined assert "--exclude-path" in joined - # Ensure we fall back to PTY only when sudo reports it is required. - assert any(c == "id -un" and pty is False for c, pty in calls) - - sudo_harvest = [ - (c, pty) for c, pty in calls if c.startswith("sudo -n") and " harvest " in c - ] - assert any(pty is False for _c, pty in sudo_harvest) - assert any(pty is True for _c, pty in sudo_harvest) - - sudo_chown = [ - (c, pty) for c, pty in calls if c.startswith("sudo -n") and " chown -R" in c - ] - assert any(pty is False for _c, pty in sudo_chown) - assert any(pty is True for _c, pty in sudo_chown) - + # Ensure PTY is used for sudo commands (sudoers requiretty) but not for tar. + pty_by_cmd = {c: pty for c, pty in calls} + assert pty_by_cmd.get("id -un") is False + assert any( + c.startswith("sudo") and " harvest " in c and pty is True for c, pty in calls + ) + assert any(c.startswith("sudo chown -R") and pty is True for c, pty in calls) assert any(c.startswith("tar -cz -C") and pty is False for c, pty in calls) @@ -260,53 +204,16 @@ def test_remote_harvest_no_sudo_does_not_request_pty_or_chown( calls: list[tuple[str, bool]] = [] class _Chan: - def __init__(self, out: bytes = b"", err: bytes = b"", rc: int = 0): - self._out = out - self._err = err - self._out_i = 0 - self._err_i = 0 + def __init__(self, rc: int = 0): self._rc = rc - self._closed = False - - def recv_ready(self) -> bool: - return (not self._closed) and self._out_i < len(self._out) - - def recv(self, n: int) -> bytes: - if self._closed: - return b"" - chunk = self._out[self._out_i : self._out_i + n] - self._out_i += len(chunk) - return chunk - - def recv_stderr_ready(self) -> bool: - return (not self._closed) and self._err_i < len(self._err) - - def recv_stderr(self, n: int) -> bytes: - if self._closed: - return b"" - chunk = self._err[self._err_i : self._err_i + n] - self._err_i += len(chunk) - return chunk - - def exit_status_ready(self) -> bool: - return self._closed or ( - self._out_i >= len(self._out) and self._err_i >= len(self._err) - ) def recv_exit_status(self) -> int: return self._rc - def shutdown_write(self) -> None: - return - - def close(self) -> None: - self._closed = True - class _Stdout: - def __init__(self, payload: bytes = b"", rc: int = 0, err: bytes = b""): + def __init__(self, payload: bytes = b"", rc: int = 0): self._bio = io.BytesIO(payload) - # _ssh_run reads stdout/stderr via the underlying channel. - self.channel = _Chan(out=payload, err=err, rc=rc) + self.channel = _Chan(rc) def read(self, n: int = -1) -> bytes: return self._bio.read(n) @@ -371,7 +278,6 @@ def test_remote_harvest_no_sudo_does_not_request_pty_or_chown( out_dir = tmp_path / "out" r.remote_harvest( - ask_become_pass=False, local_out_dir=out_dir, remote_host="example.com", remote_user="alice", @@ -382,186 +288,3 @@ def test_remote_harvest_no_sudo_does_not_request_pty_or_chown( assert "sudo" not in joined assert "sudo chown" not in joined assert any(" harvest " in c and pty is False for c, pty in calls) - - -def test_remote_harvest_sudo_password_retry_uses_sudo_s_and_writes_password( - tmp_path: Path, monkeypatch -): - """If sudo requires a password, we should fall back from -n to -S and feed stdin.""" - import sys - import types - - import enroll.remote as r - - # Avoid building a real zipapp; just create a file. - monkeypatch.setattr( - r, - "_build_enroll_pyz", - lambda td: (Path(td) / "enroll.pyz").write_bytes(b"PYZ") - or (Path(td) / "enroll.pyz"), - ) - - tgz = _make_tgz_bytes({"state.json": b'{"ok": true}\n'}) - calls: list[tuple[str, bool]] = [] - stdin_by_cmd: dict[str, list[str]] = {} - - class _Chan: - def __init__(self, out: bytes = b"", err: bytes = b"", rc: int = 0): - self._out = out - self._err = err - self._out_i = 0 - self._err_i = 0 - self._rc = rc - self._closed = False - - def recv_ready(self) -> bool: - return (not self._closed) and self._out_i < len(self._out) - - def recv(self, n: int) -> bytes: - if self._closed: - return b"" - chunk = self._out[self._out_i : self._out_i + n] - self._out_i += len(chunk) - return chunk - - def recv_stderr_ready(self) -> bool: - return (not self._closed) and self._err_i < len(self._err) - - def recv_stderr(self, n: int) -> bytes: - if self._closed: - return b"" - chunk = self._err[self._err_i : self._err_i + n] - self._err_i += len(chunk) - return chunk - - def exit_status_ready(self) -> bool: - return self._closed or ( - self._out_i >= len(self._out) and self._err_i >= len(self._err) - ) - - def recv_exit_status(self) -> int: - return self._rc - - def shutdown_write(self) -> None: - return - - def close(self) -> None: - self._closed = True - - class _Stdout: - def __init__(self, payload: bytes = b"", rc: int = 0, err: bytes = b""): - self._bio = io.BytesIO(payload) - # _ssh_run reads stdout/stderr via the underlying channel. - self.channel = _Chan(out=payload, err=err, rc=rc) - - def read(self, n: int = -1) -> bytes: - return self._bio.read(n) - - class _Stderr: - def __init__(self, payload: bytes = b""): - self._bio = io.BytesIO(payload) - - def read(self, n: int = -1) -> bytes: - return self._bio.read(n) - - class _Stdin: - def __init__(self, cmd: str): - self._cmd = cmd - stdin_by_cmd.setdefault(cmd, []) - - def write(self, s: str) -> None: - stdin_by_cmd[self._cmd].append(s) - - def flush(self) -> None: - return - - class _SFTP: - def put(self, _local: str, _remote: str) -> None: - return - - def close(self) -> None: - return - - class FakeSSH: - def __init__(self): - self._sftp = _SFTP() - - def load_system_host_keys(self): - return - - def set_missing_host_key_policy(self, _policy): - return - - def connect(self, **_kwargs): - return - - def open_sftp(self): - return self._sftp - - def exec_command(self, cmd: str, *, get_pty: bool = False, **_kwargs): - calls.append((cmd, bool(get_pty))) - - # Tar stream - if cmd.startswith("tar -cz -C"): - return (_Stdin(cmd), _Stdout(tgz, rc=0), _Stderr(b"")) - - if cmd == "mktemp -d": - return (_Stdin(cmd), _Stdout(b"/tmp/enroll-remote-789\n"), _Stderr()) - if cmd.startswith("chmod 700"): - return (_Stdin(cmd), _Stdout(b""), _Stderr()) - - # First attempt: sudo -n fails, prompting is not allowed. - if cmd.startswith("sudo -n") and " harvest " in cmd: - return ( - _Stdin(cmd), - _Stdout(b"", rc=1, err=b"sudo: a password is required\n"), - _Stderr(b"sudo: a password is required\n"), - ) - - # Retry: sudo -S succeeds and should have been fed the password via stdin. - if cmd.startswith("sudo -S") and " harvest " in cmd: - return (_Stdin(cmd), _Stdout(b"", rc=0), _Stderr(b"")) - - # chown succeeds passwordlessly (e.g., sudo timestamp is warm). - if cmd.startswith("sudo -n") and " chown -R" in cmd: - return (_Stdin(cmd), _Stdout(b"", rc=0), _Stderr(b"")) - - if cmd.startswith("rm -rf"): - return (_Stdin(cmd), _Stdout(b"", rc=0), _Stderr(b"")) - - # Fallback for unexpected commands. - return (_Stdin(cmd), _Stdout(b"", rc=0), _Stderr(b"")) - - def close(self): - return - - class RejectPolicy: - pass - - monkeypatch.setitem( - sys.modules, - "paramiko", - types.SimpleNamespace(SSHClient=FakeSSH, RejectPolicy=RejectPolicy), - ) - - out_dir = tmp_path / "out" - state_path = r.remote_harvest( - ask_become_pass=True, - getpass_fn=lambda _prompt="": "s3cr3t", - local_out_dir=out_dir, - remote_host="example.com", - remote_user="alice", - no_sudo=False, - ) - - assert state_path.exists() - assert b"ok" in state_path.read_bytes() - - # Ensure we attempted with sudo -n first, then sudo -S. - sudo_n = [c for c, _pty in calls if c.startswith("sudo -n") and " harvest " in c] - sudo_s = [c for c, _pty in calls if c.startswith("sudo -S") and " harvest " in c] - assert len(sudo_n) == 1 - assert len(sudo_s) == 1 - - # Ensure the password was written to stdin for the -S invocation. - assert stdin_by_cmd.get(sudo_s[0]) == ["s3cr3t\n"]