cgcardona / muse public
muse_transpose.py python
590 lines 19.3 KB
12901c5a Initial extraction from tellurstori/maestro cgcardona <gabriel@tellurstori.com> 4d ago
1 """Muse Transpose Service — apply MIDI pitch transposition as a Muse commit.
2
3 Provides:
4
5 - ``parse_interval`` — convert "+3", "up-minor3rd", "down-perfect5th" to signed semitones.
6 - ``update_key_metadata`` — transpose a key string (e.g. "Eb major" → "F major").
7 - ``transpose_midi_bytes`` — pure function: raw MIDI bytes → transposed bytes.
8 - ``apply_transpose_to_workdir`` — apply transposition to all MIDI files in muse-work/.
9 - ``TransposeResult`` — named result type for ``muse transpose`` output.
10
11 MIDI transposition rules:
12 - Note-On (0x9n) and Note-Off (0x8n) events on non-drum channels are shifted.
13 - Channel 9 (MIDI drum channel) is always excluded — drums are unpitched.
14 - Notes are clamped to [0, 127] to stay within MIDI range.
15 - All non-note events (tempo, program change, CC, sysex) are preserved verbatim.
16
17 Boundary rules:
18 - Must NOT import StateStore, EntityRegistry, or app.core.*.
19 - Must NOT import LLM handlers or maestro_* modules.
20 - Pure data — no FastAPI, no DB access, no side effects beyond file I/O in
21 ``apply_transpose_to_workdir``.
22 """
23 from __future__ import annotations
24
25 import logging
26 import pathlib
27 import struct
28 from dataclasses import dataclass, field
29
30 logger = logging.getLogger(__name__)
31
32 # ---------------------------------------------------------------------------
33 # Constants
34 # ---------------------------------------------------------------------------
35
36 # Named interval → absolute semitone count (direction supplied by up-/down- prefix)
37 _NAMED_INTERVALS: dict[str, int] = {
38 "unison": 0,
39 "minor2nd": 1,
40 "min2nd": 1,
41 "major2nd": 2,
42 "maj2nd": 2,
43 "minor3rd": 3,
44 "min3rd": 3,
45 "major3rd": 4,
46 "maj3rd": 4,
47 "perfect4th": 5,
48 "perf4th": 5,
49 "p4th": 5,
50 "augmented4th": 6,
51 "aug4th": 6,
52 "tritone": 6,
53 "diminished5th": 6,
54 "dim5th": 6,
55 "perfect5th": 7,
56 "perf5th": 7,
57 "p5th": 7,
58 "minor6th": 8,
59 "min6th": 8,
60 "major6th": 9,
61 "maj6th": 9,
62 "minor7th": 10,
63 "min7th": 10,
64 "major7th": 11,
65 "maj7th": 11,
66 "octave": 12,
67 }
68
69 # MIDI channel 9 (0-indexed) is the universal drum channel.
70 _DRUM_CHANNEL = 9
71
72 # Note name → semitone (C=0, chromatic ascending)
73 _NOTE_TO_SEMITONE: dict[str, int] = {
74 "c": 0,
75 "c#": 1,
76 "db": 1,
77 "d": 2,
78 "d#": 3,
79 "eb": 3,
80 "e": 4,
81 "fb": 4,
82 "f": 5,
83 "e#": 5,
84 "f#": 6,
85 "gb": 6,
86 "g": 7,
87 "g#": 8,
88 "ab": 8,
89 "a": 9,
90 "a#": 10,
91 "bb": 10,
92 "b": 11,
93 "cb": 11,
94 }
95
96 # Preferred note name for each semitone (0=C … 11=B).
97 # Uses flats for accidentals matching common Western music notation.
98 _SEMITONE_TO_NOTE: list[str] = [
99 "C", "Db", "D", "Eb", "E", "F", "F#", "G", "Ab", "A", "Bb", "B"
100 ]
101
102
103 # ---------------------------------------------------------------------------
104 # Named result types
105 # ---------------------------------------------------------------------------
106
107
108 @dataclass(frozen=True)
109 class TransposeResult:
110 """Result of a ``muse transpose <interval> [<commit>]`` operation.
111
112 Records the source commit, the interval applied, which MIDI files were
113 modified, the new commit ID (``None`` in dry-run mode), and key metadata
114 before and after.
115
116 Agent use case: after transposition, agents can inspect ``new_commit_id``
117 to verify the commit was created and ``new_key`` to update their musical
118 context about the current key center. ``files_modified`` tells the agent
119 which tracks changed so it can selectively re-render only those tracks.
120 """
121
122 source_commit_id: str
123 """Commit that was the source of the transposition."""
124
125 semitones: int
126 """Signed semitone offset applied (positive = up, negative = down)."""
127
128 files_modified: list[str] = field(default_factory=list)
129 """Relative paths of MIDI files that had notes transposed."""
130
131 files_skipped: list[str] = field(default_factory=list)
132 """Relative paths of non-MIDI or excluded files."""
133
134 new_commit_id: str | None = None
135 """New commit ID created for the transposed snapshot. None in dry-run mode."""
136
137 original_key: str | None = None
138 """Key metadata before transposition, or None if not annotated."""
139
140 new_key: str | None = None
141 """Updated key metadata after transposition, or None if original was absent."""
142
143 dry_run: bool = False
144 """True if this was a dry-run (no files written, no commit created)."""
145
146
147 # ---------------------------------------------------------------------------
148 # Interval parsing
149 # ---------------------------------------------------------------------------
150
151
152 def parse_interval(interval_str: str) -> int:
153 """Parse an interval string to a signed semitone count.
154
155 Accepts:
156 - Signed integers: ``"+3"``, ``"-5"``, ``"12"`` (no sign = positive).
157 - Named intervals: ``"up-minor3rd"``, ``"down-perfect5th"``, ``"up-octave"``.
158
159 Named interval format is ``"<direction>-<name>"`` where direction is
160 ``"up"`` (positive) or ``"down"`` (negative), and name is one of the keys
161 in ``_NAMED_INTERVALS``.
162
163 Args:
164 interval_str: Interval descriptor from the CLI argument.
165
166 Returns:
167 Signed semitone count. Positive = up, negative = down.
168
169 Raises:
170 ValueError: If the interval string cannot be parsed.
171 """
172 s = interval_str.strip()
173 try:
174 return int(s)
175 except ValueError:
176 pass
177
178 lower = s.lower()
179 if lower.startswith("up-"):
180 direction = 1
181 name = lower[3:]
182 elif lower.startswith("down-"):
183 direction = -1
184 name = lower[5:]
185 else:
186 raise ValueError(
187 f"Cannot parse interval {s!r}. "
188 "Use a signed integer (+3, -5) or a named interval "
189 "(up-minor3rd, down-perfect5th, up-octave)."
190 )
191
192 semitones = _NAMED_INTERVALS.get(name)
193 if semitones is None:
194 valid = ", ".join(sorted(_NAMED_INTERVALS))
195 raise ValueError(
196 f"Unknown interval name {name!r}. "
197 f"Valid names: {valid}"
198 )
199 return direction * semitones
200
201
202 # ---------------------------------------------------------------------------
203 # Key metadata update
204 # ---------------------------------------------------------------------------
205
206
207 def update_key_metadata(key_str: str, semitones: int) -> str:
208 """Transpose a key string by *semitones* and return the updated key name.
209
210 Parses strings in the format ``"<note> <mode>"`` (e.g. ``"Eb major"``,
211 ``"F# minor"``). The root note is transposed; the mode string is preserved
212 verbatim. Unrecognized root notes are returned unchanged so callers can
213 safely pass arbitrary metadata strings without crashing.
214
215 Args:
216 key_str: Key string to transpose (e.g. ``"Eb major"``).
217 semitones: Signed semitone offset.
218
219 Returns:
220 Updated key string (e.g. ``"F major"`` after +2 from ``"Eb major"``).
221 """
222 parts = key_str.strip().split()
223 if not parts:
224 return key_str
225
226 root_name = parts[0]
227 mode_parts = parts[1:]
228
229 semitone_val = _NOTE_TO_SEMITONE.get(root_name.lower())
230 if semitone_val is None:
231 logger.debug("⚠️ Unknown key root %r — returning key string unchanged", root_name)
232 return key_str
233
234 new_semitone = (semitone_val + semitones) % 12
235 new_root = _SEMITONE_TO_NOTE[new_semitone]
236 return " ".join([new_root] + mode_parts)
237
238
239 # ---------------------------------------------------------------------------
240 # Low-level MIDI parsing helpers
241 # ---------------------------------------------------------------------------
242
243
244 def _read_vlq(data: bytes, pos: int) -> tuple[int, int]:
245 """Read a MIDI variable-length quantity starting at *pos*.
246
247 Returns ``(value, new_pos)`` where *new_pos* points past the last byte
248 consumed. Raises ``IndexError`` if the data is truncated mid-VLQ.
249
250 VLQ encoding: each byte's high bit signals a continuation byte follows.
251 The low 7 bits of each byte are concatenated MSB-first to form the value.
252 """
253 value = 0
254 while True:
255 b = data[pos]
256 pos += 1
257 value = (value << 7) | (b & 0x7F)
258 if not (b & 0x80):
259 break
260 return value, pos
261
262
263 def _get_track_name(track_data: bytes) -> str | None:
264 """Extract the track name from a MIDI track chunk's raw event data.
265
266 Scans for the first Track Name meta-event (``0xFF 0x03``) and returns the
267 name decoded as latin-1. Returns ``None`` if no name meta-event is found
268 before the stream ends or becomes unparseable.
269
270 This enables the ``--track`` filter in ``muse transpose``: only tracks whose
271 name contains the filter substring (case-insensitive) are transposed.
272 """
273 pos = 0
274 length = len(track_data)
275 running_status = 0
276
277 while pos < length:
278 # Skip delta time (VLQ: bytes with high bit set continue)
279 while pos < length and (track_data[pos] & 0x80):
280 pos += 1
281 if pos >= length:
282 break
283 pos += 1 # last VLQ byte of delta time
284
285 if pos >= length:
286 break
287
288 b = track_data[pos]
289
290 if b == 0xFF: # meta event
291 pos += 1
292 if pos >= length:
293 break
294 meta_type = track_data[pos]
295 pos += 1
296 try:
297 meta_len, pos = _read_vlq(track_data, pos)
298 except IndexError:
299 break
300 if meta_type == 0x03 and pos + meta_len <= length: # Track Name
301 return track_data[pos : pos + meta_len].decode("latin-1")
302 pos += meta_len
303
304 elif b == 0xF0 or b == 0xF7: # sysex
305 pos += 1
306 try:
307 sysex_len, pos = _read_vlq(track_data, pos)
308 except IndexError:
309 break
310 pos += sysex_len
311
312 else:
313 # MIDI channel event (may use running status)
314 if b & 0x80:
315 running_status = b
316 pos += 1
317 status = running_status
318 msg_type = (status >> 4) & 0x0F
319 if msg_type in (0x8, 0x9, 0xA, 0xB, 0xE): # 2 data bytes
320 pos += 2
321 elif msg_type in (0xC, 0xD): # 1 data byte
322 pos += 1
323 else:
324 break # unrecognised — stop scan
325
326 return None
327
328
329 def _transpose_track_data(track_data: bytes, semitones: int) -> bytes:
330 """Transpose MIDI notes in a single track's event data.
331
332 Scans the MIDI event stream and modifies Note-On (0x9n) and Note-Off (0x8n)
333 events on non-drum channels (channel != 9). Notes are clamped to [0, 127].
334 All other events (meta, sysex, CC, program change, pitch bend, etc.) are
335 preserved byte-for-byte.
336
337 The modification is done in-place on a bytearray copy so the track length
338 is unchanged — only the note byte values differ. This guarantees the MTrk
339 chunk length header stays valid without re-encoding.
340
341 Args:
342 track_data: Raw event bytes from an MTrk chunk (after the 8-byte header).
343 semitones: Signed semitone offset to apply.
344
345 Returns:
346 Modified event data of the same length as *track_data*.
347 """
348 buf = bytearray(track_data)
349 pos = 0
350 length = len(track_data)
351 running_status = 0
352
353 while pos < length:
354 # Skip delta time (VLQ)
355 while pos < length and (track_data[pos] & 0x80):
356 pos += 1
357 if pos >= length:
358 break
359 pos += 1 # last VLQ byte
360
361 if pos >= length:
362 break
363
364 b = track_data[pos]
365
366 if b == 0xFF: # meta event — skip completely
367 pos += 1
368 if pos >= length:
369 break
370 pos += 1 # meta type
371 try:
372 meta_len, pos = _read_vlq(track_data, pos)
373 except IndexError:
374 break
375 pos += meta_len
376
377 elif b == 0xF0 or b == 0xF7: # sysex — skip completely
378 pos += 1
379 try:
380 sysex_len, pos = _read_vlq(track_data, pos)
381 except IndexError:
382 break
383 pos += sysex_len
384
385 else:
386 # MIDI channel message (possibly running status)
387 if b & 0x80:
388 running_status = b
389 pos += 1
390 status = running_status
391 channel = status & 0x0F
392 msg_type = (status >> 4) & 0x0F
393
394 if msg_type in (0x8, 0x9): # note-off, note-on
395 if pos + 1 < length and channel != _DRUM_CHANNEL:
396 original_note = track_data[pos]
397 buf[pos] = max(0, min(127, original_note + semitones))
398 pos += 2
399 elif msg_type in (0xA, 0xB, 0xE): # poly pressure, CC, pitch bend
400 pos += 2
401 elif msg_type in (0xC, 0xD): # program change, channel pressure
402 pos += 1
403 else:
404 logger.warning(
405 "⚠️ Unknown MIDI event type 0x%X at byte %d — stopping track parse",
406 msg_type,
407 pos,
408 )
409 break
410
411 return bytes(buf)
412
413
414 # ---------------------------------------------------------------------------
415 # Public MIDI transposition API
416 # ---------------------------------------------------------------------------
417
418
419 def transpose_midi_bytes(
420 data: bytes,
421 semitones: int,
422 track_filter: str | None = None,
423 ) -> tuple[bytes, int]:
424 """Apply pitch transposition to a MIDI file's raw bytes.
425
426 Parses the standard MIDI file structure (MThd header + MTrk chunks) and
427 transposes Note-On/Note-Off events on non-drum channels. The file
428 structure, chunk layout, and all non-note events are preserved exactly.
429
430 When *track_filter* is provided, only MTrk chunks whose Track Name
431 meta-event (0xFF 0x03) contains the filter substring (case-insensitive)
432 are transposed; other tracks are copied verbatim.
433
434 Args:
435 data: Raw MIDI file bytes.
436 semitones: Signed semitone offset to apply.
437 track_filter: If set, only tracks whose name matches this substring
438 (case-insensitive) are transposed.
439
440 Returns:
441 ``(modified_bytes, notes_changed_count)`` where *notes_changed_count*
442 is the number of note bytes that were actually modified. If *data* is
443 not a valid MIDI file (no ``MThd`` header), returns ``(data, 0)``.
444 """
445 if len(data) < 14 or data[:4] != b"MThd":
446 return data, 0
447
448 result = bytearray()
449 pos = 0
450
451 # MThd: tag(4) + length(4) + format(2) + ntracks(2) + division(2) = 14 bytes
452 # The length field itself says how many bytes follow it in the header chunk.
453 header_chunk_data_len = struct.unpack(">I", data[4:8])[0]
454 header_end = 8 + header_chunk_data_len
455 result.extend(data[:header_end])
456 pos = header_end
457
458 notes_changed = 0
459
460 while pos + 8 <= len(data):
461 chunk_tag = data[pos : pos + 4]
462 chunk_len = struct.unpack(">I", data[pos + 4 : pos + 8])[0]
463 chunk_start = pos + 8
464 chunk_end = chunk_start + chunk_len
465 chunk_data = data[chunk_start:chunk_end]
466 pos = chunk_end
467
468 if chunk_tag != b"MTrk":
469 # Non-track chunk (e.g. instrument-specific) — copy verbatim
470 result.extend(data[pos - 8 - chunk_len : pos])
471 continue
472
473 # Decide whether this track is in scope for transposition
474 should_transpose = True
475 if track_filter is not None:
476 track_name = _get_track_name(chunk_data)
477 if track_name is None or track_filter.lower() not in track_name.lower():
478 should_transpose = False
479 logger.debug(
480 "⚠️ Track %r does not match filter %r — copying verbatim",
481 track_name,
482 track_filter,
483 )
484
485 if should_transpose and semitones != 0:
486 modified_track = _transpose_track_data(chunk_data, semitones)
487 # Count how many note bytes changed
488 for orig_byte, new_byte in zip(chunk_data, modified_track):
489 if orig_byte != new_byte:
490 notes_changed += 1
491 else:
492 modified_track = chunk_data
493
494 result.extend(b"MTrk")
495 result.extend(struct.pack(">I", len(modified_track)))
496 result.extend(modified_track)
497
498 return bytes(result), notes_changed
499
500
501 # ---------------------------------------------------------------------------
502 # Workdir-level transposition
503 # ---------------------------------------------------------------------------
504
505
506 def apply_transpose_to_workdir(
507 workdir: pathlib.Path,
508 semitones: int,
509 track_filter: str | None = None,
510 section_filter: str | None = None,
511 dry_run: bool = False,
512 ) -> tuple[list[str], list[str]]:
513 """Apply MIDI transposition to all MIDI files under *workdir*.
514
515 Finds all ``.mid`` and ``.midi`` files, transposes them (excluding drum
516 channels), and writes modified files back in place unless *dry_run* is set.
517
518 Section filtering is a stub: if *section_filter* is provided a warning is
519 logged and the filter is ignored. Full section-scoped transposition requires
520 section boundary markers embedded in the committed MIDI metadata — a future
521 enhancement tracked separately.
522
523 Args:
524 workdir: Path to the ``muse-work/`` directory.
525 semitones: Signed semitone offset (positive = up, negative = down).
526 track_filter: Case-insensitive track name substring filter, or None.
527 section_filter: Section name filter (stub — ignored with a warning).
528 dry_run: When True, compute what would change but write nothing.
529
530 Returns:
531 ``(files_modified, files_skipped)`` — lists of POSIX paths relative
532 to *workdir*. Modified files had at least one note byte changed.
533 Skipped files are non-MIDI, unreadable, or had no transposable notes.
534 """
535 if section_filter is not None:
536 logger.warning(
537 "⚠️ --section filter is not yet implemented for muse transpose; "
538 "transposing all sections. (section=%r)",
539 section_filter,
540 )
541
542 files_modified: list[str] = []
543 files_skipped: list[str] = []
544
545 if not workdir.exists():
546 logger.warning("⚠️ muse-work/ directory not found at %s", workdir)
547 return files_modified, files_skipped
548
549 for file_path in sorted(workdir.rglob("*")):
550 if not file_path.is_file():
551 continue
552 suffix = file_path.suffix.lower()
553 if suffix not in (".mid", ".midi"):
554 continue
555
556 rel = file_path.relative_to(workdir).as_posix()
557 try:
558 original = file_path.read_bytes()
559 except OSError as exc:
560 logger.warning("⚠️ Cannot read %s: %s", rel, exc)
561 files_skipped.append(rel)
562 continue
563
564 transposed, notes_changed = transpose_midi_bytes(original, semitones, track_filter)
565
566 if transposed == original or notes_changed == 0:
567 logger.debug(
568 "⚠️ %s unchanged after transposition (no valid pitched notes found)", rel
569 )
570 files_skipped.append(rel)
571 continue
572
573 if not dry_run:
574 try:
575 file_path.write_bytes(transposed)
576 except OSError as exc:
577 logger.error("❌ Cannot write transposed %s: %s", rel, exc)
578 files_skipped.append(rel)
579 continue
580
581 files_modified.append(rel)
582 logger.info(
583 "✅ %s %s (%+d semitones, %d note byte(s) changed)",
584 "Would transpose" if dry_run else "Transposed",
585 rel,
586 semitones,
587 notes_changed,
588 )
589
590 return files_modified, files_skipped