Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions src/somd2/config/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ def __init__(
overwrite=False,
somd1_compatibility=False,
pert_file=None,
auto_fix_minimise=True,
save_crash_report=False,
save_energy_components=False,
page_size=None,
Expand Down Expand Up @@ -496,6 +497,10 @@ def __init__(
The path to a SOMD1 perturbation file to apply to the reference system.
When set, this will automatically set 'somd1_compatibility' to True.

auto_fix_minimise: bool
Whether to attempt to automatically recover from simulation instabilities
by minimising and restarting. Defaults to True.

save_crash_report: bool
Whether to save a crash report if the simulation crashes.

Expand Down Expand Up @@ -599,6 +604,7 @@ def __init__(
self.taylor_power = taylor_power
self.somd1_compatibility = somd1_compatibility
self.pert_file = pert_file
self.auto_fix_minimise = auto_fix_minimise
self.save_crash_report = save_crash_report
self.save_energy_components = save_energy_components
self.timeout = timeout
Expand Down Expand Up @@ -2383,6 +2389,16 @@ def pert_file(self, pert_file):

self._pert_file = pert_file

@property
def auto_fix_minimise(self):
return self._auto_fix_minimise

@auto_fix_minimise.setter
def auto_fix_minimise(self, auto_fix_minimise):
if not isinstance(auto_fix_minimise, bool):
raise ValueError("'auto_fix_minimise' must be of type 'bool'")
self._auto_fix_minimise = auto_fix_minimise

@property
def save_crash_report(self):
return self._save_crash_report
Expand Down
61 changes: 35 additions & 26 deletions src/somd2/runner/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1820,26 +1820,33 @@ def _checkpoint(
# Get the lambda value.
lam = self._lambda_values[index]

# -1 is the sentinel for a post-equilibration checkpoint. No
# energies are collected during equilibration, so skip all
# parquet-related work in this case.
is_post_equilibration = block == -1

# Get the energy trajectory.
df = system.energy_trajectory(to_alchemlyb=True, energy_unit="kT")
if not is_post_equilibration:
df = system.energy_trajectory(to_alchemlyb=True, energy_unit="kT")

# Set the lambda values at which energies were sampled.
if lambda_energy is None:
lambda_energy = self._lambda_values

# Create the metadata.
metadata = {
"attrs": df.attrs,
"somd2 version": __version__,
"sire version": f"{_sire_version}+{_sire_revisionid}",
"lambda": f"{lam:.5f}",
"speed": speed,
"temperature": str(self._config.temperature.value()),
}

# Add the lambda gradient if available.
if lambda_grad is not None:
metadata["lambda_grad"] = [f"{v:.5f}" for v in lambda_grad]
if not is_post_equilibration:
metadata = {
"attrs": df.attrs,
"somd2 version": __version__,
"sire version": f"{_sire_version}+{_sire_revisionid}",
"lambda": f"{lam:.5f}",
"speed": speed,
"temperature": str(self._config.temperature.value()),
}

# Add the lambda gradient if available.
if lambda_grad is not None:
metadata["lambda_grad"] = [f"{v:.5f}" for v in lambda_grad]

if is_final_block:
# Save the end-state GCMC topologies for trajectory analysis and visualisation.
Expand Down Expand Up @@ -1930,7 +1937,7 @@ def _checkpoint(

else:
# Update the starting block if necessary.
if block == 0:
if block <= 0:
block = self._start_block

# Save the current trajectory chunk to file.
Expand Down Expand Up @@ -1958,18 +1965,20 @@ def _checkpoint(
# Stream the checkpoint to file.
_sr.stream.save(system, self._filenames[index]["checkpoint"])

# Create the parquet file name.
filename = self._filenames[index]["energy_traj"]

# Create the parquet file.
if block == self._start_block:
_dataframe_to_parquet(df, metadata=metadata, filename=filename)
# Append to the parquet file.
else:
_parquet_append(
filename,
df.iloc[-self._energy_per_block :],
)
# Skip parquet creation for post-equilibration checkpoints.
if not is_post_equilibration:
# Create the parquet file name.
filename = self._filenames[index]["energy_traj"]

# Create the parquet file.
if block == self._start_block:
_dataframe_to_parquet(df, metadata=metadata, filename=filename)
# Append to the parquet file.
else:
_parquet_append(
filename,
df.iloc[-self._energy_per_block :],
)

except Exception as e:
return index, e
Expand Down
77 changes: 70 additions & 7 deletions src/somd2/runner/_repex.py
Original file line number Diff line number Diff line change
Expand Up @@ -979,6 +979,39 @@ def run(self):
_logger.error("Equilibration cancelled. Exiting.")
_sys.exit(1)

# Write a checkpoint immediately after equilibration so that a restart
# after an early production crash doesn't need to re-equilibrate.
if self._is_equilibration and not self._is_restart:
lock = _FileLock(self._lock_file)
with lock.acquire(timeout=self._config.timeout.to("seconds")):
for j in range(num_checkpoint_batches):
replicas = replica_list[
j * num_checkpoint_workers : (j + 1) * num_checkpoint_workers
]
with ThreadPoolExecutor(
max_workers=num_checkpoint_workers
) as executor:
try:
for index, error in executor.map(
self._checkpoint,
replicas,
repeat(self._lambda_values),
repeat(-1),
repeat(cycles),
):
if error is not None:
msg = (
f"Post-equilibration checkpoint failed for {_lam_sym} = "
f"{self._lambda_values[index]:.5f}:\n{error}"
)
_logger.error(msg)
raise error
except KeyboardInterrupt:
_logger.error(
"Post-equilibration checkpoint cancelled. Exiting."
)
_sys.exit(1)

# Current block number.
block = self._start_block

Expand Down Expand Up @@ -1149,6 +1182,13 @@ def run(self):
)
self._dynamics_cache.mix_states()

# Snapshot the pre-run state for crash recovery.
if self._config.auto_fix_minimise:
for i, state in enumerate(self._dynamics_cache.get_states()):
self._dynamics_cache._dynamics[
i
]._d._pre_run_state = self._dynamics_cache._openmm_states[state]

# This is a checkpoint cycle.
if is_checkpoint:
# Update the block number.
Expand Down Expand Up @@ -1278,6 +1318,12 @@ def _run_block(
# Get the dynamics object (and GCMC sampler).
dynamics, gcmc_sampler = self._dynamics_cache.get(index)

# Track whether any MC move changed the context positions so we
# can update _pre_run_state once at the end. Only needed when
# crash recovery is enabled.
needs_pre_run_snapshot = False
auto_fix_minimise = self._config.auto_fix_minimise

# Perform the GCMC move before dynamics so that the energies
# computed during dynamics are consistent with the state used
# for replica exchange mixing.
Expand All @@ -1289,6 +1335,9 @@ def _run_block(
finally:
gcmc_sampler.pop()

if auto_fix_minimise:
needs_pre_run_snapshot = True

# Write ghost residues immediately after the GCMC move so the
# ghost state and frame (saved during dynamics) are consistent.
if write_gcmc_ghosts:
Expand All @@ -1297,7 +1346,16 @@ def _run_block(
# Perform a terminal flip move before dynamics if requested.
if self._terminal_flip_samplers is not None and is_terminal_flip:
_logger.info(f"Performing terminal flip move at {_lam_sym} = {lam:.5f}")
self._terminal_flip_samplers[index].move(dynamics.context())
if self._terminal_flip_samplers[index].move(dynamics.context()):
if auto_fix_minimise:
needs_pre_run_snapshot = True

# Snapshot the context state for crash recovery if any MC move
# changed positions.
if needs_pre_run_snapshot:
dynamics._d._pre_run_state = dynamics.context().getState(
getPositions=True, getVelocities=True
)

_logger.info(f"Running dynamics at {_lam_sym} = {lam:.5f}")

Expand All @@ -1313,7 +1371,7 @@ def _run_block(
lambda_windows=lambdas,
rest2_scale_factors=self._rest2_scale_factors,
save_velocities=self._config.save_velocities,
auto_fix_minimise=True,
auto_fix_minimise=self._config.auto_fix_minimise,
num_energy_neighbours=self._config.num_energy_neighbours,
null_energy=self._config.null_energy,
save_crash_report=self._config.save_crash_report,
Expand Down Expand Up @@ -1544,7 +1602,7 @@ def _equilibrate(self, index):
energy_frequency=0,
frame_frequency=0,
save_velocities=False,
auto_fix_minimise=True,
auto_fix_minimise=self._config.auto_fix_minimise,
save_crash_report=self._config.save_crash_report,
)

Expand Down Expand Up @@ -1728,10 +1786,15 @@ def _checkpoint(self, index, lambdas, block, num_blocks, is_final_block=False):
# dynamics object.
dynamics._d._sire_mols.delete_all_frames()

_logger.info(
f"Finished block {block + 1} of {self._start_block + num_blocks} "
f"for {_lam_sym} = {lam:.5f}"
)
if block == -1:
_logger.info(
f"Writing post-equilibration checkpoint for {_lam_sym} = {lam:.5f}"
)
else:
_logger.info(
f"Finished block {block + 1} of {self._start_block + num_blocks} "
f"for {_lam_sym} = {lam:.5f}"
)

# Log the number of waters within the GCMC sampling volume.
if gcmc_sampler is not None:
Expand Down
Loading
Loading