cgcardona / muse public
muse_rebase.py python
1246 lines 44.6 KB
12901c5a Initial extraction from tellurstori/maestro cgcardona <gabriel@tellurstori.com> 4d ago
1 """Muse Rebase Service — replay commits onto a new base.
2
3 Algorithm
4 ---------
5 1. Find the merge-base (LCA) of the current branch HEAD and ``<upstream>``.
6 2. Collect the commits on the current branch that are *not* ancestors of
7 ``<upstream>`` (i.e., commits added since the LCA), ordered oldest first.
8 3. For each such commit, compute the snapshot delta relative to its own parent,
9 then apply that delta on top of the current ``onto`` tip (which starts as
10 ``<upstream>`` HEAD and advances after each successful replay).
11 4. Insert a new commit record (new commit_id because the parent has changed).
12 5. Advance the branch pointer to the final replayed commit.
13
14 Because Muse snapshots are content-addressed manifests (``{path: object_id}``),
15 ``apply_delta`` is a pure dict operation — no byte-level merge required.
16
17 State file: ``.muse/REBASE_STATE.json``
18 -----------------------------------------
19 Written when a conflict is detected mid-replay so ``--continue`` and ``--abort``
20 can resume or abandon the operation:
21
22 .. code-block:: json
23
24 {
25 "upstream_commit": "abc123...",
26 "base_commit": "def456...",
27 "original_branch": "feature",
28 "original_head": "ghi789...",
29 "commits_to_replay": ["cid1", "cid2", "cid3"],
30 "current_onto": "abc123...",
31 "completed_pairs": [["cid1", "new_cid1"]],
32 "current_commit": "cid2",
33 "conflict_paths": ["beat.mid"]
34 }
35
36 Boundary rules:
37 - Must NOT import StateStore, EntityRegistry, or get_or_create_store.
38 - Must NOT import executor modules or maestro_* handlers.
39 - May import muse_cli.db, muse_cli.models, muse_cli.merge_engine,
40 muse_cli.snapshot.
41
42 Domain analogy: a producer has 10 "fixup" commits from a late-night session.
43 ``muse rebase dev`` replays them cleanly onto the ``dev`` tip, producing a
44 linear history — the musical narrative stays readable as a sequence of
45 intentional variations.
46 """
47 from __future__ import annotations
48
49 import datetime
50 import json
51 import logging
52 import pathlib
53 from collections import deque
54 from dataclasses import dataclass, field
55 from typing import Optional
56
57 import typer
58 from sqlalchemy.ext.asyncio import AsyncSession
59
60 from maestro.muse_cli.db import (
61 get_commit_snapshot_manifest,
62 insert_commit,
63 upsert_snapshot,
64 )
65 from maestro.muse_cli.errors import ExitCode
66 from maestro.muse_cli.merge_engine import (
67 detect_conflicts,
68 diff_snapshots,
69 )
70 from maestro.muse_cli.models import MuseCliCommit
71 from maestro.muse_cli.snapshot import compute_commit_id, compute_snapshot_id
72
73 logger = logging.getLogger(__name__)
74
75 _REBASE_STATE_FILENAME = "REBASE_STATE.json"
76
77
78 # ---------------------------------------------------------------------------
79 # Result types
80 # ---------------------------------------------------------------------------
81
82
83 @dataclass(frozen=True)
84 class RebaseCommitPair:
85 """Maps an original commit to its replayed replacement.
86
87 Attributes:
88 original_commit_id: SHA of the commit that existed before the rebase.
89 new_commit_id: SHA of the freshly-replayed commit with the new parent.
90 """
91
92 original_commit_id: str
93 new_commit_id: str
94
95
96 @dataclass(frozen=True)
97 class RebaseResult:
98 """Outcome of a ``muse rebase`` operation.
99
100 Attributes:
101 branch: The branch that was rebased.
102 upstream: The upstream ref used as the new base.
103 upstream_commit_id: Resolved commit ID of the upstream tip.
104 base_commit_id: LCA commit where the histories diverged.
105 replayed: Ordered list of (original, new) commit pairs.
106 conflict_paths: Paths with unresolved conflicts (empty on success).
107 aborted: True when ``--abort`` cleared an in-progress rebase.
108 noop: True when no commits needed to be replayed.
109 autosquash_applied: True when ``--autosquash`` reordered commits.
110 """
111
112 branch: str
113 upstream: str
114 upstream_commit_id: str
115 base_commit_id: str
116 replayed: tuple[RebaseCommitPair, ...]
117 conflict_paths: tuple[str, ...]
118 aborted: bool
119 noop: bool
120 autosquash_applied: bool
121
122
123 # ---------------------------------------------------------------------------
124 # RebaseState — on-disk session record
125 # ---------------------------------------------------------------------------
126
127
128 @dataclass
129 class RebaseState:
130 """Describes an in-progress rebase with optional conflict information.
131
132 Attributes:
133 upstream_commit: The tip of the upstream branch used as the new base.
134 base_commit: LCA where ours and upstream diverged.
135 original_branch: Name of the branch being rebased.
136 original_head: Branch HEAD before the rebase started (for ``--abort``).
137 commits_to_replay: All original commits to replay, oldest first.
138 current_onto: The commit ID of the current "onto" tip.
139 completed_pairs: Pairs of (original, new) commit IDs already replayed.
140 current_commit: The commit being applied when a conflict was detected.
141 conflict_paths: Paths with unresolved conflicts (empty when none).
142 """
143
144 upstream_commit: str
145 base_commit: str
146 original_branch: str
147 original_head: str
148 commits_to_replay: list[str] = field(default_factory=list)
149 current_onto: str = ""
150 completed_pairs: list[list[str]] = field(default_factory=list)
151 current_commit: str = ""
152 conflict_paths: list[str] = field(default_factory=list)
153
154
155 # ---------------------------------------------------------------------------
156 # Filesystem helpers
157 # ---------------------------------------------------------------------------
158
159
160 def read_rebase_state(root: pathlib.Path) -> RebaseState | None:
161 """Return :class:`RebaseState` if a rebase is in progress, else ``None``.
162
163 Reads ``.muse/REBASE_STATE.json``. Returns ``None`` when the file does
164 not exist or cannot be parsed.
165
166 Args:
167 root: Repository root (directory containing ``.muse/``).
168 """
169 state_path = root / ".muse" / _REBASE_STATE_FILENAME
170 if not state_path.exists():
171 return None
172 try:
173 raw: dict[str, object] = json.loads(state_path.read_text())
174 except (json.JSONDecodeError, OSError) as exc:
175 logger.warning("⚠️ Failed to read %s: %s", _REBASE_STATE_FILENAME, exc)
176 return None
177
178 def _str(key: str, default: str = "") -> str:
179 v = raw.get(key, default)
180 return str(v) if v is not None else default
181
182 def _strlist(key: str) -> list[str]:
183 v = raw.get(key, [])
184 return [str(x) for x in v] if isinstance(v, list) else []
185
186 def _pairlist(key: str) -> list[list[str]]:
187 v = raw.get(key, [])
188 if not isinstance(v, list):
189 return []
190 result: list[list[str]] = []
191 for item in v:
192 if isinstance(item, list) and len(item) == 2:
193 result.append([str(item[0]), str(item[1])])
194 return result
195
196 return RebaseState(
197 upstream_commit=_str("upstream_commit"),
198 base_commit=_str("base_commit"),
199 original_branch=_str("original_branch"),
200 original_head=_str("original_head"),
201 commits_to_replay=_strlist("commits_to_replay"),
202 current_onto=_str("current_onto"),
203 completed_pairs=_pairlist("completed_pairs"),
204 current_commit=_str("current_commit"),
205 conflict_paths=_strlist("conflict_paths"),
206 )
207
208
209 def write_rebase_state(root: pathlib.Path, state: RebaseState) -> None:
210 """Persist *state* to ``.muse/REBASE_STATE.json``.
211
212 Args:
213 root: Repository root.
214 state: Current rebase session state.
215 """
216 state_path = root / ".muse" / _REBASE_STATE_FILENAME
217 data: dict[str, object] = {
218 "upstream_commit": state.upstream_commit,
219 "base_commit": state.base_commit,
220 "original_branch": state.original_branch,
221 "original_head": state.original_head,
222 "commits_to_replay": state.commits_to_replay,
223 "current_onto": state.current_onto,
224 "completed_pairs": state.completed_pairs,
225 "current_commit": state.current_commit,
226 "conflict_paths": state.conflict_paths,
227 }
228 state_path.write_text(json.dumps(data, indent=2))
229 logger.info(
230 "✅ Wrote REBASE_STATE.json (%d remaining, %d done)",
231 len(state.commits_to_replay),
232 len(state.completed_pairs),
233 )
234
235
236 def clear_rebase_state(root: pathlib.Path) -> None:
237 """Remove ``.muse/REBASE_STATE.json`` after a successful or aborted rebase."""
238 state_path = root / ".muse" / _REBASE_STATE_FILENAME
239 if state_path.exists():
240 state_path.unlink()
241 logger.debug("✅ Cleared REBASE_STATE.json")
242
243
244 # ---------------------------------------------------------------------------
245 # Pure helpers
246 # ---------------------------------------------------------------------------
247
248
249 def compute_delta(
250 parent_manifest: dict[str, str],
251 commit_manifest: dict[str, str],
252 ) -> tuple[dict[str, str], set[str]]:
253 """Compute the file-level changes introduced by a single commit.
254
255 Args:
256 parent_manifest: Snapshot of the commit's parent.
257 commit_manifest: Snapshot of the commit itself.
258
259 Returns:
260 Tuple of (additions_and_modifications, deletions):
261 - additions_and_modifications: ``{path: object_id}`` for paths added
262 or changed in *commit_manifest* relative to *parent_manifest*.
263 - deletions: Set of paths present in *parent_manifest* but absent from
264 *commit_manifest*.
265 """
266 changed_paths = diff_snapshots(parent_manifest, commit_manifest)
267 additions: dict[str, str] = {}
268 deletions: set[str] = set()
269 for path in changed_paths:
270 if path in commit_manifest:
271 additions[path] = commit_manifest[path]
272 else:
273 deletions.add(path)
274 return additions, deletions
275
276
277 def apply_delta(
278 onto_manifest: dict[str, str],
279 additions: dict[str, str],
280 deletions: set[str],
281 ) -> dict[str, str]:
282 """Apply a commit delta onto an ``onto`` snapshot manifest.
283
284 Produces a new manifest that represents the onto-manifest with the
285 same file changes that the original commit introduced over its parent.
286
287 Args:
288 onto_manifest: The current tip manifest to patch.
289 additions: Paths added or changed by the original commit.
290 deletions: Paths removed by the original commit.
291
292 Returns:
293 New manifest dict (copy of *onto_manifest* with delta applied).
294 """
295 result = dict(onto_manifest)
296 result.update(additions)
297 for path in deletions:
298 result.pop(path, None)
299 return result
300
301
302 def detect_rebase_conflicts(
303 onto_manifest: dict[str, str],
304 prev_onto_manifest: dict[str, str],
305 additions: dict[str, str],
306 deletions: set[str],
307 ) -> set[str]:
308 """Identify conflicts between the commit delta and changes on ``onto``.
309
310 A conflict occurs when a path was changed both in the commit being replayed
311 (relative to its parent) and in the onto branch (relative to the base).
312
313 Args:
314 onto_manifest: Current onto tip.
315 prev_onto_manifest: The onto state just before this replay step
316 (i.e. the merge base or the previous onto tip).
317 additions: Paths added/modified by the commit being replayed.
318 deletions: Paths deleted by the commit being replayed.
319
320 Returns:
321 Set of conflicting paths.
322 """
323 onto_changed = diff_snapshots(prev_onto_manifest, onto_manifest)
324 commit_changed = set(additions.keys()) | deletions
325 return detect_conflicts(onto_changed, commit_changed)
326
327
328 async def _collect_branch_commits_since_base(
329 session: AsyncSession,
330 head_commit_id: str,
331 base_commit_id: str,
332 ) -> list[MuseCliCommit]:
333 """Collect commits reachable from *head_commit_id* but not from *base_commit_id*.
334
335 Returns them in topological order, oldest first (replay order). Merge
336 commits (two parents) are included as single units; their second parent is
337 not traversed — i.e. only the primary-parent chain is followed.
338
339 Args:
340 session: Open async DB session.
341 head_commit_id: The current branch HEAD.
342 base_commit_id: The LCA — commits at or before this are excluded.
343
344 Returns:
345 List of :class:`MuseCliCommit` rows in replay order.
346 """
347 commits_reversed: list[MuseCliCommit] = []
348 seen: set[str] = set()
349 queue: deque[str] = deque([head_commit_id])
350
351 while queue:
352 cid = queue.popleft()
353 if cid in seen or cid == base_commit_id:
354 continue
355 seen.add(cid)
356 commit = await session.get(MuseCliCommit, cid)
357 if commit is None:
358 break
359 commits_reversed.append(commit)
360 if commit.parent_commit_id and commit.parent_commit_id != base_commit_id:
361 queue.append(commit.parent_commit_id)
362 elif commit.parent_commit_id == base_commit_id:
363 # Include this commit but stop traversal here
364 pass
365
366 # BFS gives newest-first; reverse to get oldest-first replay order.
367 return list(reversed(commits_reversed))
368
369
370 async def _find_merge_base_rebase(
371 session: AsyncSession,
372 commit_id_a: str,
373 commit_id_b: str,
374 ) -> str | None:
375 """Lowest common ancestor of two commits — thin wrapper used by the rebase.
376
377 Args:
378 session: Open async DB session.
379 commit_id_a: First commit ID (current branch HEAD).
380 commit_id_b: Second commit ID (upstream tip).
381
382 Returns:
383 LCA commit ID, or ``None`` if histories are disjoint.
384 """
385 from maestro.muse_cli.merge_engine import find_merge_base
386
387 return await find_merge_base(session, commit_id_a, commit_id_b)
388
389
390 def apply_autosquash(commits: list[MuseCliCommit]) -> tuple[list[MuseCliCommit], bool]:
391 """Reorder and flag fixup commits for autosquash.
392
393 Detects commits whose message starts with ``fixup! <msg>`` and moves them
394 immediately after the matching commit (matched by prefix of ``<msg>``).
395
396 Args:
397 commits: Commits in replay order (oldest first).
398
399 Returns:
400 Tuple of (reordered_commits, was_reordered).
401 """
402 # Build index of message → position for non-fixup commits
403 reordered: list[MuseCliCommit] = []
404 fixups: dict[str, list[MuseCliCommit]] = {}
405
406 for commit in commits:
407 if commit.message.startswith("fixup! "):
408 target_msg = commit.message[len("fixup! "):]
409 fixups.setdefault(target_msg, []).append(commit)
410 else:
411 reordered.append(commit)
412
413 if not fixups:
414 return commits, False
415
416 # Insert fixup commits immediately after their targets
417 result: list[MuseCliCommit] = []
418 for commit in reordered:
419 result.append(commit)
420 # Match by prefix of commit message
421 for target_msg, fixup_list in list(fixups.items()):
422 if commit.message.startswith(target_msg):
423 result.extend(fixup_list)
424 del fixups[target_msg]
425
426 # Any unmatched fixups go at the end
427 for fixup_list in fixups.values():
428 result.extend(fixup_list)
429
430 return result, True
431
432
433 # ---------------------------------------------------------------------------
434 # Interactive plan
435 # ---------------------------------------------------------------------------
436
437
438 class InteractivePlan:
439 """A parsed interactive rebase plan.
440
441 Plan lines have the format::
442
443 <action> <short-sha> <message>
444
445 Supported actions: ``pick``, ``squash``, ``drop``.
446
447 Attributes:
448 entries: List of (action, commit_id, message) tuples in plan order.
449 """
450
451 VALID_ACTIONS = frozenset({"pick", "squash", "drop", "fixup", "reword"})
452
453 def __init__(
454 self,
455 entries: list[tuple[str, str, str]],
456 ) -> None:
457 """Create a plan from parsed entries.
458
459 Args:
460 entries: List of (action, commit_id_prefix, message) tuples.
461 """
462 self.entries = entries
463
464 @classmethod
465 def from_text(cls, text: str) -> InteractivePlan:
466 """Parse a plan from the editor text.
467
468 Lines starting with ``#`` are comments and are ignored. Blank lines
469 are ignored. Each non-comment line must be ``<action> <sha> <msg>``.
470
471 Args:
472 text: Raw plan text as produced by :meth:`to_text`.
473
474 Returns:
475 Parsed :class:`InteractivePlan`.
476
477 Raises:
478 ValueError: If a line has an unrecognised action or missing fields.
479 """
480 entries: list[tuple[str, str, str]] = []
481 for raw_line in text.splitlines():
482 line = raw_line.strip()
483 if not line or line.startswith("#"):
484 continue
485 parts = line.split(None, 2)
486 if len(parts) < 2:
487 raise ValueError(f"Invalid plan line: {raw_line!r}")
488 action = parts[0].lower()
489 sha = parts[1]
490 msg = parts[2] if len(parts) > 2 else ""
491 if action not in cls.VALID_ACTIONS:
492 raise ValueError(f"Unknown action {action!r} in line: {raw_line!r}")
493 entries.append((action, sha, msg))
494 return cls(entries)
495
496 @classmethod
497 def from_commits(cls, commits: list[MuseCliCommit]) -> InteractivePlan:
498 """Build a default plan (all ``pick``) from a list of commits.
499
500 Args:
501 commits: Commits in replay order.
502
503 Returns:
504 :class:`InteractivePlan` with one ``pick`` entry per commit.
505 """
506 entries: list[tuple[str, str, str]] = []
507 for commit in commits:
508 entries.append(("pick", commit.commit_id[:8], commit.message))
509 return cls(entries)
510
511 def to_text(self) -> str:
512 """Render the plan to a human-editable text format."""
513 lines = [
514 "# Interactive rebase plan.",
515 "# Actions: pick, squash (fold into previous), drop (skip), fixup (squash no msg), reword",
516 "# Lines starting with '#' are ignored.",
517 "",
518 ]
519 for action, sha, msg in self.entries:
520 lines.append(f"{action} {sha} {msg}")
521 return "\n".join(lines) + "\n"
522
523 def resolve_against(
524 self, commits: list[MuseCliCommit]
525 ) -> list[tuple[str, MuseCliCommit]]:
526 """Match plan entries to the full commit list by SHA prefix.
527
528 Args:
529 commits: Original commits list (source of truth for full SHA).
530
531 Returns:
532 List of (action, commit) pairs in plan order, excluding dropped
533 commits.
534
535 Raises:
536 ValueError: If a plan SHA prefix matches no commit or is ambiguous.
537 """
538 resolved: list[tuple[str, MuseCliCommit]] = []
539 for action, sha_prefix, _msg in self.entries:
540 if action == "drop":
541 continue
542 matches = [c for c in commits if c.commit_id.startswith(sha_prefix)]
543 if not matches:
544 raise ValueError(
545 f"Plan SHA {sha_prefix!r} does not match any commit in the rebase range."
546 )
547 if len(matches) > 1:
548 raise ValueError(
549 f"Plan SHA {sha_prefix!r} is ambiguous — matches {len(matches)} commits."
550 )
551 resolved.append((action, matches[0]))
552 return resolved
553
554
555 # ---------------------------------------------------------------------------
556 # Async rebase core — single-step replay
557 # ---------------------------------------------------------------------------
558
559
560 async def _replay_one_commit(
561 *,
562 session: AsyncSession,
563 commit: MuseCliCommit,
564 onto_manifest: dict[str, str],
565 prev_onto_manifest: dict[str, str],
566 onto_commit_id: str,
567 branch: str,
568 ) -> tuple[str, dict[str, str], list[str]]:
569 """Replay a single commit onto the current onto tip.
570
571 Computes the delta the original commit introduced over its parent, applies
572 it to *onto_manifest*, detects conflicts, persists the new snapshot, and
573 inserts a new commit record.
574
575 Args:
576 session: Open async DB session.
577 commit: The original commit to replay.
578 onto_manifest: Snapshot manifest of the current onto tip.
579 prev_onto_manifest: Manifest of the onto base (for conflict detection).
580 onto_commit_id: Commit ID of the current onto tip.
581 branch: Branch being rebased (for the new commit record).
582
583 Returns:
584 Tuple of (new_commit_id, new_onto_manifest, conflict_paths):
585 - ``new_commit_id``: SHA of the newly inserted commit.
586 - ``new_onto_manifest``: Updated manifest for the next step.
587 - ``conflict_paths``: Empty list on success; non-empty on conflict.
588 """
589 # Resolve parent manifest for the original commit
590 parent_manifest: dict[str, str] = {}
591 if commit.parent_commit_id:
592 loaded = await get_commit_snapshot_manifest(session, commit.parent_commit_id)
593 if loaded is not None:
594 parent_manifest = loaded
595
596 commit_manifest = await get_commit_snapshot_manifest(session, commit.commit_id)
597 if commit_manifest is None:
598 commit_manifest = {}
599
600 additions, deletions = compute_delta(parent_manifest, commit_manifest)
601
602 conflict_paths = detect_rebase_conflicts(
603 onto_manifest=onto_manifest,
604 prev_onto_manifest=prev_onto_manifest,
605 additions=additions,
606 deletions=deletions,
607 )
608 if conflict_paths:
609 return "", onto_manifest, sorted(conflict_paths)
610
611 new_manifest = apply_delta(onto_manifest, additions, deletions)
612 new_snapshot_id = compute_snapshot_id(new_manifest)
613 await upsert_snapshot(session, manifest=new_manifest, snapshot_id=new_snapshot_id)
614 await session.flush()
615
616 committed_at = datetime.datetime.now(datetime.timezone.utc)
617 new_commit_id = compute_commit_id(
618 parent_ids=[onto_commit_id],
619 snapshot_id=new_snapshot_id,
620 message=commit.message,
621 committed_at_iso=committed_at.isoformat(),
622 )
623
624 new_commit = MuseCliCommit(
625 commit_id=new_commit_id,
626 repo_id=commit.repo_id,
627 branch=branch,
628 parent_commit_id=onto_commit_id,
629 snapshot_id=new_snapshot_id,
630 message=commit.message,
631 author=commit.author,
632 committed_at=committed_at,
633 commit_metadata=commit.commit_metadata,
634 )
635 await insert_commit(session, new_commit)
636
637 return new_commit_id, new_manifest, []
638
639
640 # ---------------------------------------------------------------------------
641 # Async rebase core — full pipeline
642 # ---------------------------------------------------------------------------
643
644
645 async def _rebase_async(
646 *,
647 upstream: str,
648 root: pathlib.Path,
649 session: AsyncSession,
650 interactive: bool = False,
651 autosquash: bool = False,
652 rebase_merges: bool = False,
653 ) -> RebaseResult:
654 """Run the rebase pipeline.
655
656 All filesystem and DB side-effects are isolated here so tests can inject
657 an in-memory SQLite session and a ``tmp_path`` root without touching a
658 real database.
659
660 Args:
661 upstream: Branch name or commit ID to rebase onto.
662 root: Repository root (directory containing ``.muse/``).
663 session: Open async DB session.
664 interactive: When ``True``, open $EDITOR with the rebase plan before
665 executing. The edited plan controls action, order, and squash
666 behaviour.
667 autosquash: When ``True``, automatically detect ``fixup!`` commits and
668 move them after their matching target commit.
669 rebase_merges: When ``True``, preserve merge commits during replay
670 (stub — see implementation note below).
671
672 Returns:
673 :class:`RebaseResult` describing what happened.
674
675 Raises:
676 ``typer.Exit`` with an appropriate exit code on user-facing errors.
677 """
678 import json as _json
679
680 muse_dir = root / ".muse"
681
682 # ── Guard: rebase already in progress ───────────────────────────────
683 if read_rebase_state(root) is not None:
684 typer.echo(
685 "❌ Rebase in progress. Use --continue to resume or --abort to cancel."
686 )
687 raise typer.Exit(code=ExitCode.USER_ERROR)
688
689 # ── Repo identity ────────────────────────────────────────────────────
690 repo_data: dict[str, str] = _json.loads((muse_dir / "repo.json").read_text())
691 repo_id = repo_data["repo_id"]
692
693 # ── Current branch ───────────────────────────────────────────────────
694 head_ref = (muse_dir / "HEAD").read_text().strip()
695 current_branch = head_ref.rsplit("/", 1)[-1]
696 our_ref_path = muse_dir / pathlib.Path(head_ref)
697
698 ours_commit_id = our_ref_path.read_text().strip() if our_ref_path.exists() else ""
699 if not ours_commit_id:
700 typer.echo("❌ Current branch has no commits. Cannot rebase.")
701 raise typer.Exit(code=ExitCode.USER_ERROR)
702
703 # ── Resolve upstream ─────────────────────────────────────────────────
704 # Try as a branch name first, then as a raw commit ID
705 upstream_ref_path = muse_dir / "refs" / "heads" / upstream
706 if upstream_ref_path.exists():
707 upstream_commit_id = upstream_ref_path.read_text().strip()
708 else:
709 # Might be a raw commit ID
710 candidate = await session.get(MuseCliCommit, upstream)
711 if candidate is None:
712 typer.echo(
713 f"❌ Upstream {upstream!r} is not a known branch or commit ID."
714 )
715 raise typer.Exit(code=ExitCode.USER_ERROR)
716 upstream_commit_id = candidate.commit_id
717
718 # ── Already up-to-date guard ─────────────────────────────────────────
719 if ours_commit_id == upstream_commit_id:
720 typer.echo("Already up-to-date.")
721 return RebaseResult(
722 branch=current_branch,
723 upstream=upstream,
724 upstream_commit_id=upstream_commit_id,
725 base_commit_id=upstream_commit_id,
726 replayed=(),
727 conflict_paths=(),
728 aborted=False,
729 noop=True,
730 autosquash_applied=False,
731 )
732
733 # ── Find merge base ──────────────────────────────────────────────────
734 base_commit_id = await _find_merge_base_rebase(
735 session, ours_commit_id, upstream_commit_id
736 )
737 if base_commit_id is None:
738 typer.echo(
739 f"❌ Cannot find a common ancestor between current branch and {upstream!r}. "
740 "Histories are disjoint — use 'muse merge' instead."
741 )
742 raise typer.Exit(code=ExitCode.USER_ERROR)
743
744 # ── Fast-forward: current branch IS the base → nothing to replay ─────
745 if base_commit_id == ours_commit_id:
746 # Current branch is behind upstream — just advance the pointer
747 our_ref_path.write_text(upstream_commit_id)
748 typer.echo(
749 f"✅ Fast-forward: {current_branch} → {upstream_commit_id[:8]}"
750 )
751 return RebaseResult(
752 branch=current_branch,
753 upstream=upstream,
754 upstream_commit_id=upstream_commit_id,
755 base_commit_id=base_commit_id,
756 replayed=(),
757 conflict_paths=(),
758 aborted=False,
759 noop=True,
760 autosquash_applied=False,
761 )
762
763 # ── Already up-to-date: upstream IS the base → we are ahead ──────────
764 if base_commit_id == upstream_commit_id:
765 typer.echo("Already up-to-date.")
766 return RebaseResult(
767 branch=current_branch,
768 upstream=upstream,
769 upstream_commit_id=upstream_commit_id,
770 base_commit_id=base_commit_id,
771 replayed=(),
772 conflict_paths=(),
773 aborted=False,
774 noop=True,
775 autosquash_applied=False,
776 )
777
778 # ── Collect commits to replay ─────────────────────────────────────────
779 commits_to_replay = await _collect_branch_commits_since_base(
780 session, ours_commit_id, base_commit_id
781 )
782
783 if not commits_to_replay:
784 typer.echo("Nothing to rebase.")
785 return RebaseResult(
786 branch=current_branch,
787 upstream=upstream,
788 upstream_commit_id=upstream_commit_id,
789 base_commit_id=base_commit_id,
790 replayed=(),
791 conflict_paths=(),
792 aborted=False,
793 noop=True,
794 autosquash_applied=False,
795 )
796
797 autosquash_applied = False
798
799 # ── Autosquash ────────────────────────────────────────────────────────
800 if autosquash:
801 commits_to_replay, autosquash_applied = apply_autosquash(commits_to_replay)
802
803 # ── Interactive plan ──────────────────────────────────────────────────
804 plan_actions: list[tuple[str, MuseCliCommit]] = [
805 ("pick", c) for c in commits_to_replay
806 ]
807
808 if interactive:
809 import os
810 import subprocess
811 import tempfile
812
813 plan = InteractivePlan.from_commits(commits_to_replay)
814 with tempfile.NamedTemporaryFile(
815 mode="w", suffix=".rebase-plan", delete=False
816 ) as tf:
817 tf.write(plan.to_text())
818 tf_path = tf.name
819
820 editor = os.environ.get("EDITOR", os.environ.get("VISUAL", "vi"))
821 result = subprocess.run([editor, tf_path])
822 if result.returncode != 0:
823 typer.echo("⚠️ Editor exited with non-zero code — rebase aborted.")
824 raise typer.Exit(code=ExitCode.USER_ERROR)
825
826 edited_text = pathlib.Path(tf_path).read_text()
827 pathlib.Path(tf_path).unlink(missing_ok=True)
828
829 try:
830 edited_plan = InteractivePlan.from_text(edited_text)
831 plan_actions = edited_plan.resolve_against(commits_to_replay)
832 except ValueError as exc:
833 typer.echo(f"❌ Invalid rebase plan: {exc}")
834 raise typer.Exit(code=ExitCode.USER_ERROR)
835
836 if not plan_actions:
837 typer.echo("Nothing to do.")
838 raise typer.Exit(code=ExitCode.SUCCESS)
839
840 # ── Resolve base and upstream manifests ───────────────────────────────
841 base_manifest = await get_commit_snapshot_manifest(session, base_commit_id) or {}
842 onto_manifest = (
843 await get_commit_snapshot_manifest(session, upstream_commit_id) or {}
844 )
845 onto_commit_id = upstream_commit_id
846 prev_onto_manifest = base_manifest
847
848 completed_pairs: list[RebaseCommitPair] = []
849 pending_squash_manifest: dict[str, str] | None = None
850 pending_squash_commits: list[MuseCliCommit] = []
851
852 for action, commit in plan_actions:
853 if action == "drop":
854 continue
855
856 if action in ("squash", "fixup"):
857 # Accumulate into squash group
858 pending_squash_commits.append(commit)
859 if pending_squash_manifest is None:
860 # First in group — compute its delta
861 parent_manifest: dict[str, str] = {}
862 if commit.parent_commit_id:
863 loaded = await get_commit_snapshot_manifest(
864 session, commit.parent_commit_id
865 )
866 if loaded is not None:
867 parent_manifest = loaded
868 commit_manifest = (
869 await get_commit_snapshot_manifest(session, commit.commit_id) or {}
870 )
871 additions, deletions = compute_delta(parent_manifest, commit_manifest)
872 pending_squash_manifest = apply_delta(
873 onto_manifest, additions, deletions
874 )
875 else:
876 # Add this commit's changes on top of the squash in progress
877 parent_manifest_sq: dict[str, str] = {}
878 if commit.parent_commit_id:
879 loaded_sq = await get_commit_snapshot_manifest(
880 session, commit.parent_commit_id
881 )
882 if loaded_sq is not None:
883 parent_manifest_sq = loaded_sq
884 commit_manifest_sq = (
885 await get_commit_snapshot_manifest(session, commit.commit_id) or {}
886 )
887 additions_sq, deletions_sq = compute_delta(
888 parent_manifest_sq, commit_manifest_sq
889 )
890 pending_squash_manifest = apply_delta(
891 pending_squash_manifest, additions_sq, deletions_sq
892 )
893 continue
894
895 # ── Flush pending squash group (if any) ───────────────────────
896 if pending_squash_commits and pending_squash_manifest is not None:
897 squash_message = pending_squash_commits[0].message
898 squash_snap_id = compute_snapshot_id(pending_squash_manifest)
899 await upsert_snapshot(
900 session, manifest=pending_squash_manifest, snapshot_id=squash_snap_id
901 )
902 await session.flush()
903
904 squash_at = datetime.datetime.now(datetime.timezone.utc)
905 squash_commit_id = compute_commit_id(
906 parent_ids=[onto_commit_id],
907 snapshot_id=squash_snap_id,
908 message=squash_message,
909 committed_at_iso=squash_at.isoformat(),
910 )
911 squash_commit = MuseCliCommit(
912 commit_id=squash_commit_id,
913 repo_id=pending_squash_commits[0].repo_id,
914 branch=current_branch,
915 parent_commit_id=onto_commit_id,
916 snapshot_id=squash_snap_id,
917 message=squash_message,
918 author=pending_squash_commits[0].author,
919 committed_at=squash_at,
920 )
921 await insert_commit(session, squash_commit)
922
923 for orig in pending_squash_commits:
924 completed_pairs.append(
925 RebaseCommitPair(
926 original_commit_id=orig.commit_id,
927 new_commit_id=squash_commit_id,
928 )
929 )
930 onto_manifest = pending_squash_manifest
931 onto_commit_id = squash_commit_id
932 pending_squash_commits = []
933 pending_squash_manifest = None
934
935 # ── Normal pick ────────────────────────────────────────────────
936 new_commit_id, new_manifest, conflict_paths_list = await _replay_one_commit(
937 session=session,
938 commit=commit,
939 onto_manifest=onto_manifest,
940 prev_onto_manifest=prev_onto_manifest,
941 onto_commit_id=onto_commit_id,
942 branch=current_branch,
943 )
944
945 if conflict_paths_list:
946 # Persist state and exit with conflict
947 remaining_ids = [
948 c.commit_id
949 for _, c in plan_actions[
950 plan_actions.index((action, commit)) + 1 :
951 ]
952 ]
953 state = RebaseState(
954 upstream_commit=upstream_commit_id,
955 base_commit=base_commit_id,
956 original_branch=current_branch,
957 original_head=ours_commit_id,
958 commits_to_replay=remaining_ids,
959 current_onto=onto_commit_id,
960 completed_pairs=[
961 [p.original_commit_id, p.new_commit_id] for p in completed_pairs
962 ],
963 current_commit=commit.commit_id,
964 conflict_paths=conflict_paths_list,
965 )
966 write_rebase_state(root, state)
967
968 typer.echo(
969 f"❌ Conflict while replaying {commit.commit_id[:8]} ({commit.message!r}):\n"
970 + "\n".join(f"\tboth modified: {p}" for p in conflict_paths_list)
971 + "\nResolve conflicts, then run 'muse rebase --continue'."
972 )
973 raise typer.Exit(code=ExitCode.USER_ERROR)
974
975 completed_pairs.append(
976 RebaseCommitPair(
977 original_commit_id=commit.commit_id,
978 new_commit_id=new_commit_id,
979 )
980 )
981 prev_onto_manifest = onto_manifest
982 onto_manifest = new_manifest
983 onto_commit_id = new_commit_id
984
985 # ── Flush any trailing squash group ──────────────────────────────────
986 if pending_squash_commits and pending_squash_manifest is not None:
987 squash_message = pending_squash_commits[0].message
988 squash_snap_id = compute_snapshot_id(pending_squash_manifest)
989 await upsert_snapshot(
990 session, manifest=pending_squash_manifest, snapshot_id=squash_snap_id
991 )
992 await session.flush()
993
994 squash_at = datetime.datetime.now(datetime.timezone.utc)
995 squash_commit_id = compute_commit_id(
996 parent_ids=[onto_commit_id],
997 snapshot_id=squash_snap_id,
998 message=squash_message,
999 committed_at_iso=squash_at.isoformat(),
1000 )
1001 squash_commit = MuseCliCommit(
1002 commit_id=squash_commit_id,
1003 repo_id=pending_squash_commits[0].repo_id,
1004 branch=current_branch,
1005 parent_commit_id=onto_commit_id,
1006 snapshot_id=squash_snap_id,
1007 message=squash_message,
1008 author=pending_squash_commits[0].author,
1009 committed_at=squash_at,
1010 )
1011 await insert_commit(session, squash_commit)
1012
1013 for orig in pending_squash_commits:
1014 completed_pairs.append(
1015 RebaseCommitPair(
1016 original_commit_id=orig.commit_id,
1017 new_commit_id=squash_commit_id,
1018 )
1019 )
1020 onto_commit_id = squash_commit_id
1021
1022 # ── Advance branch pointer ────────────────────────────────────────────
1023 our_ref_path.write_text(onto_commit_id)
1024
1025 typer.echo(
1026 f"✅ Rebased {len(completed_pairs)} commit(s) onto {upstream!r} "
1027 f"[{current_branch} {onto_commit_id[:8]}]"
1028 )
1029 logger.info(
1030 "✅ muse rebase: %d commit(s) replayed onto %r (%s), branch %r now at %s",
1031 len(completed_pairs),
1032 upstream,
1033 upstream_commit_id[:8],
1034 current_branch,
1035 onto_commit_id[:8],
1036 )
1037
1038 return RebaseResult(
1039 branch=current_branch,
1040 upstream=upstream,
1041 upstream_commit_id=upstream_commit_id,
1042 base_commit_id=base_commit_id,
1043 replayed=tuple(completed_pairs),
1044 conflict_paths=(),
1045 aborted=False,
1046 noop=False,
1047 autosquash_applied=autosquash_applied,
1048 )
1049
1050
1051 async def _rebase_continue_async(
1052 *,
1053 root: pathlib.Path,
1054 session: AsyncSession,
1055 ) -> RebaseResult:
1056 """Resume a rebase that was paused due to a conflict.
1057
1058 Reads ``REBASE_STATE.json``, assumes the conflicted commit has been resolved
1059 manually, creates a new commit from the current ``onto`` state, and
1060 continues replaying the remaining commits.
1061
1062 Args:
1063 root: Repository root.
1064 session: Open async DB session.
1065
1066 Returns:
1067 :class:`RebaseResult` describing the completed rebase.
1068
1069 Raises:
1070 ``typer.Exit``: If no rebase is in progress or conflicts remain.
1071 """
1072 import json as _json
1073
1074 rebase_state = read_rebase_state(root)
1075 if rebase_state is None:
1076 typer.echo("❌ No rebase in progress. Nothing to continue.")
1077 raise typer.Exit(code=ExitCode.USER_ERROR)
1078
1079 if rebase_state.conflict_paths:
1080 typer.echo(
1081 f"❌ {len(rebase_state.conflict_paths)} conflict(s) not yet resolved:\n"
1082 + "\n".join(
1083 f"\tboth modified: {p}" for p in rebase_state.conflict_paths
1084 )
1085 + "\nResolve conflicts manually, then run 'muse rebase --continue'."
1086 )
1087 raise typer.Exit(code=ExitCode.USER_ERROR)
1088
1089 muse_dir = root / ".muse"
1090 repo_data: dict[str, str] = _json.loads((muse_dir / "repo.json").read_text())
1091 repo_id = repo_data["repo_id"]
1092
1093 current_branch = rebase_state.original_branch
1094 head_ref = f"refs/heads/{current_branch}"
1095 our_ref_path = muse_dir / pathlib.Path(head_ref)
1096
1097 completed_pairs: list[RebaseCommitPair] = [
1098 RebaseCommitPair(original_commit_id=p[0], new_commit_id=p[1])
1099 for p in rebase_state.completed_pairs
1100 ]
1101
1102 onto_commit_id = rebase_state.current_onto
1103 onto_manifest = (
1104 await get_commit_snapshot_manifest(session, onto_commit_id) or {}
1105 )
1106
1107 # Replay the remaining commits
1108 for orig_cid in rebase_state.commits_to_replay:
1109 commit = await session.get(MuseCliCommit, orig_cid)
1110 if commit is None:
1111 typer.echo(f"⚠️ Commit {orig_cid[:8]} not found — skipping.")
1112 continue
1113
1114 parent_manifest: dict[str, str] = {}
1115 if commit.parent_commit_id:
1116 loaded = await get_commit_snapshot_manifest(
1117 session, commit.parent_commit_id
1118 )
1119 if loaded is not None:
1120 parent_manifest = loaded
1121
1122 commit_manifest = (
1123 await get_commit_snapshot_manifest(session, commit.commit_id) or {}
1124 )
1125 additions, deletions = compute_delta(parent_manifest, commit_manifest)
1126 new_manifest = apply_delta(onto_manifest, additions, deletions)
1127 new_snapshot_id = compute_snapshot_id(new_manifest)
1128 await upsert_snapshot(
1129 session, manifest=new_manifest, snapshot_id=new_snapshot_id
1130 )
1131 await session.flush()
1132
1133 committed_at = datetime.datetime.now(datetime.timezone.utc)
1134 new_commit_id = compute_commit_id(
1135 parent_ids=[onto_commit_id],
1136 snapshot_id=new_snapshot_id,
1137 message=commit.message,
1138 committed_at_iso=committed_at.isoformat(),
1139 )
1140 new_commit = MuseCliCommit(
1141 commit_id=new_commit_id,
1142 repo_id=repo_id,
1143 branch=current_branch,
1144 parent_commit_id=onto_commit_id,
1145 snapshot_id=new_snapshot_id,
1146 message=commit.message,
1147 author=commit.author,
1148 committed_at=committed_at,
1149 commit_metadata=commit.commit_metadata,
1150 )
1151 await insert_commit(session, new_commit)
1152
1153 completed_pairs.append(
1154 RebaseCommitPair(
1155 original_commit_id=orig_cid,
1156 new_commit_id=new_commit_id,
1157 )
1158 )
1159 onto_manifest = new_manifest
1160 onto_commit_id = new_commit_id
1161
1162 # Advance branch pointer and clear state
1163 our_ref_path.write_text(onto_commit_id)
1164 clear_rebase_state(root)
1165
1166 upstream_commit_id = rebase_state.upstream_commit
1167 base_commit_id = rebase_state.base_commit
1168
1169 typer.echo(
1170 f"✅ Rebase continued: {len(completed_pairs)} commit(s) applied "
1171 f"[{current_branch} {onto_commit_id[:8]}]"
1172 )
1173 logger.info(
1174 "✅ muse rebase --continue: %d commit(s) on %r, now at %s",
1175 len(completed_pairs),
1176 current_branch,
1177 onto_commit_id[:8],
1178 )
1179
1180 return RebaseResult(
1181 branch=current_branch,
1182 upstream=rebase_state.upstream_commit,
1183 upstream_commit_id=upstream_commit_id,
1184 base_commit_id=base_commit_id,
1185 replayed=tuple(completed_pairs),
1186 conflict_paths=(),
1187 aborted=False,
1188 noop=False,
1189 autosquash_applied=False,
1190 )
1191
1192
1193 async def _rebase_abort_async(
1194 *,
1195 root: pathlib.Path,
1196 ) -> RebaseResult:
1197 """Abort an in-progress rebase and restore the branch to its original HEAD.
1198
1199 Reads ``REBASE_STATE.json``, restores the branch pointer to
1200 ``original_head``, and removes the state file.
1201
1202 Args:
1203 root: Repository root.
1204
1205 Returns:
1206 :class:`RebaseResult` with ``aborted=True``.
1207
1208 Raises:
1209 ``typer.Exit``: If no rebase is in progress.
1210 """
1211 rebase_state = read_rebase_state(root)
1212 if rebase_state is None:
1213 typer.echo("❌ No rebase in progress. Nothing to abort.")
1214 raise typer.Exit(code=ExitCode.USER_ERROR)
1215
1216 muse_dir = root / ".muse"
1217 current_branch = rebase_state.original_branch
1218 head_ref = f"refs/heads/{current_branch}"
1219 our_ref_path = muse_dir / pathlib.Path(head_ref)
1220
1221 our_ref_path.parent.mkdir(parents=True, exist_ok=True)
1222 our_ref_path.write_text(rebase_state.original_head)
1223
1224 clear_rebase_state(root)
1225
1226 typer.echo(
1227 f"✅ Rebase aborted. Branch {current_branch!r} restored to "
1228 f"{rebase_state.original_head[:8]}."
1229 )
1230 logger.info(
1231 "✅ muse rebase --abort: %r restored to %s",
1232 current_branch,
1233 rebase_state.original_head[:8],
1234 )
1235
1236 return RebaseResult(
1237 branch=current_branch,
1238 upstream="",
1239 upstream_commit_id=rebase_state.upstream_commit,
1240 base_commit_id=rebase_state.base_commit,
1241 replayed=(),
1242 conflict_paths=(),
1243 aborted=True,
1244 noop=False,
1245 autosquash_applied=False,
1246 )