Skip to content

fix: detach optimizer-state buffers in L-BFGS / BFGS step#559

Merged
CompRhys merged 3 commits into
TorchSim:mainfrom
niklashoelter:fix/optimizer-history-detach
May 13, 2026
Merged

fix: detach optimizer-state buffers in L-BFGS / BFGS step#559
CompRhys merged 3 commits into
TorchSim:mainfrom
niklashoelter:fix/optimizer-history-detach

Conversation

@niklashoelter
Copy link
Copy Markdown
Contributor

Summary

L-BFGS and BFGS leak memory linearly in the number of optimizer steps when used with a model that requires positions.requires_grad=True (e.g. conservative force fields computing forces via torch.autograd.grad, such as ORB-v3 *_conservative_*). Three .detach() calls fix it.

Root cause

Both optimizers update an incremental per-step buffer derived from grad-tracked positions/forces:

  • L-BFGS state.s_history = torch.cat([state.s_history, s_new.unsqueeze(1)], dim=1)s_new has grad_fn from positions, so the cat result has a grad_fn whose saved tensors include the previous state.s_history. Iterated → one link per step → every past history tensor plus its saved-for-backward intermediates stay alive.
  • BFGS state.hessian[idx] = H - term1 - term2 — same pattern via dp/df.

Neither buffer is ever differentiated through.

FIRE is unaffected (overwrites state each step, no incremental buffer).

Evidence

Reproducer: InFlightAutoBatcher + ORB-v3 conservative + L-BFGS on 1614 molecular structures (28–185 atoms), padding=0.95, steps_between_swaps=1, 24 GB GPU.

Without fix: batch composition stable (n_atoms ~6000, n_edges at cap), yet memory_allocated grows linearly +22.6 MB/swap → OOM at ~10 min. torch.cuda.memory._snapshot at swap=150 shows 124 unfreed concatenate_states blocks (~5 MB each — one per past step) plus thousands of _per_system_vdot / lbfgs_step intermediates pinned by the chain.

With fix, same config:

Optimizer Before After
L-BFGS OOM ~10 min completes in 797 s, alloc flat ≈ 340 MB
BFGS OOM ~10 min completes in 625 s
FIRE works unchanged (581 s, no regression)

Final energies match within numerical noise (< 0.04 eV mean).

Tests

tests/test_optimizers.py (65/65), tests/test_optimizer_states.py, tests/test_runners.py (35/35) all pass. tests/test_optimizers_vs_ase.py skips (requires mace, unrelated). Ruff check + format clean.

L-BFGS's torch.cat extension of s_history/y_history and BFGS's in-place
hessian update both build an autograd chain when the model requires
positions.requires_grad=True (e.g. conservative force fields computing
forces via torch.autograd.grad). Each step adds one link, pinning every
previous buffer plus its saved-for-backward intermediates and leaking
tens of MB per step. The buffers are pure numerical state; no backward
is ever taken through them.
@niklashoelter niklashoelter marked this pull request as ready for review May 12, 2026 21:21
@CompRhys
Copy link
Copy Markdown
Member

Thanks for the contribution, the core test error seems unrelated but will need to diagnose that fully before we can merge this!

@CompRhys
Copy link
Copy Markdown
Member

#556 I wonder if you could also test this, some users have suggested that a 10x larger default step size would improve the LBFGS, can you test that for your case and if +ve lets bump the default?

@niklashoelter
Copy link
Copy Markdown
Contributor Author

Ran the step-size sweep on the same 1614-structure heterogeneous molecular dataset (ORB-v3 conservative, force_tol=5e-2, max_memory_padding=0.95, single 24 GB GPU), with the .detach() fix from this PR applied:

L-BFGS step_size:

step_size Wall time Total step calls Mean energy (eV)
0.1 (current default) 785 s 1245 −38228.894
0.25 587 s 918 −38228.902
0.5 497 s 887 −38228.908
1.0 438 s 745 −38228.909
1.5 423 s 743 −38228.903

step_size=1.0 is 1.8× faster than the current default and uses 40 % fewer total iterations, with energies matching to ≲ 0.02 eV. Bumping the default looks like a clear win on this workload. (Values ≥ 2.0 diverge — the optimizer overshoots and a long tail of structures oscillates without converging.)

BFGS max_step: swept 0.1 → 2.5; insensitive (wall time and total iterations are flat to within noise). max_step acts as a hard cap on per-iteration displacement, not a multiplier, and the natural BFGS step is already below 0.1 Å for these molecules, so raising the cap has no effect.

@CompRhys CompRhys merged commit a78df6f into TorchSim:main May 13, 2026
58 of 64 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Better initial L-BFGS stepsize as default leads to slow convergence in many cases.

2 participants