Source code for polyzymd.analyses.mda.universe

"""Universe loading and provenance helpers for the MDAnalysis extension layer."""

from __future__ import annotations

from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Protocol

if TYPE_CHECKING:
    from MDAnalysis.core.universe import Universe

    from polyzymd.analyses.shared.loader import TrajectoryInfo
    from polyzymd.config.schema import SimulationConfig


class _TrajectoryLoaderLike(Protocol):
    """Structural protocol for trajectory loaders used by ``UniverseProvider``."""

    def load_universe(self, replicate: int, cache: bool = True) -> Universe:
        """Load a universe for a replicate.

        Parameters
        ----------
        replicate : int
            Replicate index to load.
        cache : bool, optional
            Whether the loader may reuse its universe cache, by default True.

        Returns
        -------
        Universe
            Loaded MDAnalysis universe.
        """

    def get_trajectory_info(self, replicate: int) -> TrajectoryInfo:
        """Resolve trajectory files for a replicate without loading coordinates.

        Parameters
        ----------
        replicate : int
            Replicate index to inspect.

        Returns
        -------
        TrajectoryInfo
            Resolved trajectory metadata.
        """


LoaderFactory = Callable[..., _TrajectoryLoaderLike]

GRO_CHAIN_ID_WARNING_TEMPLATE = (
    "Using GRO topology {path} — GRO files may not preserve chain identifiers. "
    "Chain-based selections (chainid A/B/C) used by analysis plugins may be unreliable. "
    "Prefer a PDB topology when available."
)


[docs] @dataclass(frozen=True) class FileIdentity: """Filesystem identity for an input topology or trajectory file.""" path: Path format: str | None size_bytes: int mtime_ns: int
[docs] @classmethod def from_path(cls, path: Path, file_format: str | None = None) -> FileIdentity: """Create file identity metadata from a filesystem path. Parameters ---------- path : Path File path to identify. file_format : str or None, optional Format reported by the trajectory layout. When omitted, the file suffix is used without the leading dot. Returns ------- FileIdentity Path, format, size, and modification-time metadata. """ resolved_path = Path(path) stat = resolved_path.stat() inferred_format = file_format or resolved_path.suffix.removeprefix(".").lower() or None return cls( path=resolved_path, format=inferred_format, size_bytes=stat.st_size, mtime_ns=stat.st_mtime_ns, )
[docs] def as_dict(self) -> dict[str, Any]: """Serialize the identity to JSON-compatible primitive values. Returns ------- dict[str, Any] Dictionary representation with the path converted to a string. """ return { "path": str(self.path), "format": self.format, "size_bytes": self.size_bytes, "mtime_ns": self.mtime_ns, }
[docs] @dataclass(frozen=True) class UniverseProvenance: """Provenance for one replicate universe loaded from PolyzyMD outputs.""" replicate: int working_directory: Path topology: FileIdentity trajectories: tuple[FileIdentity, ...] n_segments: int loader_class: str config_engine: str | None engine_override: str | None = None warnings: tuple[str, ...] = field(default_factory=tuple)
[docs] def as_dict(self) -> dict[str, Any]: """Serialize provenance to JSON-compatible primitive values. Returns ------- dict[str, Any] Dictionary representation suitable for manifests and tests. """ return { "replicate": self.replicate, "working_directory": str(self.working_directory), "topology": self.topology.as_dict(), "trajectories": [trajectory.as_dict() for trajectory in self.trajectories], "n_segments": self.n_segments, "loader_class": self.loader_class, "config_engine": self.config_engine, "engine_override": self.engine_override, "warnings": list(self.warnings), }
[docs] @dataclass class UniverseProvider: """Config-aware provider for MDAnalysis universes and input provenance.""" config: SimulationConfig engine_override: str | None = None loader: _TrajectoryLoaderLike | None = None loader_factory: LoaderFactory | None = None _provenance_cache: dict[int, UniverseProvenance] = field(default_factory=dict, init=False)
[docs] def __post_init__(self) -> None: """Validate loader injection settings after dataclass construction.""" if self.loader is not None and self.loader_factory is not None: raise ValueError("Provide either loader or loader_factory, not both.")
[docs] @classmethod def from_config(cls, config: SimulationConfig, **kwargs: Any) -> UniverseProvider: """Create a universe provider from a simulation configuration. Parameters ---------- config : SimulationConfig PolyzyMD simulation configuration. **kwargs : Any Optional provider settings such as ``engine_override``, ``loader``, or ``loader_factory``. Returns ------- UniverseProvider Provider that lazily instantiates the trajectory loader. """ return cls(config=config, **kwargs)
[docs] def load_universe(self, replicate: int, *, cache: bool = True) -> Universe: """Load an MDAnalysis universe for a replicate through the existing loader. Parameters ---------- replicate : int Replicate index to load. cache : bool, optional Whether the underlying loader may reuse its universe cache, by default True. Returns ------- Universe Loaded MDAnalysis universe from the underlying trajectory loader. """ self.provenance_for(replicate, refresh=not cache) return self._get_loader().load_universe(replicate, cache=cache)
[docs] def provenance_for(self, replicate: int, *, refresh: bool = False) -> UniverseProvenance: """Return provenance for a replicate, computing it when needed. Parameters ---------- replicate : int Replicate index to inspect. refresh : bool, optional Recompute provenance even when cached, by default False. Returns ------- UniverseProvenance Input file identity and loader metadata for the replicate. """ if not refresh and replicate in self._provenance_cache: return self._provenance_cache[replicate] loader = self._get_loader() info = loader.get_trajectory_info(replicate) provenance = self._build_provenance(info=info, loader=loader) self._provenance_cache[replicate] = provenance return provenance
[docs] def get_provenance(self, replicate: int) -> UniverseProvenance | None: """Return cached provenance without triggering trajectory discovery. Parameters ---------- replicate : int Replicate index whose cached provenance should be returned. Returns ------- UniverseProvenance or None Cached provenance when available, otherwise ``None``. """ return self._provenance_cache.get(replicate)
def _get_loader(self) -> _TrajectoryLoaderLike: """Return the lazily instantiated trajectory loader. Returns ------- _TrajectoryLoaderLike Injected or default trajectory loader. """ if self.loader is None: factory = self.loader_factory or self._default_loader_factory self.loader = factory(self.config, engine_override=self.engine_override) return self.loader @staticmethod def _default_loader_factory( config: SimulationConfig, *, engine_override: str | None = None, ) -> _TrajectoryLoaderLike: """Create the default shared trajectory loader lazily. Parameters ---------- config : SimulationConfig PolyzyMD simulation configuration. engine_override : str or None, optional Engine override passed through to ``TrajectoryLoader``. Returns ------- _TrajectoryLoaderLike Shared trajectory loader instance. """ from polyzymd.analyses.shared.loader import TrajectoryLoader return TrajectoryLoader(config, engine_override=engine_override) def _build_provenance( self, *, info: TrajectoryInfo, loader: _TrajectoryLoaderLike, ) -> UniverseProvenance: """Build provenance metadata from shared loader trajectory info. Parameters ---------- info : TrajectoryInfo Resolved trajectory metadata from the shared loader. loader : _TrajectoryLoaderLike Loader instance used to resolve the metadata. Returns ------- UniverseProvenance Provenance with file identities and warnings. """ topology_format = self._metadata_format(info, "topology_format", info.topology_file) trajectory_format = self._metadata_format(info, "trajectory_format", None) warnings = list(getattr(info, "warnings", [])) gro_warning = self._gro_chain_id_warning(info, topology_format) if gro_warning is not None: if gro_warning not in warnings: warnings.append(gro_warning) return UniverseProvenance( replicate=info.replicate, working_directory=Path(info.working_directory), topology=FileIdentity.from_path(info.topology_file, topology_format), trajectories=tuple( FileIdentity.from_path(path, trajectory_format) for path in info.trajectory_files ), n_segments=info.n_segments, loader_class=type(loader).__name__, config_engine=self._config_engine(), engine_override=self.engine_override, warnings=tuple(warnings), ) def _config_engine(self) -> str | None: """Return the configured simulation engine name when it is concrete. Returns ------- str or None String engine name from the config, otherwise ``None``. """ engine = getattr(self.config, "engine", None) if isinstance(engine, str): return engine return None @staticmethod def _metadata_format(info: TrajectoryInfo, field_name: str, path: Path | None) -> str | None: """Return a format value from trajectory info with suffix fallback. Parameters ---------- info : TrajectoryInfo Resolved trajectory metadata. field_name : str Name of the optional format field on ``TrajectoryInfo``. path : Path or None File path used to infer the format when metadata is absent. Returns ------- str or None Lowercase format string or ``None`` when unavailable. """ value = getattr(info, field_name, None) if isinstance(value, str) and value: return value.lower() if path is not None: suffix = Path(path).suffix.removeprefix(".").lower() if suffix: return suffix return None def _gro_chain_id_warning( self, info: TrajectoryInfo, topology_format: str | None ) -> str | None: """Return an actionable GRO chain-ID warning when applicable. Parameters ---------- info : TrajectoryInfo Resolved trajectory metadata. topology_format : str or None Resolved topology format. Returns ------- str or None Warning text for GRO topology inputs, otherwise ``None``. """ topology_path = Path(info.topology_file) suffix_is_gro = topology_path.suffix.lower() == ".gro" format_is_gro = topology_format == "gro" if not suffix_is_gro and not format_is_gro: return None return GRO_CHAIN_ID_WARNING_TEMPLATE.format(path=topology_path)