diff --git a/enroll/state.py b/enroll/state.py index 8d469b6..633a22f 100644 --- a/enroll/state.py +++ b/enroll/state.py @@ -2,13 +2,25 @@ from __future__ import annotations import json import os +import stat import tempfile from pathlib import Path -from typing import Any, Dict, Mapping, Union +from typing import Any, Dict, Mapping, TextIO, Union + +from .fsutil import open_no_follow_path BundlePath = Union[str, Path] State = Dict[str, Any] +# state.json should contain structured metadata, not harvested file content. Keep +# this generous so large package inventories still work while rejecting obvious +# accidental/malicious memory-exhaustion inputs. +MAX_STATE_JSON_BYTES = 16 * 1024 * 1024 + + +class StateSafetyError(RuntimeError): + """Raised when a harvest bundle's state.json is unsafe to parse.""" + def state_path(bundle_dir: BundlePath) -> Path: """Return the canonical state.json path for a harvest bundle.""" @@ -16,10 +28,67 @@ def state_path(bundle_dir: BundlePath) -> Path: return Path(bundle_dir) / "state.json" +def _check_state_stat(path: Path, st: os.stat_result, *, max_bytes: int) -> None: + if stat.S_ISLNK(st.st_mode): + raise StateSafetyError(f"state.json is a symlink; refusing to read: {path}") + if not stat.S_ISREG(st.st_mode): + raise StateSafetyError(f"state.json is not a regular file: {path}") + if st.st_nlink > 1: + raise StateSafetyError(f"state.json is hardlinked; refusing to read: {path}") + if st.st_size > max_bytes: + raise StateSafetyError( + f"state.json is too large to parse safely " + f"({st.st_size} bytes > {max_bytes} bytes): {path}" + ) + + +def open_state_file(bundle_dir: BundlePath, *, max_bytes: int | None = None) -> TextIO: + """Open state.json only after verifying it is safe to parse. + + Direct directory bundles are more mutable than SOPS/tar/remote bundles, so do + not follow a symlinked state.json and do not parse special files, hardlinks, or + unexpectedly huge inputs. The final open also uses no-follow semantics and the + inode is compared with the pre-open lstat result to catch swaps between the + check and open. + """ + + if max_bytes is None: + max_bytes = MAX_STATE_JSON_BYTES + + path = state_path(bundle_dir) + try: + pre = path.lstat() + except FileNotFoundError: + raise FileNotFoundError(f"missing state.json: {path}") + + _check_state_stat(path, pre, max_bytes=max_bytes) + + fd = -1 + try: + fd = open_no_follow_path(str(path), write=False) + opened = os.fstat(fd) + if (opened.st_dev, opened.st_ino) != (pre.st_dev, pre.st_ino): + raise StateSafetyError( + f"state.json changed while it was being opened; refusing to read: {path}" + ) + _check_state_stat(path, opened, max_bytes=max_bytes) + f = os.fdopen(fd, "r", encoding="utf-8") + fd = -1 + return f + except OSError as e: + raise StateSafetyError(f"unable to safely open state.json: {path}: {e}") from e + finally: + if fd >= 0: + try: + os.close(fd) + except OSError: + pass + + def load_state(bundle_dir: BundlePath) -> State: """Load state.json from a harvest bundle directory.""" - with open(state_path(bundle_dir), "r", encoding="utf-8") as f: + with open_state_file(bundle_dir) as f: return json.load(f) diff --git a/tests/test_state_safety.py b/tests/test_state_safety.py new file mode 100644 index 0000000..5af9ff1 --- /dev/null +++ b/tests/test_state_safety.py @@ -0,0 +1,49 @@ +from __future__ import annotations + +import json +import os +from pathlib import Path + +import pytest + +from enroll.state import StateSafetyError, load_state, open_state_file + + +def test_load_state_reads_regular_state_json(tmp_path: Path): + (tmp_path / "state.json").write_text( + json.dumps({"host": {"hostname": "test-host"}}), encoding="utf-8" + ) + + assert load_state(tmp_path)["host"]["hostname"] == "test-host" + + +def test_load_state_rejects_state_json_symlink(tmp_path: Path): + target = tmp_path / "target.json" + target.write_text("{}", encoding="utf-8") + (tmp_path / "state.json").symlink_to(target) + + with pytest.raises(StateSafetyError, match="state.json is a symlink"): + load_state(tmp_path) + + +def test_load_state_rejects_non_regular_state_json(tmp_path: Path): + (tmp_path / "state.json").mkdir() + + with pytest.raises(StateSafetyError, match="state.json is not a regular file"): + load_state(tmp_path) + + +def test_load_state_rejects_hardlinked_state_json(tmp_path: Path): + state_file = tmp_path / "state.json" + state_file.write_text("{}", encoding="utf-8") + os.link(state_file, tmp_path / "state-copy.json") + + with pytest.raises(StateSafetyError, match="state.json is hardlinked"): + load_state(tmp_path) + + +def test_open_state_file_rejects_oversized_state_json(tmp_path: Path): + (tmp_path / "state.json").write_text("{}", encoding="utf-8") + + with pytest.raises(StateSafetyError, match="state.json is too large"): + open_state_file(tmp_path, max_bytes=1)