"""
Restraint definitions and application for OpenMM systems.
This module provides classes for defining and applying various types
of restraints (flat-bottom, harmonic, etc.) to OpenMM simulations.
"""
from __future__ import annotations
import logging
from dataclasses import dataclass, field
from enum import Enum
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
if TYPE_CHECKING:
from openmm import CustomBondForce, HarmonicBondForce, System
from openmm.app import Topology as OpenMMTopology
from openmm.unit import Quantity
logger = logging.getLogger(__name__)
def _distance_in_angstroms(value: float) -> Quantity:
"""Create an OpenMM distance quantity in Angstroms.
Parameters
----------
value : float
Distance magnitude in Angstroms.
Returns
-------
Quantity
OpenMM quantity with Angstrom units.
"""
from openmm.unit import angstrom
return value * angstrom
def _force_constant_in_kj_per_mol_nm2(value: float) -> Quantity:
"""Create an OpenMM force constant quantity.
Parameters
----------
value : float
Force constant magnitude in kJ/mol/nm^2.
Returns
-------
Quantity
OpenMM quantity with kJ/mol/nm^2 units.
"""
from openmm.unit import kilojoule_per_mole, nanometer
return value * kilojoule_per_mole / nanometer**2
def _quantity_in_nanometers(quantity: Quantity) -> float:
"""Convert an OpenMM distance quantity to nanometers.
Parameters
----------
quantity : Quantity
Distance quantity to convert.
Returns
-------
float
Distance magnitude in nanometers.
"""
from openmm.unit import nanometer
return quantity.value_in_unit(nanometer)
def _force_constant_value_in_openmm_units(quantity: Quantity) -> float:
"""Convert an OpenMM force constant to kJ/mol/nm^2.
Parameters
----------
quantity : Quantity
Force constant quantity to convert.
Returns
-------
float
Force constant magnitude in kJ/mol/nm^2.
"""
from openmm.unit import kilojoule_per_mole, nanometer
return quantity.value_in_unit(kilojoule_per_mole / nanometer**2)
[docs]
class RestraintType(str, Enum):
"""Types of restraints that can be applied."""
FLAT_BOTTOM = "flat_bottom"
HARMONIC = "harmonic"
UPPER_WALL = "upper_wall"
LOWER_WALL = "lower_wall"
[docs]
@dataclass
class AtomSelection:
"""Represents a selection of atoms using MDAnalysis-style syntax.
This class provides a flexible way to specify atoms for restraints
using selection strings that are compatible with MDAnalysis.
Attributes:
selection: MDAnalysis-compatible selection string
description: Human-readable description of what this selects
Example:
>>> sel = AtomSelection("resid 77 and name OG", "Catalytic serine oxygen")
>>> indices = sel.resolve(topology)
"""
selection: str
description: Optional[str] = None
[docs]
def resolve(self, topology: OpenMMTopology) -> List[int]:
"""Resolve the selection to atom indices.
For OpenMM topologies, we parse the selection string and
find matching atoms. Supports basic MDAnalysis-style selections.
Args:
topology: OpenMM Topology object
Returns:
List of atom indices matching the selection
Raises:
ValueError: If selection syntax is invalid or no atoms match
"""
return _parse_selection(self.selection, topology)
def _parse_selection(selection: str, topology: OpenMMTopology) -> List[int]:
"""Parse an MDAnalysis-style selection string for OpenMM topology.
Supports a subset of MDAnalysis selection syntax:
- resid N: Select atoms in residue with ID N (1-indexed, like PDB)
- resname XXX: Select atoms in residue with name XXX
- name XXX: Select atoms with name XXX
- index N: Select atom with OpenMM index N (0-indexed)
- pdbindex N: Select atom with PDB serial number N (1-indexed, auto-converts)
- and: Intersection of selections
- or: Union of selections
Index conventions:
- `resid` uses 1-indexed residue numbers (matches PDB/PyMOL display)
- `index` uses 0-indexed atom indices (matches OpenMM internal indexing)
- `pdbindex` uses 1-indexed atom serial (matches PDB ATOM column, PyMOL display)
Example: If PyMOL shows atom serial 2740 and residue 77:
- Use "pdbindex 2740" or "index 2739" to select that atom
- Use "resid 77" to select all atoms in that residue
Args:
selection: Selection string
topology: OpenMM Topology
Returns:
List of matching atom indices
"""
# Tokenize the selection
tokens = selection.lower().replace("(", " ( ").replace(")", " ) ").split()
# Build list of all atoms with their properties
atoms_data = []
for atom in topology.atoms():
atoms_data.append(
{
"index": atom.index,
"name": atom.name.lower() if atom.name else "",
"resname": atom.residue.name.lower() if atom.residue else "",
"resid": int(atom.residue.id), # PDB residue number (resSeq)
"chain": atom.residue.chain.id if atom.residue and atom.residue.chain else "",
}
)
def evaluate_simple(tokens: List[str], start: int) -> Tuple[set, int]:
"""Evaluate a simple selection (keyword value)."""
if start >= len(tokens):
return set(), start
keyword = tokens[start]
if keyword == "(":
# Recurse into parentheses
result, end = evaluate_or(tokens, start + 1)
if end < len(tokens) and tokens[end] == ")":
return result, end + 1
return result, end
if start + 1 >= len(tokens):
raise ValueError(f"Missing value after keyword '{keyword}'")
value = tokens[start + 1]
matching = set()
if keyword == "resid":
# resid matches PDB residue number (resSeq) directly
target_resid = int(value)
for atom in atoms_data:
if atom["resid"] == target_resid:
matching.add(atom["index"])
elif keyword == "resname":
for atom in atoms_data:
if atom["resname"] == value.lower():
matching.add(atom["index"])
elif keyword == "name":
for atom in atoms_data:
if atom["name"] == value.lower():
matching.add(atom["index"])
elif keyword == "index":
# index is 0-indexed (matches OpenMM internal indexing)
target_idx = int(value)
if 0 <= target_idx < len(atoms_data):
matching.add(target_idx)
elif keyword == "pdbindex":
# pdbindex is 1-indexed (matches PDB ATOM serial number / PyMOL display)
# Converts to 0-indexed for internal use
target_idx = int(value) - 1
if 0 <= target_idx < len(atoms_data):
matching.add(target_idx)
elif keyword == "chainid" or keyword == "chain":
for atom in atoms_data:
if atom["chain"].lower() == value.lower():
matching.add(atom["index"])
else:
raise ValueError(f"Unknown selection keyword: '{keyword}'")
return matching, start + 2
def evaluate_and(tokens: List[str], start: int) -> Tuple[set, int]:
"""Evaluate AND expressions."""
result, pos = evaluate_simple(tokens, start)
while pos < len(tokens) and tokens[pos] == "and":
right, pos = evaluate_simple(tokens, pos + 1)
result = result & right
return result, pos
def evaluate_or(tokens: List[str], start: int) -> Tuple[set, int]:
"""Evaluate OR expressions."""
result, pos = evaluate_and(tokens, start)
while pos < len(tokens) and tokens[pos] == "or":
right, pos = evaluate_and(tokens, pos + 1)
result = result | right
return result, pos
if not tokens:
raise ValueError("Empty selection string")
result, _ = evaluate_or(tokens, 0)
if not result:
raise ValueError(f"No atoms match selection: '{selection}'")
return sorted(list(result))
[docs]
@dataclass
class RestraintDefinition:
"""Definition of a single restraint to be applied to a system.
Attributes:
restraint_type: Type of restraint (flat_bottom, harmonic, etc.)
name: Human-readable identifier
atom1: First atom selection
atom2: Second atom selection
distance: Target or threshold distance
force_constant: Force constant for the restraint
enabled: Whether this restraint should be applied
Example:
>>> restraint = RestraintDefinition(
... restraint_type=RestraintType.FLAT_BOTTOM,
... name="catalytic_serine",
... atom1=AtomSelection("resid 77 and name OG"),
... atom2=AtomSelection("resname LIG and name C12"),
... distance=3.3 * angstrom,
... force_constant=10000 * kilojoule_per_mole / nanometer**2
... )
"""
restraint_type: RestraintType
name: str
atom1: AtomSelection
atom2: AtomSelection
distance: Quantity = field(default_factory=lambda: _distance_in_angstroms(3.3))
force_constant: Quantity = field(
default_factory=lambda: _force_constant_in_kj_per_mol_nm2(10000)
)
enabled: bool = True
[docs]
def apply(self, topology: OpenMMTopology, system: System) -> Optional[int]:
"""Apply this restraint to an OpenMM system.
Args:
topology: OpenMM Topology for resolving atom selections
system: OpenMM System to add the force to
Returns:
Index of the added force, or None if restraint is disabled
"""
if not self.enabled:
logger.info(f"Skipping disabled restraint: {self.name}")
return None
# Resolve atom selections
atom1_indices = self.atom1.resolve(topology)
atom2_indices = self.atom2.resolve(topology)
if len(atom1_indices) != 1 or len(atom2_indices) != 1:
raise ValueError(
f"Restraint '{self.name}' requires exactly one atom per selection. "
f"Got {len(atom1_indices)} for atom1, {len(atom2_indices)} for atom2"
)
atom1_idx = atom1_indices[0]
atom2_idx = atom2_indices[0]
# Create the appropriate force
if self.restraint_type == RestraintType.FLAT_BOTTOM:
force = self._create_flat_bottom_force(atom1_idx, atom2_idx)
elif self.restraint_type == RestraintType.HARMONIC:
force = self._create_harmonic_force(atom1_idx, atom2_idx)
elif self.restraint_type == RestraintType.UPPER_WALL:
force = self._create_upper_wall_force(atom1_idx, atom2_idx)
elif self.restraint_type == RestraintType.LOWER_WALL:
force = self._create_lower_wall_force(atom1_idx, atom2_idx)
else:
raise ValueError(f"Unknown restraint type: {self.restraint_type}")
force_idx = system.addForce(force)
logger.info(
f"Applied {self.restraint_type.value} restraint '{self.name}' "
f"between atoms {atom1_idx} and {atom2_idx} "
f"(r0={self.distance}, k={self.force_constant})"
)
return force_idx
def _create_flat_bottom_force(self, atom1_idx: int, atom2_idx: int) -> CustomBondForce:
"""Create a flat-bottom potential force.
U(r) = 0 if r < r0
0.5 * k * (r - r0)^2 if r >= r0
"""
from openmm import CustomBondForce
expression = "step(r - r0) * 0.5 * k * (r - r0)^2"
force = CustomBondForce(expression)
force.addGlobalParameter("k", self.force_constant)
force.addGlobalParameter("r0", self.distance)
force.addBond(atom1_idx, atom2_idx, [])
return force
def _create_harmonic_force(self, atom1_idx: int, atom2_idx: int) -> HarmonicBondForce:
"""Create a harmonic bond force.
U(r) = 0.5 * k * (r - r0)^2
"""
from openmm import HarmonicBondForce
force = HarmonicBondForce()
# Convert distance to nanometers for OpenMM
r0_nm = _quantity_in_nanometers(self.distance)
# Convert force constant to kJ/mol/nm^2
k_value = _force_constant_value_in_openmm_units(self.force_constant)
force.addBond(atom1_idx, atom2_idx, r0_nm, k_value)
return force
def _create_upper_wall_force(self, atom1_idx: int, atom2_idx: int) -> CustomBondForce:
"""Create an upper wall potential (prevent distance exceeding r0).
U(r) = 0 if r < r0
0.5 * k * (r - r0)^2 if r >= r0
(Same as flat bottom)
"""
return self._create_flat_bottom_force(atom1_idx, atom2_idx)
def _create_lower_wall_force(self, atom1_idx: int, atom2_idx: int) -> CustomBondForce:
"""Create a lower wall potential (prevent distance below r0).
U(r) = 0.5 * k * (r0 - r)^2 if r < r0
0 if r >= r0
"""
from openmm import CustomBondForce
expression = "step(r0 - r) * 0.5 * k * (r0 - r)^2"
force = CustomBondForce(expression)
force.addGlobalParameter("k", self.force_constant)
force.addGlobalParameter("r0", self.distance)
force.addBond(atom1_idx, atom2_idx, [])
return force
[docs]
class RestraintFactory:
"""Factory for creating restraints from configuration.
This class bridges the configuration schema with the restraint
implementation, creating RestraintDefinition objects from config.
"""
[docs]
@staticmethod
def from_config(config: Dict[str, Any]) -> RestraintDefinition:
"""Create a RestraintDefinition from a configuration dictionary.
Args:
config: Dictionary with restraint configuration
Returns:
RestraintDefinition instance
"""
# Parse restraint type
type_str = config.get("type", "flat_bottom")
try:
restraint_type = RestraintType(type_str)
except ValueError:
raise ValueError(f"Unknown restraint type: {type_str}")
# Parse atom selections
atom1_config = config.get("atom1", {})
atom2_config = config.get("atom2", {})
atom1 = AtomSelection(
selection=atom1_config.get("selection", ""), description=atom1_config.get("description")
)
atom2 = AtomSelection(
selection=atom2_config.get("selection", ""), description=atom2_config.get("description")
)
# Parse distance (default unit: angstrom)
distance_value = config.get("distance", 3.3)
distance = _distance_in_angstroms(distance_value)
# Parse force constant (default unit: kJ/mol/nm^2)
k_value = config.get("force_constant", 10000.0)
force_constant = _force_constant_in_kj_per_mol_nm2(k_value)
return RestraintDefinition(
restraint_type=restraint_type,
name=config.get("name", "unnamed_restraint"),
atom1=atom1,
atom2=atom2,
distance=distance,
force_constant=force_constant,
enabled=config.get("enabled", True),
)
[docs]
@staticmethod
def create_flat_bottom(
name: str,
atom1_selection: str,
atom2_selection: str,
distance: float,
force_constant: float = 10000.0,
) -> RestraintDefinition:
"""Convenience method to create a flat-bottom restraint.
Args:
name: Restraint identifier
atom1_selection: Selection string for first atom
atom2_selection: Selection string for second atom
distance: Threshold distance in Angstroms
force_constant: Force constant in kJ/mol/nm^2
Returns:
RestraintDefinition for flat-bottom potential
"""
return RestraintDefinition(
restraint_type=RestraintType.FLAT_BOTTOM,
name=name,
atom1=AtomSelection(atom1_selection),
atom2=AtomSelection(atom2_selection),
distance=_distance_in_angstroms(distance),
force_constant=_force_constant_in_kj_per_mol_nm2(force_constant),
)
[docs]
@staticmethod
def create_harmonic(
name: str,
atom1_selection: str,
atom2_selection: str,
distance: float,
force_constant: float = 10000.0,
) -> RestraintDefinition:
"""Convenience method to create a harmonic restraint.
Args:
name: Restraint identifier
atom1_selection: Selection string for first atom
atom2_selection: Selection string for second atom
distance: Equilibrium distance in Angstroms
force_constant: Force constant in kJ/mol/nm^2
Returns:
RestraintDefinition for harmonic potential
"""
return RestraintDefinition(
restraint_type=RestraintType.HARMONIC,
name=name,
atom1=AtomSelection(atom1_selection),
atom2=AtomSelection(atom2_selection),
distance=_distance_in_angstroms(distance),
force_constant=_force_constant_in_kj_per_mol_nm2(force_constant),
)
[docs]
def apply_restraints(
restraints: List[RestraintDefinition], topology: OpenMMTopology, system: System
) -> List[int]:
"""Apply multiple restraints to a system.
Args:
restraints: List of restraint definitions
topology: OpenMM Topology for resolving selections
system: OpenMM System to modify
Returns:
List of force indices for the added restraints
"""
force_indices = []
for restraint in restraints:
idx = restraint.apply(topology, system)
if idx is not None:
force_indices.append(idx)
return force_indices