Source code for polyzymd.core.restraints

"""
Restraint definitions and application for OpenMM systems.

This module provides classes for defining and applying various types
of restraints (flat-bottom, harmonic, etc.) to OpenMM simulations.
"""

from __future__ import annotations

import logging
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple, Union

from openmm import CustomBondForce, HarmonicBondForce, System
from openmm.app import Topology as OpenMMTopology
from openmm.unit import Quantity, angstrom, kilojoule_per_mole, nanometer

logger = logging.getLogger(__name__)


[docs] class RestraintType(str, Enum): """Types of restraints that can be applied.""" FLAT_BOTTOM = "flat_bottom" HARMONIC = "harmonic" UPPER_WALL = "upper_wall" LOWER_WALL = "lower_wall"
[docs] @dataclass class AtomSelection: """Represents a selection of atoms using MDAnalysis-style syntax. This class provides a flexible way to specify atoms for restraints using selection strings that are compatible with MDAnalysis. Attributes: selection: MDAnalysis-compatible selection string description: Human-readable description of what this selects Example: >>> sel = AtomSelection("resid 77 and name OG", "Catalytic serine oxygen") >>> indices = sel.resolve(topology) """ selection: str description: Optional[str] = None
[docs] def resolve(self, topology: OpenMMTopology) -> List[int]: """Resolve the selection to atom indices. For OpenMM topologies, we parse the selection string and find matching atoms. Supports basic MDAnalysis-style selections. Args: topology: OpenMM Topology object Returns: List of atom indices matching the selection Raises: ValueError: If selection syntax is invalid or no atoms match """ return _parse_selection(self.selection, topology)
def _parse_selection(selection: str, topology: OpenMMTopology) -> List[int]: """Parse an MDAnalysis-style selection string for OpenMM topology. Supports a subset of MDAnalysis selection syntax: - resid N: Select atoms in residue with ID N (1-indexed, like PDB) - resname XXX: Select atoms in residue with name XXX - name XXX: Select atoms with name XXX - index N: Select atom with OpenMM index N (0-indexed) - pdbindex N: Select atom with PDB serial number N (1-indexed, auto-converts) - and: Intersection of selections - or: Union of selections Index conventions: - `resid` uses 1-indexed residue numbers (matches PDB/PyMOL display) - `index` uses 0-indexed atom indices (matches OpenMM internal indexing) - `pdbindex` uses 1-indexed atom serial (matches PDB ATOM column, PyMOL display) Example: If PyMOL shows atom serial 2740 and residue 77: - Use "pdbindex 2740" or "index 2739" to select that atom - Use "resid 77" to select all atoms in that residue Args: selection: Selection string topology: OpenMM Topology Returns: List of matching atom indices """ # Tokenize the selection tokens = selection.lower().replace("(", " ( ").replace(")", " ) ").split() # Build list of all atoms with their properties atoms_data = [] for atom in topology.atoms(): atoms_data.append( { "index": atom.index, "name": atom.name.lower() if atom.name else "", "resname": atom.residue.name.lower() if atom.residue else "", "resid": atom.residue.index, # 0-indexed internally "chain": atom.residue.chain.id if atom.residue and atom.residue.chain else "", } ) def evaluate_simple(tokens: List[str], start: int) -> Tuple[set, int]: """Evaluate a simple selection (keyword value).""" if start >= len(tokens): return set(), start keyword = tokens[start] if keyword == "(": # Recurse into parentheses result, end = evaluate_or(tokens, start + 1) if end < len(tokens) and tokens[end] == ")": return result, end + 1 return result, end if start + 1 >= len(tokens): raise ValueError(f"Missing value after keyword '{keyword}'") value = tokens[start + 1] matching = set() if keyword == "resid": # resid is 1-indexed in selection (matches PDB), 0-indexed internally target_resid = int(value) - 1 for atom in atoms_data: if atom["resid"] == target_resid: matching.add(atom["index"]) elif keyword == "resname": for atom in atoms_data: if atom["resname"] == value.lower(): matching.add(atom["index"]) elif keyword == "name": for atom in atoms_data: if atom["name"] == value.lower(): matching.add(atom["index"]) elif keyword == "index": # index is 0-indexed (matches OpenMM internal indexing) target_idx = int(value) if 0 <= target_idx < len(atoms_data): matching.add(target_idx) elif keyword == "pdbindex": # pdbindex is 1-indexed (matches PDB ATOM serial number / PyMOL display) # Converts to 0-indexed for internal use target_idx = int(value) - 1 if 0 <= target_idx < len(atoms_data): matching.add(target_idx) elif keyword == "chainid" or keyword == "chain": for atom in atoms_data: if atom["chain"].lower() == value.lower(): matching.add(atom["index"]) else: raise ValueError(f"Unknown selection keyword: '{keyword}'") return matching, start + 2 def evaluate_and(tokens: List[str], start: int) -> Tuple[set, int]: """Evaluate AND expressions.""" result, pos = evaluate_simple(tokens, start) while pos < len(tokens) and tokens[pos] == "and": right, pos = evaluate_simple(tokens, pos + 1) result = result & right return result, pos def evaluate_or(tokens: List[str], start: int) -> Tuple[set, int]: """Evaluate OR expressions.""" result, pos = evaluate_and(tokens, start) while pos < len(tokens) and tokens[pos] == "or": right, pos = evaluate_and(tokens, pos + 1) result = result | right return result, pos if not tokens: raise ValueError("Empty selection string") result, _ = evaluate_or(tokens, 0) if not result: raise ValueError(f"No atoms match selection: '{selection}'") return sorted(list(result))
[docs] @dataclass class RestraintDefinition: """Definition of a single restraint to be applied to a system. Attributes: restraint_type: Type of restraint (flat_bottom, harmonic, etc.) name: Human-readable identifier atom1: First atom selection atom2: Second atom selection distance: Target or threshold distance force_constant: Force constant for the restraint enabled: Whether this restraint should be applied Example: >>> restraint = RestraintDefinition( ... restraint_type=RestraintType.FLAT_BOTTOM, ... name="catalytic_serine", ... atom1=AtomSelection("resid 77 and name OG"), ... atom2=AtomSelection("resname LIG and name C12"), ... distance=3.3 * angstrom, ... force_constant=10000 * kilojoule_per_mole / nanometer**2 ... ) """ restraint_type: RestraintType name: str atom1: AtomSelection atom2: AtomSelection distance: Quantity = field(default_factory=lambda: 3.3 * angstrom) force_constant: Quantity = field( default_factory=lambda: 10000 * kilojoule_per_mole / nanometer**2 ) enabled: bool = True
[docs] def apply(self, topology: OpenMMTopology, system: System) -> Optional[int]: """Apply this restraint to an OpenMM system. Args: topology: OpenMM Topology for resolving atom selections system: OpenMM System to add the force to Returns: Index of the added force, or None if restraint is disabled """ if not self.enabled: logger.info(f"Skipping disabled restraint: {self.name}") return None # Resolve atom selections atom1_indices = self.atom1.resolve(topology) atom2_indices = self.atom2.resolve(topology) if len(atom1_indices) != 1 or len(atom2_indices) != 1: raise ValueError( f"Restraint '{self.name}' requires exactly one atom per selection. " f"Got {len(atom1_indices)} for atom1, {len(atom2_indices)} for atom2" ) atom1_idx = atom1_indices[0] atom2_idx = atom2_indices[0] # Create the appropriate force if self.restraint_type == RestraintType.FLAT_BOTTOM: force = self._create_flat_bottom_force(atom1_idx, atom2_idx) elif self.restraint_type == RestraintType.HARMONIC: force = self._create_harmonic_force(atom1_idx, atom2_idx) elif self.restraint_type == RestraintType.UPPER_WALL: force = self._create_upper_wall_force(atom1_idx, atom2_idx) elif self.restraint_type == RestraintType.LOWER_WALL: force = self._create_lower_wall_force(atom1_idx, atom2_idx) else: raise ValueError(f"Unknown restraint type: {self.restraint_type}") force_idx = system.addForce(force) logger.info( f"Applied {self.restraint_type.value} restraint '{self.name}' " f"between atoms {atom1_idx} and {atom2_idx} " f"(r0={self.distance}, k={self.force_constant})" ) return force_idx
def _create_flat_bottom_force(self, atom1_idx: int, atom2_idx: int) -> CustomBondForce: """Create a flat-bottom potential force. U(r) = 0 if r < r0 0.5 * k * (r - r0)^2 if r >= r0 """ expression = "step(r - r0) * 0.5 * k * (r - r0)^2" force = CustomBondForce(expression) force.addGlobalParameter("k", self.force_constant) force.addGlobalParameter("r0", self.distance) force.addBond(atom1_idx, atom2_idx, []) return force def _create_harmonic_force(self, atom1_idx: int, atom2_idx: int) -> HarmonicBondForce: """Create a harmonic bond force. U(r) = 0.5 * k * (r - r0)^2 """ force = HarmonicBondForce() # Convert distance to nanometers for OpenMM r0_nm = self.distance.value_in_unit(nanometer) # Convert force constant to kJ/mol/nm^2 k_value = self.force_constant.value_in_unit(kilojoule_per_mole / nanometer**2) force.addBond(atom1_idx, atom2_idx, r0_nm, k_value) return force def _create_upper_wall_force(self, atom1_idx: int, atom2_idx: int) -> CustomBondForce: """Create an upper wall potential (prevent distance exceeding r0). U(r) = 0 if r < r0 0.5 * k * (r - r0)^2 if r >= r0 (Same as flat bottom) """ return self._create_flat_bottom_force(atom1_idx, atom2_idx) def _create_lower_wall_force(self, atom1_idx: int, atom2_idx: int) -> CustomBondForce: """Create a lower wall potential (prevent distance below r0). U(r) = 0.5 * k * (r0 - r)^2 if r < r0 0 if r >= r0 """ expression = "step(r0 - r) * 0.5 * k * (r0 - r)^2" force = CustomBondForce(expression) force.addGlobalParameter("k", self.force_constant) force.addGlobalParameter("r0", self.distance) force.addBond(atom1_idx, atom2_idx, []) return force
[docs] class RestraintFactory: """Factory for creating restraints from configuration. This class bridges the configuration schema with the restraint implementation, creating RestraintDefinition objects from config. """
[docs] @staticmethod def from_config(config: Dict[str, Any]) -> RestraintDefinition: """Create a RestraintDefinition from a configuration dictionary. Args: config: Dictionary with restraint configuration Returns: RestraintDefinition instance """ # Parse restraint type type_str = config.get("type", "flat_bottom") try: restraint_type = RestraintType(type_str) except ValueError: raise ValueError(f"Unknown restraint type: {type_str}") # Parse atom selections atom1_config = config.get("atom1", {}) atom2_config = config.get("atom2", {}) atom1 = AtomSelection( selection=atom1_config.get("selection", ""), description=atom1_config.get("description") ) atom2 = AtomSelection( selection=atom2_config.get("selection", ""), description=atom2_config.get("description") ) # Parse distance (default unit: angstrom) distance_value = config.get("distance", 3.3) distance = distance_value * angstrom # Parse force constant (default unit: kJ/mol/nm^2) k_value = config.get("force_constant", 10000.0) force_constant = k_value * kilojoule_per_mole / nanometer**2 return RestraintDefinition( restraint_type=restraint_type, name=config.get("name", "unnamed_restraint"), atom1=atom1, atom2=atom2, distance=distance, force_constant=force_constant, enabled=config.get("enabled", True), )
[docs] @staticmethod def create_flat_bottom( name: str, atom1_selection: str, atom2_selection: str, distance: float, force_constant: float = 10000.0, ) -> RestraintDefinition: """Convenience method to create a flat-bottom restraint. Args: name: Restraint identifier atom1_selection: Selection string for first atom atom2_selection: Selection string for second atom distance: Threshold distance in Angstroms force_constant: Force constant in kJ/mol/nm^2 Returns: RestraintDefinition for flat-bottom potential """ return RestraintDefinition( restraint_type=RestraintType.FLAT_BOTTOM, name=name, atom1=AtomSelection(atom1_selection), atom2=AtomSelection(atom2_selection), distance=distance * angstrom, force_constant=force_constant * kilojoule_per_mole / nanometer**2, )
[docs] @staticmethod def create_harmonic( name: str, atom1_selection: str, atom2_selection: str, distance: float, force_constant: float = 10000.0, ) -> RestraintDefinition: """Convenience method to create a harmonic restraint. Args: name: Restraint identifier atom1_selection: Selection string for first atom atom2_selection: Selection string for second atom distance: Equilibrium distance in Angstroms force_constant: Force constant in kJ/mol/nm^2 Returns: RestraintDefinition for harmonic potential """ return RestraintDefinition( restraint_type=RestraintType.HARMONIC, name=name, atom1=AtomSelection(atom1_selection), atom2=AtomSelection(atom2_selection), distance=distance * angstrom, force_constant=force_constant * kilojoule_per_mole / nanometer**2, )
[docs] def apply_restraints( restraints: List[RestraintDefinition], topology: OpenMMTopology, system: System ) -> List[int]: """Apply multiple restraints to a system. Args: restraints: List of restraint definitions topology: OpenMM Topology for resolving selections system: OpenMM System to modify Returns: List of force indices for the added restraints """ force_indices = [] for restraint in restraints: idx = restraint.apply(topology, system) if idx is not None: force_indices.append(idx) return force_indices