validation of artifact dir

This commit is contained in:
Miguel Jacq 2026-06-22 17:23:25 +10:00
parent efb6d7cc15
commit 992b8060a5
Signed by: mig5
GPG key ID: 03906B4110AAD3B8
3 changed files with 122 additions and 56 deletions

View file

@ -69,6 +69,8 @@ def _assert_existing_output_dir_component(path: Path, *, label: str) -> None:
raise OutputSafetyError(
f"{label} parent path contains a symlink; refusing: {path}"
)
if not stat.S_ISDIR(st.st_mode):
raise OutputSafetyError(f"{label} parent is not a directory: {path}")
_assert_trusted_root_parent(path, st, label=label)
@ -214,6 +216,24 @@ def write_text_output_file(
return out
def ensure_private_dir(path: str | Path, *, label: str = "output") -> Path:
"""Create or validate a private directory without requiring it to be empty.
This is for persistent internal directories such as Enroll's cache root,
where existing contents are expected across runs. It uses the same
component-by-component symlink and root-parent trust checks as user-facing
plaintext output directories, but permits an existing final directory.
"""
out = Path(path).expanduser()
sentinel = out / ".enroll-private-dir-check"
_assert_no_existing_symlink_components(sentinel, label=label)
out = _mkdir_private_dir_tree(out, label=label, final_must_be_new=False)
_assert_no_existing_symlink_components(sentinel, label=label)
_chmod_private(out)
return out
def prepare_new_private_dir(path: str | Path, *, label: str = "output") -> Path:
"""Create a brand-new private output directory.

View file

@ -232,7 +232,17 @@ def validate_harvest(
# Validate the whole artifact tree too, so unreferenced symlinks,
# hardlinks, special files, and path-shaping tricks do not survive
# validation simply because no managed_file currently references them.
if artifacts_dir.exists() and artifacts_dir.is_dir():
if artifacts_dir.exists():
try:
artifacts_st = artifacts_dir.lstat()
except OSError as e:
errors.append(f"unable to inspect artifacts directory: {e}")
else:
if stat.S_ISLNK(artifacts_st.st_mode):
errors.append(f"artifacts directory is a symlink: {artifacts_dir}")
elif not stat.S_ISDIR(artifacts_st.st_mode):
errors.append(f"artifacts path is not a directory: {artifacts_dir}")
else:
for root, dirs, files in os.walk(artifacts_dir, followlinks=False):
root_p = Path(root)
for name in list(dirs):
@ -244,7 +254,9 @@ def validate_harvest(
if stat.S_ISLNK(st.st_mode):
errors.append(f"artifact directory is a symlink: {fp}")
elif not stat.S_ISDIR(st.st_mode):
errors.append(f"artifact directory is not a directory: {fp}")
errors.append(
f"artifact directory is not a directory: {fp}"
)
for name in files:
fp = root_p / name
@ -259,7 +271,9 @@ def validate_harvest(
continue
parts = rel.parts
if len(parts) < 2:
errors.append(f"artifact is not under a role directory: {fp}")
errors.append(
f"artifact is not under a role directory: {fp}"
)
continue
role_name = parts[0]
src_rel = "/".join(parts[1:])

View file

@ -453,3 +453,35 @@ def test_validate_harvest_rejects_unreferenced_artifact_symlink(tmp_path: Path):
assert result.ok is False
assert any("symlink" in e for e in result.errors)
def test_validate_harvest_rejects_top_level_artifacts_symlink(tmp_path: Path):
bundle_dir = tmp_path / "bundle"
bundle_dir.mkdir()
target = tmp_path / "artifact-target"
target.mkdir()
(bundle_dir / "artifacts").symlink_to(target, target_is_directory=True)
(bundle_dir / "state.json").write_text(
json.dumps({"roles": {"users": {"managed_files": []}}}),
encoding="utf-8",
)
result = validate_harvest(str(bundle_dir), no_schema=True)
assert result.ok is False
assert any("artifacts directory is a symlink" in e for e in result.errors)
def test_validate_harvest_rejects_top_level_artifacts_file(tmp_path: Path):
bundle_dir = tmp_path / "bundle"
bundle_dir.mkdir()
(bundle_dir / "artifacts").write_text("not a directory", encoding="utf-8")
(bundle_dir / "state.json").write_text(
json.dumps({"roles": {"users": {"managed_files": []}}}),
encoding="utf-8",
)
result = validate_harvest(str(bundle_dir), no_schema=True)
assert result.ok is False
assert any("artifacts path is not a directory" in e for e in result.errors)