Source code for polyzymd.workflow.slurm

"""
SLURM job script generation for HPC cluster submission.

This module provides templates and utilities for generating SLURM
batch scripts for MD simulations.
"""

from __future__ import annotations

import logging
from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, List, Literal, Optional, Union

LOGGER = logging.getLogger(__name__)

# Preset types
PresetType = Literal["aa100", "al40", "blanca-shirts", "testing"]


[docs] @dataclass class SlurmConfig: """Configuration for SLURM job submission. Attributes: partition: SLURM partition(s) to use. qos: Quality of service. account: Account for resource allocation. time_limit: Wall time limit (HH:MM:SS). email: Email for notifications. nodes: Number of nodes. ntasks: Number of tasks. memory: Memory allocation (e.g., "3G"). gpus: Number of GPUs. exclude: Nodes to exclude. """ partition: str = "aa100" qos: str = "normal" account: str = "ucb625_asc1" time_limit: str = "23:59:59" email: str = "" nodes: int = 1 ntasks: int = 1 memory: str = "3G" gpus: int = 1 exclude: Optional[str] = None
[docs] @classmethod def from_preset(cls, preset: PresetType, email: str = "") -> "SlurmConfig": """Create a SlurmConfig from a preset. Args: preset: Preset name. email: Email for notifications. Returns: SlurmConfig with preset values. """ presets: Dict[PresetType, Dict] = { "aa100": { "partition": "aa100", "qos": "normal", "account": "ucb625_asc1", "time_limit": "23:59:59", }, "al40": { "partition": "al40", "qos": "normal", "account": "ucb625_asc1", "time_limit": "23:59:59", }, "blanca-shirts": { "partition": "blanca,blanca-shirts", "qos": "preemptable", "account": "blanca-shirts", "time_limit": "23:59:59", "exclude": "bgpu-bortz1", }, "testing": { "partition": "atesting_a100", "qos": "testing", "account": "ucb625_asc1", "time_limit": "0:05:59", }, } config_dict = presets.get(preset, presets["aa100"]) return cls(email=email, **config_dict)
[docs] @dataclass class JobContext: """Context for job script template rendering. Attributes: job_name: SLURM job name. output_file: Output file pattern (for SLURM logs). scratch_dir: Directory for simulation output (trajectories, checkpoints). projects_dir: Directory for scripts and logs. segment_index: Current segment index. replicate_num: Replicate number. extra_vars: Additional template variables. """ job_name: str output_file: str scratch_dir: str # Where simulation data goes (trajectories, checkpoints) projects_dir: str = "." # Where scripts and logs live segment_index: int = 0 replicate_num: int = 1 extra_vars: Dict = field(default_factory=dict) # Legacy alias for backwards compatibility @property def working_dir(self) -> str: """Alias for scratch_dir for backwards compatibility.""" return self.scratch_dir
[docs] class SlurmScriptGenerator: """Generator for SLURM batch scripts. Supports separate directories for: - projects_dir: Where scripts live and jobs are submitted from - scratch_dir: Where simulation output goes (trajectories, checkpoints) Example: >>> config = SlurmConfig.from_preset("aa100", email="user@example.com") >>> generator = SlurmScriptGenerator(config) >>> script = generator.generate_initial_job( ... context=JobContext( ... job_name="my_sim", ... output_file="logs/output.log", ... scratch_dir="/scratch/user/sim_output", ... projects_dir="/projects/user/polyzymd", ... ), ... python_script="run_simulation.py", ... python_args={"temperature": 300}, ... ) """ # Template for initial simulation jobs # - Job is submitted from projects_dir # - SLURM logs go to projects_dir/slurm_logs/ # - Simulation output goes to scratch_dir INITIAL_JOB_TEMPLATE = """#!/bin/bash #SBATCH --partition={partition} #SBATCH --job-name=i_{job_name} #SBATCH --output={output_file} #SBATCH --qos={qos} #SBATCH --nodes={nodes} #SBATCH --ntasks={ntasks} #SBATCH --mem={memory} #SBATCH --time={time_limit} #SBATCH --gres=gpu:{gpus} #SBATCH --mail-type=FAIL #SBATCH --mail-user={email} #SBATCH --account={account} {exclude_line} # ============================================================================= # PolyzyMD Initial Simulation Job # Segment: {segment_index} # ============================================================================= # Load conda environment (ignore module warnings on some HPC systems) module purge 2>/dev/null || true module load miniforge 2>/dev/null || true # Initialize conda/mamba for non-interactive shell eval "$(conda shell.bash hook)" mamba activate {conda_env} # Enable strict error handling after environment setup set -e # Projects directory (scripts, configs, logs) PROJECTS_DIR="{projects_dir}" # Scratch directory (simulation output) SCRATCH_DIR="{scratch_dir}" # Ensure scratch directory exists mkdir -p "$SCRATCH_DIR" # Change to projects directory where config and scripts live cd "$PROJECTS_DIR" echo "Starting initial simulation segment {segment_index}" echo "Projects dir: $PROJECTS_DIR" echo "Scratch dir: $SCRATCH_DIR" echo "Config: {config_path}" echo "Replicate: {replicate}" echo "Timestamp: $(date)" # Run the initial simulation using polyzymd CLI # This builds the system, runs equilibration, and runs the first production segment polyzymd{openff_logs_flag} run -c "{config_path}" \\ --replicate {replicate} \\ --scratch-dir "$SCRATCH_DIR" \\ --segment-time {segment_time} \\ --segment-frames {segment_frames}{skip_build_flag} echo "Segment {segment_index} completed successfully at $(date)" """ # Template for continuation jobs CONTINUATION_JOB_TEMPLATE = """#!/bin/bash #SBATCH --partition={partition} #SBATCH --job-name=c_{job_name} #SBATCH --output={output_file} #SBATCH --qos={qos} #SBATCH --nodes={nodes} #SBATCH --ntasks={ntasks} #SBATCH --mem={memory} #SBATCH --time={time_limit} #SBATCH --gres=gpu:{gpus} #SBATCH --mail-type=FAIL #SBATCH --mail-user={email} #SBATCH --account={account} {exclude_line} # ============================================================================= # PolyzyMD Continuation Job # Segment: {segment_index} # ============================================================================= # Load conda environment (ignore module warnings on some HPC systems) module purge 2>/dev/null || true module load miniforge 2>/dev/null || true # Initialize conda/mamba for non-interactive shell eval "$(conda shell.bash hook)" mamba activate {conda_env} # Enable strict error handling after environment setup set -e # Projects directory (scripts, configs, logs) PROJECTS_DIR="{projects_dir}" # Scratch directory (simulation output - where previous segment data lives) SCRATCH_DIR="{scratch_dir}" # Change to projects directory cd "$PROJECTS_DIR" echo "Starting continuation segment {segment_index}" echo "Projects dir: $PROJECTS_DIR" echo "Scratch dir: $SCRATCH_DIR" echo "Timestamp: $(date)" # Continue simulation from previous segment using polyzymd CLI # Reads checkpoint from previous segment in SCRATCH_DIR # Writes new trajectory and checkpoint to SCRATCH_DIR polyzymd{openff_logs_flag} continue \\ -w "$SCRATCH_DIR" \\ -s {segment_index} \\ -t {segment_time} \\ -n {num_samples} echo "Segment {segment_index} completed successfully at $(date)" """
[docs] def __init__( self, config: SlurmConfig, conda_env: str = "polymerist-env", openff_logs: bool = False, skip_build: bool = False, ) -> None: """Initialize the generator. Args: config: SLURM configuration. conda_env: Conda environment name. openff_logs: Enable verbose OpenFF logs in generated scripts. skip_build: Skip system building in generated scripts (use pre-built system). """ self._config = config self._conda_env = conda_env self._openff_logs = openff_logs self._skip_build = skip_build
@property def config(self) -> SlurmConfig: """Get the SLURM configuration.""" return self._config
[docs] def generate_initial_job( self, context: JobContext, config_path: str, replicate: int, segment_time: float, segment_frames: int, ) -> str: """Generate an initial simulation job script. Args: context: Job context information. config_path: Path to the YAML configuration file. replicate: Replicate number. segment_time: Duration of first segment in nanoseconds. segment_frames: Number of frames to save in first segment. Returns: SLURM batch script content. """ # Format exclude line exclude_line = "" if self._config.exclude: exclude_line = f"#SBATCH --exclude={self._config.exclude}" # Use context.projects_dir projects_dir = context.projects_dir if context.projects_dir != "." else "." # Format openff_logs flag openff_logs_flag = " --openff-logs" if self._openff_logs else "" # Format skip_build flag (only for initial job - continuation jobs don't build) skip_build_flag = " \\\n --skip-build" if self._skip_build else "" return self.INITIAL_JOB_TEMPLATE.format( partition=self._config.partition, job_name=context.job_name, output_file=context.output_file, qos=self._config.qos, nodes=self._config.nodes, ntasks=self._config.ntasks, memory=self._config.memory, time_limit=self._config.time_limit, gpus=self._config.gpus, email=self._config.email, account=self._config.account, exclude_line=exclude_line, conda_env=self._conda_env, projects_dir=projects_dir, scratch_dir=context.scratch_dir, config_path=config_path, replicate=replicate, segment_time=segment_time, segment_frames=segment_frames, segment_index=context.segment_index, openff_logs_flag=openff_logs_flag, skip_build_flag=skip_build_flag, )
[docs] def generate_continuation_job( self, context: JobContext, segment_time: float, num_samples: int, ) -> str: """Generate a continuation job script. Args: context: Job context information. segment_time: Duration of this segment in nanoseconds. num_samples: Number of frames to save. Returns: SLURM batch script content. """ exclude_line = "" if self._config.exclude: exclude_line = f"#SBATCH --exclude={self._config.exclude}" # Use context.projects_dir projects_dir = context.projects_dir if context.projects_dir != "." else "." # Format openff_logs flag openff_logs_flag = " --openff-logs" if self._openff_logs else "" return self.CONTINUATION_JOB_TEMPLATE.format( partition=self._config.partition, job_name=context.job_name, output_file=context.output_file, qos=self._config.qos, nodes=self._config.nodes, ntasks=self._config.ntasks, memory=self._config.memory, time_limit=self._config.time_limit, gpus=self._config.gpus, email=self._config.email, account=self._config.account, exclude_line=exclude_line, conda_env=self._conda_env, projects_dir=projects_dir, scratch_dir=context.scratch_dir, segment_index=context.segment_index, segment_time=segment_time, num_samples=num_samples, openff_logs_flag=openff_logs_flag, )
[docs] def save_script( self, script_content: str, output_path: Union[str, Path], make_executable: bool = True, ) -> Path: """Save a script to a file. Args: script_content: Script content. output_path: Output file path. make_executable: Whether to make the script executable. Returns: Path to the saved script. """ import os output_path = Path(output_path) output_path.parent.mkdir(parents=True, exist_ok=True) with open(output_path, "w") as f: f.write(script_content) if make_executable: os.chmod(output_path, 0o755) LOGGER.info(f"Saved script to {output_path}") return output_path
[docs] def parse_replicate_range(replicate_range: str) -> List[int]: """Parse a SLURM array range into a list of replicate numbers. Args: replicate_range: SLURM array format (e.g., "1-5", "1,3,5", "1-10:2"). Returns: List of replicate numbers. Example: >>> parse_replicate_range("1-5") [1, 2, 3, 4, 5] >>> parse_replicate_range("1,3,5") [1, 3, 5] >>> parse_replicate_range("1-10:2") [1, 3, 5, 7, 9] """ replicates = [] parts = replicate_range.split(",") for part in parts: part = part.strip() if "-" in part: if ":" in part: range_part, step = part.split(":") step = int(step) else: range_part = part step = 1 start, end = map(int, range_part.split("-")) replicates.extend(range(start, end + 1, step)) else: replicates.append(int(part)) return sorted(list(set(replicates)))
[docs] def validate_replicate_range(replicate_range: str) -> bool: """Validate that a replicate range is in proper SLURM array format. Args: replicate_range: Range string to validate. Returns: True if valid. Raises: ValueError: If the format is invalid. """ import re pattern = r"^(\d+(-\d+(:\d+)?)?)(,\d+(-\d+(:\d+)?)?)*$" if not re.match(pattern, replicate_range): raise ValueError(f"Invalid replicate range format: {replicate_range}") return True