"""
Simulation runner for executing MD simulations with OpenMM.
This module handles running equilibration and production phases
with configurable parameters, reporters, and checkpoint management.
"""
from __future__ import annotations
import json
import logging
from dataclasses import asdict
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Union
import openmm
from openmm import XmlSerializer
from openmm import unit as omm_unit
from openmm.app import CheckpointReporter, DCDReporter, PDBFile, Simulation, StateDataReporter
if TYPE_CHECKING:
from polyzymd.config.schema import (
EquilibrationStageConfig,
SimulationConfig,
SimulationPhaseConfig,
SimulationPhasesConfig,
)
from polyzymd.core.atom_groups import AtomGroupResolver
from polyzymd.core.parameters import SimulationParameters
LOGGER = logging.getLogger(__name__)
# Phase types
PhaseType = Literal["equilibration", "production"]
[docs]
class SimulationRunner:
"""Runner for executing OpenMM molecular dynamics simulations.
This class manages:
- Equilibration and production phase execution
- Checkpoint saving and state data reporting
- Energy minimization
- Unique force group assignment for energy decomposition
Example:
>>> runner = SimulationRunner(
... topology=omm_topology,
... system=omm_system,
... positions=omm_positions,
... working_dir="output/",
... )
>>> runner.minimize()
>>> runner.run_equilibration(temperature=300, duration_ns=0.5)
>>> runner.run_production(temperature=300, duration_ns=100)
"""
[docs]
def __init__(
self,
topology: Any,
system: openmm.System,
positions: Any,
working_dir: Union[str, Path],
platform: str = "CUDA",
) -> None:
"""Initialize the SimulationRunner.
Args:
topology: OpenMM Topology.
system: OpenMM System.
positions: Initial positions with units.
working_dir: Working directory for output files.
platform: Compute platform (CUDA, OpenCL, CPU).
"""
self._topology = topology
self._system = system
self._positions = positions
self._working_dir = Path(working_dir)
self._platform_name = platform
self._simulation: Optional[Simulation] = None
self._current_positions = positions
self._current_velocities = None # Carried between equilibration stages
self._current_box_vectors = None # Updated during NPT stages
self._history: Dict[str, Any] = {}
# Ensure working directory exists
self._working_dir.mkdir(parents=True, exist_ok=True)
# Apply unique force groups for energy decomposition
self._impose_unique_force_groups()
@property
def simulation(self) -> Optional[Simulation]:
"""Get the current OpenMM Simulation object."""
return self._simulation
@property
def working_dir(self) -> Path:
"""Get the working directory path."""
return self._working_dir
@property
def history(self) -> Dict[str, Any]:
"""Get the simulation history."""
return self._history
def _impose_unique_force_groups(self) -> None:
"""Assign unique force groups to each force for energy decomposition."""
from polyzymd.utils import impose_unique_force_groups
impose_unique_force_groups(self._system)
LOGGER.debug("Assigned unique force groups")
def _get_platform(self) -> openmm.Platform:
"""Get the compute platform.
Returns:
OpenMM Platform object.
"""
try:
platform = openmm.Platform.getPlatformByName(self._platform_name)
LOGGER.info(f"Using {self._platform_name} platform")
except Exception:
LOGGER.warning(f"Platform {self._platform_name} not available, falling back to CPU")
platform = openmm.Platform.getPlatformByName("CPU")
return platform
def _create_integrator(
self,
temperature: float,
friction: float = 1.0,
timestep: float = 2.0,
thermostat: str = "LangevinMiddle",
) -> openmm.Integrator:
"""Create an integrator for the simulation.
Args:
temperature: Temperature in Kelvin.
friction: Friction coefficient in 1/ps.
timestep: Time step in femtoseconds.
thermostat: Thermostat type.
Returns:
OpenMM Integrator.
"""
temp = temperature * omm_unit.kelvin
fric = friction / omm_unit.picosecond
dt = timestep * omm_unit.femtosecond
if thermostat == "LangevinMiddle":
return openmm.LangevinMiddleIntegrator(temp, fric, dt)
elif thermostat == "Langevin":
return openmm.LangevinIntegrator(temp, fric, dt)
else:
LOGGER.warning(f"Unknown thermostat {thermostat}, using LangevinMiddle")
return openmm.LangevinMiddleIntegrator(temp, fric, dt)
def _add_barostat(
self,
pressure: float = 1.0,
temperature: float = 300.0,
frequency: int = 25,
) -> None:
"""Add a Monte Carlo barostat to the system.
Args:
pressure: Pressure in atmospheres.
temperature: Temperature in Kelvin.
frequency: Update frequency in steps.
"""
barostat = openmm.MonteCarloBarostat(
pressure * omm_unit.atmosphere,
temperature * omm_unit.kelvin,
frequency,
)
self._system.addForce(barostat)
LOGGER.info(f"Added MC barostat: {pressure} atm, {temperature} K")
def _remove_barostat(self) -> None:
"""Remove any barostat from the system."""
forces_to_remove = []
for i in range(self._system.getNumForces()):
force = self._system.getForce(i)
if isinstance(force, openmm.MonteCarloBarostat):
forces_to_remove.append(i)
for i in reversed(forces_to_remove):
self._system.removeForce(i)
LOGGER.debug("Removed barostat")
[docs]
def minimize(
self,
max_iterations: int = 1000,
tolerance: float = 10.0,
) -> float:
"""Run energy minimization.
Args:
max_iterations: Maximum iterations (0 = until convergence).
tolerance: Energy tolerance in kJ/mol/nm.
Returns:
Final potential energy in kJ/mol.
"""
LOGGER.info("Running energy minimization")
# Create temporary simulation for minimization
integrator = openmm.VerletIntegrator(1.0 * omm_unit.femtosecond)
platform = self._get_platform()
simulation = Simulation(self._topology, self._system, integrator, platform)
simulation.context.setPositions(self._current_positions)
# Minimize
simulation.minimizeEnergy(
tolerance=tolerance * omm_unit.kilojoule_per_mole / omm_unit.nanometer,
maxIterations=max_iterations,
)
# Get final state (including box vectors for proper handoff to equilibration)
state = simulation.context.getState(getEnergy=True, getPositions=True)
energy = state.getPotentialEnergy().value_in_unit(omm_unit.kilojoule_per_mole)
self._current_positions = state.getPositions()
self._current_box_vectors = state.getPeriodicBoxVectors()
LOGGER.info(f"Minimization complete: E = {energy:.2f} kJ/mol")
return energy
[docs]
def run_equilibration(
self,
temperature: float,
duration_ns: Optional[float] = None,
num_samples: int = 10,
timestep_fs: float = 2.0,
friction: float = 1.0,
output_prefix: str = "equilibration",
config: Optional["SimulationPhasesConfig"] = None,
) -> Dict[str, Any]:
"""Run equilibration phase.
Supports two modes:
1. Parameter-based (legacy): Pass duration_ns, num_samples, etc.
2. Config-based: Pass config for automatic mode selection
When config is provided and uses staged equilibration, position
restraints and temperature ramping are handled automatically.
Component information is derived from the topology's chain IDs.
Args:
temperature: Temperature in Kelvin
duration_ns: Duration in nanoseconds (required for legacy mode)
num_samples: Number of trajectory frames to save
timestep_fs: Time step in femtoseconds
friction: Friction coefficient in 1/ps
output_prefix: Prefix for output files
config: SimulationPhasesConfig for config-based dispatch
Returns:
Dictionary with equilibration results
Raises:
ValueError: If neither config nor duration_ns is provided
"""
# Config-based dispatch (preferred path)
if config is not None:
if config.uses_staged_equilibration:
# Multi-stage equilibration with position restraints
from polyzymd.core.atom_groups import AtomGroupResolver, SystemComponentInfo
# Log all stages upfront for reproducibility
LOGGER.info(
f"Starting multi-stage equilibration with "
f"{len(config.equilibration_stages)} stages:"
)
for i, stage in enumerate(config.equilibration_stages):
restraint_info = (
", ".join(
f"{r.group}@{r.force_constant:.0f}" for r in stage.position_restraints
)
or "none"
)
if stage.is_temperature_ramping:
temp_info = f"{stage.temperature_start}K -> {stage.temperature_end}K"
else:
temp_info = f"{stage.temperature}K"
LOGGER.info(
f" Stage {i}: {stage.name} - {stage.duration} ns, "
f"{stage.ensemble.value}, {temp_info}, restraints: [{restraint_info}]"
)
component_info = SystemComponentInfo.from_topology(self._topology)
resolver = AtomGroupResolver(self._topology, component_info)
return self.run_staged_equilibration(
stages=config.equilibration_stages,
atom_group_resolver=resolver,
target_temperature=temperature,
)
else:
# Simple equilibration via config
eq_config = config.equilibration
return self._run_simple_equilibration(
temperature=temperature,
duration_ns=eq_config.duration,
num_samples=eq_config.samples,
timestep_fs=eq_config.time_step or timestep_fs,
friction=friction,
output_prefix=output_prefix,
)
# Legacy parameter-based mode
if duration_ns is None:
raise ValueError(
"duration_ns is required when config is not provided. "
"Either pass duration_ns or pass a SimulationPhasesConfig."
)
return self._run_simple_equilibration(
temperature=temperature,
duration_ns=duration_ns,
num_samples=num_samples,
timestep_fs=timestep_fs,
friction=friction,
output_prefix=output_prefix,
)
def _run_simple_equilibration(
self,
temperature: float,
duration_ns: float,
num_samples: int = 10,
timestep_fs: float = 2.0,
friction: float = 1.0,
output_prefix: str = "equilibration",
) -> Dict[str, Any]:
"""Run simple NVT equilibration (internal implementation).
Args:
temperature: Temperature in Kelvin.
duration_ns: Duration in nanoseconds.
num_samples: Number of trajectory frames to save.
timestep_fs: Time step in femtoseconds.
friction: Friction coefficient in 1/ps.
output_prefix: Prefix for output files.
Returns:
Dictionary with phase results.
"""
LOGGER.info(f"Starting equilibration: {duration_ns} ns at {temperature} K (NVT)")
# Remove any barostat for NVT
self._remove_barostat()
# Create output directory
phase_dir = self._working_dir / output_prefix
phase_dir.mkdir(exist_ok=True)
# Calculate steps
total_steps = int(duration_ns * 1e6 / timestep_fs)
report_interval = max(1, total_steps // num_samples)
# Create integrator and simulation
integrator = self._create_integrator(
temperature=temperature,
friction=friction,
timestep=timestep_fs,
)
platform = self._get_platform()
self._simulation = Simulation(self._topology, self._system, integrator, platform)
self._simulation.context.setPositions(self._current_positions)
self._simulation.context.setVelocitiesToTemperature(temperature * omm_unit.kelvin)
# Add reporters
traj_path = phase_dir / f"{output_prefix}_trajectory.dcd"
state_path = phase_dir / f"{output_prefix}_state_data.csv"
pdb_path = phase_dir / f"{output_prefix}_topology.pdb"
self._simulation.reporters.append(DCDReporter(str(traj_path), report_interval))
self._simulation.reporters.append(
StateDataReporter(
str(state_path),
report_interval,
step=True,
time=True,
potentialEnergy=True,
kineticEnergy=True,
totalEnergy=True,
temperature=True,
volume=True,
density=True,
speed=True,
)
)
# Save topology
with open(pdb_path, "w") as f:
PDBFile.writeFile(
self._topology,
self._current_positions,
f,
)
# Run simulation
LOGGER.info(f"Running {total_steps} steps...")
self._simulation.step(total_steps)
# Get final state (including box vectors for potential NPT follow-up)
state = self._simulation.context.getState(
getPositions=True, getVelocities=True, getEnergy=True
)
self._current_positions = state.getPositions()
self._current_velocities = state.getVelocities()
self._current_box_vectors = state.getPeriodicBoxVectors()
# Save checkpoint
checkpoint_path = phase_dir / f"{output_prefix}_checkpoint.chk"
self._simulation.saveCheckpoint(str(checkpoint_path))
results = {
"phase": "equilibration",
"ensemble": "NVT",
"temperature_K": temperature,
"duration_ns": duration_ns,
"total_steps": total_steps,
"final_energy_kJ_mol": state.getPotentialEnergy().value_in_unit(
omm_unit.kilojoule_per_mole
),
"trajectory_path": str(traj_path),
"checkpoint_path": str(checkpoint_path),
}
self._history["equilibration"] = results
LOGGER.info("Equilibration complete")
return results
[docs]
def run_equilibration_stage(
self,
stage: "EquilibrationStageConfig",
reference_positions: Any,
atom_group_resolver: "AtomGroupResolver",
stage_index: int,
default_timestep: float = 2.0,
default_friction: float = 1.0,
resume_from_step: int = 0,
resume_temperature: float | None = None,
) -> Dict[str, Any]:
"""Run a single equilibration stage with optional position restraints.
This method runs one stage of a multi-stage equilibration protocol.
It supports:
- Position restraints on predefined atom groups
- Temperature ramping (simulated annealing)
- NVT or NPT ensembles
- Graceful interruption and mid-stage resume
Args:
stage: EquilibrationStageConfig with stage settings
reference_positions: Positions to restrain atoms to (typically post-minimization)
atom_group_resolver: Resolver for predefined atom group names
stage_index: Index of this stage (for output naming)
default_timestep: Default time step in fs if not specified in stage
default_friction: Default friction coefficient in 1/ps
resume_from_step: Step to resume from within this stage (0 = start fresh).
When resuming, the simulation must already be loaded from checkpoint.
resume_temperature: Current temperature when resuming a temperature-ramping
stage. Only used when ``resume_from_step > 0`` and the stage uses
temperature ramping.
Returns:
Dictionary with stage results
"""
from polyzymd.config.schema import Ensemble
from polyzymd.core.position_restraints import (
add_position_restraints_to_system,
remove_position_restraints_from_system,
)
stage_name = f"equilibration_{stage_index}_{stage.name}"
LOGGER.info(f"Starting equilibration stage: {stage.name} ({stage.duration} ns)")
# Create output directory for this stage
phase_dir = self._working_dir / stage_name
phase_dir.mkdir(exist_ok=True)
# Get stage parameters (with defaults)
timestep_fs = stage.time_step if stage.time_step is not None else default_timestep
thermostat_timescale = stage.thermostat_timescale if stage.thermostat_timescale else 1.0
friction = 1.0 / thermostat_timescale # friction (1/ps) = 1 / timescale (ps)
# Calculate steps and reporting interval
total_steps = int(stage.duration * 1e6 / timestep_fs)
report_interval = max(1, total_steps // stage.samples)
# Handle ensemble - add/remove barostat
if stage.ensemble == Ensemble.NPT:
self._remove_barostat()
pressure = 1.0 # Default pressure for NPT stages
barostat_freq = stage.barostat_frequency if stage.barostat_frequency else 25
start_temp = stage.get_start_temperature()
self._add_barostat(
pressure=pressure,
temperature=start_temp,
frequency=barostat_freq,
)
else:
# NVT - ensure no barostat
self._remove_barostat()
# Add position restraints
restraint_force_indices = []
for restraint_config in stage.position_restraints:
atom_indices = atom_group_resolver.resolve(restraint_config.group)
if atom_indices:
force_idx = add_position_restraints_to_system(
system=self._system,
atom_indices=atom_indices,
positions=reference_positions,
force_constant=restraint_config.force_constant,
)
if force_idx >= 0:
restraint_force_indices.append(force_idx)
LOGGER.info(
f"Added position restraints to {len(atom_indices)} atoms "
f"in group '{restraint_config.group}' "
f"(k={restraint_config.force_constant:.1f} kJ/mol/nm^2)"
)
else:
LOGGER.warning(
f"No atoms found for group '{restraint_config.group}' - skipping restraint"
)
# Create integrator with starting temperature
start_temp = stage.get_start_temperature()
integrator = self._create_integrator(
temperature=start_temp,
friction=friction,
timestep=timestep_fs,
)
platform = self._get_platform()
# Create simulation
self._simulation = Simulation(self._topology, self._system, integrator, platform)
# Set box vectors BEFORE positions - critical for NPT stage transitions
# where box dimensions may have changed from previous stage
if self._current_box_vectors is not None:
self._simulation.context.setPeriodicBoxVectors(*self._current_box_vectors)
self._simulation.context.setPositions(self._current_positions)
# Velocity initialization: only generate fresh Maxwell-Boltzmann
# velocities for the first stage. Subsequent stages inherit velocities
# from the previous stage for physical continuity — matching the
# GROMACS convention (gen_vel=yes only for stage 0).
if stage_index == 0 or self._current_velocities is None:
self._simulation.context.setVelocitiesToTemperature(start_temp * omm_unit.kelvin)
LOGGER.info(f"Stage {stage_index}: initialized velocities at {start_temp} K")
else:
self._simulation.context.setVelocities(self._current_velocities)
LOGGER.info(f"Stage {stage_index}: inherited velocities from previous stage")
# Log initial energy
_state = self._simulation.context.getState(getEnergy=True)
_energy = _state.getPotentialEnergy().value_in_unit(omm_unit.kilojoule_per_mole)
LOGGER.info(f"Stage {stage_index} ({stage_name}): initial PE = {_energy:.2f} kJ/mol")
# Set up reporters
traj_path = phase_dir / f"{stage_name}_trajectory.dcd"
state_path = phase_dir / f"{stage_name}_state_data.csv"
pdb_path = phase_dir / f"{stage_name}_topology.pdb"
self._simulation.reporters.append(DCDReporter(str(traj_path), report_interval))
self._simulation.reporters.append(
StateDataReporter(
str(state_path),
report_interval,
step=True,
time=True,
potentialEnergy=True,
kineticEnergy=True,
totalEnergy=True,
temperature=True,
volume=True,
density=True,
speed=True,
)
)
# Save initial topology
with open(pdb_path, "w") as f:
PDBFile.writeFile(
self._topology,
self._current_positions,
f,
)
# Periodic checkpoint reporter so state survives hard kills
eq_chk_path = phase_dir / f"{stage_name}_checkpoint.chk"
self._simulation.reporters.append(CheckpointReporter(str(eq_chk_path), report_interval))
# Install signal handlers for graceful shutdown
from polyzymd.simulation.signals import (
GracefulExit,
get_interrupt_signal,
install_handlers,
is_interrupted,
)
install_handlers()
# Helper: save EQ_INTERRUPTED marker for mid-stage resume
def _save_eq_interrupted(steps_done: int, current_temp: float) -> None:
marker_path = phase_dir / "EQ_INTERRUPTED"
marker_path.write_text(
f"stage_index={stage_index}\n"
f"stage_name={stage.name}\n"
f"steps_completed={steps_done}\n"
f"total_steps={total_steps}\n"
f"current_temperature={current_temp}\n"
f"is_temperature_ramping={stage.is_temperature_ramping}\n"
)
LOGGER.info(
f"Saved EQ_INTERRUPTED marker: stage {stage_index}, "
f"step {steps_done}/{total_steps}, temp {current_temp} K"
)
# Track steps completed (starting from resume point)
steps_done = resume_from_step
# Run simulation with temperature ramping if needed
if stage.is_temperature_ramping:
LOGGER.info(
f"Temperature ramping: {stage.temperature_start} K -> {stage.temperature_end} K "
f"(increment={stage.temperature_increment} K every {stage.temperature_interval} fs)"
)
steps_per_update = int(stage.temperature_interval / timestep_fs)
# Calculate total temperature updates needed
temp_range = stage.temperature_end - stage.temperature_start
num_updates = int(temp_range / stage.temperature_increment)
steps_for_ramping = num_updates * steps_per_update
remaining_steps_at_final = total_steps - steps_for_ramping
# Determine starting temperature — always begin from
# temperature_start and let the fast-forward loop advance
# current_temp by skipping already-completed chunks. On
# resume, resume_temperature is used only for logging.
current_temp = stage.temperature_start
if resume_from_step > 0 and resume_temperature is not None:
LOGGER.info(
f"Resuming temperature ramp from step {resume_from_step}, "
f"saved temp {resume_temperature} K (fast-forwarding from {current_temp} K)"
)
# Temperature ramping phase — each update is a chunk we can
# interrupt between. Skip chunks already completed on resume.
ramp_step_count = 0
while current_temp < stage.temperature_end:
chunk_end = ramp_step_count + steps_per_update
if chunk_end <= resume_from_step:
# Already completed this chunk in a previous run
ramp_step_count = chunk_end
current_temp += stage.temperature_increment
continue
integrator.setTemperature(current_temp * omm_unit.kelvin)
if stage.ensemble == Ensemble.NPT:
self._simulation.context.setParameter(
openmm.MonteCarloBarostat.Temperature(),
current_temp * omm_unit.kelvin,
)
# If resuming mid-chunk, only run the remainder
steps_already = max(0, resume_from_step - ramp_step_count)
steps_this_chunk = steps_per_update - steps_already
self._simulation.step(steps_this_chunk)
steps_done += steps_this_chunk
ramp_step_count = chunk_end
if is_interrupted():
LOGGER.warning(
f"Interrupt during equilibration stage {stage_index} "
f"at step {steps_done}/{total_steps} (ramping, T={current_temp:.1f} K)"
)
_save_eq_interrupted(steps_done, current_temp)
raise GracefulExit(
signal_number=get_interrupt_signal(), steps_completed=steps_done
)
current_temp += stage.temperature_increment
# Final temperature - run remaining steps in chunks
integrator.setTemperature(stage.temperature_end * omm_unit.kelvin)
if stage.ensemble == Ensemble.NPT:
self._simulation.context.setParameter(
openmm.MonteCarloBarostat.Temperature(),
stage.temperature_end * omm_unit.kelvin,
)
current_temp = stage.temperature_end
steps_at_final_done = steps_done - steps_for_ramping
steps_at_final_remaining = max(0, remaining_steps_at_final - steps_at_final_done)
if steps_at_final_remaining > 0:
LOGGER.info(
f"Running {steps_at_final_remaining} steps at final "
f"temperature {stage.temperature_end} K"
)
chunk_size = min(report_interval, steps_at_final_remaining)
while steps_at_final_remaining > 0:
this_chunk = min(chunk_size, steps_at_final_remaining)
self._simulation.step(this_chunk)
steps_done += this_chunk
steps_at_final_remaining -= this_chunk
if is_interrupted():
LOGGER.warning(
f"Interrupt during equilibration stage {stage_index} "
f"at step {steps_done}/{total_steps} (final temp)"
)
_save_eq_interrupted(steps_done, current_temp)
raise GracefulExit(
signal_number=get_interrupt_signal(),
steps_completed=steps_done,
)
else:
# Constant temperature - run in chunks with signal checking
current_temp = stage.temperature
steps_remaining = total_steps - resume_from_step
if resume_from_step > 0:
LOGGER.info(
f"Resuming constant-temp stage from step {resume_from_step}, "
f"{steps_remaining} steps remaining"
)
else:
LOGGER.info(f"Running {total_steps} steps at {stage.temperature} K")
chunk_size = min(report_interval, steps_remaining)
while steps_remaining > 0:
this_chunk = min(chunk_size, steps_remaining)
self._simulation.step(this_chunk)
steps_done += this_chunk
steps_remaining -= this_chunk
if is_interrupted():
LOGGER.warning(
f"Interrupt during equilibration stage {stage_index} "
f"at step {steps_done}/{total_steps}"
)
_save_eq_interrupted(steps_done, current_temp)
raise GracefulExit(
signal_number=get_interrupt_signal(), steps_completed=steps_done
)
# Get final state (including box vectors for NPT stages)
state = self._simulation.context.getState(
getPositions=True, getVelocities=True, getEnergy=True
)
self._current_positions = state.getPositions()
self._current_velocities = state.getVelocities()
self._current_box_vectors = state.getPeriodicBoxVectors()
# Log final energy for diagnostics
final_energy = state.getPotentialEnergy().value_in_unit(omm_unit.kilojoule_per_mole)
LOGGER.info(f"Stage {stage_index} ({stage_name}): final PE = {final_energy:.2f} kJ/mol")
# Save checkpoint
checkpoint_path = phase_dir / f"{stage_name}_checkpoint.chk"
self._simulation.saveCheckpoint(str(checkpoint_path))
# Remove EQ_INTERRUPTED marker if present (stage completed successfully)
eq_interrupted_marker = phase_dir / "EQ_INTERRUPTED"
if eq_interrupted_marker.exists():
eq_interrupted_marker.unlink()
LOGGER.info("Removed EQ_INTERRUPTED marker — stage completed successfully")
# Remove position restraints from system for next stage
if restraint_force_indices:
LOGGER.info(
f"Removing {len(restraint_force_indices)} position restraint force(s) for next stage"
)
remove_position_restraints_from_system(self._system, restraint_force_indices)
# Build results
final_temp = stage.get_final_temperature()
results = {
"stage_index": stage_index,
"stage_name": stage.name,
"ensemble": stage.ensemble.value,
"duration_ns": stage.duration,
"total_steps": total_steps,
"temperature_start_K": stage.get_start_temperature(),
"temperature_end_K": final_temp,
"is_temperature_ramping": stage.is_temperature_ramping,
"position_restraints": [
{"group": r.group, "force_constant": r.force_constant}
for r in stage.position_restraints
],
"final_energy_kJ_mol": state.getPotentialEnergy().value_in_unit(
omm_unit.kilojoule_per_mole
),
"trajectory_path": str(traj_path),
"checkpoint_path": str(checkpoint_path),
}
LOGGER.info(f"Equilibration stage '{stage.name}' complete")
return results
def _find_completed_eq_stages(
self,
stages: List["EquilibrationStageConfig"],
) -> List[int]:
"""Find equilibration stages whose checkpoints exist on disk.
Scans ``equilibration_N_name/`` directories for checkpoint files.
Stops at the first gap — stages must be contiguous from index 0.
An interrupted stage (has ``EQ_INTERRUPTED`` marker) is NOT
considered completed and stops the scan.
Parameters
----------
stages : list of EquilibrationStageConfig
The full list of stages from the config.
Returns
-------
list of int
Indices of completed stages (contiguous from 0).
"""
completed: List[int] = []
for i, stage in enumerate(stages):
stage_name = f"equilibration_{i}_{stage.name}"
stage_dir = self._working_dir / stage_name
chk = stage_dir / f"{stage_name}_checkpoint.chk"
eq_marker = stage_dir / "EQ_INTERRUPTED"
if chk.exists() and not eq_marker.exists():
completed.append(i)
else:
break # Stop at first gap — can't skip stages
return completed
def _find_interrupted_eq_stage(
self,
stages: List["EquilibrationStageConfig"],
completed_indices: List[int],
) -> Dict[str, Any] | None:
"""Check if the next unfinished equilibration stage was interrupted mid-run.
Looks for an ``EQ_INTERRUPTED`` marker in the stage directory that
follows the last completed stage. If found, parses resume metadata.
Parameters
----------
stages : list of EquilibrationStageConfig
The full list of stages from the config.
completed_indices : list of int
Indices of fully completed stages (from ``_find_completed_eq_stages``).
Returns
-------
dict or None
Dictionary with ``stage_index``, ``steps_completed``, ``total_steps``,
and ``current_temperature`` if an interrupted stage is found;
``None`` otherwise.
"""
next_idx = len(completed_indices)
if next_idx >= len(stages):
return None
stage = stages[next_idx]
stage_name = f"equilibration_{next_idx}_{stage.name}"
stage_dir = self._working_dir / stage_name
marker_path = stage_dir / "EQ_INTERRUPTED"
if not marker_path.exists():
return None
# Parse the marker
info: Dict[str, Any] = {"stage_index": next_idx}
try:
text = marker_path.read_text()
for line in text.strip().splitlines():
key, _, value = line.partition("=")
key = key.strip()
value = value.strip()
if key == "steps_completed":
info["steps_completed"] = int(value)
elif key == "total_steps":
info["total_steps"] = int(value)
elif key == "current_temperature":
info["current_temperature"] = float(value)
elif key == "is_temperature_ramping":
info["is_temperature_ramping"] = value.lower() == "true"
except (ValueError, OSError) as exc:
LOGGER.warning(f"Could not parse EQ_INTERRUPTED marker {marker_path}: {exc}")
return None
# Verify checkpoint exists (needed for resume)
chk = stage_dir / f"{stage_name}_checkpoint.chk"
if not chk.exists():
LOGGER.warning(
f"EQ_INTERRUPTED marker found for stage {next_idx} but no checkpoint — "
f"will restart stage from beginning"
)
return None
LOGGER.info(
f"Found interrupted equilibration stage {next_idx} ({stage.name}): "
f"{info.get('steps_completed', 0)}/{info.get('total_steps', '?')} steps, "
f"T={info.get('current_temperature', '?')} K"
)
return info
def _load_eq_stage_state(
self,
stage_index: int,
stage_name: str,
) -> None:
"""Load positions, velocities, and box vectors from a completed equilibration checkpoint.
Creates a temporary ``Simulation`` to deserialise the binary
checkpoint, then stores the extracted state on ``self`` so
subsequent stages or production can pick up seamlessly.
The temporary simulation is discarded after extraction.
Parameters
----------
stage_index : int
Index of the completed stage.
stage_name : str
Name of the completed stage (used in directory/file naming).
"""
dir_name = f"equilibration_{stage_index}_{stage_name}"
chk_path = self._working_dir / dir_name / f"{dir_name}_checkpoint.chk"
if not chk_path.exists():
raise FileNotFoundError(f"Equilibration checkpoint not found: {chk_path}")
# Temporary simulation context to deserialise the binary checkpoint.
# We use a dummy VerletIntegrator because we only need to extract
# geometric state (positions, velocities, box vectors) — the next
# run_equilibration_stage() call creates a proper LangevinMiddleIntegrator.
integrator = openmm.VerletIntegrator(1.0 * omm_unit.femtosecond)
platform = self._get_platform()
temp_sim = Simulation(self._topology, self._system, integrator, platform)
temp_sim.loadCheckpoint(str(chk_path))
state = temp_sim.context.getState(getPositions=True, getVelocities=True)
self._current_positions = state.getPositions()
self._current_velocities = state.getVelocities()
self._current_box_vectors = state.getPeriodicBoxVectors()
# Discard the temporary simulation — it has a dummy integrator
del temp_sim
LOGGER.info(
f"Loaded state from equilibration stage {stage_index} ({stage_name}) checkpoint"
)
[docs]
def run_staged_equilibration(
self,
stages: List["EquilibrationStageConfig"],
atom_group_resolver: "AtomGroupResolver",
target_temperature: float,
) -> Dict[str, Any]:
"""Run complete multi-stage equilibration protocol.
This method executes a sequence of equilibration stages, each with
potentially different:
- Temperature (constant or ramping)
- Position restraints on different atom groups
- Thermodynamic ensemble (NVT/NPT)
Positions carry over between stages, and restraint forces are
added/removed as needed. If a previous run was interrupted, completed
stages are detected on disk via their checkpoint files and skipped
automatically.
Args:
stages: List of EquilibrationStageConfig objects
atom_group_resolver: Resolver for predefined atom group names
target_temperature: Final target temperature (for logging)
Returns:
Dictionary with all stage results and summary
"""
import shutil
LOGGER.info(f"Starting multi-stage equilibration with {len(stages)} stages")
# Store reference positions for restraints BEFORE loading any checkpoint.
# These are the post-minimization positions that restraint forces target.
# Must be captured before _load_eq_stage_state() overwrites _current_positions.
reference_positions = self._current_positions
# Detect stages that already completed (checkpoint exists on disk)
completed_indices = self._find_completed_eq_stages(stages)
if completed_indices:
last_idx = completed_indices[-1]
last_stage = stages[last_idx]
LOGGER.info(
f"Found {len(completed_indices)} completed equilibration stage(s) "
f"on disk — resuming after stage {last_idx} ({last_stage.name})"
)
self._load_eq_stage_state(last_idx, last_stage.name)
# Check if the next stage was interrupted mid-run (has EQ_INTERRUPTED marker)
interrupted_info = self._find_interrupted_eq_stage(stages, completed_indices)
results: Dict[str, Any] = {
"type": "staged_equilibration",
"num_stages": len(stages),
"stages": [],
"total_duration_ns": 0.0,
}
for i, stage in enumerate(stages):
if i in completed_indices:
LOGGER.info(f"Skipping completed equilibration stage {i}: {stage.name}")
results["stages"].append(
{
"stage_index": i,
"stage_name": stage.name,
"skipped": True,
"duration_ns": stage.duration,
}
)
results["total_duration_ns"] += stage.duration
continue
stage_dir_name = f"equilibration_{i}_{stage.name}"
partial_dir = self._working_dir / stage_dir_name
# Check if this stage was interrupted and can be resumed
if interrupted_info is not None and interrupted_info["stage_index"] == i:
# Resume mid-stage from checkpoint
resume_step = interrupted_info.get("steps_completed", 0)
resume_temp = interrupted_info.get("current_temperature")
LOGGER.info(
f"Resuming interrupted equilibration stage {i} ({stage.name}) "
f"from step {resume_step}"
)
self._load_eq_stage_state(i, stage.name)
stage_result = self.run_equilibration_stage(
stage=stage,
reference_positions=reference_positions,
atom_group_resolver=atom_group_resolver,
stage_index=i,
resume_from_step=resume_step,
resume_temperature=resume_temp,
)
results["stages"].append(stage_result)
results["total_duration_ns"] += stage.duration
continue
# Clean up partial stage directory (exists but no checkpoint and
# no EQ_INTERRUPTED marker — truly failed, restart from scratch)
if partial_dir.exists():
LOGGER.warning(
f"Stage {i} ({stage.name}) directory exists without checkpoint "
f"— removing partial output and re-running"
)
shutil.rmtree(partial_dir)
stage_result = self.run_equilibration_stage(
stage=stage,
reference_positions=reference_positions,
atom_group_resolver=atom_group_resolver,
stage_index=i,
)
results["stages"].append(stage_result)
results["total_duration_ns"] += stage.duration
# Get final energy
if results["stages"]:
# Find last non-skipped result
last_result = None
for r in reversed(results["stages"]):
if not r.get("skipped"):
last_result = r
break
if last_result:
results["final_energy_kJ_mol"] = last_result["final_energy_kJ_mol"]
results["final_temperature_K"] = last_result["temperature_end_K"]
self._history["equilibration"] = results
skipped = len(completed_indices)
ran = len(stages) - skipped
LOGGER.info(
f"Multi-stage equilibration complete: {len(stages)} stages "
f"({skipped} skipped, {ran} ran), "
f"{results['total_duration_ns']:.3f} ns total"
)
return results
[docs]
def run_production(
self,
temperature: float,
duration_ns: float,
num_samples: int = 2500,
timestep_fs: float = 2.0,
friction: float = 1.0,
pressure: float = 1.0,
barostat_frequency: int = 25,
output_prefix: str = "production",
segment_index: int = 0,
report_interval: int | None = None,
checkpoint_interval_s: float = 60.0,
) -> Dict[str, Any]:
"""Run NPT production simulation.
Args:
temperature: Temperature in Kelvin.
duration_ns: Duration in nanoseconds.
num_samples: Number of trajectory frames to save. Ignored when
``report_interval`` is provided.
timestep_fs: Time step in femtoseconds.
friction: Friction coefficient in 1/ps.
pressure: Pressure in atmospheres.
barostat_frequency: Barostat update frequency.
output_prefix: Prefix for output files.
segment_index: Segment index for multi-segment production.
report_interval: Fixed reporter interval in steps. When provided,
this overrides the per-segment ``total_steps // num_samples``
calculation to keep frame spacing uniform across segments.
checkpoint_interval_s: Wall-time interval in seconds between
portable restart checkpoints. Also controls how frequently
the loop checks for SLURM preemption signals. Set to 0 to
disable wall-time checkpoints (reverts to legacy behaviour).
Returns:
Dictionary with phase results.
"""
LOGGER.info(
f"Starting production: {duration_ns} ns at {temperature} K, {pressure} atm (NPT)"
)
# Add barostat for NPT
self._remove_barostat()
self._add_barostat(
pressure=pressure,
temperature=temperature,
frequency=barostat_frequency,
)
# Create output directory
phase_name = f"{output_prefix}_{segment_index}"
phase_dir = self._working_dir / phase_name
phase_dir.mkdir(exist_ok=True)
# Calculate steps
total_steps = int(duration_ns * 1e6 / timestep_fs)
# Determine report interval: prefer the fixed global value if given,
# otherwise fall back to per-segment calculation (legacy callers).
if report_interval is None:
report_interval = max(1, total_steps // num_samples)
# Create integrator and simulation
integrator = self._create_integrator(
temperature=temperature,
friction=friction,
timestep=timestep_fs,
)
platform = self._get_platform()
self._simulation = Simulation(self._topology, self._system, integrator, platform)
# Set box vectors BEFORE positions - critical for correct periodic boundary handling
if self._current_box_vectors is not None:
self._simulation.context.setPeriodicBoxVectors(*self._current_box_vectors)
self._simulation.context.setPositions(self._current_positions)
# Log initial energy
_state = self._simulation.context.getState(getEnergy=True)
_energy = _state.getPotentialEnergy().value_in_unit(omm_unit.kilojoule_per_mole)
LOGGER.info(f"Production segment {segment_index}: initial PE = {_energy:.2f} kJ/mol")
# Save system XML early so it exists on disk even if the segment is
# hard-killed (SIGKILL / OOM). This is required for checkpoint-based
# recovery: loadCheckpoint() needs a matching System object. The file
# is overwritten at segment completion with the final state.
system_xml_path = phase_dir / f"{phase_name}_system.xml"
with open(system_xml_path, "w") as f:
f.write(XmlSerializer.serialize(self._system))
LOGGER.info(f"Saved initial system to {system_xml_path}")
# Set velocities for production
# - If we have velocities from equilibration, use them (physical continuity)
# - Otherwise generate new velocities at target temperature
# Note: For continuation segments (segment > 0), ContinuationManager uses
# loadState() which restores both positions and velocities from the XML state file
if segment_index == 0:
if self._current_velocities is not None:
self._simulation.context.setVelocities(self._current_velocities)
LOGGER.info("Using velocities preserved from equilibration")
else:
self._simulation.context.setVelocitiesToTemperature(temperature * omm_unit.kelvin)
LOGGER.info("Initialized velocities from Maxwell-Boltzmann distribution")
# Add reporters
traj_path = phase_dir / f"{phase_name}_trajectory.dcd"
state_path = phase_dir / f"{phase_name}_state_data.csv"
pdb_path = phase_dir / f"{phase_name}_topology.pdb"
self._simulation.reporters.append(DCDReporter(str(traj_path), report_interval))
self._simulation.reporters.append(
StateDataReporter(
str(state_path),
report_interval,
step=True,
time=True,
potentialEnergy=True,
kineticEnergy=True,
totalEnergy=True,
temperature=True,
volume=True,
density=True,
speed=True,
)
)
# Periodic checkpoint reporter — ensures a .chk file is written every
# report_interval steps. Without this, segment 0 has NO checkpoint on
# disk until the very end, so a hard kill (SIGKILL / OOM / node failure)
# would lose all progress. Matches the behaviour already present in
# continuation.py for segments >= 1.
prod_chk_path = phase_dir / f"{phase_name}_checkpoint.chk"
self._simulation.reporters.append(CheckpointReporter(str(prod_chk_path), report_interval))
# Save topology
with open(pdb_path, "w") as f:
PDBFile.writeFile(
self._topology,
self._current_positions,
f,
)
# Save parameters JSON before the simulation loop so it exists even
# if the segment is interrupted (continuation.py requires this file)
# Save parameters JSON (needed for continuation across segments)
params_dict = {
"__class__": "SimulationParameters",
"__values__": {
"thermo_params": {
"__class__": "ThermoParameters",
"__values__": {
"temperature": {
"__class__": "Quantity",
"__values__": {"value": temperature, "unit": "kelvin"},
},
"thermostat_params": {
"__class__": "ThermostatParameters",
"__values__": {
"temperature": {
"__class__": "Quantity",
"__values__": {"value": temperature, "unit": "kelvin"},
},
"timescale": {
"__class__": "Quantity",
"__values__": {"value": friction, "unit": "/picosecond"},
},
},
},
"barostat_params": {
"__class__": "BarostatParameters",
"__values__": {
"pressure": {
"__class__": "Quantity",
"__values__": {"value": pressure, "unit": "atmosphere"},
},
"temperature": {
"__class__": "Quantity",
"__values__": {"value": temperature, "unit": "kelvin"},
},
"frequency": barostat_frequency,
},
},
},
},
"integ_params": {
"__class__": "IntegratorParameters",
"__values__": {
"time_step": {
"__class__": "Quantity",
"__values__": {"value": timestep_fs, "unit": "femtosecond"},
},
"total_time": {
"__class__": "Quantity",
"__values__": {"value": duration_ns, "unit": "nanosecond"},
},
"num_samples": num_samples,
},
},
"reporter_params": {
"__class__": "ReporterParameters",
"__values__": {
"report_interval": report_interval,
"report_trajectory": True,
"report_state_data": True,
},
},
},
}
params_path = phase_dir / f"{phase_name}_parameters.json"
with open(params_path, "w") as f:
json.dump(params_dict, f, indent=2)
LOGGER.info(f"Saved parameters to {params_path}")
# Install signal handlers for graceful shutdown (SIGUSR1 / SIGTERM)
from polyzymd.simulation.signals import (
GracefulExit,
get_interrupt_signal,
install_handlers,
is_interrupted,
save_interrupted_state,
save_restart_checkpoint,
)
install_handlers()
# Mark this segment as RUNNING in progress.json so that
# check-progress can distinguish actively running simulations
# from interrupted ones.
self._write_segment_started(segment_index, total_steps)
# Run simulation with adaptive sub-chunks for interrupt responsiveness
# and periodic wall-time restart checkpoints for preemption resilience.
LOGGER.info(f"Running {total_steps} steps...")
steps_done = 0
import time as _time
from polyzymd.simulation.progress import PROGRESS_UPDATE_INTERVAL_SECONDS
_last_progress_write = _time.monotonic()
_last_checkpoint_write = _time.monotonic()
_loop_start = _time.monotonic()
# Adaptive sub-chunk sizing: start with report_interval (the original
# chunk_size). After the first checkpoint interval elapses, measure
# actual steps/second and adapt sub_chunk to target
# checkpoint_interval / 4 seconds (~15s worth of steps). This
# ensures ~4 interrupt checks per checkpoint interval regardless of
# system size or hardware speed.
sub_chunk = min(report_interval, total_steps)
_adapted = False
try:
while steps_done < total_steps:
remaining = total_steps - steps_done
this_chunk = min(sub_chunk, remaining)
self._simulation.step(this_chunk)
steps_done += this_chunk
_now = _time.monotonic()
# Adaptive sub-chunk calibration (once, after first interval)
if (
not _adapted
and checkpoint_interval_s > 0
and (_now - _loop_start) >= checkpoint_interval_s
):
elapsed = _now - _loop_start
steps_per_sec = steps_done / elapsed if elapsed > 0 else 1.0
# Target sub-chunk duration = checkpoint_interval / 4
target_seconds = checkpoint_interval_s / 4.0
new_sub_chunk = max(10, int(steps_per_sec * target_seconds))
# Sub-chunk must be a divisor-friendly size relative to
# report_interval to avoid misaligned reporter writes.
# Round down to the nearest multiple of report_interval,
# or use report_interval itself if it's already smaller.
if new_sub_chunk >= report_interval:
new_sub_chunk = report_interval
else:
# Ensure sub_chunk divides evenly into report_interval
# so reporters fire at exact multiples.
# Find largest divisor of report_interval <= new_sub_chunk
best = new_sub_chunk
for candidate in [
report_interval // k
for k in range(1, report_interval // max(1, new_sub_chunk) + 2)
]:
if candidate <= new_sub_chunk and report_interval % candidate == 0:
best = candidate
break
new_sub_chunk = max(10, best)
if new_sub_chunk != sub_chunk:
LOGGER.info(
f"Adaptive sub-chunk: {sub_chunk} -> {new_sub_chunk} steps "
f"(~{new_sub_chunk / steps_per_sec:.1f}s at "
f"{steps_per_sec:.0f} steps/s)"
)
sub_chunk = new_sub_chunk
_adapted = True
# Periodically update RUNNING record so check-progress
# can display real-time remaining nanoseconds.
if _now - _last_progress_write >= PROGRESS_UPDATE_INTERVAL_SECONDS:
self._update_progress_running(
segment_index=segment_index,
steps_done=steps_done,
timestep_fs=timestep_fs,
)
_last_progress_write = _now
# Wall-time restart checkpoint
if (
checkpoint_interval_s > 0
and (_now - _last_checkpoint_write) >= checkpoint_interval_s
and steps_done < total_steps # skip if we're about to finish
):
save_restart_checkpoint(
simulation=self._simulation,
output_dir=phase_dir,
)
_last_checkpoint_write = _now
if is_interrupted():
LOGGER.warning(f"Interrupt detected at step {steps_done}/{total_steps}")
save_interrupted_state(
simulation=self._simulation,
output_dir=phase_dir,
segment_index=segment_index,
steps_completed=steps_done,
total_steps=total_steps,
)
# Update progress tracker (interrupted)
self._update_progress_interrupted(
segment_index=segment_index,
steps_done=steps_done,
total_steps=total_steps,
duration_ns=duration_ns,
timestep_fs=timestep_fs,
)
raise GracefulExit(
signal_number=get_interrupt_signal(), steps_completed=steps_done
)
except GracefulExit:
raise # Re-raise so caller can handle exit code
except Exception:
# On unexpected crash, still try to save interrupted state
try:
save_interrupted_state(
simulation=self._simulation,
output_dir=phase_dir,
segment_index=segment_index,
steps_completed=steps_done,
total_steps=total_steps,
)
except Exception:
LOGGER.error("Failed to save interrupted state after crash")
raise
# Get final state (no enforcePeriodicBox to preserve molecular continuity)
state = self._simulation.context.getState(
getPositions=True,
getVelocities=True,
getEnergy=True,
getForces=True,
getParameters=True,
)
self._current_positions = state.getPositions()
# Save checkpoint
checkpoint_path = phase_dir / f"{phase_name}_checkpoint.chk"
self._simulation.saveCheckpoint(str(checkpoint_path))
# Save state XML (needed for continuation across segments)
state_xml_path = phase_dir / f"{phase_name}_state.xml"
with open(state_xml_path, "w") as f:
f.write(XmlSerializer.serialize(state))
LOGGER.info(f"Saved state to {state_xml_path}")
# Save system XML (needed for continuation across segments)
system_xml_path = phase_dir / f"{phase_name}_system.xml"
with open(system_xml_path, "w") as f:
f.write(XmlSerializer.serialize(self._system))
LOGGER.info(f"Saved system to {system_xml_path}")
# Update progress tracker (successful completion)
self._update_progress_completed(
segment_index=segment_index,
total_steps=total_steps,
num_samples=num_samples,
duration_ns=duration_ns,
timestep_fs=timestep_fs,
)
results = {
"phase": "production",
"segment": segment_index,
"ensemble": "NPT",
"temperature_K": temperature,
"pressure_atm": pressure,
"duration_ns": duration_ns,
"total_steps": total_steps,
"final_energy_kJ_mol": state.getPotentialEnergy().value_in_unit(
omm_unit.kilojoule_per_mole
),
"trajectory_path": str(traj_path),
"checkpoint_path": str(checkpoint_path),
}
self._history[phase_name] = results
LOGGER.info(f"Production segment {segment_index} complete")
return results
def _write_segment_started(
self,
segment_index: int,
total_steps: int,
) -> None:
"""Write a RUNNING segment record to progress.json at segment start.
This marks the segment as actively executing so that
``check-progress`` can distinguish a running simulation from
one that was interrupted. The record is later updated to
COMPLETED or INTERRUPTED by the corresponding handler.
Parameters
----------
segment_index : int
Production segment index (0 for initial production).
total_steps : int
Total steps planned for this segment.
"""
from polyzymd.simulation.progress import (
PROGRESS_UPDATE_INTERVAL_SECONDS,
SegmentRecord,
SegmentStatus,
SimulationStatus,
_update_or_append_segment,
load_progress,
save_progress,
)
progress = load_progress(self._working_dir)
if progress is None:
LOGGER.warning("No progress file found — skipping segment-started write")
return
record = SegmentRecord(
index=segment_index,
steps_completed=0,
steps_requested=total_steps,
samples_written=0,
status=SegmentStatus.RUNNING,
)
_update_or_append_segment(progress, record)
progress.status = SimulationStatus.RUNNING
save_progress(self._working_dir, progress)
LOGGER.info(f"Marked segment {segment_index} as RUNNING in progress file")
def _update_progress_running(
self,
segment_index: int,
steps_done: int,
timestep_fs: float,
) -> None:
"""Periodically update the RUNNING segment's step count.
Called during the simulation loop to keep ``progress.json``
up to date so that ``check-progress`` can show real-time
remaining nanoseconds.
Parameters
----------
segment_index : int
Production segment index.
steps_done : int
Steps completed so far in this segment.
timestep_fs : float
Integration timestep in femtoseconds.
"""
from polyzymd.simulation.progress import (
SegmentStatus,
SimulationStatus,
_update_or_append_segment,
load_progress,
save_progress,
)
progress = load_progress(self._working_dir)
if progress is None:
return
# Find and update the existing RUNNING record
for seg in progress.segments:
if seg.index == segment_index and seg.status == SegmentStatus.RUNNING:
seg.steps_completed = steps_done
seg.duration_ns = (steps_done * timestep_fs) / 1e6
break
else:
# No RUNNING record found — shouldn't happen, but be safe
return
progress.status = SimulationStatus.RUNNING
save_progress(self._working_dir, progress)
LOGGER.debug(f"Updated running progress: segment {segment_index}, {steps_done} steps")
def _update_progress_completed(
self,
segment_index: int,
total_steps: int,
num_samples: int,
duration_ns: float,
timestep_fs: float,
) -> None:
"""Update progress file after successful segment completion.
Parameters
----------
segment_index : int
Production segment index (0 for initial production).
total_steps : int
Steps completed in this segment.
num_samples : int
Samples written in this segment.
duration_ns : float
Simulation time of this segment in nanoseconds.
timestep_fs : float
Integration timestep in femtoseconds.
"""
from polyzymd.simulation.progress import (
SegmentRecord,
SegmentStatus,
SimulationStatus,
_now_iso,
_update_or_append_segment,
load_progress,
save_progress,
)
progress = load_progress(self._working_dir)
if progress is None:
LOGGER.warning("No progress file found — skipping progress update")
return
record = SegmentRecord(
index=segment_index,
steps_completed=total_steps,
steps_requested=total_steps,
samples_written=num_samples,
status=SegmentStatus.COMPLETED,
duration_ns=duration_ns,
)
record.finished_at = _now_iso()
_update_or_append_segment(progress, record)
# Update overall status
if progress.is_complete:
progress.status = SimulationStatus.COMPLETED
else:
progress.status = SimulationStatus.RUNNING
save_progress(self._working_dir, progress)
LOGGER.info(
f"Progress updated: {progress.total_steps_completed}/"
f"{progress.total_steps_requested} steps "
f"({progress.fraction_complete():.1%})"
)
def _update_progress_interrupted(
self,
segment_index: int,
steps_done: int,
total_steps: int,
duration_ns: float,
timestep_fs: float,
) -> None:
"""Update progress file after segment interruption.
Parameters
----------
segment_index : int
Production segment index.
steps_done : int
Steps completed before interruption.
total_steps : int
Steps that were planned for this segment.
duration_ns : float
Planned simulation time of this segment in nanoseconds.
timestep_fs : float
Integration timestep in femtoseconds.
"""
from polyzymd.simulation.progress import (
SegmentRecord,
SegmentStatus,
SimulationStatus,
_update_or_append_segment,
load_progress,
save_progress,
)
progress = load_progress(self._working_dir)
if progress is None:
LOGGER.warning("No progress file found — skipping progress update")
return
# Calculate actual duration completed
actual_duration_ns = (steps_done * timestep_fs) / 1e6
record = SegmentRecord(
index=segment_index,
steps_completed=steps_done,
steps_requested=total_steps,
samples_written=0, # Interrupted — samples may be partial
status=SegmentStatus.INTERRUPTED,
duration_ns=actual_duration_ns,
)
_update_or_append_segment(progress, record)
progress.status = SimulationStatus.INTERRUPTED
save_progress(self._working_dir, progress)
LOGGER.info(
f"Progress updated (interrupted): {steps_done}/{total_steps} steps "
f"in segment {segment_index}"
)
[docs]
def save_history(self, path: Optional[Union[str, Path]] = None) -> None:
"""Save simulation history to JSON.
Args:
path: Output path (defaults to working_dir/simulation_history.json).
"""
if path is None:
path = self._working_dir / "simulation_history.json"
else:
path = Path(path)
with open(path, "w") as f:
json.dump(self._history, f, indent=2)
LOGGER.info(f"Saved simulation history to {path}")
[docs]
def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> None:
"""Load state from a checkpoint file.
Args:
checkpoint_path: Path to checkpoint file.
"""
checkpoint_path = Path(checkpoint_path)
if not checkpoint_path.exists():
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
if self._simulation is None:
raise RuntimeError(
"No active simulation. Create a simulation first with "
"run_equilibration or run_production."
)
self._simulation.loadCheckpoint(str(checkpoint_path))
# Update current positions, velocities, and box vectors
state = self._simulation.context.getState(getPositions=True, getVelocities=True)
self._current_positions = state.getPositions()
self._current_velocities = state.getVelocities()
self._current_box_vectors = state.getPeriodicBoxVectors()
LOGGER.info(f"Loaded checkpoint from {checkpoint_path}")