cgcardona / muse public
invariants.py python
249 lines 8.8 KB
e6786943 feat: upgrade to Python 3.14, drop from __future__ import annotations Gabriel Cardona <cgcardona@gmail.com> 1d ago
1 """Domain-agnostic invariants engine for Muse.
2
3 An *invariant* is a semantic rule that a domain's state must satisfy. Rules
4 are declared in TOML, evaluated against commit snapshots, and reported with
5 structured violations. Any domain plugin can implement invariant checking
6 by satisfying the :class:`InvariantChecker` protocol and wiring a CLI command.
7
8 This module defines the **shared vocabulary** — TypedDicts and protocols that
9 are domain-agnostic. Domain-specific implementations (MIDI, code, genomics…)
10 import these types and add their own rule types and evaluators.
11
12 Architecture
13 ------------
14 ::
15
16 muse/core/invariants.py ← this file: shared protocol
17 muse/plugins/midi/_invariants.py ← MIDI-specific rules + evaluator
18 muse/plugins/code/_invariants.py ← code-specific rules + evaluator
19 muse/cli/commands/midi_check.py ← CLI wiring for MIDI
20 muse/cli/commands/code_check.py ← CLI wiring for code
21
22 TOML rule file format (shared across all domains)::
23
24 [[rule]]
25 name = "my_rule" # unique human-readable identifier
26 severity = "error" # "info" | "warning" | "error"
27 scope = "file" # domain-specific scope tag
28 rule_type = "max_complexity" # domain-specific rule type string
29
30 [rule.params]
31 threshold = 10 # rule-specific numeric / string params
32
33 Severity levels
34 ---------------
35 - ``"error"`` — must be resolved before committing (when ``--strict`` is set).
36 - ``"warning"`` — reported but does not block commits.
37 - ``"info"`` — informational; surfaced in ``muse check`` output only.
38
39 Public API
40 ----------
41 - :data:`InvariantSeverity` — severity literal type alias.
42 - :class:`BaseViolation` — domain-agnostic violation record.
43 - :class:`BaseReport` — full check report for one commit.
44 - :class:`InvariantChecker` — Protocol every domain checker must satisfy.
45 - :func:`make_report` — build a ``BaseReport`` from a violation list.
46 - :func:`load_rules_toml` — parse any ``[[rule]]`` TOML file.
47 - :func:`format_report` — human-readable report text.
48 """
49
50 import logging
51 import pathlib
52 from typing import Literal, Protocol, TypedDict, runtime_checkable
53
54 logger = logging.getLogger(__name__)
55
56 # ---------------------------------------------------------------------------
57 # Shared severity literal
58 # ---------------------------------------------------------------------------
59
60 InvariantSeverity = Literal["info", "warning", "error"]
61
62
63 # ---------------------------------------------------------------------------
64 # Domain-agnostic violation + report TypedDicts
65 # ---------------------------------------------------------------------------
66
67
68 class BaseViolation(TypedDict):
69 """A single invariant violation, domain-agnostic.
70
71 Domain implementations extend this with additional fields (e.g. ``track``
72 for MIDI, ``file`` and ``symbol`` for code).
73
74 ``rule_name`` The name of the rule that fired.
75 ``severity`` Violation severity inherited from the rule declaration.
76 ``address`` Dotted path to the violating element
77 (e.g. ``"src/utils.py::my_fn"`` or ``"piano.mid/bar:4"``).
78 ``description`` Human-readable explanation of the violation.
79 """
80
81 rule_name: str
82 severity: InvariantSeverity
83 address: str
84 description: str
85
86
87 class BaseReport(TypedDict):
88 """Full invariant check report for one commit, domain-agnostic.
89
90 ``commit_id`` The commit that was checked.
91 ``domain`` Domain tag (e.g. ``"midi"``, ``"code"``).
92 ``violations`` All violations found, sorted by address.
93 ``rules_checked`` Number of rules evaluated.
94 ``has_errors`` ``True`` when any violation has severity ``"error"``.
95 ``has_warnings`` ``True`` when any violation has severity ``"warning"``.
96 """
97
98 commit_id: str
99 domain: str
100 violations: list[BaseViolation]
101 rules_checked: int
102 has_errors: bool
103 has_warnings: bool
104
105
106 # ---------------------------------------------------------------------------
107 # InvariantChecker protocol
108 # ---------------------------------------------------------------------------
109
110
111 @runtime_checkable
112 class InvariantChecker(Protocol):
113 """Protocol every domain invariant checker must satisfy.
114
115 Domain plugins implement this by providing :meth:`check` — a function that
116 loads and evaluates the domain's invariant rules against a commit, returning
117 a :class:`BaseReport`. The CLI ``muse check`` command dispatches to the
118 domain's registered checker via this protocol.
119
120 Example implementation::
121
122 class MyDomainChecker:
123 def check(
124 self,
125 repo_root: pathlib.Path,
126 commit_id: str,
127 *,
128 rules_file: pathlib.Path | None = None,
129 ) -> BaseReport:
130 rules = load_rules_toml(rules_file or default_path)
131 violations = _evaluate(repo_root, commit_id, rules)
132 return make_report(commit_id, "mydomain", violations, len(rules))
133 """
134
135 def check(
136 self,
137 repo_root: pathlib.Path,
138 commit_id: str,
139 *,
140 rules_file: pathlib.Path | None = None,
141 ) -> BaseReport:
142 """Evaluate invariant rules and return a structured report.
143
144 Args:
145 repo_root: Repository root (contains ``.muse/``).
146 commit_id: Commit to check.
147 rules_file: Path to a TOML rule file. ``None`` → use the
148 domain's default location.
149
150 Returns:
151 A :class:`BaseReport` with all violations and summary flags.
152 """
153 ...
154
155
156 # ---------------------------------------------------------------------------
157 # Helpers
158 # ---------------------------------------------------------------------------
159
160
161 def make_report(
162 commit_id: str,
163 domain: str,
164 violations: list[BaseViolation],
165 rules_checked: int,
166 ) -> BaseReport:
167 """Build a :class:`BaseReport` from a flat violation list.
168
169 Sorts violations by address then rule name for deterministic output.
170
171 Args:
172 commit_id: Commit that was checked.
173 domain: Domain tag.
174 violations: All violations found.
175 rules_checked: Number of rules that were evaluated.
176
177 Returns:
178 A fully populated :class:`BaseReport`.
179 """
180 sorted_violations = sorted(violations, key=lambda v: (v["address"], v["rule_name"]))
181 return BaseReport(
182 commit_id=commit_id,
183 domain=domain,
184 violations=sorted_violations,
185 rules_checked=rules_checked,
186 has_errors=any(v["severity"] == "error" for v in violations),
187 has_warnings=any(v["severity"] == "warning" for v in violations),
188 )
189
190
191 def load_rules_toml(path: pathlib.Path) -> list[dict[str, str | int | float | dict[str, str | int | float]]]:
192 """Parse a ``[[rule]]`` TOML file and return the raw rule dicts.
193
194 Returns an empty list when the file does not exist (domain then uses
195 built-in defaults).
196
197 Args:
198 path: Path to the TOML file.
199
200 Returns:
201 List of raw rule dicts (``{"name": ..., "severity": ..., ...}``).
202 """
203 if not path.exists():
204 logger.debug("Invariants rules file not found at %s — using defaults", path)
205 return []
206 import tomllib # stdlib on Python ≥ 3.11; Muse requires 3.12
207 try:
208 data = tomllib.loads(path.read_text())
209 rules: list[dict[str, str | int | float | dict[str, str | int | float]]] = data.get("rule", [])
210 return rules
211 except Exception as exc:
212 logger.warning("Failed to parse invariants file %s: %s", path, exc)
213 return []
214
215
216 def format_report(report: BaseReport, *, color: bool = True) -> str:
217 """Return a human-readable multi-line report string.
218
219 Args:
220 report: The report to format.
221 color: If ``True``, prefix error/warning/info lines with emoji.
222
223 Returns:
224 Formatted string ready for ``typer.echo()``.
225 """
226 lines: list[str] = []
227 prefix = {
228 "error": "❌" if color else "[error]",
229 "warning": "⚠️ " if color else "[warn] ",
230 "info": "ℹ️ " if color else "[info] ",
231 }
232 for v in report["violations"]:
233 p = prefix.get(v["severity"], " ")
234 lines.append(f" {p} [{v['rule_name']}] {v['address']}: {v['description']}")
235
236 checked = report["rules_checked"]
237 total = len(report["violations"])
238 errors = sum(1 for v in report["violations"] if v["severity"] == "error")
239 warnings = sum(1 for v in report["violations"] if v["severity"] == "warning")
240
241 summary = f"\n{checked} rules checked — {total} violation(s)"
242 if errors:
243 summary += f", {errors} error(s)"
244 if warnings:
245 summary += f", {warnings} warning(s)"
246 if not total:
247 summary = f"\n✅ {checked} rules checked — no violations"
248
249 return "\n".join(lines) + summary