"""
Simulation runner for executing MD simulations with OpenMM.
This module handles running equilibration and production phases
with configurable parameters, reporters, and checkpoint management.
"""
from __future__ import annotations
import json
import logging
from dataclasses import asdict
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Union
import openmm
from openmm import XmlSerializer
from openmm import unit as omm_unit
from openmm.app import DCDReporter, PDBFile, Simulation, StateDataReporter
if TYPE_CHECKING:
from polyzymd.config.schema import (
EquilibrationStageConfig,
SimulationConfig,
SimulationPhaseConfig,
SimulationPhasesConfig,
)
from polyzymd.core.atom_groups import AtomGroupResolver
from polyzymd.core.parameters import SimulationParameters
LOGGER = logging.getLogger(__name__)
# Phase types
PhaseType = Literal["equilibration", "production"]
[docs]
class SimulationRunner:
"""Runner for executing OpenMM molecular dynamics simulations.
This class manages:
- Equilibration and production phase execution
- Checkpoint saving and state data reporting
- Energy minimization
- Unique force group assignment for energy decomposition
Example:
>>> runner = SimulationRunner(
... topology=omm_topology,
... system=omm_system,
... positions=omm_positions,
... working_dir="output/",
... )
>>> runner.minimize()
>>> runner.run_equilibration(temperature=300, duration_ns=0.5)
>>> runner.run_production(temperature=300, duration_ns=100)
"""
[docs]
def __init__(
self,
topology: Any,
system: openmm.System,
positions: Any,
working_dir: Union[str, Path],
platform: str = "CUDA",
) -> None:
"""Initialize the SimulationRunner.
Args:
topology: OpenMM Topology.
system: OpenMM System.
positions: Initial positions with units.
working_dir: Working directory for output files.
platform: Compute platform (CUDA, OpenCL, CPU).
"""
self._topology = topology
self._system = system
self._positions = positions
self._working_dir = Path(working_dir)
self._platform_name = platform
self._simulation: Optional[Simulation] = None
self._current_positions = positions
self._current_box_vectors = None # Updated during NPT stages
self._history: Dict[str, Any] = {}
# Ensure working directory exists
self._working_dir.mkdir(parents=True, exist_ok=True)
# Apply unique force groups for energy decomposition
self._impose_unique_force_groups()
@property
def simulation(self) -> Optional[Simulation]:
"""Get the current OpenMM Simulation object."""
return self._simulation
@property
def working_dir(self) -> Path:
"""Get the working directory path."""
return self._working_dir
@property
def history(self) -> Dict[str, Any]:
"""Get the simulation history."""
return self._history
def _impose_unique_force_groups(self) -> None:
"""Assign unique force groups to each force for energy decomposition."""
from polyzymd.utils import impose_unique_force_groups
impose_unique_force_groups(self._system)
LOGGER.debug("Assigned unique force groups")
def _get_platform(self) -> openmm.Platform:
"""Get the compute platform.
Returns:
OpenMM Platform object.
"""
try:
platform = openmm.Platform.getPlatformByName(self._platform_name)
LOGGER.info(f"Using {self._platform_name} platform")
except Exception:
LOGGER.warning(f"Platform {self._platform_name} not available, falling back to CPU")
platform = openmm.Platform.getPlatformByName("CPU")
return platform
def _create_integrator(
self,
temperature: float,
friction: float = 1.0,
timestep: float = 2.0,
thermostat: str = "LangevinMiddle",
) -> openmm.Integrator:
"""Create an integrator for the simulation.
Args:
temperature: Temperature in Kelvin.
friction: Friction coefficient in 1/ps.
timestep: Time step in femtoseconds.
thermostat: Thermostat type.
Returns:
OpenMM Integrator.
"""
temp = temperature * omm_unit.kelvin
fric = friction / omm_unit.picosecond
dt = timestep * omm_unit.femtosecond
if thermostat == "LangevinMiddle":
return openmm.LangevinMiddleIntegrator(temp, fric, dt)
elif thermostat == "Langevin":
return openmm.LangevinIntegrator(temp, fric, dt)
else:
LOGGER.warning(f"Unknown thermostat {thermostat}, using LangevinMiddle")
return openmm.LangevinMiddleIntegrator(temp, fric, dt)
def _add_barostat(
self,
pressure: float = 1.0,
temperature: float = 300.0,
frequency: int = 25,
) -> None:
"""Add a Monte Carlo barostat to the system.
Args:
pressure: Pressure in atmospheres.
temperature: Temperature in Kelvin.
frequency: Update frequency in steps.
"""
barostat = openmm.MonteCarloBarostat(
pressure * omm_unit.atmosphere,
temperature * omm_unit.kelvin,
frequency,
)
self._system.addForce(barostat)
LOGGER.info(f"Added MC barostat: {pressure} atm, {temperature} K")
def _remove_barostat(self) -> None:
"""Remove any barostat from the system."""
forces_to_remove = []
for i in range(self._system.getNumForces()):
force = self._system.getForce(i)
if isinstance(force, openmm.MonteCarloBarostat):
forces_to_remove.append(i)
for i in reversed(forces_to_remove):
self._system.removeForce(i)
LOGGER.debug("Removed barostat")
[docs]
def minimize(
self,
max_iterations: int = 1000,
tolerance: float = 10.0,
) -> float:
"""Run energy minimization.
Args:
max_iterations: Maximum iterations (0 = until convergence).
tolerance: Energy tolerance in kJ/mol/nm.
Returns:
Final potential energy in kJ/mol.
"""
LOGGER.info("Running energy minimization")
# Create temporary simulation for minimization
integrator = openmm.VerletIntegrator(1.0 * omm_unit.femtosecond)
platform = self._get_platform()
simulation = Simulation(self._topology, self._system, integrator, platform)
simulation.context.setPositions(self._current_positions)
# Minimize
simulation.minimizeEnergy(
tolerance=tolerance * omm_unit.kilojoule_per_mole / omm_unit.nanometer,
maxIterations=max_iterations,
)
# Get final state (including box vectors for proper handoff to equilibration)
state = simulation.context.getState(getEnergy=True, getPositions=True)
energy = state.getPotentialEnergy().value_in_unit(omm_unit.kilojoule_per_mole)
self._current_positions = state.getPositions()
self._current_box_vectors = state.getPeriodicBoxVectors()
LOGGER.info(f"Minimization complete: E = {energy:.2f} kJ/mol")
return energy
[docs]
def run_equilibration(
self,
temperature: float,
duration_ns: Optional[float] = None,
num_samples: int = 10,
timestep_fs: float = 2.0,
friction: float = 1.0,
output_prefix: str = "equilibration",
config: Optional["SimulationPhasesConfig"] = None,
) -> Dict[str, Any]:
"""Run equilibration phase.
Supports two modes:
1. Parameter-based (legacy): Pass duration_ns, num_samples, etc.
2. Config-based: Pass config for automatic mode selection
When config is provided and uses staged equilibration, position
restraints and temperature ramping are handled automatically.
Component information is derived from the topology's chain IDs.
Args:
temperature: Temperature in Kelvin
duration_ns: Duration in nanoseconds (required for legacy mode)
num_samples: Number of trajectory frames to save
timestep_fs: Time step in femtoseconds
friction: Friction coefficient in 1/ps
output_prefix: Prefix for output files
config: SimulationPhasesConfig for config-based dispatch
Returns:
Dictionary with equilibration results
Raises:
ValueError: If neither config nor duration_ns is provided
"""
# Config-based dispatch (preferred path)
if config is not None:
if config.uses_staged_equilibration:
# Multi-stage equilibration with position restraints
from polyzymd.core.atom_groups import AtomGroupResolver, SystemComponentInfo
# Log all stages upfront for reproducibility
LOGGER.info(
f"Starting multi-stage equilibration with "
f"{len(config.equilibration_stages)} stages:"
)
for i, stage in enumerate(config.equilibration_stages):
restraint_info = (
", ".join(
f"{r.group}@{r.force_constant:.0f}" for r in stage.position_restraints
)
or "none"
)
if stage.is_temperature_ramping:
temp_info = f"{stage.temperature_start}K -> {stage.temperature_end}K"
else:
temp_info = f"{stage.temperature}K"
LOGGER.info(
f" Stage {i}: {stage.name} - {stage.duration} ns, "
f"{stage.ensemble.value}, {temp_info}, restraints: [{restraint_info}]"
)
component_info = SystemComponentInfo.from_topology(self._topology)
resolver = AtomGroupResolver(self._topology, component_info)
return self.run_staged_equilibration(
stages=config.equilibration_stages,
atom_group_resolver=resolver,
target_temperature=temperature,
)
else:
# Simple equilibration via config
eq_config = config.equilibration
return self._run_simple_equilibration(
temperature=temperature,
duration_ns=eq_config.duration,
num_samples=eq_config.samples,
timestep_fs=eq_config.time_step or timestep_fs,
friction=friction,
output_prefix=output_prefix,
)
# Legacy parameter-based mode
if duration_ns is None:
raise ValueError(
"duration_ns is required when config is not provided. "
"Either pass duration_ns or pass a SimulationPhasesConfig."
)
return self._run_simple_equilibration(
temperature=temperature,
duration_ns=duration_ns,
num_samples=num_samples,
timestep_fs=timestep_fs,
friction=friction,
output_prefix=output_prefix,
)
def _run_simple_equilibration(
self,
temperature: float,
duration_ns: float,
num_samples: int = 10,
timestep_fs: float = 2.0,
friction: float = 1.0,
output_prefix: str = "equilibration",
) -> Dict[str, Any]:
"""Run simple NVT equilibration (internal implementation).
Args:
temperature: Temperature in Kelvin.
duration_ns: Duration in nanoseconds.
num_samples: Number of trajectory frames to save.
timestep_fs: Time step in femtoseconds.
friction: Friction coefficient in 1/ps.
output_prefix: Prefix for output files.
Returns:
Dictionary with phase results.
"""
LOGGER.info(f"Starting equilibration: {duration_ns} ns at {temperature} K (NVT)")
# Remove any barostat for NVT
self._remove_barostat()
# Create output directory
phase_dir = self._working_dir / output_prefix
phase_dir.mkdir(exist_ok=True)
# Calculate steps
total_steps = int(duration_ns * 1e6 / timestep_fs)
report_interval = max(1, total_steps // num_samples)
# Create integrator and simulation
integrator = self._create_integrator(
temperature=temperature,
friction=friction,
timestep=timestep_fs,
)
platform = self._get_platform()
self._simulation = Simulation(self._topology, self._system, integrator, platform)
self._simulation.context.setPositions(self._current_positions)
self._simulation.context.setVelocitiesToTemperature(temperature * omm_unit.kelvin)
# Add reporters
traj_path = phase_dir / f"{output_prefix}_trajectory.dcd"
state_path = phase_dir / f"{output_prefix}_state_data.csv"
pdb_path = phase_dir / f"{output_prefix}_topology.pdb"
self._simulation.reporters.append(DCDReporter(str(traj_path), report_interval))
self._simulation.reporters.append(
StateDataReporter(
str(state_path),
report_interval,
step=True,
time=True,
potentialEnergy=True,
kineticEnergy=True,
totalEnergy=True,
temperature=True,
volume=True,
density=True,
speed=True,
)
)
# Save topology
with open(pdb_path, "w") as f:
PDBFile.writeFile(
self._topology,
self._current_positions,
f,
)
# Run simulation
LOGGER.info(f"Running {total_steps} steps...")
self._simulation.step(total_steps)
# Get final state (including box vectors for potential NPT follow-up)
state = self._simulation.context.getState(
getPositions=True, getVelocities=True, getEnergy=True
)
self._current_positions = state.getPositions()
self._current_box_vectors = state.getPeriodicBoxVectors()
# Save checkpoint
checkpoint_path = phase_dir / f"{output_prefix}_checkpoint.chk"
self._simulation.saveCheckpoint(str(checkpoint_path))
results = {
"phase": "equilibration",
"ensemble": "NVT",
"temperature_K": temperature,
"duration_ns": duration_ns,
"total_steps": total_steps,
"final_energy_kJ_mol": state.getPotentialEnergy().value_in_unit(
omm_unit.kilojoule_per_mole
),
"trajectory_path": str(traj_path),
"checkpoint_path": str(checkpoint_path),
}
self._history["equilibration"] = results
LOGGER.info("Equilibration complete")
return results
[docs]
def run_equilibration_stage(
self,
stage: "EquilibrationStageConfig",
reference_positions: Any,
atom_group_resolver: "AtomGroupResolver",
stage_index: int,
default_timestep: float = 2.0,
default_friction: float = 1.0,
) -> Dict[str, Any]:
"""Run a single equilibration stage with optional position restraints.
This method runs one stage of a multi-stage equilibration protocol.
It supports:
- Position restraints on predefined atom groups
- Temperature ramping (simulated annealing)
- NVT or NPT ensembles
Args:
stage: EquilibrationStageConfig with stage settings
reference_positions: Positions to restrain atoms to (typically post-minimization)
atom_group_resolver: Resolver for predefined atom group names
stage_index: Index of this stage (for output naming)
default_timestep: Default time step in fs if not specified in stage
default_friction: Default friction coefficient in 1/ps
Returns:
Dictionary with stage results
"""
from polyzymd.config.schema import Ensemble
from polyzymd.core.position_restraints import (
add_position_restraints_to_system,
remove_position_restraints_from_system,
)
stage_name = f"equilibration_{stage_index}_{stage.name}"
LOGGER.info(f"Starting equilibration stage: {stage.name} ({stage.duration} ns)")
# Create output directory for this stage
phase_dir = self._working_dir / stage_name
phase_dir.mkdir(exist_ok=True)
# Get stage parameters (with defaults)
timestep_fs = stage.time_step if stage.time_step is not None else default_timestep
friction = default_friction
thermostat_timescale = stage.thermostat_timescale if stage.thermostat_timescale else 1.0
# Calculate steps and reporting interval
total_steps = int(stage.duration * 1e6 / timestep_fs)
report_interval = max(1, total_steps // stage.samples)
# Handle ensemble - add/remove barostat
if stage.ensemble == Ensemble.NPT:
self._remove_barostat()
pressure = 1.0 # Default pressure for NPT stages
barostat_freq = stage.barostat_frequency if stage.barostat_frequency else 25
start_temp = stage.get_start_temperature()
self._add_barostat(
pressure=pressure,
temperature=start_temp,
frequency=barostat_freq,
)
else:
# NVT - ensure no barostat
self._remove_barostat()
# Add position restraints
restraint_force_indices = []
for restraint_config in stage.position_restraints:
atom_indices = atom_group_resolver.resolve(restraint_config.group)
if atom_indices:
force_idx = add_position_restraints_to_system(
system=self._system,
atom_indices=atom_indices,
positions=reference_positions,
force_constant=restraint_config.force_constant,
)
if force_idx >= 0:
restraint_force_indices.append(force_idx)
LOGGER.info(
f"Added position restraints to {len(atom_indices)} atoms "
f"in group '{restraint_config.group}' "
f"(k={restraint_config.force_constant:.1f} kJ/mol/nm^2)"
)
else:
LOGGER.warning(
f"No atoms found for group '{restraint_config.group}' - skipping restraint"
)
# Create integrator with starting temperature
start_temp = stage.get_start_temperature()
integrator = self._create_integrator(
temperature=start_temp,
friction=friction,
timestep=timestep_fs,
)
platform = self._get_platform()
# Create simulation
self._simulation = Simulation(self._topology, self._system, integrator, platform)
# Set box vectors BEFORE positions - critical for NPT stage transitions
# where box dimensions may have changed from previous stage
if self._current_box_vectors is not None:
self._simulation.context.setPeriodicBoxVectors(*self._current_box_vectors)
self._simulation.context.setPositions(self._current_positions)
self._simulation.context.setVelocitiesToTemperature(start_temp * omm_unit.kelvin)
# Log initial energy
_state = self._simulation.context.getState(getEnergy=True)
_energy = _state.getPotentialEnergy().value_in_unit(omm_unit.kilojoule_per_mole)
LOGGER.info(f"Stage {stage_index} ({stage_name}): initial PE = {_energy:.2f} kJ/mol")
# Set up reporters
traj_path = phase_dir / f"{stage_name}_trajectory.dcd"
state_path = phase_dir / f"{stage_name}_state_data.csv"
pdb_path = phase_dir / f"{stage_name}_topology.pdb"
self._simulation.reporters.append(DCDReporter(str(traj_path), report_interval))
self._simulation.reporters.append(
StateDataReporter(
str(state_path),
report_interval,
step=True,
time=True,
potentialEnergy=True,
kineticEnergy=True,
totalEnergy=True,
temperature=True,
volume=True,
density=True,
speed=True,
)
)
# Save initial topology
with open(pdb_path, "w") as f:
PDBFile.writeFile(
self._topology,
self._current_positions,
f,
)
# Run simulation with temperature ramping if needed
if stage.is_temperature_ramping:
LOGGER.info(
f"Temperature ramping: {stage.temperature_start} K -> {stage.temperature_end} K "
f"(increment={stage.temperature_increment} K every {stage.temperature_interval} fs)"
)
current_temp = stage.temperature_start
steps_per_update = int(stage.temperature_interval / timestep_fs)
# Calculate total temperature updates needed
temp_range = stage.temperature_end - stage.temperature_start
num_updates = int(temp_range / stage.temperature_increment)
steps_for_ramping = num_updates * steps_per_update
remaining_steps = total_steps - steps_for_ramping
# Temperature ramping phase
while current_temp < stage.temperature_end:
integrator.setTemperature(current_temp * omm_unit.kelvin)
self._simulation.step(steps_per_update)
current_temp += stage.temperature_increment
# Final temperature - run remaining steps
integrator.setTemperature(stage.temperature_end * omm_unit.kelvin)
if remaining_steps > 0:
LOGGER.info(
f"Running {remaining_steps} steps at final temperature {stage.temperature_end} K"
)
self._simulation.step(remaining_steps)
else:
# Constant temperature - just run all steps
LOGGER.info(f"Running {total_steps} steps at {stage.temperature} K")
self._simulation.step(total_steps)
# Get final state (including box vectors for NPT stages)
state = self._simulation.context.getState(
getPositions=True, getVelocities=True, getEnergy=True
)
self._current_positions = state.getPositions()
self._current_box_vectors = state.getPeriodicBoxVectors()
# Log final energy for diagnostics
final_energy = state.getPotentialEnergy().value_in_unit(omm_unit.kilojoule_per_mole)
LOGGER.info(f"Stage {stage_index} ({stage_name}): final PE = {final_energy:.2f} kJ/mol")
# Save checkpoint
checkpoint_path = phase_dir / f"{stage_name}_checkpoint.chk"
self._simulation.saveCheckpoint(str(checkpoint_path))
# Remove position restraints from system for next stage
if restraint_force_indices:
LOGGER.info(
f"Removing {len(restraint_force_indices)} position restraint force(s) for next stage"
)
remove_position_restraints_from_system(self._system, restraint_force_indices)
# Build results
final_temp = stage.get_final_temperature()
results = {
"stage_index": stage_index,
"stage_name": stage.name,
"ensemble": stage.ensemble.value,
"duration_ns": stage.duration,
"total_steps": total_steps,
"temperature_start_K": stage.get_start_temperature(),
"temperature_end_K": final_temp,
"is_temperature_ramping": stage.is_temperature_ramping,
"position_restraints": [
{"group": r.group, "force_constant": r.force_constant}
for r in stage.position_restraints
],
"final_energy_kJ_mol": state.getPotentialEnergy().value_in_unit(
omm_unit.kilojoule_per_mole
),
"trajectory_path": str(traj_path),
"checkpoint_path": str(checkpoint_path),
}
LOGGER.info(f"Equilibration stage '{stage.name}' complete")
return results
[docs]
def run_staged_equilibration(
self,
stages: List["EquilibrationStageConfig"],
atom_group_resolver: "AtomGroupResolver",
target_temperature: float,
) -> Dict[str, Any]:
"""Run complete multi-stage equilibration protocol.
This method executes a sequence of equilibration stages, each with
potentially different:
- Temperature (constant or ramping)
- Position restraints on different atom groups
- Thermodynamic ensemble (NVT/NPT)
Positions carry over between stages, and restraint forces are
added/removed as needed.
Args:
stages: List of EquilibrationStageConfig objects
atom_group_resolver: Resolver for predefined atom group names
target_temperature: Final target temperature (for logging)
Returns:
Dictionary with all stage results and summary
"""
LOGGER.info(f"Starting multi-stage equilibration with {len(stages)} stages")
# Store reference positions for restraints (post-minimization)
reference_positions = self._current_positions
results = {
"type": "staged_equilibration",
"num_stages": len(stages),
"stages": [],
"total_duration_ns": 0.0,
}
for i, stage in enumerate(stages):
stage_result = self.run_equilibration_stage(
stage=stage,
reference_positions=reference_positions,
atom_group_resolver=atom_group_resolver,
stage_index=i,
)
results["stages"].append(stage_result)
results["total_duration_ns"] += stage.duration
# Get final energy
if results["stages"]:
results["final_energy_kJ_mol"] = results["stages"][-1]["final_energy_kJ_mol"]
results["final_temperature_K"] = results["stages"][-1]["temperature_end_K"]
self._history["equilibration"] = results
LOGGER.info(
f"Multi-stage equilibration complete: {len(stages)} stages, "
f"{results['total_duration_ns']:.3f} ns total"
)
return results
[docs]
def run_production(
self,
temperature: float,
duration_ns: float,
num_samples: int = 2500,
timestep_fs: float = 2.0,
friction: float = 1.0,
pressure: float = 1.0,
barostat_frequency: int = 25,
output_prefix: str = "production",
segment_index: int = 0,
) -> Dict[str, Any]:
"""Run NPT production simulation.
Args:
temperature: Temperature in Kelvin.
duration_ns: Duration in nanoseconds.
num_samples: Number of trajectory frames to save.
timestep_fs: Time step in femtoseconds.
friction: Friction coefficient in 1/ps.
pressure: Pressure in atmospheres.
barostat_frequency: Barostat update frequency.
output_prefix: Prefix for output files.
segment_index: Segment index for daisy-chaining.
Returns:
Dictionary with phase results.
"""
LOGGER.info(
f"Starting production: {duration_ns} ns at {temperature} K, {pressure} atm (NPT)"
)
# Add barostat for NPT
self._remove_barostat()
self._add_barostat(
pressure=pressure,
temperature=temperature,
frequency=barostat_frequency,
)
# Create output directory
phase_name = f"{output_prefix}_{segment_index}"
phase_dir = self._working_dir / phase_name
phase_dir.mkdir(exist_ok=True)
# Calculate steps
total_steps = int(duration_ns * 1e6 / timestep_fs)
report_interval = max(1, total_steps // num_samples)
# Create integrator and simulation
integrator = self._create_integrator(
temperature=temperature,
friction=friction,
timestep=timestep_fs,
)
platform = self._get_platform()
# Capture velocities and box vectors from equilibration before creating new Simulation
# (creating new Simulation destroys the old context)
# Box vectors are critical for NPT stages where box dimensions change
equilibration_velocities = None
equilibration_box_vectors = None
if self._simulation is not None and segment_index == 0:
state = self._simulation.context.getState(getVelocities=True)
equilibration_velocities = state.getVelocities()
equilibration_box_vectors = state.getPeriodicBoxVectors()
self._simulation = Simulation(self._topology, self._system, integrator, platform)
# Set box vectors BEFORE positions - critical for correct periodic boundary handling
# Use captured box vectors from equilibration, or fall back to stored ones from staged equilibration
if equilibration_box_vectors is not None:
self._simulation.context.setPeriodicBoxVectors(*equilibration_box_vectors)
elif self._current_box_vectors is not None:
self._simulation.context.setPeriodicBoxVectors(*self._current_box_vectors)
self._simulation.context.setPositions(self._current_positions)
# Log initial energy
_state = self._simulation.context.getState(getEnergy=True)
_energy = _state.getPotentialEnergy().value_in_unit(omm_unit.kilojoule_per_mole)
LOGGER.info(f"Production segment {segment_index}: initial PE = {_energy:.2f} kJ/mol")
# Set velocities for production
# - If we have velocities from equilibration, use them (physical continuity)
# - Otherwise generate new velocities at target temperature
# Note: For daisy-chain continuation (segment > 0), ContinuationManager uses
# loadState() which restores both positions and velocities from the XML state file
if segment_index == 0:
if equilibration_velocities is not None:
self._simulation.context.setVelocities(equilibration_velocities)
LOGGER.info("Using velocities preserved from equilibration")
else:
self._simulation.context.setVelocitiesToTemperature(temperature * omm_unit.kelvin)
LOGGER.info("Initialized velocities from Maxwell-Boltzmann distribution")
# Add reporters
traj_path = phase_dir / f"{phase_name}_trajectory.dcd"
state_path = phase_dir / f"{phase_name}_state_data.csv"
pdb_path = phase_dir / f"{phase_name}_topology.pdb"
self._simulation.reporters.append(DCDReporter(str(traj_path), report_interval))
self._simulation.reporters.append(
StateDataReporter(
str(state_path),
report_interval,
step=True,
time=True,
potentialEnergy=True,
kineticEnergy=True,
totalEnergy=True,
temperature=True,
volume=True,
density=True,
speed=True,
)
)
# Save topology
with open(pdb_path, "w") as f:
PDBFile.writeFile(
self._topology,
self._current_positions,
f,
)
# Run simulation
LOGGER.info(f"Running {total_steps} steps...")
self._simulation.step(total_steps)
# Get final state (no enforcePeriodicBox to preserve molecular continuity)
state = self._simulation.context.getState(
getPositions=True,
getVelocities=True,
getEnergy=True,
getForces=True,
getParameters=True,
)
self._current_positions = state.getPositions()
# Save checkpoint
checkpoint_path = phase_dir / f"{phase_name}_checkpoint.chk"
self._simulation.saveCheckpoint(str(checkpoint_path))
# Save state XML (needed for continuation/daisy-chain)
state_xml_path = phase_dir / f"{phase_name}_state.xml"
with open(state_xml_path, "w") as f:
f.write(XmlSerializer.serialize(state))
LOGGER.info(f"Saved state to {state_xml_path}")
# Save system XML (needed for continuation/daisy-chain)
system_xml_path = phase_dir / f"{phase_name}_system.xml"
with open(system_xml_path, "w") as f:
f.write(XmlSerializer.serialize(self._system))
LOGGER.info(f"Saved system to {system_xml_path}")
# Save parameters JSON (needed for continuation/daisy-chain)
params_dict = {
"__class__": "SimulationParameters",
"__values__": {
"thermo_params": {
"__class__": "ThermoParameters",
"__values__": {
"temperature": {
"__class__": "Quantity",
"__values__": {"value": temperature, "unit": "kelvin"},
},
"thermostat_params": {
"__class__": "ThermostatParameters",
"__values__": {
"temperature": {
"__class__": "Quantity",
"__values__": {"value": temperature, "unit": "kelvin"},
},
"timescale": {
"__class__": "Quantity",
"__values__": {"value": friction, "unit": "/picosecond"},
},
},
},
"barostat_params": {
"__class__": "BarostatParameters",
"__values__": {
"pressure": {
"__class__": "Quantity",
"__values__": {"value": pressure, "unit": "atmosphere"},
},
"temperature": {
"__class__": "Quantity",
"__values__": {"value": temperature, "unit": "kelvin"},
},
"frequency": barostat_frequency,
},
},
},
},
"integ_params": {
"__class__": "IntegratorParameters",
"__values__": {
"time_step": {
"__class__": "Quantity",
"__values__": {"value": timestep_fs, "unit": "femtosecond"},
},
"total_time": {
"__class__": "Quantity",
"__values__": {"value": duration_ns, "unit": "nanosecond"},
},
"num_samples": num_samples,
},
},
"reporter_params": {
"__class__": "ReporterParameters",
"__values__": {
"report_interval": report_interval,
"report_trajectory": True,
"report_state_data": True,
},
},
},
}
params_path = phase_dir / f"{phase_name}_parameters.json"
with open(params_path, "w") as f:
json.dump(params_dict, f, indent=2)
LOGGER.info(f"Saved parameters to {params_path}")
results = {
"phase": "production",
"segment": segment_index,
"ensemble": "NPT",
"temperature_K": temperature,
"pressure_atm": pressure,
"duration_ns": duration_ns,
"total_steps": total_steps,
"final_energy_kJ_mol": state.getPotentialEnergy().value_in_unit(
omm_unit.kilojoule_per_mole
),
"trajectory_path": str(traj_path),
"checkpoint_path": str(checkpoint_path),
}
self._history[phase_name] = results
LOGGER.info(f"Production segment {segment_index} complete")
return results
[docs]
def save_history(self, path: Optional[Union[str, Path]] = None) -> None:
"""Save simulation history to JSON.
Args:
path: Output path (defaults to working_dir/simulation_history.json).
"""
if path is None:
path = self._working_dir / "simulation_history.json"
else:
path = Path(path)
with open(path, "w") as f:
json.dump(self._history, f, indent=2)
LOGGER.info(f"Saved simulation history to {path}")
[docs]
def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> None:
"""Load state from a checkpoint file.
Args:
checkpoint_path: Path to checkpoint file.
"""
checkpoint_path = Path(checkpoint_path)
if not checkpoint_path.exists():
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
if self._simulation is None:
raise RuntimeError(
"No active simulation. Create a simulation first with "
"run_equilibration or run_production."
)
self._simulation.loadCheckpoint(str(checkpoint_path))
# Update current positions and box vectors
state = self._simulation.context.getState(getPositions=True)
self._current_positions = state.getPositions()
self._current_box_vectors = state.getPeriodicBoxVectors()
LOGGER.info(f"Loaded checkpoint from {checkpoint_path}")