cgcardona / muse public
test_commit_drift_safety.py python
310 lines 11.2 KB
12901c5a Initial extraction from tellurstori/maestro cgcardona <gabriel@tellurstori.com> 4d ago
1 """Tests for drift-aware commit safety (Phase 8).
2
3 Verifies:
4 - Clean commits succeed (200).
5 - Dirty notes block commit (409 + WORKING_TREE_DIRTY).
6 - Dirty controllers block commit (409).
7 - force=True bypasses drift check.
8 - Commit route boundary seal (AST).
9 """
10 from __future__ import annotations
11
12 import ast
13 import uuid
14 from dataclasses import asdict
15 from pathlib import Path
16
17 import pytest
18
19 from maestro.contracts.json_types import CCEventDict, NoteDict
20 from maestro.services.muse_drift import (
21 CommitConflictPayload,
22 DriftReport,
23 DriftSeverity,
24 RegionDriftSummary,
25 compute_drift_report,
26 )
27
28
29 # ── Helpers ───────────────────────────────────────────────────────────────
30
31
32 def _note(pitch: int, start: float) -> NoteDict:
33
34 return {"pitch": pitch, "start_beat": start, "duration_beats": 1.0, "velocity": 100, "channel": 0}
35
36
37 def _cc(cc_num: int, beat: float, value: int) -> CCEventDict:
38
39 return {"cc": cc_num, "beat": beat, "value": value}
40
41
42 # ---------------------------------------------------------------------------
43 # 5.1 — Clean commit allowed
44 # ---------------------------------------------------------------------------
45
46
47 class TestCleanCommitAllowed:
48
49 def test_clean_drift_does_not_require_action(self) -> None:
50
51 report = compute_drift_report(
52 project_id="p1",
53 head_variation_id="v1",
54 head_snapshot_notes={"r1": [_note(60, 0.0)]},
55 working_snapshot_notes={"r1": [_note(60, 0.0)]},
56 track_regions={"r1": "t1"},
57 )
58 assert report.is_clean is True
59 assert report.requires_user_action() is False
60 assert report.severity == DriftSeverity.CLEAN
61
62 def test_clean_drift_with_controllers(self) -> None:
63
64 cc_events = [_cc(64, 0.0, 127)]
65 report = compute_drift_report(
66 project_id="p1",
67 head_variation_id="v1",
68 head_snapshot_notes={"r1": [_note(60, 0.0)]},
69 working_snapshot_notes={"r1": [_note(60, 0.0)]},
70 track_regions={"r1": "t1"},
71 head_cc={"r1": cc_events},
72 working_cc={"r1": cc_events},
73 )
74 assert report.requires_user_action() is False
75
76
77 # ---------------------------------------------------------------------------
78 # 5.2 — Dirty notes blocked (409)
79 # ---------------------------------------------------------------------------
80
81
82 class TestDirtyNotesBlocked:
83
84 def test_dirty_drift_requires_action(self) -> None:
85
86 report = compute_drift_report(
87 project_id="p1",
88 head_variation_id="v1",
89 head_snapshot_notes={"r1": [_note(60, 0.0)]},
90 working_snapshot_notes={"r1": [_note(60, 0.0), _note(72, 2.0)]},
91 track_regions={"r1": "t1"},
92 )
93 assert report.is_clean is False
94 assert report.requires_user_action() is True
95 assert report.severity == DriftSeverity.DIRTY
96
97 def test_conflict_payload_from_dirty_report(self) -> None:
98
99 report = compute_drift_report(
100 project_id="p1",
101 head_variation_id="v1",
102 head_snapshot_notes={"r1": [_note(60, 0.0)]},
103 working_snapshot_notes={"r1": [_note(60, 0.0), _note(72, 2.0)]},
104 track_regions={"r1": "t1"},
105 )
106 conflict = CommitConflictPayload.from_drift_report(report)
107 assert conflict.severity == "dirty"
108 assert conflict.total_changes == 1
109 assert "r1" in conflict.changed_regions
110
111 def test_conflict_payload_serializable(self) -> None:
112
113 report = compute_drift_report(
114 project_id="p1",
115 head_variation_id="v1",
116 head_snapshot_notes={"r1": [_note(60, 0.0)]},
117 working_snapshot_notes={"r1": [_note(60, 0.0), _note(72, 2.0)]},
118 track_regions={"r1": "t1"},
119 )
120 conflict = CommitConflictPayload.from_drift_report(report)
121 payload = asdict(conflict)
122 assert payload["severity"] == "dirty"
123 assert payload["total_changes"] == 1
124 assert "error" not in payload
125
126 def test_conflict_payload_error_field(self) -> None:
127
128 """The 409 detail must include 'error': 'WORKING_TREE_DIRTY'."""
129 report = compute_drift_report(
130 project_id="p1",
131 head_variation_id="v1",
132 head_snapshot_notes={"r1": [_note(60, 0.0)]},
133 working_snapshot_notes={"r1": [_note(60, 0.0), _note(72, 2.0)]},
134 track_regions={"r1": "t1"},
135 )
136 conflict = CommitConflictPayload.from_drift_report(report)
137 drift_dict = asdict(conflict)
138 detail: dict[str, object] = {
139 "error": "WORKING_TREE_DIRTY",
140 "drift": drift_dict,
141 }
142 assert detail["error"] == "WORKING_TREE_DIRTY"
143 assert drift_dict["severity"] == "dirty"
144
145
146 # ---------------------------------------------------------------------------
147 # 5.3 — Dirty controllers blocked (409)
148 # ---------------------------------------------------------------------------
149
150
151 class TestDirtyControllersBlocked:
152
153 def test_cc_change_requires_action(self) -> None:
154
155 report = compute_drift_report(
156 project_id="p1",
157 head_variation_id="v1",
158 head_snapshot_notes={"r1": [_note(60, 0.0)]},
159 working_snapshot_notes={"r1": [_note(60, 0.0)]},
160 track_regions={"r1": "t1"},
161 head_cc={"r1": []},
162 working_cc={"r1": [_cc(64, 0.0, 127)]},
163 )
164 assert report.requires_user_action() is True
165 conflict = CommitConflictPayload.from_drift_report(report)
166 assert conflict.total_changes == 1
167
168 def test_pb_change_requires_action(self) -> None:
169
170 report = compute_drift_report(
171 project_id="p1",
172 head_variation_id="v1",
173 head_snapshot_notes={"r1": [_note(60, 0.0)]},
174 working_snapshot_notes={"r1": [_note(60, 0.0)]},
175 track_regions={"r1": "t1"},
176 head_pb={"r1": [{"beat": 1.0, "value": 4096}]},
177 working_pb={"r1": [{"beat": 1.0, "value": 8192}]},
178 )
179 assert report.requires_user_action() is True
180
181
182 # ---------------------------------------------------------------------------
183 # 5.4 — Force commit allowed
184 # ---------------------------------------------------------------------------
185
186
187 class TestForceCommitAllowed:
188
189 def test_force_field_exists_on_request_model(self) -> None:
190
191 from maestro.models.requests import CommitVariationRequest
192 req = CommitVariationRequest(
193 project_id="p1",
194 base_state_id="s1",
195 variation_id="v1",
196 accepted_phrase_ids=["ph1"],
197 force=True,
198 )
199 assert req.force is True
200
201 def test_force_default_is_false(self) -> None:
202
203 from maestro.models.requests import CommitVariationRequest
204 req = CommitVariationRequest(
205 project_id="p1",
206 base_state_id="s1",
207 variation_id="v1",
208 accepted_phrase_ids=["ph1"],
209 )
210 assert req.force is False
211
212 def test_requires_user_action_still_true_with_force(self) -> None:
213
214 """Force doesn't change the drift report — only the commit route checks it."""
215 report = compute_drift_report(
216 project_id="p1",
217 head_variation_id="v1",
218 head_snapshot_notes={"r1": [_note(60, 0.0)]},
219 working_snapshot_notes={"r1": [_note(60, 0.0), _note(72, 2.0)]},
220 track_regions={"r1": "t1"},
221 )
222 assert report.requires_user_action() is True
223
224
225 # ---------------------------------------------------------------------------
226 # CommitConflictPayload unit tests
227 # ---------------------------------------------------------------------------
228
229
230 class TestCommitConflictPayload:
231
232 def test_fingerprint_delta_only_dirty_regions(self) -> None:
233
234 report = compute_drift_report(
235 project_id="p1",
236 head_variation_id="v1",
237 head_snapshot_notes={
238 "r1": [_note(60, 0.0)],
239 "r2": [_note(72, 0.0)],
240 },
241 working_snapshot_notes={
242 "r1": [_note(60, 0.0)],
243 "r2": [_note(72, 0.0), _note(76, 2.0)],
244 },
245 track_regions={"r1": "t1", "r2": "t2"},
246 )
247 conflict = CommitConflictPayload.from_drift_report(report)
248 assert "r2" in conflict.fingerprint_delta
249 assert "r1" not in conflict.fingerprint_delta
250
251 def test_payload_excludes_sample_changes(self) -> None:
252
253 report = compute_drift_report(
254 project_id="p1",
255 head_variation_id="v1",
256 head_snapshot_notes={"r1": [_note(60, 0.0)]},
257 working_snapshot_notes={"r1": [_note(60, 0.0), _note(72, 2.0)]},
258 track_regions={"r1": "t1"},
259 )
260 conflict = CommitConflictPayload.from_drift_report(report)
261 payload = asdict(conflict)
262 assert "sample_changes" not in payload
263 assert "region_summaries" not in payload
264
265
266 # ---------------------------------------------------------------------------
267 # 5.5 — Boundary seal
268 # ---------------------------------------------------------------------------
269
270
271 class TestCommitRouteBoundary:
272
273 def test_no_drift_internal_imports(self) -> None:
274
275 """Commit route may only import compute_drift_report and CommitConflictPayload from drift."""
276 filepath = Path(__file__).resolve().parent.parent / "maestro" / "api" / "routes" / "variation" / "commit.py"
277 tree = ast.parse(filepath.read_text())
278 forbidden_names = {"_fingerprint", "_combined_fingerprint", "RegionDriftSummary", "DriftSeverity"}
279 for node in ast.walk(tree):
280 if isinstance(node, (ast.Import, ast.ImportFrom)):
281 for alias in node.names:
282 assert alias.name not in forbidden_names, (
283 f"commit.py imports drift internal: {alias.name}"
284 )
285
286 def test_commit_route_imports_only_public_drift_api(self) -> None:
287
288 """Only compute_drift_report and CommitConflictPayload are used from muse_drift."""
289 filepath = Path(__file__).resolve().parent.parent / "maestro" / "api" / "routes" / "variation" / "commit.py"
290 tree = ast.parse(filepath.read_text())
291 drift_imports: list[str] = []
292 for node in ast.walk(tree):
293 if isinstance(node, ast.ImportFrom) and node.module and "muse_drift" in node.module:
294 for alias in node.names:
295 drift_imports.append(alias.name)
296 allowed = {"compute_drift_report", "CommitConflictPayload"}
297 for name in drift_imports:
298 assert name in allowed, f"commit.py imports non-public drift symbol: {name}"
299
300 def test_commit_route_does_not_import_state_store_internals(self) -> None:
301
302 """Commit route uses get_or_create_store (allowed) but not StateStore class directly."""
303 filepath = Path(__file__).resolve().parent.parent / "maestro" / "api" / "routes" / "variation" / "commit.py"
304 tree = ast.parse(filepath.read_text())
305 for node in ast.walk(tree):
306 if isinstance(node, (ast.Import, ast.ImportFrom)):
307 for alias in node.names:
308 assert alias.name != "StateStore", (
309 "commit.py imports StateStore class directly"
310 )