cgcardona / muse public
midi_diff.py python
338 lines 11.0 KB
d7054e63 feat(phase-1): typed delta algebra — replace DeltaManifest with Structu… Gabriel Cardona <gabriel@tellurstori.com> 2d ago
1 """MIDI note-level diff for the Muse music plugin.
2
3 Implements the Myers / LCS shortest-edit-script algorithm on MIDI note
4 sequences, producing a ``StructuredDelta`` with note-level ``InsertOp``,
5 ``DeleteOp``, and ``ReplaceOp`` entries inside a ``PatchOp``.
6
7 This is what lets ``muse show`` display "C4 added at beat 3.5" rather than
8 "tracks/drums.mid modified".
9
10 Algorithm
11 ---------
12 1. Parse MIDI bytes and extract paired note events (note_on + note_off)
13 sorted by start tick.
14 2. Represent each note as a ``NoteKey`` TypedDict with five fields.
15 3. Run the O(nm) LCS dynamic-programming algorithm on the two note sequences.
16 4. Traceback to produce a shortest edit script of keep / insert / delete steps.
17 5. Map edit steps to typed ``DomainOp`` instances.
18 6. Wrap the ops in a ``StructuredDelta`` with a human-readable summary.
19
20 Public API
21 ----------
22 - :class:`NoteKey` — hashable note identity.
23 - :func:`extract_notes` — MIDI bytes → sorted ``list[NoteKey]``.
24 - :func:`lcs_edit_script` — LCS shortest edit script on two note lists.
25 - :func:`diff_midi_notes` — top-level: MIDI bytes × 2 → ``StructuredDelta``.
26 """
27 from __future__ import annotations
28
29 import hashlib
30 import io
31 import logging
32 from dataclasses import dataclass
33 from typing import Literal, TypedDict
34
35 import mido
36
37 from muse.domain import (
38 DeleteOp,
39 DomainOp,
40 InsertOp,
41 StructuredDelta,
42 )
43
44 logger = logging.getLogger(__name__)
45
46 #: Identifies the sub-domain for note-level operations inside a PatchOp.
47 _CHILD_DOMAIN = "midi_notes"
48
49 _PITCH_NAMES = ["C", "C#", "D", "D#", "E", "F", "F#", "G", "G#", "A", "A#", "B"]
50
51
52 # ---------------------------------------------------------------------------
53 # NoteKey — the unit of LCS comparison
54 # ---------------------------------------------------------------------------
55
56
57 class NoteKey(TypedDict):
58 """Fully-specified MIDI note used as the LCS comparison unit.
59
60 Two notes are considered identical in LCS iff all five fields match.
61 A pitch change, velocity change, timing shift, or channel change
62 counts as a delete of the old note and an insert of the new one.
63 This is conservative but correct — it means the LCS finds true
64 structural matches and surfaces real musical changes.
65 """
66
67 pitch: int
68 velocity: int
69 start_tick: int
70 duration_ticks: int
71 channel: int
72
73
74 # ---------------------------------------------------------------------------
75 # Edit step — output of the LCS traceback
76 # ---------------------------------------------------------------------------
77
78 EditKind = Literal["keep", "insert", "delete"]
79
80
81 @dataclass(frozen=True)
82 class EditStep:
83 """One step in the shortest edit script."""
84
85 kind: EditKind
86 base_index: int # index in the base note sequence
87 target_index: int # index in the target note sequence
88 note: NoteKey
89
90
91 # ---------------------------------------------------------------------------
92 # Helpers
93 # ---------------------------------------------------------------------------
94
95
96 def _pitch_name(midi_pitch: int) -> str:
97 """Return a human-readable pitch string, e.g. ``"C4"``, ``"F#5"``."""
98 octave = midi_pitch // 12 - 1
99 name = _PITCH_NAMES[midi_pitch % 12]
100 return f"{name}{octave}"
101
102
103 def _note_content_id(note: NoteKey) -> str:
104 """Return a deterministic SHA-256 for a note's five identity fields.
105
106 This gives a stable ``content_id`` for use in ``InsertOp`` / ``DeleteOp``
107 without requiring the note to be stored as a separate blob in the object
108 store. The hash uniquely identifies "this specific note event".
109 """
110 payload = (
111 f"{note['pitch']}:{note['velocity']}:"
112 f"{note['start_tick']}:{note['duration_ticks']}:{note['channel']}"
113 )
114 return hashlib.sha256(payload.encode()).hexdigest()
115
116
117 def _note_summary(note: NoteKey, ticks_per_beat: int) -> str:
118 """Return a human-readable one-liner for a note, e.g. ``"C4 vel=80 @beat=1.00"``."""
119 beat = note["start_tick"] / max(ticks_per_beat, 1)
120 dur = note["duration_ticks"] / max(ticks_per_beat, 1)
121 return (
122 f"{_pitch_name(note['pitch'])} "
123 f"vel={note['velocity']} "
124 f"@beat={beat:.2f} "
125 f"dur={dur:.2f}"
126 )
127
128
129 # ---------------------------------------------------------------------------
130 # Note extraction
131 # ---------------------------------------------------------------------------
132
133
134 def extract_notes(midi_bytes: bytes) -> tuple[list[NoteKey], int]:
135 """Parse *midi_bytes* and return ``(notes, ticks_per_beat)``.
136
137 Notes are paired note_on / note_off events. A note_on with velocity=0
138 is treated as note_off. Notes are sorted by start_tick then pitch for
139 deterministic ordering.
140
141 Args:
142 midi_bytes: Raw bytes of a ``.mid`` file.
143
144 Returns:
145 A tuple of (sorted NoteKey list, ticks_per_beat integer).
146
147 Raises:
148 ValueError: When *midi_bytes* cannot be parsed as a MIDI file.
149 """
150 try:
151 mid = mido.MidiFile(file=io.BytesIO(midi_bytes))
152 except Exception as exc:
153 raise ValueError(f"Cannot parse MIDI bytes: {exc}") from exc
154
155 ticks_per_beat: int = int(mid.ticks_per_beat)
156 # (channel, pitch) → (start_tick, velocity)
157 active: dict[tuple[int, int], tuple[int, int]] = {}
158 notes: list[NoteKey] = []
159
160 for track in mid.tracks:
161 abs_tick = 0
162 for msg in track:
163 abs_tick += msg.time
164 if msg.type == "note_on" and msg.velocity > 0:
165 active[(msg.channel, msg.note)] = (abs_tick, msg.velocity)
166 elif msg.type == "note_off" or (
167 msg.type == "note_on" and msg.velocity == 0
168 ):
169 key = (msg.channel, msg.note)
170 if key in active:
171 start, vel = active.pop(key)
172 notes.append(
173 NoteKey(
174 pitch=msg.note,
175 velocity=vel,
176 start_tick=start,
177 duration_ticks=max(abs_tick - start, 1),
178 channel=msg.channel,
179 )
180 )
181
182 # Close any notes still open at end of file with duration 1.
183 for (ch, pitch), (start, vel) in active.items():
184 notes.append(
185 NoteKey(
186 pitch=pitch,
187 velocity=vel,
188 start_tick=start,
189 duration_ticks=1,
190 channel=ch,
191 )
192 )
193
194 notes.sort(key=lambda n: (n["start_tick"], n["pitch"], n["channel"]))
195 return notes, ticks_per_beat
196
197
198 # ---------------------------------------------------------------------------
199 # LCS / Myers algorithm
200 # ---------------------------------------------------------------------------
201
202
203 def lcs_edit_script(
204 base: list[NoteKey],
205 target: list[NoteKey],
206 ) -> list[EditStep]:
207 """Compute the shortest edit script transforming *base* into *target*.
208
209 Uses the standard O(n·m) LCS dynamic-programming algorithm followed by
210 linear-time traceback. Two notes are matched iff all five ``NoteKey``
211 fields are equal.
212
213 Args:
214 base: The base (ancestor) note sequence.
215 target: The target (newer) note sequence.
216
217 Returns:
218 A list of ``EditStep`` with kind ``"keep"``, ``"insert"``, or
219 ``"delete"`` that transforms *base* into *target* in order.
220 The list is minimal: ``len(keep steps) == LCS length``.
221 """
222 n, m = len(base), len(target)
223
224 # dp[i][j] = length of LCS of base[i:] and target[j:]
225 dp: list[list[int]] = [[0] * (m + 1) for _ in range(n + 1)]
226 for i in range(n - 1, -1, -1):
227 for j in range(m - 1, -1, -1):
228 if base[i] == target[j]:
229 dp[i][j] = dp[i + 1][j + 1] + 1
230 else:
231 dp[i][j] = max(dp[i + 1][j], dp[i][j + 1])
232
233 # Traceback: reconstruct the edit script.
234 steps: list[EditStep] = []
235 i, j = 0, 0
236 while i < n or j < m:
237 if i < n and j < m and base[i] == target[j]:
238 steps.append(EditStep("keep", i, j, base[i]))
239 i += 1
240 j += 1
241 elif j < m and (i >= n or dp[i][j + 1] >= dp[i + 1][j]):
242 steps.append(EditStep("insert", i, j, target[j]))
243 j += 1
244 else:
245 steps.append(EditStep("delete", i, j, base[i]))
246 i += 1
247
248 return steps
249
250
251 # ---------------------------------------------------------------------------
252 # Public diff entry point
253 # ---------------------------------------------------------------------------
254
255
256 def diff_midi_notes(
257 base_bytes: bytes,
258 target_bytes: bytes,
259 *,
260 file_path: str = "",
261 ) -> StructuredDelta:
262 """Compute a note-level ``StructuredDelta`` between two MIDI files.
263
264 Parses both files, runs LCS on their note sequences, and returns a
265 ``StructuredDelta`` suitable for embedding in a ``PatchOp.child_ops``
266 list or storing directly as a commit's ``structured_delta``.
267
268 Args:
269 base_bytes: Raw bytes of the base (ancestor) MIDI file.
270 target_bytes: Raw bytes of the target (newer) MIDI file.
271 file_path: Workspace-relative path of the file being diffed.
272 Used only in log messages and ``content_summary`` strings.
273
274 Returns:
275 A ``StructuredDelta`` with ``InsertOp`` and ``DeleteOp`` entries for
276 each note added or removed. The ``summary`` field is human-readable,
277 e.g. ``"3 notes added, 1 note removed"``.
278
279 Raises:
280 ValueError: When either byte string cannot be parsed as MIDI.
281 """
282 base_notes, base_tpb = extract_notes(base_bytes)
283 target_notes, target_tpb = extract_notes(target_bytes)
284 tpb = base_tpb # use base ticks_per_beat for summary formatting
285
286 steps = lcs_edit_script(base_notes, target_notes)
287
288 child_ops: list[DomainOp] = []
289 inserts = 0
290 deletes = 0
291
292 for step in steps:
293 if step.kind == "insert":
294 child_ops.append(
295 InsertOp(
296 op="insert",
297 address=f"note:{step.target_index}",
298 position=step.target_index,
299 content_id=_note_content_id(step.note),
300 content_summary=_note_summary(step.note, tpb),
301 )
302 )
303 inserts += 1
304 elif step.kind == "delete":
305 child_ops.append(
306 DeleteOp(
307 op="delete",
308 address=f"note:{step.base_index}",
309 position=step.base_index,
310 content_id=_note_content_id(step.note),
311 content_summary=_note_summary(step.note, tpb),
312 )
313 )
314 deletes += 1
315 # "keep" steps produce no ops — the note is unchanged.
316
317 parts: list[str] = []
318 if inserts:
319 parts.append(f"{inserts} note{'s' if inserts != 1 else ''} added")
320 if deletes:
321 parts.append(f"{deletes} note{'s' if deletes != 1 else ''} removed")
322 child_summary = ", ".join(parts) if parts else "no note changes"
323
324 logger.debug(
325 "✅ MIDI diff %r: +%d -%d notes (%d LCS steps)",
326 file_path,
327 inserts,
328 deletes,
329 len(steps),
330 )
331
332 return StructuredDelta(
333 domain=_CHILD_DOMAIN,
334 ops=child_ops,
335 summary=child_summary,
336 )
337
338