Source code for polyzymd.workflow.daisy_chain

"""
Daisy-chain job submission for HPC SLURM scheduler.

This module provides utilities for breaking long MD simulations into
smaller dependent jobs that are automatically chained together using
SLURM job dependencies.
"""

from __future__ import annotations

import logging
import os
import subprocess
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Union

from polyzymd.config.schema import SimulationConfig
from polyzymd.workflow.slurm import (
    JobContext,
    SlurmConfig,
    SlurmScriptGenerator,
    parse_replicate_range,
    validate_replicate_range,
)

LOGGER = logging.getLogger(__name__)


[docs] @dataclass class SegmentInfo: """Information about a single simulation segment. Attributes: index: Segment index (0-based for initial, 1+ for continuations) duration_ns: Duration of this segment in nanoseconds samples: Number of trajectory frames to save is_initial: Whether this is the initial (build + equilibration + first prod) segment cumulative_time_ns: Total simulated time up to and including this segment """ index: int duration_ns: float samples: int is_initial: bool cumulative_time_ns: float
[docs] @dataclass class DaisyChainConfig: """Configuration for daisy-chain submission. Attributes: slurm_config: SLURM job configuration total_production_time_ns: Total production time in nanoseconds total_segments: Number of segments to split production into total_samples: Total trajectory frames across all segments equilibration_time_ns: Equilibration time (only for initial segment) replicates: List of replicate numbers to run dry_run: If True, create scripts but don't submit output_script_dir: Directory for generated job scripts config_path: Path to the YAML configuration file """ slurm_config: SlurmConfig total_production_time_ns: float total_segments: int = 10 total_samples: int = 2500 equilibration_time_ns: float = 0.5 replicates: List[int] = field(default_factory=lambda: [1]) dry_run: bool = False output_script_dir: Path = Path("daisy_chain_scripts") config_path: str = "config.yaml" @property def segment_duration_ns(self) -> float: """Get the duration of each segment in nanoseconds.""" return self.total_production_time_ns / self.total_segments @property def samples_per_segment(self) -> int: """Get the number of frames per segment.""" return self.total_samples // self.total_segments
[docs] def get_segments(self) -> List[SegmentInfo]: """Generate segment information for all segments. Returns: List of SegmentInfo objects for each segment. """ segments = [] cumulative_time = 0.0 for i in range(self.total_segments): duration = self.segment_duration_ns cumulative_time += duration segments.append( SegmentInfo( index=i, duration_ns=duration, samples=self.samples_per_segment, is_initial=(i == 0), cumulative_time_ns=cumulative_time, ) ) return segments
[docs] @classmethod def from_simulation_config( cls, sim_config: SimulationConfig, slurm_config: SlurmConfig, replicates: Union[str, List[int]] = "1", dry_run: bool = False, output_script_dir: Union[str, Path] = "daisy_chain_scripts", config_path: str = "config.yaml", ) -> "DaisyChainConfig": """Create DaisyChainConfig from a SimulationConfig. Args: sim_config: Simulation configuration slurm_config: SLURM configuration replicates: Replicate range string (e.g., "1-5") or list of ints dry_run: If True, don't submit jobs output_script_dir: Directory for job scripts config_path: Path to the YAML configuration file Returns: Configured DaisyChainConfig """ # Parse replicates if string if isinstance(replicates, str): validate_replicate_range(replicates) replicate_list = parse_replicate_range(replicates) else: replicate_list = replicates return cls( slurm_config=slurm_config, total_production_time_ns=sim_config.simulation_phases.production.duration, total_segments=sim_config.simulation_phases.segments, total_samples=sim_config.simulation_phases.production.samples, equilibration_time_ns=sim_config.simulation_phases.total_equilibration_duration, replicates=replicate_list, dry_run=dry_run, output_script_dir=Path(output_script_dir), config_path=config_path, )
[docs] @dataclass class SubmissionResult: """Result of job submission. Attributes: job_id: SLURM job ID (or dummy ID for dry run) script_path: Path to the generated script segment_index: Segment index for this job replicate: Replicate number is_dry_run: Whether this was a dry run """ job_id: str script_path: Path segment_index: int replicate: int is_dry_run: bool = False
[docs] class DaisyChainSubmitter: """Handles daisy-chain job submission for MD simulations. This class generates SLURM job scripts and submits them with proper dependencies so that continuation jobs run after their prerequisites. Example: >>> sim_config = SimulationConfig.from_yaml("config.yaml") >>> slurm_config = SlurmConfig.from_preset("aa100", email="user@example.com") >>> dc_config = DaisyChainConfig.from_simulation_config( ... sim_config, slurm_config, replicates="1-3" ... ) >>> submitter = DaisyChainSubmitter(sim_config, dc_config) >>> results = submitter.submit_all() """
[docs] def __init__( self, sim_config: SimulationConfig, dc_config: DaisyChainConfig, conda_env: str = "polymerist-env", openff_logs: bool = False, skip_build: bool = False, ) -> None: """Initialize the DaisyChainSubmitter. Args: sim_config: Simulation configuration dc_config: Daisy-chain configuration conda_env: Conda environment name openff_logs: Enable verbose OpenFF logs in generated scripts skip_build: Skip system building in generated scripts (use pre-built system) """ self._sim_config = sim_config self._dc_config = dc_config self._openff_logs = openff_logs self._skip_build = skip_build self._generator = SlurmScriptGenerator( dc_config.slurm_config, conda_env, openff_logs=openff_logs, skip_build=skip_build ) # Track submitted jobs per replicate self._job_chains: Dict[int, List[SubmissionResult]] = {}
@property def sim_config(self) -> SimulationConfig: """Get the simulation configuration.""" return self._sim_config @property def dc_config(self) -> DaisyChainConfig: """Get the daisy-chain configuration.""" return self._dc_config @property def job_chains(self) -> Dict[int, List[SubmissionResult]]: """Get the job chains for all replicates.""" return self._job_chains def _create_job_name(self, segment_index: int, replicate: int) -> str: """Create a descriptive job name. Args: segment_index: Segment index replicate: Replicate number Returns: Formatted job name """ enzyme = self._sim_config.enzyme.name temp = int(self._sim_config.thermodynamics.temperature) polymer_info = "" if self._sim_config.polymers and self._sim_config.polymers.enabled: prefix = self._sim_config.polymers.type_prefix # Get minority percentage probs = [m.probability for m in self._sim_config.polymers.monomers] minority_pct = int(min(probs) * 100) polymer_info = f"_{prefix}-{minority_pct}%" return f"s{segment_index}_r{replicate}_{temp}K_{enzyme}{polymer_info}" def _create_output_file_pattern(self, segment_index: int, replicate: int) -> str: """Create output file pattern for SLURM logs. SLURM logs go to the slurm_logs subdirectory within projects. Args: segment_index: Segment index replicate: Replicate number Returns: Output file pattern (relative to projects_dir) """ job_name = self._create_job_name(segment_index, replicate) logs_subdir = self._sim_config.output.slurm_logs_subdir return f"{logs_subdir}/{job_name}.%A_%a.out" def _get_scratch_dir(self, replicate: int) -> str: """Get the scratch directory path for a replicate. This is where simulation output (trajectories, checkpoints) goes. Args: replicate: Replicate number Returns: Scratch directory path (absolute) """ scratch_dir = self._sim_config.get_working_directory(replicate) return str(scratch_dir.resolve()) def _get_projects_dir(self) -> str: """Get the projects directory path. This is where scripts, configs, and logs live. Returns: Projects directory path (absolute) """ projects_dir = self._sim_config.get_projects_directory() return str(projects_dir.resolve())
[docs] def generate_initial_script(self, replicate: int) -> str: """Generate the initial job script content. Args: replicate: Replicate number Returns: Script content string """ context = JobContext( job_name=self._create_job_name(0, replicate), output_file=self._create_output_file_pattern(0, replicate), scratch_dir=self._get_scratch_dir(replicate), projects_dir=self._get_projects_dir(), segment_index=0, replicate_num=replicate, ) return self._generator.generate_initial_job( context=context, config_path=self._dc_config.config_path, replicate=replicate, segment_time=self._dc_config.segment_duration_ns, segment_frames=self._dc_config.samples_per_segment, )
[docs] def generate_continuation_script(self, segment_index: int, replicate: int) -> str: """Generate a continuation job script content. Args: segment_index: Segment index (1 or higher) replicate: Replicate number Returns: Script content string """ context = JobContext( job_name=self._create_job_name(segment_index, replicate), output_file=self._create_output_file_pattern(segment_index, replicate), scratch_dir=self._get_scratch_dir(replicate), projects_dir=self._get_projects_dir(), segment_index=segment_index, replicate_num=replicate, ) return self._generator.generate_continuation_job( context=context, segment_time=self._dc_config.segment_duration_ns, num_samples=self._dc_config.samples_per_segment, )
def _save_script(self, content: str, filename: str) -> Path: """Save a script to the output directory. Args: content: Script content filename: Script filename Returns: Path to saved script """ output_dir = self._dc_config.output_script_dir output_dir.mkdir(parents=True, exist_ok=True) script_path = output_dir / filename with open(script_path, "w") as f: f.write(content) os.chmod(script_path, 0o755) return script_path def _submit_job( self, script_path: Path, segment_index: int, replicate: int, dependency_job_id: Optional[str] = None, ) -> SubmissionResult: """Submit a job to SLURM. Args: script_path: Path to the job script segment_index: Segment index replicate: Replicate number dependency_job_id: Job ID to depend on (for continuation jobs) Returns: SubmissionResult with job information """ if self._dc_config.dry_run: job_id = f"DRY_RUN_{replicate}_{segment_index}" LOGGER.info(f"[DRY RUN] Would submit {script_path}") return SubmissionResult( job_id=job_id, script_path=script_path, segment_index=segment_index, replicate=replicate, is_dry_run=True, ) # Build sbatch command # Use --export=NONE to start with clean environment, letting the script's # module/conda initialization work properly regardless of submission context cmd = ["sbatch", "--export=NONE"] if dependency_job_id: cmd.extend(["--dependency", f"afterok:{dependency_job_id}"]) # Add exclude if configured if self._dc_config.slurm_config.exclude: cmd.extend(["--exclude", self._dc_config.slurm_config.exclude]) cmd.append(str(script_path)) # Submit try: result = subprocess.run(cmd, capture_output=True, text=True, check=True) job_id = result.stdout.strip().split()[-1] LOGGER.info(f"Submitted job {job_id} from {script_path}") return SubmissionResult( job_id=job_id, script_path=script_path, segment_index=segment_index, replicate=replicate, is_dry_run=False, ) except subprocess.CalledProcessError as e: LOGGER.error(f"Error submitting job: {e}") LOGGER.error(f"STDOUT: {e.stdout}") LOGGER.error(f"STDERR: {e.stderr}") raise RuntimeError(f"Failed to submit job: {e.stderr}") from e
[docs] def submit_replicate_chain(self, replicate: int) -> List[SubmissionResult]: """Submit all jobs for a single replicate. Args: replicate: Replicate number Returns: List of SubmissionResults for all segments """ LOGGER.info(f"Submitting job chain for replicate {replicate}") results: List[SubmissionResult] = [] segments = self._dc_config.get_segments() for segment in segments: if segment.is_initial: # Initial job script_content = self.generate_initial_script(replicate) filename = f"initial_seg{segment.index}_rep{replicate}.sh" script_path = self._save_script(script_content, filename) result = self._submit_job( script_path=script_path, segment_index=segment.index, replicate=replicate, dependency_job_id=None, ) else: # Continuation job script_content = self.generate_continuation_script(segment.index, replicate) filename = f"continue_seg{segment.index}_rep{replicate}.sh" script_path = self._save_script(script_content, filename) # Depend on previous segment prev_job_id = results[-1].job_id result = self._submit_job( script_path=script_path, segment_index=segment.index, replicate=replicate, dependency_job_id=prev_job_id, ) results.append(result) self._job_chains[replicate] = results return results
[docs] def submit_all(self) -> Dict[int, List[SubmissionResult]]: """Submit jobs for all replicates. Returns: Dictionary mapping replicate numbers to their job chains """ self._print_submission_summary() for replicate in self._dc_config.replicates: self.submit_replicate_chain(replicate) self._print_completion_summary() return self._job_chains
def _print_submission_summary(self) -> None: """Print a summary before submission.""" config = self._dc_config num_replicates = len(config.replicates) total_jobs = num_replicates * config.total_segments print(f"\nPreparing {config.total_segments}-segment simulation jobs") print(f" Enzyme: {self._sim_config.enzyme.name}") if self._sim_config.polymers and self._sim_config.polymers.enabled: print(f" Polymer: {self._sim_config.polymers.type_prefix}") print(f" Polymer count: {self._sim_config.polymers.count}") print(f" Temperature: {self._sim_config.thermodynamics.temperature} K") print(f" Total production time: {config.total_production_time_ns} ns") print(f" Time per segment: {config.segment_duration_ns} ns") print(f" Samples per segment: {config.samples_per_segment}") print(f" Replicates: {config.replicates} ({num_replicates} total)") print(f" Total jobs to submit: {total_jobs}") print(f" Dependency chains: {num_replicates} independent chains") print() print("SLURM Configuration:") print(f" Partition: {config.slurm_config.partition}") print(f" QoS: {config.slurm_config.qos}") print(f" Account: {config.slurm_config.account}") print(f" Time limit: {config.slurm_config.time_limit}") print() if config.dry_run: print("*** DRY RUN MODE - Scripts will be created but not submitted ***") print() def _print_completion_summary(self) -> None: """Print a summary after submission.""" config = self._dc_config total_jobs = sum(len(chain) for chain in self._job_chains.values()) if config.dry_run: print(f"\nDry run completed. {total_jobs} job scripts created.") print(f"Scripts saved to: {config.output_script_dir}") print("Review the scripts and run without --dry-run to submit them.") else: print(f"\nAll {total_jobs} jobs submitted successfully!") print("\nDependency chains:") for replicate, results in sorted(self._job_chains.items()): job_ids = [r.job_id for r in results] print(f" Replicate {replicate}: {' -> '.join(job_ids)}") print("\nMonitor progress with: squeue -u $USER") print("Check job details with: scontrol show job <job_id>")
[docs] def submit_daisy_chain( config_path: Union[str, Path], slurm_preset: str = "aa100", replicates: str = "1", email: str = "", dry_run: bool = False, conda_env: str = "polymerist-env", output_dir: Optional[Union[str, Path]] = None, scratch_dir: Optional[Union[str, Path]] = None, projects_dir: Optional[Union[str, Path]] = None, time_limit: Optional[str] = None, memory: Optional[str] = None, openff_logs: bool = False, skip_build: bool = False, ) -> Dict[int, List[SubmissionResult]]: """Convenience function to submit daisy-chain jobs from a YAML config. Args: config_path: Path to simulation YAML config slurm_preset: SLURM preset name (aa100, al40, blanca-shirts, testing) replicates: Replicate range string (e.g., "1-5", "1,3,5") email: Email for job notifications dry_run: If True, don't submit jobs conda_env: Conda environment name output_dir: Directory for job scripts (default: from config or "job_scripts") scratch_dir: Override scratch directory for simulation output projects_dir: Override projects directory for scripts/logs time_limit: Override SLURM time limit (format: HH:MM:SS or M:SS) memory: Override SLURM memory allocation (e.g., "4G", "8G") openff_logs: Enable verbose OpenFF logs in generated scripts skip_build: Skip system building in generated scripts (use pre-built system) Returns: Dictionary mapping replicate numbers to submission results Example: >>> results = submit_daisy_chain( ... config_path="simulation.yaml", ... slurm_preset="aa100", ... replicates="1-5", ... email="user@example.com", ... dry_run=True, ... ) """ # Load simulation config sim_config = SimulationConfig.from_yaml(config_path) # Apply CLI overrides for directories if scratch_dir: sim_config.output.scratch_directory = Path(scratch_dir) if projects_dir: sim_config.output.projects_directory = Path(projects_dir) # Determine output script directory if output_dir: script_output_dir = Path(output_dir) else: script_output_dir = sim_config.output.get_job_scripts_directory() # Create SLURM config from preset # Cast to PresetType for type checker (validated by argparse choices) from polyzymd.workflow.slurm import PresetType slurm_config = SlurmConfig.from_preset(slurm_preset, email=email) # type: ignore[arg-type] # Override time limit if provided if time_limit: slurm_config.time_limit = time_limit # Override memory if provided if memory: slurm_config.memory = memory # Create daisy-chain config dc_config = DaisyChainConfig.from_simulation_config( sim_config=sim_config, slurm_config=slurm_config, replicates=replicates, dry_run=dry_run, output_script_dir=script_output_dir, config_path=str(config_path), ) # Create submitter and submit submitter = DaisyChainSubmitter( sim_config, dc_config, conda_env=conda_env, openff_logs=openff_logs, skip_build=skip_build ) return submitter.submit_all()
[docs] def main() -> int: """Main entry point for daisy-chain submission CLI. Returns: Exit code (0 for success, 1 for failure). """ import argparse import sys parser = argparse.ArgumentParser(description="Submit daisy-chained MD simulation jobs to SLURM") parser.add_argument( "-c", "--config", type=str, required=True, help="Path to simulation YAML configuration file", ) parser.add_argument( "-r", "--replicates", type=str, default="1", help="Replicate range (e.g., '1-5', '1,3,5'). Default: 1", ) parser.add_argument( "--preset", type=str, choices=["aa100", "al40", "blanca-shirts", "testing"], default="aa100", help="SLURM partition preset. Default: aa100", ) parser.add_argument( "--email", type=str, default="", help="Email for job notifications", ) parser.add_argument( "--dry-run", action="store_true", help="Generate scripts but don't submit them", ) parser.add_argument( "--conda-env", type=str, default="polymerist-env", help="Conda environment name. Default: polymerist-env", ) parser.add_argument( "--output-dir", type=str, default="daisy_chain_scripts", help="Output directory for job scripts. Default: daisy_chain_scripts", ) parser.add_argument( "-v", "--verbose", action="store_true", help="Enable verbose logging", ) args = parser.parse_args() # Setup logging log_level = logging.DEBUG if args.verbose else logging.INFO logging.basicConfig( level=log_level, format="%(asctime)s - %(levelname)s - %(message)s", ) try: submit_daisy_chain( config_path=args.config, slurm_preset=args.preset, replicates=args.replicates, email=args.email, dry_run=args.dry_run, conda_env=args.conda_env, output_dir=args.output_dir, ) return 0 except FileNotFoundError as e: LOGGER.error(f"Configuration file not found: {e}") return 1 except ValueError as e: LOGGER.error(f"Invalid configuration: {e}") return 1 except Exception as e: LOGGER.error(f"Error during submission: {e}") import traceback traceback.print_exc() return 1
if __name__ == "__main__": import sys sys.exit(main())