cgcardona / muse public
_invariants.py python
590 lines 19.6 KB
766ee24d feat: code domain leverages core invariants, query engine, manifests, p… Gabriel Cardona <gabriel@tellurstori.com> 1d ago
1 """Musical invariants engine for the Muse music plugin.
2
3 Invariants are semantic rules that a MIDI track must satisfy. They are
4 evaluated at commit time, merge time, or on-demand via ``muse music-check``.
5 Violations are reported with human-readable descriptions, severity levels,
6 and structured addresses for programmatic consumers.
7
8 Rule file format (TOML)
9 -----------------------
10 Rules are declared in ``.muse/midi_invariants.toml`` (default path).
11 Example::
12
13 [[rule]]
14 name = "max_polyphony"
15 severity = "error"
16 scope = "track"
17 rule_type = "max_polyphony"
18
19 [rule.params]
20 max_simultaneous = 6
21
22 [[rule]]
23 name = "keep_in_range"
24 severity = "warning"
25 scope = "track"
26 rule_type = "pitch_range"
27
28 [rule.params]
29 min_pitch = 24
30 max_pitch = 108
31
32 [[rule]]
33 name = "no_fifths"
34 severity = "warning"
35 scope = "voice_pair"
36 rule_type = "no_parallel_fifths"
37
38 [[rule]]
39 name = "consistent_key"
40 severity = "info"
41 scope = "track"
42 rule_type = "key_consistency"
43
44 [rule.params]
45 threshold = 0.15
46
47 Built-in rule types
48 -------------------
49
50 ``max_polyphony``
51 Detects bars where more than *max_simultaneous* notes overlap at any
52 tick position. Uses a sweep-line algorithm over start/end tick events.
53
54 ``pitch_range``
55 Detects any note with ``pitch < min_pitch`` or ``pitch > max_pitch``.
56
57 ``key_consistency``
58 Detects notes whose pitch class is highly inconsistent with the key
59 estimated by the Krumhansl-Schmuckler algorithm. Fires when the ratio
60 of "foreign" pitch classes exceeds *threshold*.
61
62 ``no_parallel_fifths``
63 Detects consecutive bars where the lowest voice and the second-lowest
64 voice both move by a perfect fifth in parallel (a classical counterpoint
65 violation). Best-effort heuristic — voice assignment is implicit.
66
67 Severity levels
68 ---------------
69 - ``"error"`` — must be resolved before committing (when ``--strict`` is set).
70 - ``"warning"`` — reported but does not block commits.
71 - ``"info"`` — informational; surfaced in ``muse music-check`` output only.
72
73 Public API
74 ----------
75 - :class:`InvariantRule` — rule declaration TypedDict.
76 - :class:`InvariantViolation` — single violation record TypedDict.
77 - :class:`InvariantReport` — full report for one commit / track.
78 - :func:`load_invariant_rules` — load from TOML file with defaults fallback.
79 - :func:`run_invariants` — evaluate all rules against a commit.
80 """
81 from __future__ import annotations
82
83 import logging
84 import pathlib
85 from typing import Literal, TypedDict
86
87 from muse.core.invariants import BaseReport, BaseViolation, make_report
88 from muse.core.object_store import read_object
89 from muse.core.store import get_commit_snapshot_manifest
90 from muse.plugins.midi._query import NoteInfo, key_signature_guess, notes_by_bar
91 from muse.plugins.midi.midi_diff import extract_notes
92
93 logger = logging.getLogger(__name__)
94
95 _DEFAULT_RULES_FILE = ".muse/midi_invariants.toml"
96
97
98 # ---------------------------------------------------------------------------
99 # Types
100 # ---------------------------------------------------------------------------
101
102
103 class _InvariantRuleRequired(TypedDict):
104 name: str
105 severity: Literal["info", "warning", "error"]
106 scope: Literal["track", "bar", "voice_pair", "global"]
107 rule_type: str
108
109
110 class InvariantRule(_InvariantRuleRequired, total=False):
111 """Declaration of one MIDI invariant rule.
112
113 ``name`` Human-readable rule identifier (unique within a rule set).
114 ``severity`` Violation severity: ``"info"``, ``"warning"``, or ``"error"``.
115 ``scope`` Granularity: ``"track"``, ``"bar"``, ``"voice_pair"``, ``"global"``.
116 ``rule_type`` Built-in type string: ``"max_polyphony"``, ``"pitch_range"``,
117 ``"key_consistency"``, ``"no_parallel_fifths"``.
118 ``params`` Rule-specific parameter dict.
119 """
120
121 params: dict[str, str | int | float]
122
123
124 class InvariantViolation(TypedDict):
125 """A single invariant violation record.
126
127 ``rule_name`` The name of the rule that fired.
128 ``severity`` Severity level from the rule declaration.
129 ``track`` Workspace-relative MIDI file path.
130 ``bar`` 1-indexed bar number (0 for track-level violations).
131 ``description`` Human-readable explanation of what was violated.
132 ``addresses`` Note addresses or other domain addresses involved.
133 """
134
135 rule_name: str
136 severity: Literal["info", "warning", "error"]
137 track: str
138 bar: int
139 description: str
140 addresses: list[str]
141
142
143 class InvariantReport(TypedDict):
144 """Full invariant check report for one commit.
145
146 ``commit_id`` The commit that was checked.
147 ``violations`` All violations found, sorted by track then bar.
148 ``rules_checked`` Number of rules evaluated.
149 ``has_errors`` True when any violation has severity ``"error"``.
150 ``has_warnings`` True when any violation has severity ``"warning"``.
151 """
152
153 commit_id: str
154 violations: list[InvariantViolation]
155 rules_checked: int
156 has_errors: bool
157 has_warnings: bool
158
159
160 # ---------------------------------------------------------------------------
161 # Built-in rule implementations
162 # ---------------------------------------------------------------------------
163
164
165 def check_max_polyphony(
166 notes: list[NoteInfo],
167 track: str,
168 rule_name: str,
169 severity: Literal["info", "warning", "error"],
170 *,
171 max_simultaneous: int = 6,
172 ) -> list[InvariantViolation]:
173 """Find bars where simultaneous note count exceeds *max_simultaneous*.
174
175 Uses a tick-based sweep-line over (start_tick, end_tick) intervals.
176 Reports one violation per offending bar.
177
178 Args:
179 notes: All notes in the track.
180 track: Track file path for violation records.
181 rule_name: Rule identifier string.
182 severity: Violation severity.
183 max_simultaneous: Maximum allowed simultaneous notes.
184
185 Returns:
186 List of :class:`InvariantViolation` records.
187 """
188 violations: list[InvariantViolation] = []
189 bars = notes_by_bar(notes)
190
191 for bar_num, bar_notes in sorted(bars.items()):
192 # Collect all tick events: +1 for note_on, -1 for note_off.
193 events: list[tuple[int, int]] = []
194 for n in bar_notes:
195 events.append((n.start_tick, 1))
196 events.append((n.start_tick + n.duration_ticks, -1))
197 events.sort(key=lambda e: (e[0], e[1])) # off before on at same tick
198
199 current = 0
200 peak = 0
201 peak_tick = 0
202 for tick, delta in events:
203 current += delta
204 if current > peak:
205 peak = current
206 peak_tick = tick
207
208 if peak > max_simultaneous:
209 violations.append(
210 InvariantViolation(
211 rule_name=rule_name,
212 severity=severity,
213 track=track,
214 bar=bar_num,
215 description=(
216 f"Polyphony reached {peak} simultaneous notes at tick {peak_tick} "
217 f"(max allowed: {max_simultaneous})"
218 ),
219 addresses=[f"bar:{bar_num}:tick:{peak_tick}"],
220 )
221 )
222
223 return violations
224
225
226 def check_pitch_range(
227 notes: list[NoteInfo],
228 track: str,
229 rule_name: str,
230 severity: Literal["info", "warning", "error"],
231 *,
232 min_pitch: int = 0,
233 max_pitch: int = 127,
234 ) -> list[InvariantViolation]:
235 """Find notes outside the allowed MIDI pitch range.
236
237 Args:
238 notes: All notes in the track.
239 track: Track file path.
240 rule_name: Rule identifier.
241 severity: Violation severity.
242 min_pitch: Lowest allowed MIDI pitch (inclusive).
243 max_pitch: Highest allowed MIDI pitch (inclusive).
244
245 Returns:
246 One :class:`InvariantViolation` per out-of-range note.
247 """
248 violations: list[InvariantViolation] = []
249 for note in notes:
250 if note.pitch < min_pitch or note.pitch > max_pitch:
251 violations.append(
252 InvariantViolation(
253 rule_name=rule_name,
254 severity=severity,
255 track=track,
256 bar=note.bar,
257 description=(
258 f"Note {note.pitch_name} (MIDI {note.pitch}) is outside "
259 f"allowed range [{min_pitch}, {max_pitch}]"
260 ),
261 addresses=[f"bar:{note.bar}:pitch:{note.pitch}"],
262 )
263 )
264 return violations
265
266
267 def check_key_consistency(
268 notes: list[NoteInfo],
269 track: str,
270 rule_name: str,
271 severity: Literal["info", "warning", "error"],
272 *,
273 threshold: float = 0.15,
274 ) -> list[InvariantViolation]:
275 """Detect notes whose pitch class is inconsistent with the guessed key.
276
277 Estimates the key using the Krumhansl-Schmuckler algorithm, then counts
278 the fraction of notes that use a pitch class not diatonic to that key.
279 Fires when the foreign-note ratio exceeds *threshold*.
280
281 Args:
282 notes: All notes in the track.
283 track: Track file path.
284 rule_name: Rule identifier.
285 severity: Violation severity.
286 threshold: Maximum allowed ratio of foreign pitch classes (0.0–1.0).
287
288 Returns:
289 Zero or one :class:`InvariantViolation` for the track.
290 """
291 if not notes:
292 return []
293
294 key_guess = key_signature_guess(notes)
295 # Parse key guess string e.g. "G major" or "D minor".
296 parts = key_guess.split()
297 if len(parts) < 2:
298 return []
299
300 root_name = parts[0]
301 mode = parts[1]
302
303 pitch_classes = ["C", "C#", "D", "D#", "E", "F", "F#", "G", "G#", "A", "A#", "B"]
304 root_idx = pitch_classes.index(root_name) if root_name in pitch_classes else -1
305 if root_idx < 0:
306 return []
307
308 # Diatonic pitch classes for major and natural minor scales.
309 major_steps = [0, 2, 4, 5, 7, 9, 11]
310 minor_steps = [0, 2, 3, 5, 7, 8, 10]
311 steps = major_steps if mode == "major" else minor_steps
312 diatonic_pcs = frozenset((root_idx + s) % 12 for s in steps)
313
314 foreign = sum(1 for n in notes if n.pitch_class not in diatonic_pcs)
315 ratio = foreign / len(notes)
316
317 if ratio > threshold:
318 return [
319 InvariantViolation(
320 rule_name=rule_name,
321 severity=severity,
322 track=track,
323 bar=0,
324 description=(
325 f"{foreign}/{len(notes)} notes ({ratio:.0%}) use pitch classes "
326 f"foreign to estimated key {key_guess} "
327 f"(threshold: {threshold:.0%})"
328 ),
329 addresses=[track],
330 )
331 ]
332 return []
333
334
335 def check_no_parallel_fifths(
336 notes: list[NoteInfo],
337 track: str,
338 rule_name: str,
339 severity: Literal["info", "warning", "error"],
340 ) -> list[InvariantViolation]:
341 """Detect consecutive bars with parallel perfect fifth motion.
342
343 Heuristic: for each pair of consecutive bars, find the two lowest-pitched
344 notes (approximating bass and tenor voices) and check whether both voices
345 move by a perfect fifth (7 semitones) in the same direction.
346
347 This is a best-effort approximation — accurate voice separation would
348 require dedicated voice-leading analysis beyond this scope.
349
350 Args:
351 notes: All notes in the track.
352 track: Track file path.
353 rule_name: Rule identifier.
354 severity: Violation severity.
355
356 Returns:
357 One :class:`InvariantViolation` per detected parallel-fifth bar pair.
358 """
359 violations: list[InvariantViolation] = []
360 bars = notes_by_bar(notes)
361 sorted_bars = sorted(bars.keys())
362
363 for i in range(len(sorted_bars) - 1):
364 bar_a = sorted_bars[i]
365 bar_b = sorted_bars[i + 1]
366 notes_a = sorted(bars[bar_a], key=lambda n: n.pitch)
367 notes_b = sorted(bars[bar_b], key=lambda n: n.pitch)
368
369 if len(notes_a) < 2 or len(notes_b) < 2:
370 continue
371
372 # Take two lowest pitches as approximated bass + tenor voices.
373 v1_a, v2_a = notes_a[0].pitch, notes_a[1].pitch
374 v1_b, v2_b = notes_b[0].pitch, notes_b[1].pitch
375
376 # Interval between voices in each bar.
377 interval_a = abs(v2_a - v1_a) % 12
378 interval_b = abs(v2_b - v1_b) % 12
379
380 # Both form a perfect fifth (7 semitones modulo octave)?
381 if interval_a == 7 and interval_b == 7:
382 # Both voices moved in the same direction?
383 motion_v1 = v1_b - v1_a
384 motion_v2 = v2_b - v2_a
385 if (motion_v1 > 0 and motion_v2 > 0) or (motion_v1 < 0 and motion_v2 < 0):
386 violations.append(
387 InvariantViolation(
388 rule_name=rule_name,
389 severity=severity,
390 track=track,
391 bar=bar_b,
392 description=(
393 f"Parallel fifths between bars {bar_a} and {bar_b}: "
394 f"lower voice {notes_a[0].pitch_name}→{notes_b[0].pitch_name}, "
395 f"upper voice {notes_a[1].pitch_name}→{notes_b[1].pitch_name}"
396 ),
397 addresses=[f"bar:{bar_a}", f"bar:{bar_b}"],
398 )
399 )
400
401 return violations
402
403
404 # ---------------------------------------------------------------------------
405 # Rule loading
406 # ---------------------------------------------------------------------------
407
408 _DEFAULT_RULE_SET: list[InvariantRule] = [
409 InvariantRule(
410 name="max_polyphony",
411 severity="warning",
412 scope="track",
413 rule_type="max_polyphony",
414 params={"max_simultaneous": 8},
415 ),
416 InvariantRule(
417 name="pitch_range",
418 severity="warning",
419 scope="track",
420 rule_type="pitch_range",
421 params={"min_pitch": 0, "max_pitch": 127},
422 ),
423 ]
424
425
426 def load_invariant_rules(rules_file: pathlib.Path | None = None) -> list[InvariantRule]:
427 """Load invariant rules from a TOML file, falling back to defaults.
428
429 Requires ``tomllib`` (Python 3.11+) for TOML parsing. If the file does
430 not exist or cannot be parsed, the default rule set is returned.
431
432 Args:
433 rules_file: Path to the TOML rule file. ``None`` means use defaults.
434
435 Returns:
436 List of :class:`InvariantRule` dicts.
437 """
438 if rules_file is None or not rules_file.exists():
439 return list(_DEFAULT_RULE_SET)
440
441 try:
442 import tomllib
443
444 with rules_file.open("rb") as fh:
445 data = tomllib.load(fh)
446
447 rules: list[InvariantRule] = []
448 for raw in data.get("rule", []):
449 _valid_severities: dict[str, Literal["info", "warning", "error"]] = {
450 "info": "info", "warning": "warning", "error": "error",
451 }
452 _valid_scopes: dict[str, Literal["track", "bar", "voice_pair", "global"]] = {
453 "track": "track", "bar": "bar", "voice_pair": "voice_pair", "global": "global",
454 }
455 sev = _valid_severities.get(str(raw.get("severity", "")), "warning")
456 scope = _valid_scopes.get(str(raw.get("scope", "")), "track")
457 rule = InvariantRule(
458 name=str(raw.get("name", "unnamed")),
459 severity=sev,
460 scope=scope,
461 rule_type=str(raw.get("rule_type", "")),
462 )
463 if "params" in raw:
464 rule["params"] = raw["params"]
465 rules.append(rule)
466 return rules if rules else list(_DEFAULT_RULE_SET)
467
468 except Exception as exc:
469 logger.warning("⚠️ Could not load invariant rules from %s: %s", rules_file, exc)
470 return list(_DEFAULT_RULE_SET)
471
472
473 # ---------------------------------------------------------------------------
474 # Main runner
475 # ---------------------------------------------------------------------------
476
477
478 def run_invariants(
479 root: "pathlib.Path",
480 commit_id: str,
481 rules: list[InvariantRule],
482 *,
483 track_filter: str | None = None,
484 ) -> InvariantReport:
485 """Evaluate all *rules* against every MIDI track in *commit_id*.
486
487 Args:
488 root: Repository root.
489 commit_id: Commit to check.
490 rules: List of :class:`InvariantRule` declarations.
491 track_filter: Restrict check to a single MIDI file path.
492
493 Returns:
494 An :class:`InvariantReport` with all violations found.
495 """
496 import pathlib as _pathlib
497
498 all_violations: list[InvariantViolation] = []
499 manifest = get_commit_snapshot_manifest(root, commit_id) or {}
500
501 midi_paths = [
502 p for p in manifest
503 if p.lower().endswith(".mid")
504 and (track_filter is None or p == track_filter)
505 ]
506
507 for track_path in sorted(midi_paths):
508 obj_hash = manifest.get(track_path)
509 if obj_hash is None:
510 continue
511 raw = read_object(root, obj_hash)
512 if raw is None:
513 continue
514 try:
515 keys, tpb = extract_notes(raw)
516 except ValueError as exc:
517 logger.debug("Cannot parse MIDI %r: %s", track_path, exc)
518 continue
519
520 notes = [NoteInfo.from_note_key(k, tpb) for k in keys]
521
522 for rule in rules:
523 rt = rule["rule_type"]
524 sev = rule["severity"]
525 params = rule.get("params", {})
526 name = rule["name"]
527
528 if rt == "max_polyphony":
529 max_sim = int(params.get("max_simultaneous", 8))
530 all_violations.extend(
531 check_max_polyphony(notes, track_path, name, sev, max_simultaneous=max_sim)
532 )
533 elif rt == "pitch_range":
534 min_p = int(params.get("min_pitch", 0))
535 max_p = int(params.get("max_pitch", 127))
536 all_violations.extend(
537 check_pitch_range(notes, track_path, name, sev, min_pitch=min_p, max_pitch=max_p)
538 )
539 elif rt == "key_consistency":
540 thresh = float(params.get("threshold", 0.15))
541 all_violations.extend(
542 check_key_consistency(notes, track_path, name, sev, threshold=thresh)
543 )
544 elif rt == "no_parallel_fifths":
545 all_violations.extend(
546 check_no_parallel_fifths(notes, track_path, name, sev)
547 )
548 else:
549 logger.debug("Unknown rule_type %r in rule %r — skipped", rt, name)
550
551 all_violations.sort(key=lambda v: (v["track"], v["bar"]))
552 has_errors = any(v["severity"] == "error" for v in all_violations)
553 has_warnings = any(v["severity"] == "warning" for v in all_violations)
554
555 return InvariantReport(
556 commit_id=commit_id,
557 violations=all_violations,
558 rules_checked=len(rules) * len(midi_paths),
559 has_errors=has_errors,
560 has_warnings=has_warnings,
561 )
562
563
564 class MidiChecker:
565 """Satisfies :class:`~muse.core.invariants.InvariantChecker` for the MIDI domain.
566
567 Wraps :func:`run_invariants` so that the generic ``muse check`` command
568 can dispatch to the MIDI checker without knowing MIDI internals.
569 """
570
571 def check(
572 self,
573 repo_root: pathlib.Path,
574 commit_id: str,
575 *,
576 rules_file: pathlib.Path | None = None,
577 ) -> BaseReport:
578 """Run MIDI invariant checks against *commit_id* and return a :class:`~muse.core.invariants.BaseReport`."""
579 rules = load_invariant_rules(rules_file)
580 midi_report = run_invariants(repo_root, commit_id, rules)
581 base_violations: list[BaseViolation] = [
582 BaseViolation(
583 rule_name=v["rule_name"],
584 severity=v["severity"],
585 address=v["track"],
586 description=v["description"],
587 )
588 for v in midi_report["violations"]
589 ]
590 return make_report(commit_id, "midi", base_violations, midi_report["rules_checked"])