test_commit_drift_safety.py
python
| 1 | """Tests for drift-aware commit safety (Phase 8). |
| 2 | |
| 3 | Verifies: |
| 4 | - Clean commits succeed (200). |
| 5 | - Dirty notes block commit (409 + WORKING_TREE_DIRTY). |
| 6 | - Dirty controllers block commit (409). |
| 7 | - force=True bypasses drift check. |
| 8 | - Commit route boundary seal (AST). |
| 9 | """ |
| 10 | from __future__ import annotations |
| 11 | |
| 12 | import ast |
| 13 | import uuid |
| 14 | from dataclasses import asdict |
| 15 | from pathlib import Path |
| 16 | |
| 17 | import pytest |
| 18 | |
| 19 | from maestro.contracts.json_types import CCEventDict, NoteDict |
| 20 | from maestro.services.muse_drift import ( |
| 21 | CommitConflictPayload, |
| 22 | DriftReport, |
| 23 | DriftSeverity, |
| 24 | RegionDriftSummary, |
| 25 | compute_drift_report, |
| 26 | ) |
| 27 | |
| 28 | |
| 29 | # ── Helpers ─────────────────────────────────────────────────────────────── |
| 30 | |
| 31 | |
| 32 | def _note(pitch: int, start: float) -> NoteDict: |
| 33 | |
| 34 | return {"pitch": pitch, "start_beat": start, "duration_beats": 1.0, "velocity": 100, "channel": 0} |
| 35 | |
| 36 | |
| 37 | def _cc(cc_num: int, beat: float, value: int) -> CCEventDict: |
| 38 | |
| 39 | return {"cc": cc_num, "beat": beat, "value": value} |
| 40 | |
| 41 | |
| 42 | # --------------------------------------------------------------------------- |
| 43 | # 5.1 — Clean commit allowed |
| 44 | # --------------------------------------------------------------------------- |
| 45 | |
| 46 | |
| 47 | class TestCleanCommitAllowed: |
| 48 | |
| 49 | def test_clean_drift_does_not_require_action(self) -> None: |
| 50 | |
| 51 | report = compute_drift_report( |
| 52 | project_id="p1", |
| 53 | head_variation_id="v1", |
| 54 | head_snapshot_notes={"r1": [_note(60, 0.0)]}, |
| 55 | working_snapshot_notes={"r1": [_note(60, 0.0)]}, |
| 56 | track_regions={"r1": "t1"}, |
| 57 | ) |
| 58 | assert report.is_clean is True |
| 59 | assert report.requires_user_action() is False |
| 60 | assert report.severity == DriftSeverity.CLEAN |
| 61 | |
| 62 | def test_clean_drift_with_controllers(self) -> None: |
| 63 | |
| 64 | cc_events = [_cc(64, 0.0, 127)] |
| 65 | report = compute_drift_report( |
| 66 | project_id="p1", |
| 67 | head_variation_id="v1", |
| 68 | head_snapshot_notes={"r1": [_note(60, 0.0)]}, |
| 69 | working_snapshot_notes={"r1": [_note(60, 0.0)]}, |
| 70 | track_regions={"r1": "t1"}, |
| 71 | head_cc={"r1": cc_events}, |
| 72 | working_cc={"r1": cc_events}, |
| 73 | ) |
| 74 | assert report.requires_user_action() is False |
| 75 | |
| 76 | |
| 77 | # --------------------------------------------------------------------------- |
| 78 | # 5.2 — Dirty notes blocked (409) |
| 79 | # --------------------------------------------------------------------------- |
| 80 | |
| 81 | |
| 82 | class TestDirtyNotesBlocked: |
| 83 | |
| 84 | def test_dirty_drift_requires_action(self) -> None: |
| 85 | |
| 86 | report = compute_drift_report( |
| 87 | project_id="p1", |
| 88 | head_variation_id="v1", |
| 89 | head_snapshot_notes={"r1": [_note(60, 0.0)]}, |
| 90 | working_snapshot_notes={"r1": [_note(60, 0.0), _note(72, 2.0)]}, |
| 91 | track_regions={"r1": "t1"}, |
| 92 | ) |
| 93 | assert report.is_clean is False |
| 94 | assert report.requires_user_action() is True |
| 95 | assert report.severity == DriftSeverity.DIRTY |
| 96 | |
| 97 | def test_conflict_payload_from_dirty_report(self) -> None: |
| 98 | |
| 99 | report = compute_drift_report( |
| 100 | project_id="p1", |
| 101 | head_variation_id="v1", |
| 102 | head_snapshot_notes={"r1": [_note(60, 0.0)]}, |
| 103 | working_snapshot_notes={"r1": [_note(60, 0.0), _note(72, 2.0)]}, |
| 104 | track_regions={"r1": "t1"}, |
| 105 | ) |
| 106 | conflict = CommitConflictPayload.from_drift_report(report) |
| 107 | assert conflict.severity == "dirty" |
| 108 | assert conflict.total_changes == 1 |
| 109 | assert "r1" in conflict.changed_regions |
| 110 | |
| 111 | def test_conflict_payload_serializable(self) -> None: |
| 112 | |
| 113 | report = compute_drift_report( |
| 114 | project_id="p1", |
| 115 | head_variation_id="v1", |
| 116 | head_snapshot_notes={"r1": [_note(60, 0.0)]}, |
| 117 | working_snapshot_notes={"r1": [_note(60, 0.0), _note(72, 2.0)]}, |
| 118 | track_regions={"r1": "t1"}, |
| 119 | ) |
| 120 | conflict = CommitConflictPayload.from_drift_report(report) |
| 121 | payload = asdict(conflict) |
| 122 | assert payload["severity"] == "dirty" |
| 123 | assert payload["total_changes"] == 1 |
| 124 | assert "error" not in payload |
| 125 | |
| 126 | def test_conflict_payload_error_field(self) -> None: |
| 127 | |
| 128 | """The 409 detail must include 'error': 'WORKING_TREE_DIRTY'.""" |
| 129 | report = compute_drift_report( |
| 130 | project_id="p1", |
| 131 | head_variation_id="v1", |
| 132 | head_snapshot_notes={"r1": [_note(60, 0.0)]}, |
| 133 | working_snapshot_notes={"r1": [_note(60, 0.0), _note(72, 2.0)]}, |
| 134 | track_regions={"r1": "t1"}, |
| 135 | ) |
| 136 | conflict = CommitConflictPayload.from_drift_report(report) |
| 137 | drift_dict = asdict(conflict) |
| 138 | detail: dict[str, object] = { |
| 139 | "error": "WORKING_TREE_DIRTY", |
| 140 | "drift": drift_dict, |
| 141 | } |
| 142 | assert detail["error"] == "WORKING_TREE_DIRTY" |
| 143 | assert drift_dict["severity"] == "dirty" |
| 144 | |
| 145 | |
| 146 | # --------------------------------------------------------------------------- |
| 147 | # 5.3 — Dirty controllers blocked (409) |
| 148 | # --------------------------------------------------------------------------- |
| 149 | |
| 150 | |
| 151 | class TestDirtyControllersBlocked: |
| 152 | |
| 153 | def test_cc_change_requires_action(self) -> None: |
| 154 | |
| 155 | report = compute_drift_report( |
| 156 | project_id="p1", |
| 157 | head_variation_id="v1", |
| 158 | head_snapshot_notes={"r1": [_note(60, 0.0)]}, |
| 159 | working_snapshot_notes={"r1": [_note(60, 0.0)]}, |
| 160 | track_regions={"r1": "t1"}, |
| 161 | head_cc={"r1": []}, |
| 162 | working_cc={"r1": [_cc(64, 0.0, 127)]}, |
| 163 | ) |
| 164 | assert report.requires_user_action() is True |
| 165 | conflict = CommitConflictPayload.from_drift_report(report) |
| 166 | assert conflict.total_changes == 1 |
| 167 | |
| 168 | def test_pb_change_requires_action(self) -> None: |
| 169 | |
| 170 | report = compute_drift_report( |
| 171 | project_id="p1", |
| 172 | head_variation_id="v1", |
| 173 | head_snapshot_notes={"r1": [_note(60, 0.0)]}, |
| 174 | working_snapshot_notes={"r1": [_note(60, 0.0)]}, |
| 175 | track_regions={"r1": "t1"}, |
| 176 | head_pb={"r1": [{"beat": 1.0, "value": 4096}]}, |
| 177 | working_pb={"r1": [{"beat": 1.0, "value": 8192}]}, |
| 178 | ) |
| 179 | assert report.requires_user_action() is True |
| 180 | |
| 181 | |
| 182 | # --------------------------------------------------------------------------- |
| 183 | # 5.4 — Force commit allowed |
| 184 | # --------------------------------------------------------------------------- |
| 185 | |
| 186 | |
| 187 | class TestForceCommitAllowed: |
| 188 | |
| 189 | def test_force_field_exists_on_request_model(self) -> None: |
| 190 | |
| 191 | from maestro.models.requests import CommitVariationRequest |
| 192 | req = CommitVariationRequest( |
| 193 | project_id="p1", |
| 194 | base_state_id="s1", |
| 195 | variation_id="v1", |
| 196 | accepted_phrase_ids=["ph1"], |
| 197 | force=True, |
| 198 | ) |
| 199 | assert req.force is True |
| 200 | |
| 201 | def test_force_default_is_false(self) -> None: |
| 202 | |
| 203 | from maestro.models.requests import CommitVariationRequest |
| 204 | req = CommitVariationRequest( |
| 205 | project_id="p1", |
| 206 | base_state_id="s1", |
| 207 | variation_id="v1", |
| 208 | accepted_phrase_ids=["ph1"], |
| 209 | ) |
| 210 | assert req.force is False |
| 211 | |
| 212 | def test_requires_user_action_still_true_with_force(self) -> None: |
| 213 | |
| 214 | """Force doesn't change the drift report — only the commit route checks it.""" |
| 215 | report = compute_drift_report( |
| 216 | project_id="p1", |
| 217 | head_variation_id="v1", |
| 218 | head_snapshot_notes={"r1": [_note(60, 0.0)]}, |
| 219 | working_snapshot_notes={"r1": [_note(60, 0.0), _note(72, 2.0)]}, |
| 220 | track_regions={"r1": "t1"}, |
| 221 | ) |
| 222 | assert report.requires_user_action() is True |
| 223 | |
| 224 | |
| 225 | # --------------------------------------------------------------------------- |
| 226 | # CommitConflictPayload unit tests |
| 227 | # --------------------------------------------------------------------------- |
| 228 | |
| 229 | |
| 230 | class TestCommitConflictPayload: |
| 231 | |
| 232 | def test_fingerprint_delta_only_dirty_regions(self) -> None: |
| 233 | |
| 234 | report = compute_drift_report( |
| 235 | project_id="p1", |
| 236 | head_variation_id="v1", |
| 237 | head_snapshot_notes={ |
| 238 | "r1": [_note(60, 0.0)], |
| 239 | "r2": [_note(72, 0.0)], |
| 240 | }, |
| 241 | working_snapshot_notes={ |
| 242 | "r1": [_note(60, 0.0)], |
| 243 | "r2": [_note(72, 0.0), _note(76, 2.0)], |
| 244 | }, |
| 245 | track_regions={"r1": "t1", "r2": "t2"}, |
| 246 | ) |
| 247 | conflict = CommitConflictPayload.from_drift_report(report) |
| 248 | assert "r2" in conflict.fingerprint_delta |
| 249 | assert "r1" not in conflict.fingerprint_delta |
| 250 | |
| 251 | def test_payload_excludes_sample_changes(self) -> None: |
| 252 | |
| 253 | report = compute_drift_report( |
| 254 | project_id="p1", |
| 255 | head_variation_id="v1", |
| 256 | head_snapshot_notes={"r1": [_note(60, 0.0)]}, |
| 257 | working_snapshot_notes={"r1": [_note(60, 0.0), _note(72, 2.0)]}, |
| 258 | track_regions={"r1": "t1"}, |
| 259 | ) |
| 260 | conflict = CommitConflictPayload.from_drift_report(report) |
| 261 | payload = asdict(conflict) |
| 262 | assert "sample_changes" not in payload |
| 263 | assert "region_summaries" not in payload |
| 264 | |
| 265 | |
| 266 | # --------------------------------------------------------------------------- |
| 267 | # 5.5 — Boundary seal |
| 268 | # --------------------------------------------------------------------------- |
| 269 | |
| 270 | |
| 271 | class TestCommitRouteBoundary: |
| 272 | |
| 273 | def test_no_drift_internal_imports(self) -> None: |
| 274 | |
| 275 | """Commit route may only import compute_drift_report and CommitConflictPayload from drift.""" |
| 276 | filepath = Path(__file__).resolve().parent.parent / "maestro" / "api" / "routes" / "variation" / "commit.py" |
| 277 | tree = ast.parse(filepath.read_text()) |
| 278 | forbidden_names = {"_fingerprint", "_combined_fingerprint", "RegionDriftSummary", "DriftSeverity"} |
| 279 | for node in ast.walk(tree): |
| 280 | if isinstance(node, (ast.Import, ast.ImportFrom)): |
| 281 | for alias in node.names: |
| 282 | assert alias.name not in forbidden_names, ( |
| 283 | f"commit.py imports drift internal: {alias.name}" |
| 284 | ) |
| 285 | |
| 286 | def test_commit_route_imports_only_public_drift_api(self) -> None: |
| 287 | |
| 288 | """Only compute_drift_report and CommitConflictPayload are used from muse_drift.""" |
| 289 | filepath = Path(__file__).resolve().parent.parent / "maestro" / "api" / "routes" / "variation" / "commit.py" |
| 290 | tree = ast.parse(filepath.read_text()) |
| 291 | drift_imports: list[str] = [] |
| 292 | for node in ast.walk(tree): |
| 293 | if isinstance(node, ast.ImportFrom) and node.module and "muse_drift" in node.module: |
| 294 | for alias in node.names: |
| 295 | drift_imports.append(alias.name) |
| 296 | allowed = {"compute_drift_report", "CommitConflictPayload"} |
| 297 | for name in drift_imports: |
| 298 | assert name in allowed, f"commit.py imports non-public drift symbol: {name}" |
| 299 | |
| 300 | def test_commit_route_does_not_import_state_store_internals(self) -> None: |
| 301 | |
| 302 | """Commit route uses get_or_create_store (allowed) but not StateStore class directly.""" |
| 303 | filepath = Path(__file__).resolve().parent.parent / "maestro" / "api" / "routes" / "variation" / "commit.py" |
| 304 | tree = ast.parse(filepath.read_text()) |
| 305 | for node in ast.walk(tree): |
| 306 | if isinstance(node, (ast.Import, ast.ImportFrom)): |
| 307 | for alias in node.names: |
| 308 | assert alias.name != "StateStore", ( |
| 309 | "commit.py imports StateStore class directly" |
| 310 | ) |