_code_query.py
python
| 1 | """Code-domain query evaluator for the Muse generic query engine. |
| 2 | |
| 3 | Implements :data:`~muse.core.query_engine.CommitEvaluator` for the code domain. |
| 4 | Allows agents and humans to search the commit history for code changes:: |
| 5 | |
| 6 | muse code-query "symbol == 'my_function' and change == 'added'" |
| 7 | muse code-query "language == 'Python' and author == 'agent-x'" |
| 8 | muse code-query "agent_id == 'claude' and sem_ver_bump == 'major'" |
| 9 | muse code-query "file == 'src/core.py'" |
| 10 | muse code-query "change == 'added' and kind == 'class'" |
| 11 | |
| 12 | Query language |
| 13 | -------------- |
| 14 | |
| 15 | query = and_expr ( 'or' and_expr )* |
| 16 | and_expr = atom ( 'and' atom )* |
| 17 | atom = FIELD OP VALUE |
| 18 | FIELD = 'symbol' | 'file' | 'language' | 'kind' | 'change' |
| 19 | | 'author' | 'agent_id' | 'model_id' | 'toolchain_id' |
| 20 | | 'sem_ver_bump' | 'branch' |
| 21 | OP = '==' | '!=' | 'contains' | 'startswith' |
| 22 | VALUE = QUOTED_STRING | UNQUOTED_WORD |
| 23 | |
| 24 | Supported fields |
| 25 | ---------------- |
| 26 | |
| 27 | ``symbol`` Qualified symbol name (e.g. ``"MyClass.method"``). |
| 28 | ``file`` Workspace-relative file path. |
| 29 | ``language`` Language name (``"Python"``, ``"TypeScript"``…). |
| 30 | ``kind`` Symbol kind (``"function"``, ``"class"``, ``"method"``…). |
| 31 | ``change`` ``"added"``, ``"removed"``, or ``"modified"``. |
| 32 | ``author`` Commit author string. |
| 33 | ``agent_id`` Agent identity from commit provenance. |
| 34 | ``model_id`` Model ID from commit provenance. |
| 35 | ``toolchain_id`` Toolchain string from commit provenance. |
| 36 | ``sem_ver_bump`` Semantic version bump: ``"none"``, ``"patch"``, |
| 37 | ``"minor"``, ``"major"``. |
| 38 | ``branch`` Branch name. |
| 39 | """ |
| 40 | from __future__ import annotations |
| 41 | |
| 42 | import logging |
| 43 | import pathlib |
| 44 | import re |
| 45 | from dataclasses import dataclass |
| 46 | from typing import Literal, TypeGuard, get_args |
| 47 | |
| 48 | from muse.core.query_engine import CommitEvaluator, QueryMatch |
| 49 | from muse.core.store import CommitRecord |
| 50 | from muse.domain import DomainOp |
| 51 | from muse.plugins.code._query import language_of, symbols_for_snapshot |
| 52 | |
| 53 | logger = logging.getLogger(__name__) |
| 54 | |
| 55 | |
| 56 | # --------------------------------------------------------------------------- |
| 57 | # Query AST types |
| 58 | # --------------------------------------------------------------------------- |
| 59 | |
| 60 | CodeField = Literal[ |
| 61 | "symbol", "file", "language", "kind", "change", |
| 62 | "author", "agent_id", "model_id", "toolchain_id", |
| 63 | "sem_ver_bump", "branch", |
| 64 | ] |
| 65 | |
| 66 | CodeOp = Literal["==", "!=", "contains", "startswith"] |
| 67 | |
| 68 | |
| 69 | @dataclass(frozen=True) |
| 70 | class Comparison: |
| 71 | """A single field OP value predicate.""" |
| 72 | |
| 73 | field: CodeField |
| 74 | op: CodeOp |
| 75 | value: str |
| 76 | |
| 77 | |
| 78 | @dataclass(frozen=True) |
| 79 | class AndExpr: |
| 80 | """Conjunction of predicates (all must match).""" |
| 81 | |
| 82 | clauses: list[Comparison] |
| 83 | |
| 84 | |
| 85 | @dataclass(frozen=True) |
| 86 | class OrExpr: |
| 87 | """Disjunction of AND-expressions (any must match).""" |
| 88 | |
| 89 | clauses: list[AndExpr] |
| 90 | |
| 91 | |
| 92 | # --------------------------------------------------------------------------- |
| 93 | # Tokeniser & parser |
| 94 | # --------------------------------------------------------------------------- |
| 95 | |
| 96 | _TOKEN_RE = re.compile( |
| 97 | r""" |
| 98 | (?P<keyword>(?:or|and|contains|startswith)(?![A-Za-z0-9_.])) |
| 99 | |(?P<op>==|!=) |
| 100 | |(?P<quoted>"[^"]*"|'[^']*') |
| 101 | |(?P<word>[A-Za-z_][A-Za-z0-9_.]*) |
| 102 | """, |
| 103 | re.VERBOSE, |
| 104 | ) |
| 105 | |
| 106 | _VALID_FIELDS: frozenset[str] = frozenset(get_args(CodeField)) |
| 107 | _VALID_OPS: frozenset[str] = frozenset(get_args(CodeOp)) |
| 108 | |
| 109 | |
| 110 | def _is_code_field(tok: str) -> TypeGuard[CodeField]: |
| 111 | return tok in _VALID_FIELDS |
| 112 | |
| 113 | |
| 114 | def _is_code_op(tok: str) -> TypeGuard[CodeOp]: |
| 115 | return tok in _VALID_OPS |
| 116 | |
| 117 | |
| 118 | def _as_code_field(tok: str) -> CodeField: |
| 119 | """Validate and narrow *tok* to :data:`CodeField`; raises :exc:`ValueError` if invalid.""" |
| 120 | if not _is_code_field(tok): |
| 121 | raise ValueError(f"Unknown field: {tok!r}. Valid: {sorted(_VALID_FIELDS)}") |
| 122 | return tok |
| 123 | |
| 124 | |
| 125 | def _as_code_op(tok: str) -> CodeOp: |
| 126 | """Validate and narrow *tok* to :data:`CodeOp`; raises :exc:`ValueError` if invalid.""" |
| 127 | if not _is_code_op(tok): |
| 128 | raise ValueError(f"Unknown operator: {tok!r}. Valid: {sorted(_VALID_OPS)}") |
| 129 | return tok |
| 130 | |
| 131 | |
| 132 | def _tokenize(query: str) -> list[str]: |
| 133 | return [m.group() for m in _TOKEN_RE.finditer(query)] |
| 134 | |
| 135 | |
| 136 | def _parse_query(query: str) -> OrExpr: |
| 137 | """Parse a query string into an :class:`OrExpr` AST.""" |
| 138 | tokens = _tokenize(query.strip()) |
| 139 | pos = 0 |
| 140 | |
| 141 | def peek() -> str | None: |
| 142 | return tokens[pos] if pos < len(tokens) else None |
| 143 | |
| 144 | def consume() -> str: |
| 145 | nonlocal pos |
| 146 | tok = tokens[pos] |
| 147 | pos += 1 |
| 148 | return tok |
| 149 | |
| 150 | def parse_atom() -> Comparison: |
| 151 | field_tok = consume() |
| 152 | validated_field = _as_code_field(field_tok) |
| 153 | op_tok = consume() |
| 154 | validated_op = _as_code_op(op_tok) |
| 155 | val_tok = consume() |
| 156 | if val_tok.startswith(("'", '"')): |
| 157 | val_tok = val_tok[1:-1] |
| 158 | return Comparison( |
| 159 | field=validated_field, |
| 160 | op=validated_op, |
| 161 | value=val_tok, |
| 162 | ) |
| 163 | |
| 164 | def parse_and() -> AndExpr: |
| 165 | clauses: list[Comparison] = [parse_atom()] |
| 166 | while peek() == "and": |
| 167 | consume() |
| 168 | clauses.append(parse_atom()) |
| 169 | return AndExpr(clauses=clauses) |
| 170 | |
| 171 | def parse_or() -> OrExpr: |
| 172 | clauses: list[AndExpr] = [parse_and()] |
| 173 | while peek() == "or": |
| 174 | consume() |
| 175 | clauses.append(parse_and()) |
| 176 | return OrExpr(clauses=clauses) |
| 177 | |
| 178 | return parse_or() |
| 179 | |
| 180 | |
| 181 | # --------------------------------------------------------------------------- |
| 182 | # Evaluator |
| 183 | # --------------------------------------------------------------------------- |
| 184 | |
| 185 | |
| 186 | def _match_op(actual: str, op: CodeOp, expected: str) -> bool: |
| 187 | """Apply *op* to *actual* and *expected* strings.""" |
| 188 | if op == "==": |
| 189 | return actual == expected |
| 190 | if op == "!=": |
| 191 | return actual != expected |
| 192 | if op == "contains": |
| 193 | return expected.lower() in actual.lower() |
| 194 | # op == "startswith" |
| 195 | return actual.lower().startswith(expected.lower()) |
| 196 | |
| 197 | |
| 198 | def _commit_matches_comparison( |
| 199 | comparison: Comparison, |
| 200 | commit: CommitRecord, |
| 201 | manifest: dict[str, str], |
| 202 | root: pathlib.Path, |
| 203 | symbol_matches: list[dict[str, str]], |
| 204 | ) -> bool: |
| 205 | """Return True if *commit* + its symbols satisfy *comparison*. |
| 206 | |
| 207 | For symbol/file/language/kind/change fields, each (symbol, file) pair |
| 208 | that matches is appended to *symbol_matches* for result detail. |
| 209 | """ |
| 210 | f = comparison.field |
| 211 | op = comparison.op |
| 212 | v = comparison.value |
| 213 | |
| 214 | # Commit-level fields — no symbol iteration needed. |
| 215 | if f == "author": |
| 216 | return _match_op(commit.author, op, v) |
| 217 | if f == "agent_id": |
| 218 | return _match_op(commit.agent_id, op, v) |
| 219 | if f == "model_id": |
| 220 | return _match_op(commit.model_id, op, v) |
| 221 | if f == "toolchain_id": |
| 222 | return _match_op(commit.toolchain_id, op, v) |
| 223 | if f == "sem_ver_bump": |
| 224 | return _match_op(commit.sem_ver_bump, op, v) |
| 225 | if f == "branch": |
| 226 | return _match_op(commit.branch, op, v) |
| 227 | |
| 228 | # Symbol/file-level fields — iterate the delta ops. |
| 229 | delta = commit.structured_delta |
| 230 | if delta is None: |
| 231 | return False |
| 232 | |
| 233 | hit = False |
| 234 | for op_rec in delta.get("ops", []): |
| 235 | op_type = op_rec.get("op", "") |
| 236 | address: str = op_rec.get("address", "") |
| 237 | |
| 238 | # Resolve file vs symbol from address. |
| 239 | if "::" in address: |
| 240 | file_path, symbol_name = address.split("::", 1) |
| 241 | else: |
| 242 | file_path = address |
| 243 | symbol_name = "" |
| 244 | |
| 245 | lang = language_of(file_path) |
| 246 | change_type = ( |
| 247 | "added" if op_type == "insert" |
| 248 | else "removed" if op_type == "delete" |
| 249 | else "modified" |
| 250 | ) |
| 251 | |
| 252 | # For PatchOps also iterate child_ops. |
| 253 | all_ops: list[DomainOp] = [op_rec] |
| 254 | if op_rec.get("op") == "patch" and op_rec["op"] == "patch": |
| 255 | all_ops = [op_rec] + op_rec["child_ops"] |
| 256 | |
| 257 | for rec in all_ops: |
| 258 | rec_address: str = str(rec.get("address", address)) |
| 259 | if "::" in rec_address: |
| 260 | rec_file, rec_symbol = rec_address.split("::", 1) |
| 261 | else: |
| 262 | rec_file = rec_address |
| 263 | rec_symbol = "" |
| 264 | |
| 265 | rec_kind = str(rec.get("kind", "")) |
| 266 | rec_op_type = str(rec.get("op", "")) |
| 267 | rec_change = ( |
| 268 | "added" if rec_op_type == "insert" |
| 269 | else "removed" if rec_op_type == "delete" |
| 270 | else "modified" |
| 271 | ) |
| 272 | |
| 273 | field_val = { |
| 274 | "symbol": rec_symbol or symbol_name, |
| 275 | "file": rec_file or file_path, |
| 276 | "language": lang, |
| 277 | "kind": rec_kind, |
| 278 | "change": rec_change or change_type, |
| 279 | }.get(f, "") |
| 280 | |
| 281 | if field_val is not None and _match_op(field_val, op, v): |
| 282 | hit = True |
| 283 | sym = rec_symbol or symbol_name |
| 284 | symbol_matches.append({ |
| 285 | "file": rec_file or file_path, |
| 286 | "symbol": sym, |
| 287 | "kind": rec_kind, |
| 288 | "change": rec_change or change_type, |
| 289 | "language": lang, |
| 290 | }) |
| 291 | |
| 292 | return hit |
| 293 | |
| 294 | |
| 295 | def build_evaluator(query: str) -> CommitEvaluator: |
| 296 | """Parse *query* and return a :data:`CommitEvaluator` for :func:`~muse.core.query_engine.walk_history`. |
| 297 | |
| 298 | Args: |
| 299 | query: A query string in the code query DSL. |
| 300 | |
| 301 | Returns: |
| 302 | A callable that can be passed to :func:`~muse.core.query_engine.walk_history`. |
| 303 | |
| 304 | Raises: |
| 305 | ValueError: If the query cannot be parsed. |
| 306 | """ |
| 307 | ast = _parse_query(query) |
| 308 | |
| 309 | def evaluator( |
| 310 | commit: CommitRecord, |
| 311 | manifest: dict[str, str], |
| 312 | root: pathlib.Path, |
| 313 | ) -> list[QueryMatch]: |
| 314 | matches: list[QueryMatch] = [] |
| 315 | symbol_matches: list[dict[str, str]] = [] |
| 316 | |
| 317 | # An OrExpr matches when any AndExpr matches. |
| 318 | for and_expr in ast.clauses: |
| 319 | clause_symbols: list[dict[str, str]] = [] |
| 320 | # An AndExpr matches when ALL comparisons match. |
| 321 | all_match = all( |
| 322 | _commit_matches_comparison(cmp, commit, manifest, root, clause_symbols) |
| 323 | for cmp in and_expr.clauses |
| 324 | ) |
| 325 | if all_match: |
| 326 | symbol_matches.extend(clause_symbols) |
| 327 | break # or-short-circuit |
| 328 | |
| 329 | if not symbol_matches: |
| 330 | # Check if commit-level only match. |
| 331 | only_commit_fields = all( |
| 332 | cmp.field in {"author", "agent_id", "model_id", "toolchain_id", "sem_ver_bump", "branch"} |
| 333 | for and_expr in ast.clauses |
| 334 | for cmp in and_expr.clauses |
| 335 | ) |
| 336 | commit_match = any( |
| 337 | all( |
| 338 | _commit_matches_comparison(cmp, commit, manifest, root, []) |
| 339 | for cmp in and_expr.clauses |
| 340 | ) |
| 341 | for and_expr in ast.clauses |
| 342 | ) |
| 343 | if only_commit_fields and commit_match: |
| 344 | m = QueryMatch( |
| 345 | commit_id=commit.commit_id, |
| 346 | author=commit.author, |
| 347 | committed_at=commit.committed_at.isoformat(), |
| 348 | branch=commit.branch, |
| 349 | detail=commit.message[:80], |
| 350 | extra={}, |
| 351 | ) |
| 352 | if commit.agent_id: |
| 353 | m["agent_id"] = commit.agent_id |
| 354 | if commit.model_id: |
| 355 | m["model_id"] = commit.model_id |
| 356 | matches.append(m) |
| 357 | else: |
| 358 | for sym in symbol_matches[:20]: # cap per-commit matches |
| 359 | detail = sym.get("symbol") or sym.get("file", "?") |
| 360 | change = sym.get("change", "") |
| 361 | if change: |
| 362 | detail = f"{detail} ({change})" |
| 363 | m = QueryMatch( |
| 364 | commit_id=commit.commit_id, |
| 365 | author=commit.author, |
| 366 | committed_at=commit.committed_at.isoformat(), |
| 367 | branch=commit.branch, |
| 368 | detail=detail, |
| 369 | extra={k: v for k, v in sym.items()}, |
| 370 | ) |
| 371 | if commit.agent_id: |
| 372 | m["agent_id"] = commit.agent_id |
| 373 | matches.append(m) |
| 374 | |
| 375 | return matches |
| 376 | |
| 377 | return evaluator |