Source code for polyzymd.simulation.runner

"""
Simulation runner for executing MD simulations with OpenMM.

This module handles running equilibration and production phases
with configurable parameters, reporters, and checkpoint management.
"""

from __future__ import annotations

import json
import logging
from dataclasses import asdict
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Union

if TYPE_CHECKING:
    import openmm
    from openmm import XmlSerializer
    from openmm import unit as omm_unit
    from openmm.app import CheckpointReporter, DCDReporter, PDBFile, Simulation, StateDataReporter

    from polyzymd.config.schema import (
        EquilibrationStageConfig,
        SimulationConfig,
        SimulationPhasesConfig,
    )
    from polyzymd.core.atom_groups import AtomGroupResolver
    from polyzymd.core.parameters import SimulationParameters
else:
    openmm = None
    XmlSerializer = None
    omm_unit = None
    CheckpointReporter = None
    DCDReporter = None
    PDBFile = None
    Simulation = None
    StateDataReporter = None

LOGGER = logging.getLogger(__name__)

# Phase types
PhaseType = Literal["equilibration", "production"]


def _ensure_openmm_loaded() -> None:
    """Load OpenMM symbols used by the simulation runner lazily.

    Returns
    -------
    None
        The module globals are populated on first use.
    """
    global CheckpointReporter, DCDReporter, PDBFile, Simulation, StateDataReporter
    global XmlSerializer, omm_unit, openmm

    if openmm is not None:
        return

    import openmm as _openmm
    from openmm import XmlSerializer as _XmlSerializer
    from openmm import unit as _omm_unit
    from openmm.app import (
        CheckpointReporter as _CheckpointReporter,
    )
    from openmm.app import (
        DCDReporter as _DCDReporter,
    )
    from openmm.app import (
        PDBFile as _PDBFile,
    )
    from openmm.app import (
        Simulation as _Simulation,
    )
    from openmm.app import (
        StateDataReporter as _StateDataReporter,
    )

    openmm = _openmm
    XmlSerializer = _XmlSerializer
    omm_unit = _omm_unit
    CheckpointReporter = _CheckpointReporter
    DCDReporter = _DCDReporter
    PDBFile = _PDBFile
    Simulation = _Simulation
    StateDataReporter = _StateDataReporter


[docs] class SimulationRunner: """Runner for executing OpenMM molecular dynamics simulations. This class manages: - Equilibration and production phase execution - Checkpoint saving and state data reporting - Energy minimization - Unique force group assignment for energy decomposition Example: >>> runner = SimulationRunner( ... topology=omm_topology, ... system=omm_system, ... positions=omm_positions, ... working_dir="output/", ... ) >>> runner.minimize() >>> runner.run_equilibration(temperature=300, config=sim_config.simulation_phases) >>> runner.run_production(temperature=300, duration_ns=100) """
[docs] def __init__( self, topology: Any, system: openmm.System, positions: Any, working_dir: Union[str, Path], platform: str = "CUDA", ) -> None: """Initialize the SimulationRunner. Args: topology: OpenMM Topology. system: OpenMM System. positions: Initial positions with units. working_dir: Working directory for output files. platform: Compute platform (CUDA, OpenCL, CPU). """ _ensure_openmm_loaded() self._topology = topology self._system = system self._positions = positions self._working_dir = Path(working_dir) self._platform_name = platform self._simulation: Optional[Simulation] = None self._current_positions = positions self._current_velocities = None # Carried between equilibration stages self._current_box_vectors = None # Updated during NPT stages self._history: Dict[str, Any] = {} # Ensure working directory exists self._working_dir.mkdir(parents=True, exist_ok=True) # Apply unique force groups for energy decomposition self._impose_unique_force_groups()
@property def simulation(self) -> Optional[Simulation]: """Get the current OpenMM Simulation object.""" return self._simulation @property def working_dir(self) -> Path: """Get the working directory path.""" return self._working_dir @property def history(self) -> Dict[str, Any]: """Get the simulation history.""" return self._history def _impose_unique_force_groups(self) -> None: """Assign unique force groups to each force for energy decomposition.""" from polyzymd.utils import impose_unique_force_groups impose_unique_force_groups(self._system) LOGGER.debug("Assigned unique force groups") def _get_platform(self) -> openmm.Platform: """Get the compute platform. Returns: OpenMM Platform object. """ _ensure_openmm_loaded() try: platform = openmm.Platform.getPlatformByName(self._platform_name) LOGGER.info(f"Using {self._platform_name} platform") except openmm.OpenMMException as exc: LOGGER.warning( "Platform %s is not available (%s); falling back to CPU. " "Install/configure the requested OpenMM platform or set platform='CPU'.", self._platform_name, exc, ) platform = openmm.Platform.getPlatformByName("CPU") return platform def _create_integrator( self, temperature: float, friction: float = 1.0, timestep: float = 2.0, thermostat: str = "LangevinMiddle", ) -> openmm.Integrator: """Create an integrator for the simulation. Args: temperature: Temperature in Kelvin. friction: Friction coefficient in 1/ps. timestep: Time step in femtoseconds. thermostat: Thermostat type. Returns: OpenMM Integrator. """ temp = temperature * omm_unit.kelvin fric = friction / omm_unit.picosecond dt = timestep * omm_unit.femtosecond if thermostat == "LangevinMiddle": return openmm.LangevinMiddleIntegrator(temp, fric, dt) elif thermostat == "Langevin": return openmm.LangevinIntegrator(temp, fric, dt) else: LOGGER.warning(f"Unknown thermostat {thermostat}, using LangevinMiddle") return openmm.LangevinMiddleIntegrator(temp, fric, dt) def _add_barostat( self, pressure: float = 1.0, temperature: float = 300.0, frequency: int = 25, ) -> None: """Add a Monte Carlo barostat to the system. Args: pressure: Pressure in atmospheres. temperature: Temperature in Kelvin. frequency: Update frequency in steps. """ barostat = openmm.MonteCarloBarostat( pressure * omm_unit.atmosphere, temperature * omm_unit.kelvin, frequency, ) self._system.addForce(barostat) LOGGER.info(f"Added MC barostat: {pressure} atm, {temperature} K") def _remove_barostat(self) -> None: """Remove any barostat from the system.""" forces_to_remove = [] for i in range(self._system.getNumForces()): force = self._system.getForce(i) if isinstance(force, openmm.MonteCarloBarostat): forces_to_remove.append(i) for i in reversed(forces_to_remove): self._system.removeForce(i) LOGGER.debug("Removed barostat")
[docs] def minimize( self, max_iterations: int = 1000, tolerance: float = 10.0, ) -> float: """Run energy minimization. Args: max_iterations: Maximum iterations (0 = until convergence). tolerance: Energy tolerance in kJ/mol/nm. Returns: Final potential energy in kJ/mol. """ LOGGER.info("Running energy minimization") # Create temporary simulation for minimization integrator = openmm.VerletIntegrator(1.0 * omm_unit.femtosecond) platform = self._get_platform() simulation = Simulation(self._topology, self._system, integrator, platform) simulation.context.setPositions(self._current_positions) # Minimize simulation.minimizeEnergy( tolerance=tolerance * omm_unit.kilojoule_per_mole / omm_unit.nanometer, maxIterations=max_iterations, ) # Get final state (including box vectors for proper handoff to equilibration) state = simulation.context.getState(getEnergy=True, getPositions=True) energy = state.getPotentialEnergy().value_in_unit(omm_unit.kilojoule_per_mole) self._current_positions = state.getPositions() self._current_box_vectors = state.getPeriodicBoxVectors() LOGGER.info(f"Minimization complete: E = {energy:.2f} kJ/mol") return energy
[docs] def run_equilibration( self, temperature: float, config: "SimulationPhasesConfig", ) -> Dict[str, Any]: """Run equilibration phase. Staged equilibration is required. Position restraints and temperature ramping are handled automatically. Component information is derived from the topology's chain IDs. Args: temperature: Temperature in Kelvin config: SimulationPhasesConfig containing equilibration stages Returns: Dictionary with equilibration results """ from polyzymd.core.atom_groups import AtomGroupResolver, SystemComponentInfo LOGGER.info( f"Starting staged equilibration with {len(config.equilibration_stages)} stage(s):" ) for i, stage in enumerate(config.equilibration_stages): restraint_info = ( ", ".join(f"{r.group}@{r.force_constant:.0f}" for r in stage.position_restraints) or "none" ) if stage.is_temperature_ramping: temp_info = f"{stage.temperature_start}K -> {stage.temperature_end}K" else: temp_info = f"{stage.temperature}K" LOGGER.info( f" Stage {i}: {stage.name} - {stage.duration} ns, " f"{stage.ensemble.value}, {temp_info}, restraints: [{restraint_info}]" ) component_info = SystemComponentInfo.from_topology(self._topology) resolver = AtomGroupResolver(self._topology, component_info) return self.run_staged_equilibration( stages=config.equilibration_stages, atom_group_resolver=resolver, target_temperature=temperature, )
[docs] def run_equilibration_stage( self, stage: "EquilibrationStageConfig", reference_positions: Any, atom_group_resolver: "AtomGroupResolver", stage_index: int, default_timestep: float = 2.0, default_friction: float = 1.0, resume_from_step: int = 0, resume_temperature: float | None = None, ) -> Dict[str, Any]: """Run a single equilibration stage with optional position restraints. This method runs one stage of a multi-stage equilibration protocol. It supports: - Position restraints on predefined atom groups - Temperature ramping (simulated annealing) - NVT or NPT ensembles - Graceful interruption and mid-stage resume Args: stage: EquilibrationStageConfig with stage settings reference_positions: Positions to restrain atoms to (typically post-minimization) atom_group_resolver: Resolver for predefined atom group names stage_index: Index of this stage (for output naming) default_timestep: Default time step in fs if not specified in stage default_friction: Default friction coefficient in 1/ps resume_from_step: Step to resume from within this stage (0 = start fresh). When resuming, the simulation must already be loaded from checkpoint. resume_temperature: Current temperature when resuming a temperature-ramping stage. Only used when ``resume_from_step > 0`` and the stage uses temperature ramping. Returns: Dictionary with stage results """ from polyzymd.config.schema import Ensemble from polyzymd.core.position_restraints import ( add_position_restraints_to_system, remove_position_restraints_from_system, ) stage_name = f"equilibration_{stage_index}_{stage.name}" LOGGER.info(f"Starting equilibration stage: {stage.name} ({stage.duration} ns)") # Create output directory for this stage phase_dir = self._working_dir / stage_name phase_dir.mkdir(exist_ok=True) # Get stage parameters (with defaults) timestep_fs = stage.time_step if stage.time_step is not None else default_timestep thermostat_timescale = stage.thermostat_timescale if stage.thermostat_timescale else 1.0 friction = 1.0 / thermostat_timescale # friction (1/ps) = 1 / timescale (ps) # Calculate steps and reporting interval total_steps = int(stage.duration * 1e6 / timestep_fs) report_interval = max(1, total_steps // stage.samples) # Handle ensemble - add/remove barostat if stage.ensemble == Ensemble.NPT: self._remove_barostat() pressure = 1.0 # Default pressure for NPT stages barostat_freq = stage.barostat_frequency if stage.barostat_frequency else 25 start_temp = stage.get_start_temperature() self._add_barostat( pressure=pressure, temperature=start_temp, frequency=barostat_freq, ) else: # NVT - ensure no barostat self._remove_barostat() # Add position restraints restraint_force_indices = [] for restraint_config in stage.position_restraints: atom_indices = atom_group_resolver.resolve(restraint_config.group) if atom_indices: force_idx = add_position_restraints_to_system( system=self._system, atom_indices=atom_indices, positions=reference_positions, force_constant=restraint_config.force_constant, ) if force_idx >= 0: restraint_force_indices.append(force_idx) LOGGER.info( f"Added position restraints to {len(atom_indices)} atoms " f"in group '{restraint_config.group}' " f"(k={restraint_config.force_constant:.1f} kJ/mol/nm^2)" ) else: LOGGER.warning( f"No atoms found for group '{restraint_config.group}' - skipping restraint" ) # Create integrator with starting temperature start_temp = stage.get_start_temperature() integrator = self._create_integrator( temperature=start_temp, friction=friction, timestep=timestep_fs, ) platform = self._get_platform() # Create simulation self._simulation = Simulation(self._topology, self._system, integrator, platform) # Set box vectors BEFORE positions - critical for NPT stage transitions # where box dimensions may have changed from previous stage if self._current_box_vectors is not None: self._simulation.context.setPeriodicBoxVectors(*self._current_box_vectors) self._simulation.context.setPositions(self._current_positions) # Velocity initialization: only generate fresh Maxwell-Boltzmann # velocities for the first stage. Subsequent stages inherit velocities # from the previous stage for physical continuity — matching the # GROMACS convention (gen_vel=yes only for stage 0). if stage_index == 0 or self._current_velocities is None: self._simulation.context.setVelocitiesToTemperature(start_temp * omm_unit.kelvin) LOGGER.info(f"Stage {stage_index}: initialized velocities at {start_temp} K") else: self._simulation.context.setVelocities(self._current_velocities) LOGGER.info(f"Stage {stage_index}: inherited velocities from previous stage") # Log initial energy _state = self._simulation.context.getState(getEnergy=True) _energy = _state.getPotentialEnergy().value_in_unit(omm_unit.kilojoule_per_mole) LOGGER.info(f"Stage {stage_index} ({stage_name}): initial PE = {_energy:.2f} kJ/mol") # Set up reporters traj_path = phase_dir / f"{stage_name}_trajectory.dcd" state_path = phase_dir / f"{stage_name}_state_data.csv" pdb_path = phase_dir / f"{stage_name}_topology.pdb" self._simulation.reporters.append(DCDReporter(str(traj_path), report_interval)) self._simulation.reporters.append( StateDataReporter( str(state_path), report_interval, step=True, time=True, potentialEnergy=True, kineticEnergy=True, totalEnergy=True, temperature=True, volume=True, density=True, speed=True, ) ) # Save initial topology with open(pdb_path, "w") as f: PDBFile.writeFile( self._topology, self._current_positions, f, ) # Periodic checkpoint reporter so state survives hard kills eq_chk_path = phase_dir / f"{stage_name}_checkpoint.chk" self._simulation.reporters.append(CheckpointReporter(str(eq_chk_path), report_interval)) # Install signal handlers for graceful shutdown from polyzymd.simulation.signals import ( GracefulExit, get_interrupt_signal, install_handlers, is_interrupted, ) install_handlers() # Helper: save EQ_INTERRUPTED marker for mid-stage resume def _save_eq_interrupted(steps_done: int, current_temp: float) -> None: marker_path = phase_dir / "EQ_INTERRUPTED" marker_path.write_text( f"stage_index={stage_index}\n" f"stage_name={stage.name}\n" f"steps_completed={steps_done}\n" f"total_steps={total_steps}\n" f"current_temperature={current_temp}\n" f"is_temperature_ramping={stage.is_temperature_ramping}\n" ) LOGGER.info( f"Saved EQ_INTERRUPTED marker: stage {stage_index}, " f"step {steps_done}/{total_steps}, temp {current_temp} K" ) # Track steps completed (starting from resume point) steps_done = resume_from_step # Run simulation with temperature ramping if needed if stage.is_temperature_ramping: LOGGER.info( f"Temperature ramping: {stage.temperature_start} K -> {stage.temperature_end} K " f"(increment={stage.temperature_increment} K every {stage.temperature_interval} fs)" ) steps_per_update = int(stage.temperature_interval / timestep_fs) # Calculate total temperature updates needed temp_range = stage.temperature_end - stage.temperature_start num_updates = int(temp_range / stage.temperature_increment) steps_for_ramping = num_updates * steps_per_update remaining_steps_at_final = total_steps - steps_for_ramping # Determine starting temperature — always begin from # temperature_start and let the fast-forward loop advance # current_temp by skipping already-completed chunks. On # resume, resume_temperature is used only for logging. current_temp = stage.temperature_start if resume_from_step > 0 and resume_temperature is not None: LOGGER.info( f"Resuming temperature ramp from step {resume_from_step}, " f"saved temp {resume_temperature} K (fast-forwarding from {current_temp} K)" ) # Temperature ramping phase — each update is a chunk we can # interrupt between. Skip chunks already completed on resume. ramp_step_count = 0 while current_temp < stage.temperature_end: chunk_end = ramp_step_count + steps_per_update if chunk_end <= resume_from_step: # Already completed this chunk in a previous run ramp_step_count = chunk_end current_temp += stage.temperature_increment continue integrator.setTemperature(current_temp * omm_unit.kelvin) if stage.ensemble == Ensemble.NPT: self._simulation.context.setParameter( openmm.MonteCarloBarostat.Temperature(), current_temp * omm_unit.kelvin, ) # If resuming mid-chunk, only run the remainder steps_already = max(0, resume_from_step - ramp_step_count) steps_this_chunk = steps_per_update - steps_already self._simulation.step(steps_this_chunk) steps_done += steps_this_chunk ramp_step_count = chunk_end if is_interrupted(): LOGGER.warning( f"Interrupt during equilibration stage {stage_index} " f"at step {steps_done}/{total_steps} (ramping, T={current_temp:.1f} K)" ) _save_eq_interrupted(steps_done, current_temp) raise GracefulExit( signal_number=get_interrupt_signal(), steps_completed=steps_done ) current_temp += stage.temperature_increment # Final temperature - run remaining steps in chunks integrator.setTemperature(stage.temperature_end * omm_unit.kelvin) if stage.ensemble == Ensemble.NPT: self._simulation.context.setParameter( openmm.MonteCarloBarostat.Temperature(), stage.temperature_end * omm_unit.kelvin, ) current_temp = stage.temperature_end steps_at_final_done = steps_done - steps_for_ramping steps_at_final_remaining = max(0, remaining_steps_at_final - steps_at_final_done) if steps_at_final_remaining > 0: LOGGER.info( f"Running {steps_at_final_remaining} steps at final " f"temperature {stage.temperature_end} K" ) chunk_size = min(report_interval, steps_at_final_remaining) while steps_at_final_remaining > 0: this_chunk = min(chunk_size, steps_at_final_remaining) self._simulation.step(this_chunk) steps_done += this_chunk steps_at_final_remaining -= this_chunk if is_interrupted(): LOGGER.warning( f"Interrupt during equilibration stage {stage_index} " f"at step {steps_done}/{total_steps} (final temp)" ) _save_eq_interrupted(steps_done, current_temp) raise GracefulExit( signal_number=get_interrupt_signal(), steps_completed=steps_done, ) else: # Constant temperature - run in chunks with signal checking current_temp = stage.temperature steps_remaining = total_steps - resume_from_step if resume_from_step > 0: LOGGER.info( f"Resuming constant-temp stage from step {resume_from_step}, " f"{steps_remaining} steps remaining" ) else: LOGGER.info(f"Running {total_steps} steps at {stage.temperature} K") chunk_size = min(report_interval, steps_remaining) while steps_remaining > 0: this_chunk = min(chunk_size, steps_remaining) self._simulation.step(this_chunk) steps_done += this_chunk steps_remaining -= this_chunk if is_interrupted(): LOGGER.warning( f"Interrupt during equilibration stage {stage_index} " f"at step {steps_done}/{total_steps}" ) _save_eq_interrupted(steps_done, current_temp) raise GracefulExit( signal_number=get_interrupt_signal(), steps_completed=steps_done ) # Get final state (including box vectors for NPT stages) state = self._simulation.context.getState( getPositions=True, getVelocities=True, getEnergy=True ) self._current_positions = state.getPositions() self._current_velocities = state.getVelocities() self._current_box_vectors = state.getPeriodicBoxVectors() # Log final energy for diagnostics final_energy = state.getPotentialEnergy().value_in_unit(omm_unit.kilojoule_per_mole) LOGGER.info(f"Stage {stage_index} ({stage_name}): final PE = {final_energy:.2f} kJ/mol") # Save checkpoint checkpoint_path = phase_dir / f"{stage_name}_checkpoint.chk" self._simulation.saveCheckpoint(str(checkpoint_path)) # Remove EQ_INTERRUPTED marker if present (stage completed successfully) eq_interrupted_marker = phase_dir / "EQ_INTERRUPTED" if eq_interrupted_marker.exists(): eq_interrupted_marker.unlink() LOGGER.info("Removed EQ_INTERRUPTED marker — stage completed successfully") # Remove position restraints from system for next stage if restraint_force_indices: LOGGER.info( f"Removing {len(restraint_force_indices)} position restraint force(s) for next stage" ) remove_position_restraints_from_system(self._system, restraint_force_indices) # Build results final_temp = stage.get_final_temperature() results = { "stage_index": stage_index, "stage_name": stage.name, "ensemble": stage.ensemble.value, "duration_ns": stage.duration, "total_steps": total_steps, "temperature_start_K": stage.get_start_temperature(), "temperature_end_K": final_temp, "is_temperature_ramping": stage.is_temperature_ramping, "position_restraints": [ {"group": r.group, "force_constant": r.force_constant} for r in stage.position_restraints ], "final_energy_kJ_mol": state.getPotentialEnergy().value_in_unit( omm_unit.kilojoule_per_mole ), "trajectory_path": str(traj_path), "checkpoint_path": str(checkpoint_path), } LOGGER.info(f"Equilibration stage '{stage.name}' complete") return results
def _find_completed_eq_stages( self, stages: List["EquilibrationStageConfig"], ) -> List[int]: """Find equilibration stages whose checkpoints exist on disk. Scans ``equilibration_N_name/`` directories for checkpoint files. Stops at the first gap — stages must be contiguous from index 0. An interrupted stage (has ``EQ_INTERRUPTED`` marker) is NOT considered completed and stops the scan. Parameters ---------- stages : list of EquilibrationStageConfig The full list of stages from the config. Returns ------- list of int Indices of completed stages (contiguous from 0). """ completed: List[int] = [] for i, stage in enumerate(stages): stage_name = f"equilibration_{i}_{stage.name}" stage_dir = self._working_dir / stage_name chk = stage_dir / f"{stage_name}_checkpoint.chk" eq_marker = stage_dir / "EQ_INTERRUPTED" if chk.exists() and not eq_marker.exists(): completed.append(i) else: break # Stop at first gap — can't skip stages return completed def _find_interrupted_eq_stage( self, stages: List["EquilibrationStageConfig"], completed_indices: List[int], ) -> Dict[str, Any] | None: """Check if the next unfinished equilibration stage was interrupted mid-run. Looks for an ``EQ_INTERRUPTED`` marker in the stage directory that follows the last completed stage. If found, parses resume metadata. Parameters ---------- stages : list of EquilibrationStageConfig The full list of stages from the config. completed_indices : list of int Indices of fully completed stages (from ``_find_completed_eq_stages``). Returns ------- dict or None Dictionary with ``stage_index``, ``steps_completed``, ``total_steps``, and ``current_temperature`` if an interrupted stage is found; ``None`` otherwise. """ next_idx = len(completed_indices) if next_idx >= len(stages): return None stage = stages[next_idx] stage_name = f"equilibration_{next_idx}_{stage.name}" stage_dir = self._working_dir / stage_name marker_path = stage_dir / "EQ_INTERRUPTED" if not marker_path.exists(): return None # Parse the marker info: Dict[str, Any] = {"stage_index": next_idx} try: text = marker_path.read_text() for line in text.strip().splitlines(): key, _, value = line.partition("=") key = key.strip() value = value.strip() if key == "steps_completed": info["steps_completed"] = int(value) elif key == "total_steps": info["total_steps"] = int(value) elif key == "current_temperature": info["current_temperature"] = float(value) elif key == "is_temperature_ramping": info["is_temperature_ramping"] = value.lower() == "true" except (ValueError, OSError) as exc: LOGGER.warning(f"Could not parse EQ_INTERRUPTED marker {marker_path}: {exc}") return None # Verify checkpoint exists (needed for resume) chk = stage_dir / f"{stage_name}_checkpoint.chk" if not chk.exists(): LOGGER.warning( f"EQ_INTERRUPTED marker found for stage {next_idx} but no checkpoint — " f"will restart stage from beginning" ) return None LOGGER.info( f"Found interrupted equilibration stage {next_idx} ({stage.name}): " f"{info.get('steps_completed', 0)}/{info.get('total_steps', '?')} steps, " f"T={info.get('current_temperature', '?')} K" ) return info def _load_eq_stage_state( self, stage_index: int, stage_name: str, ) -> None: """Load positions, velocities, and box vectors from a completed equilibration checkpoint. Creates a temporary ``Simulation`` to deserialise the binary checkpoint, then stores the extracted state on ``self`` so subsequent stages or production can pick up seamlessly. The temporary simulation is discarded after extraction. Parameters ---------- stage_index : int Index of the completed stage. stage_name : str Name of the completed stage (used in directory/file naming). """ dir_name = f"equilibration_{stage_index}_{stage_name}" chk_path = self._working_dir / dir_name / f"{dir_name}_checkpoint.chk" if not chk_path.exists(): raise FileNotFoundError(f"Equilibration checkpoint not found: {chk_path}") # Temporary simulation context to deserialise the binary checkpoint. # We use a dummy VerletIntegrator because we only need to extract # geometric state (positions, velocities, box vectors) — the next # run_equilibration_stage() call creates a proper LangevinMiddleIntegrator. integrator = openmm.VerletIntegrator(1.0 * omm_unit.femtosecond) platform = self._get_platform() temp_sim = Simulation(self._topology, self._system, integrator, platform) temp_sim.loadCheckpoint(str(chk_path)) state = temp_sim.context.getState(getPositions=True, getVelocities=True) self._current_positions = state.getPositions() self._current_velocities = state.getVelocities() self._current_box_vectors = state.getPeriodicBoxVectors() # Discard the temporary simulation — it has a dummy integrator del temp_sim LOGGER.info( f"Loaded state from equilibration stage {stage_index} ({stage_name}) checkpoint" )
[docs] def run_staged_equilibration( self, stages: List["EquilibrationStageConfig"], atom_group_resolver: "AtomGroupResolver", target_temperature: float, ) -> Dict[str, Any]: """Run complete multi-stage equilibration protocol. This method executes a sequence of equilibration stages, each with potentially different: - Temperature (constant or ramping) - Position restraints on different atom groups - Thermodynamic ensemble (NVT/NPT) Positions carry over between stages, and restraint forces are added/removed as needed. If a previous run was interrupted, completed stages are detected on disk via their checkpoint files and skipped automatically. Args: stages: List of EquilibrationStageConfig objects atom_group_resolver: Resolver for predefined atom group names target_temperature: Final target temperature (for logging) Returns: Dictionary with all stage results and summary """ import shutil LOGGER.info(f"Starting multi-stage equilibration with {len(stages)} stages") # Store reference positions for restraints BEFORE loading any checkpoint. # These are the post-minimization positions that restraint forces target. # Must be captured before _load_eq_stage_state() overwrites _current_positions. reference_positions = self._current_positions # Detect stages that already completed (checkpoint exists on disk) completed_indices = self._find_completed_eq_stages(stages) if completed_indices: last_idx = completed_indices[-1] last_stage = stages[last_idx] LOGGER.info( f"Found {len(completed_indices)} completed equilibration stage(s) " f"on disk — resuming after stage {last_idx} ({last_stage.name})" ) self._load_eq_stage_state(last_idx, last_stage.name) # Check if the next stage was interrupted mid-run (has EQ_INTERRUPTED marker) interrupted_info = self._find_interrupted_eq_stage(stages, completed_indices) results: Dict[str, Any] = { "type": "staged_equilibration", "num_stages": len(stages), "stages": [], "total_duration_ns": 0.0, } for i, stage in enumerate(stages): if i in completed_indices: LOGGER.info(f"Skipping completed equilibration stage {i}: {stage.name}") results["stages"].append( { "stage_index": i, "stage_name": stage.name, "skipped": True, "duration_ns": stage.duration, } ) results["total_duration_ns"] += stage.duration continue stage_dir_name = f"equilibration_{i}_{stage.name}" partial_dir = self._working_dir / stage_dir_name # Check if this stage was interrupted and can be resumed if interrupted_info is not None and interrupted_info["stage_index"] == i: # Resume mid-stage from checkpoint resume_step = interrupted_info.get("steps_completed", 0) resume_temp = interrupted_info.get("current_temperature") LOGGER.info( f"Resuming interrupted equilibration stage {i} ({stage.name}) " f"from step {resume_step}" ) self._load_eq_stage_state(i, stage.name) stage_result = self.run_equilibration_stage( stage=stage, reference_positions=reference_positions, atom_group_resolver=atom_group_resolver, stage_index=i, resume_from_step=resume_step, resume_temperature=resume_temp, ) results["stages"].append(stage_result) results["total_duration_ns"] += stage.duration continue # Clean up partial stage directory (exists but no checkpoint and # no EQ_INTERRUPTED marker — truly failed, restart from scratch) if partial_dir.exists(): LOGGER.warning( f"Stage {i} ({stage.name}) directory exists without checkpoint " f"— removing partial output and re-running" ) shutil.rmtree(partial_dir) stage_result = self.run_equilibration_stage( stage=stage, reference_positions=reference_positions, atom_group_resolver=atom_group_resolver, stage_index=i, ) results["stages"].append(stage_result) results["total_duration_ns"] += stage.duration # Get final energy if results["stages"]: # Find last non-skipped result last_result = None for r in reversed(results["stages"]): if not r.get("skipped"): last_result = r break if last_result: results["final_energy_kJ_mol"] = last_result["final_energy_kJ_mol"] results["final_temperature_K"] = last_result["temperature_end_K"] self._history["equilibration"] = results skipped = len(completed_indices) ran = len(stages) - skipped LOGGER.info( f"Multi-stage equilibration complete: {len(stages)} stages " f"({skipped} skipped, {ran} ran), " f"{results['total_duration_ns']:.3f} ns total" ) return results
[docs] def run_production( self, temperature: float, duration_ns: float, num_samples: int = 2500, timestep_fs: float = 2.0, friction: float = 1.0, pressure: float = 1.0, barostat_frequency: int = 25, output_prefix: str = "production", segment_index: int = 0, *, report_interval: int, checkpoint_interval_s: float, ) -> Dict[str, Any]: """Run NPT production simulation. Args: temperature: Temperature in Kelvin. duration_ns: Duration in nanoseconds. num_samples: Number of trajectory frames to save. timestep_fs: Time step in femtoseconds. friction: Friction coefficient in 1/ps. pressure: Pressure in atmospheres. barostat_frequency: Barostat update frequency. output_prefix: Prefix for output files. segment_index: Segment index for multi-segment production. report_interval: Explicit reporter interval in steps. checkpoint_interval_s: Wall-time interval in seconds between portable restart checkpoints. Also controls how frequently the loop checks for SLURM preemption signals. Returns: Dictionary with phase results. """ LOGGER.info( f"Starting production: {duration_ns} ns at {temperature} K, {pressure} atm (NPT)" ) # Add barostat for NPT self._remove_barostat() self._add_barostat( pressure=pressure, temperature=temperature, frequency=barostat_frequency, ) # Create output directory phase_name = f"{output_prefix}_{segment_index}" phase_dir = self._working_dir / phase_name phase_dir.mkdir(exist_ok=True) # Calculate steps total_steps = int(duration_ns * 1e6 / timestep_fs) if report_interval <= 0: raise ValueError("report_interval must be a positive integer") if checkpoint_interval_s <= 0: raise ValueError("checkpoint_interval_s must be positive") # Create integrator and simulation integrator = self._create_integrator( temperature=temperature, friction=friction, timestep=timestep_fs, ) platform = self._get_platform() self._simulation = Simulation(self._topology, self._system, integrator, platform) # Set box vectors BEFORE positions - critical for correct periodic boundary handling if self._current_box_vectors is not None: self._simulation.context.setPeriodicBoxVectors(*self._current_box_vectors) self._simulation.context.setPositions(self._current_positions) # Log initial energy _state = self._simulation.context.getState(getEnergy=True) _energy = _state.getPotentialEnergy().value_in_unit(omm_unit.kilojoule_per_mole) LOGGER.info(f"Production segment {segment_index}: initial PE = {_energy:.2f} kJ/mol") # Save system XML early so it exists on disk even if the segment is # hard-killed (SIGKILL / OOM). This is required for checkpoint-based # recovery: loadCheckpoint() needs a matching System object. The file # is overwritten at segment completion with the final state. system_xml_path = phase_dir / f"{phase_name}_system.xml" with open(system_xml_path, "w") as f: f.write(XmlSerializer.serialize(self._system)) LOGGER.info(f"Saved initial system to {system_xml_path}") # Set velocities for production # - If we have velocities from equilibration, use them (physical continuity) # - Otherwise generate new velocities at target temperature # Note: For continuation segments (segment > 0), ContinuationManager uses # loadState() which restores both positions and velocities from the XML state file if segment_index == 0: if self._current_velocities is not None: self._simulation.context.setVelocities(self._current_velocities) LOGGER.info("Using velocities preserved from equilibration") else: self._simulation.context.setVelocitiesToTemperature(temperature * omm_unit.kelvin) LOGGER.info("Initialized velocities from Maxwell-Boltzmann distribution") # Add reporters traj_path = phase_dir / f"{phase_name}_trajectory.dcd" state_path = phase_dir / f"{phase_name}_state_data.csv" pdb_path = phase_dir / f"{phase_name}_topology.pdb" self._simulation.reporters.append(DCDReporter(str(traj_path), report_interval)) self._simulation.reporters.append( StateDataReporter( str(state_path), report_interval, step=True, time=True, potentialEnergy=True, kineticEnergy=True, totalEnergy=True, temperature=True, volume=True, density=True, speed=True, ) ) # Periodic checkpoint reporter — ensures a .chk file is written every # report_interval steps. Without this, segment 0 has NO checkpoint on # disk until the very end, so a hard kill (SIGKILL / OOM / node failure) # would lose all progress. Matches the behaviour already present in # continuation.py for segments >= 1. prod_chk_path = phase_dir / f"{phase_name}_checkpoint.chk" self._simulation.reporters.append(CheckpointReporter(str(prod_chk_path), report_interval)) # Save topology with open(pdb_path, "w") as f: PDBFile.writeFile( self._topology, self._current_positions, f, ) # Save parameters JSON before the simulation loop so it exists even # if the segment is interrupted (continuation.py requires this file) # Save parameters JSON (needed for continuation across segments) params_dict = { "__class__": "SimulationParameters", "__values__": { "thermo_params": { "__class__": "ThermoParameters", "__values__": { "temperature": { "__class__": "Quantity", "__values__": {"value": temperature, "unit": "kelvin"}, }, "thermostat_params": { "__class__": "ThermostatParameters", "__values__": { "temperature": { "__class__": "Quantity", "__values__": {"value": temperature, "unit": "kelvin"}, }, "timescale": { "__class__": "Quantity", "__values__": {"value": friction, "unit": "/picosecond"}, }, }, }, "barostat_params": { "__class__": "BarostatParameters", "__values__": { "pressure": { "__class__": "Quantity", "__values__": {"value": pressure, "unit": "atmosphere"}, }, "temperature": { "__class__": "Quantity", "__values__": {"value": temperature, "unit": "kelvin"}, }, "frequency": barostat_frequency, }, }, }, }, "integ_params": { "__class__": "IntegratorParameters", "__values__": { "time_step": { "__class__": "Quantity", "__values__": {"value": timestep_fs, "unit": "femtosecond"}, }, "total_time": { "__class__": "Quantity", "__values__": {"value": duration_ns, "unit": "nanosecond"}, }, "num_samples": num_samples, }, }, "reporter_params": { "__class__": "ReporterParameters", "__values__": { "report_interval": report_interval, "report_trajectory": True, "report_state_data": True, }, }, }, } params_path = phase_dir / f"{phase_name}_parameters.json" with open(params_path, "w") as f: json.dump(params_dict, f, indent=2) LOGGER.info(f"Saved parameters to {params_path}") # Install signal handlers for graceful shutdown (SIGUSR1 / SIGTERM) from polyzymd.simulation.signals import ( GracefulExit, get_interrupt_signal, install_handlers, interrupted_state_save_exceptions, is_interrupted, save_interrupted_state, save_restart_checkpoint, ) install_handlers() # Mark this segment as RUNNING in progress.json so that # check-progress can distinguish actively running simulations # from interrupted ones. self._write_segment_started(segment_index, total_steps) # Run simulation with adaptive sub-chunks for interrupt responsiveness # and periodic wall-time restart checkpoints for preemption resilience. LOGGER.info(f"Running {total_steps} steps...") steps_done = 0 import time as _time from polyzymd.simulation.progress import PROGRESS_UPDATE_INTERVAL_SECONDS _last_progress_write = _time.monotonic() _last_checkpoint_write = _time.monotonic() _loop_start = _time.monotonic() # Adaptive sub-chunk sizing: start with report_interval (the original # chunk_size). After the first checkpoint interval elapses, measure # actual steps/second and adapt sub_chunk to target # checkpoint_interval / 4 seconds (~15s worth of steps). This # ensures ~4 interrupt checks per checkpoint interval regardless of # system size or hardware speed. sub_chunk = min(report_interval, total_steps) _adapted = False try: while steps_done < total_steps: remaining = total_steps - steps_done this_chunk = min(sub_chunk, remaining) self._simulation.step(this_chunk) steps_done += this_chunk _now = _time.monotonic() # Adaptive sub-chunk calibration (once, after first interval) if not _adapted and (_now - _loop_start) >= checkpoint_interval_s: elapsed = _now - _loop_start steps_per_sec = steps_done / elapsed if elapsed > 0 else 1.0 # Target sub-chunk duration = checkpoint_interval / 4 target_seconds = checkpoint_interval_s / 4.0 new_sub_chunk = max(10, int(steps_per_sec * target_seconds)) # Sub-chunk must be a divisor-friendly size relative to # report_interval to avoid misaligned reporter writes. # Round down to the nearest multiple of report_interval, # or use report_interval itself if it's already smaller. if new_sub_chunk >= report_interval: new_sub_chunk = report_interval else: # Ensure sub_chunk divides evenly into report_interval # so reporters fire at exact multiples. # Find largest divisor of report_interval <= new_sub_chunk best = new_sub_chunk for candidate in [ report_interval // k for k in range(1, report_interval // max(1, new_sub_chunk) + 2) ]: if candidate <= new_sub_chunk and report_interval % candidate == 0: best = candidate break new_sub_chunk = max(10, best) if new_sub_chunk != sub_chunk: LOGGER.info( f"Adaptive sub-chunk: {sub_chunk} -> {new_sub_chunk} steps " f"(~{new_sub_chunk / steps_per_sec:.1f}s at " f"{steps_per_sec:.0f} steps/s)" ) sub_chunk = new_sub_chunk _adapted = True # Periodically update RUNNING record so check-progress # can display real-time remaining nanoseconds. if _now - _last_progress_write >= PROGRESS_UPDATE_INTERVAL_SECONDS: self._update_progress_running( segment_index=segment_index, steps_done=steps_done, timestep_fs=timestep_fs, ) _last_progress_write = _now # Wall-time restart checkpoint if ( (_now - _last_checkpoint_write) >= checkpoint_interval_s and steps_done < total_steps # skip if we're about to finish ): save_restart_checkpoint( simulation=self._simulation, output_dir=phase_dir, ) _last_checkpoint_write = _now if is_interrupted(): LOGGER.warning(f"Interrupt detected at step {steps_done}/{total_steps}") save_interrupted_state( simulation=self._simulation, output_dir=phase_dir, segment_index=segment_index, steps_completed=steps_done, total_steps=total_steps, ) # Update progress tracker (interrupted) self._update_progress_interrupted( segment_index=segment_index, steps_done=steps_done, total_steps=total_steps, duration_ns=duration_ns, timestep_fs=timestep_fs, ) raise GracefulExit( signal_number=get_interrupt_signal(), steps_completed=steps_done ) except GracefulExit: raise # Re-raise so caller can handle exit code except interrupted_state_save_exceptions(): # On unexpected crash, still try to save interrupted state try: save_interrupted_state( simulation=self._simulation, output_dir=phase_dir, segment_index=segment_index, steps_completed=steps_done, total_steps=total_steps, ) except interrupted_state_save_exceptions() as save_exc: LOGGER.exception( "Failed to save interrupted state after crash: %s", save_exc, ) raise # Get final state (no enforcePeriodicBox to preserve molecular continuity) state = self._simulation.context.getState( getPositions=True, getVelocities=True, getEnergy=True, getForces=True, getParameters=True, ) self._current_positions = state.getPositions() # Save checkpoint checkpoint_path = phase_dir / f"{phase_name}_checkpoint.chk" self._simulation.saveCheckpoint(str(checkpoint_path)) # Save state XML (needed for continuation across segments) state_xml_path = phase_dir / f"{phase_name}_state.xml" with open(state_xml_path, "w") as f: f.write(XmlSerializer.serialize(state)) LOGGER.info(f"Saved state to {state_xml_path}") # Save system XML (needed for continuation across segments) system_xml_path = phase_dir / f"{phase_name}_system.xml" with open(system_xml_path, "w") as f: f.write(XmlSerializer.serialize(self._system)) LOGGER.info(f"Saved system to {system_xml_path}") # Update progress tracker (successful completion) self._update_progress_completed( segment_index=segment_index, total_steps=total_steps, num_samples=num_samples, duration_ns=duration_ns, timestep_fs=timestep_fs, ) results = { "phase": "production", "segment": segment_index, "ensemble": "NPT", "temperature_K": temperature, "pressure_atm": pressure, "duration_ns": duration_ns, "total_steps": total_steps, "final_energy_kJ_mol": state.getPotentialEnergy().value_in_unit( omm_unit.kilojoule_per_mole ), "trajectory_path": str(traj_path), "checkpoint_path": str(checkpoint_path), } self._history[phase_name] = results LOGGER.info(f"Production segment {segment_index} complete") return results
def _write_segment_started( self, segment_index: int, total_steps: int, ) -> None: """Write a RUNNING segment record to progress.json at segment start. This marks the segment as actively executing so that ``check-progress`` can distinguish a running simulation from one that was interrupted. The record is later updated to COMPLETED or INTERRUPTED by the corresponding handler. Parameters ---------- segment_index : int Production segment index (0 for initial production). total_steps : int Total steps planned for this segment. """ from polyzymd.simulation.progress import ( PROGRESS_UPDATE_INTERVAL_SECONDS, SegmentRecord, SegmentStatus, SimulationStatus, _update_or_append_segment, load_progress, save_progress, ) progress = load_progress(self._working_dir) if progress is None: LOGGER.warning("No progress file found — skipping segment-started write") return record = SegmentRecord( index=segment_index, steps_completed=0, steps_requested=total_steps, samples_written=0, status=SegmentStatus.RUNNING, ) _update_or_append_segment(progress, record) progress.status = SimulationStatus.RUNNING save_progress(self._working_dir, progress) LOGGER.info(f"Marked segment {segment_index} as RUNNING in progress file") def _update_progress_running( self, segment_index: int, steps_done: int, timestep_fs: float, ) -> None: """Periodically update the RUNNING segment's step count. Called during the simulation loop to keep ``progress.json`` up to date so that ``check-progress`` can show real-time remaining nanoseconds. Parameters ---------- segment_index : int Production segment index. steps_done : int Steps completed so far in this segment. timestep_fs : float Integration timestep in femtoseconds. """ from polyzymd.simulation.progress import ( SegmentStatus, SimulationStatus, _update_or_append_segment, load_progress, save_progress, ) progress = load_progress(self._working_dir) if progress is None: return # Find and update the existing RUNNING record for seg in progress.segments: if seg.index == segment_index and seg.status == SegmentStatus.RUNNING: seg.steps_completed = steps_done seg.duration_ns = (steps_done * timestep_fs) / 1e6 break else: # No RUNNING record found — shouldn't happen, but be safe return progress.status = SimulationStatus.RUNNING save_progress(self._working_dir, progress) LOGGER.debug(f"Updated running progress: segment {segment_index}, {steps_done} steps") def _update_progress_completed( self, segment_index: int, total_steps: int, num_samples: int, duration_ns: float, timestep_fs: float, ) -> None: """Update progress file after successful segment completion. Parameters ---------- segment_index : int Production segment index (0 for initial production). total_steps : int Steps completed in this segment. num_samples : int Samples written in this segment. duration_ns : float Simulation time of this segment in nanoseconds. timestep_fs : float Integration timestep in femtoseconds. """ from polyzymd.simulation.progress import ( SegmentRecord, SegmentStatus, SimulationStatus, _now_iso, _update_or_append_segment, load_progress, save_progress, ) progress = load_progress(self._working_dir) if progress is None: LOGGER.warning("No progress file found — skipping progress update") return record = SegmentRecord( index=segment_index, steps_completed=total_steps, steps_requested=total_steps, samples_written=num_samples, status=SegmentStatus.COMPLETED, duration_ns=duration_ns, ) record.finished_at = _now_iso() _update_or_append_segment(progress, record) # Update overall status if progress.is_complete: progress.status = SimulationStatus.COMPLETED else: progress.status = SimulationStatus.RUNNING save_progress(self._working_dir, progress) LOGGER.info( f"Progress updated: {progress.total_steps_completed}/" f"{progress.total_steps_requested} steps " f"({progress.fraction_complete():.1%})" ) def _update_progress_interrupted( self, segment_index: int, steps_done: int, total_steps: int, duration_ns: float, timestep_fs: float, ) -> None: """Update progress file after segment interruption. Parameters ---------- segment_index : int Production segment index. steps_done : int Steps completed before interruption. total_steps : int Steps that were planned for this segment. duration_ns : float Planned simulation time of this segment in nanoseconds. timestep_fs : float Integration timestep in femtoseconds. """ from polyzymd.simulation.progress import ( SegmentRecord, SegmentStatus, SimulationStatus, _update_or_append_segment, load_progress, save_progress, ) progress = load_progress(self._working_dir) if progress is None: LOGGER.warning("No progress file found — skipping progress update") return # Calculate actual duration completed actual_duration_ns = (steps_done * timestep_fs) / 1e6 record = SegmentRecord( index=segment_index, steps_completed=steps_done, steps_requested=total_steps, samples_written=0, # Interrupted — samples may be partial status=SegmentStatus.INTERRUPTED, duration_ns=actual_duration_ns, ) _update_or_append_segment(progress, record) progress.status = SimulationStatus.INTERRUPTED save_progress(self._working_dir, progress) LOGGER.info( f"Progress updated (interrupted): {steps_done}/{total_steps} steps " f"in segment {segment_index}" )
[docs] def save_history(self, path: Optional[Union[str, Path]] = None) -> None: """Save simulation history to JSON. Args: path: Output path (defaults to working_dir/simulation_history.json). """ if path is None: path = self._working_dir / "simulation_history.json" else: path = Path(path) with open(path, "w") as f: json.dump(self._history, f, indent=2) LOGGER.info(f"Saved simulation history to {path}")
[docs] def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> None: """Load state from a checkpoint file. Args: checkpoint_path: Path to checkpoint file. """ checkpoint_path = Path(checkpoint_path) if not checkpoint_path.exists(): raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") if self._simulation is None: raise RuntimeError( "No active simulation. Create a simulation first with " "run_equilibration or run_production." ) self._simulation.loadCheckpoint(str(checkpoint_path)) # Update current positions, velocities, and box vectors state = self._simulation.context.getState(getPositions=True, getVelocities=True) self._current_positions = state.getPositions() self._current_velocities = state.getVelocities() self._current_box_vectors = state.getPeriodicBoxVectors() LOGGER.info(f"Loaded checkpoint from {checkpoint_path}")