muse_rebase.py
python
| 1 | """Muse Rebase Service — replay commits onto a new base. |
| 2 | |
| 3 | Algorithm |
| 4 | --------- |
| 5 | 1. Find the merge-base (LCA) of the current branch HEAD and ``<upstream>``. |
| 6 | 2. Collect the commits on the current branch that are *not* ancestors of |
| 7 | ``<upstream>`` (i.e., commits added since the LCA), ordered oldest first. |
| 8 | 3. For each such commit, compute the snapshot delta relative to its own parent, |
| 9 | then apply that delta on top of the current ``onto`` tip (which starts as |
| 10 | ``<upstream>`` HEAD and advances after each successful replay). |
| 11 | 4. Insert a new commit record (new commit_id because the parent has changed). |
| 12 | 5. Advance the branch pointer to the final replayed commit. |
| 13 | |
| 14 | Because Muse snapshots are content-addressed manifests (``{path: object_id}``), |
| 15 | ``apply_delta`` is a pure dict operation — no byte-level merge required. |
| 16 | |
| 17 | State file: ``.muse/REBASE_STATE.json`` |
| 18 | ----------------------------------------- |
| 19 | Written when a conflict is detected mid-replay so ``--continue`` and ``--abort`` |
| 20 | can resume or abandon the operation: |
| 21 | |
| 22 | .. code-block:: json |
| 23 | |
| 24 | { |
| 25 | "upstream_commit": "abc123...", |
| 26 | "base_commit": "def456...", |
| 27 | "original_branch": "feature", |
| 28 | "original_head": "ghi789...", |
| 29 | "commits_to_replay": ["cid1", "cid2", "cid3"], |
| 30 | "current_onto": "abc123...", |
| 31 | "completed_pairs": [["cid1", "new_cid1"]], |
| 32 | "current_commit": "cid2", |
| 33 | "conflict_paths": ["beat.mid"] |
| 34 | } |
| 35 | |
| 36 | Boundary rules: |
| 37 | - Must NOT import StateStore, EntityRegistry, or get_or_create_store. |
| 38 | - Must NOT import executor modules or maestro_* handlers. |
| 39 | - May import muse_cli.db, muse_cli.models, muse_cli.merge_engine, |
| 40 | muse_cli.snapshot. |
| 41 | |
| 42 | Domain analogy: a producer has 10 "fixup" commits from a late-night session. |
| 43 | ``muse rebase dev`` replays them cleanly onto the ``dev`` tip, producing a |
| 44 | linear history — the musical narrative stays readable as a sequence of |
| 45 | intentional variations. |
| 46 | """ |
| 47 | from __future__ import annotations |
| 48 | |
| 49 | import datetime |
| 50 | import json |
| 51 | import logging |
| 52 | import pathlib |
| 53 | from collections import deque |
| 54 | from dataclasses import dataclass, field |
| 55 | from typing import Optional |
| 56 | |
| 57 | import typer |
| 58 | from sqlalchemy.ext.asyncio import AsyncSession |
| 59 | |
| 60 | from maestro.muse_cli.db import ( |
| 61 | get_commit_snapshot_manifest, |
| 62 | insert_commit, |
| 63 | upsert_snapshot, |
| 64 | ) |
| 65 | from maestro.muse_cli.errors import ExitCode |
| 66 | from maestro.muse_cli.merge_engine import ( |
| 67 | detect_conflicts, |
| 68 | diff_snapshots, |
| 69 | ) |
| 70 | from maestro.muse_cli.models import MuseCliCommit |
| 71 | from maestro.muse_cli.snapshot import compute_commit_id, compute_snapshot_id |
| 72 | |
| 73 | logger = logging.getLogger(__name__) |
| 74 | |
| 75 | _REBASE_STATE_FILENAME = "REBASE_STATE.json" |
| 76 | |
| 77 | |
| 78 | # --------------------------------------------------------------------------- |
| 79 | # Result types |
| 80 | # --------------------------------------------------------------------------- |
| 81 | |
| 82 | |
| 83 | @dataclass(frozen=True) |
| 84 | class RebaseCommitPair: |
| 85 | """Maps an original commit to its replayed replacement. |
| 86 | |
| 87 | Attributes: |
| 88 | original_commit_id: SHA of the commit that existed before the rebase. |
| 89 | new_commit_id: SHA of the freshly-replayed commit with the new parent. |
| 90 | """ |
| 91 | |
| 92 | original_commit_id: str |
| 93 | new_commit_id: str |
| 94 | |
| 95 | |
| 96 | @dataclass(frozen=True) |
| 97 | class RebaseResult: |
| 98 | """Outcome of a ``muse rebase`` operation. |
| 99 | |
| 100 | Attributes: |
| 101 | branch: The branch that was rebased. |
| 102 | upstream: The upstream ref used as the new base. |
| 103 | upstream_commit_id: Resolved commit ID of the upstream tip. |
| 104 | base_commit_id: LCA commit where the histories diverged. |
| 105 | replayed: Ordered list of (original, new) commit pairs. |
| 106 | conflict_paths: Paths with unresolved conflicts (empty on success). |
| 107 | aborted: True when ``--abort`` cleared an in-progress rebase. |
| 108 | noop: True when no commits needed to be replayed. |
| 109 | autosquash_applied: True when ``--autosquash`` reordered commits. |
| 110 | """ |
| 111 | |
| 112 | branch: str |
| 113 | upstream: str |
| 114 | upstream_commit_id: str |
| 115 | base_commit_id: str |
| 116 | replayed: tuple[RebaseCommitPair, ...] |
| 117 | conflict_paths: tuple[str, ...] |
| 118 | aborted: bool |
| 119 | noop: bool |
| 120 | autosquash_applied: bool |
| 121 | |
| 122 | |
| 123 | # --------------------------------------------------------------------------- |
| 124 | # RebaseState — on-disk session record |
| 125 | # --------------------------------------------------------------------------- |
| 126 | |
| 127 | |
| 128 | @dataclass |
| 129 | class RebaseState: |
| 130 | """Describes an in-progress rebase with optional conflict information. |
| 131 | |
| 132 | Attributes: |
| 133 | upstream_commit: The tip of the upstream branch used as the new base. |
| 134 | base_commit: LCA where ours and upstream diverged. |
| 135 | original_branch: Name of the branch being rebased. |
| 136 | original_head: Branch HEAD before the rebase started (for ``--abort``). |
| 137 | commits_to_replay: All original commits to replay, oldest first. |
| 138 | current_onto: The commit ID of the current "onto" tip. |
| 139 | completed_pairs: Pairs of (original, new) commit IDs already replayed. |
| 140 | current_commit: The commit being applied when a conflict was detected. |
| 141 | conflict_paths: Paths with unresolved conflicts (empty when none). |
| 142 | """ |
| 143 | |
| 144 | upstream_commit: str |
| 145 | base_commit: str |
| 146 | original_branch: str |
| 147 | original_head: str |
| 148 | commits_to_replay: list[str] = field(default_factory=list) |
| 149 | current_onto: str = "" |
| 150 | completed_pairs: list[list[str]] = field(default_factory=list) |
| 151 | current_commit: str = "" |
| 152 | conflict_paths: list[str] = field(default_factory=list) |
| 153 | |
| 154 | |
| 155 | # --------------------------------------------------------------------------- |
| 156 | # Filesystem helpers |
| 157 | # --------------------------------------------------------------------------- |
| 158 | |
| 159 | |
| 160 | def read_rebase_state(root: pathlib.Path) -> RebaseState | None: |
| 161 | """Return :class:`RebaseState` if a rebase is in progress, else ``None``. |
| 162 | |
| 163 | Reads ``.muse/REBASE_STATE.json``. Returns ``None`` when the file does |
| 164 | not exist or cannot be parsed. |
| 165 | |
| 166 | Args: |
| 167 | root: Repository root (directory containing ``.muse/``). |
| 168 | """ |
| 169 | state_path = root / ".muse" / _REBASE_STATE_FILENAME |
| 170 | if not state_path.exists(): |
| 171 | return None |
| 172 | try: |
| 173 | raw: dict[str, object] = json.loads(state_path.read_text()) |
| 174 | except (json.JSONDecodeError, OSError) as exc: |
| 175 | logger.warning("⚠️ Failed to read %s: %s", _REBASE_STATE_FILENAME, exc) |
| 176 | return None |
| 177 | |
| 178 | def _str(key: str, default: str = "") -> str: |
| 179 | v = raw.get(key, default) |
| 180 | return str(v) if v is not None else default |
| 181 | |
| 182 | def _strlist(key: str) -> list[str]: |
| 183 | v = raw.get(key, []) |
| 184 | return [str(x) for x in v] if isinstance(v, list) else [] |
| 185 | |
| 186 | def _pairlist(key: str) -> list[list[str]]: |
| 187 | v = raw.get(key, []) |
| 188 | if not isinstance(v, list): |
| 189 | return [] |
| 190 | result: list[list[str]] = [] |
| 191 | for item in v: |
| 192 | if isinstance(item, list) and len(item) == 2: |
| 193 | result.append([str(item[0]), str(item[1])]) |
| 194 | return result |
| 195 | |
| 196 | return RebaseState( |
| 197 | upstream_commit=_str("upstream_commit"), |
| 198 | base_commit=_str("base_commit"), |
| 199 | original_branch=_str("original_branch"), |
| 200 | original_head=_str("original_head"), |
| 201 | commits_to_replay=_strlist("commits_to_replay"), |
| 202 | current_onto=_str("current_onto"), |
| 203 | completed_pairs=_pairlist("completed_pairs"), |
| 204 | current_commit=_str("current_commit"), |
| 205 | conflict_paths=_strlist("conflict_paths"), |
| 206 | ) |
| 207 | |
| 208 | |
| 209 | def write_rebase_state(root: pathlib.Path, state: RebaseState) -> None: |
| 210 | """Persist *state* to ``.muse/REBASE_STATE.json``. |
| 211 | |
| 212 | Args: |
| 213 | root: Repository root. |
| 214 | state: Current rebase session state. |
| 215 | """ |
| 216 | state_path = root / ".muse" / _REBASE_STATE_FILENAME |
| 217 | data: dict[str, object] = { |
| 218 | "upstream_commit": state.upstream_commit, |
| 219 | "base_commit": state.base_commit, |
| 220 | "original_branch": state.original_branch, |
| 221 | "original_head": state.original_head, |
| 222 | "commits_to_replay": state.commits_to_replay, |
| 223 | "current_onto": state.current_onto, |
| 224 | "completed_pairs": state.completed_pairs, |
| 225 | "current_commit": state.current_commit, |
| 226 | "conflict_paths": state.conflict_paths, |
| 227 | } |
| 228 | state_path.write_text(json.dumps(data, indent=2)) |
| 229 | logger.info( |
| 230 | "✅ Wrote REBASE_STATE.json (%d remaining, %d done)", |
| 231 | len(state.commits_to_replay), |
| 232 | len(state.completed_pairs), |
| 233 | ) |
| 234 | |
| 235 | |
| 236 | def clear_rebase_state(root: pathlib.Path) -> None: |
| 237 | """Remove ``.muse/REBASE_STATE.json`` after a successful or aborted rebase.""" |
| 238 | state_path = root / ".muse" / _REBASE_STATE_FILENAME |
| 239 | if state_path.exists(): |
| 240 | state_path.unlink() |
| 241 | logger.debug("✅ Cleared REBASE_STATE.json") |
| 242 | |
| 243 | |
| 244 | # --------------------------------------------------------------------------- |
| 245 | # Pure helpers |
| 246 | # --------------------------------------------------------------------------- |
| 247 | |
| 248 | |
| 249 | def compute_delta( |
| 250 | parent_manifest: dict[str, str], |
| 251 | commit_manifest: dict[str, str], |
| 252 | ) -> tuple[dict[str, str], set[str]]: |
| 253 | """Compute the file-level changes introduced by a single commit. |
| 254 | |
| 255 | Args: |
| 256 | parent_manifest: Snapshot of the commit's parent. |
| 257 | commit_manifest: Snapshot of the commit itself. |
| 258 | |
| 259 | Returns: |
| 260 | Tuple of (additions_and_modifications, deletions): |
| 261 | - additions_and_modifications: ``{path: object_id}`` for paths added |
| 262 | or changed in *commit_manifest* relative to *parent_manifest*. |
| 263 | - deletions: Set of paths present in *parent_manifest* but absent from |
| 264 | *commit_manifest*. |
| 265 | """ |
| 266 | changed_paths = diff_snapshots(parent_manifest, commit_manifest) |
| 267 | additions: dict[str, str] = {} |
| 268 | deletions: set[str] = set() |
| 269 | for path in changed_paths: |
| 270 | if path in commit_manifest: |
| 271 | additions[path] = commit_manifest[path] |
| 272 | else: |
| 273 | deletions.add(path) |
| 274 | return additions, deletions |
| 275 | |
| 276 | |
| 277 | def apply_delta( |
| 278 | onto_manifest: dict[str, str], |
| 279 | additions: dict[str, str], |
| 280 | deletions: set[str], |
| 281 | ) -> dict[str, str]: |
| 282 | """Apply a commit delta onto an ``onto`` snapshot manifest. |
| 283 | |
| 284 | Produces a new manifest that represents the onto-manifest with the |
| 285 | same file changes that the original commit introduced over its parent. |
| 286 | |
| 287 | Args: |
| 288 | onto_manifest: The current tip manifest to patch. |
| 289 | additions: Paths added or changed by the original commit. |
| 290 | deletions: Paths removed by the original commit. |
| 291 | |
| 292 | Returns: |
| 293 | New manifest dict (copy of *onto_manifest* with delta applied). |
| 294 | """ |
| 295 | result = dict(onto_manifest) |
| 296 | result.update(additions) |
| 297 | for path in deletions: |
| 298 | result.pop(path, None) |
| 299 | return result |
| 300 | |
| 301 | |
| 302 | def detect_rebase_conflicts( |
| 303 | onto_manifest: dict[str, str], |
| 304 | prev_onto_manifest: dict[str, str], |
| 305 | additions: dict[str, str], |
| 306 | deletions: set[str], |
| 307 | ) -> set[str]: |
| 308 | """Identify conflicts between the commit delta and changes on ``onto``. |
| 309 | |
| 310 | A conflict occurs when a path was changed both in the commit being replayed |
| 311 | (relative to its parent) and in the onto branch (relative to the base). |
| 312 | |
| 313 | Args: |
| 314 | onto_manifest: Current onto tip. |
| 315 | prev_onto_manifest: The onto state just before this replay step |
| 316 | (i.e. the merge base or the previous onto tip). |
| 317 | additions: Paths added/modified by the commit being replayed. |
| 318 | deletions: Paths deleted by the commit being replayed. |
| 319 | |
| 320 | Returns: |
| 321 | Set of conflicting paths. |
| 322 | """ |
| 323 | onto_changed = diff_snapshots(prev_onto_manifest, onto_manifest) |
| 324 | commit_changed = set(additions.keys()) | deletions |
| 325 | return detect_conflicts(onto_changed, commit_changed) |
| 326 | |
| 327 | |
| 328 | async def _collect_branch_commits_since_base( |
| 329 | session: AsyncSession, |
| 330 | head_commit_id: str, |
| 331 | base_commit_id: str, |
| 332 | ) -> list[MuseCliCommit]: |
| 333 | """Collect commits reachable from *head_commit_id* but not from *base_commit_id*. |
| 334 | |
| 335 | Returns them in topological order, oldest first (replay order). Merge |
| 336 | commits (two parents) are included as single units; their second parent is |
| 337 | not traversed — i.e. only the primary-parent chain is followed. |
| 338 | |
| 339 | Args: |
| 340 | session: Open async DB session. |
| 341 | head_commit_id: The current branch HEAD. |
| 342 | base_commit_id: The LCA — commits at or before this are excluded. |
| 343 | |
| 344 | Returns: |
| 345 | List of :class:`MuseCliCommit` rows in replay order. |
| 346 | """ |
| 347 | commits_reversed: list[MuseCliCommit] = [] |
| 348 | seen: set[str] = set() |
| 349 | queue: deque[str] = deque([head_commit_id]) |
| 350 | |
| 351 | while queue: |
| 352 | cid = queue.popleft() |
| 353 | if cid in seen or cid == base_commit_id: |
| 354 | continue |
| 355 | seen.add(cid) |
| 356 | commit = await session.get(MuseCliCommit, cid) |
| 357 | if commit is None: |
| 358 | break |
| 359 | commits_reversed.append(commit) |
| 360 | if commit.parent_commit_id and commit.parent_commit_id != base_commit_id: |
| 361 | queue.append(commit.parent_commit_id) |
| 362 | elif commit.parent_commit_id == base_commit_id: |
| 363 | # Include this commit but stop traversal here |
| 364 | pass |
| 365 | |
| 366 | # BFS gives newest-first; reverse to get oldest-first replay order. |
| 367 | return list(reversed(commits_reversed)) |
| 368 | |
| 369 | |
| 370 | async def _find_merge_base_rebase( |
| 371 | session: AsyncSession, |
| 372 | commit_id_a: str, |
| 373 | commit_id_b: str, |
| 374 | ) -> str | None: |
| 375 | """Lowest common ancestor of two commits — thin wrapper used by the rebase. |
| 376 | |
| 377 | Args: |
| 378 | session: Open async DB session. |
| 379 | commit_id_a: First commit ID (current branch HEAD). |
| 380 | commit_id_b: Second commit ID (upstream tip). |
| 381 | |
| 382 | Returns: |
| 383 | LCA commit ID, or ``None`` if histories are disjoint. |
| 384 | """ |
| 385 | from maestro.muse_cli.merge_engine import find_merge_base |
| 386 | |
| 387 | return await find_merge_base(session, commit_id_a, commit_id_b) |
| 388 | |
| 389 | |
| 390 | def apply_autosquash(commits: list[MuseCliCommit]) -> tuple[list[MuseCliCommit], bool]: |
| 391 | """Reorder and flag fixup commits for autosquash. |
| 392 | |
| 393 | Detects commits whose message starts with ``fixup! <msg>`` and moves them |
| 394 | immediately after the matching commit (matched by prefix of ``<msg>``). |
| 395 | |
| 396 | Args: |
| 397 | commits: Commits in replay order (oldest first). |
| 398 | |
| 399 | Returns: |
| 400 | Tuple of (reordered_commits, was_reordered). |
| 401 | """ |
| 402 | # Build index of message → position for non-fixup commits |
| 403 | reordered: list[MuseCliCommit] = [] |
| 404 | fixups: dict[str, list[MuseCliCommit]] = {} |
| 405 | |
| 406 | for commit in commits: |
| 407 | if commit.message.startswith("fixup! "): |
| 408 | target_msg = commit.message[len("fixup! "):] |
| 409 | fixups.setdefault(target_msg, []).append(commit) |
| 410 | else: |
| 411 | reordered.append(commit) |
| 412 | |
| 413 | if not fixups: |
| 414 | return commits, False |
| 415 | |
| 416 | # Insert fixup commits immediately after their targets |
| 417 | result: list[MuseCliCommit] = [] |
| 418 | for commit in reordered: |
| 419 | result.append(commit) |
| 420 | # Match by prefix of commit message |
| 421 | for target_msg, fixup_list in list(fixups.items()): |
| 422 | if commit.message.startswith(target_msg): |
| 423 | result.extend(fixup_list) |
| 424 | del fixups[target_msg] |
| 425 | |
| 426 | # Any unmatched fixups go at the end |
| 427 | for fixup_list in fixups.values(): |
| 428 | result.extend(fixup_list) |
| 429 | |
| 430 | return result, True |
| 431 | |
| 432 | |
| 433 | # --------------------------------------------------------------------------- |
| 434 | # Interactive plan |
| 435 | # --------------------------------------------------------------------------- |
| 436 | |
| 437 | |
| 438 | class InteractivePlan: |
| 439 | """A parsed interactive rebase plan. |
| 440 | |
| 441 | Plan lines have the format:: |
| 442 | |
| 443 | <action> <short-sha> <message> |
| 444 | |
| 445 | Supported actions: ``pick``, ``squash``, ``drop``. |
| 446 | |
| 447 | Attributes: |
| 448 | entries: List of (action, commit_id, message) tuples in plan order. |
| 449 | """ |
| 450 | |
| 451 | VALID_ACTIONS = frozenset({"pick", "squash", "drop", "fixup", "reword"}) |
| 452 | |
| 453 | def __init__( |
| 454 | self, |
| 455 | entries: list[tuple[str, str, str]], |
| 456 | ) -> None: |
| 457 | """Create a plan from parsed entries. |
| 458 | |
| 459 | Args: |
| 460 | entries: List of (action, commit_id_prefix, message) tuples. |
| 461 | """ |
| 462 | self.entries = entries |
| 463 | |
| 464 | @classmethod |
| 465 | def from_text(cls, text: str) -> InteractivePlan: |
| 466 | """Parse a plan from the editor text. |
| 467 | |
| 468 | Lines starting with ``#`` are comments and are ignored. Blank lines |
| 469 | are ignored. Each non-comment line must be ``<action> <sha> <msg>``. |
| 470 | |
| 471 | Args: |
| 472 | text: Raw plan text as produced by :meth:`to_text`. |
| 473 | |
| 474 | Returns: |
| 475 | Parsed :class:`InteractivePlan`. |
| 476 | |
| 477 | Raises: |
| 478 | ValueError: If a line has an unrecognised action or missing fields. |
| 479 | """ |
| 480 | entries: list[tuple[str, str, str]] = [] |
| 481 | for raw_line in text.splitlines(): |
| 482 | line = raw_line.strip() |
| 483 | if not line or line.startswith("#"): |
| 484 | continue |
| 485 | parts = line.split(None, 2) |
| 486 | if len(parts) < 2: |
| 487 | raise ValueError(f"Invalid plan line: {raw_line!r}") |
| 488 | action = parts[0].lower() |
| 489 | sha = parts[1] |
| 490 | msg = parts[2] if len(parts) > 2 else "" |
| 491 | if action not in cls.VALID_ACTIONS: |
| 492 | raise ValueError(f"Unknown action {action!r} in line: {raw_line!r}") |
| 493 | entries.append((action, sha, msg)) |
| 494 | return cls(entries) |
| 495 | |
| 496 | @classmethod |
| 497 | def from_commits(cls, commits: list[MuseCliCommit]) -> InteractivePlan: |
| 498 | """Build a default plan (all ``pick``) from a list of commits. |
| 499 | |
| 500 | Args: |
| 501 | commits: Commits in replay order. |
| 502 | |
| 503 | Returns: |
| 504 | :class:`InteractivePlan` with one ``pick`` entry per commit. |
| 505 | """ |
| 506 | entries: list[tuple[str, str, str]] = [] |
| 507 | for commit in commits: |
| 508 | entries.append(("pick", commit.commit_id[:8], commit.message)) |
| 509 | return cls(entries) |
| 510 | |
| 511 | def to_text(self) -> str: |
| 512 | """Render the plan to a human-editable text format.""" |
| 513 | lines = [ |
| 514 | "# Interactive rebase plan.", |
| 515 | "# Actions: pick, squash (fold into previous), drop (skip), fixup (squash no msg), reword", |
| 516 | "# Lines starting with '#' are ignored.", |
| 517 | "", |
| 518 | ] |
| 519 | for action, sha, msg in self.entries: |
| 520 | lines.append(f"{action} {sha} {msg}") |
| 521 | return "\n".join(lines) + "\n" |
| 522 | |
| 523 | def resolve_against( |
| 524 | self, commits: list[MuseCliCommit] |
| 525 | ) -> list[tuple[str, MuseCliCommit]]: |
| 526 | """Match plan entries to the full commit list by SHA prefix. |
| 527 | |
| 528 | Args: |
| 529 | commits: Original commits list (source of truth for full SHA). |
| 530 | |
| 531 | Returns: |
| 532 | List of (action, commit) pairs in plan order, excluding dropped |
| 533 | commits. |
| 534 | |
| 535 | Raises: |
| 536 | ValueError: If a plan SHA prefix matches no commit or is ambiguous. |
| 537 | """ |
| 538 | resolved: list[tuple[str, MuseCliCommit]] = [] |
| 539 | for action, sha_prefix, _msg in self.entries: |
| 540 | if action == "drop": |
| 541 | continue |
| 542 | matches = [c for c in commits if c.commit_id.startswith(sha_prefix)] |
| 543 | if not matches: |
| 544 | raise ValueError( |
| 545 | f"Plan SHA {sha_prefix!r} does not match any commit in the rebase range." |
| 546 | ) |
| 547 | if len(matches) > 1: |
| 548 | raise ValueError( |
| 549 | f"Plan SHA {sha_prefix!r} is ambiguous — matches {len(matches)} commits." |
| 550 | ) |
| 551 | resolved.append((action, matches[0])) |
| 552 | return resolved |
| 553 | |
| 554 | |
| 555 | # --------------------------------------------------------------------------- |
| 556 | # Async rebase core — single-step replay |
| 557 | # --------------------------------------------------------------------------- |
| 558 | |
| 559 | |
| 560 | async def _replay_one_commit( |
| 561 | *, |
| 562 | session: AsyncSession, |
| 563 | commit: MuseCliCommit, |
| 564 | onto_manifest: dict[str, str], |
| 565 | prev_onto_manifest: dict[str, str], |
| 566 | onto_commit_id: str, |
| 567 | branch: str, |
| 568 | ) -> tuple[str, dict[str, str], list[str]]: |
| 569 | """Replay a single commit onto the current onto tip. |
| 570 | |
| 571 | Computes the delta the original commit introduced over its parent, applies |
| 572 | it to *onto_manifest*, detects conflicts, persists the new snapshot, and |
| 573 | inserts a new commit record. |
| 574 | |
| 575 | Args: |
| 576 | session: Open async DB session. |
| 577 | commit: The original commit to replay. |
| 578 | onto_manifest: Snapshot manifest of the current onto tip. |
| 579 | prev_onto_manifest: Manifest of the onto base (for conflict detection). |
| 580 | onto_commit_id: Commit ID of the current onto tip. |
| 581 | branch: Branch being rebased (for the new commit record). |
| 582 | |
| 583 | Returns: |
| 584 | Tuple of (new_commit_id, new_onto_manifest, conflict_paths): |
| 585 | - ``new_commit_id``: SHA of the newly inserted commit. |
| 586 | - ``new_onto_manifest``: Updated manifest for the next step. |
| 587 | - ``conflict_paths``: Empty list on success; non-empty on conflict. |
| 588 | """ |
| 589 | # Resolve parent manifest for the original commit |
| 590 | parent_manifest: dict[str, str] = {} |
| 591 | if commit.parent_commit_id: |
| 592 | loaded = await get_commit_snapshot_manifest(session, commit.parent_commit_id) |
| 593 | if loaded is not None: |
| 594 | parent_manifest = loaded |
| 595 | |
| 596 | commit_manifest = await get_commit_snapshot_manifest(session, commit.commit_id) |
| 597 | if commit_manifest is None: |
| 598 | commit_manifest = {} |
| 599 | |
| 600 | additions, deletions = compute_delta(parent_manifest, commit_manifest) |
| 601 | |
| 602 | conflict_paths = detect_rebase_conflicts( |
| 603 | onto_manifest=onto_manifest, |
| 604 | prev_onto_manifest=prev_onto_manifest, |
| 605 | additions=additions, |
| 606 | deletions=deletions, |
| 607 | ) |
| 608 | if conflict_paths: |
| 609 | return "", onto_manifest, sorted(conflict_paths) |
| 610 | |
| 611 | new_manifest = apply_delta(onto_manifest, additions, deletions) |
| 612 | new_snapshot_id = compute_snapshot_id(new_manifest) |
| 613 | await upsert_snapshot(session, manifest=new_manifest, snapshot_id=new_snapshot_id) |
| 614 | await session.flush() |
| 615 | |
| 616 | committed_at = datetime.datetime.now(datetime.timezone.utc) |
| 617 | new_commit_id = compute_commit_id( |
| 618 | parent_ids=[onto_commit_id], |
| 619 | snapshot_id=new_snapshot_id, |
| 620 | message=commit.message, |
| 621 | committed_at_iso=committed_at.isoformat(), |
| 622 | ) |
| 623 | |
| 624 | new_commit = MuseCliCommit( |
| 625 | commit_id=new_commit_id, |
| 626 | repo_id=commit.repo_id, |
| 627 | branch=branch, |
| 628 | parent_commit_id=onto_commit_id, |
| 629 | snapshot_id=new_snapshot_id, |
| 630 | message=commit.message, |
| 631 | author=commit.author, |
| 632 | committed_at=committed_at, |
| 633 | commit_metadata=commit.commit_metadata, |
| 634 | ) |
| 635 | await insert_commit(session, new_commit) |
| 636 | |
| 637 | return new_commit_id, new_manifest, [] |
| 638 | |
| 639 | |
| 640 | # --------------------------------------------------------------------------- |
| 641 | # Async rebase core — full pipeline |
| 642 | # --------------------------------------------------------------------------- |
| 643 | |
| 644 | |
| 645 | async def _rebase_async( |
| 646 | *, |
| 647 | upstream: str, |
| 648 | root: pathlib.Path, |
| 649 | session: AsyncSession, |
| 650 | interactive: bool = False, |
| 651 | autosquash: bool = False, |
| 652 | rebase_merges: bool = False, |
| 653 | ) -> RebaseResult: |
| 654 | """Run the rebase pipeline. |
| 655 | |
| 656 | All filesystem and DB side-effects are isolated here so tests can inject |
| 657 | an in-memory SQLite session and a ``tmp_path`` root without touching a |
| 658 | real database. |
| 659 | |
| 660 | Args: |
| 661 | upstream: Branch name or commit ID to rebase onto. |
| 662 | root: Repository root (directory containing ``.muse/``). |
| 663 | session: Open async DB session. |
| 664 | interactive: When ``True``, open $EDITOR with the rebase plan before |
| 665 | executing. The edited plan controls action, order, and squash |
| 666 | behaviour. |
| 667 | autosquash: When ``True``, automatically detect ``fixup!`` commits and |
| 668 | move them after their matching target commit. |
| 669 | rebase_merges: When ``True``, preserve merge commits during replay |
| 670 | (stub — see implementation note below). |
| 671 | |
| 672 | Returns: |
| 673 | :class:`RebaseResult` describing what happened. |
| 674 | |
| 675 | Raises: |
| 676 | ``typer.Exit`` with an appropriate exit code on user-facing errors. |
| 677 | """ |
| 678 | import json as _json |
| 679 | |
| 680 | muse_dir = root / ".muse" |
| 681 | |
| 682 | # ── Guard: rebase already in progress ─────────────────────────────── |
| 683 | if read_rebase_state(root) is not None: |
| 684 | typer.echo( |
| 685 | "❌ Rebase in progress. Use --continue to resume or --abort to cancel." |
| 686 | ) |
| 687 | raise typer.Exit(code=ExitCode.USER_ERROR) |
| 688 | |
| 689 | # ── Repo identity ──────────────────────────────────────────────────── |
| 690 | repo_data: dict[str, str] = _json.loads((muse_dir / "repo.json").read_text()) |
| 691 | repo_id = repo_data["repo_id"] |
| 692 | |
| 693 | # ── Current branch ─────────────────────────────────────────────────── |
| 694 | head_ref = (muse_dir / "HEAD").read_text().strip() |
| 695 | current_branch = head_ref.rsplit("/", 1)[-1] |
| 696 | our_ref_path = muse_dir / pathlib.Path(head_ref) |
| 697 | |
| 698 | ours_commit_id = our_ref_path.read_text().strip() if our_ref_path.exists() else "" |
| 699 | if not ours_commit_id: |
| 700 | typer.echo("❌ Current branch has no commits. Cannot rebase.") |
| 701 | raise typer.Exit(code=ExitCode.USER_ERROR) |
| 702 | |
| 703 | # ── Resolve upstream ───────────────────────────────────────────────── |
| 704 | # Try as a branch name first, then as a raw commit ID |
| 705 | upstream_ref_path = muse_dir / "refs" / "heads" / upstream |
| 706 | if upstream_ref_path.exists(): |
| 707 | upstream_commit_id = upstream_ref_path.read_text().strip() |
| 708 | else: |
| 709 | # Might be a raw commit ID |
| 710 | candidate = await session.get(MuseCliCommit, upstream) |
| 711 | if candidate is None: |
| 712 | typer.echo( |
| 713 | f"❌ Upstream {upstream!r} is not a known branch or commit ID." |
| 714 | ) |
| 715 | raise typer.Exit(code=ExitCode.USER_ERROR) |
| 716 | upstream_commit_id = candidate.commit_id |
| 717 | |
| 718 | # ── Already up-to-date guard ───────────────────────────────────────── |
| 719 | if ours_commit_id == upstream_commit_id: |
| 720 | typer.echo("Already up-to-date.") |
| 721 | return RebaseResult( |
| 722 | branch=current_branch, |
| 723 | upstream=upstream, |
| 724 | upstream_commit_id=upstream_commit_id, |
| 725 | base_commit_id=upstream_commit_id, |
| 726 | replayed=(), |
| 727 | conflict_paths=(), |
| 728 | aborted=False, |
| 729 | noop=True, |
| 730 | autosquash_applied=False, |
| 731 | ) |
| 732 | |
| 733 | # ── Find merge base ────────────────────────────────────────────────── |
| 734 | base_commit_id = await _find_merge_base_rebase( |
| 735 | session, ours_commit_id, upstream_commit_id |
| 736 | ) |
| 737 | if base_commit_id is None: |
| 738 | typer.echo( |
| 739 | f"❌ Cannot find a common ancestor between current branch and {upstream!r}. " |
| 740 | "Histories are disjoint — use 'muse merge' instead." |
| 741 | ) |
| 742 | raise typer.Exit(code=ExitCode.USER_ERROR) |
| 743 | |
| 744 | # ── Fast-forward: current branch IS the base → nothing to replay ───── |
| 745 | if base_commit_id == ours_commit_id: |
| 746 | # Current branch is behind upstream — just advance the pointer |
| 747 | our_ref_path.write_text(upstream_commit_id) |
| 748 | typer.echo( |
| 749 | f"✅ Fast-forward: {current_branch} → {upstream_commit_id[:8]}" |
| 750 | ) |
| 751 | return RebaseResult( |
| 752 | branch=current_branch, |
| 753 | upstream=upstream, |
| 754 | upstream_commit_id=upstream_commit_id, |
| 755 | base_commit_id=base_commit_id, |
| 756 | replayed=(), |
| 757 | conflict_paths=(), |
| 758 | aborted=False, |
| 759 | noop=True, |
| 760 | autosquash_applied=False, |
| 761 | ) |
| 762 | |
| 763 | # ── Already up-to-date: upstream IS the base → we are ahead ────────── |
| 764 | if base_commit_id == upstream_commit_id: |
| 765 | typer.echo("Already up-to-date.") |
| 766 | return RebaseResult( |
| 767 | branch=current_branch, |
| 768 | upstream=upstream, |
| 769 | upstream_commit_id=upstream_commit_id, |
| 770 | base_commit_id=base_commit_id, |
| 771 | replayed=(), |
| 772 | conflict_paths=(), |
| 773 | aborted=False, |
| 774 | noop=True, |
| 775 | autosquash_applied=False, |
| 776 | ) |
| 777 | |
| 778 | # ── Collect commits to replay ───────────────────────────────────────── |
| 779 | commits_to_replay = await _collect_branch_commits_since_base( |
| 780 | session, ours_commit_id, base_commit_id |
| 781 | ) |
| 782 | |
| 783 | if not commits_to_replay: |
| 784 | typer.echo("Nothing to rebase.") |
| 785 | return RebaseResult( |
| 786 | branch=current_branch, |
| 787 | upstream=upstream, |
| 788 | upstream_commit_id=upstream_commit_id, |
| 789 | base_commit_id=base_commit_id, |
| 790 | replayed=(), |
| 791 | conflict_paths=(), |
| 792 | aborted=False, |
| 793 | noop=True, |
| 794 | autosquash_applied=False, |
| 795 | ) |
| 796 | |
| 797 | autosquash_applied = False |
| 798 | |
| 799 | # ── Autosquash ──────────────────────────────────────────────────────── |
| 800 | if autosquash: |
| 801 | commits_to_replay, autosquash_applied = apply_autosquash(commits_to_replay) |
| 802 | |
| 803 | # ── Interactive plan ────────────────────────────────────────────────── |
| 804 | plan_actions: list[tuple[str, MuseCliCommit]] = [ |
| 805 | ("pick", c) for c in commits_to_replay |
| 806 | ] |
| 807 | |
| 808 | if interactive: |
| 809 | import os |
| 810 | import subprocess |
| 811 | import tempfile |
| 812 | |
| 813 | plan = InteractivePlan.from_commits(commits_to_replay) |
| 814 | with tempfile.NamedTemporaryFile( |
| 815 | mode="w", suffix=".rebase-plan", delete=False |
| 816 | ) as tf: |
| 817 | tf.write(plan.to_text()) |
| 818 | tf_path = tf.name |
| 819 | |
| 820 | editor = os.environ.get("EDITOR", os.environ.get("VISUAL", "vi")) |
| 821 | result = subprocess.run([editor, tf_path]) |
| 822 | if result.returncode != 0: |
| 823 | typer.echo("⚠️ Editor exited with non-zero code — rebase aborted.") |
| 824 | raise typer.Exit(code=ExitCode.USER_ERROR) |
| 825 | |
| 826 | edited_text = pathlib.Path(tf_path).read_text() |
| 827 | pathlib.Path(tf_path).unlink(missing_ok=True) |
| 828 | |
| 829 | try: |
| 830 | edited_plan = InteractivePlan.from_text(edited_text) |
| 831 | plan_actions = edited_plan.resolve_against(commits_to_replay) |
| 832 | except ValueError as exc: |
| 833 | typer.echo(f"❌ Invalid rebase plan: {exc}") |
| 834 | raise typer.Exit(code=ExitCode.USER_ERROR) |
| 835 | |
| 836 | if not plan_actions: |
| 837 | typer.echo("Nothing to do.") |
| 838 | raise typer.Exit(code=ExitCode.SUCCESS) |
| 839 | |
| 840 | # ── Resolve base and upstream manifests ─────────────────────────────── |
| 841 | base_manifest = await get_commit_snapshot_manifest(session, base_commit_id) or {} |
| 842 | onto_manifest = ( |
| 843 | await get_commit_snapshot_manifest(session, upstream_commit_id) or {} |
| 844 | ) |
| 845 | onto_commit_id = upstream_commit_id |
| 846 | prev_onto_manifest = base_manifest |
| 847 | |
| 848 | completed_pairs: list[RebaseCommitPair] = [] |
| 849 | pending_squash_manifest: dict[str, str] | None = None |
| 850 | pending_squash_commits: list[MuseCliCommit] = [] |
| 851 | |
| 852 | for action, commit in plan_actions: |
| 853 | if action == "drop": |
| 854 | continue |
| 855 | |
| 856 | if action in ("squash", "fixup"): |
| 857 | # Accumulate into squash group |
| 858 | pending_squash_commits.append(commit) |
| 859 | if pending_squash_manifest is None: |
| 860 | # First in group — compute its delta |
| 861 | parent_manifest: dict[str, str] = {} |
| 862 | if commit.parent_commit_id: |
| 863 | loaded = await get_commit_snapshot_manifest( |
| 864 | session, commit.parent_commit_id |
| 865 | ) |
| 866 | if loaded is not None: |
| 867 | parent_manifest = loaded |
| 868 | commit_manifest = ( |
| 869 | await get_commit_snapshot_manifest(session, commit.commit_id) or {} |
| 870 | ) |
| 871 | additions, deletions = compute_delta(parent_manifest, commit_manifest) |
| 872 | pending_squash_manifest = apply_delta( |
| 873 | onto_manifest, additions, deletions |
| 874 | ) |
| 875 | else: |
| 876 | # Add this commit's changes on top of the squash in progress |
| 877 | parent_manifest_sq: dict[str, str] = {} |
| 878 | if commit.parent_commit_id: |
| 879 | loaded_sq = await get_commit_snapshot_manifest( |
| 880 | session, commit.parent_commit_id |
| 881 | ) |
| 882 | if loaded_sq is not None: |
| 883 | parent_manifest_sq = loaded_sq |
| 884 | commit_manifest_sq = ( |
| 885 | await get_commit_snapshot_manifest(session, commit.commit_id) or {} |
| 886 | ) |
| 887 | additions_sq, deletions_sq = compute_delta( |
| 888 | parent_manifest_sq, commit_manifest_sq |
| 889 | ) |
| 890 | pending_squash_manifest = apply_delta( |
| 891 | pending_squash_manifest, additions_sq, deletions_sq |
| 892 | ) |
| 893 | continue |
| 894 | |
| 895 | # ── Flush pending squash group (if any) ─────────────────────── |
| 896 | if pending_squash_commits and pending_squash_manifest is not None: |
| 897 | squash_message = pending_squash_commits[0].message |
| 898 | squash_snap_id = compute_snapshot_id(pending_squash_manifest) |
| 899 | await upsert_snapshot( |
| 900 | session, manifest=pending_squash_manifest, snapshot_id=squash_snap_id |
| 901 | ) |
| 902 | await session.flush() |
| 903 | |
| 904 | squash_at = datetime.datetime.now(datetime.timezone.utc) |
| 905 | squash_commit_id = compute_commit_id( |
| 906 | parent_ids=[onto_commit_id], |
| 907 | snapshot_id=squash_snap_id, |
| 908 | message=squash_message, |
| 909 | committed_at_iso=squash_at.isoformat(), |
| 910 | ) |
| 911 | squash_commit = MuseCliCommit( |
| 912 | commit_id=squash_commit_id, |
| 913 | repo_id=pending_squash_commits[0].repo_id, |
| 914 | branch=current_branch, |
| 915 | parent_commit_id=onto_commit_id, |
| 916 | snapshot_id=squash_snap_id, |
| 917 | message=squash_message, |
| 918 | author=pending_squash_commits[0].author, |
| 919 | committed_at=squash_at, |
| 920 | ) |
| 921 | await insert_commit(session, squash_commit) |
| 922 | |
| 923 | for orig in pending_squash_commits: |
| 924 | completed_pairs.append( |
| 925 | RebaseCommitPair( |
| 926 | original_commit_id=orig.commit_id, |
| 927 | new_commit_id=squash_commit_id, |
| 928 | ) |
| 929 | ) |
| 930 | onto_manifest = pending_squash_manifest |
| 931 | onto_commit_id = squash_commit_id |
| 932 | pending_squash_commits = [] |
| 933 | pending_squash_manifest = None |
| 934 | |
| 935 | # ── Normal pick ──────────────────────────────────────────────── |
| 936 | new_commit_id, new_manifest, conflict_paths_list = await _replay_one_commit( |
| 937 | session=session, |
| 938 | commit=commit, |
| 939 | onto_manifest=onto_manifest, |
| 940 | prev_onto_manifest=prev_onto_manifest, |
| 941 | onto_commit_id=onto_commit_id, |
| 942 | branch=current_branch, |
| 943 | ) |
| 944 | |
| 945 | if conflict_paths_list: |
| 946 | # Persist state and exit with conflict |
| 947 | remaining_ids = [ |
| 948 | c.commit_id |
| 949 | for _, c in plan_actions[ |
| 950 | plan_actions.index((action, commit)) + 1 : |
| 951 | ] |
| 952 | ] |
| 953 | state = RebaseState( |
| 954 | upstream_commit=upstream_commit_id, |
| 955 | base_commit=base_commit_id, |
| 956 | original_branch=current_branch, |
| 957 | original_head=ours_commit_id, |
| 958 | commits_to_replay=remaining_ids, |
| 959 | current_onto=onto_commit_id, |
| 960 | completed_pairs=[ |
| 961 | [p.original_commit_id, p.new_commit_id] for p in completed_pairs |
| 962 | ], |
| 963 | current_commit=commit.commit_id, |
| 964 | conflict_paths=conflict_paths_list, |
| 965 | ) |
| 966 | write_rebase_state(root, state) |
| 967 | |
| 968 | typer.echo( |
| 969 | f"❌ Conflict while replaying {commit.commit_id[:8]} ({commit.message!r}):\n" |
| 970 | + "\n".join(f"\tboth modified: {p}" for p in conflict_paths_list) |
| 971 | + "\nResolve conflicts, then run 'muse rebase --continue'." |
| 972 | ) |
| 973 | raise typer.Exit(code=ExitCode.USER_ERROR) |
| 974 | |
| 975 | completed_pairs.append( |
| 976 | RebaseCommitPair( |
| 977 | original_commit_id=commit.commit_id, |
| 978 | new_commit_id=new_commit_id, |
| 979 | ) |
| 980 | ) |
| 981 | prev_onto_manifest = onto_manifest |
| 982 | onto_manifest = new_manifest |
| 983 | onto_commit_id = new_commit_id |
| 984 | |
| 985 | # ── Flush any trailing squash group ────────────────────────────────── |
| 986 | if pending_squash_commits and pending_squash_manifest is not None: |
| 987 | squash_message = pending_squash_commits[0].message |
| 988 | squash_snap_id = compute_snapshot_id(pending_squash_manifest) |
| 989 | await upsert_snapshot( |
| 990 | session, manifest=pending_squash_manifest, snapshot_id=squash_snap_id |
| 991 | ) |
| 992 | await session.flush() |
| 993 | |
| 994 | squash_at = datetime.datetime.now(datetime.timezone.utc) |
| 995 | squash_commit_id = compute_commit_id( |
| 996 | parent_ids=[onto_commit_id], |
| 997 | snapshot_id=squash_snap_id, |
| 998 | message=squash_message, |
| 999 | committed_at_iso=squash_at.isoformat(), |
| 1000 | ) |
| 1001 | squash_commit = MuseCliCommit( |
| 1002 | commit_id=squash_commit_id, |
| 1003 | repo_id=pending_squash_commits[0].repo_id, |
| 1004 | branch=current_branch, |
| 1005 | parent_commit_id=onto_commit_id, |
| 1006 | snapshot_id=squash_snap_id, |
| 1007 | message=squash_message, |
| 1008 | author=pending_squash_commits[0].author, |
| 1009 | committed_at=squash_at, |
| 1010 | ) |
| 1011 | await insert_commit(session, squash_commit) |
| 1012 | |
| 1013 | for orig in pending_squash_commits: |
| 1014 | completed_pairs.append( |
| 1015 | RebaseCommitPair( |
| 1016 | original_commit_id=orig.commit_id, |
| 1017 | new_commit_id=squash_commit_id, |
| 1018 | ) |
| 1019 | ) |
| 1020 | onto_commit_id = squash_commit_id |
| 1021 | |
| 1022 | # ── Advance branch pointer ──────────────────────────────────────────── |
| 1023 | our_ref_path.write_text(onto_commit_id) |
| 1024 | |
| 1025 | typer.echo( |
| 1026 | f"✅ Rebased {len(completed_pairs)} commit(s) onto {upstream!r} " |
| 1027 | f"[{current_branch} {onto_commit_id[:8]}]" |
| 1028 | ) |
| 1029 | logger.info( |
| 1030 | "✅ muse rebase: %d commit(s) replayed onto %r (%s), branch %r now at %s", |
| 1031 | len(completed_pairs), |
| 1032 | upstream, |
| 1033 | upstream_commit_id[:8], |
| 1034 | current_branch, |
| 1035 | onto_commit_id[:8], |
| 1036 | ) |
| 1037 | |
| 1038 | return RebaseResult( |
| 1039 | branch=current_branch, |
| 1040 | upstream=upstream, |
| 1041 | upstream_commit_id=upstream_commit_id, |
| 1042 | base_commit_id=base_commit_id, |
| 1043 | replayed=tuple(completed_pairs), |
| 1044 | conflict_paths=(), |
| 1045 | aborted=False, |
| 1046 | noop=False, |
| 1047 | autosquash_applied=autosquash_applied, |
| 1048 | ) |
| 1049 | |
| 1050 | |
| 1051 | async def _rebase_continue_async( |
| 1052 | *, |
| 1053 | root: pathlib.Path, |
| 1054 | session: AsyncSession, |
| 1055 | ) -> RebaseResult: |
| 1056 | """Resume a rebase that was paused due to a conflict. |
| 1057 | |
| 1058 | Reads ``REBASE_STATE.json``, assumes the conflicted commit has been resolved |
| 1059 | manually, creates a new commit from the current ``onto`` state, and |
| 1060 | continues replaying the remaining commits. |
| 1061 | |
| 1062 | Args: |
| 1063 | root: Repository root. |
| 1064 | session: Open async DB session. |
| 1065 | |
| 1066 | Returns: |
| 1067 | :class:`RebaseResult` describing the completed rebase. |
| 1068 | |
| 1069 | Raises: |
| 1070 | ``typer.Exit``: If no rebase is in progress or conflicts remain. |
| 1071 | """ |
| 1072 | import json as _json |
| 1073 | |
| 1074 | rebase_state = read_rebase_state(root) |
| 1075 | if rebase_state is None: |
| 1076 | typer.echo("❌ No rebase in progress. Nothing to continue.") |
| 1077 | raise typer.Exit(code=ExitCode.USER_ERROR) |
| 1078 | |
| 1079 | if rebase_state.conflict_paths: |
| 1080 | typer.echo( |
| 1081 | f"❌ {len(rebase_state.conflict_paths)} conflict(s) not yet resolved:\n" |
| 1082 | + "\n".join( |
| 1083 | f"\tboth modified: {p}" for p in rebase_state.conflict_paths |
| 1084 | ) |
| 1085 | + "\nResolve conflicts manually, then run 'muse rebase --continue'." |
| 1086 | ) |
| 1087 | raise typer.Exit(code=ExitCode.USER_ERROR) |
| 1088 | |
| 1089 | muse_dir = root / ".muse" |
| 1090 | repo_data: dict[str, str] = _json.loads((muse_dir / "repo.json").read_text()) |
| 1091 | repo_id = repo_data["repo_id"] |
| 1092 | |
| 1093 | current_branch = rebase_state.original_branch |
| 1094 | head_ref = f"refs/heads/{current_branch}" |
| 1095 | our_ref_path = muse_dir / pathlib.Path(head_ref) |
| 1096 | |
| 1097 | completed_pairs: list[RebaseCommitPair] = [ |
| 1098 | RebaseCommitPair(original_commit_id=p[0], new_commit_id=p[1]) |
| 1099 | for p in rebase_state.completed_pairs |
| 1100 | ] |
| 1101 | |
| 1102 | onto_commit_id = rebase_state.current_onto |
| 1103 | onto_manifest = ( |
| 1104 | await get_commit_snapshot_manifest(session, onto_commit_id) or {} |
| 1105 | ) |
| 1106 | |
| 1107 | # Replay the remaining commits |
| 1108 | for orig_cid in rebase_state.commits_to_replay: |
| 1109 | commit = await session.get(MuseCliCommit, orig_cid) |
| 1110 | if commit is None: |
| 1111 | typer.echo(f"⚠️ Commit {orig_cid[:8]} not found — skipping.") |
| 1112 | continue |
| 1113 | |
| 1114 | parent_manifest: dict[str, str] = {} |
| 1115 | if commit.parent_commit_id: |
| 1116 | loaded = await get_commit_snapshot_manifest( |
| 1117 | session, commit.parent_commit_id |
| 1118 | ) |
| 1119 | if loaded is not None: |
| 1120 | parent_manifest = loaded |
| 1121 | |
| 1122 | commit_manifest = ( |
| 1123 | await get_commit_snapshot_manifest(session, commit.commit_id) or {} |
| 1124 | ) |
| 1125 | additions, deletions = compute_delta(parent_manifest, commit_manifest) |
| 1126 | new_manifest = apply_delta(onto_manifest, additions, deletions) |
| 1127 | new_snapshot_id = compute_snapshot_id(new_manifest) |
| 1128 | await upsert_snapshot( |
| 1129 | session, manifest=new_manifest, snapshot_id=new_snapshot_id |
| 1130 | ) |
| 1131 | await session.flush() |
| 1132 | |
| 1133 | committed_at = datetime.datetime.now(datetime.timezone.utc) |
| 1134 | new_commit_id = compute_commit_id( |
| 1135 | parent_ids=[onto_commit_id], |
| 1136 | snapshot_id=new_snapshot_id, |
| 1137 | message=commit.message, |
| 1138 | committed_at_iso=committed_at.isoformat(), |
| 1139 | ) |
| 1140 | new_commit = MuseCliCommit( |
| 1141 | commit_id=new_commit_id, |
| 1142 | repo_id=repo_id, |
| 1143 | branch=current_branch, |
| 1144 | parent_commit_id=onto_commit_id, |
| 1145 | snapshot_id=new_snapshot_id, |
| 1146 | message=commit.message, |
| 1147 | author=commit.author, |
| 1148 | committed_at=committed_at, |
| 1149 | commit_metadata=commit.commit_metadata, |
| 1150 | ) |
| 1151 | await insert_commit(session, new_commit) |
| 1152 | |
| 1153 | completed_pairs.append( |
| 1154 | RebaseCommitPair( |
| 1155 | original_commit_id=orig_cid, |
| 1156 | new_commit_id=new_commit_id, |
| 1157 | ) |
| 1158 | ) |
| 1159 | onto_manifest = new_manifest |
| 1160 | onto_commit_id = new_commit_id |
| 1161 | |
| 1162 | # Advance branch pointer and clear state |
| 1163 | our_ref_path.write_text(onto_commit_id) |
| 1164 | clear_rebase_state(root) |
| 1165 | |
| 1166 | upstream_commit_id = rebase_state.upstream_commit |
| 1167 | base_commit_id = rebase_state.base_commit |
| 1168 | |
| 1169 | typer.echo( |
| 1170 | f"✅ Rebase continued: {len(completed_pairs)} commit(s) applied " |
| 1171 | f"[{current_branch} {onto_commit_id[:8]}]" |
| 1172 | ) |
| 1173 | logger.info( |
| 1174 | "✅ muse rebase --continue: %d commit(s) on %r, now at %s", |
| 1175 | len(completed_pairs), |
| 1176 | current_branch, |
| 1177 | onto_commit_id[:8], |
| 1178 | ) |
| 1179 | |
| 1180 | return RebaseResult( |
| 1181 | branch=current_branch, |
| 1182 | upstream=rebase_state.upstream_commit, |
| 1183 | upstream_commit_id=upstream_commit_id, |
| 1184 | base_commit_id=base_commit_id, |
| 1185 | replayed=tuple(completed_pairs), |
| 1186 | conflict_paths=(), |
| 1187 | aborted=False, |
| 1188 | noop=False, |
| 1189 | autosquash_applied=False, |
| 1190 | ) |
| 1191 | |
| 1192 | |
| 1193 | async def _rebase_abort_async( |
| 1194 | *, |
| 1195 | root: pathlib.Path, |
| 1196 | ) -> RebaseResult: |
| 1197 | """Abort an in-progress rebase and restore the branch to its original HEAD. |
| 1198 | |
| 1199 | Reads ``REBASE_STATE.json``, restores the branch pointer to |
| 1200 | ``original_head``, and removes the state file. |
| 1201 | |
| 1202 | Args: |
| 1203 | root: Repository root. |
| 1204 | |
| 1205 | Returns: |
| 1206 | :class:`RebaseResult` with ``aborted=True``. |
| 1207 | |
| 1208 | Raises: |
| 1209 | ``typer.Exit``: If no rebase is in progress. |
| 1210 | """ |
| 1211 | rebase_state = read_rebase_state(root) |
| 1212 | if rebase_state is None: |
| 1213 | typer.echo("❌ No rebase in progress. Nothing to abort.") |
| 1214 | raise typer.Exit(code=ExitCode.USER_ERROR) |
| 1215 | |
| 1216 | muse_dir = root / ".muse" |
| 1217 | current_branch = rebase_state.original_branch |
| 1218 | head_ref = f"refs/heads/{current_branch}" |
| 1219 | our_ref_path = muse_dir / pathlib.Path(head_ref) |
| 1220 | |
| 1221 | our_ref_path.parent.mkdir(parents=True, exist_ok=True) |
| 1222 | our_ref_path.write_text(rebase_state.original_head) |
| 1223 | |
| 1224 | clear_rebase_state(root) |
| 1225 | |
| 1226 | typer.echo( |
| 1227 | f"✅ Rebase aborted. Branch {current_branch!r} restored to " |
| 1228 | f"{rebase_state.original_head[:8]}." |
| 1229 | ) |
| 1230 | logger.info( |
| 1231 | "✅ muse rebase --abort: %r restored to %s", |
| 1232 | current_branch, |
| 1233 | rebase_state.original_head[:8], |
| 1234 | ) |
| 1235 | |
| 1236 | return RebaseResult( |
| 1237 | branch=current_branch, |
| 1238 | upstream="", |
| 1239 | upstream_commit_id=rebase_state.upstream_commit, |
| 1240 | base_commit_id=rebase_state.base_commit, |
| 1241 | replayed=(), |
| 1242 | conflict_paths=(), |
| 1243 | aborted=True, |
| 1244 | noop=False, |
| 1245 | autosquash_applied=False, |
| 1246 | ) |