[JAX] Fix MNIST L2 jax test instability#2933
Conversation
…o avoid failing by noise near convergence Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
for more information, see https://pre-commit.ci
|
/te-ci jax L2 |
|
/te-ci jax L2 |
Greptile SummaryThis PR addresses MNIST CI test flakiness by collecting per-epoch metrics and asserting on the best value within the last ~10% of epochs (minimum 2), loosening loss thresholds slightly, and applying Confidence Score: 4/5Safe to merge; changes are test-only and reduce CI flakiness with no functional impact on training code. No P0/P1 issues. The tail-window approach is sound and the threshold tightening (accuracy typo fix) makes tests more meaningful. Two minor P2s noted. No files require special attention. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[train_and_evaluate runs N epochs] --> B[Collect per-epoch: train_losses, train_accuracies, test_losses, test_accuracies]
B --> C[Return 4 lists]
C --> D[verify called with lists]
D --> E["tail = max(2, ceil(epochs × 0.1))"]
E --> F["tail = min(tail, epochs)"]
F --> G["Slice last 'tail' entries from each list"]
G --> H["best_train_loss = min(train_losses[-tail:])"]
G --> I["best_train_accuracy = max(train_accuracies[-tail:])"]
G --> J["best_test_loss = min(test_losses[-tail:])"]
G --> K["best_test_accuracy = max(test_accuracies[-tail:])"]
H --> L{best_train_loss < 0.06?}
I --> M{best_train_accuracy > 0.98?}
J --> N{best_test_loss < 0.05?}
K --> O{best_test_accuracy > 0.98?}
L -- No --> FAIL[AssertionError with descriptive message]
M -- No --> FAIL
N -- No --> FAIL
O -- No --> FAIL
L -- Yes --> PASS
M -- Yes --> PASS
N -- Yes --> PASS
O -- Yes --> PASS[All assertions pass]
Reviews (1): Last reviewed commit: "Merge branch 'main' into tdophung/fix-mn..." | Re-trigger Greptile |
| tail = max(2, math.ceil(epochs * 0.1)) | ||
| tail = min(tail, epochs) |
There was a problem hiding this comment.
tail window is 40% of epochs, not ~10%, for the default epoch count
With --epochs 5 (the value used in setUpClass), math.ceil(5 × 0.1) = 1, so tail = max(2, 1) = 2, which covers the last 2 of 5 epochs (40%). The docstring says "last ~10% of epochs", which is misleading for the typical test configuration. This isn't a bug — 2 epochs is perfectly reasonable — but the comment could note that the minimum of 2 dominates for small epoch counts, so callers aren't surprised when they see the window covering a large fraction of training.
| def verify(actual): | ||
| """Check If loss and accuracy match target""" | ||
| desired_traing_loss = 0.055 | ||
| """Check that loss and accuracy match target. | ||
|
|
||
| ``actual`` is ``[train_losses, train_accuracies, test_losses, test_accuracies]``, | ||
| i.e. per-epoch lists of metrics. To avoid flakiness from stochastic noise in | ||
| the final epoch near convergence (especially under FP8), the check considers | ||
| a tail window of the last ~10% of epochs (at least 2) and asserts on the | ||
| best metric within that window. | ||
| """ | ||
| train_losses, train_accuracies, test_losses, test_accuracies = actual |
There was a problem hiding this comment.
verify() will raise an unhelpful TypeError if passed None
train_and_evaluate returns None when args.dry_run is True. If a future test (or a direct call) passes None to verify(), the tuple-unpack on line 345 raises TypeError: cannot unpack non-iterable NoneType object with no context. Adding an early guard keeps failure messages readable:
if actual is None:
return # dry_run path; nothing to verify
Description
L2 jax unittest for mnist sometimes fails in CI for non-error reasons such as noisy loss near convergence. This PR fixes that with looser thresholds, and checking the min loss in the last 10% of steps to filter out noisy high loss near convergence. Also apply xla flags for deterministic results
Fixes # (issue)
Type of change
Changes
Checklist: