diff --git a/enroll/harvest_collectors/paths.py b/enroll/harvest_collectors/paths.py index f11896a..eda2d41 100644 --- a/enroll/harvest_collectors/paths.py +++ b/enroll/harvest_collectors/paths.py @@ -5,12 +5,13 @@ import os from typing import Dict, List, Optional, Set from .. import harvest as h -from ..capture import capture_file +from ..capture import capture_file, capture_link from ..harvest_types import ( ExcludedFile, ExtraPathsSnapshot, ManagedDir, ManagedFile, + ManagedLink, UsrLocalCustomSnapshot, ) from ..system_paths import MAX_FILES_CAP @@ -132,6 +133,7 @@ class ExtraPathsCollector(HarvestCollector): self.notes: List[str] = [] self.excluded: List[ExcludedFile] = [] self.managed: List[ManagedFile] = [] + self.managed_links: List[ManagedLink] = [] self.managed_dirs: List[ManagedDir] = [] self.dir_seen: Set[str] = set() @@ -178,28 +180,53 @@ class ExtraPathsCollector(HarvestCollector): exclude_patterns=self.exclude_specs, managed_dirs=self.managed_dirs, managed_files=self.managed, + managed_links=self.managed_links, excluded=self.excluded, notes=self.notes, ) def _collect_included_dirs(self) -> None: + role_seen = self.seen_by_role.setdefault(self.role_name, set()) for pat in self.context.path_filter.iter_include_patterns(): if pat.kind == "prefix": path = pat.value - if os.path.isdir(path) and not os.path.islink(path): - self._walk_and_capture_dirs(path) + if os.path.islink(path): + self._capture_included_link(path, role_seen) + elif os.path.isdir(path): + self._walk_and_capture_dirs(path, role_seen) elif pat.kind == "glob": for hit in glob.glob(pat.value, recursive=True): - if os.path.isdir(hit) and not os.path.islink(hit): - self._walk_and_capture_dirs(hit) + if os.path.islink(hit): + self._capture_included_link(hit, role_seen) + elif os.path.isdir(hit): + self._walk_and_capture_dirs(hit, role_seen) - def _walk_and_capture_dirs(self, root: str) -> None: + def _capture_included_link(self, path: str, role_seen: Set[str]) -> None: + path = os.path.normpath(path) + if not path.startswith("/"): + path = "/" + path + if path in self.already_all: + return + if capture_link( + role_name=self.role_name, + abs_path=path, + reason="user_include_link", + policy=self.context.policy, + path_filter=self.context.path_filter, + managed_out=self.managed_links, + excluded_out=self.excluded, + seen_role=role_seen, + seen_global=self.context.captured_global, + ): + self.already_all.add(path) + + def _walk_and_capture_dirs(self, root: str, role_seen: Set[str]) -> None: root = os.path.normpath(root) if not root.startswith("/"): root = "/" + root if not os.path.isdir(root) or os.path.islink(root): return - for dirpath, dirnames, _ in os.walk(root, followlinks=False): + for dirpath, dirnames, filenames in os.walk(root, followlinks=False): if len(self.managed_dirs) >= MAX_FILES_CAP: self.notes.append( f"Reached directory cap ({MAX_FILES_CAP}) while scanning {root}." @@ -243,7 +270,17 @@ class ExtraPathsCollector(HarvestCollector): pruned: List[str] = [] for dirname in dirnames: path = os.path.join(dirpath, dirname) - if os.path.islink(path) or self.context.path_filter.is_excluded(path): + if self.context.path_filter.is_excluded(path): + continue + if os.path.islink(path): + self._capture_included_link(path, role_seen) continue pruned.append(dirname) dirnames[:] = pruned + + for filename in filenames: + path = os.path.join(dirpath, filename) + if self.context.path_filter.is_excluded(path): + continue + if os.path.islink(path): + self._capture_included_link(path, role_seen) diff --git a/tests/test_harvest_collectors.py b/tests/test_harvest_collectors.py index 94b5259..f4de696 100644 --- a/tests/test_harvest_collectors.py +++ b/tests/test_harvest_collectors.py @@ -394,3 +394,53 @@ def test_usr_local_custom_collector_scans_executable_bin_and_notes_cap( "usr_local_etc_custom", "usr_local_bin_script", ] + + +def test_extra_paths_collector_records_symlinks_without_following(tmp_path): + root = tmp_path / "include" + root.mkdir() + real_file = root / "real.conf" + real_file.write_text("ok", encoding="utf-8") + (root / "link.conf").symlink_to("real.conf") + + outside = tmp_path / "outside" + outside.mkdir() + (outside / "outside.conf").write_text("do-not-follow", encoding="utf-8") + (root / "shared").symlink_to(outside, target_is_directory=True) + + ctx = _context(tmp_path, include=[str(root)]) + result = ExtraPathsCollector( + ctx, + seen_by_role={}, + already_all=set(), + include_paths=[str(root)], + ).collect() + + links = {(link.path, link.target, link.reason) for link in result.managed_links} + assert (str(root / "link.conf"), "real.conf", "user_include_link") in links + assert (str(root / "shared"), str(outside), "user_include_link") in links + + managed_files = {mf.path for mf in result.managed_files} + assert str(real_file) in managed_files + assert str(outside / "outside.conf") not in managed_files + + +def test_extra_paths_collector_records_include_path_that_is_symlink(tmp_path): + real_root = tmp_path / "real" + real_root.mkdir() + (real_root / "inside.conf").write_text("do-not-follow", encoding="utf-8") + link_root = tmp_path / "linked-root" + link_root.symlink_to(real_root, target_is_directory=True) + + ctx = _context(tmp_path, include=[str(link_root)]) + result = ExtraPathsCollector( + ctx, + seen_by_role={}, + already_all=set(), + include_paths=[str(link_root)], + ).collect() + + assert [(link.path, link.target, link.reason) for link in result.managed_links] == [ + (str(link_root), str(real_root), "user_include_link") + ] + assert result.managed_files == []