cgcardona / muse public
checkout_symbol.py python
211 lines 7.0 KB
e6786943 feat: upgrade to Python 3.14, drop from __future__ import annotations Gabriel Cardona <cgcardona@gmail.com> 1d ago
1 """muse checkout-symbol — restore a historical version of a specific symbol.
2
3 Extracts a single named symbol from a historical committed snapshot and writes
4 it back into the current working-tree file, replacing the current version of
5 that symbol.
6
7 This is a **surgical** operation: only the target symbol's lines change.
8 All surrounding code — other symbols, comments, imports, blank lines outside
9 the symbol boundary — is left untouched.
10
11 Why this matters
12 ----------------
13 Git's ``checkout`` restores entire files. If you need to roll back a single
14 function while keeping everything else current, you need to manually cherry-
15 pick lines. ``muse checkout-symbol`` does this atomically against Muse's
16 content-addressed symbol index.
17
18 Usage::
19
20 muse checkout-symbol "src/billing.py::compute_invoice_total" --commit HEAD~3
21 muse checkout-symbol "src/auth.py::validate_token" --commit abc12345 --dry-run
22
23 Output (without --dry-run)::
24
25 Restoring: src/billing.py::compute_invoice_total
26 from commit: abc12345 (2026-02-15)
27 lines 42–67 → replaced with 31 historical lines
28 ✅ Written to src/billing.py
29
30 Output (with --dry-run)::
31
32 Dry run — no files will be written.
33
34 Restoring: src/billing.py::compute_invoice_total
35 from commit: abc12345 (2026-02-15)
36
37 --- current
38 +++ historical
39 @@ -42,26 +42,20 @@
40 def compute_invoice_total(...):
41 - ...current body...
42 + ...historical body...
43
44 Flags:
45
46 ``--commit, -c REF``
47 Required. Commit to restore from.
48
49 ``--dry-run``
50 Print the diff without writing anything.
51 """
52
53 import difflib
54 import json
55 import logging
56 import pathlib
57
58 import typer
59
60 from muse.core.errors import ExitCode
61 from muse.core.object_store import read_object
62 from muse.core.repo import require_repo
63 from muse.core.store import get_commit_snapshot_manifest, resolve_commit_ref
64 from muse.plugins.code.ast_parser import parse_symbols
65
66 logger = logging.getLogger(__name__)
67
68 app = typer.Typer()
69
70
71 def _read_repo_id(root: pathlib.Path) -> str:
72 return str(json.loads((root / ".muse" / "repo.json").read_text())["repo_id"])
73
74
75 def _read_branch(root: pathlib.Path) -> str:
76 head_ref = (root / ".muse" / "HEAD").read_text().strip()
77 return head_ref.removeprefix("refs/heads/").strip()
78
79
80 def _extract_lines(source: bytes, lineno: int, end_lineno: int) -> list[str]:
81 """Extract lines *lineno*..*end_lineno* (1-indexed, inclusive) as a list."""
82 all_lines = source.decode("utf-8", errors="replace").splitlines(keepends=True)
83 return all_lines[lineno - 1:end_lineno]
84
85
86 def _find_current_symbol_lines(
87 working_tree_file: pathlib.Path,
88 address: str,
89 ) -> tuple[int, int] | None:
90 """Return (lineno, end_lineno) for *address* in the current working-tree file.
91
92 Returns ``None`` if the symbol is not found.
93 """
94 if not working_tree_file.exists():
95 return None
96 raw = working_tree_file.read_bytes()
97 tree = parse_symbols(raw, str(working_tree_file))
98 rec = tree.get(address)
99 if rec is None:
100 return None
101 return rec["lineno"], rec["end_lineno"]
102
103
104 @app.callback(invoke_without_command=True)
105 def checkout_symbol(
106 ctx: typer.Context,
107 address: str = typer.Argument(
108 ..., metavar="ADDRESS",
109 help='Symbol address, e.g. "src/billing.py::compute_invoice_total".',
110 ),
111 ref: str = typer.Option(
112 ..., "--commit", "-c", metavar="REF",
113 help="Commit to restore the symbol from (required).",
114 ),
115 dry_run: bool = typer.Option(
116 False, "--dry-run",
117 help="Print the diff without writing anything.",
118 ),
119 ) -> None:
120 """Restore a historical version of a specific symbol into the working tree.
121
122 Extracts the symbol body from the given historical commit and splices it
123 into the current working-tree file at the symbol's current location.
124 Only the target symbol's lines change; everything else is left untouched.
125
126 If the symbol does not exist at ``--commit``, the command exits with an
127 error. If the symbol does not exist in the current working tree (perhaps
128 it was deleted), the historical version is appended to the end of the file.
129 """
130 root = require_repo()
131 repo_id = _read_repo_id(root)
132 branch = _read_branch(root)
133
134 if "::" not in address:
135 typer.echo("❌ ADDRESS must be a symbol address like 'src/billing.py::func'.", err=True)
136 raise typer.Exit(code=ExitCode.USER_ERROR)
137
138 file_rel, sym_qualified = address.split("::", 1)
139
140 commit = resolve_commit_ref(root, repo_id, branch, ref)
141 if commit is None:
142 typer.echo(f"❌ Commit '{ref}' not found.", err=True)
143 raise typer.Exit(code=ExitCode.USER_ERROR)
144
145 # Read the historical blob.
146 manifest = get_commit_snapshot_manifest(root, commit.commit_id) or {}
147 obj_id = manifest.get(file_rel)
148 if obj_id is None:
149 typer.echo(
150 f"❌ '{file_rel}' is not in snapshot {commit.commit_id[:8]}.", err=True
151 )
152 raise typer.Exit(code=ExitCode.USER_ERROR)
153
154 historical_raw = read_object(root, obj_id)
155 if historical_raw is None:
156 typer.echo(f"❌ Blob {obj_id[:8]} missing from object store.", err=True)
157 raise typer.Exit(code=ExitCode.USER_ERROR)
158
159 # Find the symbol in the historical blob.
160 hist_tree = parse_symbols(historical_raw, file_rel)
161 hist_rec = hist_tree.get(address)
162 if hist_rec is None:
163 typer.echo(
164 f"❌ Symbol '{address}' not found in commit {commit.commit_id[:8]}.", err=True
165 )
166 raise typer.Exit(code=ExitCode.USER_ERROR)
167
168 historical_lines = _extract_lines(
169 historical_raw, hist_rec["lineno"], hist_rec["end_lineno"]
170 )
171
172 # Find the symbol in the current working tree.
173 working_file = root / file_rel
174 current_lines = working_file.read_bytes().decode("utf-8", errors="replace").splitlines(
175 keepends=True
176 ) if working_file.exists() else []
177
178 current_sym_range = _find_current_symbol_lines(working_file, address)
179
180 if dry_run:
181 typer.echo("Dry run — no files will be written.\n")
182
183 typer.echo(f"Restoring: {address}")
184 typer.echo(f" from commit: {commit.commit_id[:8]} ({commit.committed_at.date()})")
185
186 if current_sym_range is not None:
187 cur_start, cur_end = current_sym_range
188 typer.echo(
189 f" lines {cur_start}–{cur_end} → replaced with "
190 f"{len(historical_lines)} historical line(s)"
191 )
192 new_lines = current_lines[:cur_start - 1] + historical_lines + current_lines[cur_end:]
193 else:
194 typer.echo(f" symbol not found in working tree — appending at end of file")
195 new_lines = current_lines + ["\n"] + historical_lines
196
197 if dry_run:
198 # Show unified diff.
199 diff = difflib.unified_diff(
200 current_lines,
201 new_lines,
202 fromfile="current",
203 tofile="historical",
204 lineterm="",
205 )
206 typer.echo("\n" + "".join(diff))
207 return
208
209 # Write the patched file.
210 working_file.write_text("".join(new_lines), encoding="utf-8")
211 typer.echo(f"✅ Written to {file_rel}")