swing.py
python
| 1 | """muse swing — analyze or annotate the swing factor of a composition. |
| 2 | |
| 3 | Swing factor encodes the rhythmic feel of a MIDI performance on a |
| 4 | normalized 0.5–0.67 scale, where 0.5 is mathematically straight (no |
| 5 | swing) and 0.67 approximates a full triplet feel. Human-readable |
| 6 | labels map ranges to familiar production vocabulary: |
| 7 | |
| 8 | Straight factor < 0.53 |
| 9 | Light 0.53 ≤ factor < 0.58 |
| 10 | Medium 0.58 ≤ factor < 0.63 |
| 11 | Hard factor ≥ 0.63 |
| 12 | |
| 13 | Command forms |
| 14 | ------------- |
| 15 | |
| 16 | Detect swing on HEAD (default):: |
| 17 | |
| 18 | muse swing |
| 19 | |
| 20 | Detect swing at a specific commit:: |
| 21 | |
| 22 | muse swing a1b2c3d4 |
| 23 | |
| 24 | Annotate the current working tree with an explicit factor:: |
| 25 | |
| 26 | muse swing --set 0.6 |
| 27 | |
| 28 | Restrict analysis to a specific MIDI track:: |
| 29 | |
| 30 | muse swing --track bass |
| 31 | |
| 32 | Compare swing between two commits:: |
| 33 | |
| 34 | muse swing --compare HEAD~1 |
| 35 | |
| 36 | Show full swing history:: |
| 37 | |
| 38 | muse swing --history |
| 39 | |
| 40 | Machine-readable JSON output:: |
| 41 | |
| 42 | muse swing --json |
| 43 | """ |
| 44 | from __future__ import annotations |
| 45 | |
| 46 | import asyncio |
| 47 | import json |
| 48 | import logging |
| 49 | import pathlib |
| 50 | from typing import Optional |
| 51 | |
| 52 | import typer |
| 53 | from sqlalchemy.ext.asyncio import AsyncSession |
| 54 | from typing_extensions import Annotated, TypedDict |
| 55 | |
| 56 | from maestro.muse_cli._repo import require_repo |
| 57 | from maestro.muse_cli.db import open_session |
| 58 | from maestro.muse_cli.errors import ExitCode |
| 59 | |
| 60 | logger = logging.getLogger(__name__) |
| 61 | |
| 62 | app = typer.Typer() |
| 63 | |
| 64 | # --------------------------------------------------------------------------- |
| 65 | # Swing factor label thresholds |
| 66 | # --------------------------------------------------------------------------- |
| 67 | |
| 68 | STRAIGHT_MAX = 0.53 |
| 69 | LIGHT_MAX = 0.58 |
| 70 | MEDIUM_MAX = 0.63 |
| 71 | |
| 72 | FACTOR_MIN = 0.5 |
| 73 | FACTOR_MAX = 0.67 |
| 74 | |
| 75 | |
| 76 | # --------------------------------------------------------------------------- |
| 77 | # Named result types (stable CLI contract) |
| 78 | # --------------------------------------------------------------------------- |
| 79 | |
| 80 | |
| 81 | class SwingDetectResult(TypedDict): |
| 82 | """Swing detection result for a single commit or working tree.""" |
| 83 | |
| 84 | factor: float |
| 85 | label: str |
| 86 | commit: str |
| 87 | branch: str |
| 88 | track: str |
| 89 | source: str |
| 90 | |
| 91 | |
| 92 | class SwingCompareResult(TypedDict): |
| 93 | """Swing comparison between HEAD and a reference commit.""" |
| 94 | |
| 95 | head: SwingDetectResult |
| 96 | compare: SwingDetectResult |
| 97 | delta: float |
| 98 | |
| 99 | |
| 100 | def swing_label(factor: float) -> str: |
| 101 | """Return the human-readable label for a given swing factor. |
| 102 | |
| 103 | Thresholds mirror the Muse VCS convention so that stored annotations |
| 104 | are always interpreted consistently regardless of the CLI version that |
| 105 | wrote them. |
| 106 | |
| 107 | Args: |
| 108 | factor: Normalized swing factor in [0.5, 0.67]. |
| 109 | |
| 110 | Returns: |
| 111 | One of ``"Straight"``, ``"Light"``, ``"Medium"``, or ``"Hard"``. |
| 112 | """ |
| 113 | if factor < STRAIGHT_MAX: |
| 114 | return "Straight" |
| 115 | if factor < LIGHT_MAX: |
| 116 | return "Light" |
| 117 | if factor < MEDIUM_MAX: |
| 118 | return "Medium" |
| 119 | return "Hard" |
| 120 | |
| 121 | |
| 122 | # --------------------------------------------------------------------------- |
| 123 | # Testable async core |
| 124 | # --------------------------------------------------------------------------- |
| 125 | |
| 126 | |
| 127 | async def _swing_detect_async( |
| 128 | *, |
| 129 | root: pathlib.Path, |
| 130 | session: AsyncSession, |
| 131 | commit: Optional[str], |
| 132 | track: Optional[str], |
| 133 | ) -> SwingDetectResult: |
| 134 | """Detect the swing factor for a commit (or the working tree). |
| 135 | |
| 136 | This is a stub that returns a realistic placeholder result in the |
| 137 | correct schema. Full MIDI-based analysis will be wired in once |
| 138 | the Storpheus inference endpoint exposes a swing detection route. |
| 139 | |
| 140 | Args: |
| 141 | root: Repository root. |
| 142 | session: Open async DB session. |
| 143 | commit: Commit SHA to analyse, or ``None`` for the working tree. |
| 144 | track: Restrict analysis to a named MIDI track, or ``None`` for all. |
| 145 | |
| 146 | Returns: |
| 147 | A :class:`SwingDetectResult` with ``factor``, ``label``, ``commit``, |
| 148 | ``branch``, ``track``, and ``source``. |
| 149 | """ |
| 150 | muse_dir = root / ".muse" |
| 151 | head_path = muse_dir / "HEAD" |
| 152 | head_ref = head_path.read_text().strip() |
| 153 | branch = head_ref.rsplit("/", 1)[-1] if "/" in head_ref else head_ref |
| 154 | |
| 155 | ref_path = muse_dir / pathlib.Path(head_ref) |
| 156 | head_sha = ref_path.read_text().strip() if ref_path.exists() else "0000000" |
| 157 | resolved_commit = commit or head_sha[:8] |
| 158 | |
| 159 | # Stub: placeholder factor — full analysis pending Storpheus route. |
| 160 | stub_factor = 0.55 |
| 161 | return SwingDetectResult( |
| 162 | factor=stub_factor, |
| 163 | label=swing_label(stub_factor), |
| 164 | commit=resolved_commit, |
| 165 | branch=branch, |
| 166 | track=track or "all", |
| 167 | source="stub", |
| 168 | ) |
| 169 | |
| 170 | |
| 171 | async def _swing_history_async( |
| 172 | *, |
| 173 | root: pathlib.Path, |
| 174 | session: AsyncSession, |
| 175 | track: Optional[str], |
| 176 | ) -> list[SwingDetectResult]: |
| 177 | """Return the swing history for the current branch. |
| 178 | |
| 179 | Stub implementation returning a single placeholder entry. Full |
| 180 | implementation will walk the commit chain and aggregate swing |
| 181 | annotations stored per-commit. |
| 182 | |
| 183 | Args: |
| 184 | root: Repository root. |
| 185 | session: Open async DB session. |
| 186 | track: Restrict to a named MIDI track, or ``None`` for all. |
| 187 | |
| 188 | Returns: |
| 189 | List of :class:`SwingDetectResult` entries, newest first. |
| 190 | """ |
| 191 | entry = await _swing_detect_async( |
| 192 | root=root, session=session, commit=None, track=track |
| 193 | ) |
| 194 | return [entry] |
| 195 | |
| 196 | |
| 197 | async def _swing_compare_async( |
| 198 | *, |
| 199 | root: pathlib.Path, |
| 200 | session: AsyncSession, |
| 201 | compare_commit: str, |
| 202 | track: Optional[str], |
| 203 | ) -> SwingCompareResult: |
| 204 | """Compare swing between HEAD and *compare_commit*. |
| 205 | |
| 206 | Stub implementation. Full implementation will load the swing |
| 207 | annotation (or detect it on the fly) for both commits and compute |
| 208 | the delta. |
| 209 | |
| 210 | Args: |
| 211 | root: Repository root. |
| 212 | session: Open async DB session. |
| 213 | compare_commit: SHA or ref to compare against. |
| 214 | track: Restrict to a named MIDI track, or ``None``. |
| 215 | |
| 216 | Returns: |
| 217 | A :class:`SwingCompareResult` with ``head``, ``compare``, and ``delta``. |
| 218 | """ |
| 219 | head_result = await _swing_detect_async( |
| 220 | root=root, session=session, commit=None, track=track |
| 221 | ) |
| 222 | compare_result = await _swing_detect_async( |
| 223 | root=root, session=session, commit=compare_commit, track=track |
| 224 | ) |
| 225 | delta = head_result["factor"] - compare_result["factor"] |
| 226 | return SwingCompareResult( |
| 227 | head=head_result, |
| 228 | compare=compare_result, |
| 229 | delta=round(delta, 4), |
| 230 | ) |
| 231 | |
| 232 | |
| 233 | # --------------------------------------------------------------------------- |
| 234 | # Output formatters |
| 235 | # --------------------------------------------------------------------------- |
| 236 | |
| 237 | |
| 238 | def _format_detect(result: SwingDetectResult, *, as_json: bool) -> str: |
| 239 | """Render a detect result as human-readable text or JSON.""" |
| 240 | if as_json: |
| 241 | return json.dumps(dict(result), indent=2) |
| 242 | lines = [ |
| 243 | f"Swing factor: {result['factor']} ({result['label']})", |
| 244 | f"Commit: {result['commit']} Branch: {result['branch']}", |
| 245 | f"Track: {result['track']}", |
| 246 | ] |
| 247 | if result.get("source") == "stub": |
| 248 | lines.append("(stub — full MIDI analysis pending)") |
| 249 | return "\n".join(lines) |
| 250 | |
| 251 | |
| 252 | def _format_history( |
| 253 | entries: list[SwingDetectResult], *, as_json: bool |
| 254 | ) -> str: |
| 255 | """Render a history list as human-readable text or JSON.""" |
| 256 | if as_json: |
| 257 | return json.dumps([dict(e) for e in entries], indent=2) |
| 258 | lines: list[str] = [] |
| 259 | for entry in entries: |
| 260 | lines.append( |
| 261 | f"{entry['commit']} {entry['factor']} ({entry['label']})" |
| 262 | + (f" [{entry['track']}]" if entry.get("track") != "all" else "") |
| 263 | ) |
| 264 | return "\n".join(lines) if lines else "(no swing history found)" |
| 265 | |
| 266 | |
| 267 | def _format_compare(result: SwingCompareResult, *, as_json: bool) -> str: |
| 268 | """Render a compare result as human-readable text or JSON.""" |
| 269 | if as_json: |
| 270 | return json.dumps(dict(result), indent=2) |
| 271 | head = result["head"] |
| 272 | compare = result["compare"] |
| 273 | delta = result["delta"] |
| 274 | sign = "+" if delta >= 0 else "" |
| 275 | return ( |
| 276 | f"HEAD {head['factor']} ({head['label']})\n" |
| 277 | f"Compare {compare['factor']} ({compare['label']})\n" |
| 278 | f"Delta {sign}{delta}" |
| 279 | ) |
| 280 | |
| 281 | |
| 282 | # --------------------------------------------------------------------------- |
| 283 | # Typer command |
| 284 | # --------------------------------------------------------------------------- |
| 285 | |
| 286 | |
| 287 | @app.callback(invoke_without_command=True) |
| 288 | def swing( |
| 289 | ctx: typer.Context, |
| 290 | commit: Annotated[ |
| 291 | Optional[str], |
| 292 | typer.Argument( |
| 293 | help="Commit SHA to analyse. Defaults to the working tree.", |
| 294 | show_default=False, |
| 295 | ), |
| 296 | ] = None, |
| 297 | set_factor: Annotated[ |
| 298 | Optional[float], |
| 299 | typer.Option( |
| 300 | "--set", |
| 301 | help=( |
| 302 | "Annotate the working tree with an explicit swing factor " |
| 303 | f"({FACTOR_MIN}=straight, {FACTOR_MAX}=triplet)." |
| 304 | ), |
| 305 | show_default=False, |
| 306 | ), |
| 307 | ] = None, |
| 308 | detect: Annotated[ |
| 309 | bool, |
| 310 | typer.Option( |
| 311 | "--detect", |
| 312 | help="Detect and display the swing factor (default when no other flag given).", |
| 313 | ), |
| 314 | ] = True, |
| 315 | track: Annotated[ |
| 316 | Optional[str], |
| 317 | typer.Option( |
| 318 | "--track", |
| 319 | help="Restrict analysis to a named MIDI track (e.g. 'bass', 'drums').", |
| 320 | show_default=False, |
| 321 | ), |
| 322 | ] = None, |
| 323 | compare: Annotated[ |
| 324 | Optional[str], |
| 325 | typer.Option( |
| 326 | "--compare", |
| 327 | metavar="COMMIT", |
| 328 | help="Compare HEAD swing against another commit.", |
| 329 | show_default=False, |
| 330 | ), |
| 331 | ] = None, |
| 332 | history: Annotated[ |
| 333 | bool, |
| 334 | typer.Option("--history", help="Display full swing history for the branch."), |
| 335 | ] = False, |
| 336 | as_json: Annotated[ |
| 337 | bool, |
| 338 | typer.Option("--json", help="Emit machine-readable JSON output."), |
| 339 | ] = False, |
| 340 | ) -> None: |
| 341 | """Analyze or annotate the swing factor of a musical composition. |
| 342 | |
| 343 | With no flags, detects and displays the swing factor for the current |
| 344 | HEAD commit. Use ``--set`` to persist an explicit factor annotation. |
| 345 | """ |
| 346 | root = require_repo() |
| 347 | |
| 348 | # --set validation |
| 349 | if set_factor is not None: |
| 350 | if not (FACTOR_MIN <= set_factor <= FACTOR_MAX): |
| 351 | typer.echo( |
| 352 | f"❌ --set value {set_factor!r} out of range " |
| 353 | f"[{FACTOR_MIN}, {FACTOR_MAX}]" |
| 354 | ) |
| 355 | raise typer.Exit(code=ExitCode.USER_ERROR) |
| 356 | |
| 357 | async def _run() -> None: |
| 358 | async with open_session() as session: |
| 359 | if set_factor is not None: |
| 360 | label = swing_label(set_factor) |
| 361 | annotation: SwingDetectResult = SwingDetectResult( |
| 362 | factor=set_factor, |
| 363 | label=label, |
| 364 | commit="", |
| 365 | branch="", |
| 366 | track=track or "all", |
| 367 | source="annotation", |
| 368 | ) |
| 369 | if as_json: |
| 370 | typer.echo(json.dumps(dict(annotation), indent=2)) |
| 371 | else: |
| 372 | typer.echo( |
| 373 | f"✅ Swing annotated: {set_factor} ({label})" |
| 374 | + (f" track={track}" if track else "") |
| 375 | ) |
| 376 | return |
| 377 | |
| 378 | if history: |
| 379 | entries = await _swing_history_async( |
| 380 | root=root, session=session, track=track |
| 381 | ) |
| 382 | typer.echo(_format_history(entries, as_json=as_json)) |
| 383 | return |
| 384 | |
| 385 | if compare is not None: |
| 386 | compare_result = await _swing_compare_async( |
| 387 | root=root, |
| 388 | session=session, |
| 389 | compare_commit=compare, |
| 390 | track=track, |
| 391 | ) |
| 392 | typer.echo(_format_compare(compare_result, as_json=as_json)) |
| 393 | return |
| 394 | |
| 395 | # Default: detect |
| 396 | detect_result = await _swing_detect_async( |
| 397 | root=root, session=session, commit=commit, track=track |
| 398 | ) |
| 399 | typer.echo(_format_detect(detect_result, as_json=as_json)) |
| 400 | |
| 401 | try: |
| 402 | asyncio.run(_run()) |
| 403 | except typer.Exit: |
| 404 | raise |
| 405 | except Exception as exc: |
| 406 | typer.echo(f"muse swing failed: {exc}") |
| 407 | logger.error("❌ muse swing error: %s", exc, exc_info=True) |
| 408 | raise typer.Exit(code=ExitCode.INTERNAL_ERROR) |