Initial pass at an --enforce mode for enroll diff, to manifest and restore state of old harvest if ansible is on the PATH
All checks were successful
CI / test (push) Successful in 8m13s
Lint / test (push) Successful in 33s
Trivy / test (push) Successful in 23s

This commit is contained in:
Miguel Jacq 2026-01-10 09:50:28 +11:00
parent 9749190cd8
commit 9a249cc973
Signed by: mig5
GPG key ID: 59B3F0C24135C6A9
2 changed files with 266 additions and 1 deletions

View file

@ -11,7 +11,14 @@ from pathlib import Path
from typing import Optional
from .cache import new_harvest_cache_dir
from .diff import compare_harvests, format_report, post_webhook, send_email
from .diff import (
compare_harvests,
enforce_old_harvest,
format_report,
has_enforceable_drift,
post_webhook,
send_email,
)
from .explain import explain_state
from .harvest import harvest
from .manifest import manifest
@ -560,6 +567,15 @@ def main() -> None:
"This affects file drift reporting only (added/removed/changed files), not package/service/user diffs."
),
)
d.add_argument(
"--enforce",
action="store_true",
help=(
"If differences are detected, attempt to enforce the old harvest state locally by generating a manifest and "
"running ansible-playbook. Requires ansible-playbook on PATH. "
"Enroll does not attempt to downgrade packages; if the only drift is package version upgrades (or newly installed packages), enforcement is skipped."
),
)
d.add_argument(
"--out",
help="Write the report to this file instead of stdout.",
@ -840,6 +856,40 @@ def main() -> None:
exclude_paths=list(getattr(args, "exclude_path", []) or []),
)
# Optional enforcement: if drift is detected, attempt to restore the
# system to the *old* (baseline) state using ansible-playbook.
if bool(getattr(args, "enforce", False)):
if has_changes:
if not has_enforceable_drift(report):
report["enforcement"] = {
"requested": True,
"status": "skipped",
"reason": (
"no enforceable drift detected (only package additions and/or version changes); "
"enroll does not attempt to downgrade packages"
),
}
else:
try:
info = enforce_old_harvest(
args.old,
sops_mode=bool(getattr(args, "sops", False)),
)
except Exception as e:
raise SystemExit(
f"error: could not enforce old harvest state: {e}"
) from e
report["enforcement"] = {
"requested": True,
**(info or {}),
}
else:
report["enforcement"] = {
"requested": True,
"status": "skipped",
"reason": "no differences detected",
}
txt = format_report(report, fmt=str(getattr(args, "format", "text")))
out_path = getattr(args, "out", None)
if out_path:

View file

@ -529,6 +529,162 @@ def compare_harvests(
return report, has_changes
def _tail_text(s: str, *, max_chars: int = 4000) -> str:
s = s or ""
if len(s) <= max_chars:
return s
return "" + s[-max_chars:]
def has_enforceable_drift(report: Dict[str, Any]) -> bool:
"""Return True if the diff report contains drift that is safe/meaningful to enforce.
Enforce mode is intended to restore *state* (files/users/services) and to
reinstall packages that were removed.
It is deliberately conservative about package drift:
- Package *version* changes alone are not enforced (no downgrades).
- Newly installed packages are not removed.
This helper lets the CLI decide whether `--enforce` should actually run.
"""
pk = report.get("packages", {}) or {}
if pk.get("removed"):
return True
sv = report.get("services", {}) or {}
if (sv.get("enabled_added") or []) or (sv.get("enabled_removed") or []):
return True
for ch in sv.get("changed", []) or []:
changes = ch.get("changes") or {}
# Ignore package set drift for enforceability decisions; package
# enforcement is handled via reinstalling removed packages, and we
# avoid trying to "undo" upgrades/renames.
for k in changes.keys():
if k != "packages":
return True
us = report.get("users", {}) or {}
if (
(us.get("added") or [])
or (us.get("removed") or [])
or (us.get("changed") or [])
):
return True
fl = report.get("files", {}) or {}
if (
(fl.get("added") or [])
or (fl.get("removed") or [])
or (fl.get("changed") or [])
):
return True
return False
def enforce_old_harvest(
old_path: str,
*,
sops_mode: bool = False,
) -> Dict[str, Any]:
"""Enforce the *old* (baseline) harvest state on the current machine.
When Ansible is available, this:
1) renders a temporary manifest from the old harvest, and
2) runs ansible-playbook locally to apply it.
Returns a dict suitable for attaching to the diff report under
report['enforcement'].
"""
ansible_playbook = shutil.which("ansible-playbook")
if not ansible_playbook:
raise RuntimeError(
"ansible-playbook not found on PATH (cannot enforce; install Ansible)"
)
# Import lazily to avoid heavy import cost and potential CLI cycles.
from .manifest import manifest
started_at = _utc_now_iso()
with ExitStack() as stack:
old_b = _bundle_from_input(old_path, sops_mode=sops_mode)
if old_b.tempdir:
stack.callback(old_b.tempdir.cleanup)
with tempfile.TemporaryDirectory(prefix="enroll-enforce-") as td:
td_path = Path(td)
try:
os.chmod(td_path, 0o700)
except OSError:
pass
# 1) Generate a manifest in a temp directory.
manifest(str(old_b.dir), str(td_path))
playbook = td_path / "playbook.yml"
if not playbook.exists():
raise RuntimeError(
f"manifest did not produce expected playbook.yml at {playbook}"
)
# 2) Apply it locally.
env = dict(os.environ)
cfg = td_path / "ansible.cfg"
if cfg.exists():
env["ANSIBLE_CONFIG"] = str(cfg)
cmd = [
ansible_playbook,
"-i",
"localhost,",
"-c",
"local",
str(playbook),
]
p = subprocess.run(
cmd,
cwd=str(td_path),
env=env,
capture_output=True,
text=True,
check=False,
) # nosec
finished_at = _utc_now_iso()
info: Dict[str, Any] = {
"status": "applied" if p.returncode == 0 else "failed",
"started_at": started_at,
"finished_at": finished_at,
"ansible_playbook": ansible_playbook,
"command": cmd,
"returncode": int(p.returncode),
}
# Include a small tail for observability in webhooks/emails.
if p.stdout:
info["stdout_tail"] = _tail_text(p.stdout)
if p.stderr:
info["stderr_tail"] = _tail_text(p.stderr)
if p.returncode != 0:
err = (p.stderr or p.stdout or "").strip()
if err:
err = _tail_text(err)
raise RuntimeError(
"ansible-playbook failed"
+ (f" (rc={p.returncode})" if p.returncode is not None else "")
+ (f": {err}" if err else "")
)
return info
def format_report(report: Dict[str, Any], *, fmt: str = "text") -> str:
fmt = (fmt or "text").lower()
if fmt == "json":
@ -553,6 +709,30 @@ def _report_text(report: Dict[str, Any]) -> str:
if ex_paths:
lines.append(f"file exclude patterns: {', '.join(str(p) for p in ex_paths)}")
enf = report.get("enforcement") or {}
if enf:
lines.append("\nEnforcement")
status = str(enf.get("status") or "").strip().lower()
if status == "applied":
lines.append(
f" applied old harvest via ansible-playbook (rc={enf.get('returncode')})"
+ (
f" (finished {enf.get('finished_at')})"
if enf.get("finished_at")
else ""
)
)
elif status == "failed":
lines.append(
f" attempted enforcement but ansible-playbook failed (rc={enf.get('returncode')})"
)
elif status == "skipped":
r = enf.get("reason")
lines.append(" skipped" + (f": {r}" if r else ""))
else:
# Best-effort formatting for future fields.
lines.append(" " + json.dumps(enf, sort_keys=True))
pk = report.get("packages", {})
lines.append("\nPackages")
lines.append(f" added: {len(pk.get('added', []) or [])}")
@ -668,6 +848,41 @@ def _report_markdown(report: Dict[str, Any]) -> str:
+ "\n"
)
enf = report.get("enforcement") or {}
if enf:
out.append("\n## Enforcement\n")
status = str(enf.get("status") or "").strip().lower()
if status == "applied":
out.append(
"- ✅ Applied old harvest via ansible-playbook"
+ (
f" (rc={enf.get('returncode')})"
if enf.get("returncode") is not None
else ""
)
+ (
f" (finished `{enf.get('finished_at')}`)"
if enf.get("finished_at")
else ""
)
+ "\n"
)
elif status == "failed":
out.append(
"- ⚠️ Attempted enforcement but ansible-playbook failed"
+ (
f" (rc={enf.get('returncode')})"
if enf.get("returncode") is not None
else ""
)
+ "\n"
)
elif status == "skipped":
r = enf.get("reason")
out.append("- Skipped" + (f": {r}" if r else "") + "\n")
else:
out.append(f"- {json.dumps(enf, sort_keys=True)}\n")
pk = report.get("packages", {})
out.append("## Packages\n")
out.append(f"- Added: {len(pk.get('added', []) or [])}\n")