From d96ad3dc02bb6c5e7ab9b0e793a6beae19454bd1 Mon Sep 17 00:00:00 2001 From: Miguel Jacq Date: Mon, 22 Jun 2026 20:26:06 +1000 Subject: [PATCH] Some more hardening to not process raw jinja inside salt/ansible cmd. But, I think this is the end of the road --- enroll/ansible.py | 5 +- enroll/puppet.py | 5 +- enroll/render_safety.py | 232 +++++++++++++++++++++++++++++++++ enroll/salt.py | 12 +- enroll/yamlutil.py | 35 ++++- tests/test_manifest.py | 12 ++ tests/test_manifest_ansible.py | 71 ++++++++++ tests/test_manifest_puppet.py | 69 ++++++++++ tests/test_manifest_salt.py | 79 +++++++++++ 9 files changed, 508 insertions(+), 12 deletions(-) create mode 100644 enroll/render_safety.py diff --git a/enroll/ansible.py b/enroll/ansible.py index e0fcd0c..2eaec0a 100644 --- a/enroll/ansible.py +++ b/enroll/ansible.py @@ -18,6 +18,7 @@ from .manifest_safety import ( iter_safe_artifact_files, prepare_manifest_output_dir, ) +from .render_safety import ansible_unsafe_data from .role_names import avoid_reserved_role_name from .state import inventory_packages_from_state, roles_from_state from .yamlutil import yaml_dump_mapping, yaml_load_mapping @@ -688,7 +689,7 @@ def _write_hostvars(site_root: str, fqdn: str, role: str, data: Dict[str, Any]) except Exception: existing_map = {} - merged = _merge_mappings_overwrite(existing_map, data) + merged = _merge_mappings_overwrite(existing_map, ansible_unsafe_data(data)) out = "---\n" + yaml_dump_mapping(merged, sort_keys=True) with open(path, "w", encoding="utf-8") as f: @@ -699,7 +700,7 @@ def _write_role_defaults(role_dir: str, mapping: Dict[str, Any]) -> None: """Overwrite role defaults/main.yml with the provided mapping.""" defaults_path = os.path.join(role_dir, "defaults", "main.yml") os.makedirs(os.path.dirname(defaults_path), exist_ok=True) - out = "---\n" + yaml_dump_mapping(mapping, sort_keys=True) + out = "---\n" + yaml_dump_mapping(ansible_unsafe_data(mapping), sort_keys=True) with open(defaults_path, "w", encoding="utf-8") as f: f.write(out) diff --git a/enroll/puppet.py b/enroll/puppet.py index e034cea..baf7596 100644 --- a/enroll/puppet.py +++ b/enroll/puppet.py @@ -20,6 +20,7 @@ from .manifest_safety import ( prepare_manifest_output_dir, safe_artifact_file, ) +from .render_safety import puppet_hiera_safe_data from .state import inventory_packages_from_state, roles_from_state from .jinjaturtle import ( can_jinjify_path, @@ -1586,7 +1587,9 @@ def _render_hiera_yaml() -> str: def _write_yaml(path: Path, data: Dict[str, Any]) -> None: path.parent.mkdir(parents=True, exist_ok=True) path.write_text( - yaml.safe_dump(data, sort_keys=True, explicit_start=True), + yaml.safe_dump( + puppet_hiera_safe_data(data), sort_keys=True, explicit_start=True + ), encoding="utf-8", ) diff --git a/enroll/render_safety.py b/enroll/render_safety.py new file mode 100644 index 0000000..e8fc54a --- /dev/null +++ b/enroll/render_safety.py @@ -0,0 +1,232 @@ +from __future__ import annotations + +import json +import re +from collections.abc import Mapping, Set as AbstractSet +from typing import Any + + +ANSIBLE_JINJA_STARTS = ("{{", "{%", "{#") + + +class AnsibleUnsafeText(str): + """String subclass dumped as Ansible's ``!unsafe`` YAML scalar. + + Ansible templating can recursively evaluate Jinja delimiters that arrive + through variables/defaults. Harvested data is not authored playbook code; + values containing Jinja starts must be tagged as unsafe data before they are + written to Ansible variable files. + """ + + +def is_ansible_template_like(value: str) -> bool: + """Return true if *value* contains a Jinja start delimiter.""" + + return any(marker in value for marker in ANSIBLE_JINJA_STARTS) + + +def ansible_unsafe_data(value: Any) -> Any: + """Recursively mark template-looking harvested strings as Ansible data. + + Keep ordinary strings untouched so generated output remains readable and so + existing tests/tools that use ``yaml.safe_load`` continue to work for normal + data. Mapping keys are also strings in Ansible data structures, so protect + keys as well as values. + """ + + if isinstance(value, AnsibleUnsafeText): + return value + if isinstance(value, str): + return AnsibleUnsafeText(value) if is_ansible_template_like(value) else value + if isinstance(value, Mapping): + return { + ansible_unsafe_data(str(key)): ansible_unsafe_data(inner) + for key, inner in value.items() + } + if isinstance(value, list): + return [ansible_unsafe_data(item) for item in value] + if isinstance(value, tuple): + return [ansible_unsafe_data(item) for item in value] + if isinstance(value, AbstractSet): + return sorted(ansible_unsafe_data(item) for item in value) + return value + + +def escape_puppet_hiera_interpolation(value: str) -> str: + """Preserve literal ``%{`` text in Puppet Hiera data sources. + + Hiera treats ``%{...}`` in data values as interpolation. Enroll's Hiera + data is generated from harvested values, not authored Hiera expressions, so + any literal interpolation opener is escaped with Hiera's documented + ``literal('%')`` helper. + """ + + return str(value).replace("%{", "%{literal('%')}{") + + +def puppet_hiera_safe_data(value: Any) -> Any: + """Recursively escape Hiera interpolation openers in harvested data.""" + + if isinstance(value, Mapping): + return { + escape_puppet_hiera_interpolation(str(key)): puppet_hiera_safe_data(inner) + for key, inner in value.items() + } + if isinstance(value, list): + return [puppet_hiera_safe_data(item) for item in value] + if isinstance(value, tuple): + return [puppet_hiera_safe_data(item) for item in value] + if isinstance(value, AbstractSet): + return sorted(puppet_hiera_safe_data(item) for item in value) + if isinstance(value, str): + return escape_puppet_hiera_interpolation(value) + return value + + +def _plain_json_data(value: Any) -> Any: + if isinstance(value, Mapping): + return {str(key): _plain_json_data(inner) for key, inner in value.items()} + if isinstance(value, list): + return [_plain_json_data(item) for item in value] + if isinstance(value, tuple): + return [_plain_json_data(item) for item in value] + if isinstance(value, AbstractSet): + return sorted(_plain_json_data(item) for item in value) + return value + + +def _escape_braces_inside_json_strings(text: str) -> str: + """Replace literal braces only while scanning JSON string tokens.""" + + out: list[str] = [] + in_string = False + escaped = False + for ch in text: + if not in_string: + out.append(ch) + if ch == '"': + in_string = True + continue + + if escaped: + out.append(ch) + escaped = False + elif ch == "\\": + out.append(ch) + escaped = True + elif ch == '"': + out.append(ch) + in_string = False + elif ch == "{": + out.append("\\u007b") + elif ch == "}": + out.append("\\u007d") + else: + out.append(ch) + return "".join(out) + + +def salt_sls_json_quote(value: Any) -> str: + """Return a double-quoted YAML/JSON scalar safe for Salt's Jinja pass. + + Salt state and pillar SLS files normally use the ``jinja|yaml`` renderer + pipeline. YAML/JSON quoting alone does not stop ``{{ ... }}``, ``{% ... %}`` + or ``{# ... #}`` inside harvested values from being evaluated before YAML is + parsed. JSON/YAML double-quoted scalars decode ``\u007b`` and ``\u007d`` + after Jinja has run, so encode braces inside string tokens as Unicode escapes. + """ + + dumped = json.dumps(str(value), ensure_ascii=False) + return _escape_braces_inside_json_strings(dumped) + + +_PLAIN_YAML_KEY_RE = re.compile(r"^[A-Za-z0-9_./:-]+$") + + +def _salt_yaml_key(value: Any) -> str: + text = str(value) + if text and _PLAIN_YAML_KEY_RE.match(text) and not text.startswith(("-", "?", ":")): + return text + return salt_sls_json_quote(text) + + +def _salt_yaml_scalar(value: Any) -> str: + if value is None: + return "null" + if isinstance(value, bool): + return "true" if value else "false" + if isinstance(value, int) and not isinstance(value, bool): + return str(value) + if isinstance(value, float): + return json.dumps(value, allow_nan=False) + return salt_sls_json_quote(value) + + +def _salt_yaml_lines( + value: Any, indent: int = 0, *, sort_keys: bool = True +) -> list[str]: + prefix = " " * indent + if isinstance(value, Mapping): + if not value: + return [prefix + "{}"] + keys = sorted(value, key=lambda item: str(item)) if sort_keys else list(value) + lines: list[str] = [] + for key in keys: + inner = value[key] + key_text = _salt_yaml_key(key) + if isinstance(inner, Mapping): + if not inner: + lines.append(f"{prefix}{key_text}: {{}}") + else: + lines.append(f"{prefix}{key_text}:") + lines.extend( + _salt_yaml_lines(inner, indent + 2, sort_keys=sort_keys) + ) + elif isinstance(inner, (list, tuple, set)): + seq = list(inner) if not isinstance(inner, set) else sorted(inner) + if not seq: + lines.append(f"{prefix}{key_text}: []") + else: + lines.append(f"{prefix}{key_text}:") + lines.extend(_salt_yaml_lines(seq, indent + 2, sort_keys=sort_keys)) + else: + lines.append(f"{prefix}{key_text}: {_salt_yaml_scalar(inner)}") + return lines + + if isinstance(value, (list, tuple, set)): + seq = list(value) if not isinstance(value, set) else sorted(value) + if not seq: + return [prefix + "[]"] + lines = [] + for item in seq: + if isinstance(item, Mapping): + if not item: + lines.append(prefix + "- {}") + else: + lines.append(prefix + "-") + lines.extend( + _salt_yaml_lines(item, indent + 2, sort_keys=sort_keys) + ) + elif isinstance(item, (list, tuple, set)): + lines.append(prefix + "-") + lines.extend(_salt_yaml_lines(item, indent + 2, sort_keys=sort_keys)) + else: + lines.append(f"{prefix}- {_salt_yaml_scalar(item)}") + return lines + + return [prefix + _salt_yaml_scalar(value)] + + +def salt_sls_yaml_dump( + value: Any, + *, + sort_keys: bool = True, + explicit_start: bool = False, +) -> str: + """Dump block YAML whose string braces cannot form Salt Jinja delimiters.""" + + lines = _salt_yaml_lines(_plain_json_data(value), sort_keys=sort_keys) + rendered = "\n".join(lines).rstrip() + "\n" + if explicit_start: + rendered = "---\n" + rendered + return rendered diff --git a/enroll/salt.py b/enroll/salt.py index 27ec915..2a9fc69 100644 --- a/enroll/salt.py +++ b/enroll/salt.py @@ -8,7 +8,6 @@ import shutil from pathlib import Path from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple -import yaml from .cm import ( CMModule, @@ -22,8 +21,9 @@ from .manifest_safety import ( prepare_manifest_output_dir, safe_artifact_file, ) +from .render_safety import salt_sls_json_quote, salt_sls_yaml_dump from .state import inventory_packages_from_state, roles_from_state -from .yamlutil import yaml_dump_mapping, yaml_load_mapping_file +from .yamlutil import yaml_load_mapping_file class SaltRole(CMModule): @@ -381,7 +381,7 @@ def _active_service_state_ids_by_unit( def _yaml_quote(value: Any) -> str: - return json.dumps(str(value), ensure_ascii=False) + return salt_sls_json_quote(value) def _yaml_bool(value: Any) -> str: @@ -870,9 +870,7 @@ def _collect_salt_roles( def _append_yaml_value(lines: List[str], key: str, value: Any, *, indent: int) -> None: prefix = " " * indent if isinstance(value, dict): - dumped = yaml.safe_dump( - _plain_salt_data(value), sort_keys=True, default_flow_style=False - ).rstrip() + dumped = salt_sls_yaml_dump(_plain_salt_data(value), sort_keys=True).rstrip() if not dumped: lines.append(f"{prefix}- {key}: {{}}") return @@ -1501,7 +1499,7 @@ def _render_pillar_role(srole: SaltRole) -> str: def _write_yaml(path: Path, data: Dict[str, Any]) -> None: path.parent.mkdir(parents=True, exist_ok=True) path.write_text( - yaml_dump_mapping(data, sort_keys=True, explicit_start=True), + salt_sls_yaml_dump(data, sort_keys=True, explicit_start=True), encoding="utf-8", ) diff --git a/enroll/yamlutil.py b/enroll/yamlutil.py index 00e8496..b3bf10d 100644 --- a/enroll/yamlutil.py +++ b/enroll/yamlutil.py @@ -5,6 +5,21 @@ from typing import Any, Dict, Mapping import yaml +from .render_safety import AnsibleUnsafeText + + +class IndentedSafeLoader(yaml.SafeLoader): # type: ignore[misc] + """PyYAML loader that understands Ansible's ``!unsafe`` tag.""" + + +def _construct_ansible_unsafe( + loader: yaml.Loader, node: yaml.Node +) -> AnsibleUnsafeText: + return AnsibleUnsafeText(loader.construct_scalar(node)) + + +IndentedSafeLoader.add_constructor("!unsafe", _construct_ansible_unsafe) + class IndentedSafeDumper(yaml.SafeDumper): # type: ignore[misc] """PyYAML dumper that indents sequences under mapping keys.""" @@ -17,10 +32,17 @@ class IndentedSafeDumper(yaml.SafeDumper): # type: ignore[misc] def yaml_load_mapping(text: str) -> Dict[str, Any]: - """Load YAML text and return a mapping, or an empty mapping on failure.""" + """Load YAML text and return a mapping, or an empty mapping on failure. + + Enroll may re-read Ansible host_vars that contain ``!unsafe`` scalars + written during the same manifest operation, so the loader accepts that tag + while remaining otherwise based on PyYAML's SafeLoader. + """ try: - obj = yaml.safe_load(text) + obj = yaml.load( + text, Loader=IndentedSafeLoader + ) # nosec B506 - subclasses yaml.SafeLoader; only adds !unsafe scalar support. except Exception: return {} return obj if isinstance(obj, dict) else {} @@ -34,6 +56,15 @@ def yaml_load_mapping_file(path: Path) -> Dict[str, Any]: return yaml_load_mapping(path.read_text(encoding="utf-8")) +def _represent_ansible_unsafe( + dumper: yaml.Dumper, data: AnsibleUnsafeText +) -> yaml.Node: + return dumper.represent_scalar("!unsafe", str(data)) + + +IndentedSafeDumper.add_representer(AnsibleUnsafeText, _represent_ansible_unsafe) + + def yaml_dump_mapping( obj: Mapping[str, Any], *, diff --git a/tests/test_manifest.py b/tests/test_manifest.py index 94148ad..27721ea 100644 --- a/tests/test_manifest.py +++ b/tests/test_manifest.py @@ -2378,3 +2378,15 @@ def test_manifest_non_fqdn_refuses_existing_output(tmp_path: Path): with pytest.raises(RuntimeError, match="already exists"): manifest.manifest(str(bundle), str(out), no_common_roles=True) + + +def test_yaml_dump_mapping_emits_ansible_unsafe_tag_for_marked_values(): + from enroll.render_safety import ansible_unsafe_data + + data = ansible_unsafe_data({"value": "{{ lookup('pipe','id') }}"}) + dumped = yaml_helpers.yaml_dump_mapping(data) + + assert "value: !unsafe" in dumped + assert "{{ lookup(''pipe'',''id'') }}" in dumped + loaded = yaml_helpers.yaml_load_mapping(dumped) + assert loaded["value"] == "{{ lookup('pipe','id') }}" diff --git a/tests/test_manifest_ansible.py b/tests/test_manifest_ansible.py index 88e4995..48a2721 100644 --- a/tests/test_manifest_ansible.py +++ b/tests/test_manifest_ansible.py @@ -89,3 +89,74 @@ def test_ansible_role_normalises_package_snapshot(): assert role.files["/etc/curlrc"]["dest"] == "/etc/curlrc" assert role.services == {} assert role.origin_lines == ["package `curl` from role `curl`"] + + +from pathlib import Path + +from state_helpers import write_schema_state + +from enroll import manifest, yamlutil as yaml_helpers + + +def _ansible_jinja_payload_state(payload: str) -> dict: + return { + "schema_version": 3, + "host": {"hostname": "test", "os": "debian", "pkg_backend": "dpkg"}, + "inventory": {"packages": {}}, + "roles": { + "users": { + "role_name": "users", + "users": [ + { + "name": "alice", + "uid": 1000, + "gid": 1000, + "gecos": payload, + "home": "/home/alice", + "shell": "/bin/bash", + "primary_group": "alice", + "supplementary_groups": [], + } + ], + "managed_dirs": [], + "managed_files": [], + "managed_links": [], + "excluded": [], + "notes": [], + }, + "services": [], + "packages": [], + }, + } + + +def test_ansible_static_marks_harvested_jinja_values_unsafe(tmp_path: Path): + bundle = tmp_path / "bundle" + out = tmp_path / "out" + payload = "{{ lookup('pipe','touch /tmp/PWNED_BY_ENROLL_ANSIBLE') }}" + write_schema_state(bundle, _ansible_jinja_payload_state(payload)) + + manifest.manifest(str(bundle), str(out), target="ansible") + + defaults = out / "roles" / "users" / "defaults" / "main.yml" + text = defaults.read_text(encoding="utf-8") + assert "gecos: !unsafe" in text + assert "lookup(''pipe'',''touch /tmp/PWNED_BY_ENROLL_ANSIBLE'')" in text + loaded = yaml_helpers.yaml_load_mapping(text) + assert loaded["users_users"][0]["gecos"] == payload + + +def test_ansible_fqdn_marks_harvested_jinja_values_unsafe(tmp_path: Path): + bundle = tmp_path / "bundle" + out = tmp_path / "out" + payload = "{{ lookup('pipe','touch /tmp/PWNED_BY_ENROLL_ANSIBLE') }}" + write_schema_state(bundle, _ansible_jinja_payload_state(payload)) + + manifest.manifest(str(bundle), str(out), target="ansible", fqdn="host.example.test") + + hostvars = out / "inventory" / "host_vars" / "host.example.test" / "users.yml" + text = hostvars.read_text(encoding="utf-8") + assert "gecos: !unsafe" in text + assert "lookup(''pipe'',''touch /tmp/PWNED_BY_ENROLL_ANSIBLE'')" in text + loaded = yaml_helpers.yaml_load_mapping(text) + assert loaded["users_users"][0]["gecos"] == payload diff --git a/tests/test_manifest_puppet.py b/tests/test_manifest_puppet.py index c54f6c7..3aac35f 100644 --- a/tests/test_manifest_puppet.py +++ b/tests/test_manifest_puppet.py @@ -1408,3 +1408,72 @@ def test_manifest_puppet_user_gecos_with_newline_is_single_line(tmp_path: Path): assert 'comment => "Real Name\\ntouch /tmp/pwned"' in init_pp # And there must be no line that is just the injected command. assert "\ntouch /tmp/pwned\n" not in init_pp + + +def _puppet_hiera_payload_state(payload: str) -> dict: + return { + "schema_version": 3, + "host": {"hostname": "test", "os": "debian", "pkg_backend": "dpkg"}, + "inventory": {"packages": {}}, + "roles": { + "users": { + "role_name": "users", + "users": [ + { + "name": "alice", + "uid": 1000, + "gid": 1000, + "gecos": payload, + "home": "/home/alice", + "shell": "/bin/bash", + "primary_group": "alice", + "supplementary_groups": [], + } + ], + "managed_dirs": [], + "managed_files": [], + "managed_links": [], + "excluded": [], + "notes": [], + }, + "services": [], + "packages": [], + }, + } + + +def test_manifest_puppet_static_quotes_template_like_harvested_values( + tmp_path: Path, +): + bundle = tmp_path / "bundle" + out = tmp_path / "puppet" + payload = "%{lookup('enroll::classes')}" + _write_state(bundle, _puppet_hiera_payload_state(payload)) + + manifest.manifest(str(bundle), str(out), target="puppet") + + init_pp = (out / "modules" / "users" / "manifests" / "init.pp").read_text( + encoding="utf-8" + ) + assert "comment => '%{lookup(\\'enroll::classes\\')}'" in init_pp + + +def test_manifest_puppet_hiera_escapes_harvested_interpolation_tokens( + tmp_path: Path, +): + bundle = tmp_path / "bundle" + out = tmp_path / "puppet" + payload = "%{lookup('enroll::classes')}" + _write_state(bundle, _puppet_hiera_payload_state(payload)) + + manifest.manifest(str(bundle), str(out), target="puppet", fqdn="node.example") + + node_yaml = out / "data" / "nodes" / "node.example.yaml" + text = node_yaml.read_text(encoding="utf-8") + assert payload not in text + assert "%{literal(''%'')}{lookup(''enroll::classes'')}" in text + data = yaml.safe_load(text) + assert ( + data["users::users"]["alice"]["comment"] + == "%{literal('%')}{lookup('enroll::classes')}" + ) diff --git a/tests/test_manifest_salt.py b/tests/test_manifest_salt.py index 3c8aec4..aa3e5ab 100644 --- a/tests/test_manifest_salt.py +++ b/tests/test_manifest_salt.py @@ -964,3 +964,82 @@ def test_salt_names_are_sanitised_for_target_reserved_words() -> None: assert _salt_name("123") == "role_123" assert _salt_name("top") == "role_top" assert _salt_name("web-app") == "web_app" + + +def test_manifest_salt_static_escapes_harvested_jinja_delimiters(tmp_path: Path): + bundle = tmp_path / "bundle" + out = tmp_path / "salt" + state = _sample_state() + payload = "{{ salt['cmd.run']('touch /tmp/PWNED_BY_ENROLL_SALT') }}" + state["roles"]["users"]["users"][0]["gecos"] = payload + _write_sample_artifacts(bundle) + _write_state(bundle, state) + + manifest.manifest(str(bundle), str(out), target="salt") + + users_sls = (out / "states" / "roles" / "users" / "init.sls").read_text( + encoding="utf-8" + ) + assert payload not in users_sls + assert "\\u007b\\u007b salt['cmd.run']" in users_sls + + calls = [] + + class FakeCmd: + def run(self, command): + calls.append(command) + return "EXECUTED" + + from jinja2 import Template + + rendered = Template(users_sls).render(salt={"cmd.run": FakeCmd().run}) + rendered_data = yaml.safe_load(rendered) + assert calls == [] + user_state = next( + state + for state in rendered_data.values() + if isinstance(state, dict) and "user.present" in state + ) + attrs = user_state["user.present"] + fullname = next(item["fullname"] for item in attrs if "fullname" in item) + assert fullname == payload + + +def test_manifest_salt_fqdn_escapes_harvested_jinja_delimiters_in_pillar( + tmp_path: Path, +): + bundle = tmp_path / "bundle" + out = tmp_path / "salt" + state = _sample_state() + payload = "{{ salt['cmd.run']('touch /tmp/PWNED_BY_ENROLL_SALT') }}" + state["roles"]["users"]["users"][0]["gecos"] = payload + _write_sample_artifacts(bundle) + _write_state(bundle, state) + + manifest.manifest(str(bundle), str(out), target="salt", fqdn="node.example") + + pillar_top = yaml.safe_load( + (out / "pillar" / "top.sls").read_text(encoding="utf-8") + ) + node_sls = pillar_top["base"]["node.example"][0] + pillar_path = out / "pillar" / Path(*node_sls.split(".")) + text = pillar_path.with_suffix(".sls").read_text(encoding="utf-8") + assert payload not in text + assert "\\u007b\\u007b salt['cmd.run']" in text + + calls = [] + + class FakeCmd: + def run(self, command): + calls.append(command) + return "EXECUTED" + + from jinja2 import Template + + rendered = Template(text).render(salt={"cmd.run": FakeCmd().run}) + rendered_data = yaml.safe_load(rendered) + assert calls == [] + assert ( + rendered_data["enroll"]["roles"]["users"]["users"]["alice"]["fullname"] + == payload + )