cgcardona / muse public
swing.py python
408 lines 12.0 KB
12901c5a Initial extraction from tellurstori/maestro cgcardona <gabriel@tellurstori.com> 4d ago
1 """muse swing — analyze or annotate the swing factor of a composition.
2
3 Swing factor encodes the rhythmic feel of a MIDI performance on a
4 normalized 0.5–0.67 scale, where 0.5 is mathematically straight (no
5 swing) and 0.67 approximates a full triplet feel. Human-readable
6 labels map ranges to familiar production vocabulary:
7
8 Straight factor < 0.53
9 Light 0.53 ≤ factor < 0.58
10 Medium 0.58 ≤ factor < 0.63
11 Hard factor ≥ 0.63
12
13 Command forms
14 -------------
15
16 Detect swing on HEAD (default)::
17
18 muse swing
19
20 Detect swing at a specific commit::
21
22 muse swing a1b2c3d4
23
24 Annotate the current working tree with an explicit factor::
25
26 muse swing --set 0.6
27
28 Restrict analysis to a specific MIDI track::
29
30 muse swing --track bass
31
32 Compare swing between two commits::
33
34 muse swing --compare HEAD~1
35
36 Show full swing history::
37
38 muse swing --history
39
40 Machine-readable JSON output::
41
42 muse swing --json
43 """
44 from __future__ import annotations
45
46 import asyncio
47 import json
48 import logging
49 import pathlib
50 from typing import Optional
51
52 import typer
53 from sqlalchemy.ext.asyncio import AsyncSession
54 from typing_extensions import Annotated, TypedDict
55
56 from maestro.muse_cli._repo import require_repo
57 from maestro.muse_cli.db import open_session
58 from maestro.muse_cli.errors import ExitCode
59
60 logger = logging.getLogger(__name__)
61
62 app = typer.Typer()
63
64 # ---------------------------------------------------------------------------
65 # Swing factor label thresholds
66 # ---------------------------------------------------------------------------
67
68 STRAIGHT_MAX = 0.53
69 LIGHT_MAX = 0.58
70 MEDIUM_MAX = 0.63
71
72 FACTOR_MIN = 0.5
73 FACTOR_MAX = 0.67
74
75
76 # ---------------------------------------------------------------------------
77 # Named result types (stable CLI contract)
78 # ---------------------------------------------------------------------------
79
80
81 class SwingDetectResult(TypedDict):
82 """Swing detection result for a single commit or working tree."""
83
84 factor: float
85 label: str
86 commit: str
87 branch: str
88 track: str
89 source: str
90
91
92 class SwingCompareResult(TypedDict):
93 """Swing comparison between HEAD and a reference commit."""
94
95 head: SwingDetectResult
96 compare: SwingDetectResult
97 delta: float
98
99
100 def swing_label(factor: float) -> str:
101 """Return the human-readable label for a given swing factor.
102
103 Thresholds mirror the Muse VCS convention so that stored annotations
104 are always interpreted consistently regardless of the CLI version that
105 wrote them.
106
107 Args:
108 factor: Normalized swing factor in [0.5, 0.67].
109
110 Returns:
111 One of ``"Straight"``, ``"Light"``, ``"Medium"``, or ``"Hard"``.
112 """
113 if factor < STRAIGHT_MAX:
114 return "Straight"
115 if factor < LIGHT_MAX:
116 return "Light"
117 if factor < MEDIUM_MAX:
118 return "Medium"
119 return "Hard"
120
121
122 # ---------------------------------------------------------------------------
123 # Testable async core
124 # ---------------------------------------------------------------------------
125
126
127 async def _swing_detect_async(
128 *,
129 root: pathlib.Path,
130 session: AsyncSession,
131 commit: Optional[str],
132 track: Optional[str],
133 ) -> SwingDetectResult:
134 """Detect the swing factor for a commit (or the working tree).
135
136 This is a stub that returns a realistic placeholder result in the
137 correct schema. Full MIDI-based analysis will be wired in once
138 the Storpheus inference endpoint exposes a swing detection route.
139
140 Args:
141 root: Repository root.
142 session: Open async DB session.
143 commit: Commit SHA to analyse, or ``None`` for the working tree.
144 track: Restrict analysis to a named MIDI track, or ``None`` for all.
145
146 Returns:
147 A :class:`SwingDetectResult` with ``factor``, ``label``, ``commit``,
148 ``branch``, ``track``, and ``source``.
149 """
150 muse_dir = root / ".muse"
151 head_path = muse_dir / "HEAD"
152 head_ref = head_path.read_text().strip()
153 branch = head_ref.rsplit("/", 1)[-1] if "/" in head_ref else head_ref
154
155 ref_path = muse_dir / pathlib.Path(head_ref)
156 head_sha = ref_path.read_text().strip() if ref_path.exists() else "0000000"
157 resolved_commit = commit or head_sha[:8]
158
159 # Stub: placeholder factor — full analysis pending Storpheus route.
160 stub_factor = 0.55
161 return SwingDetectResult(
162 factor=stub_factor,
163 label=swing_label(stub_factor),
164 commit=resolved_commit,
165 branch=branch,
166 track=track or "all",
167 source="stub",
168 )
169
170
171 async def _swing_history_async(
172 *,
173 root: pathlib.Path,
174 session: AsyncSession,
175 track: Optional[str],
176 ) -> list[SwingDetectResult]:
177 """Return the swing history for the current branch.
178
179 Stub implementation returning a single placeholder entry. Full
180 implementation will walk the commit chain and aggregate swing
181 annotations stored per-commit.
182
183 Args:
184 root: Repository root.
185 session: Open async DB session.
186 track: Restrict to a named MIDI track, or ``None`` for all.
187
188 Returns:
189 List of :class:`SwingDetectResult` entries, newest first.
190 """
191 entry = await _swing_detect_async(
192 root=root, session=session, commit=None, track=track
193 )
194 return [entry]
195
196
197 async def _swing_compare_async(
198 *,
199 root: pathlib.Path,
200 session: AsyncSession,
201 compare_commit: str,
202 track: Optional[str],
203 ) -> SwingCompareResult:
204 """Compare swing between HEAD and *compare_commit*.
205
206 Stub implementation. Full implementation will load the swing
207 annotation (or detect it on the fly) for both commits and compute
208 the delta.
209
210 Args:
211 root: Repository root.
212 session: Open async DB session.
213 compare_commit: SHA or ref to compare against.
214 track: Restrict to a named MIDI track, or ``None``.
215
216 Returns:
217 A :class:`SwingCompareResult` with ``head``, ``compare``, and ``delta``.
218 """
219 head_result = await _swing_detect_async(
220 root=root, session=session, commit=None, track=track
221 )
222 compare_result = await _swing_detect_async(
223 root=root, session=session, commit=compare_commit, track=track
224 )
225 delta = head_result["factor"] - compare_result["factor"]
226 return SwingCompareResult(
227 head=head_result,
228 compare=compare_result,
229 delta=round(delta, 4),
230 )
231
232
233 # ---------------------------------------------------------------------------
234 # Output formatters
235 # ---------------------------------------------------------------------------
236
237
238 def _format_detect(result: SwingDetectResult, *, as_json: bool) -> str:
239 """Render a detect result as human-readable text or JSON."""
240 if as_json:
241 return json.dumps(dict(result), indent=2)
242 lines = [
243 f"Swing factor: {result['factor']} ({result['label']})",
244 f"Commit: {result['commit']} Branch: {result['branch']}",
245 f"Track: {result['track']}",
246 ]
247 if result.get("source") == "stub":
248 lines.append("(stub — full MIDI analysis pending)")
249 return "\n".join(lines)
250
251
252 def _format_history(
253 entries: list[SwingDetectResult], *, as_json: bool
254 ) -> str:
255 """Render a history list as human-readable text or JSON."""
256 if as_json:
257 return json.dumps([dict(e) for e in entries], indent=2)
258 lines: list[str] = []
259 for entry in entries:
260 lines.append(
261 f"{entry['commit']} {entry['factor']} ({entry['label']})"
262 + (f" [{entry['track']}]" if entry.get("track") != "all" else "")
263 )
264 return "\n".join(lines) if lines else "(no swing history found)"
265
266
267 def _format_compare(result: SwingCompareResult, *, as_json: bool) -> str:
268 """Render a compare result as human-readable text or JSON."""
269 if as_json:
270 return json.dumps(dict(result), indent=2)
271 head = result["head"]
272 compare = result["compare"]
273 delta = result["delta"]
274 sign = "+" if delta >= 0 else ""
275 return (
276 f"HEAD {head['factor']} ({head['label']})\n"
277 f"Compare {compare['factor']} ({compare['label']})\n"
278 f"Delta {sign}{delta}"
279 )
280
281
282 # ---------------------------------------------------------------------------
283 # Typer command
284 # ---------------------------------------------------------------------------
285
286
287 @app.callback(invoke_without_command=True)
288 def swing(
289 ctx: typer.Context,
290 commit: Annotated[
291 Optional[str],
292 typer.Argument(
293 help="Commit SHA to analyse. Defaults to the working tree.",
294 show_default=False,
295 ),
296 ] = None,
297 set_factor: Annotated[
298 Optional[float],
299 typer.Option(
300 "--set",
301 help=(
302 "Annotate the working tree with an explicit swing factor "
303 f"({FACTOR_MIN}=straight, {FACTOR_MAX}=triplet)."
304 ),
305 show_default=False,
306 ),
307 ] = None,
308 detect: Annotated[
309 bool,
310 typer.Option(
311 "--detect",
312 help="Detect and display the swing factor (default when no other flag given).",
313 ),
314 ] = True,
315 track: Annotated[
316 Optional[str],
317 typer.Option(
318 "--track",
319 help="Restrict analysis to a named MIDI track (e.g. 'bass', 'drums').",
320 show_default=False,
321 ),
322 ] = None,
323 compare: Annotated[
324 Optional[str],
325 typer.Option(
326 "--compare",
327 metavar="COMMIT",
328 help="Compare HEAD swing against another commit.",
329 show_default=False,
330 ),
331 ] = None,
332 history: Annotated[
333 bool,
334 typer.Option("--history", help="Display full swing history for the branch."),
335 ] = False,
336 as_json: Annotated[
337 bool,
338 typer.Option("--json", help="Emit machine-readable JSON output."),
339 ] = False,
340 ) -> None:
341 """Analyze or annotate the swing factor of a musical composition.
342
343 With no flags, detects and displays the swing factor for the current
344 HEAD commit. Use ``--set`` to persist an explicit factor annotation.
345 """
346 root = require_repo()
347
348 # --set validation
349 if set_factor is not None:
350 if not (FACTOR_MIN <= set_factor <= FACTOR_MAX):
351 typer.echo(
352 f"❌ --set value {set_factor!r} out of range "
353 f"[{FACTOR_MIN}, {FACTOR_MAX}]"
354 )
355 raise typer.Exit(code=ExitCode.USER_ERROR)
356
357 async def _run() -> None:
358 async with open_session() as session:
359 if set_factor is not None:
360 label = swing_label(set_factor)
361 annotation: SwingDetectResult = SwingDetectResult(
362 factor=set_factor,
363 label=label,
364 commit="",
365 branch="",
366 track=track or "all",
367 source="annotation",
368 )
369 if as_json:
370 typer.echo(json.dumps(dict(annotation), indent=2))
371 else:
372 typer.echo(
373 f"✅ Swing annotated: {set_factor} ({label})"
374 + (f" track={track}" if track else "")
375 )
376 return
377
378 if history:
379 entries = await _swing_history_async(
380 root=root, session=session, track=track
381 )
382 typer.echo(_format_history(entries, as_json=as_json))
383 return
384
385 if compare is not None:
386 compare_result = await _swing_compare_async(
387 root=root,
388 session=session,
389 compare_commit=compare,
390 track=track,
391 )
392 typer.echo(_format_compare(compare_result, as_json=as_json))
393 return
394
395 # Default: detect
396 detect_result = await _swing_detect_async(
397 root=root, session=session, commit=commit, track=track
398 )
399 typer.echo(_format_detect(detect_result, as_json=as_json))
400
401 try:
402 asyncio.run(_run())
403 except typer.Exit:
404 raise
405 except Exception as exc:
406 typer.echo(f"muse swing failed: {exc}")
407 logger.error("❌ muse swing error: %s", exc, exc_info=True)
408 raise typer.Exit(code=ExitCode.INTERNAL_ERROR)