cgcardona / muse public
recall.py python
312 lines 9.4 KB
12901c5a Initial extraction from tellurstori/maestro cgcardona <gabriel@tellurstori.com> 4d ago
1 """muse recall — keyword search over musical commit history.
2
3 Accepts a natural-language query string and returns the top-N commits from
4 history ranked by keyword overlap against commit messages.
5
6 Usage::
7
8 muse recall "dark jazz bassline"
9 muse recall "drum fill" --limit 3 --threshold 0.5 --branch main
10 muse recall "piano" --since 2026-01-01 --until 2026-02-01 --json
11
12 Scoring algorithm (stub — vector search planned):
13 Each commit message is tokenized (lowercase, split on whitespace/punctuation).
14 The score is the normalised overlap coefficient between the query tokens and
15 the message tokens:
16
17 score = |query_tokens ∩ message_tokens| / |query_tokens|
18
19 This gives 1.0 when every query word appears in the message, and 0.0 when
20 none do. Commits with score < ``--threshold`` are excluded.
21
22 Note:
23 Full vector embedding search via Qdrant is a planned enhancement (see muse
24 context / issue backlog). When implemented, the scoring function will be
25 replaced by cosine similarity over pre-computed embeddings, with no change
26 to the CLI interface.
27 """
28 from __future__ import annotations
29
30 import asyncio
31 import json
32 import logging
33 import pathlib
34 import re
35 from datetime import datetime, timezone
36 from typing import Annotated, TypedDict
37
38 import typer
39 from sqlalchemy.ext.asyncio import AsyncSession
40 from sqlalchemy.future import select
41
42 from maestro.muse_cli._repo import require_repo
43 from maestro.muse_cli.db import open_session
44 from maestro.muse_cli.errors import ExitCode
45 from maestro.muse_cli.models import MuseCliCommit
46
47 logger = logging.getLogger(__name__)
48
49 app = typer.Typer()
50
51 _TOKEN_RE = re.compile(r"[a-zA-Z0-9]+")
52
53
54 class RecallResult(TypedDict):
55 """A single ranked recall result entry."""
56
57 rank: int
58 score: float
59 commit_id: str
60 date: str
61 branch: str
62 message: str
63
64
65 # ---------------------------------------------------------------------------
66 # Scoring helpers
67 # ---------------------------------------------------------------------------
68
69
70 def _tokenize(text: str) -> set[str]:
71 """Return a set of lowercase word tokens from *text*."""
72 return {m.group().lower() for m in _TOKEN_RE.finditer(text)}
73
74
75 def _score(query_tokens: set[str], message: str) -> float:
76 """Return a [0, 1] keyword overlap score.
77
78 Uses the overlap coefficient: |Q ∩ M| / |Q| so that a short, precise
79 query can match a verbose commit message without penalty.
80
81 Returns 0.0 when *query_tokens* is empty (avoids division by zero).
82 """
83 if not query_tokens:
84 return 0.0
85 message_tokens = _tokenize(message)
86 return len(query_tokens & message_tokens) / len(query_tokens)
87
88
89 # ---------------------------------------------------------------------------
90 # Testable async core
91 # ---------------------------------------------------------------------------
92
93
94 async def _fetch_commits(
95 session: AsyncSession,
96 *,
97 repo_id: str,
98 branch: str | None,
99 since: datetime | None,
100 until: datetime | None,
101 ) -> list[MuseCliCommit]:
102 """Fetch all candidate commits from the DB, optionally filtered.
103
104 Filters are applied at the SQL level to minimise in-memory work. The
105 caller ranks and limits the result set.
106 """
107 stmt = select(MuseCliCommit).where(MuseCliCommit.repo_id == repo_id)
108
109 if branch is not None:
110 stmt = stmt.where(MuseCliCommit.branch == branch)
111 if since is not None:
112 stmt = stmt.where(MuseCliCommit.committed_at >= since)
113 if until is not None:
114 stmt = stmt.where(MuseCliCommit.committed_at <= until)
115
116 stmt = stmt.order_by(MuseCliCommit.committed_at.desc())
117
118 result = await session.execute(stmt)
119 return list(result.scalars().all())
120
121
122 async def _recall_async(
123 *,
124 root: pathlib.Path,
125 session: AsyncSession,
126 query: str,
127 limit: int,
128 threshold: float,
129 branch: str | None,
130 since: datetime | None,
131 until: datetime | None,
132 as_json: bool,
133 ) -> list[RecallResult]:
134 """Core recall logic — fully injectable for tests.
135
136 Returns the list of ranked result dicts (also echoed to stdout).
137 """
138 muse_dir = root / ".muse"
139 repo_data: dict[str, str] = json.loads((muse_dir / "repo.json").read_text())
140 repo_id = repo_data["repo_id"]
141
142 # When no branch filter is given, default to current branch from HEAD.
143 effective_branch: str | None = branch
144 if effective_branch is None:
145 head_ref = (muse_dir / "HEAD").read_text().strip()
146 effective_branch = head_ref.rsplit("/", 1)[-1]
147
148 commits = await _fetch_commits(
149 session,
150 repo_id=repo_id,
151 branch=effective_branch,
152 since=since,
153 until=until,
154 )
155
156 query_tokens = _tokenize(query)
157 scored: list[tuple[float, MuseCliCommit]] = []
158 for commit in commits:
159 score = _score(query_tokens, commit.message)
160 if score >= threshold:
161 scored.append((score, commit))
162
163 # Sort by score descending, then by recency (committed_at desc) for ties.
164 scored.sort(key=lambda x: (x[0], x[1].committed_at.timestamp()), reverse=True)
165 top = scored[:limit]
166
167 results: list[RecallResult] = [
168 RecallResult(
169 rank=i + 1,
170 score=round(score, 4),
171 commit_id=commit.commit_id,
172 date=commit.committed_at.strftime("%Y-%m-%d %H:%M:%S"),
173 branch=commit.branch,
174 message=commit.message,
175 )
176 for i, (score, commit) in enumerate(top)
177 ]
178
179 if as_json:
180 typer.echo(json.dumps(results, indent=2))
181 else:
182 _render_results(query=query, results=results, threshold=threshold)
183
184 return results
185
186
187 def _render_results(
188 *,
189 query: str,
190 results: list[RecallResult],
191 threshold: float,
192 ) -> None:
193 """Print ranked recall results in human-readable format.
194
195 Note: similarity scores are keyword-overlap estimates, not semantic
196 embeddings. Vector search via Qdrant is a planned enhancement.
197 """
198 typer.echo(f'Recall: "{query}"')
199 typer.echo(f"(keyword match · threshold {threshold:.2f} · "
200 "vector search is a planned enhancement)")
201 typer.echo("")
202
203 if not results:
204 typer.echo(" No matching commits found.")
205 return
206
207 for entry in results:
208 typer.echo(
209 f" #{entry['rank']} score={entry['score']:.4f} "
210 f"commit {entry['commit_id']} [{entry['date']}]"
211 )
212 typer.echo(f" {entry['message']}")
213 typer.echo("")
214
215
216 # ---------------------------------------------------------------------------
217 # Typer command
218 # ---------------------------------------------------------------------------
219
220
221 @app.callback(invoke_without_command=True)
222 def recall(
223 ctx: typer.Context,
224 query: Annotated[str, typer.Argument(help="Natural-language description to search for.")],
225 limit: int = typer.Option(
226 5,
227 "--limit",
228 "-n",
229 help="Maximum number of results to return.",
230 min=1,
231 ),
232 threshold: float = typer.Option(
233 0.6,
234 "--threshold",
235 help="Minimum similarity score (0–1) to include a commit.",
236 min=0.0,
237 max=1.0,
238 ),
239 branch: str | None = typer.Option(
240 None,
241 "--branch",
242 help="Filter by branch name (defaults to current branch).",
243 ),
244 since: str | None = typer.Option(
245 None,
246 "--since",
247 help="Only include commits on or after this date (YYYY-MM-DD).",
248 ),
249 until: str | None = typer.Option(
250 None,
251 "--until",
252 help="Only include commits on or before this date (YYYY-MM-DD).",
253 ),
254 as_json: bool = typer.Option(
255 False,
256 "--json",
257 help="Output results as JSON.",
258 ),
259 ) -> None:
260 """Search commit history by description (keyword match over messages).
261
262 Returns the top ``--limit`` commits whose messages best match the query,
263 sorted by keyword overlap score. Commits below ``--threshold`` are
264 excluded.
265
266 Note:
267 Full semantic vector search via Qdrant is a planned enhancement
268 (see muse context). Until then, scoring is based on keyword overlap
269 between the query and commit messages.
270 """
271 # Validate date args first — fail fast before touching the filesystem.
272 since_dt: datetime | None = None
273 until_dt: datetime | None = None
274
275 if since:
276 try:
277 since_dt = datetime.strptime(since, "%Y-%m-%d").replace(tzinfo=timezone.utc)
278 except ValueError:
279 typer.echo(f"❌ --since: invalid date format '{since}' — expected YYYY-MM-DD")
280 raise typer.Exit(code=ExitCode.USER_ERROR)
281
282 if until:
283 try:
284 until_dt = datetime.strptime(until, "%Y-%m-%d").replace(tzinfo=timezone.utc)
285 except ValueError:
286 typer.echo(f"❌ --until: invalid date format '{until}' — expected YYYY-MM-DD")
287 raise typer.Exit(code=ExitCode.USER_ERROR)
288
289 root = require_repo()
290
291 async def _run() -> None:
292 async with open_session() as session:
293 await _recall_async(
294 root=root,
295 session=session,
296 query=query,
297 limit=limit,
298 threshold=threshold,
299 branch=branch,
300 since=since_dt,
301 until=until_dt,
302 as_json=as_json,
303 )
304
305 try:
306 asyncio.run(_run())
307 except typer.Exit:
308 raise
309 except Exception as exc:
310 typer.echo(f"❌ muse recall failed: {exc}")
311 logger.error("❌ muse recall error: %s", exc, exc_info=True)
312 raise typer.Exit(code=ExitCode.INTERNAL_ERROR)