Source code for polyzymd.analyses.mda.frame_selection

"""Frame-selection helpers for MDAnalysis extension-layer jobs."""

from __future__ import annotations

from collections.abc import Iterable, Sized
from dataclasses import dataclass, field
from operator import index
from typing import TYPE_CHECKING, Any

from polyzymd.analyses.mda.base import MDARunKwargs
from polyzymd.analyses.shared.window import resolve_trajectory_window

if TYPE_CHECKING:
    from polyzymd.analyses.shared.window import TrajectoryWindow


def _is_scalar_frame_selector(frames: Any) -> bool:
    """Return whether ``frames`` is a scalar selector value.

    Parameters
    ----------
    frames : Any
        Candidate explicit frame selector.

    Returns
    -------
    bool
        ``True`` when the selector is a scalar that MDAnalysis should not
        receive through the ``frames`` keyword.
    """

    return isinstance(frames, (bool, int, float, complex))


def _is_boolean_frame_value(frame: Any) -> bool:
    """Return whether a frame selector value is a boolean mask entry.

    Parameters
    ----------
    frame : Any
        Candidate element from an explicit ``frames`` selector.

    Returns
    -------
    bool
        ``True`` for Python booleans and NumPy-style boolean scalar values.
    """

    if isinstance(frame, bool):
        return True
    frame_type = type(frame)
    return (
        frame_type.__name__ == "bool_"
        and frame_type.__module__.split(".", maxsplit=1)[0] == "numpy"
    )


def _is_integer_frame_value(frame: Any) -> bool:
    """Return whether a frame selector value is an integer index.

    Parameters
    ----------
    frame : Any
        Candidate element from an explicit ``frames`` selector.

    Returns
    -------
    bool
        ``True`` when ``frame`` can be used as an integer index and is not a
        boolean mask value.
    """

    if _is_boolean_frame_value(frame):
        return False
    try:
        index(frame)
    except TypeError:
        return False
    return True


def _coerce_frames(frames: Any) -> tuple[Any, ...]:
    """Validate and freeze an explicit frame selector.

    Parameters
    ----------
    frames : Any
        Candidate frame sequence or boolean mask.

    Returns
    -------
    tuple[Any, ...]
        Immutable frame selector suitable for forwarding to MDAnalysis.

    Raises
    ------
    ValueError
        Raised when ``frames`` is empty, scalar, string-like, or unsized.
    """

    if _is_scalar_frame_selector(frames):
        raise ValueError("frames must be a non-empty sequence or boolean mask, not a scalar")
    if isinstance(frames, (str, bytes)):
        raise ValueError("frames must be a non-empty sequence or boolean mask, not a string")
    if not isinstance(frames, Sized) or not isinstance(frames, Iterable):
        raise ValueError("frames must be a sized iterable sequence or boolean mask")

    try:
        frozen_frames = tuple(frames)
    except TypeError as exc:
        raise ValueError("frames must be a sized iterable sequence or boolean mask") from exc
    if len(frozen_frames) == 0:
        raise ValueError("frames must contain at least one entry")
    return frozen_frames


def _selected_count_from_frames(frames: tuple[Any, ...], n_frames_total: int | None) -> int:
    """Return the selected frame count for an explicit frame selector.

    Parameters
    ----------
    frames : tuple[Any, ...]
        Frozen frame sequence or boolean mask.
    n_frames_total : int | None
        Total trajectory length when known.

    Returns
    -------
    int
        Number of frames selected by ``frames``.

    Raises
    ------
    ValueError
        Raised when a boolean mask has the wrong length or selects no frames.
    """

    is_bool_mask = all(_is_boolean_frame_value(frame) for frame in frames)
    is_integer_indices = all(_is_integer_frame_value(frame) for frame in frames)
    if not is_bool_mask and not is_integer_indices:
        raise ValueError("frames must contain only integer frame indices or boolean mask values")
    if is_integer_indices:
        return len(frames)

    if n_frames_total is not None and len(frames) != n_frames_total:
        raise ValueError(
            "Boolean frame mask length "
            f"{len(frames)} does not match trajectory length {n_frames_total}"
        )

    n_frames_selected = sum(1 for frame in frames if frame)
    if n_frames_selected == 0:
        raise ValueError("frames must select at least one frame")
    return n_frames_selected


def _selected_count_from_slice(
    *,
    start: int | None,
    stop: int | None,
    step: int | None,
    n_frames_total: int | None,
) -> int | None:
    """Return the selected frame count for start/stop/step selectors.

    Parameters
    ----------
    start : int | None
        Inclusive start frame.
    stop : int | None
        Exclusive stop frame.
    step : int | None
        Frame stride.
    n_frames_total : int | None
        Total trajectory length when known.

    Returns
    -------
    int | None
        Number of selected frames when it can be resolved, otherwise ``None``.
    """

    resolved_start = 0 if start is None else start
    resolved_stop = n_frames_total if stop is None else stop
    if resolved_stop is None:
        return None
    resolved_step = 1 if step is None else step
    return len(range(resolved_start, resolved_stop, resolved_step))


[docs] @dataclass(frozen=True) class FrameSelection: """Validated MDAnalysis ``run()`` frame-selection arguments. ``FrameSelection`` is the import-light bridge between PolyzyMD's equilibration/window semantics and MDAnalysis ``AnalysisBase.run`` keyword arguments. Explicit ``frames`` selectors are mutually exclusive with ``start``, ``stop``, and ``step`` because MDAnalysis treats them as separate selection modes. Parameters ---------- start : int | None, optional Inclusive start frame passed to ``run(start=...)``. stop : int | None, optional Exclusive stop frame passed to ``run(stop=...)``. step : int | None, optional Frame stride passed to ``run(step=...)``. frames : Any, optional Explicit frame indices or boolean mask passed to ``run(frames=...)``. equilibration : str | None, optional Original equilibration setting used to derive this selection. equilibration_start : int | None, optional Start frame implied by equilibration alone. equilibration_ps : float | None, optional Equilibration time converted to picoseconds. timestep_ps : float | None, optional Trajectory timestep in picoseconds. first_frame_time_ps : float | None, optional Absolute MDAnalysis timestamp of loaded frame 0 in picoseconds, when available. selected_start_time_ps : float | None, optional Timestamp of the selected start frame in the active time reference. equilibration_time_reference : str | None, optional Time reference used to interpret ``equilibration``. n_frames_total : int | None, optional Total trajectory frame count when known. warning_message : str | None, optional Non-fatal warning from equilibration/window validation. """ start: int | None = None stop: int | None = None step: int | None = None frames: Any = None equilibration: str | None = None equilibration_start: int | None = None equilibration_ps: float | None = None timestep_ps: float | None = None first_frame_time_ps: float | None = None selected_start_time_ps: float | None = None equilibration_time_reference: str | None = None n_frames_total: int | None = None warning_message: str | None = None n_frames_selected: int | None = field(default=None, init=False)
[docs] def __post_init__(self) -> None: """Validate mutually exclusive MDAnalysis frame-selection modes. Raises ------ ValueError Raised when the selector would produce invalid MDAnalysis ``run`` keyword arguments or an empty known frame selection. """ self._validate_total_frame_count() if self.frames is not None: self._validate_explicit_frames_mode() return self._validate_slice_mode()
[docs] def run_kwargs(self) -> MDARunKwargs: """Return keyword arguments for ``MDAnalysis`` ``AnalysisBase.run``. Returns ------- MDARunKwargs ``frames`` alone when explicit frames are set, otherwise the non-``None`` ``start``, ``stop``, and ``step`` values. """ if self.frames is not None: return MDARunKwargs(frames=list(self.frames)) kwargs = MDARunKwargs() if self.start is not None: kwargs["start"] = self.start if self.stop is not None: kwargs["stop"] = self.stop if self.step is not None: kwargs["step"] = self.step return kwargs
[docs] @classmethod def from_trajectory_window(cls, window: TrajectoryWindow) -> FrameSelection: """Build a frame selection from a resolved PolyzyMD trajectory window. Parameters ---------- window : TrajectoryWindow Existing validated shared trajectory window. Returns ------- FrameSelection Selection that forwards the same ``start``, ``stop``, and ``step`` values to MDAnalysis while preserving window provenance. """ return cls( start=window.start, stop=window.stop, step=window.step, equilibration=getattr(window, "equilibration", None), equilibration_start=window.equilibration_start, equilibration_ps=window.equilibration_ps, timestep_ps=window.timestep_ps, first_frame_time_ps=window.first_frame_time_ps, selected_start_time_ps=window.selected_start_time_ps, equilibration_time_reference=window.equilibration_time_reference, n_frames_total=window.n_frames_total, warning_message=window.warning_message, )
[docs] @classmethod def from_equilibration( cls, *, equilibration: str, n_frames_total: int, timestep_ps: float, start: int | None = None, stop: int | None = None, step: int = 1, min_frames: int = 1, first_frame_time_ps: float | None = None, ) -> FrameSelection: """Resolve PolyzyMD equilibration/window settings to a frame selection. Finite first-frame timestamps make equilibration absolute in MDAnalysis trajectory time; missing timestamps keep the stale loaded-frame-relative origin. Parameters ---------- equilibration : str Equilibration time string such as ``"10ns"``. n_frames_total : int Total number of trajectory frames. timestep_ps : float Trajectory timestep in picoseconds. start : int | None, optional Absolute start frame. When ``None``, equilibration determines the start frame. stop : int | None, optional Absolute exclusive stop frame. When ``None``, the trajectory end is used. step : int, optional Frame stride, by default 1. min_frames : int, optional Minimum required number of selected frames, by default 1. first_frame_time_ps : float | None, optional Absolute MDAnalysis timestamp of loaded frame 0 in picoseconds. Returns ------- FrameSelection Validated selection suitable for MDAnalysis ``run()``. """ window = resolve_trajectory_window( equilibration=equilibration, n_frames_total=n_frames_total, timestep_ps=timestep_ps, start=start, stop=stop, step=step, min_frames=min_frames, first_frame_time_ps=first_frame_time_ps, ) selection = cls.from_trajectory_window(window) object.__setattr__(selection, "equilibration", equilibration) return selection
def _validate_total_frame_count(self) -> None: """Validate the optional total frame count. Raises ------ ValueError Raised when the known trajectory length is invalid. """ if self.n_frames_total is not None and self.n_frames_total < 1: raise ValueError("n_frames_total must be >= 1 when provided") def _validate_explicit_frames_mode(self) -> None: """Validate explicit ``frames`` selectors. Raises ------ ValueError Raised when ``frames`` is mixed with slice arguments or selects no frames. """ if self.start is not None or self.stop is not None or self.step is not None: raise ValueError("frames cannot be combined with start, stop, or step") frozen_frames = _coerce_frames(self.frames) n_frames_selected = _selected_count_from_frames(frozen_frames, self.n_frames_total) object.__setattr__(self, "frames", frozen_frames) object.__setattr__(self, "n_frames_selected", n_frames_selected) def _validate_slice_mode(self) -> None: """Validate ``start``/``stop``/``step`` selectors. Raises ------ ValueError Raised when the slice arguments are invalid or select no known frames. """ if self.step is not None and self.step < 1: raise ValueError(f"step must be >= 1, got {self.step}") if self.start is not None and self.start < 0: raise ValueError(f"start must be >= 0, got {self.start}") if self.stop is not None and self.start is not None and self.stop <= self.start: raise ValueError(f"stop={self.stop} must be greater than start={self.start}") if self.n_frames_total is not None: if self.start is not None and self.start >= self.n_frames_total: raise ValueError( f"start={self.start} is outside the trajectory range [0, {self.n_frames_total})" ) if self.stop is not None and self.stop > self.n_frames_total: raise ValueError( f"stop={self.stop} exceeds trajectory length {self.n_frames_total}" ) n_frames_selected = _selected_count_from_slice( start=self.start, stop=self.stop, step=self.step, n_frames_total=self.n_frames_total, ) if n_frames_selected == 0: raise ValueError("Frame selection must select at least one frame") object.__setattr__(self, "n_frames_selected", n_frames_selected)