Source code for polyzymd.simulation.continuation

"""
Continuation manager for resuming MD simulations from checkpoints.

This module handles loading simulation state from previous segments
and continuing production simulations for daisy-chain workflows.
"""

from __future__ import annotations

import json
import logging
import sys
from pathlib import Path
from typing import Any, Dict, Optional, Union

import openmm
from openmm import XmlSerializer
from openmm import unit as u
from openmm.app import (
    CheckpointReporter,
    DCDReporter,
    PDBFile,
    Simulation,
    StateDataReporter,
)
from openmm.unit import Quantity

LOGGER = logging.getLogger(__name__)


[docs] def quantity_from_dict(qdict: Dict[str, Any]) -> Quantity: """Convert serialized quantity dictionary back to OpenMM Quantity. Args: qdict: Dictionary with __values__ containing value and unit. Returns: OpenMM Quantity with appropriate units. """ value = qdict["__values__"]["value"] unit_str = qdict["__values__"]["unit"] # Handle inverse units (e.g., "/picosecond") if unit_str.startswith("/"): base_unit = getattr(u, unit_str[1:]) return value / base_unit # Map common unit variations unit_mapping = { "atmosphere": u.atmospheres, "atmospheres": u.atmospheres, "kelvin": u.kelvin, "femtosecond": u.femtoseconds, "femtoseconds": u.femtoseconds, "nanosecond": u.nanoseconds, "nanoseconds": u.nanoseconds, "picosecond": u.picoseconds, "picoseconds": u.picoseconds, } if unit_str in unit_mapping: return value * unit_mapping[unit_str] else: return value * getattr(u, unit_str)
[docs] class ContinuationManager: """Manager for continuing MD simulations from previous segments. This class handles loading state from previous production segments and continuing the simulation, primarily for daisy-chain workflows on HPC clusters. Example: >>> manager = ContinuationManager( ... working_dir="simulation_output/", ... segment_index=2, # Continuing to segment 2 ... ) >>> manager.load_previous_state() >>> manager.run_segment(duration_ns=20.0, num_samples=250) """
[docs] def __init__( self, working_dir: Union[str, Path], segment_index: int, ) -> None: """Initialize the ContinuationManager. Args: working_dir: Working directory containing simulation outputs. segment_index: Current segment index (1-based). """ self._working_dir = Path(working_dir) self._segment_index = segment_index self._prev_segment = segment_index - 1 # State self._system: Optional[openmm.System] = None self._topology: Optional[Any] = None self._simulation: Optional[Simulation] = None self._param_dict: Optional[Dict[str, Any]] = None
@property def working_dir(self) -> Path: """Get the working directory.""" return self._working_dir @property def segment_index(self) -> int: """Get the current segment index.""" return self._segment_index @property def simulation(self) -> Optional[Simulation]: """Get the OpenMM Simulation object.""" return self._simulation def _find_solvated_pdb(self) -> Path: """Find the solvated PDB file in the working directory. Returns: Path to the solvated PDB file. Raises: FileNotFoundError: If no suitable PDB file is found. """ patterns = [ "*solvated*.pdb", "*_solvated.pdb", "solvated_*.pdb", "equilibration/*_topology.pdb", "production_0/*_topology.pdb", ] for pattern in patterns: pdb_files = list(self._working_dir.glob(pattern)) if pdb_files: return pdb_files[0] # Fallback to any PDB pdb_files = list(self._working_dir.glob("**/*.pdb")) if pdb_files: return pdb_files[0] raise FileNotFoundError(f"Could not find solvated PDB file in {self._working_dir}") def _get_previous_paths(self) -> Dict[str, Path]: """Get paths to files from the previous segment. Returns: Dictionary with paths to state, system, and parameter files. """ prev_dir = self._working_dir / f"production_{self._prev_segment}" return { "state": prev_dir / f"production_{self._prev_segment}_state.xml", "system": prev_dir / f"production_{self._prev_segment}_system.xml", "params": prev_dir / f"production_{self._prev_segment}_parameters.json", "checkpoint": prev_dir / f"production_{self._prev_segment}_checkpoint.chk", }
[docs] def load_previous_state(self) -> None: """Load state from the previous segment. This loads the system, topology, and parameters from the previous production segment. Raises: FileNotFoundError: If required files are missing. """ LOGGER.info(f"Loading state from segment {self._prev_segment}") paths = self._get_previous_paths() # Check that required files exist for name, path in paths.items(): if name != "checkpoint" and not path.exists(): raise FileNotFoundError(f"Required file not found: {path}") # Load system LOGGER.info(f"Loading system from {paths['system']}") with open(paths["system"], "r") as f: self._system = XmlSerializer.deserialize(f.read()) # Load topology pdb_path = self._find_solvated_pdb() LOGGER.info(f"Loading topology from {pdb_path}") self._topology = PDBFile(str(pdb_path)).topology # Load parameters LOGGER.info(f"Loading parameters from {paths['params']}") with open(paths["params"], "r") as f: self._param_dict = json.load(f) LOGGER.info("Previous state loaded successfully")
def _create_integrator(self) -> openmm.Integrator: """Create an integrator from the parameter dictionary. Returns: OpenMM LangevinMiddleIntegrator. """ if self._param_dict is None: raise RuntimeError("Parameters not loaded. Call load_previous_state first.") integ_raw = self._param_dict["__values__"]["integ_params"]["__values__"] time_step = quantity_from_dict(integ_raw["time_step"]) thermo_raw = self._param_dict["__values__"]["thermo_params"]["__values__"] thermostat_raw = thermo_raw["thermostat_params"]["__values__"] temperature = quantity_from_dict(thermostat_raw["temperature"]) friction_coeff = quantity_from_dict(thermostat_raw["timescale"]) return openmm.LangevinMiddleIntegrator(temperature, friction_coeff, time_step) def _add_barostat_if_needed(self) -> None: """Add barostat to the system if parameters specify NPT.""" if self._system is None or self._param_dict is None: raise RuntimeError("System/parameters not loaded") thermo_raw = self._param_dict["__values__"]["thermo_params"]["__values__"] if "barostat_params" not in thermo_raw: return # Check if barostat already exists has_barostat = any( isinstance(self._system.getForce(i), openmm.MonteCarloBarostat) for i in range(self._system.getNumForces()) ) if has_barostat: LOGGER.debug("Barostat already present") return barostat_raw = thermo_raw["barostat_params"]["__values__"] temperature = quantity_from_dict(barostat_raw["temperature"]) pressure = quantity_from_dict(barostat_raw["pressure"]) frequency = barostat_raw.get("update_frequency", 25) barostat = openmm.MonteCarloBarostat(pressure, temperature, frequency) self._system.addForce(barostat) LOGGER.info(f"Added barostat: {pressure} at {temperature}") def _setup_reporters( self, total_steps: int, num_samples: int, output_dir: Path, ) -> None: """Setup reporters for the simulation. Args: total_steps: Total steps for this segment. num_samples: Number of trajectory frames to save. output_dir: Output directory for this segment. """ if self._simulation is None: raise RuntimeError("Simulation not created") report_interval = max(1, total_steps // num_samples) # Trajectory reporter traj_path = output_dir / f"production_{self._segment_index}_trajectory.dcd" self._simulation.reporters.append(DCDReporter(str(traj_path), report_interval)) # State data reporter state_path = output_dir / f"production_{self._segment_index}_state_data.csv" self._simulation.reporters.append( StateDataReporter( str(state_path), report_interval, step=True, time=True, potentialEnergy=True, kineticEnergy=True, totalEnergy=True, temperature=True, volume=True, density=True, speed=True, ) ) # Checkpoint reporter checkpoint_path = output_dir / f"production_{self._segment_index}_checkpoint.chk" self._simulation.reporters.append(CheckpointReporter(str(checkpoint_path), report_interval)) LOGGER.info(f"Setup reporters with interval {report_interval}") def _save_final_state(self, output_dir: Path) -> None: """Save the final state and system after simulation. Args: output_dir: Output directory for this segment. """ if self._simulation is None: raise RuntimeError("Simulation not available") # Save state (no enforcePeriodicBox to preserve molecular continuity) state_path = output_dir / f"production_{self._segment_index}_state.xml" state = self._simulation.context.getState( getPositions=True, getVelocities=True, getForces=True, getEnergy=True, getParameters=True, ) with open(state_path, "w") as f: f.write(XmlSerializer.serialize(state)) # Save system system_path = output_dir / f"production_{self._segment_index}_system.xml" with open(system_path, "w") as f: f.write(XmlSerializer.serialize(self._simulation.system)) LOGGER.info(f"Saved final state to {state_path}") LOGGER.info(f"Saved system to {system_path}")
[docs] def run_segment( self, duration_ns: float, num_samples: int = 250, timestep_fs: float = 2.0, ) -> Dict[str, Any]: """Run the continuation segment. Args: duration_ns: Duration of this segment in nanoseconds. num_samples: Number of trajectory frames to save. timestep_fs: Time step in femtoseconds. Returns: Dictionary with segment results. """ if self._system is None or self._topology is None: raise RuntimeError("State not loaded. Call load_previous_state first.") LOGGER.info( f"Starting segment {self._segment_index}: {duration_ns} ns, {num_samples} frames" ) # Update parameters for this segment if self._param_dict: self._param_dict["__values__"]["integ_params"]["__values__"]["total_time"] = { "__class__": "Quantity", "__values__": {"value": duration_ns, "unit": "nanosecond"}, } self._param_dict["__values__"]["integ_params"]["__values__"]["num_samples"] = ( num_samples ) # Add barostat if needed self._add_barostat_if_needed() # Create integrator and simulation integrator = self._create_integrator() self._simulation = Simulation(self._topology, self._system, integrator) # Load state from previous segment paths = self._get_previous_paths() LOGGER.info(f"Loading state from {paths['state']}") self._simulation.loadState(str(paths["state"])) # Create output directory output_dir = self._working_dir / f"production_{self._segment_index}" output_dir.mkdir(exist_ok=True) # Calculate total steps total_steps = int(duration_ns * 1e6 / timestep_fs) # Setup reporters self._setup_reporters(total_steps, num_samples, output_dir) # Save parameters for this segment if self._param_dict: param_path = output_dir / f"production_{self._segment_index}_parameters.json" with open(param_path, "w") as f: json.dump(self._param_dict, f, indent=2) # Run simulation LOGGER.info(f"Running {total_steps} steps...") self._simulation.step(total_steps) # Save final state self._save_final_state(output_dir) results = { "segment_index": self._segment_index, "duration_ns": duration_ns, "total_steps": total_steps, "num_samples": num_samples, "output_dir": str(output_dir), } LOGGER.info(f"Segment {self._segment_index} completed successfully") return results
[docs] def main() -> int: """Main entry point for continuation script. Returns: Exit code (0 for success, 1 for failure). """ import argparse parser = argparse.ArgumentParser(description="Continue MD simulation from previous segment") parser.add_argument( "-s", "--segment_index", type=int, required=True, help="Current segment index (1-based)", ) parser.add_argument( "-w", "--working_dir", type=str, required=True, help="Working directory path", ) parser.add_argument( "-t", "--segment_time", type=float, required=True, help="Time for this segment in nanoseconds", ) parser.add_argument( "-n", "--num_samples", type=int, default=250, help="Number of frames to save for this segment", ) args = parser.parse_args() # Setup logging logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s", handlers=[ logging.FileHandler( Path(args.working_dir) / f"simulation_status_segment_{args.segment_index}.log" ), logging.StreamHandler(sys.stdout), ], ) try: manager = ContinuationManager( working_dir=args.working_dir, segment_index=args.segment_index, ) manager.load_previous_state() manager.run_segment( duration_ns=args.segment_time, num_samples=args.num_samples, ) return 0 except Exception as e: LOGGER.error(f"Error during simulation: {e}") import traceback traceback.print_exc() return 1
if __name__ == "__main__": sys.exit(main())