Avoid TOCTOU issues, stronger perms on manifest dir, don't allow harvesting to existing dir by default, scan whole file for potential secrets
All checks were successful
CI / test (push) Successful in 48s
CI / test (almalinux, docker.io/library/almalinux:9, python3.11) (push) Successful in 11m19s
CI / test (debian, docker.io/library/debian:13, python3) (push) Successful in 20m40s
Lint / test (push) Successful in 48s

This commit is contained in:
Miguel Jacq 2026-06-22 11:41:11 +10:00
parent c7a6bfe979
commit e78f61c5ed
Signed by: mig5
GPG key ID: 03906B4110AAD3B8
12 changed files with 490 additions and 56 deletions

View file

@ -1507,8 +1507,6 @@ Discovery order:
1. `--no-config` disables config loading, 1. `--no-config` disables config loading,
2. `--config PATH` or `-c PATH`, 2. `--config PATH` or `-c PATH`,
3. `$ENROLL_CONFIG`, 3. `$ENROLL_CONFIG`,
4. `./enroll.ini`,
5. `./.enroll.ini`,
6. `$XDG_CONFIG_HOME/enroll/enroll.ini`, 6. `$XDG_CONFIG_HOME/enroll/enroll.ini`,
7. `~/.config/enroll/enroll.ini`. 7. `~/.config/enroll/enroll.ini`.

View file

@ -1,11 +1,11 @@
from __future__ import annotations from __future__ import annotations
import os import os
import shutil import errno
import stat import stat
from typing import List, Optional, Set from typing import List, Optional, Set
from .fsutil import stat_triplet from .fsutil import stat_triplet, stat_triplet_from_stat
from .harvest_types import ExcludedFile, ManagedFile, ManagedLink from .harvest_types import ExcludedFile, ManagedFile, ManagedLink
from .ignore import IgnorePolicy from .ignore import IgnorePolicy
from .pathfilter import PathFilter from .pathfilter import PathFilter
@ -54,12 +54,69 @@ def files_differ(a: str, b: str, *, max_bytes: int = 2_000_000) -> bool:
return True return True
def copy_into_bundle( def _open_no_follow_write(path: str, mode: int = 0o600) -> int:
bundle_dir: str, role_name: str, abs_path: str, src_rel: str flags = os.O_WRONLY | os.O_CREAT | os.O_EXCL | getattr(os, "O_CLOEXEC", 0)
if hasattr(os, "O_NOFOLLOW"):
flags |= os.O_NOFOLLOW
return os.open(path, flags, mode)
def write_bytes_into_bundle(
bundle_dir: str, role_name: str, src_rel: str, data: bytes
) -> None: ) -> None:
dst = os.path.join(bundle_dir, "artifacts", role_name, src_rel) dst = os.path.join(bundle_dir, "artifacts", role_name, src_rel)
os.makedirs(os.path.dirname(dst), exist_ok=True) os.makedirs(os.path.dirname(dst), exist_ok=True)
shutil.copy2(abs_path, dst)
fd = -1
try:
fd = _open_no_follow_write(dst, 0o600)
with os.fdopen(fd, "wb") as f:
fd = -1
f.write(data)
try:
os.chmod(dst, 0o600)
except OSError:
pass
finally:
if fd >= 0:
os.close(fd)
def copy_into_bundle(
bundle_dir: str, role_name: str, abs_path: str, src_rel: str
) -> None:
"""Legacy safe copy helper used by tests and non-IgnorePolicy callers.
Real harvests using IgnorePolicy copy the exact bytes read from the safely
opened source file in capture_file(). This helper still refuses source
symlinks at copy time and refuses destination symlink overwrites.
"""
flags = os.O_RDONLY | getattr(os, "O_CLOEXEC", 0)
if hasattr(os, "O_NOFOLLOW"):
flags |= os.O_NOFOLLOW
fd = -1
try:
try:
fd = os.open(abs_path, flags)
except OSError as e:
if e.errno in {errno.ELOOP, errno.ENOTDIR}:
raise OSError("refusing to copy symlink source") from e
raise
st = os.fstat(fd)
if not stat.S_ISREG(st.st_mode):
raise OSError("refusing to copy non-regular source")
chunks: list[bytes] = []
while True:
chunk = os.read(fd, 1024 * 1024)
if not chunk:
break
chunks.append(chunk)
write_bytes_into_bundle(bundle_dir, role_name, src_rel, b"".join(chunks))
finally:
if fd >= 0:
os.close(fd)
def capture_file( def capture_file(
@ -99,16 +156,31 @@ def capture_file(
_mark_seen() _mark_seen()
return False return False
deny = policy.deny_reason(abs_path) inspection = None
inspect_file = getattr(policy, "inspect_file", None)
if callable(inspect_file):
inspected = inspect_file(abs_path)
if isinstance(inspected, tuple) and len(inspected) == 2:
deny, inspection = inspected
else:
# Some tests and third-party callers use MagicMock/spec policies that
# expose inspect_file but have not configured it. Fall back to the
# legacy deny_reason/copy path for those non-real policies.
deny = policy.deny_reason(abs_path)
else:
deny = policy.deny_reason(abs_path)
if deny: if deny:
excluded_out.append(ExcludedFile(path=abs_path, reason=deny)) excluded_out.append(ExcludedFile(path=abs_path, reason=deny))
_mark_seen() _mark_seen()
return False return False
try: try:
owner, group, mode = ( if metadata is not None:
metadata if metadata is not None else stat_triplet(abs_path) owner, group, mode = metadata
) elif inspection is not None:
owner, group, mode = stat_triplet_from_stat(inspection.stat_result)
else:
owner, group, mode = stat_triplet(abs_path)
except OSError: except OSError:
excluded_out.append(ExcludedFile(path=abs_path, reason="unreadable")) excluded_out.append(ExcludedFile(path=abs_path, reason="unreadable"))
_mark_seen() _mark_seen()
@ -116,7 +188,10 @@ def capture_file(
src_rel = abs_path.lstrip("/") src_rel = abs_path.lstrip("/")
try: try:
copy_into_bundle(bundle_dir, role_name, abs_path, src_rel) if inspection is not None:
write_bytes_into_bundle(bundle_dir, role_name, src_rel, inspection.data)
else:
copy_into_bundle(bundle_dir, role_name, abs_path, src_rel)
except OSError: except OSError:
excluded_out.append(ExcludedFile(path=abs_path, reason="unreadable")) excluded_out.append(ExcludedFile(path=abs_path, reason="unreadable"))
_mark_seen() _mark_seen()

View file

@ -928,6 +928,7 @@ def main() -> None:
no_sudo=bool(args.no_sudo), no_sudo=bool(args.no_sudo),
include_paths=list(getattr(args, "include_path", []) or []), include_paths=list(getattr(args, "include_path", []) or []),
exclude_paths=list(getattr(args, "exclude_path", []) or []), exclude_paths=list(getattr(args, "exclude_path", []) or []),
allow_existing_output=True,
) )
_encrypt_harvest_dir_to_sops( _encrypt_harvest_dir_to_sops(
tmp_bundle, out_file, list(sops_fps) tmp_bundle, out_file, list(sops_fps)
@ -954,6 +955,7 @@ def main() -> None:
no_sudo=bool(args.no_sudo), no_sudo=bool(args.no_sudo),
include_paths=list(getattr(args, "include_path", []) or []), include_paths=list(getattr(args, "include_path", []) or []),
exclude_paths=list(getattr(args, "exclude_path", []) or []), exclude_paths=list(getattr(args, "exclude_path", []) or []),
allow_existing_output=not bool(args.out),
) )
print(str(state)) print(str(state))
else: else:
@ -971,6 +973,7 @@ def main() -> None:
dangerous=bool(args.dangerous), dangerous=bool(args.dangerous),
include_paths=list(getattr(args, "include_path", []) or []), include_paths=list(getattr(args, "include_path", []) or []),
exclude_paths=list(getattr(args, "exclude_path", []) or []), exclude_paths=list(getattr(args, "exclude_path", []) or []),
allow_existing_output=True,
) )
_encrypt_harvest_dir_to_sops( _encrypt_harvest_dir_to_sops(
tmp_bundle, out_file, list(sops_fps) tmp_bundle, out_file, list(sops_fps)
@ -990,6 +993,7 @@ def main() -> None:
dangerous=bool(args.dangerous), dangerous=bool(args.dangerous),
include_paths=list(getattr(args, "include_path", []) or []), include_paths=list(getattr(args, "include_path", []) or []),
exclude_paths=list(getattr(args, "exclude_path", []) or []), exclude_paths=list(getattr(args, "exclude_path", []) or []),
allow_existing_output=not bool(args.out),
) )
print(path) print(path)
elif args.cmd == "explain": elif args.cmd == "explain":
@ -1164,6 +1168,7 @@ def main() -> None:
no_sudo=bool(args.no_sudo), no_sudo=bool(args.no_sudo),
include_paths=list(getattr(args, "include_path", []) or []), include_paths=list(getattr(args, "include_path", []) or []),
exclude_paths=list(getattr(args, "exclude_path", []) or []), exclude_paths=list(getattr(args, "exclude_path", []) or []),
allow_existing_output=True,
) )
_encrypt_harvest_dir_to_sops( _encrypt_harvest_dir_to_sops(
tmp_bundle, out_file, list(sops_fps) tmp_bundle, out_file, list(sops_fps)
@ -1201,6 +1206,7 @@ def main() -> None:
no_sudo=bool(args.no_sudo), no_sudo=bool(args.no_sudo),
include_paths=list(getattr(args, "include_path", []) or []), include_paths=list(getattr(args, "include_path", []) or []),
exclude_paths=list(getattr(args, "exclude_path", []) or []), exclude_paths=list(getattr(args, "exclude_path", []) or []),
allow_existing_output=not bool(args.harvest),
) )
manifest( manifest(
str(harvest_dir), str(harvest_dir),
@ -1228,6 +1234,7 @@ def main() -> None:
dangerous=bool(args.dangerous), dangerous=bool(args.dangerous),
include_paths=list(getattr(args, "include_path", []) or []), include_paths=list(getattr(args, "include_path", []) or []),
exclude_paths=list(getattr(args, "exclude_path", []) or []), exclude_paths=list(getattr(args, "exclude_path", []) or []),
allow_existing_output=True,
) )
_encrypt_harvest_dir_to_sops( _encrypt_harvest_dir_to_sops(
tmp_bundle, out_file, list(sops_fps) tmp_bundle, out_file, list(sops_fps)

View file

@ -5,6 +5,25 @@ import os
from typing import Tuple from typing import Tuple
def stat_triplet_from_stat(st: os.stat_result) -> Tuple[str, str, str]:
"""Return (owner, group, mode) for an existing stat result."""
mode = oct(st.st_mode & 0o7777)[2:].zfill(4)
import grp
import pwd
try:
owner = pwd.getpwuid(st.st_uid).pw_name
except KeyError:
owner = str(st.st_uid)
try:
group = grp.getgrgid(st.st_gid).gr_name
except KeyError:
group = str(st.st_gid)
return owner, group, mode
def file_md5(path: str) -> str: def file_md5(path: str) -> str:
"""Return hex MD5 of a file. """Return hex MD5 of a file.
@ -23,18 +42,4 @@ def stat_triplet(path: str) -> Tuple[str, str, str]:
owner/group are usernames/group names when resolvable, otherwise numeric ids. owner/group are usernames/group names when resolvable, otherwise numeric ids.
mode is a zero-padded octal string (e.g. "0644"). mode is a zero-padded octal string (e.g. "0644").
""" """
st = os.stat(path, follow_symlinks=True) return stat_triplet_from_stat(os.stat(path, follow_symlinks=True))
mode = oct(st.st_mode & 0o7777)[2:].zfill(4)
import grp
import pwd
try:
owner = pwd.getpwuid(st.st_uid).pw_name
except KeyError:
owner = str(st.st_uid)
try:
group = grp.getgrgid(st.st_gid).gr_name
except KeyError:
group = str(st.st_gid)
return owner, group, mode

View file

@ -15,6 +15,7 @@ from . import systemd as _systemd
from .fsutil import stat_triplet from .fsutil import stat_triplet
from .platform import detect_platform, get_backend from .platform import detect_platform, get_backend
from .ignore import IgnorePolicy from .ignore import IgnorePolicy
from .harvest_safety import ensure_private_empty_dir, prepare_new_private_dir
from .pathfilter import PathFilter from .pathfilter import PathFilter
from .version import get_enroll_version from .version import get_enroll_version
from .state import write_state from .state import write_state
@ -527,6 +528,7 @@ def harvest(
dangerous: bool = False, dangerous: bool = False,
include_paths: Optional[List[str]] = None, include_paths: Optional[List[str]] = None,
exclude_paths: Optional[List[str]] = None, exclude_paths: Optional[List[str]] = None,
allow_existing_output: bool = False,
) -> str: ) -> str:
# If a policy is not supplied, build one. `--dangerous` relaxes secret # If a policy is not supplied, build one. `--dangerous` relaxes secret
# detection and deny-glob skipping. # detection and deny-glob skipping.
@ -536,7 +538,12 @@ def harvest(
# If callers explicitly provided a policy but also requested # If callers explicitly provided a policy but also requested
# dangerous behaviour, honour the CLI intent. # dangerous behaviour, honour the CLI intent.
policy.dangerous = True policy.dangerous = True
os.makedirs(bundle_dir, exist_ok=True) bundle_path = (
ensure_private_empty_dir(bundle_dir, label="harvest output")
if allow_existing_output
else prepare_new_private_dir(bundle_dir, label="harvest output")
)
bundle_dir = str(bundle_path)
# User-provided includes/excludes. Excludes apply to all harvesting; # User-provided includes/excludes. Excludes apply to all harvesting;
# includes are harvested into an extra role. # includes are harvested into an extra role.

104
enroll/harvest_safety.py Normal file
View file

@ -0,0 +1,104 @@
from __future__ import annotations
import os
import stat
from pathlib import Path
class OutputSafetyError(RuntimeError):
"""Raised when an output path is unsafe for root-run plaintext output."""
def _chmod_private(path: Path) -> None:
try:
os.chmod(path, 0o700)
except OSError:
# Best-effort; callers still benefit from mkdir(mode=0o700) on normal FSes.
pass
def _assert_no_existing_symlink_components(path: Path, *, label: str) -> None:
"""Reject symlinks in existing parent components of an output path."""
parts = path.parts
if not parts:
return
if path.is_absolute():
cur = Path(parts[0])
rest = parts[1:-1]
else:
cur = Path.cwd()
rest = parts[:-1]
for part in rest:
cur = cur / part
if not os.path.lexists(cur):
return
try:
st = cur.lstat()
except OSError as e:
raise OutputSafetyError(f"unable to inspect {label} parent: {cur}") from e
if stat.S_ISLNK(st.st_mode):
raise OutputSafetyError(
f"{label} parent path contains a symlink; refusing: {cur}"
)
def prepare_new_private_dir(path: str | Path, *, label: str = "output") -> Path:
"""Create a brand-new private output directory.
Refuse existing paths, including symlinks. This prevents root-run harvests
from writing into attacker-precreated directories in shared locations such
as /tmp, and keeps plaintext bundles private by default.
"""
out = Path(path).expanduser()
_assert_no_existing_symlink_components(out, label=label)
if os.path.lexists(out):
raise OutputSafetyError(
f"{label} path already exists; refusing to overwrite or merge: {out}"
)
out.mkdir(parents=True, exist_ok=False, mode=0o700)
_chmod_private(out)
try:
st = out.lstat()
except OSError as e:
raise OutputSafetyError(f"unable to inspect {label} path: {out}") from e
if stat.S_ISLNK(st.st_mode) or not stat.S_ISDIR(st.st_mode):
raise OutputSafetyError(f"{label} path is not a real directory: {out}")
return out
def ensure_private_empty_dir(path: str | Path, *, label: str = "output") -> Path:
"""Create or validate a private empty directory.
This is for internally-generated random cache/temp directories. User-facing
--out paths should normally use prepare_new_private_dir() instead.
"""
out = Path(path).expanduser()
_assert_no_existing_symlink_components(out, label=label)
if os.path.lexists(out):
try:
st = out.lstat()
except OSError as e:
raise OutputSafetyError(f"unable to inspect {label} path: {out}") from e
if stat.S_ISLNK(st.st_mode):
raise OutputSafetyError(f"{label} path is a symlink; refusing: {out}")
if not stat.S_ISDIR(st.st_mode):
raise OutputSafetyError(
f"{label} path exists but is not a directory: {out}"
)
if any(out.iterdir()):
raise OutputSafetyError(
f"{label} path is not empty; refusing to merge: {out}"
)
_chmod_private(out)
return out
out.mkdir(parents=True, exist_ok=False, mode=0o700)
_chmod_private(out)
return out

View file

@ -1,8 +1,10 @@
from __future__ import annotations from __future__ import annotations
import fnmatch import fnmatch
import errno
import os import os
import re import re
import stat
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional from typing import Optional
@ -57,7 +59,13 @@ DEFAULT_ALLOW_BINARY_GLOBS = [
# aws_secret_access_key = ... # aws_secret_access_key = ...
# GOOGLE_APPLICATION_CREDENTIALS=/path/to/key.json # GOOGLE_APPLICATION_CREDENTIALS=/path/to/key.json
SENSITIVE_CONTENT_PATTERNS = [ SENSITIVE_CONTENT_PATTERNS = [
re.compile(rb"-----BEGIN (RSA |EC |OPENSSH |DSA |)PRIVATE KEY-----"), re.compile(
rb"-----BEGIN (?:RSA |EC |OPENSSH |DSA |ENCRYPTED |PGP )?PRIVATE KEY(?: BLOCK)?-----"
),
re.compile(rb"(?i)-----BEGIN OPENSSH PRIVATE KEY-----"),
re.compile(rb"(?i)AGE-SECRET-KEY-[A-Z0-9]+"),
re.compile(rb"(?i)OPENSSH PRIVATE KEY"),
re.compile(rb"(?i)PGP PRIVATE KEY BLOCK"),
re.compile( re.compile(
rb"""(?ix) rb"""(?ix)
(^|[^A-Za-z0-9]) (^|[^A-Za-z0-9])
@ -89,6 +97,14 @@ BLOCK_START = b"/*"
BLOCK_END = b"*/" BLOCK_END = b"*/"
@dataclass(frozen=True)
class FileInspection:
"""Bytes and metadata captured from one safely-opened source file."""
data: bytes
stat_result: os.stat_result
@dataclass @dataclass
class IgnorePolicy: class IgnorePolicy:
deny_globs: Optional[list[str]] = None deny_globs: Optional[list[str]] = None
@ -128,7 +144,7 @@ class IgnorePolicy:
yield raw yield raw
def deny_reason(self, path: str) -> Optional[str]: def _path_deny_reason(self, path: str) -> Optional[str]:
# Always ignore plain *.log files (rarely useful as config, often noisy). # Always ignore plain *.log files (rarely useful as config, often noisy).
if path.endswith(".log"): if path.endswith(".log"):
return "log_file" return "log_file"
@ -143,24 +159,9 @@ class IgnorePolicy:
for g in self.deny_globs or []: for g in self.deny_globs or []:
if fnmatch.fnmatch(path, g): if fnmatch.fnmatch(path, g):
return "denied_path" return "denied_path"
return None
try: def _content_deny_reason(self, path: str, data: bytes) -> Optional[str]:
st = os.stat(path, follow_symlinks=True)
except OSError:
return "unreadable"
if st.st_size > self.max_file_bytes:
return "too_large"
if not os.path.isfile(path) or os.path.islink(path):
return "not_regular_file"
try:
with open(path, "rb") as f:
data = f.read(min(self.sample_bytes, st.st_size))
except OSError:
return "unreadable"
if b"\x00" in data: if b"\x00" in data:
for g in self.allow_binary_globs or []: for g in self.allow_binary_globs or []:
if fnmatch.fnmatch(path, g): if fnmatch.fnmatch(path, g):
@ -176,6 +177,67 @@ class IgnorePolicy:
return None return None
def inspect_file(self, path: str) -> tuple[Optional[str], Optional[FileInspection]]:
"""Safely inspect a regular file and return the exact bytes to copy.
The source is opened with O_NOFOLLOW where available, fstat() is taken
from that file descriptor, and the whole file is read only after the
size cap passes. With the default 256 KiB cap this avoids a memory DoS
while ensuring secret scanning covers every byte that may be copied.
"""
deny = self._path_deny_reason(path)
if deny:
return deny, None
flags = os.O_RDONLY | getattr(os, "O_CLOEXEC", 0)
if hasattr(os, "O_NOFOLLOW"):
flags |= os.O_NOFOLLOW
fd: Optional[int] = None
try:
try:
fd = os.open(path, flags)
except OSError as e:
if e.errno in {errno.ELOOP, errno.ENOTDIR}:
return "not_regular_file", None
return "unreadable", None
try:
st = os.fstat(fd)
except OSError:
return "unreadable", None
if not stat.S_ISREG(st.st_mode):
return "not_regular_file", None
if st.st_size > self.max_file_bytes:
return "too_large", None
chunks: list[bytes] = []
remaining = int(st.st_size)
while remaining > 0:
chunk = os.read(fd, min(1024 * 1024, remaining))
if not chunk:
break
chunks.append(chunk)
remaining -= len(chunk)
data = b"".join(chunks)
deny = self._content_deny_reason(path, data)
if deny:
return deny, None
return None, FileInspection(data=data, stat_result=st)
finally:
if fd is not None:
try:
os.close(fd)
except OSError:
pass
def deny_reason(self, path: str) -> Optional[str]:
deny, _inspection = self.inspect_file(path)
return deny
def deny_reason_dir(self, path: str) -> Optional[str]: def deny_reason_dir(self, path: str) -> Optional[str]:
"""Directory-specific deny logic. """Directory-specific deny logic.

View file

@ -131,7 +131,11 @@ def prepare_manifest_output_dir(
) )
_assert_no_output_symlinks(out) _assert_no_output_symlinks(out)
return out return out
out.mkdir(parents=True, exist_ok=False) out.mkdir(parents=True, exist_ok=False, mode=0o700)
try:
os.chmod(out, 0o700)
except OSError:
pass
return out return out

View file

@ -13,6 +13,8 @@ from pathlib import Path
from pathlib import PurePosixPath from pathlib import PurePosixPath
from typing import Optional, Callable, TextIO from typing import Optional, Callable, TextIO
from .harvest_safety import ensure_private_empty_dir, prepare_new_private_dir
class RemoteSudoPasswordRequired(RuntimeError): class RemoteSudoPasswordRequired(RuntimeError):
"""Raised when sudo requires a password but none was provided.""" """Raised when sudo requires a password but none was provided."""
@ -139,12 +141,16 @@ def remote_harvest(
getpass_fn=getpass_fn, getpass_fn=getpass_fn,
) )
allow_existing_output = bool(kwargs.pop("allow_existing_output", False))
output_prepared = False
while True: while True:
try: try:
return _remote_harvest( return _remote_harvest(
sudo_password=sudo_password, sudo_password=sudo_password,
no_sudo=no_sudo, no_sudo=no_sudo,
ssh_key_passphrase=ssh_key_passphrase, ssh_key_passphrase=ssh_key_passphrase,
allow_existing_output=allow_existing_output or output_prepared,
**kwargs, **kwargs,
) )
except RemoteSSHKeyPassphraseRequired: except RemoteSSHKeyPassphraseRequired:
@ -158,6 +164,7 @@ def remote_harvest(
# Fallback prompt if interactive. # Fallback prompt if interactive.
if stdin is not None and getattr(stdin, "isatty", lambda: False)(): if stdin is not None and getattr(stdin, "isatty", lambda: False)():
ssh_key_passphrase = getpass_fn(key_prompt) ssh_key_passphrase = getpass_fn(key_prompt)
output_prepared = True
continue continue
raise RemoteSSHKeyPassphraseRequired( raise RemoteSSHKeyPassphraseRequired(
@ -173,6 +180,7 @@ def remote_harvest(
# Fallback prompt if interactive. # Fallback prompt if interactive.
if stdin is not None and getattr(stdin, "isatty", lambda: False)(): if stdin is not None and getattr(stdin, "isatty", lambda: False)():
sudo_password = getpass_fn(prompt) sudo_password = getpass_fn(prompt)
output_prepared = True
continue continue
raise RemoteSudoPasswordRequired( raise RemoteSudoPasswordRequired(
@ -413,6 +421,7 @@ def _remote_harvest(
ssh_key_passphrase: Optional[str] = None, ssh_key_passphrase: Optional[str] = None,
include_paths: Optional[list[str]] = None, include_paths: Optional[list[str]] = None,
exclude_paths: Optional[list[str]] = None, exclude_paths: Optional[list[str]] = None,
allow_existing_output: bool = False,
) -> Path: ) -> Path:
"""Run enroll harvest on a remote host via SSH and pull the bundle locally. """Run enroll harvest on a remote host via SSH and pull the bundle locally.
@ -426,12 +435,11 @@ def _remote_harvest(
"Install it with: pip install paramiko" "Install it with: pip install paramiko"
) from e ) from e
local_out_dir = Path(local_out_dir) local_out_dir = (
local_out_dir.mkdir(parents=True, exist_ok=True) ensure_private_empty_dir(local_out_dir, label="remote harvest output")
try: if allow_existing_output
os.chmod(local_out_dir, 0o700) else prepare_new_private_dir(local_out_dir, label="remote harvest output")
except OSError: )
pass
# Build a zipapp locally and upload it to the remote. # Build a zipapp locally and upload it to the remote.
with tempfile.TemporaryDirectory(prefix="enroll-remote-") as td: with tempfile.TemporaryDirectory(prefix="enroll-remote-") as td:

View file

@ -1,6 +1,8 @@
from __future__ import annotations from __future__ import annotations
import json import json
import os
import tempfile
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Mapping, Union from typing import Any, Dict, Mapping, Union
@ -31,8 +33,34 @@ def write_state(
"""Write state.json to a harvest bundle directory and return its path.""" """Write state.json to a harvest bundle directory and return its path."""
path = state_path(bundle_dir) path = state_path(bundle_dir)
with open(path, "w", encoding="utf-8") as f: path.parent.mkdir(parents=True, exist_ok=True)
json.dump(state, f, indent=indent, sort_keys=sort_keys)
fd = -1
tmp_name = ""
try:
fd, tmp_name = tempfile.mkstemp(
prefix=f".{path.name}.", suffix=".tmp", dir=str(path.parent), text=True
)
try:
os.fchmod(fd, 0o600)
except OSError:
pass
with os.fdopen(fd, "w", encoding="utf-8") as f:
fd = -1
json.dump(state, f, indent=indent, sort_keys=sort_keys)
os.replace(tmp_name, path)
try:
os.chmod(path, 0o600)
except OSError:
pass
finally:
if fd >= 0:
os.close(fd)
if tmp_name:
try:
os.unlink(tmp_name)
except FileNotFoundError:
pass
return path return path

View file

@ -0,0 +1,112 @@
from __future__ import annotations
import os
from pathlib import Path
import pytest
from enroll.capture import capture_file
from enroll.harvest import harvest
from enroll.harvest_types import ExcludedFile, ManagedFile
from enroll.ignore import FileInspection, IgnorePolicy
from enroll.manifest_safety import prepare_manifest_output_dir
from enroll.harvest_safety import OutputSafetyError, prepare_new_private_dir
from enroll.pathfilter import PathFilter
class _RacePolicy(IgnorePolicy):
def inspect_file(self, path: str):
fd = os.open(path, os.O_RDONLY | getattr(os, "O_CLOEXEC", 0))
try:
st = os.fstat(fd)
data = os.read(fd, st.st_size)
finally:
os.close(fd)
Path(path).write_bytes(b"changed-after-inspection")
return None, FileInspection(data=data, stat_result=st)
def test_prepare_new_private_dir_refuses_existing_path(tmp_path: Path):
out = tmp_path / "bundle"
out.mkdir()
with pytest.raises(OutputSafetyError, match="already exists"):
prepare_new_private_dir(out, label="harvest output")
def test_prepare_new_private_dir_creates_0700(tmp_path: Path):
out = prepare_new_private_dir(tmp_path / "bundle", label="harvest output")
assert out.exists()
assert (out.stat().st_mode & 0o777) == 0o700
def test_harvest_refuses_existing_plaintext_output_dir(tmp_path: Path):
out = tmp_path / "bundle"
out.mkdir()
with pytest.raises(OutputSafetyError, match="already exists"):
harvest(str(out))
def test_manifest_output_dir_is_private_by_default(tmp_path: Path):
out = prepare_manifest_output_dir(tmp_path / "manifest")
assert (out.stat().st_mode & 0o777) == 0o700
def test_capture_file_writes_inspected_bytes_not_later_source(tmp_path: Path):
source = tmp_path / "source.conf"
source.write_bytes(b"safe-original")
bundle = tmp_path / "bundle"
bundle.mkdir()
managed: list[ManagedFile] = []
excluded: list[ExcludedFile] = []
ok = capture_file(
bundle_dir=str(bundle),
role_name="role",
abs_path=str(source),
reason="test",
policy=_RacePolicy(),
path_filter=PathFilter(),
managed_out=managed,
excluded_out=excluded,
)
assert ok is True
artifact = bundle / "artifacts" / "role" / str(source).lstrip("/")
assert artifact.read_bytes() == b"safe-original"
assert source.read_bytes() == b"changed-after-inspection"
def test_capture_file_rejects_symlink_source_with_ignore_policy(tmp_path: Path):
target = tmp_path / "target.conf"
target.write_text("safe=true\n", encoding="utf-8")
link = tmp_path / "link.conf"
link.symlink_to(target)
bundle = tmp_path / "bundle"
bundle.mkdir()
managed: list[ManagedFile] = []
excluded: list[ExcludedFile] = []
ok = capture_file(
bundle_dir=str(bundle),
role_name="role",
abs_path=str(link),
reason="test",
policy=IgnorePolicy(),
path_filter=PathFilter(),
managed_out=managed,
excluded_out=excluded,
)
assert ok is False
assert managed == []
assert excluded and excluded[0].reason == "not_regular_file"
def test_prepare_new_private_dir_rejects_symlink_parent(tmp_path: Path):
real = tmp_path / "real"
real.mkdir()
link = tmp_path / "link"
link.symlink_to(real, target_is_directory=True)
with pytest.raises(OutputSafetyError, match="parent path contains a symlink"):
prepare_new_private_dir(link / "bundle", label="harvest output")

View file

@ -282,3 +282,27 @@ def test_deny_reason_shadow_backup():
pol = IgnorePolicy() pol = IgnorePolicy()
assert pol.deny_reason("/etc/shadow-") == "backup_file" assert pol.deny_reason("/etc/shadow-") == "backup_file"
assert pol.deny_reason("/etc/passwd-") == "backup_file" assert pol.deny_reason("/etc/passwd-") == "backup_file"
def test_detects_encrypted_private_key_marker(tmp_path):
p = tmp_path / "key.pem"
p.write_text(
"-----BEGIN ENCRYPTED PRIVATE KEY-----\nabc\n-----END ENCRYPTED PRIVATE KEY-----\n",
encoding="utf-8",
)
assert IgnorePolicy().deny_reason(str(p)) == "sensitive_content"
def test_detects_pgp_private_key_marker(tmp_path):
p = tmp_path / "pgp.asc"
p.write_text(
"-----BEGIN PGP PRIVATE KEY BLOCK-----\nabc\n-----END PGP PRIVATE KEY BLOCK-----\n",
encoding="utf-8",
)
assert IgnorePolicy().deny_reason(str(p)) == "sensitive_content"
def test_secret_scan_reads_whole_file_under_size_cap(tmp_path):
p = tmp_path / "large.conf"
p.write_bytes(b"A" * 70_000 + b"\nlate_token = abc123\n")
assert IgnorePolicy().deny_reason(str(p)) == "sensitive_content"