Compare commits

..

No commits in common. "a2be708a315f39fdcd07f8251b5347072e4b614b" and "fd55bcde9b4bb561cfdbc39403c62742bc9beec0" have entirely different histories.

5 changed files with 31 additions and 683 deletions

View file

@ -1,5 +0,0 @@
## Contributors
mig5 would like to thank the following people for their contributions to Enroll.
* [slhck](https://slhck.info/)

View file

@ -13,7 +13,7 @@ from .cache import new_harvest_cache_dir
from .diff import compare_harvests, format_report, post_webhook, send_email from .diff import compare_harvests, format_report, post_webhook, send_email
from .harvest import harvest from .harvest import harvest
from .manifest import manifest from .manifest import manifest
from .remote import remote_harvest, RemoteSudoPasswordRequired from .remote import remote_harvest
from .sopsutil import SopsError, encrypt_file_binary from .sopsutil import SopsError, encrypt_file_binary
from .version import get_enroll_version from .version import get_enroll_version
@ -352,17 +352,6 @@ def _add_remote_args(p: argparse.ArgumentParser) -> None:
help="SSH username for --remote-host (default: local $USER).", help="SSH username for --remote-host (default: local $USER).",
) )
# Align terminology with Ansible: "become" == sudo.
p.add_argument(
"--ask-become-pass",
"-K",
action="store_true",
help=(
"Prompt for the remote sudo (become) password when using --remote-host "
"(similar to ansible --ask-become-pass)."
),
)
def main() -> None: def main() -> None:
ap = argparse.ArgumentParser(prog="enroll") ap = argparse.ArgumentParser(prog="enroll")
@ -634,7 +623,6 @@ def main() -> None:
except OSError: except OSError:
pass pass
remote_harvest( remote_harvest(
ask_become_pass=args.ask_become_pass,
local_out_dir=tmp_bundle, local_out_dir=tmp_bundle,
remote_host=args.remote_host, remote_host=args.remote_host,
remote_port=int(args.remote_port), remote_port=int(args.remote_port),
@ -655,7 +643,6 @@ def main() -> None:
else new_harvest_cache_dir(hint=args.remote_host).dir else new_harvest_cache_dir(hint=args.remote_host).dir
) )
state = remote_harvest( state = remote_harvest(
ask_become_pass=args.ask_become_pass,
local_out_dir=out_dir, local_out_dir=out_dir,
remote_host=args.remote_host, remote_host=args.remote_host,
remote_port=int(args.remote_port), remote_port=int(args.remote_port),
@ -782,7 +769,6 @@ def main() -> None:
except OSError: except OSError:
pass pass
remote_harvest( remote_harvest(
ask_become_pass=args.ask_become_pass,
local_out_dir=tmp_bundle, local_out_dir=tmp_bundle,
remote_host=args.remote_host, remote_host=args.remote_host,
remote_port=int(args.remote_port), remote_port=int(args.remote_port),
@ -812,7 +798,6 @@ def main() -> None:
else new_harvest_cache_dir(hint=args.remote_host).dir else new_harvest_cache_dir(hint=args.remote_host).dir
) )
remote_harvest( remote_harvest(
ask_become_pass=args.ask_become_pass,
local_out_dir=harvest_dir, local_out_dir=harvest_dir,
remote_host=args.remote_host, remote_host=args.remote_host,
remote_port=int(args.remote_port), remote_port=int(args.remote_port),
@ -927,11 +912,5 @@ def main() -> None:
if getattr(args, "exit_code", False) and has_changes: if getattr(args, "exit_code", False) and has_changes:
raise SystemExit(2) raise SystemExit(2)
except RemoteSudoPasswordRequired:
raise SystemExit(
"error: remote sudo requires a password. Re-run with --ask-become-pass."
) from None
except RuntimeError as e:
raise SystemExit(f"error: {e}") from None
except SopsError as e: except SopsError as e:
raise SystemExit(f"error: {e}") from None raise SystemExit(f"error: {e}")

View file

@ -1,117 +1,14 @@
from __future__ import annotations from __future__ import annotations
import getpass
import os import os
import shlex import shlex
import shutil import shutil
import sys
import time
import tarfile import tarfile
import tempfile import tempfile
import zipapp import zipapp
from pathlib import Path from pathlib import Path
from pathlib import PurePosixPath from pathlib import PurePosixPath
from typing import Optional, Callable, TextIO from typing import Optional
class RemoteSudoPasswordRequired(RuntimeError):
"""Raised when sudo requires a password but none was provided."""
def _sudo_password_required(out: str, err: str) -> bool:
"""Return True if sudo output indicates it needs a password/TTY."""
blob = (out + "\n" + err).lower()
patterns = (
"a password is required",
"password is required",
"a terminal is required to read the password",
"no tty present and no askpass program specified",
"must have a tty to run sudo",
"sudo: sorry, you must have a tty",
"askpass",
)
return any(p in blob for p in patterns)
def _sudo_not_permitted(out: str, err: str) -> bool:
"""Return True if sudo output indicates the user cannot sudo at all."""
blob = (out + "\n" + err).lower()
patterns = (
"is not in the sudoers file",
"not allowed to execute",
"may not run sudo",
"sorry, user",
)
return any(p in blob for p in patterns)
def _sudo_tty_required(out: str, err: str) -> bool:
"""Return True if sudo output indicates it requires a TTY (sudoers requiretty)."""
blob = (out + "\n" + err).lower()
patterns = (
"must have a tty",
"sorry, you must have a tty",
"sudo: sorry, you must have a tty",
"must have a tty to run sudo",
)
return any(p in blob for p in patterns)
def _resolve_become_password(
ask_become_pass: bool,
*,
prompt: str = "sudo password: ",
getpass_fn: Callable[[str], str] = getpass.getpass,
) -> Optional[str]:
if ask_become_pass:
return getpass_fn(prompt)
return None
def remote_harvest(
*,
ask_become_pass: bool = False,
no_sudo: bool = False,
prompt: str = "sudo password: ",
getpass_fn: Optional[Callable[[str], str]] = None,
stdin: Optional[TextIO] = None,
**kwargs,
):
"""Call _remote_harvest, with a safe sudo password fallback.
Behavior:
- Run without a password unless --ask-become-pass is set.
- If the remote sudo policy requires a password and none was provided,
prompt and retry when running interactively.
"""
# Resolve defaults at call time (easier to test/monkeypatch, and avoids capturing
# sys.stdin / getpass.getpass at import time).
if getpass_fn is None:
getpass_fn = getpass.getpass
if stdin is None:
stdin = sys.stdin
sudo_password = _resolve_become_password(
ask_become_pass and not no_sudo,
prompt=prompt,
getpass_fn=getpass_fn,
)
try:
return _remote_harvest(sudo_password=sudo_password, no_sudo=no_sudo, **kwargs)
except RemoteSudoPasswordRequired:
if sudo_password is not None:
raise
# Fallback prompt if interactive
if stdin is not None and getattr(stdin, "isatty", lambda: False)():
pw = getpass_fn(prompt)
return _remote_harvest(sudo_password=pw, no_sudo=no_sudo, **kwargs)
raise RemoteSudoPasswordRequired(
"Remote sudo requires a password. Re-run with --ask-become-pass."
)
def _safe_extract_tar(tar: tarfile.TarFile, dest: Path) -> None: def _safe_extract_tar(tar: tarfile.TarFile, dest: Path) -> None:
@ -182,14 +79,7 @@ def _build_enroll_pyz(tmpdir: Path) -> Path:
return pyz_path return pyz_path
def _ssh_run( def _ssh_run(ssh, cmd: str, *, get_pty: bool = False) -> tuple[int, str, str]:
ssh,
cmd: str,
*,
get_pty: bool = False,
stdin_text: Optional[str] = None,
close_stdin: bool = False,
) -> tuple[int, str, str]:
"""Run a command over a Paramiko SSHClient. """Run a command over a Paramiko SSHClient.
Paramiko's exec_command runs commands without a TTY by default. Paramiko's exec_command runs commands without a TTY by default.
@ -200,133 +90,14 @@ def _ssh_run(
We do not request a PTY for commands that stream binary data We do not request a PTY for commands that stream binary data
(e.g. tar/gzip output), as a PTY can corrupt the byte stream. (e.g. tar/gzip output), as a PTY can corrupt the byte stream.
""" """
stdin, stdout, stderr = ssh.exec_command(cmd, get_pty=get_pty) _stdin, stdout, stderr = ssh.exec_command(cmd, get_pty=get_pty)
# All three file-like objects share the same underlying Channel. out = stdout.read().decode("utf-8", errors="replace")
chan = stdout.channel err = stderr.read().decode("utf-8", errors="replace")
rc = stdout.channel.recv_exit_status()
if stdin_text is not None and stdin is not None:
try:
stdin.write(stdin_text)
stdin.flush()
except Exception:
# If the remote side closed stdin early, ignore.
pass # nosec
finally:
if close_stdin:
# For sudo -S, a wrong password causes sudo to re-prompt and wait
# forever for more input. We try hard to deliver EOF so sudo can
# fail fast.
try:
chan.shutdown_write() # sends EOF to the remote process
except Exception:
pass # nosec
try:
stdin.close()
except Exception:
pass # nosec
# Read incrementally to avoid blocking forever on stdout.read()/stderr.read()
# if the remote process is waiting for more input (e.g. sudo password retry).
out_chunks: list[bytes] = []
err_chunks: list[bytes] = []
# Keep a small tail of stderr to detect sudo retry messages without
# repeatedly joining potentially large buffers.
err_tail = b""
while True:
progressed = False
if chan.recv_ready():
out_chunks.append(chan.recv(1024 * 64))
progressed = True
if chan.recv_stderr_ready():
chunk = chan.recv_stderr(1024 * 64)
err_chunks.append(chunk)
err_tail = (err_tail + chunk)[-4096:]
progressed = True
# If we just attempted sudo -S with a single password line and sudo is
# asking again, detect it and stop waiting.
if close_stdin and stdin_text is not None:
blob = err_tail.lower()
if b"sorry, try again" in blob or b"incorrect password" in blob:
try:
chan.close()
except Exception:
pass # nosec
break
# Exit once the process has exited and we have drained the buffers.
if (
chan.exit_status_ready()
and not chan.recv_ready()
and not chan.recv_stderr_ready()
):
break
if not progressed:
time.sleep(0.05)
out = b"".join(out_chunks).decode("utf-8", errors="replace")
err = b"".join(err_chunks).decode("utf-8", errors="replace")
rc = chan.recv_exit_status() if chan.exit_status_ready() else 1
return rc, out, err return rc, out, err
def _ssh_run_sudo( def remote_harvest(
ssh,
cmd: str,
*,
sudo_password: Optional[str] = None,
get_pty: bool = True,
) -> tuple[int, str, str]:
"""Run cmd via sudo with a safe non-interactive-first strategy.
Strategy:
1) Try `sudo -n`.
2) If sudo reports a password is required and we have one, retry with
`sudo -S` and feed it via stdin.
3) If sudo reports a password is required and we *don't* have one, raise
RemoteSudoPasswordRequired.
We avoid requesting a PTY unless the remote sudo policy requires it.
This makes sudo -S behavior more reliable (wrong passwords fail fast
instead of blocking on a PTY).
"""
cmd_n = f"sudo -n -p '' -- {cmd}"
# First try: never prompt, and prefer no PTY.
rc, out, err = _ssh_run(ssh, cmd_n, get_pty=False)
need_pty = False
# Some sudoers configurations require a TTY even for passwordless sudo.
if get_pty and rc != 0 and _sudo_tty_required(out, err):
need_pty = True
rc, out, err = _ssh_run(ssh, cmd_n, get_pty=True)
if rc == 0:
return rc, out, err
if _sudo_not_permitted(out, err):
return rc, out, err
if _sudo_password_required(out, err):
if sudo_password is None:
raise RemoteSudoPasswordRequired(
"Remote sudo requires a password, but none was provided."
)
cmd_s = f"sudo -S -p '' -- {cmd}"
return _ssh_run(
ssh,
cmd_s,
get_pty=need_pty,
stdin_text=str(sudo_password) + "\n",
close_stdin=True,
)
return rc, out, err
def _remote_harvest(
*, *,
local_out_dir: Path, local_out_dir: Path,
remote_host: str, remote_host: str,
@ -335,7 +106,6 @@ def _remote_harvest(
remote_python: str = "python3", remote_python: str = "python3",
dangerous: bool = False, dangerous: bool = False,
no_sudo: bool = False, no_sudo: bool = False,
sudo_password: Optional[str] = None,
include_paths: Optional[list[str]] = None, include_paths: Optional[list[str]] = None,
exclude_paths: Optional[list[str]] = None, exclude_paths: Optional[list[str]] = None,
) -> Path: ) -> Path:
@ -420,15 +190,10 @@ def _remote_harvest(
argv.extend(["--exclude-path", str(p)]) argv.extend(["--exclude-path", str(p)])
_cmd = " ".join(map(shlex.quote, argv)) _cmd = " ".join(map(shlex.quote, argv))
if not no_sudo: cmd = f"sudo {_cmd}" if not no_sudo else _cmd
# Prefer non-interactive sudo first; retry with -S only when needed.
rc, out, err = _ssh_run_sudo( # PTY for sudo commands (helps sudoers requiretty).
ssh, _cmd, sudo_password=sudo_password, get_pty=True rc, out, err = _ssh_run(ssh, cmd, get_pty=(not no_sudo))
)
cmd = f"sudo {_cmd}"
else:
cmd = _cmd
rc, out, err = _ssh_run(ssh, cmd, get_pty=False)
if rc != 0: if rc != 0:
raise RuntimeError( raise RuntimeError(
"Remote harvest failed.\n" "Remote harvest failed.\n"
@ -445,17 +210,12 @@ def _remote_harvest(
"Unable to determine remote username for chown. " "Unable to determine remote username for chown. "
"Pass --remote-user explicitly or use --no-sudo." "Pass --remote-user explicitly or use --no-sudo."
) )
chown_cmd = f"chown -R {resolved_user} {rbundle}" cmd = f"sudo chown -R {resolved_user} {rbundle}"
rc, out, err = _ssh_run_sudo( rc, out, err = _ssh_run(ssh, cmd, get_pty=True)
ssh,
chown_cmd,
sudo_password=sudo_password,
get_pty=True,
)
if rc != 0: if rc != 0:
raise RuntimeError( raise RuntimeError(
"chown of harvest failed.\n" "chown of harvest failed.\n"
f"Command: sudo {chown_cmd}\n" f"Command: {cmd}\n"
f"Exit code: {rc}\n" f"Exit code: {rc}\n"
f"Stdout: {out.strip()}\n" f"Stdout: {out.strip()}\n"
f"Stderr: {err.strip()}" f"Stderr: {err.strip()}"

View file

@ -1,7 +1,5 @@
import sys import sys
import pytest
import enroll.cli as cli import enroll.cli as cli
@ -260,113 +258,6 @@ def test_cli_single_shot_remote_without_harvest_prints_state_path(
assert ("manifest", str(cache_dir), str(ansible_dir), "example.test") in calls assert ("manifest", str(cache_dir), str(ansible_dir), "example.test") in calls
def test_cli_harvest_remote_ask_become_pass_prompts_and_passes_password(
monkeypatch, tmp_path
):
from enroll.cache import HarvestCache
import enroll.remote as r
cache_dir = tmp_path / "cache"
cache_dir.mkdir()
called = {}
def fake_cache_dir(*, hint=None):
return HarvestCache(dir=cache_dir)
def fake__remote_harvest(*, sudo_password=None, **kwargs):
called["sudo_password"] = sudo_password
return cache_dir / "state.json"
monkeypatch.setattr(cli, "new_harvest_cache_dir", fake_cache_dir)
monkeypatch.setattr(r, "_remote_harvest", fake__remote_harvest)
monkeypatch.setattr(r.getpass, "getpass", lambda _prompt="": "pw123")
monkeypatch.setattr(
sys,
"argv",
[
"enroll",
"harvest",
"--remote-host",
"example.test",
"--ask-become-pass",
],
)
cli.main()
assert called["sudo_password"] == "pw123"
def test_cli_harvest_remote_password_required_fallback_prompts_and_retries(
monkeypatch, tmp_path
):
from enroll.cache import HarvestCache
import enroll.remote as r
cache_dir = tmp_path / "cache"
cache_dir.mkdir()
def fake_cache_dir(*, hint=None):
return HarvestCache(dir=cache_dir)
calls = []
def fake__remote_harvest(*, sudo_password=None, **kwargs):
calls.append(sudo_password)
if sudo_password is None:
raise r.RemoteSudoPasswordRequired("pw required")
return cache_dir / "state.json"
class _TTYStdin:
def isatty(self):
return True
monkeypatch.setattr(cli, "new_harvest_cache_dir", fake_cache_dir)
monkeypatch.setattr(r, "_remote_harvest", fake__remote_harvest)
monkeypatch.setattr(r.getpass, "getpass", lambda _prompt="": "pw456")
monkeypatch.setattr(sys, "stdin", _TTYStdin())
monkeypatch.setattr(
sys, "argv", ["enroll", "harvest", "--remote-host", "example.test"]
)
cli.main()
assert calls == [None, "pw456"]
def test_cli_harvest_remote_password_required_noninteractive_errors(
monkeypatch, tmp_path
):
from enroll.cache import HarvestCache
import enroll.remote as r
cache_dir = tmp_path / "cache"
cache_dir.mkdir()
def fake_cache_dir(*, hint=None):
return HarvestCache(dir=cache_dir)
def fake__remote_harvest(*, sudo_password=None, **kwargs):
raise r.RemoteSudoPasswordRequired("pw required")
class _NoTTYStdin:
def isatty(self):
return False
monkeypatch.setattr(cli, "new_harvest_cache_dir", fake_cache_dir)
monkeypatch.setattr(r, "_remote_harvest", fake__remote_harvest)
monkeypatch.setattr(sys, "stdin", _NoTTYStdin())
monkeypatch.setattr(
sys, "argv", ["enroll", "harvest", "--remote-host", "example.test"]
)
with pytest.raises(SystemExit) as e:
cli.main()
assert "--ask-become-pass" in str(e.value)
def test_cli_manifest_common_args(monkeypatch, tmp_path): def test_cli_manifest_common_args(monkeypatch, tmp_path):
"""Ensure --fqdn and jinjaturtle mode flags are forwarded correctly.""" """Ensure --fqdn and jinjaturtle mode flags are forwarded correctly."""

View file

@ -69,53 +69,16 @@ def test_remote_harvest_happy_path(tmp_path: Path, monkeypatch):
calls: list[tuple[str, bool]] = [] calls: list[tuple[str, bool]] = []
class _Chan: class _Chan:
def __init__(self, out: bytes = b"", err: bytes = b"", rc: int = 0): def __init__(self, rc: int = 0):
self._out = out
self._err = err
self._out_i = 0
self._err_i = 0
self._rc = rc self._rc = rc
self._closed = False
def recv_ready(self) -> bool:
return (not self._closed) and self._out_i < len(self._out)
def recv(self, n: int) -> bytes:
if self._closed:
return b""
chunk = self._out[self._out_i : self._out_i + n]
self._out_i += len(chunk)
return chunk
def recv_stderr_ready(self) -> bool:
return (not self._closed) and self._err_i < len(self._err)
def recv_stderr(self, n: int) -> bytes:
if self._closed:
return b""
chunk = self._err[self._err_i : self._err_i + n]
self._err_i += len(chunk)
return chunk
def exit_status_ready(self) -> bool:
return self._closed or (
self._out_i >= len(self._out) and self._err_i >= len(self._err)
)
def recv_exit_status(self) -> int: def recv_exit_status(self) -> int:
return self._rc return self._rc
def shutdown_write(self) -> None:
return
def close(self) -> None:
self._closed = True
class _Stdout: class _Stdout:
def __init__(self, payload: bytes = b"", rc: int = 0, err: bytes = b""): def __init__(self, payload: bytes = b"", rc: int = 0):
self._bio = io.BytesIO(payload) self._bio = io.BytesIO(payload)
# _ssh_run reads stdout/stderr via the underlying channel. self.channel = _Chan(rc)
self.channel = _Chan(out=payload, err=err, rc=rc)
def read(self, n: int = -1) -> bytes: def read(self, n: int = -1) -> bytes:
return self._bio.read(n) return self._bio.read(n)
@ -167,20 +130,10 @@ def test_remote_harvest_happy_path(tmp_path: Path, monkeypatch):
return (None, _Stdout(b"/tmp/enroll-remote-123\n"), _Stderr()) return (None, _Stdout(b"/tmp/enroll-remote-123\n"), _Stderr())
if cmd.startswith("chmod 700"): if cmd.startswith("chmod 700"):
return (None, _Stdout(b""), _Stderr()) return (None, _Stdout(b""), _Stderr())
if cmd.startswith("sudo -n") and " harvest " in cmd:
if not get_pty:
msg = b"sudo: sorry, you must have a tty to run sudo\n"
return (None, _Stdout(b"", rc=1, err=msg), _Stderr(msg))
return (None, _Stdout(b"", rc=0), _Stderr(b""))
if cmd.startswith("sudo -S") and " harvest " in cmd:
return (None, _Stdout(b""), _Stderr())
if " harvest " in cmd: if " harvest " in cmd:
return (None, _Stdout(b""), _Stderr()) return (None, _Stdout(b""), _Stderr())
if cmd.startswith("sudo -n") and " chown -R" in cmd: if cmd.startswith("sudo chown -R"):
if not get_pty: return (None, _Stdout(b""), _Stderr())
msg = b"sudo: sorry, you must have a tty to run sudo\n"
return (None, _Stdout(b"", rc=1, err=msg), _Stderr(msg))
return (None, _Stdout(b"", rc=0), _Stderr(b""))
if cmd.startswith("rm -rf"): if cmd.startswith("rm -rf"):
return (None, _Stdout(b""), _Stderr()) return (None, _Stdout(b""), _Stderr())
@ -201,7 +154,6 @@ def test_remote_harvest_happy_path(tmp_path: Path, monkeypatch):
out_dir = tmp_path / "out" out_dir = tmp_path / "out"
state_path = r.remote_harvest( state_path = r.remote_harvest(
ask_become_pass=False,
local_out_dir=out_dir, local_out_dir=out_dir,
remote_host="example.com", remote_host="example.com",
remote_port=2222, remote_port=2222,
@ -223,21 +175,13 @@ def test_remote_harvest_happy_path(tmp_path: Path, monkeypatch):
assert "--include-path" in joined assert "--include-path" in joined
assert "--exclude-path" in joined assert "--exclude-path" in joined
# Ensure we fall back to PTY only when sudo reports it is required. # Ensure PTY is used for sudo commands (sudoers requiretty) but not for tar.
assert any(c == "id -un" and pty is False for c, pty in calls) pty_by_cmd = {c: pty for c, pty in calls}
assert pty_by_cmd.get("id -un") is False
sudo_harvest = [ assert any(
(c, pty) for c, pty in calls if c.startswith("sudo -n") and " harvest " in c c.startswith("sudo") and " harvest " in c and pty is True for c, pty in calls
] )
assert any(pty is False for _c, pty in sudo_harvest) assert any(c.startswith("sudo chown -R") and pty is True for c, pty in calls)
assert any(pty is True for _c, pty in sudo_harvest)
sudo_chown = [
(c, pty) for c, pty in calls if c.startswith("sudo -n") and " chown -R" in c
]
assert any(pty is False for _c, pty in sudo_chown)
assert any(pty is True for _c, pty in sudo_chown)
assert any(c.startswith("tar -cz -C") and pty is False for c, pty in calls) assert any(c.startswith("tar -cz -C") and pty is False for c, pty in calls)
@ -260,53 +204,16 @@ def test_remote_harvest_no_sudo_does_not_request_pty_or_chown(
calls: list[tuple[str, bool]] = [] calls: list[tuple[str, bool]] = []
class _Chan: class _Chan:
def __init__(self, out: bytes = b"", err: bytes = b"", rc: int = 0): def __init__(self, rc: int = 0):
self._out = out
self._err = err
self._out_i = 0
self._err_i = 0
self._rc = rc self._rc = rc
self._closed = False
def recv_ready(self) -> bool:
return (not self._closed) and self._out_i < len(self._out)
def recv(self, n: int) -> bytes:
if self._closed:
return b""
chunk = self._out[self._out_i : self._out_i + n]
self._out_i += len(chunk)
return chunk
def recv_stderr_ready(self) -> bool:
return (not self._closed) and self._err_i < len(self._err)
def recv_stderr(self, n: int) -> bytes:
if self._closed:
return b""
chunk = self._err[self._err_i : self._err_i + n]
self._err_i += len(chunk)
return chunk
def exit_status_ready(self) -> bool:
return self._closed or (
self._out_i >= len(self._out) and self._err_i >= len(self._err)
)
def recv_exit_status(self) -> int: def recv_exit_status(self) -> int:
return self._rc return self._rc
def shutdown_write(self) -> None:
return
def close(self) -> None:
self._closed = True
class _Stdout: class _Stdout:
def __init__(self, payload: bytes = b"", rc: int = 0, err: bytes = b""): def __init__(self, payload: bytes = b"", rc: int = 0):
self._bio = io.BytesIO(payload) self._bio = io.BytesIO(payload)
# _ssh_run reads stdout/stderr via the underlying channel. self.channel = _Chan(rc)
self.channel = _Chan(out=payload, err=err, rc=rc)
def read(self, n: int = -1) -> bytes: def read(self, n: int = -1) -> bytes:
return self._bio.read(n) return self._bio.read(n)
@ -371,7 +278,6 @@ def test_remote_harvest_no_sudo_does_not_request_pty_or_chown(
out_dir = tmp_path / "out" out_dir = tmp_path / "out"
r.remote_harvest( r.remote_harvest(
ask_become_pass=False,
local_out_dir=out_dir, local_out_dir=out_dir,
remote_host="example.com", remote_host="example.com",
remote_user="alice", remote_user="alice",
@ -382,186 +288,3 @@ def test_remote_harvest_no_sudo_does_not_request_pty_or_chown(
assert "sudo" not in joined assert "sudo" not in joined
assert "sudo chown" not in joined assert "sudo chown" not in joined
assert any(" harvest " in c and pty is False for c, pty in calls) assert any(" harvest " in c and pty is False for c, pty in calls)
def test_remote_harvest_sudo_password_retry_uses_sudo_s_and_writes_password(
tmp_path: Path, monkeypatch
):
"""If sudo requires a password, we should fall back from -n to -S and feed stdin."""
import sys
import types
import enroll.remote as r
# Avoid building a real zipapp; just create a file.
monkeypatch.setattr(
r,
"_build_enroll_pyz",
lambda td: (Path(td) / "enroll.pyz").write_bytes(b"PYZ")
or (Path(td) / "enroll.pyz"),
)
tgz = _make_tgz_bytes({"state.json": b'{"ok": true}\n'})
calls: list[tuple[str, bool]] = []
stdin_by_cmd: dict[str, list[str]] = {}
class _Chan:
def __init__(self, out: bytes = b"", err: bytes = b"", rc: int = 0):
self._out = out
self._err = err
self._out_i = 0
self._err_i = 0
self._rc = rc
self._closed = False
def recv_ready(self) -> bool:
return (not self._closed) and self._out_i < len(self._out)
def recv(self, n: int) -> bytes:
if self._closed:
return b""
chunk = self._out[self._out_i : self._out_i + n]
self._out_i += len(chunk)
return chunk
def recv_stderr_ready(self) -> bool:
return (not self._closed) and self._err_i < len(self._err)
def recv_stderr(self, n: int) -> bytes:
if self._closed:
return b""
chunk = self._err[self._err_i : self._err_i + n]
self._err_i += len(chunk)
return chunk
def exit_status_ready(self) -> bool:
return self._closed or (
self._out_i >= len(self._out) and self._err_i >= len(self._err)
)
def recv_exit_status(self) -> int:
return self._rc
def shutdown_write(self) -> None:
return
def close(self) -> None:
self._closed = True
class _Stdout:
def __init__(self, payload: bytes = b"", rc: int = 0, err: bytes = b""):
self._bio = io.BytesIO(payload)
# _ssh_run reads stdout/stderr via the underlying channel.
self.channel = _Chan(out=payload, err=err, rc=rc)
def read(self, n: int = -1) -> bytes:
return self._bio.read(n)
class _Stderr:
def __init__(self, payload: bytes = b""):
self._bio = io.BytesIO(payload)
def read(self, n: int = -1) -> bytes:
return self._bio.read(n)
class _Stdin:
def __init__(self, cmd: str):
self._cmd = cmd
stdin_by_cmd.setdefault(cmd, [])
def write(self, s: str) -> None:
stdin_by_cmd[self._cmd].append(s)
def flush(self) -> None:
return
class _SFTP:
def put(self, _local: str, _remote: str) -> None:
return
def close(self) -> None:
return
class FakeSSH:
def __init__(self):
self._sftp = _SFTP()
def load_system_host_keys(self):
return
def set_missing_host_key_policy(self, _policy):
return
def connect(self, **_kwargs):
return
def open_sftp(self):
return self._sftp
def exec_command(self, cmd: str, *, get_pty: bool = False, **_kwargs):
calls.append((cmd, bool(get_pty)))
# Tar stream
if cmd.startswith("tar -cz -C"):
return (_Stdin(cmd), _Stdout(tgz, rc=0), _Stderr(b""))
if cmd == "mktemp -d":
return (_Stdin(cmd), _Stdout(b"/tmp/enroll-remote-789\n"), _Stderr())
if cmd.startswith("chmod 700"):
return (_Stdin(cmd), _Stdout(b""), _Stderr())
# First attempt: sudo -n fails, prompting is not allowed.
if cmd.startswith("sudo -n") and " harvest " in cmd:
return (
_Stdin(cmd),
_Stdout(b"", rc=1, err=b"sudo: a password is required\n"),
_Stderr(b"sudo: a password is required\n"),
)
# Retry: sudo -S succeeds and should have been fed the password via stdin.
if cmd.startswith("sudo -S") and " harvest " in cmd:
return (_Stdin(cmd), _Stdout(b"", rc=0), _Stderr(b""))
# chown succeeds passwordlessly (e.g., sudo timestamp is warm).
if cmd.startswith("sudo -n") and " chown -R" in cmd:
return (_Stdin(cmd), _Stdout(b"", rc=0), _Stderr(b""))
if cmd.startswith("rm -rf"):
return (_Stdin(cmd), _Stdout(b"", rc=0), _Stderr(b""))
# Fallback for unexpected commands.
return (_Stdin(cmd), _Stdout(b"", rc=0), _Stderr(b""))
def close(self):
return
class RejectPolicy:
pass
monkeypatch.setitem(
sys.modules,
"paramiko",
types.SimpleNamespace(SSHClient=FakeSSH, RejectPolicy=RejectPolicy),
)
out_dir = tmp_path / "out"
state_path = r.remote_harvest(
ask_become_pass=True,
getpass_fn=lambda _prompt="": "s3cr3t",
local_out_dir=out_dir,
remote_host="example.com",
remote_user="alice",
no_sudo=False,
)
assert state_path.exists()
assert b"ok" in state_path.read_bytes()
# Ensure we attempted with sudo -n first, then sudo -S.
sudo_n = [c for c, _pty in calls if c.startswith("sudo -n") and " harvest " in c]
sudo_s = [c for c, _pty in calls if c.startswith("sudo -S") and " harvest " in c]
assert len(sudo_n) == 1
assert len(sudo_s) == 1
# Ensure the password was written to stdin for the -S invocation.
assert stdin_by_cmd.get(sudo_s[0]) == ["s3cr3t\n"]