Compare commits
No commits in common. "a2be708a315f39fdcd07f8251b5347072e4b614b" and "fd55bcde9b4bb561cfdbc39403c62742bc9beec0" have entirely different histories.
a2be708a31
...
fd55bcde9b
5 changed files with 31 additions and 683 deletions
|
|
@ -1,5 +0,0 @@
|
||||||
## Contributors
|
|
||||||
|
|
||||||
mig5 would like to thank the following people for their contributions to Enroll.
|
|
||||||
|
|
||||||
* [slhck](https://slhck.info/)
|
|
||||||
|
|
@ -13,7 +13,7 @@ from .cache import new_harvest_cache_dir
|
||||||
from .diff import compare_harvests, format_report, post_webhook, send_email
|
from .diff import compare_harvests, format_report, post_webhook, send_email
|
||||||
from .harvest import harvest
|
from .harvest import harvest
|
||||||
from .manifest import manifest
|
from .manifest import manifest
|
||||||
from .remote import remote_harvest, RemoteSudoPasswordRequired
|
from .remote import remote_harvest
|
||||||
from .sopsutil import SopsError, encrypt_file_binary
|
from .sopsutil import SopsError, encrypt_file_binary
|
||||||
from .version import get_enroll_version
|
from .version import get_enroll_version
|
||||||
|
|
||||||
|
|
@ -352,17 +352,6 @@ def _add_remote_args(p: argparse.ArgumentParser) -> None:
|
||||||
help="SSH username for --remote-host (default: local $USER).",
|
help="SSH username for --remote-host (default: local $USER).",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Align terminology with Ansible: "become" == sudo.
|
|
||||||
p.add_argument(
|
|
||||||
"--ask-become-pass",
|
|
||||||
"-K",
|
|
||||||
action="store_true",
|
|
||||||
help=(
|
|
||||||
"Prompt for the remote sudo (become) password when using --remote-host "
|
|
||||||
"(similar to ansible --ask-become-pass)."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
def main() -> None:
|
||||||
ap = argparse.ArgumentParser(prog="enroll")
|
ap = argparse.ArgumentParser(prog="enroll")
|
||||||
|
|
@ -634,7 +623,6 @@ def main() -> None:
|
||||||
except OSError:
|
except OSError:
|
||||||
pass
|
pass
|
||||||
remote_harvest(
|
remote_harvest(
|
||||||
ask_become_pass=args.ask_become_pass,
|
|
||||||
local_out_dir=tmp_bundle,
|
local_out_dir=tmp_bundle,
|
||||||
remote_host=args.remote_host,
|
remote_host=args.remote_host,
|
||||||
remote_port=int(args.remote_port),
|
remote_port=int(args.remote_port),
|
||||||
|
|
@ -655,7 +643,6 @@ def main() -> None:
|
||||||
else new_harvest_cache_dir(hint=args.remote_host).dir
|
else new_harvest_cache_dir(hint=args.remote_host).dir
|
||||||
)
|
)
|
||||||
state = remote_harvest(
|
state = remote_harvest(
|
||||||
ask_become_pass=args.ask_become_pass,
|
|
||||||
local_out_dir=out_dir,
|
local_out_dir=out_dir,
|
||||||
remote_host=args.remote_host,
|
remote_host=args.remote_host,
|
||||||
remote_port=int(args.remote_port),
|
remote_port=int(args.remote_port),
|
||||||
|
|
@ -782,7 +769,6 @@ def main() -> None:
|
||||||
except OSError:
|
except OSError:
|
||||||
pass
|
pass
|
||||||
remote_harvest(
|
remote_harvest(
|
||||||
ask_become_pass=args.ask_become_pass,
|
|
||||||
local_out_dir=tmp_bundle,
|
local_out_dir=tmp_bundle,
|
||||||
remote_host=args.remote_host,
|
remote_host=args.remote_host,
|
||||||
remote_port=int(args.remote_port),
|
remote_port=int(args.remote_port),
|
||||||
|
|
@ -812,7 +798,6 @@ def main() -> None:
|
||||||
else new_harvest_cache_dir(hint=args.remote_host).dir
|
else new_harvest_cache_dir(hint=args.remote_host).dir
|
||||||
)
|
)
|
||||||
remote_harvest(
|
remote_harvest(
|
||||||
ask_become_pass=args.ask_become_pass,
|
|
||||||
local_out_dir=harvest_dir,
|
local_out_dir=harvest_dir,
|
||||||
remote_host=args.remote_host,
|
remote_host=args.remote_host,
|
||||||
remote_port=int(args.remote_port),
|
remote_port=int(args.remote_port),
|
||||||
|
|
@ -927,11 +912,5 @@ def main() -> None:
|
||||||
|
|
||||||
if getattr(args, "exit_code", False) and has_changes:
|
if getattr(args, "exit_code", False) and has_changes:
|
||||||
raise SystemExit(2)
|
raise SystemExit(2)
|
||||||
except RemoteSudoPasswordRequired:
|
|
||||||
raise SystemExit(
|
|
||||||
"error: remote sudo requires a password. Re-run with --ask-become-pass."
|
|
||||||
) from None
|
|
||||||
except RuntimeError as e:
|
|
||||||
raise SystemExit(f"error: {e}") from None
|
|
||||||
except SopsError as e:
|
except SopsError as e:
|
||||||
raise SystemExit(f"error: {e}") from None
|
raise SystemExit(f"error: {e}")
|
||||||
|
|
|
||||||
268
enroll/remote.py
268
enroll/remote.py
|
|
@ -1,117 +1,14 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import getpass
|
|
||||||
import os
|
import os
|
||||||
import shlex
|
import shlex
|
||||||
import shutil
|
import shutil
|
||||||
import sys
|
|
||||||
import time
|
|
||||||
import tarfile
|
import tarfile
|
||||||
import tempfile
|
import tempfile
|
||||||
import zipapp
|
import zipapp
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from pathlib import PurePosixPath
|
from pathlib import PurePosixPath
|
||||||
from typing import Optional, Callable, TextIO
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
class RemoteSudoPasswordRequired(RuntimeError):
|
|
||||||
"""Raised when sudo requires a password but none was provided."""
|
|
||||||
|
|
||||||
|
|
||||||
def _sudo_password_required(out: str, err: str) -> bool:
|
|
||||||
"""Return True if sudo output indicates it needs a password/TTY."""
|
|
||||||
blob = (out + "\n" + err).lower()
|
|
||||||
patterns = (
|
|
||||||
"a password is required",
|
|
||||||
"password is required",
|
|
||||||
"a terminal is required to read the password",
|
|
||||||
"no tty present and no askpass program specified",
|
|
||||||
"must have a tty to run sudo",
|
|
||||||
"sudo: sorry, you must have a tty",
|
|
||||||
"askpass",
|
|
||||||
)
|
|
||||||
return any(p in blob for p in patterns)
|
|
||||||
|
|
||||||
|
|
||||||
def _sudo_not_permitted(out: str, err: str) -> bool:
|
|
||||||
"""Return True if sudo output indicates the user cannot sudo at all."""
|
|
||||||
blob = (out + "\n" + err).lower()
|
|
||||||
patterns = (
|
|
||||||
"is not in the sudoers file",
|
|
||||||
"not allowed to execute",
|
|
||||||
"may not run sudo",
|
|
||||||
"sorry, user",
|
|
||||||
)
|
|
||||||
return any(p in blob for p in patterns)
|
|
||||||
|
|
||||||
|
|
||||||
def _sudo_tty_required(out: str, err: str) -> bool:
|
|
||||||
"""Return True if sudo output indicates it requires a TTY (sudoers requiretty)."""
|
|
||||||
blob = (out + "\n" + err).lower()
|
|
||||||
patterns = (
|
|
||||||
"must have a tty",
|
|
||||||
"sorry, you must have a tty",
|
|
||||||
"sudo: sorry, you must have a tty",
|
|
||||||
"must have a tty to run sudo",
|
|
||||||
)
|
|
||||||
return any(p in blob for p in patterns)
|
|
||||||
|
|
||||||
|
|
||||||
def _resolve_become_password(
|
|
||||||
ask_become_pass: bool,
|
|
||||||
*,
|
|
||||||
prompt: str = "sudo password: ",
|
|
||||||
getpass_fn: Callable[[str], str] = getpass.getpass,
|
|
||||||
) -> Optional[str]:
|
|
||||||
if ask_become_pass:
|
|
||||||
return getpass_fn(prompt)
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def remote_harvest(
|
|
||||||
*,
|
|
||||||
ask_become_pass: bool = False,
|
|
||||||
no_sudo: bool = False,
|
|
||||||
prompt: str = "sudo password: ",
|
|
||||||
getpass_fn: Optional[Callable[[str], str]] = None,
|
|
||||||
stdin: Optional[TextIO] = None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
"""Call _remote_harvest, with a safe sudo password fallback.
|
|
||||||
|
|
||||||
Behavior:
|
|
||||||
- Run without a password unless --ask-become-pass is set.
|
|
||||||
- If the remote sudo policy requires a password and none was provided,
|
|
||||||
prompt and retry when running interactively.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Resolve defaults at call time (easier to test/monkeypatch, and avoids capturing
|
|
||||||
# sys.stdin / getpass.getpass at import time).
|
|
||||||
if getpass_fn is None:
|
|
||||||
getpass_fn = getpass.getpass
|
|
||||||
if stdin is None:
|
|
||||||
stdin = sys.stdin
|
|
||||||
|
|
||||||
sudo_password = _resolve_become_password(
|
|
||||||
ask_become_pass and not no_sudo,
|
|
||||||
prompt=prompt,
|
|
||||||
getpass_fn=getpass_fn,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
return _remote_harvest(sudo_password=sudo_password, no_sudo=no_sudo, **kwargs)
|
|
||||||
except RemoteSudoPasswordRequired:
|
|
||||||
if sudo_password is not None:
|
|
||||||
raise
|
|
||||||
|
|
||||||
# Fallback prompt if interactive
|
|
||||||
if stdin is not None and getattr(stdin, "isatty", lambda: False)():
|
|
||||||
pw = getpass_fn(prompt)
|
|
||||||
return _remote_harvest(sudo_password=pw, no_sudo=no_sudo, **kwargs)
|
|
||||||
|
|
||||||
raise RemoteSudoPasswordRequired(
|
|
||||||
"Remote sudo requires a password. Re-run with --ask-become-pass."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _safe_extract_tar(tar: tarfile.TarFile, dest: Path) -> None:
|
def _safe_extract_tar(tar: tarfile.TarFile, dest: Path) -> None:
|
||||||
|
|
@ -182,14 +79,7 @@ def _build_enroll_pyz(tmpdir: Path) -> Path:
|
||||||
return pyz_path
|
return pyz_path
|
||||||
|
|
||||||
|
|
||||||
def _ssh_run(
|
def _ssh_run(ssh, cmd: str, *, get_pty: bool = False) -> tuple[int, str, str]:
|
||||||
ssh,
|
|
||||||
cmd: str,
|
|
||||||
*,
|
|
||||||
get_pty: bool = False,
|
|
||||||
stdin_text: Optional[str] = None,
|
|
||||||
close_stdin: bool = False,
|
|
||||||
) -> tuple[int, str, str]:
|
|
||||||
"""Run a command over a Paramiko SSHClient.
|
"""Run a command over a Paramiko SSHClient.
|
||||||
|
|
||||||
Paramiko's exec_command runs commands without a TTY by default.
|
Paramiko's exec_command runs commands without a TTY by default.
|
||||||
|
|
@ -200,133 +90,14 @@ def _ssh_run(
|
||||||
We do not request a PTY for commands that stream binary data
|
We do not request a PTY for commands that stream binary data
|
||||||
(e.g. tar/gzip output), as a PTY can corrupt the byte stream.
|
(e.g. tar/gzip output), as a PTY can corrupt the byte stream.
|
||||||
"""
|
"""
|
||||||
stdin, stdout, stderr = ssh.exec_command(cmd, get_pty=get_pty)
|
_stdin, stdout, stderr = ssh.exec_command(cmd, get_pty=get_pty)
|
||||||
# All three file-like objects share the same underlying Channel.
|
out = stdout.read().decode("utf-8", errors="replace")
|
||||||
chan = stdout.channel
|
err = stderr.read().decode("utf-8", errors="replace")
|
||||||
|
rc = stdout.channel.recv_exit_status()
|
||||||
if stdin_text is not None and stdin is not None:
|
|
||||||
try:
|
|
||||||
stdin.write(stdin_text)
|
|
||||||
stdin.flush()
|
|
||||||
except Exception:
|
|
||||||
# If the remote side closed stdin early, ignore.
|
|
||||||
pass # nosec
|
|
||||||
finally:
|
|
||||||
if close_stdin:
|
|
||||||
# For sudo -S, a wrong password causes sudo to re-prompt and wait
|
|
||||||
# forever for more input. We try hard to deliver EOF so sudo can
|
|
||||||
# fail fast.
|
|
||||||
try:
|
|
||||||
chan.shutdown_write() # sends EOF to the remote process
|
|
||||||
except Exception:
|
|
||||||
pass # nosec
|
|
||||||
try:
|
|
||||||
stdin.close()
|
|
||||||
except Exception:
|
|
||||||
pass # nosec
|
|
||||||
|
|
||||||
# Read incrementally to avoid blocking forever on stdout.read()/stderr.read()
|
|
||||||
# if the remote process is waiting for more input (e.g. sudo password retry).
|
|
||||||
out_chunks: list[bytes] = []
|
|
||||||
err_chunks: list[bytes] = []
|
|
||||||
# Keep a small tail of stderr to detect sudo retry messages without
|
|
||||||
# repeatedly joining potentially large buffers.
|
|
||||||
err_tail = b""
|
|
||||||
|
|
||||||
while True:
|
|
||||||
progressed = False
|
|
||||||
if chan.recv_ready():
|
|
||||||
out_chunks.append(chan.recv(1024 * 64))
|
|
||||||
progressed = True
|
|
||||||
if chan.recv_stderr_ready():
|
|
||||||
chunk = chan.recv_stderr(1024 * 64)
|
|
||||||
err_chunks.append(chunk)
|
|
||||||
err_tail = (err_tail + chunk)[-4096:]
|
|
||||||
progressed = True
|
|
||||||
|
|
||||||
# If we just attempted sudo -S with a single password line and sudo is
|
|
||||||
# asking again, detect it and stop waiting.
|
|
||||||
if close_stdin and stdin_text is not None:
|
|
||||||
blob = err_tail.lower()
|
|
||||||
if b"sorry, try again" in blob or b"incorrect password" in blob:
|
|
||||||
try:
|
|
||||||
chan.close()
|
|
||||||
except Exception:
|
|
||||||
pass # nosec
|
|
||||||
break
|
|
||||||
|
|
||||||
# Exit once the process has exited and we have drained the buffers.
|
|
||||||
if (
|
|
||||||
chan.exit_status_ready()
|
|
||||||
and not chan.recv_ready()
|
|
||||||
and not chan.recv_stderr_ready()
|
|
||||||
):
|
|
||||||
break
|
|
||||||
|
|
||||||
if not progressed:
|
|
||||||
time.sleep(0.05)
|
|
||||||
|
|
||||||
out = b"".join(out_chunks).decode("utf-8", errors="replace")
|
|
||||||
err = b"".join(err_chunks).decode("utf-8", errors="replace")
|
|
||||||
rc = chan.recv_exit_status() if chan.exit_status_ready() else 1
|
|
||||||
return rc, out, err
|
return rc, out, err
|
||||||
|
|
||||||
|
|
||||||
def _ssh_run_sudo(
|
def remote_harvest(
|
||||||
ssh,
|
|
||||||
cmd: str,
|
|
||||||
*,
|
|
||||||
sudo_password: Optional[str] = None,
|
|
||||||
get_pty: bool = True,
|
|
||||||
) -> tuple[int, str, str]:
|
|
||||||
"""Run cmd via sudo with a safe non-interactive-first strategy.
|
|
||||||
|
|
||||||
Strategy:
|
|
||||||
1) Try `sudo -n`.
|
|
||||||
2) If sudo reports a password is required and we have one, retry with
|
|
||||||
`sudo -S` and feed it via stdin.
|
|
||||||
3) If sudo reports a password is required and we *don't* have one, raise
|
|
||||||
RemoteSudoPasswordRequired.
|
|
||||||
|
|
||||||
We avoid requesting a PTY unless the remote sudo policy requires it.
|
|
||||||
This makes sudo -S behavior more reliable (wrong passwords fail fast
|
|
||||||
instead of blocking on a PTY).
|
|
||||||
"""
|
|
||||||
cmd_n = f"sudo -n -p '' -- {cmd}"
|
|
||||||
|
|
||||||
# First try: never prompt, and prefer no PTY.
|
|
||||||
rc, out, err = _ssh_run(ssh, cmd_n, get_pty=False)
|
|
||||||
need_pty = False
|
|
||||||
|
|
||||||
# Some sudoers configurations require a TTY even for passwordless sudo.
|
|
||||||
if get_pty and rc != 0 and _sudo_tty_required(out, err):
|
|
||||||
need_pty = True
|
|
||||||
rc, out, err = _ssh_run(ssh, cmd_n, get_pty=True)
|
|
||||||
|
|
||||||
if rc == 0:
|
|
||||||
return rc, out, err
|
|
||||||
|
|
||||||
if _sudo_not_permitted(out, err):
|
|
||||||
return rc, out, err
|
|
||||||
|
|
||||||
if _sudo_password_required(out, err):
|
|
||||||
if sudo_password is None:
|
|
||||||
raise RemoteSudoPasswordRequired(
|
|
||||||
"Remote sudo requires a password, but none was provided."
|
|
||||||
)
|
|
||||||
cmd_s = f"sudo -S -p '' -- {cmd}"
|
|
||||||
return _ssh_run(
|
|
||||||
ssh,
|
|
||||||
cmd_s,
|
|
||||||
get_pty=need_pty,
|
|
||||||
stdin_text=str(sudo_password) + "\n",
|
|
||||||
close_stdin=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
return rc, out, err
|
|
||||||
|
|
||||||
|
|
||||||
def _remote_harvest(
|
|
||||||
*,
|
*,
|
||||||
local_out_dir: Path,
|
local_out_dir: Path,
|
||||||
remote_host: str,
|
remote_host: str,
|
||||||
|
|
@ -335,7 +106,6 @@ def _remote_harvest(
|
||||||
remote_python: str = "python3",
|
remote_python: str = "python3",
|
||||||
dangerous: bool = False,
|
dangerous: bool = False,
|
||||||
no_sudo: bool = False,
|
no_sudo: bool = False,
|
||||||
sudo_password: Optional[str] = None,
|
|
||||||
include_paths: Optional[list[str]] = None,
|
include_paths: Optional[list[str]] = None,
|
||||||
exclude_paths: Optional[list[str]] = None,
|
exclude_paths: Optional[list[str]] = None,
|
||||||
) -> Path:
|
) -> Path:
|
||||||
|
|
@ -420,15 +190,10 @@ def _remote_harvest(
|
||||||
argv.extend(["--exclude-path", str(p)])
|
argv.extend(["--exclude-path", str(p)])
|
||||||
|
|
||||||
_cmd = " ".join(map(shlex.quote, argv))
|
_cmd = " ".join(map(shlex.quote, argv))
|
||||||
if not no_sudo:
|
cmd = f"sudo {_cmd}" if not no_sudo else _cmd
|
||||||
# Prefer non-interactive sudo first; retry with -S only when needed.
|
|
||||||
rc, out, err = _ssh_run_sudo(
|
# PTY for sudo commands (helps sudoers requiretty).
|
||||||
ssh, _cmd, sudo_password=sudo_password, get_pty=True
|
rc, out, err = _ssh_run(ssh, cmd, get_pty=(not no_sudo))
|
||||||
)
|
|
||||||
cmd = f"sudo {_cmd}"
|
|
||||||
else:
|
|
||||||
cmd = _cmd
|
|
||||||
rc, out, err = _ssh_run(ssh, cmd, get_pty=False)
|
|
||||||
if rc != 0:
|
if rc != 0:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Remote harvest failed.\n"
|
"Remote harvest failed.\n"
|
||||||
|
|
@ -445,17 +210,12 @@ def _remote_harvest(
|
||||||
"Unable to determine remote username for chown. "
|
"Unable to determine remote username for chown. "
|
||||||
"Pass --remote-user explicitly or use --no-sudo."
|
"Pass --remote-user explicitly or use --no-sudo."
|
||||||
)
|
)
|
||||||
chown_cmd = f"chown -R {resolved_user} {rbundle}"
|
cmd = f"sudo chown -R {resolved_user} {rbundle}"
|
||||||
rc, out, err = _ssh_run_sudo(
|
rc, out, err = _ssh_run(ssh, cmd, get_pty=True)
|
||||||
ssh,
|
|
||||||
chown_cmd,
|
|
||||||
sudo_password=sudo_password,
|
|
||||||
get_pty=True,
|
|
||||||
)
|
|
||||||
if rc != 0:
|
if rc != 0:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"chown of harvest failed.\n"
|
"chown of harvest failed.\n"
|
||||||
f"Command: sudo {chown_cmd}\n"
|
f"Command: {cmd}\n"
|
||||||
f"Exit code: {rc}\n"
|
f"Exit code: {rc}\n"
|
||||||
f"Stdout: {out.strip()}\n"
|
f"Stdout: {out.strip()}\n"
|
||||||
f"Stderr: {err.strip()}"
|
f"Stderr: {err.strip()}"
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,5 @@
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
import enroll.cli as cli
|
import enroll.cli as cli
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -260,113 +258,6 @@ def test_cli_single_shot_remote_without_harvest_prints_state_path(
|
||||||
assert ("manifest", str(cache_dir), str(ansible_dir), "example.test") in calls
|
assert ("manifest", str(cache_dir), str(ansible_dir), "example.test") in calls
|
||||||
|
|
||||||
|
|
||||||
def test_cli_harvest_remote_ask_become_pass_prompts_and_passes_password(
|
|
||||||
monkeypatch, tmp_path
|
|
||||||
):
|
|
||||||
from enroll.cache import HarvestCache
|
|
||||||
import enroll.remote as r
|
|
||||||
|
|
||||||
cache_dir = tmp_path / "cache"
|
|
||||||
cache_dir.mkdir()
|
|
||||||
|
|
||||||
called = {}
|
|
||||||
|
|
||||||
def fake_cache_dir(*, hint=None):
|
|
||||||
return HarvestCache(dir=cache_dir)
|
|
||||||
|
|
||||||
def fake__remote_harvest(*, sudo_password=None, **kwargs):
|
|
||||||
called["sudo_password"] = sudo_password
|
|
||||||
return cache_dir / "state.json"
|
|
||||||
|
|
||||||
monkeypatch.setattr(cli, "new_harvest_cache_dir", fake_cache_dir)
|
|
||||||
monkeypatch.setattr(r, "_remote_harvest", fake__remote_harvest)
|
|
||||||
monkeypatch.setattr(r.getpass, "getpass", lambda _prompt="": "pw123")
|
|
||||||
|
|
||||||
monkeypatch.setattr(
|
|
||||||
sys,
|
|
||||||
"argv",
|
|
||||||
[
|
|
||||||
"enroll",
|
|
||||||
"harvest",
|
|
||||||
"--remote-host",
|
|
||||||
"example.test",
|
|
||||||
"--ask-become-pass",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
cli.main()
|
|
||||||
assert called["sudo_password"] == "pw123"
|
|
||||||
|
|
||||||
|
|
||||||
def test_cli_harvest_remote_password_required_fallback_prompts_and_retries(
|
|
||||||
monkeypatch, tmp_path
|
|
||||||
):
|
|
||||||
from enroll.cache import HarvestCache
|
|
||||||
import enroll.remote as r
|
|
||||||
|
|
||||||
cache_dir = tmp_path / "cache"
|
|
||||||
cache_dir.mkdir()
|
|
||||||
|
|
||||||
def fake_cache_dir(*, hint=None):
|
|
||||||
return HarvestCache(dir=cache_dir)
|
|
||||||
|
|
||||||
calls = []
|
|
||||||
|
|
||||||
def fake__remote_harvest(*, sudo_password=None, **kwargs):
|
|
||||||
calls.append(sudo_password)
|
|
||||||
if sudo_password is None:
|
|
||||||
raise r.RemoteSudoPasswordRequired("pw required")
|
|
||||||
return cache_dir / "state.json"
|
|
||||||
|
|
||||||
class _TTYStdin:
|
|
||||||
def isatty(self):
|
|
||||||
return True
|
|
||||||
|
|
||||||
monkeypatch.setattr(cli, "new_harvest_cache_dir", fake_cache_dir)
|
|
||||||
monkeypatch.setattr(r, "_remote_harvest", fake__remote_harvest)
|
|
||||||
monkeypatch.setattr(r.getpass, "getpass", lambda _prompt="": "pw456")
|
|
||||||
monkeypatch.setattr(sys, "stdin", _TTYStdin())
|
|
||||||
|
|
||||||
monkeypatch.setattr(
|
|
||||||
sys, "argv", ["enroll", "harvest", "--remote-host", "example.test"]
|
|
||||||
)
|
|
||||||
|
|
||||||
cli.main()
|
|
||||||
assert calls == [None, "pw456"]
|
|
||||||
|
|
||||||
|
|
||||||
def test_cli_harvest_remote_password_required_noninteractive_errors(
|
|
||||||
monkeypatch, tmp_path
|
|
||||||
):
|
|
||||||
from enroll.cache import HarvestCache
|
|
||||||
import enroll.remote as r
|
|
||||||
|
|
||||||
cache_dir = tmp_path / "cache"
|
|
||||||
cache_dir.mkdir()
|
|
||||||
|
|
||||||
def fake_cache_dir(*, hint=None):
|
|
||||||
return HarvestCache(dir=cache_dir)
|
|
||||||
|
|
||||||
def fake__remote_harvest(*, sudo_password=None, **kwargs):
|
|
||||||
raise r.RemoteSudoPasswordRequired("pw required")
|
|
||||||
|
|
||||||
class _NoTTYStdin:
|
|
||||||
def isatty(self):
|
|
||||||
return False
|
|
||||||
|
|
||||||
monkeypatch.setattr(cli, "new_harvest_cache_dir", fake_cache_dir)
|
|
||||||
monkeypatch.setattr(r, "_remote_harvest", fake__remote_harvest)
|
|
||||||
monkeypatch.setattr(sys, "stdin", _NoTTYStdin())
|
|
||||||
|
|
||||||
monkeypatch.setattr(
|
|
||||||
sys, "argv", ["enroll", "harvest", "--remote-host", "example.test"]
|
|
||||||
)
|
|
||||||
|
|
||||||
with pytest.raises(SystemExit) as e:
|
|
||||||
cli.main()
|
|
||||||
assert "--ask-become-pass" in str(e.value)
|
|
||||||
|
|
||||||
|
|
||||||
def test_cli_manifest_common_args(monkeypatch, tmp_path):
|
def test_cli_manifest_common_args(monkeypatch, tmp_path):
|
||||||
"""Ensure --fqdn and jinjaturtle mode flags are forwarded correctly."""
|
"""Ensure --fqdn and jinjaturtle mode flags are forwarded correctly."""
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -69,53 +69,16 @@ def test_remote_harvest_happy_path(tmp_path: Path, monkeypatch):
|
||||||
calls: list[tuple[str, bool]] = []
|
calls: list[tuple[str, bool]] = []
|
||||||
|
|
||||||
class _Chan:
|
class _Chan:
|
||||||
def __init__(self, out: bytes = b"", err: bytes = b"", rc: int = 0):
|
def __init__(self, rc: int = 0):
|
||||||
self._out = out
|
|
||||||
self._err = err
|
|
||||||
self._out_i = 0
|
|
||||||
self._err_i = 0
|
|
||||||
self._rc = rc
|
self._rc = rc
|
||||||
self._closed = False
|
|
||||||
|
|
||||||
def recv_ready(self) -> bool:
|
|
||||||
return (not self._closed) and self._out_i < len(self._out)
|
|
||||||
|
|
||||||
def recv(self, n: int) -> bytes:
|
|
||||||
if self._closed:
|
|
||||||
return b""
|
|
||||||
chunk = self._out[self._out_i : self._out_i + n]
|
|
||||||
self._out_i += len(chunk)
|
|
||||||
return chunk
|
|
||||||
|
|
||||||
def recv_stderr_ready(self) -> bool:
|
|
||||||
return (not self._closed) and self._err_i < len(self._err)
|
|
||||||
|
|
||||||
def recv_stderr(self, n: int) -> bytes:
|
|
||||||
if self._closed:
|
|
||||||
return b""
|
|
||||||
chunk = self._err[self._err_i : self._err_i + n]
|
|
||||||
self._err_i += len(chunk)
|
|
||||||
return chunk
|
|
||||||
|
|
||||||
def exit_status_ready(self) -> bool:
|
|
||||||
return self._closed or (
|
|
||||||
self._out_i >= len(self._out) and self._err_i >= len(self._err)
|
|
||||||
)
|
|
||||||
|
|
||||||
def recv_exit_status(self) -> int:
|
def recv_exit_status(self) -> int:
|
||||||
return self._rc
|
return self._rc
|
||||||
|
|
||||||
def shutdown_write(self) -> None:
|
|
||||||
return
|
|
||||||
|
|
||||||
def close(self) -> None:
|
|
||||||
self._closed = True
|
|
||||||
|
|
||||||
class _Stdout:
|
class _Stdout:
|
||||||
def __init__(self, payload: bytes = b"", rc: int = 0, err: bytes = b""):
|
def __init__(self, payload: bytes = b"", rc: int = 0):
|
||||||
self._bio = io.BytesIO(payload)
|
self._bio = io.BytesIO(payload)
|
||||||
# _ssh_run reads stdout/stderr via the underlying channel.
|
self.channel = _Chan(rc)
|
||||||
self.channel = _Chan(out=payload, err=err, rc=rc)
|
|
||||||
|
|
||||||
def read(self, n: int = -1) -> bytes:
|
def read(self, n: int = -1) -> bytes:
|
||||||
return self._bio.read(n)
|
return self._bio.read(n)
|
||||||
|
|
@ -167,20 +130,10 @@ def test_remote_harvest_happy_path(tmp_path: Path, monkeypatch):
|
||||||
return (None, _Stdout(b"/tmp/enroll-remote-123\n"), _Stderr())
|
return (None, _Stdout(b"/tmp/enroll-remote-123\n"), _Stderr())
|
||||||
if cmd.startswith("chmod 700"):
|
if cmd.startswith("chmod 700"):
|
||||||
return (None, _Stdout(b""), _Stderr())
|
return (None, _Stdout(b""), _Stderr())
|
||||||
if cmd.startswith("sudo -n") and " harvest " in cmd:
|
|
||||||
if not get_pty:
|
|
||||||
msg = b"sudo: sorry, you must have a tty to run sudo\n"
|
|
||||||
return (None, _Stdout(b"", rc=1, err=msg), _Stderr(msg))
|
|
||||||
return (None, _Stdout(b"", rc=0), _Stderr(b""))
|
|
||||||
if cmd.startswith("sudo -S") and " harvest " in cmd:
|
|
||||||
return (None, _Stdout(b""), _Stderr())
|
|
||||||
if " harvest " in cmd:
|
if " harvest " in cmd:
|
||||||
return (None, _Stdout(b""), _Stderr())
|
return (None, _Stdout(b""), _Stderr())
|
||||||
if cmd.startswith("sudo -n") and " chown -R" in cmd:
|
if cmd.startswith("sudo chown -R"):
|
||||||
if not get_pty:
|
return (None, _Stdout(b""), _Stderr())
|
||||||
msg = b"sudo: sorry, you must have a tty to run sudo\n"
|
|
||||||
return (None, _Stdout(b"", rc=1, err=msg), _Stderr(msg))
|
|
||||||
return (None, _Stdout(b"", rc=0), _Stderr(b""))
|
|
||||||
if cmd.startswith("rm -rf"):
|
if cmd.startswith("rm -rf"):
|
||||||
return (None, _Stdout(b""), _Stderr())
|
return (None, _Stdout(b""), _Stderr())
|
||||||
|
|
||||||
|
|
@ -201,7 +154,6 @@ def test_remote_harvest_happy_path(tmp_path: Path, monkeypatch):
|
||||||
|
|
||||||
out_dir = tmp_path / "out"
|
out_dir = tmp_path / "out"
|
||||||
state_path = r.remote_harvest(
|
state_path = r.remote_harvest(
|
||||||
ask_become_pass=False,
|
|
||||||
local_out_dir=out_dir,
|
local_out_dir=out_dir,
|
||||||
remote_host="example.com",
|
remote_host="example.com",
|
||||||
remote_port=2222,
|
remote_port=2222,
|
||||||
|
|
@ -223,21 +175,13 @@ def test_remote_harvest_happy_path(tmp_path: Path, monkeypatch):
|
||||||
assert "--include-path" in joined
|
assert "--include-path" in joined
|
||||||
assert "--exclude-path" in joined
|
assert "--exclude-path" in joined
|
||||||
|
|
||||||
# Ensure we fall back to PTY only when sudo reports it is required.
|
# Ensure PTY is used for sudo commands (sudoers requiretty) but not for tar.
|
||||||
assert any(c == "id -un" and pty is False for c, pty in calls)
|
pty_by_cmd = {c: pty for c, pty in calls}
|
||||||
|
assert pty_by_cmd.get("id -un") is False
|
||||||
sudo_harvest = [
|
assert any(
|
||||||
(c, pty) for c, pty in calls if c.startswith("sudo -n") and " harvest " in c
|
c.startswith("sudo") and " harvest " in c and pty is True for c, pty in calls
|
||||||
]
|
)
|
||||||
assert any(pty is False for _c, pty in sudo_harvest)
|
assert any(c.startswith("sudo chown -R") and pty is True for c, pty in calls)
|
||||||
assert any(pty is True for _c, pty in sudo_harvest)
|
|
||||||
|
|
||||||
sudo_chown = [
|
|
||||||
(c, pty) for c, pty in calls if c.startswith("sudo -n") and " chown -R" in c
|
|
||||||
]
|
|
||||||
assert any(pty is False for _c, pty in sudo_chown)
|
|
||||||
assert any(pty is True for _c, pty in sudo_chown)
|
|
||||||
|
|
||||||
assert any(c.startswith("tar -cz -C") and pty is False for c, pty in calls)
|
assert any(c.startswith("tar -cz -C") and pty is False for c, pty in calls)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -260,53 +204,16 @@ def test_remote_harvest_no_sudo_does_not_request_pty_or_chown(
|
||||||
calls: list[tuple[str, bool]] = []
|
calls: list[tuple[str, bool]] = []
|
||||||
|
|
||||||
class _Chan:
|
class _Chan:
|
||||||
def __init__(self, out: bytes = b"", err: bytes = b"", rc: int = 0):
|
def __init__(self, rc: int = 0):
|
||||||
self._out = out
|
|
||||||
self._err = err
|
|
||||||
self._out_i = 0
|
|
||||||
self._err_i = 0
|
|
||||||
self._rc = rc
|
self._rc = rc
|
||||||
self._closed = False
|
|
||||||
|
|
||||||
def recv_ready(self) -> bool:
|
|
||||||
return (not self._closed) and self._out_i < len(self._out)
|
|
||||||
|
|
||||||
def recv(self, n: int) -> bytes:
|
|
||||||
if self._closed:
|
|
||||||
return b""
|
|
||||||
chunk = self._out[self._out_i : self._out_i + n]
|
|
||||||
self._out_i += len(chunk)
|
|
||||||
return chunk
|
|
||||||
|
|
||||||
def recv_stderr_ready(self) -> bool:
|
|
||||||
return (not self._closed) and self._err_i < len(self._err)
|
|
||||||
|
|
||||||
def recv_stderr(self, n: int) -> bytes:
|
|
||||||
if self._closed:
|
|
||||||
return b""
|
|
||||||
chunk = self._err[self._err_i : self._err_i + n]
|
|
||||||
self._err_i += len(chunk)
|
|
||||||
return chunk
|
|
||||||
|
|
||||||
def exit_status_ready(self) -> bool:
|
|
||||||
return self._closed or (
|
|
||||||
self._out_i >= len(self._out) and self._err_i >= len(self._err)
|
|
||||||
)
|
|
||||||
|
|
||||||
def recv_exit_status(self) -> int:
|
def recv_exit_status(self) -> int:
|
||||||
return self._rc
|
return self._rc
|
||||||
|
|
||||||
def shutdown_write(self) -> None:
|
|
||||||
return
|
|
||||||
|
|
||||||
def close(self) -> None:
|
|
||||||
self._closed = True
|
|
||||||
|
|
||||||
class _Stdout:
|
class _Stdout:
|
||||||
def __init__(self, payload: bytes = b"", rc: int = 0, err: bytes = b""):
|
def __init__(self, payload: bytes = b"", rc: int = 0):
|
||||||
self._bio = io.BytesIO(payload)
|
self._bio = io.BytesIO(payload)
|
||||||
# _ssh_run reads stdout/stderr via the underlying channel.
|
self.channel = _Chan(rc)
|
||||||
self.channel = _Chan(out=payload, err=err, rc=rc)
|
|
||||||
|
|
||||||
def read(self, n: int = -1) -> bytes:
|
def read(self, n: int = -1) -> bytes:
|
||||||
return self._bio.read(n)
|
return self._bio.read(n)
|
||||||
|
|
@ -371,7 +278,6 @@ def test_remote_harvest_no_sudo_does_not_request_pty_or_chown(
|
||||||
|
|
||||||
out_dir = tmp_path / "out"
|
out_dir = tmp_path / "out"
|
||||||
r.remote_harvest(
|
r.remote_harvest(
|
||||||
ask_become_pass=False,
|
|
||||||
local_out_dir=out_dir,
|
local_out_dir=out_dir,
|
||||||
remote_host="example.com",
|
remote_host="example.com",
|
||||||
remote_user="alice",
|
remote_user="alice",
|
||||||
|
|
@ -382,186 +288,3 @@ def test_remote_harvest_no_sudo_does_not_request_pty_or_chown(
|
||||||
assert "sudo" not in joined
|
assert "sudo" not in joined
|
||||||
assert "sudo chown" not in joined
|
assert "sudo chown" not in joined
|
||||||
assert any(" harvest " in c and pty is False for c, pty in calls)
|
assert any(" harvest " in c and pty is False for c, pty in calls)
|
||||||
|
|
||||||
|
|
||||||
def test_remote_harvest_sudo_password_retry_uses_sudo_s_and_writes_password(
|
|
||||||
tmp_path: Path, monkeypatch
|
|
||||||
):
|
|
||||||
"""If sudo requires a password, we should fall back from -n to -S and feed stdin."""
|
|
||||||
import sys
|
|
||||||
import types
|
|
||||||
|
|
||||||
import enroll.remote as r
|
|
||||||
|
|
||||||
# Avoid building a real zipapp; just create a file.
|
|
||||||
monkeypatch.setattr(
|
|
||||||
r,
|
|
||||||
"_build_enroll_pyz",
|
|
||||||
lambda td: (Path(td) / "enroll.pyz").write_bytes(b"PYZ")
|
|
||||||
or (Path(td) / "enroll.pyz"),
|
|
||||||
)
|
|
||||||
|
|
||||||
tgz = _make_tgz_bytes({"state.json": b'{"ok": true}\n'})
|
|
||||||
calls: list[tuple[str, bool]] = []
|
|
||||||
stdin_by_cmd: dict[str, list[str]] = {}
|
|
||||||
|
|
||||||
class _Chan:
|
|
||||||
def __init__(self, out: bytes = b"", err: bytes = b"", rc: int = 0):
|
|
||||||
self._out = out
|
|
||||||
self._err = err
|
|
||||||
self._out_i = 0
|
|
||||||
self._err_i = 0
|
|
||||||
self._rc = rc
|
|
||||||
self._closed = False
|
|
||||||
|
|
||||||
def recv_ready(self) -> bool:
|
|
||||||
return (not self._closed) and self._out_i < len(self._out)
|
|
||||||
|
|
||||||
def recv(self, n: int) -> bytes:
|
|
||||||
if self._closed:
|
|
||||||
return b""
|
|
||||||
chunk = self._out[self._out_i : self._out_i + n]
|
|
||||||
self._out_i += len(chunk)
|
|
||||||
return chunk
|
|
||||||
|
|
||||||
def recv_stderr_ready(self) -> bool:
|
|
||||||
return (not self._closed) and self._err_i < len(self._err)
|
|
||||||
|
|
||||||
def recv_stderr(self, n: int) -> bytes:
|
|
||||||
if self._closed:
|
|
||||||
return b""
|
|
||||||
chunk = self._err[self._err_i : self._err_i + n]
|
|
||||||
self._err_i += len(chunk)
|
|
||||||
return chunk
|
|
||||||
|
|
||||||
def exit_status_ready(self) -> bool:
|
|
||||||
return self._closed or (
|
|
||||||
self._out_i >= len(self._out) and self._err_i >= len(self._err)
|
|
||||||
)
|
|
||||||
|
|
||||||
def recv_exit_status(self) -> int:
|
|
||||||
return self._rc
|
|
||||||
|
|
||||||
def shutdown_write(self) -> None:
|
|
||||||
return
|
|
||||||
|
|
||||||
def close(self) -> None:
|
|
||||||
self._closed = True
|
|
||||||
|
|
||||||
class _Stdout:
|
|
||||||
def __init__(self, payload: bytes = b"", rc: int = 0, err: bytes = b""):
|
|
||||||
self._bio = io.BytesIO(payload)
|
|
||||||
# _ssh_run reads stdout/stderr via the underlying channel.
|
|
||||||
self.channel = _Chan(out=payload, err=err, rc=rc)
|
|
||||||
|
|
||||||
def read(self, n: int = -1) -> bytes:
|
|
||||||
return self._bio.read(n)
|
|
||||||
|
|
||||||
class _Stderr:
|
|
||||||
def __init__(self, payload: bytes = b""):
|
|
||||||
self._bio = io.BytesIO(payload)
|
|
||||||
|
|
||||||
def read(self, n: int = -1) -> bytes:
|
|
||||||
return self._bio.read(n)
|
|
||||||
|
|
||||||
class _Stdin:
|
|
||||||
def __init__(self, cmd: str):
|
|
||||||
self._cmd = cmd
|
|
||||||
stdin_by_cmd.setdefault(cmd, [])
|
|
||||||
|
|
||||||
def write(self, s: str) -> None:
|
|
||||||
stdin_by_cmd[self._cmd].append(s)
|
|
||||||
|
|
||||||
def flush(self) -> None:
|
|
||||||
return
|
|
||||||
|
|
||||||
class _SFTP:
|
|
||||||
def put(self, _local: str, _remote: str) -> None:
|
|
||||||
return
|
|
||||||
|
|
||||||
def close(self) -> None:
|
|
||||||
return
|
|
||||||
|
|
||||||
class FakeSSH:
|
|
||||||
def __init__(self):
|
|
||||||
self._sftp = _SFTP()
|
|
||||||
|
|
||||||
def load_system_host_keys(self):
|
|
||||||
return
|
|
||||||
|
|
||||||
def set_missing_host_key_policy(self, _policy):
|
|
||||||
return
|
|
||||||
|
|
||||||
def connect(self, **_kwargs):
|
|
||||||
return
|
|
||||||
|
|
||||||
def open_sftp(self):
|
|
||||||
return self._sftp
|
|
||||||
|
|
||||||
def exec_command(self, cmd: str, *, get_pty: bool = False, **_kwargs):
|
|
||||||
calls.append((cmd, bool(get_pty)))
|
|
||||||
|
|
||||||
# Tar stream
|
|
||||||
if cmd.startswith("tar -cz -C"):
|
|
||||||
return (_Stdin(cmd), _Stdout(tgz, rc=0), _Stderr(b""))
|
|
||||||
|
|
||||||
if cmd == "mktemp -d":
|
|
||||||
return (_Stdin(cmd), _Stdout(b"/tmp/enroll-remote-789\n"), _Stderr())
|
|
||||||
if cmd.startswith("chmod 700"):
|
|
||||||
return (_Stdin(cmd), _Stdout(b""), _Stderr())
|
|
||||||
|
|
||||||
# First attempt: sudo -n fails, prompting is not allowed.
|
|
||||||
if cmd.startswith("sudo -n") and " harvest " in cmd:
|
|
||||||
return (
|
|
||||||
_Stdin(cmd),
|
|
||||||
_Stdout(b"", rc=1, err=b"sudo: a password is required\n"),
|
|
||||||
_Stderr(b"sudo: a password is required\n"),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Retry: sudo -S succeeds and should have been fed the password via stdin.
|
|
||||||
if cmd.startswith("sudo -S") and " harvest " in cmd:
|
|
||||||
return (_Stdin(cmd), _Stdout(b"", rc=0), _Stderr(b""))
|
|
||||||
|
|
||||||
# chown succeeds passwordlessly (e.g., sudo timestamp is warm).
|
|
||||||
if cmd.startswith("sudo -n") and " chown -R" in cmd:
|
|
||||||
return (_Stdin(cmd), _Stdout(b"", rc=0), _Stderr(b""))
|
|
||||||
|
|
||||||
if cmd.startswith("rm -rf"):
|
|
||||||
return (_Stdin(cmd), _Stdout(b"", rc=0), _Stderr(b""))
|
|
||||||
|
|
||||||
# Fallback for unexpected commands.
|
|
||||||
return (_Stdin(cmd), _Stdout(b"", rc=0), _Stderr(b""))
|
|
||||||
|
|
||||||
def close(self):
|
|
||||||
return
|
|
||||||
|
|
||||||
class RejectPolicy:
|
|
||||||
pass
|
|
||||||
|
|
||||||
monkeypatch.setitem(
|
|
||||||
sys.modules,
|
|
||||||
"paramiko",
|
|
||||||
types.SimpleNamespace(SSHClient=FakeSSH, RejectPolicy=RejectPolicy),
|
|
||||||
)
|
|
||||||
|
|
||||||
out_dir = tmp_path / "out"
|
|
||||||
state_path = r.remote_harvest(
|
|
||||||
ask_become_pass=True,
|
|
||||||
getpass_fn=lambda _prompt="": "s3cr3t",
|
|
||||||
local_out_dir=out_dir,
|
|
||||||
remote_host="example.com",
|
|
||||||
remote_user="alice",
|
|
||||||
no_sudo=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert state_path.exists()
|
|
||||||
assert b"ok" in state_path.read_bytes()
|
|
||||||
|
|
||||||
# Ensure we attempted with sudo -n first, then sudo -S.
|
|
||||||
sudo_n = [c for c, _pty in calls if c.startswith("sudo -n") and " harvest " in c]
|
|
||||||
sudo_s = [c for c, _pty in calls if c.startswith("sudo -S") and " harvest " in c]
|
|
||||||
assert len(sudo_n) == 1
|
|
||||||
assert len(sudo_s) == 1
|
|
||||||
|
|
||||||
# Ensure the password was written to stdin for the -S invocation.
|
|
||||||
assert stdin_by_cmd.get(sudo_s[0]) == ["s3cr3t\n"]
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue