recall.py
python
| 1 | """muse recall — keyword search over musical commit history. |
| 2 | |
| 3 | Accepts a natural-language query string and returns the top-N commits from |
| 4 | history ranked by keyword overlap against commit messages. |
| 5 | |
| 6 | Usage:: |
| 7 | |
| 8 | muse recall "dark jazz bassline" |
| 9 | muse recall "drum fill" --limit 3 --threshold 0.5 --branch main |
| 10 | muse recall "piano" --since 2026-01-01 --until 2026-02-01 --json |
| 11 | |
| 12 | Scoring algorithm (stub — vector search planned): |
| 13 | Each commit message is tokenized (lowercase, split on whitespace/punctuation). |
| 14 | The score is the normalised overlap coefficient between the query tokens and |
| 15 | the message tokens: |
| 16 | |
| 17 | score = |query_tokens ∩ message_tokens| / |query_tokens| |
| 18 | |
| 19 | This gives 1.0 when every query word appears in the message, and 0.0 when |
| 20 | none do. Commits with score < ``--threshold`` are excluded. |
| 21 | |
| 22 | Note: |
| 23 | Full vector embedding search via Qdrant is a planned enhancement (see muse |
| 24 | context / issue backlog). When implemented, the scoring function will be |
| 25 | replaced by cosine similarity over pre-computed embeddings, with no change |
| 26 | to the CLI interface. |
| 27 | """ |
| 28 | from __future__ import annotations |
| 29 | |
| 30 | import asyncio |
| 31 | import json |
| 32 | import logging |
| 33 | import pathlib |
| 34 | import re |
| 35 | from datetime import datetime, timezone |
| 36 | from typing import Annotated, TypedDict |
| 37 | |
| 38 | import typer |
| 39 | from sqlalchemy.ext.asyncio import AsyncSession |
| 40 | from sqlalchemy.future import select |
| 41 | |
| 42 | from maestro.muse_cli._repo import require_repo |
| 43 | from maestro.muse_cli.db import open_session |
| 44 | from maestro.muse_cli.errors import ExitCode |
| 45 | from maestro.muse_cli.models import MuseCliCommit |
| 46 | |
| 47 | logger = logging.getLogger(__name__) |
| 48 | |
| 49 | app = typer.Typer() |
| 50 | |
| 51 | _TOKEN_RE = re.compile(r"[a-zA-Z0-9]+") |
| 52 | |
| 53 | |
| 54 | class RecallResult(TypedDict): |
| 55 | """A single ranked recall result entry.""" |
| 56 | |
| 57 | rank: int |
| 58 | score: float |
| 59 | commit_id: str |
| 60 | date: str |
| 61 | branch: str |
| 62 | message: str |
| 63 | |
| 64 | |
| 65 | # --------------------------------------------------------------------------- |
| 66 | # Scoring helpers |
| 67 | # --------------------------------------------------------------------------- |
| 68 | |
| 69 | |
| 70 | def _tokenize(text: str) -> set[str]: |
| 71 | """Return a set of lowercase word tokens from *text*.""" |
| 72 | return {m.group().lower() for m in _TOKEN_RE.finditer(text)} |
| 73 | |
| 74 | |
| 75 | def _score(query_tokens: set[str], message: str) -> float: |
| 76 | """Return a [0, 1] keyword overlap score. |
| 77 | |
| 78 | Uses the overlap coefficient: |Q ∩ M| / |Q| so that a short, precise |
| 79 | query can match a verbose commit message without penalty. |
| 80 | |
| 81 | Returns 0.0 when *query_tokens* is empty (avoids division by zero). |
| 82 | """ |
| 83 | if not query_tokens: |
| 84 | return 0.0 |
| 85 | message_tokens = _tokenize(message) |
| 86 | return len(query_tokens & message_tokens) / len(query_tokens) |
| 87 | |
| 88 | |
| 89 | # --------------------------------------------------------------------------- |
| 90 | # Testable async core |
| 91 | # --------------------------------------------------------------------------- |
| 92 | |
| 93 | |
| 94 | async def _fetch_commits( |
| 95 | session: AsyncSession, |
| 96 | *, |
| 97 | repo_id: str, |
| 98 | branch: str | None, |
| 99 | since: datetime | None, |
| 100 | until: datetime | None, |
| 101 | ) -> list[MuseCliCommit]: |
| 102 | """Fetch all candidate commits from the DB, optionally filtered. |
| 103 | |
| 104 | Filters are applied at the SQL level to minimise in-memory work. The |
| 105 | caller ranks and limits the result set. |
| 106 | """ |
| 107 | stmt = select(MuseCliCommit).where(MuseCliCommit.repo_id == repo_id) |
| 108 | |
| 109 | if branch is not None: |
| 110 | stmt = stmt.where(MuseCliCommit.branch == branch) |
| 111 | if since is not None: |
| 112 | stmt = stmt.where(MuseCliCommit.committed_at >= since) |
| 113 | if until is not None: |
| 114 | stmt = stmt.where(MuseCliCommit.committed_at <= until) |
| 115 | |
| 116 | stmt = stmt.order_by(MuseCliCommit.committed_at.desc()) |
| 117 | |
| 118 | result = await session.execute(stmt) |
| 119 | return list(result.scalars().all()) |
| 120 | |
| 121 | |
| 122 | async def _recall_async( |
| 123 | *, |
| 124 | root: pathlib.Path, |
| 125 | session: AsyncSession, |
| 126 | query: str, |
| 127 | limit: int, |
| 128 | threshold: float, |
| 129 | branch: str | None, |
| 130 | since: datetime | None, |
| 131 | until: datetime | None, |
| 132 | as_json: bool, |
| 133 | ) -> list[RecallResult]: |
| 134 | """Core recall logic — fully injectable for tests. |
| 135 | |
| 136 | Returns the list of ranked result dicts (also echoed to stdout). |
| 137 | """ |
| 138 | muse_dir = root / ".muse" |
| 139 | repo_data: dict[str, str] = json.loads((muse_dir / "repo.json").read_text()) |
| 140 | repo_id = repo_data["repo_id"] |
| 141 | |
| 142 | # When no branch filter is given, default to current branch from HEAD. |
| 143 | effective_branch: str | None = branch |
| 144 | if effective_branch is None: |
| 145 | head_ref = (muse_dir / "HEAD").read_text().strip() |
| 146 | effective_branch = head_ref.rsplit("/", 1)[-1] |
| 147 | |
| 148 | commits = await _fetch_commits( |
| 149 | session, |
| 150 | repo_id=repo_id, |
| 151 | branch=effective_branch, |
| 152 | since=since, |
| 153 | until=until, |
| 154 | ) |
| 155 | |
| 156 | query_tokens = _tokenize(query) |
| 157 | scored: list[tuple[float, MuseCliCommit]] = [] |
| 158 | for commit in commits: |
| 159 | score = _score(query_tokens, commit.message) |
| 160 | if score >= threshold: |
| 161 | scored.append((score, commit)) |
| 162 | |
| 163 | # Sort by score descending, then by recency (committed_at desc) for ties. |
| 164 | scored.sort(key=lambda x: (x[0], x[1].committed_at.timestamp()), reverse=True) |
| 165 | top = scored[:limit] |
| 166 | |
| 167 | results: list[RecallResult] = [ |
| 168 | RecallResult( |
| 169 | rank=i + 1, |
| 170 | score=round(score, 4), |
| 171 | commit_id=commit.commit_id, |
| 172 | date=commit.committed_at.strftime("%Y-%m-%d %H:%M:%S"), |
| 173 | branch=commit.branch, |
| 174 | message=commit.message, |
| 175 | ) |
| 176 | for i, (score, commit) in enumerate(top) |
| 177 | ] |
| 178 | |
| 179 | if as_json: |
| 180 | typer.echo(json.dumps(results, indent=2)) |
| 181 | else: |
| 182 | _render_results(query=query, results=results, threshold=threshold) |
| 183 | |
| 184 | return results |
| 185 | |
| 186 | |
| 187 | def _render_results( |
| 188 | *, |
| 189 | query: str, |
| 190 | results: list[RecallResult], |
| 191 | threshold: float, |
| 192 | ) -> None: |
| 193 | """Print ranked recall results in human-readable format. |
| 194 | |
| 195 | Note: similarity scores are keyword-overlap estimates, not semantic |
| 196 | embeddings. Vector search via Qdrant is a planned enhancement. |
| 197 | """ |
| 198 | typer.echo(f'Recall: "{query}"') |
| 199 | typer.echo(f"(keyword match · threshold {threshold:.2f} · " |
| 200 | "vector search is a planned enhancement)") |
| 201 | typer.echo("") |
| 202 | |
| 203 | if not results: |
| 204 | typer.echo(" No matching commits found.") |
| 205 | return |
| 206 | |
| 207 | for entry in results: |
| 208 | typer.echo( |
| 209 | f" #{entry['rank']} score={entry['score']:.4f} " |
| 210 | f"commit {entry['commit_id']} [{entry['date']}]" |
| 211 | ) |
| 212 | typer.echo(f" {entry['message']}") |
| 213 | typer.echo("") |
| 214 | |
| 215 | |
| 216 | # --------------------------------------------------------------------------- |
| 217 | # Typer command |
| 218 | # --------------------------------------------------------------------------- |
| 219 | |
| 220 | |
| 221 | @app.callback(invoke_without_command=True) |
| 222 | def recall( |
| 223 | ctx: typer.Context, |
| 224 | query: Annotated[str, typer.Argument(help="Natural-language description to search for.")], |
| 225 | limit: int = typer.Option( |
| 226 | 5, |
| 227 | "--limit", |
| 228 | "-n", |
| 229 | help="Maximum number of results to return.", |
| 230 | min=1, |
| 231 | ), |
| 232 | threshold: float = typer.Option( |
| 233 | 0.6, |
| 234 | "--threshold", |
| 235 | help="Minimum similarity score (0–1) to include a commit.", |
| 236 | min=0.0, |
| 237 | max=1.0, |
| 238 | ), |
| 239 | branch: str | None = typer.Option( |
| 240 | None, |
| 241 | "--branch", |
| 242 | help="Filter by branch name (defaults to current branch).", |
| 243 | ), |
| 244 | since: str | None = typer.Option( |
| 245 | None, |
| 246 | "--since", |
| 247 | help="Only include commits on or after this date (YYYY-MM-DD).", |
| 248 | ), |
| 249 | until: str | None = typer.Option( |
| 250 | None, |
| 251 | "--until", |
| 252 | help="Only include commits on or before this date (YYYY-MM-DD).", |
| 253 | ), |
| 254 | as_json: bool = typer.Option( |
| 255 | False, |
| 256 | "--json", |
| 257 | help="Output results as JSON.", |
| 258 | ), |
| 259 | ) -> None: |
| 260 | """Search commit history by description (keyword match over messages). |
| 261 | |
| 262 | Returns the top ``--limit`` commits whose messages best match the query, |
| 263 | sorted by keyword overlap score. Commits below ``--threshold`` are |
| 264 | excluded. |
| 265 | |
| 266 | Note: |
| 267 | Full semantic vector search via Qdrant is a planned enhancement |
| 268 | (see muse context). Until then, scoring is based on keyword overlap |
| 269 | between the query and commit messages. |
| 270 | """ |
| 271 | # Validate date args first — fail fast before touching the filesystem. |
| 272 | since_dt: datetime | None = None |
| 273 | until_dt: datetime | None = None |
| 274 | |
| 275 | if since: |
| 276 | try: |
| 277 | since_dt = datetime.strptime(since, "%Y-%m-%d").replace(tzinfo=timezone.utc) |
| 278 | except ValueError: |
| 279 | typer.echo(f"❌ --since: invalid date format '{since}' — expected YYYY-MM-DD") |
| 280 | raise typer.Exit(code=ExitCode.USER_ERROR) |
| 281 | |
| 282 | if until: |
| 283 | try: |
| 284 | until_dt = datetime.strptime(until, "%Y-%m-%d").replace(tzinfo=timezone.utc) |
| 285 | except ValueError: |
| 286 | typer.echo(f"❌ --until: invalid date format '{until}' — expected YYYY-MM-DD") |
| 287 | raise typer.Exit(code=ExitCode.USER_ERROR) |
| 288 | |
| 289 | root = require_repo() |
| 290 | |
| 291 | async def _run() -> None: |
| 292 | async with open_session() as session: |
| 293 | await _recall_async( |
| 294 | root=root, |
| 295 | session=session, |
| 296 | query=query, |
| 297 | limit=limit, |
| 298 | threshold=threshold, |
| 299 | branch=branch, |
| 300 | since=since_dt, |
| 301 | until=until_dt, |
| 302 | as_json=as_json, |
| 303 | ) |
| 304 | |
| 305 | try: |
| 306 | asyncio.run(_run()) |
| 307 | except typer.Exit: |
| 308 | raise |
| 309 | except Exception as exc: |
| 310 | typer.echo(f"❌ muse recall failed: {exc}") |
| 311 | logger.error("❌ muse recall error: %s", exc, exc_info=True) |
| 312 | raise typer.Exit(code=ExitCode.INTERNAL_ERROR) |