"""Job execution primitives for MDAnalysis-compatible analyses."""
from __future__ import annotations
from collections.abc import Callable, Mapping
from dataclasses import dataclass, field
from typing import Any
from polyzymd.analyses.mda.base import MDAnalysisExtensionError, MDARunKwargs
from polyzymd.analyses.mda.frame_selection import FrameSelection
MDA_BACKEND_RUN_CONTROL_KEYS = frozenset(
{
"backend",
"n_workers",
"n_parts",
"unsupported_backend",
"verbose",
"progressbar_kwargs",
}
)
[docs]
class MDAAnalysisJobError(MDAnalysisExtensionError):
"""Runtime failure raised when an MDAnalysis job violates its contract."""
[docs]
@dataclass(frozen=True)
class MDABackendPolicy:
"""Validated MDAnalysis internal backend options for one job.
The default policy forwards no backend-related keyword arguments so that
PolyzyMD-level parallelism remains the default. Supplying worker, part, or
unsupported-backend options requires an explicit backend to avoid ambiguous
nested-parallelism requests.
"""
backend: Any = None
n_workers: int | None = None
n_parts: int | None = None
unsupported_backend: bool | None = None
verbose: bool | None = None
progressbar_kwargs: Mapping[str, Any] | None = None
[docs]
def __post_init__(self) -> None:
"""Validate backend opt-in and worker/part counts.
Raises
------
ValueError
Raised when worker/part counts are not positive or backend-specific
options are supplied without an explicit backend.
"""
self._validate_positive_count(self.n_workers, field_name="n_workers")
self._validate_positive_count(self.n_parts, field_name="n_parts")
if self.backend is None and (
self.n_workers is not None
or self.n_parts is not None
or self.unsupported_backend is not None
):
raise ValueError(
"n_workers, n_parts, and unsupported_backend require an explicit backend"
)
if self.progressbar_kwargs is not None:
object.__setattr__(self, "progressbar_kwargs", dict(self.progressbar_kwargs))
[docs]
def run_kwargs(self) -> MDARunKwargs:
"""Return keyword arguments for ``AnalysisBase.run``.
Returns
-------
MDARunKwargs
Non-``None`` backend, worker, progress, and verbosity settings.
"""
kwargs = MDARunKwargs()
if self.backend is not None:
kwargs["backend"] = self.backend
if self.n_workers is not None:
kwargs["n_workers"] = self.n_workers
if self.n_parts is not None:
kwargs["n_parts"] = self.n_parts
if self.unsupported_backend is not None:
kwargs["unsupported_backend"] = self.unsupported_backend
if self.verbose is not None:
kwargs["verbose"] = self.verbose
if self.progressbar_kwargs is not None:
kwargs["progressbar_kwargs"] = dict(self.progressbar_kwargs)
return kwargs
[docs]
def is_default(self) -> bool:
"""Return whether the policy forwards no ``run()`` keyword arguments.
Returns
-------
bool
``True`` when no backend, progress, or verbosity options are set.
"""
return len(self.run_kwargs()) == 0
@staticmethod
def _validate_positive_count(value: int | None, *, field_name: str) -> None:
"""Validate an optional positive integer count.
Parameters
----------
value : int or None
Candidate count value.
field_name : str
Name used in validation errors.
Raises
------
ValueError
Raised when ``value`` is not a positive integer.
"""
if value is None:
return
if not isinstance(value, int) or isinstance(value, bool) or value < 1:
raise ValueError(f"{field_name} must be a positive integer")
[docs]
@dataclass(frozen=True)
class MDAUniversePolicy:
"""Lightweight universe provenance and execution policy for one job.
This policy intentionally does not load universes. It carries only metadata
and provenance from the layer that already resolved or supplied a universe.
"""
condition_label: str | None = None
replicate: int | None = None
provenance: Any = None
metadata: Mapping[str, Any] = field(default_factory=dict)
[docs]
def __post_init__(self) -> None:
"""Freeze metadata to avoid accidental mutation after job execution."""
object.__setattr__(self, "metadata", dict(self.metadata))
[docs]
def as_dict(self) -> dict[str, Any]:
"""Serialize lightweight policy metadata to primitive values.
Returns
-------
dict[str, Any]
Dictionary containing condition, replicate, provenance, and metadata.
"""
provenance = self.provenance
if hasattr(provenance, "as_dict"):
provenance = provenance.as_dict()
return {
"condition_label": self.condition_label,
"replicate": self.replicate,
"provenance": provenance,
"metadata": dict(self.metadata),
}
[docs]
@dataclass(frozen=True)
class MDAJobResult:
"""Completed result reference for one MDAnalysis-compatible job."""
name: str
analysis: Any
results: Any
run_kwargs: Mapping[str, Any]
frame_selection: FrameSelection
backend_policy: MDABackendPolicy
universe_policy: MDAUniversePolicy
[docs]
class MDAFunctionAdapter:
"""Adapt a simple function to the MDAnalysis ``run()``/``results`` shape.
The function is called exactly once during ``run()`` as
``function(universe, **frame_kwargs, **function_kwargs)``. The adapter does
not implement trajectory loops; functions that need frame iteration should
use MDAnalysis primitives internally.
"""
[docs]
def __init__(
self,
function: Callable[..., Any],
universe: Any,
*,
function_kwargs: Mapping[str, Any] | None = None,
) -> None:
"""Initialize the function adapter.
Parameters
----------
function : Callable[..., Any]
Function receiving the universe, frame kwargs, and function kwargs.
universe : Any
Already-loaded universe or universe-like object supplied by the
caller.
function_kwargs : Mapping[str, Any] or None, optional
Additional keyword arguments forwarded after frame kwargs.
Raises
------
TypeError
Raised when ``function`` is not callable.
"""
if not callable(function):
raise TypeError("function must be callable")
self.function = function
self.universe = universe
self.function_kwargs = dict(function_kwargs or {})
self.results: dict[str, Any] | Any = {}
[docs]
def run(self, **frame_kwargs: Any) -> MDAFunctionAdapter:
"""Run the adapted function once and store normalized results.
Parameters
----------
**frame_kwargs : Any
Frame-selection keyword arguments from ``FrameSelection``.
Returns
-------
MDAFunctionAdapter
This adapter with ``results`` populated.
Raises
------
MDAAnalysisJobError
Raised when MDAnalysis backend or run-control kwargs are supplied
directly to the function adapter.
"""
self._validate_frame_kwargs(frame_kwargs)
value = self.function(self.universe, **frame_kwargs, **self.function_kwargs)
self.results = self._normalize_result(value)
return self
@staticmethod
def _validate_frame_kwargs(frame_kwargs: Mapping[str, Any]) -> None:
"""Validate direct adapter ``run()`` keyword arguments.
Parameters
----------
frame_kwargs : Mapping[str, Any]
Candidate frame-selection keyword arguments supplied to ``run()``.
Raises
------
MDAAnalysisJobError
Raised when backend or run-control keyword arguments are supplied.
"""
blocked_keys = sorted(MDA_BACKEND_RUN_CONTROL_KEYS.intersection(frame_kwargs))
if blocked_keys:
joined_keys = ", ".join(blocked_keys)
raise MDAAnalysisJobError(
"MDAFunctionAdapter.run() accepts only frame-selection kwargs; "
f"received backend/run-control kwargs: {joined_keys}"
)
@staticmethod
def _normalize_result(value: Any) -> dict[str, Any]:
"""Normalize mapping or scalar function output.
Parameters
----------
value : Any
Return value from the adapted function.
Returns
-------
dict[str, Any]
Mapping output converted to ``dict`` or scalar output under the
``"value"`` key.
"""
if isinstance(value, Mapping):
return dict(value)
return {"value": value}
[docs]
@dataclass
class MDAAnalysisJob:
"""Execute exactly one MDAnalysis ``AnalysisBase``-compatible job.
A job receives either a ready analysis object or a zero-argument factory
that constructs one. Execution merges ``FrameSelection.run_kwargs()`` with
validated backend policy kwargs and calls ``analysis.run(**kwargs)`` once.
"""
name: str
frame_selection: FrameSelection = field(default_factory=FrameSelection)
analysis: Any = None
analysis_factory: Callable[[], Any] | None = None
backend_policy: MDABackendPolicy = field(default_factory=MDABackendPolicy)
universe_policy: MDAUniversePolicy = field(default_factory=MDAUniversePolicy)
result: MDAJobResult | None = field(default=None, init=False)
[docs]
def __post_init__(self) -> None:
"""Validate constructor misuse before execution.
Raises
------
ValueError
Raised when both or neither analysis inputs are supplied.
TypeError
Raised when supplied inputs do not satisfy the job construction
contract.
"""
if not isinstance(self.name, str) or not self.name:
raise ValueError("name must be a non-empty string")
self._validate_collaborators()
has_analysis = self.analysis is not None
has_factory = self.analysis_factory is not None
if has_analysis == has_factory:
raise ValueError("Provide exactly one of analysis or analysis_factory")
if has_analysis and not callable(getattr(self.analysis, "run", None)):
raise TypeError("analysis must provide a callable run(**kwargs) method")
if has_factory and not callable(self.analysis_factory):
raise TypeError("analysis_factory must be callable")
if isinstance(self.analysis, MDAFunctionAdapter):
self._validate_function_backend_policy()
[docs]
@classmethod
def from_function(
cls,
name: str,
function: Callable[..., Any],
universe: Any,
*,
frame_selection: FrameSelection | None = None,
backend_policy: MDABackendPolicy | None = None,
universe_policy: MDAUniversePolicy | None = None,
function_kwargs: Mapping[str, Any] | None = None,
) -> MDAAnalysisJob:
"""Create a job from a function and already-loaded universe.
Parameters
----------
name : str
Job name used in result metadata and error messages.
function : Callable[..., Any]
Function called once by ``MDAFunctionAdapter.run``.
universe : Any
Already-loaded universe or universe-like object.
frame_selection : FrameSelection or None, optional
Frame selection to forward to the function adapter.
backend_policy : MDABackendPolicy or None, optional
Backend policy for this job.
universe_policy : MDAUniversePolicy or None, optional
Lightweight universe provenance policy.
function_kwargs : Mapping[str, Any] or None, optional
Additional keyword arguments forwarded to the function.
Returns
-------
MDAAnalysisJob
Job wrapping the function adapter.
"""
resolved_backend_policy = backend_policy or MDABackendPolicy()
if not resolved_backend_policy.is_default():
raise ValueError("Function-adapter jobs do not accept backend policy kwargs")
return cls(
name=name,
analysis=MDAFunctionAdapter(
function,
universe,
function_kwargs=function_kwargs,
),
frame_selection=frame_selection or FrameSelection(),
backend_policy=resolved_backend_policy,
universe_policy=universe_policy or MDAUniversePolicy(),
)
@property
def results(self) -> Any:
"""Return completed job results.
Returns
-------
Any
Results object from the completed analysis.
Raises
------
MDAAnalysisJobError
Raised when the job has not run yet.
"""
if self.result is None:
raise MDAAnalysisJobError(f"MDAAnalysisJob '{self.name}' has not run yet")
return self.result.results
[docs]
def run(self) -> MDAJobResult:
"""Execute the wrapped analysis and store the completed job result.
Returns
-------
MDAJobResult
Completed job result reference.
Raises
------
MDAAnalysisJobError
Raised when factory construction, analysis execution, or result
collection violates the runtime job contract.
"""
analysis = self._resolve_analysis()
run_kwargs = self._build_run_kwargs(analysis)
try:
run_return = analysis.run(**run_kwargs)
except Exception as exc:
raise MDAAnalysisJobError(
f"MDAAnalysisJob '{self.name}' failed during analysis.run(): {exc}"
) from exc
completed_analysis = run_return if hasattr(run_return, "results") else analysis
if not hasattr(completed_analysis, "results"):
raise MDAAnalysisJobError(
f"MDAAnalysisJob '{self.name}' completed without a results attribute"
)
self.result = MDAJobResult(
name=self.name,
analysis=completed_analysis,
results=completed_analysis.results,
run_kwargs=run_kwargs,
frame_selection=self.frame_selection,
backend_policy=self.backend_policy,
universe_policy=self.universe_policy,
)
return self.result
[docs]
def execute(self) -> MDAJobResult:
"""Execute the job.
Returns
-------
MDAJobResult
Completed job result reference.
"""
return self.run()
def _resolve_analysis(self) -> Any:
"""Return the analysis object for this run.
Returns
-------
Any
Analysis object with a callable ``run`` method.
Raises
------
MDAAnalysisJobError
Raised when the factory fails or returns an invalid object.
"""
if self.analysis is not None:
return self.analysis
try:
analysis = self.analysis_factory()
except Exception as exc:
raise MDAAnalysisJobError(
f"MDAAnalysisJob '{self.name}' failed to construct analysis: {exc}"
) from exc
if not callable(getattr(analysis, "run", None)):
raise MDAAnalysisJobError(
f"MDAAnalysisJob '{self.name}' factory returned an object without run(**kwargs)"
)
return analysis
def _build_run_kwargs(self, analysis: Any) -> dict[str, Any]:
"""Merge frame selection and backend policy keyword arguments.
Parameters
----------
analysis : Any
Analysis object selected for this execution.
Returns
-------
dict[str, Any]
Keyword arguments forwarded to ``analysis.run``.
"""
if isinstance(analysis, MDAFunctionAdapter):
self._validate_function_backend_policy()
return {
**self.frame_selection.run_kwargs(),
**self.backend_policy.run_kwargs(),
}
def _validate_collaborators(self) -> None:
"""Validate policy collaborators supplied to the job.
Raises
------
TypeError
Raised when a collaborator does not implement the expected concrete
extension-layer policy type.
"""
if not isinstance(self.frame_selection, FrameSelection):
raise TypeError("frame_selection must be a FrameSelection instance")
if not isinstance(self.backend_policy, MDABackendPolicy):
raise TypeError("backend_policy must be an MDABackendPolicy instance")
if not isinstance(self.universe_policy, MDAUniversePolicy):
raise TypeError("universe_policy must be an MDAUniversePolicy instance")
def _validate_function_backend_policy(self) -> None:
"""Reject backend policy kwargs for function-adapter jobs.
Raises
------
MDAAnalysisJobError
Raised when a factory-backed function adapter is paired with a
non-default backend policy at execution time.
ValueError
Raised when a direct function-adapter job is constructed with a
non-default backend policy.
"""
if self.backend_policy.is_default():
return
message = "Function-adapter jobs do not accept backend policy kwargs"
if self.analysis is None:
raise MDAAnalysisJobError(message)
raise ValueError(message)