cgcardona / muse public
midi_parser.py python
393 lines 13.1 KB
12901c5a Initial extraction from tellurstori/maestro cgcardona <gabriel@tellurstori.com> 4d ago
1 """MIDI and MusicXML parsing for ``muse import``.
2
3 Converts standard music file formats into Muse's internal note representation:
4 a list of :class:`NoteEvent` objects and a :class:`MuseImportData` container.
5
6 Supported formats
7 -----------------
8 - ``.mid`` / ``.midi`` — Standard MIDI File via ``mido``
9 - ``.xml`` / ``.musicxml`` — MusicXML via Python's built-in ``xml.etree.ElementTree``
10
11 Named result types registered in ``docs/reference/type_contracts.md``:
12 - ``MuseImportData``
13 - ``NoteEvent``
14 """
15 from __future__ import annotations
16
17 import dataclasses
18 import logging
19 import pathlib
20 import xml.etree.ElementTree as ET
21 from typing import Any
22
23 logger = logging.getLogger(__name__)
24
25 #: File extensions accepted by this module.
26 SUPPORTED_MIDI_EXTENSIONS = {".mid", ".midi"}
27 SUPPORTED_XML_EXTENSIONS = {".xml", ".musicxml"}
28 SUPPORTED_EXTENSIONS = SUPPORTED_MIDI_EXTENSIONS | SUPPORTED_XML_EXTENSIONS
29
30
31 @dataclasses.dataclass
32 class NoteEvent:
33 """A single sounding note extracted from an imported file."""
34
35 pitch: int
36 velocity: int
37 start_tick: int
38 duration_ticks: int
39 channel: int
40 channel_name: str
41
42
43 @dataclasses.dataclass
44 class MuseImportData:
45 """All data extracted from a single imported music file."""
46
47 source_path: pathlib.Path
48 format: str
49 ticks_per_beat: int
50 tempo_bpm: float
51 notes: list[NoteEvent]
52 tracks: list[str]
53 raw_meta: dict[str, Any]
54
55
56 def parse_file(path: pathlib.Path) -> MuseImportData:
57 """Dispatch to the correct parser based on file extension.
58
59 Raises:
60 ValueError: When the extension is not in :data:`SUPPORTED_EXTENSIONS`.
61 FileNotFoundError: When the file does not exist.
62 RuntimeError: When the file is malformed.
63 """
64 if not path.exists():
65 raise FileNotFoundError(f"File not found: {path}")
66 ext = path.suffix.lower()
67 if ext in SUPPORTED_MIDI_EXTENSIONS:
68 return parse_midi_file(path)
69 if ext in SUPPORTED_XML_EXTENSIONS:
70 return parse_musicxml_file(path)
71 supported = ", ".join(sorted(SUPPORTED_EXTENSIONS))
72 raise ValueError(
73 f"Unsupported file extension '{path.suffix}'. Supported: {supported}"
74 )
75
76
77 def parse_midi_file(path: pathlib.Path) -> MuseImportData:
78 """Parse a Standard MIDI File into a :class:`MuseImportData`.
79
80 Uses ``mido``. Note-on with velocity=0 is treated as note-off.
81
82 Raises:
83 RuntimeError: When ``mido`` cannot read the file.
84 """
85 try:
86 import mido
87 except ImportError:
88 raise RuntimeError(
89 "mido is required for MIDI import. "
90 "It is pre-installed in the Maestro Docker image."
91 )
92
93 try:
94 mid = mido.MidiFile(str(path))
95 except Exception as exc:
96 raise RuntimeError(f"Cannot parse MIDI file '{path}': {exc}") from exc
97
98 ticks_per_beat = int(mid.ticks_per_beat)
99 tempo_us: int = 500_000 # 120 BPM default
100 notes: list[NoteEvent] = []
101 # (channel, pitch) -> (start_tick, velocity)
102 active: dict[tuple[int, int], tuple[int, int]] = {}
103
104 for track in mid.tracks:
105 abs_tick = 0
106 for msg in track:
107 abs_tick += msg.time
108 if msg.type == "set_tempo":
109 tempo_us = msg.tempo
110 elif msg.type == "note_on" and msg.velocity > 0:
111 active[(msg.channel, msg.note)] = (abs_tick, msg.velocity)
112 elif msg.type == "note_off" or (
113 msg.type == "note_on" and msg.velocity == 0
114 ):
115 key = (msg.channel, msg.note)
116 if key in active:
117 start, vel = active.pop(key)
118 notes.append(
119 NoteEvent(
120 pitch=msg.note,
121 velocity=vel,
122 start_tick=start,
123 duration_ticks=max(abs_tick - start, 1),
124 channel=msg.channel,
125 channel_name=f"ch{msg.channel}",
126 )
127 )
128
129 # Notes never closed — truncate to duration 1
130 for (ch, pitch), (start, vel) in active.items():
131 notes.append(
132 NoteEvent(
133 pitch=pitch,
134 velocity=vel,
135 start_tick=start,
136 duration_ticks=1,
137 channel=ch,
138 channel_name=f"ch{ch}",
139 )
140 )
141
142 tempo_bpm = 60_000_000 / tempo_us
143 tracks = _unique_ordered([n.channel_name for n in notes])
144
145 logger.debug(
146 "✅ Parsed MIDI %s: %d notes, %d tracks, %.1f BPM",
147 path.name, len(notes), len(tracks), tempo_bpm,
148 )
149 return MuseImportData(
150 source_path=path,
151 format="midi",
152 ticks_per_beat=ticks_per_beat,
153 tempo_bpm=tempo_bpm,
154 notes=notes,
155 tracks=tracks,
156 raw_meta={"num_tracks": len(mid.tracks)},
157 )
158
159
160 def parse_musicxml_file(path: pathlib.Path) -> MuseImportData:
161 """Parse a MusicXML ``<score-partwise>`` file into a :class:`MuseImportData`.
162
163 Raises:
164 RuntimeError: When the XML is invalid or not a recognisable MusicXML document.
165 """
166 try:
167 tree = ET.parse(str(path))
168 except ET.ParseError as exc:
169 raise RuntimeError(f"Cannot parse MusicXML file '{path}': {exc}") from exc
170
171 root = tree.getroot()
172
173 # Strip XML namespace prefix, e.g. {http://www.musicxml.org/…}element → element
174 ns = ""
175 if root.tag.startswith("{"):
176 ns = root.tag[: root.tag.index("}") + 1]
177
178 def t(name: str) -> str:
179 return f"{ns}{name}"
180
181 if root.tag not in (t("score-partwise"), "score-partwise"):
182 raise RuntimeError(
183 f"Unrecognised MusicXML root element '{root.tag}'. "
184 "Expected <score-partwise>."
185 )
186
187 tempo_bpm = 120.0
188 for direction in root.iter(t("direction")):
189 sound = direction.find(t("sound"))
190 if sound is not None:
191 raw = sound.get("tempo")
192 if raw is not None:
193 try:
194 tempo_bpm = float(raw)
195 break
196 except ValueError:
197 pass
198
199 ticks_per_beat = 480 # internal default for MusicXML
200
201 part_names: list[str] = []
202 for pn in root.iter(t("part-name")):
203 name = (pn.text or "").strip()
204 part_names.append(name or f"Part {len(part_names) + 1}")
205
206 _STEP_SEMITONE: dict[str, int] = {
207 "C": 0, "D": 2, "E": 4, "F": 5, "G": 7, "A": 9, "B": 11,
208 }
209
210 notes: list[NoteEvent] = []
211 parts = root.findall(t("part"))
212
213 for ch_idx, part_el in enumerate(parts):
214 channel_name = part_names[ch_idx] if ch_idx < len(part_names) else f"ch{ch_idx}"
215 abs_tick = 0
216 divisions = 1
217
218 for measure_el in part_el.findall(t("measure")):
219 attrs = measure_el.find(t("attributes"))
220 if attrs is not None:
221 div_el = attrs.find(t("divisions"))
222 if div_el is not None and div_el.text:
223 try:
224 divisions = int(div_el.text)
225 except ValueError:
226 pass
227
228 measure_tick = abs_tick
229
230 for note_el in measure_el.findall(t("note")):
231 dur_el = note_el.find(t("duration"))
232 dur_xml = int(dur_el.text) if dur_el is not None and dur_el.text else 0
233 dur_ticks = int(dur_xml * ticks_per_beat / max(divisions, 1))
234
235 if note_el.find(t("rest")) is not None:
236 measure_tick += dur_ticks
237 continue
238
239 pitch_el = note_el.find(t("pitch"))
240 if pitch_el is None:
241 measure_tick += dur_ticks
242 continue
243
244 step_el = pitch_el.find(t("step"))
245 oct_el = pitch_el.find(t("octave"))
246 alt_el = pitch_el.find(t("alter"))
247
248 step = (step_el.text or "C").strip() if step_el is not None else "C"
249 octave = int(oct_el.text or "4") if oct_el is not None else 4
250 alter = int(float(alt_el.text or "0")) if alt_el is not None else 0
251
252 pitch = max(0, min(127, (octave + 1) * 12 + _STEP_SEMITONE.get(step, 0) + alter))
253 is_chord = note_el.find(t("chord")) is not None
254 note_start = measure_tick
255 if not is_chord:
256 measure_tick += dur_ticks
257
258 notes.append(
259 NoteEvent(
260 pitch=pitch,
261 velocity=80,
262 start_tick=note_start,
263 duration_ticks=max(dur_ticks, 1),
264 channel=ch_idx,
265 channel_name=channel_name,
266 )
267 )
268
269 abs_tick = measure_tick
270
271 tracks = _unique_ordered([n.channel_name for n in notes])
272 logger.debug(
273 "✅ Parsed MusicXML %s: %d notes, %d parts, %.1f BPM",
274 path.name, len(notes), len(parts), tempo_bpm,
275 )
276 return MuseImportData(
277 source_path=path,
278 format="musicxml",
279 ticks_per_beat=ticks_per_beat,
280 tempo_bpm=tempo_bpm,
281 notes=notes,
282 tracks=tracks,
283 raw_meta={"num_parts": len(parts), "part_names": part_names},
284 )
285
286
287 def apply_track_map(notes: list[NoteEvent], track_map: dict[str, str]) -> list[NoteEvent]:
288 """Return notes with ``channel_name`` fields remapped per *track_map*.
289
290 Keys may be ``"ch<N>"`` or bare channel number strings.
291 Notes for unmapped channels are returned unchanged.
292 """
293 normalised: dict[int, str] = {}
294 for key, name in track_map.items():
295 k = key.strip()
296 try:
297 ch = int(k[2:]) if k.startswith("ch") else int(k)
298 normalised[ch] = name
299 except ValueError:
300 logger.warning("⚠️ Ignoring invalid track-map key %r", key)
301
302 result: list[NoteEvent] = []
303 for note in notes:
304 if note.channel in normalised:
305 result.append(dataclasses.replace(note, channel_name=normalised[note.channel]))
306 else:
307 result.append(note)
308 return result
309
310
311 def parse_track_map_arg(raw: str) -> dict[str, str]:
312 """Parse ``"ch0=bass,ch1=piano"`` into ``{"ch0": "bass", "ch1": "piano"}``.
313
314 Raises:
315 ValueError: When any pair is not in ``KEY=VALUE`` format.
316 """
317 result: dict[str, str] = {}
318 for pair in raw.split(","):
319 pair = pair.strip()
320 if not pair:
321 continue
322 if "=" not in pair:
323 raise ValueError(
324 f"Invalid track-map entry {pair!r}. Expected KEY=VALUE (e.g. ch0=bass)."
325 )
326 key, _, value = pair.partition("=")
327 result[key.strip()] = value.strip()
328 return result
329
330
331 def analyze_import(data: MuseImportData) -> str:
332 """Return a multi-line analysis of *data* covering harmonic, rhythmic, and dynamic dimensions."""
333 notes = data.notes
334 if not notes:
335 return " (no notes found — file may be empty or contain only meta events)"
336
337 _NOTE_NAMES = ["C", "C#", "D", "D#", "E", "F", "F#", "G", "G#", "A", "A#", "B"]
338 pitches = [n.pitch for n in notes]
339 pitch_counts: dict[int, int] = {}
340 for p in pitches:
341 pitch_counts[p] = pitch_counts.get(p, 0) + 1
342 top_pitches = sorted(pitch_counts, key=lambda p: -pitch_counts[p])[:5]
343 top_str = ", ".join(
344 f"{_NOTE_NAMES[p % 12]}{p // 12 - 1}({pitch_counts[p]}x)" for p in top_pitches
345 )
346 pitch_min, pitch_max = min(pitches), max(pitches)
347
348 total = len(notes)
349 span_ticks = max(n.start_tick + n.duration_ticks for n in notes)
350 span_beats = span_ticks / max(data.ticks_per_beat, 1)
351 density = total / max(span_beats, 0.001)
352
353 velocities = [n.velocity for n in notes]
354 avg_vel = sum(velocities) / len(velocities)
355 vel_min, vel_max = min(velocities), max(velocities)
356
357 def _band(v: float) -> str:
358 if v < 40: return "pp (very soft)"
359 if v < 65: return "p (soft)"
360 if v < 85: return "mp/mf (medium)"
361 if v < 105: return "f (loud)"
362 return "ff (very loud)"
363
364 track_summary = ", ".join(data.tracks) if data.tracks else "(none)"
365 return "\n".join([
366 f" Format: {data.format}",
367 f" Tempo: {data.tempo_bpm:.1f} BPM",
368 f" Tracks: {track_summary}",
369 "",
370 " ── Harmonic ──────────────────────────────────",
371 f" Pitch range: {_NOTE_NAMES[pitch_min % 12]}{pitch_min // 12 - 1}"
372 f"–{_NOTE_NAMES[pitch_max % 12]}{pitch_max // 12 - 1}",
373 f" Top pitches: {top_str}",
374 "",
375 " ── Rhythmic ──────────────────────────────────",
376 f" Notes: {total}",
377 f" Span: {span_beats:.1f} beats",
378 f" Density: {density:.1f} notes/beat",
379 "",
380 " ── Dynamic ───────────────────────────────────",
381 f" Velocity: avg={avg_vel:.0f}, min={vel_min}, max={vel_max}",
382 f" Character: {_band(avg_vel)}",
383 ])
384
385
386 def _unique_ordered(items: list[str]) -> list[str]:
387 seen: set[str] = set()
388 result: list[str] = []
389 for item in items:
390 if item not in seen:
391 seen.add(item)
392 result.append(item)
393 return result