From 492b2a5b7db4942abf5fa7d7caeead4bd7614d71 Mon Sep 17 00:00:00 2001 From: Lester Hedges Date: Mon, 13 Apr 2026 10:57:33 +0100 Subject: [PATCH 1/4] Set dynamics _pre_run_state for crash recovery from mix_states and MC moves. --- src/somd2/runner/_repex.py | 21 +++++++++++++++- src/somd2/runner/_runner.py | 48 +++++++++++++++++++++++++++++-------- 2 files changed, 58 insertions(+), 11 deletions(-) diff --git a/src/somd2/runner/_repex.py b/src/somd2/runner/_repex.py index b019030..e8c5a3d 100644 --- a/src/somd2/runner/_repex.py +++ b/src/somd2/runner/_repex.py @@ -526,6 +526,10 @@ def mix_states(self): old_state = self._old_states[i] self._num_swaps[old_state, state] += 1 + # Snapshot the pre-run state for crash recovery. + for i, state in enumerate(self._states): + self._dynamics[i]._d._pre_run_state = self._openmm_states[state] + # Store the current states. self._old_states = self._states.copy() @@ -1278,6 +1282,10 @@ def _run_block( # Get the dynamics object (and GCMC sampler). dynamics, gcmc_sampler = self._dynamics_cache.get(index) + # Track whether any MC move changed the context positions so we + # can update _pre_run_state once at the end. + needs_pre_run_snapshot = False + # Perform the GCMC move before dynamics so that the energies # computed during dynamics are consistent with the state used # for replica exchange mixing. @@ -1289,6 +1297,8 @@ def _run_block( finally: gcmc_sampler.pop() + needs_pre_run_snapshot = True + # Write ghost residues immediately after the GCMC move so the # ghost state and frame (saved during dynamics) are consistent. if write_gcmc_ghosts: @@ -1297,7 +1307,16 @@ def _run_block( # Perform a terminal flip move before dynamics if requested. if self._terminal_flip_samplers is not None and is_terminal_flip: _logger.info(f"Performing terminal flip move at {_lam_sym} = {lam:.5f}") - self._terminal_flip_samplers[index].move(dynamics.context()) + if self._terminal_flip_samplers[index].move(dynamics.context()): + needs_pre_run_snapshot = True + + # Snapshot the context state for crash recovery if any MC move + # changed positions. This overwrites the snapshot set in + # mix_states() so _rebuild_and_minimise() has a consistent state. + if needs_pre_run_snapshot: + dynamics._d._pre_run_state = dynamics.context().getState( + getPositions=True, getVelocities=True + ) _logger.info(f"Running dynamics at {_lam_sym} = {lam:.5f}") diff --git a/src/somd2/runner/_runner.py b/src/somd2/runner/_runner.py index 54345e4..b735608 100644 --- a/src/somd2/runner/_runner.py +++ b/src/somd2/runner/_runner.py @@ -770,6 +770,9 @@ def generate_lam_vals(lambda_base, increment=0.001): finally: gcmc_sampler.pop() + # GCMC always changes positions. + needs_pre_run_snapshot = True + # Perform a terminal flip move at the specified frequency. if ( terminal_flip_sampler is not None @@ -779,11 +782,23 @@ def generate_lam_vals(lambda_base, increment=0.001): f"Performing terminal flip move at " f"{_lam_sym} = {lambda_value:.5f}" ) - if ( - terminal_flip_sampler.move(dynamics.context()) - and self._config.randomise_velocities - ): - dynamics.randomise_velocities() + flip_accepted = terminal_flip_sampler.move( + dynamics.context() + ) + if flip_accepted: + needs_pre_run_snapshot = True + if self._config.randomise_velocities: + dynamics.randomise_velocities() + + # Snapshot the context state once for crash recovery + # if any MC move changed positions. + if needs_pre_run_snapshot: + dynamics._d._pre_run_state = ( + dynamics.context().getState( + getPositions=True, getVelocities=True + ) + ) + needs_pre_run_snapshot = False # Write ghost residues immediately after the GCMC # move if a frame will be saved in the upcoming @@ -1090,6 +1105,9 @@ def generate_lam_vals(lambda_base, increment=0.001): finally: gcmc_sampler.pop() + # GCMC always changes positions. + needs_pre_run_snapshot = True + # Perform a terminal flip move at the specified frequency. if ( terminal_flip_sampler is not None @@ -1099,11 +1117,21 @@ def generate_lam_vals(lambda_base, increment=0.001): f"Performing terminal flip move at " f"{_lam_sym} = {lambda_value:.5f}" ) - if ( - terminal_flip_sampler.move(dynamics.context()) - and self._config.randomise_velocities - ): - dynamics.randomise_velocities() + flip_accepted = terminal_flip_sampler.move( + dynamics.context() + ) + if flip_accepted: + needs_pre_run_snapshot = True + if self._config.randomise_velocities: + dynamics.randomise_velocities() + + # Snapshot the context state once for crash recovery + # if any MC move changed positions. + if needs_pre_run_snapshot: + dynamics._d._pre_run_state = dynamics.context().getState( + getPositions=True, getVelocities=True + ) + needs_pre_run_snapshot = False # Write ghost residues immediately after the GCMC # move if a frame will be saved in the upcoming From 597994653c5f9a9bc5ae3d3c55ad340cfe83a66e Mon Sep 17 00:00:00 2001 From: Lester Hedges Date: Mon, 13 Apr 2026 11:52:53 +0100 Subject: [PATCH 2/4] Expose auto_fix_minimise dynamics run option. --- src/somd2/config/_config.py | 16 ++++++++++++++++ src/somd2/runner/_repex.py | 4 ++-- src/somd2/runner/_runner.py | 16 ++++++++-------- 3 files changed, 26 insertions(+), 10 deletions(-) diff --git a/src/somd2/config/_config.py b/src/somd2/config/_config.py index 3ab4449..e2969fa 100644 --- a/src/somd2/config/_config.py +++ b/src/somd2/config/_config.py @@ -162,6 +162,7 @@ def __init__( overwrite=False, somd1_compatibility=False, pert_file=None, + auto_fix_minimise=True, save_crash_report=False, save_energy_components=False, page_size=None, @@ -496,6 +497,10 @@ def __init__( The path to a SOMD1 perturbation file to apply to the reference system. When set, this will automatically set 'somd1_compatibility' to True. + auto_fix_minimise: bool + Whether to attempt to automatically recover from simulation instabilities + by minimising and restarting. Defaults to True. + save_crash_report: bool Whether to save a crash report if the simulation crashes. @@ -599,6 +604,7 @@ def __init__( self.taylor_power = taylor_power self.somd1_compatibility = somd1_compatibility self.pert_file = pert_file + self.auto_fix_minimise = auto_fix_minimise self.save_crash_report = save_crash_report self.save_energy_components = save_energy_components self.timeout = timeout @@ -2383,6 +2389,16 @@ def pert_file(self, pert_file): self._pert_file = pert_file + @property + def auto_fix_minimise(self): + return self._auto_fix_minimise + + @auto_fix_minimise.setter + def auto_fix_minimise(self, auto_fix_minimise): + if not isinstance(auto_fix_minimise, bool): + raise ValueError("'auto_fix_minimise' must be of type 'bool'") + self._auto_fix_minimise = auto_fix_minimise + @property def save_crash_report(self): return self._save_crash_report diff --git a/src/somd2/runner/_repex.py b/src/somd2/runner/_repex.py index e8c5a3d..be189ea 100644 --- a/src/somd2/runner/_repex.py +++ b/src/somd2/runner/_repex.py @@ -1332,7 +1332,7 @@ def _run_block( lambda_windows=lambdas, rest2_scale_factors=self._rest2_scale_factors, save_velocities=self._config.save_velocities, - auto_fix_minimise=True, + auto_fix_minimise=self._config.auto_fix_minimise, num_energy_neighbours=self._config.num_energy_neighbours, null_energy=self._config.null_energy, save_crash_report=self._config.save_crash_report, @@ -1563,7 +1563,7 @@ def _equilibrate(self, index): energy_frequency=0, frame_frequency=0, save_velocities=False, - auto_fix_minimise=True, + auto_fix_minimise=self._config.auto_fix_minimise, save_crash_report=self._config.save_crash_report, ) diff --git a/src/somd2/runner/_runner.py b/src/somd2/runner/_runner.py index b735608..cc3a8c4 100644 --- a/src/somd2/runner/_runner.py +++ b/src/somd2/runner/_runner.py @@ -578,7 +578,7 @@ def generate_lam_vals(lambda_base, increment=0.001): energy_frequency=0, frame_frequency=0, save_velocities=False, - auto_fix_minimise=True, + auto_fix_minimise=self._config.auto_fix_minimise, save_crash_report=self._config.save_crash_report, ) @@ -819,7 +819,7 @@ def generate_lam_vals(lambda_base, increment=0.001): lambda_windows=lambda_array, rest2_scale_factors=rest2_scale_factors, save_velocities=self._config.save_velocities, - auto_fix_minimise=True, + auto_fix_minimise=self._config.auto_fix_minimise, num_energy_neighbours=num_energy_neighbours, null_energy=self._config.null_energy, save_crash_report=self._config.save_crash_report, @@ -868,7 +868,7 @@ def generate_lam_vals(lambda_base, increment=0.001): lambda_windows=lambda_array, rest2_scale_factors=rest2_scale_factors, save_velocities=self._config.save_velocities, - auto_fix_minimise=True, + auto_fix_minimise=self._config.auto_fix_minimise, num_energy_neighbours=num_energy_neighbours, null_energy=self._config.null_energy, save_crash_report=self._config.save_crash_report, @@ -882,7 +882,7 @@ def generate_lam_vals(lambda_base, increment=0.001): lambda_windows=lambda_array, rest2_scale_factors=rest2_scale_factors, save_velocities=self._config.save_velocities, - auto_fix_minimise=True, + auto_fix_minimise=self._config.auto_fix_minimise, num_energy_neighbours=num_energy_neighbours, null_energy=self._config.null_energy, save_crash_report=self._config.save_crash_report, @@ -1024,7 +1024,7 @@ def generate_lam_vals(lambda_base, increment=0.001): lambda_windows=lambda_array, rest2_scale_factors=rest2_scale_factors, save_velocities=self._config.save_velocities, - auto_fix_minimise=True, + auto_fix_minimise=self._config.auto_fix_minimise, num_energy_neighbours=num_energy_neighbours, null_energy=self._config.null_energy, save_crash_report=self._config.save_crash_report, @@ -1151,7 +1151,7 @@ def generate_lam_vals(lambda_base, increment=0.001): lambda_windows=lambda_array, rest2_scale_factors=rest2_scale_factors, save_velocities=self._config.save_velocities, - auto_fix_minimise=True, + auto_fix_minimise=self._config.auto_fix_minimise, num_energy_neighbours=num_energy_neighbours, null_energy=self._config.null_energy, save_crash_report=self._config.save_crash_report, @@ -1186,7 +1186,7 @@ def generate_lam_vals(lambda_base, increment=0.001): lambda_windows=lambda_array, rest2_scale_factors=rest2_scale_factors, save_velocities=self._config.save_velocities, - auto_fix_minimise=True, + auto_fix_minimise=self._config.auto_fix_minimise, num_energy_neighbours=num_energy_neighbours, null_energy=self._config.null_energy, save_crash_report=self._config.save_crash_report, @@ -1200,7 +1200,7 @@ def generate_lam_vals(lambda_base, increment=0.001): lambda_windows=lambda_array, rest2_scale_factors=rest2_scale_factors, save_velocities=self._config.save_velocities, - auto_fix_minimise=True, + auto_fix_minimise=self._config.auto_fix_minimise, num_energy_neighbours=num_energy_neighbours, null_energy=self._config.null_energy, save_crash_report=self._config.save_crash_report, From 3243d7b4764bdc74deb4d5828e5fb8237758aa35 Mon Sep 17 00:00:00 2001 From: Lester Hedges Date: Mon, 13 Apr 2026 12:28:31 +0100 Subject: [PATCH 3/4] Only set pre-run state when auto_fix_minimise=True. --- src/somd2/runner/_repex.py | 24 +++++++++++++++--------- src/somd2/runner/_runner.py | 5 +++-- 2 files changed, 18 insertions(+), 11 deletions(-) diff --git a/src/somd2/runner/_repex.py b/src/somd2/runner/_repex.py index be189ea..4c6aa2e 100644 --- a/src/somd2/runner/_repex.py +++ b/src/somd2/runner/_repex.py @@ -526,10 +526,6 @@ def mix_states(self): old_state = self._old_states[i] self._num_swaps[old_state, state] += 1 - # Snapshot the pre-run state for crash recovery. - for i, state in enumerate(self._states): - self._dynamics[i]._d._pre_run_state = self._openmm_states[state] - # Store the current states. self._old_states = self._states.copy() @@ -1153,6 +1149,13 @@ def run(self): ) self._dynamics_cache.mix_states() + # Snapshot the pre-run state for crash recovery. + if self._config.auto_fix_minimise: + for i, state in enumerate(self._dynamics_cache.get_states()): + self._dynamics_cache._dynamics[ + i + ]._d._pre_run_state = self._dynamics_cache._openmm_states[state] + # This is a checkpoint cycle. if is_checkpoint: # Update the block number. @@ -1283,8 +1286,10 @@ def _run_block( dynamics, gcmc_sampler = self._dynamics_cache.get(index) # Track whether any MC move changed the context positions so we - # can update _pre_run_state once at the end. + # can update _pre_run_state once at the end. Only needed when + # crash recovery is enabled. needs_pre_run_snapshot = False + auto_fix_minimise = self._config.auto_fix_minimise # Perform the GCMC move before dynamics so that the energies # computed during dynamics are consistent with the state used @@ -1297,7 +1302,8 @@ def _run_block( finally: gcmc_sampler.pop() - needs_pre_run_snapshot = True + if auto_fix_minimise: + needs_pre_run_snapshot = True # Write ghost residues immediately after the GCMC move so the # ghost state and frame (saved during dynamics) are consistent. @@ -1308,11 +1314,11 @@ def _run_block( if self._terminal_flip_samplers is not None and is_terminal_flip: _logger.info(f"Performing terminal flip move at {_lam_sym} = {lam:.5f}") if self._terminal_flip_samplers[index].move(dynamics.context()): - needs_pre_run_snapshot = True + if auto_fix_minimise: + needs_pre_run_snapshot = True # Snapshot the context state for crash recovery if any MC move - # changed positions. This overwrites the snapshot set in - # mix_states() so _rebuild_and_minimise() has a consistent state. + # changed positions. if needs_pre_run_snapshot: dynamics._d._pre_run_state = dynamics.context().getState( getPositions=True, getVelocities=True diff --git a/src/somd2/runner/_runner.py b/src/somd2/runner/_runner.py index cc3a8c4..188f8ce 100644 --- a/src/somd2/runner/_runner.py +++ b/src/somd2/runner/_runner.py @@ -771,7 +771,7 @@ def generate_lam_vals(lambda_base, increment=0.001): gcmc_sampler.pop() # GCMC always changes positions. - needs_pre_run_snapshot = True + needs_pre_run_snapshot = self._config.auto_fix_minimise # Perform a terminal flip move at the specified frequency. if ( @@ -786,7 +786,8 @@ def generate_lam_vals(lambda_base, increment=0.001): dynamics.context() ) if flip_accepted: - needs_pre_run_snapshot = True + if self._config.auto_fix_minimise: + needs_pre_run_snapshot = True if self._config.randomise_velocities: dynamics.randomise_velocities() From 727269cf56b070cc54a9f937f2d6f95963b8ace0 Mon Sep 17 00:00:00 2001 From: Lester Hedges Date: Mon, 13 Apr 2026 14:03:39 +0100 Subject: [PATCH 4/4] Add support for post-equilibration checkpoint. --- src/somd2/runner/_base.py | 61 +++++++++++++++++++++---------------- src/somd2/runner/_repex.py | 46 +++++++++++++++++++++++++--- src/somd2/runner/_runner.py | 25 +++++++++++++++ 3 files changed, 102 insertions(+), 30 deletions(-) diff --git a/src/somd2/runner/_base.py b/src/somd2/runner/_base.py index f4c2ed7..3e05798 100644 --- a/src/somd2/runner/_base.py +++ b/src/somd2/runner/_base.py @@ -1820,26 +1820,33 @@ def _checkpoint( # Get the lambda value. lam = self._lambda_values[index] + # -1 is the sentinel for a post-equilibration checkpoint. No + # energies are collected during equilibration, so skip all + # parquet-related work in this case. + is_post_equilibration = block == -1 + # Get the energy trajectory. - df = system.energy_trajectory(to_alchemlyb=True, energy_unit="kT") + if not is_post_equilibration: + df = system.energy_trajectory(to_alchemlyb=True, energy_unit="kT") # Set the lambda values at which energies were sampled. if lambda_energy is None: lambda_energy = self._lambda_values # Create the metadata. - metadata = { - "attrs": df.attrs, - "somd2 version": __version__, - "sire version": f"{_sire_version}+{_sire_revisionid}", - "lambda": f"{lam:.5f}", - "speed": speed, - "temperature": str(self._config.temperature.value()), - } - - # Add the lambda gradient if available. - if lambda_grad is not None: - metadata["lambda_grad"] = [f"{v:.5f}" for v in lambda_grad] + if not is_post_equilibration: + metadata = { + "attrs": df.attrs, + "somd2 version": __version__, + "sire version": f"{_sire_version}+{_sire_revisionid}", + "lambda": f"{lam:.5f}", + "speed": speed, + "temperature": str(self._config.temperature.value()), + } + + # Add the lambda gradient if available. + if lambda_grad is not None: + metadata["lambda_grad"] = [f"{v:.5f}" for v in lambda_grad] if is_final_block: # Save the end-state GCMC topologies for trajectory analysis and visualisation. @@ -1930,7 +1937,7 @@ def _checkpoint( else: # Update the starting block if necessary. - if block == 0: + if block <= 0: block = self._start_block # Save the current trajectory chunk to file. @@ -1958,18 +1965,20 @@ def _checkpoint( # Stream the checkpoint to file. _sr.stream.save(system, self._filenames[index]["checkpoint"]) - # Create the parquet file name. - filename = self._filenames[index]["energy_traj"] - - # Create the parquet file. - if block == self._start_block: - _dataframe_to_parquet(df, metadata=metadata, filename=filename) - # Append to the parquet file. - else: - _parquet_append( - filename, - df.iloc[-self._energy_per_block :], - ) + # Skip parquet creation for post-equilibration checkpoints. + if not is_post_equilibration: + # Create the parquet file name. + filename = self._filenames[index]["energy_traj"] + + # Create the parquet file. + if block == self._start_block: + _dataframe_to_parquet(df, metadata=metadata, filename=filename) + # Append to the parquet file. + else: + _parquet_append( + filename, + df.iloc[-self._energy_per_block :], + ) except Exception as e: return index, e diff --git a/src/somd2/runner/_repex.py b/src/somd2/runner/_repex.py index 4c6aa2e..19e474d 100644 --- a/src/somd2/runner/_repex.py +++ b/src/somd2/runner/_repex.py @@ -979,6 +979,39 @@ def run(self): _logger.error("Equilibration cancelled. Exiting.") _sys.exit(1) + # Write a checkpoint immediately after equilibration so that a restart + # after an early production crash doesn't need to re-equilibrate. + if self._is_equilibration and not self._is_restart: + lock = _FileLock(self._lock_file) + with lock.acquire(timeout=self._config.timeout.to("seconds")): + for j in range(num_checkpoint_batches): + replicas = replica_list[ + j * num_checkpoint_workers : (j + 1) * num_checkpoint_workers + ] + with ThreadPoolExecutor( + max_workers=num_checkpoint_workers + ) as executor: + try: + for index, error in executor.map( + self._checkpoint, + replicas, + repeat(self._lambda_values), + repeat(-1), + repeat(cycles), + ): + if error is not None: + msg = ( + f"Post-equilibration checkpoint failed for {_lam_sym} = " + f"{self._lambda_values[index]:.5f}:\n{error}" + ) + _logger.error(msg) + raise error + except KeyboardInterrupt: + _logger.error( + "Post-equilibration checkpoint cancelled. Exiting." + ) + _sys.exit(1) + # Current block number. block = self._start_block @@ -1753,10 +1786,15 @@ def _checkpoint(self, index, lambdas, block, num_blocks, is_final_block=False): # dynamics object. dynamics._d._sire_mols.delete_all_frames() - _logger.info( - f"Finished block {block + 1} of {self._start_block + num_blocks} " - f"for {_lam_sym} = {lam:.5f}" - ) + if block == -1: + _logger.info( + f"Writing post-equilibration checkpoint for {_lam_sym} = {lam:.5f}" + ) + else: + _logger.info( + f"Finished block {block + 1} of {self._start_block + num_blocks} " + f"for {_lam_sym} = {lam:.5f}" + ) # Log the number of waters within the GCMC sampling volume. if gcmc_sampler is not None: diff --git a/src/somd2/runner/_runner.py b/src/somd2/runner/_runner.py index 188f8ce..8dabc36 100644 --- a/src/somd2/runner/_runner.py +++ b/src/somd2/runner/_runner.py @@ -719,6 +719,31 @@ def generate_lam_vals(lambda_base, increment=0.001): # Store the checkpoint time in nanoseconds. checkpoint_interval = checkpoint_frequency.to("ns") + # Write a checkpoint immediately after equilibration so that a restart + # after an early production crash doesn't need to re-equilibrate. + if is_equilibrated: + lock = _FileLock(self._lock_file) + with lock.acquire(timeout=self._config.timeout.to("seconds")): + _, error = self._checkpoint( + system, + index, + block=-1, + speed=0.0, + lambda_energy=lambda_energy, + lambda_grad=lambda_grad, + ) + if error is not None: + msg = ( + f"Post-equilibration checkpoint failed for {_lam_sym} = " + f"{lambda_value:.5f}:\n{error}" + ) + _logger.error(msg) + raise error + _logger.info( + f"Writing post-equilibration checkpoint " + f"for {_lam_sym} = {lambda_value:.5f}" + ) + # Store the start time. start = _timer()