cgcardona / muse public
_query.py python
205 lines 6.8 KB
e6786943 feat: upgrade to Python 3.14, drop from __future__ import annotations Gabriel Cardona <cgcardona@gmail.com> 1d ago
1 """Shared query helpers for the code-domain CLI commands.
2
3 This module provides the low-level primitives that multiple code-domain
4 commands need — symbol extraction from snapshots, commit-graph walking,
5 and language classification — so each command can stay thin.
6
7 None of these functions are part of the public ``CodePlugin`` API. They
8 are internal helpers for the CLI layer and must not be imported by any
9 core module.
10 """
11
12 import itertools
13 import logging
14 import pathlib
15 from collections.abc import Iterator
16
17 from muse.core.object_store import read_object
18 from muse.core.store import CommitRecord, read_commit
19 from muse.domain import DomainOp
20 from muse.plugins.code.ast_parser import (
21 SEMANTIC_EXTENSIONS,
22 SymbolRecord,
23 SymbolTree,
24 parse_symbols,
25 )
26
27 logger = logging.getLogger(__name__)
28
29 # ---------------------------------------------------------------------------
30 # Language classification
31 # ---------------------------------------------------------------------------
32
33 _SUFFIX_LANG: dict[str, str] = {
34 ".py": "Python", ".pyi": "Python",
35 ".ts": "TypeScript", ".tsx": "TypeScript",
36 ".js": "JavaScript", ".jsx": "JavaScript",
37 ".mjs": "JavaScript", ".cjs": "JavaScript",
38 ".go": "Go",
39 ".rs": "Rust",
40 ".java": "Java",
41 ".cs": "C#",
42 ".c": "C", ".h": "C",
43 ".cpp": "C++", ".cc": "C++", ".cxx": "C++", ".hpp": "C++", ".hxx": "C++",
44 ".rb": "Ruby",
45 ".kt": "Kotlin", ".kts": "Kotlin",
46 }
47
48
49 def language_of(file_path: str) -> str:
50 """Return a display language name for *file_path* based on its suffix."""
51 suffix = pathlib.PurePosixPath(file_path).suffix.lower()
52 return _SUFFIX_LANG.get(suffix, suffix or "(no ext)")
53
54
55 def is_semantic(file_path: str) -> bool:
56 """Return ``True`` if *file_path* has a suffix with AST-level support."""
57 suffix = pathlib.PurePosixPath(file_path).suffix.lower()
58 return suffix in SEMANTIC_EXTENSIONS
59
60
61 # ---------------------------------------------------------------------------
62 # Symbol extraction from a snapshot manifest
63 # ---------------------------------------------------------------------------
64
65
66 def symbols_for_snapshot(
67 root: pathlib.Path,
68 manifest: dict[str, str],
69 *,
70 kind_filter: str | None = None,
71 file_filter: str | None = None,
72 language_filter: str | None = None,
73 ) -> dict[str, SymbolTree]:
74 """Extract symbol trees for all semantic files in *manifest*.
75
76 Args:
77 root: Repository root (used to locate the object store).
78 manifest: Snapshot manifest mapping file path → SHA-256.
79 kind_filter: If set, only include symbols with this ``kind``.
80 file_filter: If set, only include symbols from this exact file path.
81 language_filter: If set, only include symbols from files of this language.
82
83 Returns:
84 Dict mapping ``file_path → SymbolTree``; empty trees are omitted.
85 """
86 result: dict[str, SymbolTree] = {}
87 for file_path, object_id in sorted(manifest.items()):
88 if not is_semantic(file_path):
89 continue
90 if file_filter and file_path != file_filter:
91 continue
92 if language_filter and language_of(file_path) != language_filter:
93 continue
94 raw = read_object(root, object_id)
95 if raw is None:
96 logger.debug("Object %s missing for %s — skipping", object_id[:8], file_path)
97 continue
98 tree = parse_symbols(raw, file_path)
99 if kind_filter:
100 tree = {addr: rec for addr, rec in tree.items() if rec["kind"] == kind_filter}
101 if tree:
102 result[file_path] = tree
103 return result
104
105
106 # ---------------------------------------------------------------------------
107 # Commit-graph walking
108 # ---------------------------------------------------------------------------
109
110
111 def walk_commits(
112 root: pathlib.Path,
113 start_commit_id: str,
114 max_commits: int = 10_000,
115 ) -> list[CommitRecord]:
116 """Walk the parent chain from *start_commit_id*, newest-first.
117
118 Args:
119 root: Repository root.
120 start_commit_id: SHA-256 of the commit to start from.
121 max_commits: Safety cap — stop after this many commits.
122
123 Returns:
124 List of ``CommitRecord`` objects, newest first.
125 """
126 commits: list[CommitRecord] = []
127 seen: set[str] = set()
128 current_id: str | None = start_commit_id
129 while current_id and current_id not in seen and len(commits) < max_commits:
130 seen.add(current_id)
131 commit = read_commit(root, current_id)
132 if commit is None:
133 break
134 commits.append(commit)
135 current_id = commit.parent_commit_id
136 return commits
137
138
139 def walk_commits_range(
140 root: pathlib.Path,
141 to_commit_id: str,
142 from_commit_id: str | None,
143 max_commits: int = 10_000,
144 ) -> list[CommitRecord]:
145 """Collect commits from *to_commit_id* back to (not including) *from_commit_id*.
146
147 Args:
148 root: Repository root.
149 to_commit_id: Inclusive end of the range.
150 from_commit_id: Exclusive start; ``None`` means walk to the initial commit.
151 max_commits: Safety cap.
152
153 Returns:
154 List of ``CommitRecord`` objects, newest first.
155 """
156 commits: list[CommitRecord] = []
157 seen: set[str] = set()
158 current_id: str | None = to_commit_id
159 while current_id and current_id not in seen and len(commits) < max_commits:
160 seen.add(current_id)
161 if current_id == from_commit_id:
162 break
163 commit = read_commit(root, current_id)
164 if commit is None:
165 break
166 commits.append(commit)
167 current_id = commit.parent_commit_id
168 return commits
169
170
171 # ---------------------------------------------------------------------------
172 # Op traversal helpers
173 # ---------------------------------------------------------------------------
174
175
176 def flat_symbol_ops(ops: list[DomainOp]) -> Iterator[DomainOp]:
177 """Yield all leaf ops, recursing into PatchOp.child_ops.
178
179 Only yields ops that have a symbol-level address (i.e. contain ``::``).
180 """
181 for op in ops:
182 if op["op"] == "patch":
183 for child in op["child_ops"]:
184 if "::" in child["address"]:
185 yield child
186 elif "::" in op["address"]:
187 yield op
188
189
190 def touched_files(ops: list[DomainOp]) -> frozenset[str]:
191 """Return the set of file paths that appear as PatchOp addresses in *ops*.
192
193 Only counts files that had symbol-level child ops (semantic changes),
194 not coarse file-level replace/insert/delete ops.
195 """
196 files: set[str] = set()
197 for op in ops:
198 if op["op"] == "patch" and op["child_ops"]:
199 files.add(op["address"])
200 return frozenset(files)
201
202
203 def file_pairs(files: frozenset[str]) -> Iterator[tuple[str, str]]:
204 """Yield all ordered pairs ``(a, b)`` with ``a < b`` from *files*."""
205 yield from itertools.combinations(sorted(files), 2)