Source code for polyzymd.analyses.mda.lifecycle

"""Lifecycle bridge between PolyzyMD replicates and MDAnalysis jobs."""

from __future__ import annotations

import logging
from collections.abc import Sequence
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Any

from pydantic import BaseModel

from polyzymd.analyses.exceptions import PluginContractError
from polyzymd.analyses.mda.artifacts import ReplicateArtifact, raw_mdanalysis_results_path
from polyzymd.analyses.mda.frame_selection import FrameSelection
from polyzymd.analyses.mda.job import MDAAnalysisJob, MDAJobResult, MDAUniversePolicy
from polyzymd.analyses.mda.plugin import MDACollectorContext
from polyzymd.analyses.mda.store import ArtifactStore, ArtifactStoreError

if TYPE_CHECKING:
    from polyzymd.analyses._framework.contexts import ReplicateContext
    from polyzymd.analyses.mda.job import MDABackendPolicy

logger = logging.getLogger("polyzymd.analyses")


[docs] @dataclass(frozen=True) class MDAReplicateJobContext: """Context passed to ``Analysis.build_mda_jobs()`` for one replicate.""" replicate_context: ReplicateContext universe: Any frame_selection: FrameSelection universe_policy: MDAUniversePolicy artifact_store: ArtifactStore @property def output_dir(self) -> Path: """Return the replicate output directory. Returns ------- Path Directory owned by this replicate analysis run. """ return self.replicate_context.output_dir @property def replicate(self) -> int: """Return the one-indexed replicate ID. Returns ------- int Replicate ID from the framework context. """ return self.replicate_context.replicate @property def settings(self) -> BaseModel: """Return resolved plugin settings. Returns ------- BaseModel Settings model supplied by the public lifecycle. """ return self.replicate_context.settings @property def backend_policy(self) -> MDABackendPolicy: """Return the MDAnalysis backend policy for job construction. Returns ------- MDABackendPolicy Policy resolved from comparison configuration, or the serial default. """ return self.replicate_context.backend_policy
[docs] def build_mda_replicate_job_context( analysis: Any, ctx: ReplicateContext, replicate: int ) -> MDAReplicateJobContext: """Build the MDAnalysis job context for one replicate. Parameters ---------- analysis : Any Analysis instance requesting an MDAnalysis job context. ctx : ReplicateContext Framework-provided replicate context. replicate : int One-indexed replicate ID. Returns ------- MDAReplicateJobContext Context containing the loaded universe, resolved frame selection, universe policy, and artifact store. """ loader = analysis._trajectory_loader_factory()(ctx.sim_config) provider = _build_universe_provider(analysis, ctx, loader) universe = provider.load_universe(replicate) window = analysis.get_trajectory_window(ctx, replicate, loader, universe) if getattr(window, "warning_message", None): logger.warning( "%s: %s [condition=%s, replicate=%d]", analysis.name, window.warning_message, ctx.condition.label, replicate, ) frame_selection = FrameSelection.from_trajectory_window(window) provenance = _provenance_for(provider, replicate) universe_policy = MDAUniversePolicy( condition_label=ctx.condition.label, replicate=replicate, provenance=provenance, metadata={"equilibration": ctx.equilibration}, ) artifact_store = analysis._mda_artifact_store_factory()(ctx.output_dir) return MDAReplicateJobContext( replicate_context=ctx, universe=universe, frame_selection=frame_selection, universe_policy=universe_policy, artifact_store=artifact_store, )
[docs] def run_mda_replicate_jobs( analysis: Any, ctx: ReplicateContext, replicate: int ) -> ReplicateArtifact | None: """Run MDAnalysis jobs for one replicate and collect a strict artifact. Parameters ---------- analysis : Any Analysis instance with a ``build_mda_jobs()`` hook. ctx : ReplicateContext Framework-provided replicate context. replicate : int One-indexed replicate ID. Returns ------- ReplicateArtifact or None Collected replicate artifact, or ``None`` when the hook declines the MDA path. """ mda_ctx = build_mda_replicate_job_context(analysis, ctx, replicate) jobs = analysis.build_mda_jobs(mda_ctx) if jobs is None: return None normalized_jobs = _validate_jobs(jobs, analysis_name=analysis.name) completed_jobs = [job.run() for job in normalized_jobs] return _artifact_from_completed_jobs(analysis, mda_ctx, completed_jobs)
def _build_universe_provider(analysis: Any, ctx: ReplicateContext, loader: Any) -> Any: """Create the universe provider using the analysis injection hook. Parameters ---------- analysis : Any Analysis instance supplying the provider factory. ctx : ReplicateContext Framework-provided replicate context. loader : Any Shared trajectory loader already allocated for this replicate. Returns ------- Any Universe provider compatible with ``UniverseProvider``. """ provider_factory = analysis._mda_universe_provider_factory() if hasattr(provider_factory, "from_config"): return provider_factory.from_config(ctx.sim_config, loader=loader) return provider_factory(ctx.sim_config, loader=loader) def _provenance_for(provider: Any, replicate: int) -> Any: """Return provider provenance when the provider exposes it. Parameters ---------- provider : Any Universe provider used to load the replicate universe. replicate : int One-indexed replicate ID. Returns ------- Any Provider-specific provenance object, or ``None`` when unavailable. """ if hasattr(provider, "provenance_for"): return provider.provenance_for(replicate) if hasattr(provider, "get_provenance"): return provider.get_provenance(replicate) return None def _validate_jobs( jobs: Sequence[MDAAnalysisJob], *, analysis_name: str ) -> tuple[MDAAnalysisJob, ...]: """Validate job-builder output. Parameters ---------- jobs : sequence of MDAAnalysisJob Jobs returned by ``build_mda_jobs()``. analysis_name : str Analysis name for diagnostics. Returns ------- tuple[MDAAnalysisJob, ...] Concrete job tuple ready for execution. """ if isinstance(jobs, (str, bytes)) or not isinstance(jobs, Sequence): raise PluginContractError( f"{analysis_name}.build_mda_jobs() must return a sequence of MDAAnalysisJob objects" ) normalized_jobs = tuple(jobs) if not normalized_jobs: raise PluginContractError(f"{analysis_name}.build_mda_jobs() returned no jobs") invalid = [job for job in normalized_jobs if not isinstance(job, MDAAnalysisJob)] if invalid: raise PluginContractError( f"{analysis_name}.build_mda_jobs() returned {type(invalid[0]).__name__}; " "expected MDAAnalysisJob" ) return normalized_jobs def _artifact_from_completed_jobs( analysis: Any, ctx: MDAReplicateJobContext, completed_jobs: Sequence[MDAJobResult], ) -> ReplicateArtifact: """Collect completed MDA job results through the analysis collector. Parameters ---------- analysis : Any Analysis instance that owns the jobs. ctx : MDAReplicateJobContext MDA replicate context used for execution. completed_jobs : sequence of MDAJobResult Completed job result references. Returns ------- ReplicateArtifact Strict JSON-compatible artifact envelope. """ collector_ctx = _build_collector_context(analysis, ctx) collector = analysis.build_mda_collector(collector_ctx) if not callable(collector): raise PluginContractError( f"{analysis.name}.build_mda_collector() must return a callable collector, " f"got {type(collector).__name__}" ) artifact = collector(collector_ctx, tuple(completed_jobs)) _validate_collected_artifact(artifact, collector_ctx) return artifact def _build_collector_context(analysis: Any, ctx: MDAReplicateJobContext) -> MDACollectorContext: """Build collector context from a completed MDA replicate context. Parameters ---------- analysis : Any Analysis instance that owns the jobs. ctx : MDAReplicateJobContext MDA replicate context used for execution. Returns ------- MDACollectorContext Context passed to the artifact collector. """ settings_fingerprint = analysis.aggregate_settings_fingerprint(ctx.settings) return MDACollectorContext( analysis_name=analysis.name, replicate_context=ctx.replicate_context, frame_selection=ctx.frame_selection, universe_policy=ctx.universe_policy, artifact_store=ctx.artifact_store, settings_fingerprint=settings_fingerprint, warnings=_collector_warnings(ctx), ) def _collector_warnings(ctx: MDAReplicateJobContext) -> tuple[str, ...]: """Return warning messages known before collection. Parameters ---------- ctx : MDAReplicateJobContext MDA replicate context used for execution. Returns ------- tuple of str Warning messages from frame selection and universe provenance. """ warnings: list[str] = [] if ctx.frame_selection.warning_message: warnings.append(ctx.frame_selection.warning_message) policy = ctx.universe_policy.as_dict() provenance = policy.get("provenance") if isinstance(policy, dict) else None provider_warnings = provenance.get("warnings", []) if isinstance(provenance, dict) else [] if isinstance(provider_warnings, list): warnings.extend(str(warning) for warning in provider_warnings) return tuple(warnings) def _validate_collected_artifact( artifact: Any, ctx: MDACollectorContext, ) -> None: """Validate collector output before lifecycle persistence. Parameters ---------- artifact : Any Collector output to validate. ctx : MDACollectorContext Collector context with expected identity and artifact store. """ if not isinstance(artifact, ReplicateArtifact): raise PluginContractError( f"{ctx.analysis_name}.build_mda_collector() returned {type(artifact).__name__}; " "expected ReplicateArtifact" ) expected_identity = { "analysis_name": ctx.analysis_name, "condition_label": ctx.condition_label, "replicate": ctx.replicate, } actual_identity = { "analysis_name": artifact.analysis_name, "condition_label": artifact.condition_label, "replicate": artifact.replicate, } if actual_identity != expected_identity: raise PluginContractError( f"{ctx.analysis_name}.build_mda_collector() returned artifact identity " f"{actual_identity!r}; expected {expected_identity!r}" ) raw_path = raw_mdanalysis_results_path(artifact) if raw_path is not None: raise PluginContractError( f"{ctx.analysis_name}.build_mda_collector() returned raw MDAnalysis Results " f"at {raw_path}; map Results to JSON primitives or sidecars before persistence" ) for sidecar in artifact.sidecars: try: ctx.artifact_store.validate_sidecar(sidecar) except ArtifactStoreError as exc: raise PluginContractError( f"{ctx.analysis_name}.build_mda_collector() returned invalid sidecar " f"{sidecar.path!r}: {exc}" ) from exc