Source code for polyzymd.analyses.mda.plugin

"""Collector interfaces for MDAnalysis job outputs."""

from __future__ import annotations

import math
from collections.abc import Mapping, Sequence
from dataclasses import dataclass
from operator import index as operator_index
from pathlib import Path
from typing import TYPE_CHECKING, Any, Protocol

from pydantic import BaseModel

from polyzymd.analyses.exceptions import PluginContractError
from polyzymd.analyses.mda.artifacts import (
    ReplicateArtifact,
    raw_mdanalysis_results_path,
)
from polyzymd.analyses.mda.frame_selection import FrameSelection
from polyzymd.analyses.mda.job import MDAJobResult, MDAUniversePolicy
from polyzymd.analyses.mda.store import ArtifactStore

if TYPE_CHECKING:
    from polyzymd.analyses._framework.contexts import ReplicateContext


[docs] @dataclass(frozen=True) class MDACollectorContext: """Context supplied to an MDAnalysis artifact collector. The context contains framework identity and provenance for one replicate so collectors can map raw job results into PolyzyMD-owned artifacts without reaching back into the orchestrator. """ analysis_name: str replicate_context: ReplicateContext frame_selection: FrameSelection universe_policy: MDAUniversePolicy artifact_store: ArtifactStore settings_fingerprint: str | None = None warnings: Sequence[str] = ()
[docs] def __post_init__(self) -> None: """Freeze warning messages as strings for artifact reuse.""" object.__setattr__(self, "warnings", tuple(str(warning) for warning in self.warnings))
@property def condition_label(self) -> str: """Return the simulation condition label. Returns ------- str Condition label from the framework context. """ return self.replicate_context.condition.label @property def replicate(self) -> int: """Return the one-indexed replicate ID. Returns ------- int Replicate ID from the framework context. """ return self.replicate_context.replicate @property def output_dir(self) -> Path: """Return the replicate output directory. Returns ------- Path Directory owned by this replicate analysis run. """ return self.replicate_context.output_dir @property def result_path(self) -> Path: """Return the canonical replicate result path. Returns ------- Path Canonical artifact JSON path for this replicate. """ return self.replicate_context.result_path @property def settings(self) -> BaseModel: """Return resolved plugin settings. Returns ------- BaseModel Settings model supplied by the public lifecycle. """ return self.replicate_context.settings
[docs] class MDAArtifactCollector(Protocol): """Protocol for converting completed MDAnalysis jobs to an artifact."""
[docs] def __call__( self, ctx: MDACollectorContext, completed_jobs: Sequence[MDAJobResult], ) -> ReplicateArtifact: """Collect completed jobs into one replicate artifact. Parameters ---------- ctx : MDACollectorContext Framework-provided collector context for one replicate. completed_jobs : sequence of MDAJobResult Completed MDAnalysis-compatible jobs. Returns ------- ReplicateArtifact PolyzyMD-owned artifact for this replicate. """
[docs] class StrictJSONMDAResultCollector: """Default collector for jobs that already return strict JSON values. This collector preserves the P2-001 simple job behavior while rejecting raw MDAnalysis ``Results`` containers and other non-JSON values. Analyses with rich ``Results`` objects should implement a custom collector that maps data to primitive payloads and sidecars. """
[docs] def __call__( self, ctx: MDACollectorContext, completed_jobs: Sequence[MDAJobResult], ) -> ReplicateArtifact: """Collect completed jobs into a strict JSON artifact. Parameters ---------- ctx : MDACollectorContext Framework-provided collector context for one replicate. completed_jobs : sequence of MDAJobResult Completed MDAnalysis-compatible jobs. Returns ------- ReplicateArtifact JSON-compatible replicate artifact. """ job_payloads = [ _job_result_payload(job, analysis_name=ctx.analysis_name) for job in completed_jobs ] provenance = strict_json_payload( ctx.universe_policy.as_dict(), analysis_name=ctx.analysis_name ) metadata: dict[str, Any] = {"result_kind": "mda_replicate_jobs"} if ctx.settings_fingerprint is not None: metadata["settings_fingerprint"] = ctx.settings_fingerprint return ReplicateArtifact( analysis_name=ctx.analysis_name, condition_label=ctx.condition_label, replicate=ctx.replicate, payload={"jobs": job_payloads, "n_jobs": len(job_payloads)}, provenance={ "source": "mda_job_lifecycle", "frame_selection": frame_selection_payload(ctx.frame_selection), "universe_policy": provenance, }, metadata=metadata, warnings=list(ctx.warnings), )
[docs] def frame_selection_payload(frame_selection: FrameSelection) -> dict[str, Any]: """Serialize frame-selection provenance to primitive values. Parameters ---------- frame_selection : FrameSelection Frame selection used for a job or replicate context. Returns ------- dict[str, Any] JSON-compatible frame-selection metadata. """ return { "start": frame_selection.start, "stop": frame_selection.stop, "step": frame_selection.step, "frames": _frame_selector_payload(frame_selection.frames), "equilibration": frame_selection.equilibration, "equilibration_start": frame_selection.equilibration_start, "equilibration_ps": frame_selection.equilibration_ps, "timestep_ps": frame_selection.timestep_ps, "first_frame_time_ps": frame_selection.first_frame_time_ps, "selected_start_time_ps": frame_selection.selected_start_time_ps, "equilibration_time_reference": frame_selection.equilibration_time_reference, "n_frames_total": frame_selection.n_frames_total, "n_frames_selected": frame_selection.n_frames_selected, "warning_message": frame_selection.warning_message, }
def _frame_selector_payload(frames: Any) -> list[int | bool] | None: """Serialize explicit frame selectors to JSON-safe Python scalars.""" if frames is None: return None payload: list[int | bool] = [] for frame in frames: if _is_boolean_frame_value(frame): payload.append(bool(frame)) else: try: payload.append(operator_index(frame)) except TypeError as exc: raise PluginContractError( "frame_selection.build_mda_jobs() produced non-integer explicit frame " f"selector {frame!r}; use integer indices or a boolean mask" ) from exc return payload def _is_boolean_frame_value(frame: Any) -> bool: """Return whether a frame selector value is boolean-like.""" if isinstance(frame, bool): return True frame_type = type(frame) return ( frame_type.__name__ == "bool_" and frame_type.__module__.split(".", maxsplit=1)[0] == "numpy" )
[docs] def strict_json_payload(value: Any, *, analysis_name: str) -> Any: """Convert supported values to strict JSON-compatible primitives. Parameters ---------- value : Any Candidate payload returned by an MDA job. analysis_name : str Analysis name for diagnostics. Returns ------- Any JSON-compatible primitive, list, or dictionary. """ raw_path = raw_mdanalysis_results_path(value) if raw_path is not None: raise PluginContractError( f"{analysis_name}.build_mda_jobs() produced raw MDAnalysis Results at {raw_path}; " "implement build_mda_collector() to map Results to JSON primitives or sidecars" ) if value is None or isinstance(value, (str, bool)): return value if isinstance(value, float): if not math.isfinite(value): raise PluginContractError( f"{analysis_name}.build_mda_jobs() produced non-finite float result {value!r}; " "implement build_mda_collector() to map Results to JSON primitives or sidecars" ) return value if isinstance(value, int): return value if isinstance(value, Path): return str(value) if isinstance(value, BaseModel): return strict_json_payload(value.model_dump(mode="json"), analysis_name=analysis_name) if isinstance(value, Mapping): payload = {} for key, item in value.items(): if not isinstance(key, str): raise PluginContractError( f"{analysis_name}.build_mda_jobs() produced non-string mapping key " f"{key!r}; implement build_mda_collector() to map Results to JSON " "primitives or sidecars" ) payload[key] = strict_json_payload(item, analysis_name=analysis_name) return payload if isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray)): return [strict_json_payload(item, analysis_name=analysis_name) for item in value] raise PluginContractError( f"{analysis_name}.build_mda_jobs() produced non-JSON-serializable " f"{type(value).__name__} results; implement build_mda_collector() to map Results " "to JSON primitives or sidecars" )
def _job_result_payload(job: MDAJobResult, *, analysis_name: str) -> dict[str, Any]: """Serialize one completed job result to JSON-compatible primitives. Parameters ---------- job : MDAJobResult Completed job result reference. analysis_name : str Analysis name for diagnostics. Returns ------- dict[str, Any] JSON-compatible job payload. """ return { "name": job.name, "results": strict_json_payload(job.results, analysis_name=f"{analysis_name}.{job.name}"), "run_kwargs": strict_json_payload( dict(job.run_kwargs), analysis_name=f"{analysis_name}.{job.name}" ), "frame_selection": frame_selection_payload(job.frame_selection), "backend_policy": strict_json_payload( job.backend_policy.run_kwargs(), analysis_name=f"{analysis_name}.{job.name}" ), "universe_policy": strict_json_payload( job.universe_policy.as_dict(), analysis_name=f"{analysis_name}.{job.name}" ), }