fix: detach optimizer-state buffers in L-BFGS / BFGS step#559
Conversation
L-BFGS's torch.cat extension of s_history/y_history and BFGS's in-place hessian update both build an autograd chain when the model requires positions.requires_grad=True (e.g. conservative force fields computing forces via torch.autograd.grad). Each step adds one link, pinning every previous buffer plus its saved-for-backward intermediates and leaking tens of MB per step. The buffers are pure numerical state; no backward is ever taken through them.
|
Thanks for the contribution, the core test error seems unrelated but will need to diagnose that fully before we can merge this! |
|
#556 I wonder if you could also test this, some users have suggested that a 10x larger default step size would improve the LBFGS, can you test that for your case and if +ve lets bump the default? |
|
Ran the step-size sweep on the same 1614-structure heterogeneous molecular dataset (ORB-v3 conservative, force_tol=5e-2, max_memory_padding=0.95, single 24 GB GPU), with the L-BFGS
BFGS |
Summary
L-BFGS and BFGS leak memory linearly in the number of optimizer steps when used with a model that requires
positions.requires_grad=True(e.g. conservative force fields computing forces viatorch.autograd.grad, such as ORB-v3*_conservative_*). Three.detach()calls fix it.Root cause
Both optimizers update an incremental per-step buffer derived from grad-tracked positions/forces:
state.s_history = torch.cat([state.s_history, s_new.unsqueeze(1)], dim=1)—s_newhasgrad_fnfrom positions, so the cat result has agrad_fnwhose saved tensors include the previousstate.s_history. Iterated → one link per step → every past history tensor plus its saved-for-backward intermediates stay alive.state.hessian[idx] = H - term1 - term2— same pattern viadp/df.Neither buffer is ever differentiated through.
FIRE is unaffected (overwrites state each step, no incremental buffer).
Evidence
Reproducer:
InFlightAutoBatcher+ ORB-v3 conservative + L-BFGS on 1614 molecular structures (28–185 atoms),padding=0.95,steps_between_swaps=1, 24 GB GPU.Without fix: batch composition stable (
n_atoms~6000,n_edgesat cap), yetmemory_allocatedgrows linearly +22.6 MB/swap → OOM at ~10 min.torch.cuda.memory._snapshotat swap=150 shows 124 unfreedconcatenate_statesblocks (~5 MB each — one per past step) plus thousands of_per_system_vdot/lbfgs_stepintermediates pinned by the chain.With fix, same config:
Final energies match within numerical noise (< 0.04 eV mean).
Tests
tests/test_optimizers.py(65/65),tests/test_optimizer_states.py,tests/test_runners.py(35/35) all pass.tests/test_optimizers_vs_ase.pyskips (requiresmace, unrelated). Ruff check + format clean.