cgcardona / muse public
muse_tempo.py python
234 lines 8.1 KB
12901c5a Initial extraction from tellurstori/maestro cgcardona <gabriel@tellurstori.com> 4d ago
1 """Muse Tempo Service — read and annotate tempo (BPM) on Muse CLI commits.
2
3 Provides:
4
5 - ``extract_bpm_from_midi`` — pure function: bytes → BPM or None.
6 Parses the MIDI Set Tempo meta-event (FF 51 03 tt tt tt).
7 Returns the BPM from the *first* tempo event found, or ``None`` if
8 the file contains no tempo events (uses MIDI default 120 BPM implicitly).
9
10 - ``detect_tempo_from_snapshot`` — highest BPM of any MIDI file in a
11 snapshot manifest; ``None`` if no MIDI files or no tempo events found.
12
13 - ``MuseTempoResult`` — named result type for a single commit tempo query.
14
15 - ``MuseTempoHistoryEntry`` — one row in a ``--history`` traversal.
16
17 - ``build_tempo_history`` — ordered list of history entries, newest-first.
18
19 Boundary rules:
20 - Must NOT import StateStore, EntityRegistry, or app.core.*.
21 - Must NOT import LLM handlers or maestro_* modules.
22 - Pure data — no FastAPI, no DB access, no side effects beyond logging.
23 """
24 from __future__ import annotations
25
26 import logging
27 import pathlib
28 from dataclasses import dataclass
29
30 from maestro.muse_cli.models import MuseCliCommit
31
32 logger = logging.getLogger(__name__)
33
34 # MIDI meta-event markers
35 _MIDI_HEADER = b"MThd"
36 _META_TEMPO_TYPE = 0x51
37 _META_EVENT_MARKER = 0xFF
38 # Default MIDI tempo: 500000 microseconds/beat = 120 BPM
39 _DEFAULT_MIDI_USPB = 500_000
40 _MICROSECONDS_PER_MINUTE = 60_000_000
41
42
43 # ---------------------------------------------------------------------------
44 # Named result types
45 # ---------------------------------------------------------------------------
46
47
48 @dataclass(frozen=True)
49 class MuseTempoResult:
50 """Result of a ``muse tempo [<commit>]`` query.
51
52 ``tempo_bpm`` is the annotated value stored via ``--set``.
53 ``detected_bpm`` is extracted from MIDI files in the snapshot.
54 Either may be ``None`` when the data is unavailable.
55 """
56
57 commit_id: str
58 branch: str
59 message: str
60 tempo_bpm: float | None
61 """Explicitly annotated BPM (from ``muse tempo --set``)."""
62 detected_bpm: float | None
63 """Auto-detected BPM from MIDI tempo map events in the snapshot."""
64
65 @property
66 def effective_bpm(self) -> float | None:
67 """Annotated value takes precedence; falls back to detected."""
68 return self.tempo_bpm if self.tempo_bpm is not None else self.detected_bpm
69
70
71 @dataclass(frozen=True)
72 class MuseTempoHistoryEntry:
73 """One row in a ``muse tempo --history`` traversal.
74
75 ``delta_bpm`` is the signed change vs. the previous commit's effective
76 BPM, or ``None`` for the oldest commit (no ancestor to compare against).
77 """
78
79 commit_id: str
80 message: str
81 effective_bpm: float | None
82 delta_bpm: float | None
83
84
85 # ---------------------------------------------------------------------------
86 # Pure MIDI parsing
87 # ---------------------------------------------------------------------------
88
89
90 def extract_bpm_from_midi(data: bytes) -> float | None:
91 """Return the BPM from the first Set Tempo meta-event in *data*.
92
93 Parses a raw MIDI byte string looking for the standard Set Tempo
94 meta-event (``0xFF 0x51 0x03 <3-byte big-endian microseconds/beat>``).
95 Returns ``None`` when:
96
97 - *data* is not a valid MIDI file (no ``MThd`` header).
98 - No Set Tempo event is present (implicit 120 BPM per MIDI spec, but
99 we return ``None`` rather than assume, so callers can distinguish
100 "no event found" from "120 BPM was set explicitly").
101
102 Only the *first* tempo event is returned. For rubato detection
103 (multiple tempo events) use ``detect_all_tempos_from_midi`` below.
104 """
105 if not data[:4] == _MIDI_HEADER:
106 return None
107
108 i = 0
109 length = len(data)
110 while i < length - 5:
111 if data[i] == _META_EVENT_MARKER and data[i + 1] == _META_TEMPO_TYPE:
112 # FF 51 03 tt tt tt
113 meta_len = data[i + 2]
114 if meta_len >= 3 and i + 2 + meta_len < length:
115 raw_uspb: int = (data[i + 3] << 16) | (data[i + 4] << 8) | data[i + 5]
116 if raw_uspb > 0:
117 bpm = _MICROSECONDS_PER_MINUTE / raw_uspb
118 logger.debug("✅ MIDI tempo event: %d µs/beat → %.2f BPM", raw_uspb, bpm)
119 return round(bpm, 2)
120 i += 1
121 return None
122
123
124 def detect_all_tempos_from_midi(data: bytes) -> list[float]:
125 """Return BPM for every Set Tempo meta-event in *data*, in order.
126
127 Used for drift (rubato) detection — a file with a single entry has
128 a constant tempo; multiple entries indicate rubato or tempo changes.
129 Returns an empty list if *data* is not a valid MIDI file or has no
130 tempo events.
131 """
132 if not data[:4] == _MIDI_HEADER:
133 return []
134
135 tempos: list[float] = []
136 i = 0
137 length = len(data)
138 while i < length - 5:
139 if data[i] == _META_EVENT_MARKER and data[i + 1] == _META_TEMPO_TYPE:
140 meta_len = data[i + 2]
141 if meta_len >= 3 and i + 2 + meta_len < length:
142 raw_uspb: int = (data[i + 3] << 16) | (data[i + 4] << 8) | data[i + 5]
143 if raw_uspb > 0:
144 tempos.append(round(_MICROSECONDS_PER_MINUTE / raw_uspb, 2))
145 i += 1
146 return tempos
147
148
149 def detect_tempo_from_snapshot(
150 manifest: dict[str, str],
151 workdir: pathlib.Path,
152 ) -> float | None:
153 """Detect tempo from MIDI files listed in a snapshot manifest.
154
155 Iterates files in the manifest, reads those with a ``.mid`` or
156 ``.midi`` suffix, and returns the BPM from the first tempo event
157 found across all files. Files are processed in sorted order for
158 determinism.
159
160 Returns ``None`` when no MIDI files are present or none contain a
161 Set Tempo meta-event.
162 """
163 for rel_path in sorted(manifest.keys()):
164 if not (rel_path.lower().endswith(".mid") or rel_path.lower().endswith(".midi")):
165 continue
166 abs_path = workdir / rel_path
167 if not abs_path.is_file():
168 continue
169 try:
170 data = abs_path.read_bytes()
171 except OSError as exc:
172 logger.warning("⚠️ Could not read MIDI file %s: %s", rel_path, exc)
173 continue
174 bpm = extract_bpm_from_midi(data)
175 if bpm is not None:
176 logger.debug("✅ Detected %.2f BPM from %s", bpm, rel_path)
177 return bpm
178 return None
179
180
181 # ---------------------------------------------------------------------------
182 # History helpers
183 # ---------------------------------------------------------------------------
184
185
186 def build_tempo_history(
187 commits: list[MuseCliCommit],
188 ) -> list[MuseTempoHistoryEntry]:
189 """Build a tempo history list from a newest-first commit chain.
190
191 Each entry records the commit ID, message, effective BPM, and the
192 signed delta vs. the previous (older) commit. The oldest commit has
193 ``delta_bpm=None`` because it has no ancestor.
194
195 *commits* must be ordered newest-first (as returned by ``_load_commits``
196 in ``commands/log.py``). The delta is newest-relative-to-older so a
197 BPM increase shows as a positive delta.
198
199 The effective BPM for each commit is the annotated ``tempo_bpm`` stored
200 in ``metadata``, if present; auto-detected values are not stored in the
201 DB, so history only reflects explicitly set annotations.
202 """
203
204 # Walk oldest→newest to compute deltas, then reverse for output.
205 oldest_first = list(reversed(commits))
206 bpms: list[float | None] = []
207 for commit in oldest_first:
208 meta: dict[str, object] = commit.commit_metadata or {}
209 bpm_raw = meta.get("tempo_bpm")
210 bpm: float | None = float(bpm_raw) if isinstance(bpm_raw, (int, float)) else None
211 bpms.append(bpm)
212
213 result: list[MuseTempoHistoryEntry] = []
214 for idx, commit in enumerate(oldest_first):
215 bpm = bpms[idx]
216 if idx == 0:
217 delta: float | None = None
218 else:
219 older = bpms[idx - 1]
220 if bpm is not None and older is not None:
221 delta = round(bpm - older, 2)
222 else:
223 delta = None
224 result.append(
225 MuseTempoHistoryEntry(
226 commit_id=commit.commit_id,
227 message=commit.message,
228 effective_bpm=bpm,
229 delta_bpm=delta,
230 )
231 )
232
233 # Return newest-first (matches log convention)
234 return list(reversed(result))