groove_check.py
python
| 1 | """muse groove-check — analyze rhythmic drift across commits. |
| 2 | |
| 3 | Detects which commit in a range "killed the groove" by measuring how much |
| 4 | the average note-onset deviation from the quantization grid shifted between |
| 5 | adjacent commits. A large drift_delta signals a quantize operation that was |
| 6 | too aggressive, a tempo map change that made the pocket feel stiff, or any |
| 7 | edit that disrupted rhythmic consistency. |
| 8 | |
| 9 | Output (default tabular):: |
| 10 | |
| 11 | Groove-check — range HEAD~6..HEAD threshold 0.1 beats |
| 12 | |
| 13 | Commit Groove Score Drift Δ Status |
| 14 | -------- ------------ ------- ------ |
| 15 | a1b2c3d4 0.0400 0.0000 OK |
| 16 | e5f6a7b8 0.0500 0.0100 OK |
| 17 | c9d0e1f2 0.0600 0.0100 OK |
| 18 | a3b4c5d6 0.0900 0.0300 OK |
| 19 | e7f8a9b0 0.1500 0.0600 WARN |
| 20 | c1d2e3f4 0.1300 0.0200 OK |
| 21 | |
| 22 | Flagged: 1 / 6 commits (worst: e7f8a9b0) |
| 23 | |
| 24 | Flags |
| 25 | ----- |
| 26 | [range] Commit range to analyze (e.g. HEAD~5..HEAD). Default: last 10. |
| 27 | --track TEXT Scope analysis to a specific instrument track. |
| 28 | --section TEXT Scope analysis to a specific musical section. |
| 29 | --threshold FLOAT Drift threshold in beats (default 0.1). Commits that exceed |
| 30 | this value are flagged WARN; >2× threshold = FAIL. |
| 31 | --json Emit machine-readable JSON output. |
| 32 | """ |
| 33 | from __future__ import annotations |
| 34 | |
| 35 | import asyncio |
| 36 | import json |
| 37 | import logging |
| 38 | import pathlib |
| 39 | from typing import Optional |
| 40 | |
| 41 | import typer |
| 42 | from sqlalchemy.ext.asyncio import AsyncSession |
| 43 | from typing_extensions import Annotated |
| 44 | |
| 45 | from maestro.muse_cli._repo import require_repo |
| 46 | from maestro.muse_cli.db import open_session |
| 47 | from maestro.muse_cli.errors import ExitCode |
| 48 | from maestro.services.muse_groove_check import ( |
| 49 | DEFAULT_THRESHOLD, |
| 50 | GrooveCheckResult, |
| 51 | GrooveStatus, |
| 52 | compute_groove_check, |
| 53 | ) |
| 54 | |
| 55 | logger = logging.getLogger(__name__) |
| 56 | |
| 57 | app = typer.Typer() |
| 58 | |
| 59 | # --------------------------------------------------------------------------- |
| 60 | # Column widths |
| 61 | # --------------------------------------------------------------------------- |
| 62 | |
| 63 | _COL_COMMIT = 8 |
| 64 | _COL_SCORE = 12 |
| 65 | _COL_DELTA = 7 |
| 66 | _COL_STATUS = 6 |
| 67 | |
| 68 | # --------------------------------------------------------------------------- |
| 69 | # Rendering |
| 70 | # --------------------------------------------------------------------------- |
| 71 | |
| 72 | |
| 73 | def _render_table(result: GrooveCheckResult) -> None: |
| 74 | """Emit an ASCII table of groove-check results to stdout. |
| 75 | |
| 76 | Includes a summary line with the total flagged count and worst commit. |
| 77 | |
| 78 | Args: |
| 79 | result: Completed :class:`GrooveCheckResult` from the analysis. |
| 80 | """ |
| 81 | typer.echo( |
| 82 | f"Groove-check — range {result.commit_range}" |
| 83 | f" threshold {result.threshold} beats" |
| 84 | ) |
| 85 | typer.echo("") |
| 86 | |
| 87 | header = ( |
| 88 | f"{'Commit':<{_COL_COMMIT}} " |
| 89 | f"{'Groove Score':>{_COL_SCORE}} " |
| 90 | f"{'Drift Δ':>{_COL_DELTA}} " |
| 91 | f"{'Status':<{_COL_STATUS}}" |
| 92 | ) |
| 93 | sep = ( |
| 94 | f"{'-' * _COL_COMMIT} " |
| 95 | f"{'-' * _COL_SCORE} " |
| 96 | f"{'-' * _COL_DELTA} " |
| 97 | f"{'-' * _COL_STATUS}" |
| 98 | ) |
| 99 | typer.echo(header) |
| 100 | typer.echo(sep) |
| 101 | |
| 102 | for entry in result.entries: |
| 103 | typer.echo( |
| 104 | f"{entry.commit:<{_COL_COMMIT}} " |
| 105 | f"{entry.groove_score:>{_COL_SCORE}.4f} " |
| 106 | f"{entry.drift_delta:>{_COL_DELTA}.4f} " |
| 107 | f"{entry.status.value:<{_COL_STATUS}}" |
| 108 | ) |
| 109 | |
| 110 | typer.echo("") |
| 111 | worst_label = ( |
| 112 | f" (worst: {result.worst_commit})" if result.worst_commit else "" |
| 113 | ) |
| 114 | typer.echo( |
| 115 | f"Flagged: {result.flagged_commits} / {result.total_commits} commits" |
| 116 | f"{worst_label}" |
| 117 | ) |
| 118 | |
| 119 | |
| 120 | def _render_json(result: GrooveCheckResult) -> None: |
| 121 | """Emit groove-check results as structured JSON for agent consumption. |
| 122 | |
| 123 | Args: |
| 124 | result: Completed :class:`GrooveCheckResult` from the analysis. |
| 125 | """ |
| 126 | payload = { |
| 127 | "commit_range": result.commit_range, |
| 128 | "threshold": result.threshold, |
| 129 | "total_commits": result.total_commits, |
| 130 | "flagged_commits": result.flagged_commits, |
| 131 | "worst_commit": result.worst_commit, |
| 132 | "entries": [ |
| 133 | { |
| 134 | "commit": e.commit, |
| 135 | "groove_score": e.groove_score, |
| 136 | "drift_delta": e.drift_delta, |
| 137 | "status": e.status.value, |
| 138 | "track": e.track, |
| 139 | "section": e.section, |
| 140 | "midi_files": e.midi_files, |
| 141 | } |
| 142 | for e in result.entries |
| 143 | ], |
| 144 | } |
| 145 | typer.echo(json.dumps(payload, indent=2)) |
| 146 | |
| 147 | |
| 148 | # --------------------------------------------------------------------------- |
| 149 | # Async core (injectable for tests) |
| 150 | # --------------------------------------------------------------------------- |
| 151 | |
| 152 | |
| 153 | async def _groove_check_async( |
| 154 | *, |
| 155 | root: pathlib.Path, |
| 156 | session: AsyncSession, |
| 157 | commit_range: Optional[str], |
| 158 | track: Optional[str], |
| 159 | section: Optional[str], |
| 160 | threshold: float, |
| 161 | as_json: bool, |
| 162 | ) -> GrooveCheckResult: |
| 163 | """Core groove-check logic — fully injectable for unit tests. |
| 164 | |
| 165 | Resolves the effective commit range from the ``.muse/`` layout, calls |
| 166 | :func:`compute_groove_check` (pure, stub-backed), and renders output. |
| 167 | |
| 168 | Args: |
| 169 | root: Repository root (directory containing ``.muse/``). |
| 170 | session: Open async DB session (reserved for full implementation). |
| 171 | commit_range: Explicit range string or None to use the last 10 commits. |
| 172 | track: Restrict analysis to a named instrument track. |
| 173 | section: Restrict analysis to a named musical section. |
| 174 | threshold: Drift threshold in beats for WARN/FAIL classification. |
| 175 | as_json: Emit JSON instead of the ASCII table. |
| 176 | |
| 177 | Returns: |
| 178 | The :class:`GrooveCheckResult` produced by the analysis. |
| 179 | """ |
| 180 | if threshold <= 0: |
| 181 | typer.echo("❌ --threshold must be a positive number.") |
| 182 | raise typer.Exit(code=ExitCode.USER_ERROR) |
| 183 | |
| 184 | muse_dir = root / ".muse" |
| 185 | head_path = muse_dir / "HEAD" |
| 186 | head_ref = head_path.read_text().strip() |
| 187 | branch = head_ref.rsplit("/", 1)[-1] if "/" in head_ref else head_ref |
| 188 | ref_path = muse_dir / pathlib.Path(head_ref) |
| 189 | head_sha = ref_path.read_text().strip() if ref_path.exists() else "" |
| 190 | |
| 191 | if not head_sha and not commit_range: |
| 192 | typer.echo(f"No commits yet on branch {branch} — nothing to analyse.") |
| 193 | raise typer.Exit(code=ExitCode.SUCCESS) |
| 194 | |
| 195 | effective_range = commit_range or f"HEAD~{10}..HEAD" |
| 196 | |
| 197 | result = compute_groove_check( |
| 198 | commit_range=effective_range, |
| 199 | threshold=threshold, |
| 200 | track=track, |
| 201 | section=section, |
| 202 | ) |
| 203 | |
| 204 | if as_json: |
| 205 | _render_json(result) |
| 206 | else: |
| 207 | _render_table(result) |
| 208 | |
| 209 | return result |
| 210 | |
| 211 | |
| 212 | # --------------------------------------------------------------------------- |
| 213 | # Typer command |
| 214 | # --------------------------------------------------------------------------- |
| 215 | |
| 216 | |
| 217 | @app.callback(invoke_without_command=True) |
| 218 | def groove_check( |
| 219 | ctx: typer.Context, |
| 220 | commit_range: Annotated[ |
| 221 | Optional[str], |
| 222 | typer.Argument( |
| 223 | help=( |
| 224 | "Commit range to analyze (e.g. HEAD~5..HEAD). " |
| 225 | "Defaults to the last 10 commits." |
| 226 | ), |
| 227 | show_default=False, |
| 228 | metavar="RANGE", |
| 229 | ), |
| 230 | ] = None, |
| 231 | track: Annotated[ |
| 232 | Optional[str], |
| 233 | typer.Option( |
| 234 | "--track", |
| 235 | help="Scope analysis to a specific instrument track (e.g. 'drums').", |
| 236 | show_default=False, |
| 237 | metavar="TEXT", |
| 238 | ), |
| 239 | ] = None, |
| 240 | section: Annotated[ |
| 241 | Optional[str], |
| 242 | typer.Option( |
| 243 | "--section", |
| 244 | help="Scope analysis to a specific musical section (e.g. 'verse').", |
| 245 | show_default=False, |
| 246 | metavar="TEXT", |
| 247 | ), |
| 248 | ] = None, |
| 249 | threshold: Annotated[ |
| 250 | float, |
| 251 | typer.Option( |
| 252 | "--threshold", |
| 253 | help=( |
| 254 | "Drift threshold in beats (default 0.1). Commits whose " |
| 255 | "drift_delta exceeds this are flagged WARN; >2× = FAIL." |
| 256 | ), |
| 257 | ), |
| 258 | ] = DEFAULT_THRESHOLD, |
| 259 | as_json: Annotated[ |
| 260 | bool, |
| 261 | typer.Option("--json", help="Emit machine-readable JSON output."), |
| 262 | ] = False, |
| 263 | ) -> None: |
| 264 | """Analyze rhythmic drift across commits to find groove regressions. |
| 265 | |
| 266 | Computes note-onset deviation from the quantization grid for each commit |
| 267 | in the range, then flags commits where the deviation shifted significantly |
| 268 | relative to their neighbors. Use this after a session to spot which |
| 269 | commit made the pocket feel stiff. |
| 270 | """ |
| 271 | root = require_repo() |
| 272 | |
| 273 | async def _run() -> None: |
| 274 | async with open_session() as session: |
| 275 | await _groove_check_async( |
| 276 | root=root, |
| 277 | session=session, |
| 278 | commit_range=commit_range, |
| 279 | track=track, |
| 280 | section=section, |
| 281 | threshold=threshold, |
| 282 | as_json=as_json, |
| 283 | ) |
| 284 | |
| 285 | try: |
| 286 | asyncio.run(_run()) |
| 287 | except typer.Exit: |
| 288 | raise |
| 289 | except Exception as exc: |
| 290 | typer.echo(f"❌ muse groove-check failed: {exc}") |
| 291 | logger.error("❌ muse groove-check error: %s", exc, exc_info=True) |
| 292 | raise typer.Exit(code=ExitCode.INTERNAL_ERROR) |