Skip to content

[JAX] Fix MNIST L2 jax test instability#2933

Open
tdophung wants to merge 5 commits intoNVIDIA:mainfrom
tdophung:tdophung/fix-mnist-CI-instability
Open

[JAX] Fix MNIST L2 jax test instability#2933
tdophung wants to merge 5 commits intoNVIDIA:mainfrom
tdophung:tdophung/fix-mnist-CI-instability

Conversation

@tdophung
Copy link
Copy Markdown
Collaborator

@tdophung tdophung commented Apr 27, 2026

Description

L2 jax unittest for mnist sometimes fails in CI for non-error reasons such as noisy loss near convergence. This PR fixes that with looser thresholds, and checking the min loss in the last 10% of steps to filter out noisy high loss near convergence. Also apply xla flags for deterministic results

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Loosen loss failure threshold in MNIST test
  • Add xla gpu deterministic flag for this test
  • Fix typo for test accuracy check (9.8% -> 98%)

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

tdophung and others added 3 commits April 27, 2026 10:32
…o avoid failing by noise near convergence

Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
@tdophung
Copy link
Copy Markdown
Collaborator Author

/te-ci jax L2

@tdophung tdophung changed the title Fix MNIST L2 jax test instability [JAX] Fix MNIST L2 jax test instability Apr 27, 2026
@tdophung
Copy link
Copy Markdown
Collaborator Author

/te-ci jax L2

@tdophung tdophung marked this pull request as ready for review April 28, 2026 16:24
@tdophung tdophung requested a review from phu0ngng April 28, 2026 16:24
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Apr 28, 2026

Greptile Summary

This PR addresses MNIST CI test flakiness by collecting per-epoch metrics and asserting on the best value within the last ~10% of epochs (minimum 2), loosening loss thresholds slightly, and applying --xla_gpu_deterministic_ops before the mnist tests. It also fixes a meaningful correctness bug where the test accuracy threshold was 0.098 (9.8%) instead of 0.98 (98%), making the accuracy check effectively a no-op against any converged model.

Confidence Score: 4/5

Safe to merge; changes are test-only and reduce CI flakiness with no functional impact on training code.

No P0/P1 issues. The tail-window approach is sound and the threshold tightening (accuracy typo fix) makes tests more meaningful. Two minor P2s noted.

No files require special attention.

Important Files Changed

Filename Overview
examples/jax/mnist/test_single_gpu_mnist.py Refactored verify() to collect per-epoch metrics lists and assess the best metric over a tail window; fixed critical accuracy threshold typo (0.098 → 0.98); loosened loss thresholds slightly.
qa/L2_jax_unittest/test.sh Moved XLA deterministic ops flag export to before the mnist test so both mnist and encoder tests run deterministically; comment updated to match.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[train_and_evaluate runs N epochs] --> B[Collect per-epoch: train_losses, train_accuracies, test_losses, test_accuracies]
    B --> C[Return 4 lists]
    C --> D[verify called with lists]
    D --> E["tail = max(2, ceil(epochs × 0.1))"]
    E --> F["tail = min(tail, epochs)"]
    F --> G["Slice last 'tail' entries from each list"]
    G --> H["best_train_loss = min(train_losses[-tail:])"]
    G --> I["best_train_accuracy = max(train_accuracies[-tail:])"]
    G --> J["best_test_loss = min(test_losses[-tail:])"]
    G --> K["best_test_accuracy = max(test_accuracies[-tail:])"]
    H --> L{best_train_loss < 0.06?}
    I --> M{best_train_accuracy > 0.98?}
    J --> N{best_test_loss < 0.05?}
    K --> O{best_test_accuracy > 0.98?}
    L -- No --> FAIL[AssertionError with descriptive message]
    M -- No --> FAIL
    N -- No --> FAIL
    O -- No --> FAIL
    L -- Yes --> PASS
    M -- Yes --> PASS
    N -- Yes --> PASS
    O -- Yes --> PASS[All assertions pass]
Loading

Reviews (1): Last reviewed commit: "Merge branch 'main' into tdophung/fix-mn..." | Re-trigger Greptile

Comment on lines +347 to +348
tail = max(2, math.ceil(epochs * 0.1))
tail = min(tail, epochs)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 tail window is 40% of epochs, not ~10%, for the default epoch count

With --epochs 5 (the value used in setUpClass), math.ceil(5 × 0.1) = 1, so tail = max(2, 1) = 2, which covers the last 2 of 5 epochs (40%). The docstring says "last ~10% of epochs", which is misleading for the typical test configuration. This isn't a bug — 2 epochs is perfectly reasonable — but the comment could note that the minimum of 2 dominates for small epoch counts, so callers aren't surprised when they see the window covering a large fraction of training.

Comment on lines 336 to +345
def verify(actual):
"""Check If loss and accuracy match target"""
desired_traing_loss = 0.055
"""Check that loss and accuracy match target.

``actual`` is ``[train_losses, train_accuracies, test_losses, test_accuracies]``,
i.e. per-epoch lists of metrics. To avoid flakiness from stochastic noise in
the final epoch near convergence (especially under FP8), the check considers
a tail window of the last ~10% of epochs (at least 2) and asserts on the
best metric within that window.
"""
train_losses, train_accuracies, test_losses, test_accuracies = actual
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 verify() will raise an unhelpful TypeError if passed None

train_and_evaluate returns None when args.dry_run is True. If a future test (or a direct call) passes None to verify(), the tuple-unpack on line 345 raises TypeError: cannot unpack non-iterable NoneType object with no context. Adding an early guard keeps failure messages readable:

if actual is None:
    return  # dry_run path; nothing to verify

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant