cgcardona / muse public
midi_merge.py python
427 lines 15.3 KB
5a7035e3 feat: Python 3.12 baseline, dep refresh, and docs navigation index Gabriel Cardona <gabriel@tellurstori.com> 2d ago
1 """MIDI dimension-aware merge for the Muse music plugin.
2
3 This module implements the multidimensional merge that makes Muse meaningfully
4 different from git. Git treats every file as an opaque byte sequence: any
5 two-branch change to the same file is a conflict. Muse understands that a
6 MIDI file has *independent orthogonal axes*, and two collaborators can touch
7 different axes of the same file without conflicting.
8
9 Dimensions
10 ----------
11
12 +---------------+----------------------------------------------------+
13 | Dimension | MIDI event types |
14 +===============+====================================================+
15 | ``melodic`` | ``note_on`` / ``note_off`` (pitch + timing) |
16 +---------------+----------------------------------------------------+
17 | ``rhythmic`` | Alias for ``melodic`` — timing is inseparable from |
18 | | pitch in the MIDI event model; provided as a |
19 | | user-facing label in ``.museattributes`` rules. |
20 +---------------+----------------------------------------------------+
21 | ``harmonic`` | ``pitchwheel`` events |
22 +---------------+----------------------------------------------------+
23 | ``dynamic`` | ``control_change`` events |
24 +---------------+----------------------------------------------------+
25 | ``structural``| ``set_tempo``, ``time_signature``, ``key_signature``,|
26 | | ``program_change``, text/sysex meta events |
27 +---------------+----------------------------------------------------+
28
29 Merge algorithm
30 ---------------
31
32 1. Parse ``base``, ``left``, and ``right`` MIDI bytes into event streams.
33 2. Convert to absolute-tick representation and bucket by dimension.
34 3. Hash each bucket; compare ``base ↔ left`` and ``base ↔ right`` to detect
35 per-dimension changes.
36 4. For each dimension apply the winning side determined by ``.museattributes``
37 strategy (or the standard one-sided-change rule when no conflict exists).
38 5. Reconstruct a valid MIDI file by merging winning dimension slices, sorting
39 by absolute tick, converting back to delta-time, and writing to bytes.
40
41 Public API
42 ----------
43
44 - :func:`extract_dimensions` — parse MIDI bytes → ``dict[dim, DimensionSlice]``
45 - :func:`merge_midi_dimensions` — three-way dimension merge → bytes or ``None``
46 - :func:`dimension_conflict_detail` — per-dimension change report for logging
47 """
48 from __future__ import annotations
49
50 import hashlib
51 import io
52 import json
53 from dataclasses import dataclass, field
54
55 import mido
56
57 from muse.core.attributes import AttributeRule, resolve_strategy
58
59 # ---------------------------------------------------------------------------
60 # Dimension constants
61 # ---------------------------------------------------------------------------
62
63 #: Internal dimension names used as dict keys throughout this module.
64 INTERNAL_DIMS: list[str] = ["notes", "harmonic", "dynamic", "structural"]
65
66 #: User-facing dimension names from .museattributes mapped to internal buckets.
67 #: Both "melodic" and "rhythmic" map to the same "notes" bucket because MIDI
68 #: event timing and pitch are carried in the same event structure.
69 DIM_ALIAS: dict[str, str] = {
70 "melodic": "notes",
71 "rhythmic": "notes",
72 "harmonic": "harmonic",
73 "dynamic": "dynamic",
74 "structural": "structural",
75 }
76
77 #: Canonical alias → internal dim name, with internal names as pass-throughs.
78 _CANONICAL: dict[str, str] = {**DIM_ALIAS, "notes": "notes"}
79
80
81 # ---------------------------------------------------------------------------
82 # Data types
83 # ---------------------------------------------------------------------------
84
85
86 @dataclass
87 class DimensionSlice:
88 """Events belonging to one dimension of a MIDI file.
89
90 ``events`` is a list of ``(abs_tick, mido.Message)`` pairs sorted by
91 ascending absolute tick. ``content_hash`` is the SHA-256 digest of the
92 canonical JSON serialisation of the event list (used for change detection
93 without loading file bytes).
94 """
95
96 name: str
97 events: list[tuple[int, mido.Message]] = field(default_factory=list)
98 content_hash: str = ""
99
100 def __post_init__(self) -> None:
101 if not self.content_hash:
102 self.content_hash = _hash_events(self.events)
103
104
105 @dataclass
106 class MidiDimensions:
107 """All dimension slices extracted from one MIDI file."""
108
109 ticks_per_beat: int
110 file_type: int
111 slices: dict[str, DimensionSlice] # internal dim name → slice
112
113 def get(self, user_dim: str) -> DimensionSlice:
114 """Return the slice for a user-facing or internal dimension name."""
115 internal = _CANONICAL.get(user_dim, user_dim)
116 return self.slices[internal]
117
118
119 # ---------------------------------------------------------------------------
120 # Internal helpers
121 # ---------------------------------------------------------------------------
122
123
124 def _classify_event(msg: mido.Message) -> str | None:
125 """Map a mido Message to an internal dimension bucket, or ``None`` to skip."""
126 t = msg.type
127 if t in ("note_on", "note_off"):
128 return "notes"
129 if t == "pitchwheel":
130 return "harmonic"
131 if t == "control_change":
132 return "dynamic"
133 if t in (
134 "set_tempo",
135 "time_signature",
136 "key_signature",
137 "program_change",
138 "sysex",
139 "text",
140 "copyright",
141 "track_name",
142 "instrument_name",
143 "lyrics",
144 "marker",
145 "cue_marker",
146 "sequencer_specific",
147 "end_of_track",
148 ):
149 return "structural"
150 # Unrecognised meta events → structural bucket as a safe default.
151 if getattr(msg, "is_meta", False):
152 return "structural"
153 return None
154
155
156 type _MsgVal = int | str | list[int]
157
158
159 def _msg_to_dict(msg: mido.Message) -> dict[str, _MsgVal]:
160 """Serialise a mido Message to a JSON-compatible dict."""
161 d: dict[str, _MsgVal] = {"type": msg.type}
162 for attr in ("channel", "note", "velocity", "control", "value",
163 "pitch", "program", "numerator", "denominator",
164 "clocks_per_click", "notated_32nd_notes_per_beat",
165 "tempo", "key", "scale", "text", "data"):
166 if hasattr(msg, attr):
167 raw = getattr(msg, attr)
168 if isinstance(raw, (bytes, bytearray)):
169 d[attr] = list(raw)
170 elif isinstance(raw, str):
171 d[attr] = raw
172 elif isinstance(raw, int):
173 d[attr] = raw
174 # Other types (float, etc.) are skipped — not present in standard MIDI
175 return d
176
177
178 def _hash_events(events: list[tuple[int, mido.Message]]) -> str:
179 """SHA-256 of the canonical JSON representation of an event list."""
180 payload = json.dumps(
181 [(tick, _msg_to_dict(msg)) for tick, msg in events],
182 sort_keys=True,
183 separators=(",", ":"),
184 ).encode()
185 return hashlib.sha256(payload).hexdigest()
186
187
188 def _to_absolute(track: mido.MidiTrack) -> list[tuple[int, mido.Message]]:
189 """Convert a delta-time track to a list of ``(abs_tick, msg)`` pairs."""
190 result: list[tuple[int, mido.Message]] = []
191 abs_tick = 0
192 for msg in track:
193 abs_tick += msg.time
194 result.append((abs_tick, msg))
195 return result
196
197
198 # ---------------------------------------------------------------------------
199 # Public: extract_dimensions
200 # ---------------------------------------------------------------------------
201
202
203 def extract_dimensions(midi_bytes: bytes) -> MidiDimensions:
204 """Parse *midi_bytes* and bucket events by dimension.
205
206 Args:
207 midi_bytes: Raw bytes of a ``.mid`` file.
208
209 Returns:
210 A :class:`MidiDimensions` with one :class:`DimensionSlice` per
211 internal dimension. Events are sorted by ascending absolute tick.
212
213 Raises:
214 ValueError: If *midi_bytes* cannot be parsed as a MIDI file.
215 """
216 try:
217 mid = mido.MidiFile(file=io.BytesIO(midi_bytes))
218 except Exception as exc:
219 raise ValueError(f"Failed to parse MIDI data: {exc}") from exc
220
221 buckets: dict[str, list[tuple[int, mido.Message]]] = {
222 dim: [] for dim in INTERNAL_DIMS
223 }
224
225 for track in mid.tracks:
226 for abs_tick, msg in _to_absolute(track):
227 bucket = _classify_event(msg)
228 if bucket is not None:
229 buckets[bucket].append((abs_tick, msg))
230
231 # Sort each bucket by ascending absolute tick, then by event type for
232 # determinism when multiple events share the same tick.
233 for dim in INTERNAL_DIMS:
234 buckets[dim].sort(key=lambda x: (x[0], x[1].type))
235
236 slices = {dim: DimensionSlice(name=dim, events=events)
237 for dim, events in buckets.items()}
238 return MidiDimensions(
239 ticks_per_beat=mid.ticks_per_beat,
240 file_type=mid.type,
241 slices=slices,
242 )
243
244
245 # ---------------------------------------------------------------------------
246 # Public: dimension_conflict_detail
247 # ---------------------------------------------------------------------------
248
249
250 def dimension_conflict_detail(
251 base: MidiDimensions,
252 left: MidiDimensions,
253 right: MidiDimensions,
254 ) -> dict[str, str]:
255 """Return a per-dimension change report for a conflicting file.
256
257 Returns a dict mapping internal dimension name to one of:
258
259 - ``"unchanged"`` — neither side changed this dimension.
260 - ``"left_only"`` — only the left (ours) side changed.
261 - ``"right_only"`` — only the right (theirs) side changed.
262 - ``"both"`` — both sides changed; a dimension-level conflict.
263
264 This is used by :func:`merge_midi_dimensions` and can also be surfaced
265 in ``muse merge`` output for human-readable conflict diagnostics.
266 """
267 report: dict[str, str] = {}
268 for dim in INTERNAL_DIMS:
269 base_hash = base.slices[dim].content_hash
270 left_hash = left.slices[dim].content_hash
271 right_hash = right.slices[dim].content_hash
272 left_changed = base_hash != left_hash
273 right_changed = base_hash != right_hash
274 if left_changed and right_changed:
275 report[dim] = "both"
276 elif left_changed:
277 report[dim] = "left_only"
278 elif right_changed:
279 report[dim] = "right_only"
280 else:
281 report[dim] = "unchanged"
282 return report
283
284
285 # ---------------------------------------------------------------------------
286 # Reconstruction helpers
287 # ---------------------------------------------------------------------------
288
289
290 def _events_to_track(
291 events: list[tuple[int, mido.Message]],
292 ) -> mido.MidiTrack:
293 """Convert absolute-tick events to a mido MidiTrack with delta times."""
294 track = mido.MidiTrack()
295 prev_tick = 0
296 for abs_tick, msg in sorted(events, key=lambda x: (x[0], x[1].type)):
297 delta = abs_tick - prev_tick
298 # mido Message objects are immutable; copy() gives us a mutable clone.
299 new_msg = msg.copy(time=delta)
300 track.append(new_msg)
301 prev_tick = abs_tick
302 # Ensure every track ends with end_of_track.
303 if not track or track[-1].type != "end_of_track":
304 track.append(mido.MetaMessage("end_of_track", time=0))
305 return track
306
307
308 def _reconstruct(
309 ticks_per_beat: int,
310 winning_slices: dict[str, list[tuple[int, mido.Message]]],
311 ) -> bytes:
312 """Build a type-0 MIDI file from winning dimension event lists.
313
314 All dimension events are merged into a single track (type-0) for
315 maximum compatibility. The absolute-tick ordering is preserved.
316 """
317 all_events: list[tuple[int, mido.Message]] = []
318 for events in winning_slices.values():
319 all_events.extend(events)
320
321 # Remove duplicate end_of_track messages; add exactly one at the end.
322 all_events = [
323 (tick, msg) for tick, msg in all_events
324 if msg.type != "end_of_track"
325 ]
326 all_events.sort(key=lambda x: (x[0], x[1].type))
327
328 track = _events_to_track(all_events)
329 mid = mido.MidiFile(type=0, ticks_per_beat=ticks_per_beat)
330 mid.tracks.append(track)
331
332 buf = io.BytesIO()
333 mid.save(file=buf)
334 return buf.getvalue()
335
336
337 # ---------------------------------------------------------------------------
338 # Public: merge_midi_dimensions
339 # ---------------------------------------------------------------------------
340
341
342 def merge_midi_dimensions(
343 base_bytes: bytes,
344 left_bytes: bytes,
345 right_bytes: bytes,
346 attrs_rules: list[AttributeRule],
347 path: str,
348 ) -> tuple[bytes, dict[str, str]] | None:
349 """Attempt a dimension-level three-way merge of a MIDI file.
350
351 For each internal dimension:
352
353 - If neither side changed → keep base.
354 - If only one side changed → take that side (clean auto-merge).
355 - If both sides changed → consult ``.museattributes`` strategy:
356
357 * ``ours`` / ``theirs`` → take the specified side; record in report.
358 * ``manual`` / ``auto`` / ``union`` → unresolvable; return ``None``.
359
360 Args:
361 base_bytes: MIDI bytes for the common ancestor.
362 left_bytes: MIDI bytes for the ours (left) branch.
363 right_bytes: MIDI bytes for the theirs (right) branch.
364 attrs_rules: Rule list from :func:`muse.core.attributes.load_attributes`.
365 path: Workspace-relative POSIX path (used for strategy lookup).
366
367 Returns:
368 A ``(merged_bytes, dimension_report)`` tuple when all dimension
369 conflicts can be resolved, or ``None`` when at least one dimension
370 conflict has no resolvable strategy.
371
372 *dimension_report* maps each internal dimension name to the side
373 chosen: ``"base"``, ``"left"``, ``"right"``, or the strategy string.
374
375 Raises:
376 ValueError: If any of the byte strings cannot be parsed as MIDI.
377 """
378 base_dims = extract_dimensions(base_bytes)
379 left_dims = extract_dimensions(left_bytes)
380 right_dims = extract_dimensions(right_bytes)
381
382 detail = dimension_conflict_detail(base_dims, left_dims, right_dims)
383
384 winning_slices: dict[str, list[tuple[int, mido.Message]]] = {}
385 dimension_report: dict[str, str] = {}
386
387 for dim in INTERNAL_DIMS:
388 change = detail[dim]
389
390 if change == "unchanged":
391 winning_slices[dim] = base_dims.slices[dim].events
392 dimension_report[dim] = "base"
393
394 elif change == "left_only":
395 winning_slices[dim] = left_dims.slices[dim].events
396 dimension_report[dim] = "left"
397
398 elif change == "right_only":
399 winning_slices[dim] = right_dims.slices[dim].events
400 dimension_report[dim] = "right"
401
402 else:
403 # Both sides changed — consult .museattributes for this dimension.
404 # Try user-facing aliases first, then internal name.
405 user_dim_names = [k for k, v in DIM_ALIAS.items() if v == dim] + [dim]
406 strategy = "auto"
407 for user_dim in user_dim_names:
408 s = resolve_strategy(attrs_rules, path, user_dim)
409 if s != "auto":
410 strategy = s
411 break
412 # Also try dimension wildcard ("*")
413 if strategy == "auto":
414 strategy = resolve_strategy(attrs_rules, path, "*")
415
416 if strategy == "ours":
417 winning_slices[dim] = left_dims.slices[dim].events
418 dimension_report[dim] = f"ours ({dim})"
419 elif strategy == "theirs":
420 winning_slices[dim] = right_dims.slices[dim].events
421 dimension_report[dim] = f"theirs ({dim})"
422 else:
423 # "auto", "union", "manual" — cannot resolve this dimension.
424 return None
425
426 merged_bytes = _reconstruct(base_dims.ticks_per_beat, winning_slices)
427 return merged_bytes, dimension_report