Source code for polyzymd.analyses.mda.frame_selection

"""Frame-selection helpers for MDAnalysis extension-layer jobs."""

from __future__ import annotations

from collections.abc import Iterable, Sized
from dataclasses import dataclass, field
from operator import index
from typing import TYPE_CHECKING, Any

from polyzymd.analyses.mda.base import MDARunKwargs
from polyzymd.analyses.shared.window import resolve_trajectory_window

if TYPE_CHECKING:
    from polyzymd.analyses.shared.window import TrajectoryWindow


def _is_scalar_frame_selector(frames: Any) -> bool:
    """Return whether ``frames`` is a scalar selector value.

    Parameters
    ----------
    frames : Any
        Candidate explicit frame selector.

    Returns
    -------
    bool
        ``True`` when the selector is a scalar that MDAnalysis should not
        receive through the ``frames`` keyword.
    """

    return isinstance(frames, (bool, int, float, complex))


def _normalize_scalar_value(value: Any) -> Any:
    """Return NumPy-like scalar values as Python scalars.

    The helper intentionally avoids importing NumPy so lightweight analysis
    modules can normalize scalar provenance in environments without the full
    simulation stack.

    Parameters
    ----------
    value : Any
        Candidate value to normalize.

    Returns
    -------
    Any
        Python ``bool``, ``int``, or ``float`` when ``value`` is a NumPy-like
        scalar value, otherwise the original value.
    """

    shape = getattr(value, "shape", None)
    if shape not in (None, ()):  # Avoid turning array selectors into scalar selectors
        return value

    item = getattr(value, "item", None)
    if callable(item):
        try:
            native_value = item()
        except (TypeError, ValueError):
            native_value = value
        if isinstance(native_value, (bool, int, float)):
            return native_value

    return value


def _is_boolean_frame_value(frame: Any) -> bool:
    """Return whether a frame selector value is a boolean mask entry.

    Parameters
    ----------
    frame : Any
        Candidate element from an explicit ``frames`` selector.

    Returns
    -------
    bool
        ``True`` for Python booleans and NumPy-style boolean scalar values.
    """

    if isinstance(frame, bool):
        return True
    try:
        import numpy as np
    except ImportError:
        return False
    return isinstance(frame, np.bool_)


def _is_integer_frame_value(frame: Any) -> bool:
    """Return whether a frame selector value is an integer index.

    Parameters
    ----------
    frame : Any
        Candidate element from an explicit ``frames`` selector.

    Returns
    -------
    bool
        ``True`` when ``frame`` can be used as an integer index and is not a
        boolean mask value.
    """

    if _is_boolean_frame_value(frame):
        return False
    try:
        index(frame)
    except TypeError:
        return False
    return True


def _coerce_frames(frames: Any) -> tuple[Any, ...]:
    """Validate and freeze an explicit frame selector.

    Parameters
    ----------
    frames : Any
        Candidate frame sequence or boolean mask.

    Returns
    -------
    tuple[Any, ...]
        Immutable frame selector suitable for forwarding to MDAnalysis.

    Raises
    ------
    ValueError
        Raised when ``frames`` is empty, scalar, string-like, or unsized.
    """

    if _is_scalar_frame_selector(frames):
        raise ValueError("frames must be a non-empty sequence or boolean mask, not a scalar")
    if isinstance(frames, (str, bytes)):
        raise ValueError("frames must be a non-empty sequence or boolean mask, not a string")
    if not isinstance(frames, Sized) or not isinstance(frames, Iterable):
        raise ValueError("frames must be a sized iterable sequence or boolean mask")

    try:
        frozen_frames = tuple(frames)
    except TypeError as exc:
        raise ValueError("frames must be a sized iterable sequence or boolean mask") from exc
    if len(frozen_frames) == 0:
        raise ValueError("frames must contain at least one entry")
    return frozen_frames


def _normalize_frame_selector_values(frames: Any) -> list[int | bool]:
    """Return explicit frame selectors as Python integer or boolean scalars.

    Parameters
    ----------
    frames : Any
        Explicit frame index sequence or boolean mask.

    Returns
    -------
    list of int or bool
        Frame selector values with NumPy-like scalar values converted to Python
        ``int`` or ``bool`` instances.

    Raises
    ------
    ValueError
        Raised when a selector entry is neither integer-like nor boolean-like.
    """

    normalized: list[int | bool] = []
    for frame in frames:
        if _is_boolean_frame_value(frame):
            normalized.append(bool(frame))
            continue
        try:
            normalized.append(index(frame))
        except TypeError as exc:
            raise ValueError(
                "frames must contain only integer frame indices or boolean mask values"
            ) from exc
    return normalized


def _selected_count_from_frames(frames: tuple[Any, ...], n_frames_total: int | None) -> int:
    """Return the selected frame count for an explicit frame selector.

    Parameters
    ----------
    frames : tuple[Any, ...]
        Frozen frame sequence or boolean mask.
    n_frames_total : int | None
        Total trajectory length when known.

    Returns
    -------
    int
        Number of frames selected by ``frames``.

    Raises
    ------
    ValueError
        Raised when explicit integer indices are outside a known trajectory range,
        or when a boolean mask has the wrong length or selects no frames.
    """

    is_bool_mask = all(_is_boolean_frame_value(frame) for frame in frames)
    is_integer_indices = all(_is_integer_frame_value(frame) for frame in frames)
    if not is_bool_mask and not is_integer_indices:
        raise ValueError("frames must contain only integer frame indices or boolean mask values")
    if is_integer_indices:
        if n_frames_total is not None:
            for frame in frames:
                frame_index = index(frame)
                if frame_index < 0 or frame_index >= n_frames_total:
                    raise ValueError(
                        f"Explicit frame index {frame_index} is outside the trajectory range "
                        f"[0, {n_frames_total})"
                    )
        return len(frames)

    if n_frames_total is not None and len(frames) != n_frames_total:
        raise ValueError(
            "Boolean frame mask length "
            f"{len(frames)} does not match trajectory length {n_frames_total}"
        )

    n_frames_selected = sum(1 for frame in frames if frame)
    if n_frames_selected == 0:
        raise ValueError("frames must select at least one frame")
    return n_frames_selected


def _selected_count_from_slice(
    *,
    start: int | None,
    stop: int | None,
    step: int | None,
    n_frames_total: int | None,
) -> int | None:
    """Return the selected frame count for start/stop/step selectors.

    Parameters
    ----------
    start : int | None
        Inclusive start frame.
    stop : int | None
        Exclusive stop frame.
    step : int | None
        Frame stride.
    n_frames_total : int | None
        Total trajectory length when known.

    Returns
    -------
    int | None
        Number of selected frames when it can be resolved, otherwise ``None``.
    """

    resolved_start = 0 if start is None else start
    resolved_stop = n_frames_total if stop is None else stop
    if resolved_stop is None:
        return None
    resolved_step = 1 if step is None else step
    return len(range(resolved_start, resolved_stop, resolved_step))


[docs] @dataclass(frozen=True) class FrameSelection: """Validated MDAnalysis ``run()`` frame-selection arguments. ``FrameSelection`` is the import-light bridge between PolyzyMD's equilibration/window semantics and MDAnalysis ``AnalysisBase.run`` keyword arguments. Explicit ``frames`` selectors are mutually exclusive with ``start``, ``stop``, and ``step`` because MDAnalysis treats them as separate selection modes. Parameters ---------- start : int | None, optional Inclusive start frame passed to ``run(start=...)``. stop : int | None, optional Exclusive stop frame passed to ``run(stop=...)``. step : int | None, optional Frame stride passed to ``run(step=...)``. frames : Any, optional Explicit frame indices or boolean mask passed to ``run(frames=...)``. equilibration : str | None, optional Original equilibration setting used to derive this selection. equilibration_start : int | None, optional Start frame implied by equilibration alone. equilibration_ps : float | None, optional Equilibration time converted to picoseconds. timestep_ps : float | None, optional Trajectory timestep in picoseconds. first_frame_time_ps : float | None, optional Absolute MDAnalysis timestamp of loaded frame 0 in picoseconds, when available. selected_start_time_ps : float | None, optional Timestamp of the selected start frame in the active time reference. equilibration_time_reference : str | None, optional Time reference used to interpret ``equilibration``. n_frames_total : int | None, optional Total trajectory frame count when known. warning_message : str | None, optional Non-fatal warning from equilibration/window validation. """ start: int | None = None stop: int | None = None step: int | None = None frames: Any = None equilibration: str | None = None equilibration_start: int | None = None equilibration_ps: float | None = None timestep_ps: float | None = None first_frame_time_ps: float | None = None selected_start_time_ps: float | None = None equilibration_time_reference: str | None = None n_frames_total: int | None = None warning_message: str | None = None n_frames_selected: int | None = field(default=None, init=False)
[docs] def __post_init__(self) -> None: """Validate mutually exclusive MDAnalysis frame-selection modes. Raises ------ ValueError Raised when the selector would produce invalid MDAnalysis ``run`` keyword arguments or an empty known frame selection. """ self._normalize_scalar_fields() self._validate_total_frame_count() if self.frames is not None: self._validate_explicit_frames_mode() return self._validate_slice_mode()
[docs] def run_kwargs(self) -> MDARunKwargs: """Return keyword arguments for ``MDAnalysis`` ``AnalysisBase.run``. Returns ------- MDARunKwargs ``frames`` alone when explicit frames are set, otherwise the non-``None`` ``start``, ``stop``, and ``step`` values. """ if self.frames is not None: return MDARunKwargs(frames=_normalize_frame_selector_values(self.frames)) kwargs = MDARunKwargs() if self.start is not None: kwargs["start"] = _normalize_scalar_value(self.start) if self.stop is not None: kwargs["stop"] = _normalize_scalar_value(self.stop) if self.step is not None: kwargs["step"] = _normalize_scalar_value(self.step) return kwargs
[docs] @classmethod def from_trajectory_window(cls, window: TrajectoryWindow) -> FrameSelection: """Build a frame selection from a resolved PolyzyMD trajectory window. Parameters ---------- window : TrajectoryWindow Existing validated shared trajectory window. Returns ------- FrameSelection Selection that forwards the same ``start``, ``stop``, and ``step`` values to MDAnalysis while preserving window provenance. """ return cls( start=window.start, stop=window.stop, step=window.step, equilibration=getattr(window, "equilibration", None), equilibration_start=window.equilibration_start, equilibration_ps=window.equilibration_ps, timestep_ps=window.timestep_ps, first_frame_time_ps=window.first_frame_time_ps, selected_start_time_ps=window.selected_start_time_ps, equilibration_time_reference=window.equilibration_time_reference, n_frames_total=window.n_frames_total, warning_message=window.warning_message, )
[docs] @classmethod def from_equilibration( cls, *, equilibration: str, n_frames_total: int, timestep_ps: float, start: int | None = None, stop: int | None = None, step: int = 1, min_frames: int = 1, first_frame_time_ps: float | None = None, ) -> FrameSelection: """Resolve PolyzyMD equilibration/window settings to a frame selection. Finite first-frame timestamps make equilibration absolute in MDAnalysis trajectory time; missing timestamps keep the stale loaded-frame-relative origin. Parameters ---------- equilibration : str Equilibration time string such as ``"10ns"``. n_frames_total : int Total number of trajectory frames. timestep_ps : float Trajectory timestep in picoseconds. start : int | None, optional Absolute start frame. When ``None``, equilibration determines the start frame. stop : int | None, optional Absolute exclusive stop frame. When ``None``, the trajectory end is used. step : int, optional Frame stride, by default 1. min_frames : int, optional Minimum required number of selected frames, by default 1. first_frame_time_ps : float | None, optional Absolute MDAnalysis timestamp of loaded frame 0 in picoseconds. Returns ------- FrameSelection Validated selection suitable for MDAnalysis ``run()``. """ window = resolve_trajectory_window( equilibration=equilibration, n_frames_total=n_frames_total, timestep_ps=timestep_ps, start=start, stop=stop, step=step, min_frames=min_frames, first_frame_time_ps=first_frame_time_ps, ) selection = cls.from_trajectory_window(window) object.__setattr__(selection, "equilibration", equilibration) return selection
def _normalize_scalar_fields(self) -> None: """Normalize NumPy-like scalar selector and provenance fields. Returns ------- None The dataclass fields are updated in place through frozen dataclass attribute assignment. """ for field_name in ( "start", "stop", "step", "equilibration_start", "equilibration_ps", "timestep_ps", "first_frame_time_ps", "selected_start_time_ps", "n_frames_total", ): object.__setattr__( self, field_name, _normalize_scalar_value(getattr(self, field_name)), ) def _validate_total_frame_count(self) -> None: """Validate the optional total frame count. Raises ------ ValueError Raised when the known trajectory length is invalid. """ if self.n_frames_total is not None and self.n_frames_total < 1: raise ValueError("n_frames_total must be >= 1 when provided") def _validate_explicit_frames_mode(self) -> None: """Validate explicit ``frames`` selectors. Raises ------ ValueError Raised when ``frames`` is mixed with slice arguments or selects no frames. """ if self.start is not None or self.stop is not None or self.step is not None: raise ValueError("frames cannot be combined with start, stop, or step") frozen_frames = _coerce_frames(self.frames) n_frames_selected = _selected_count_from_frames(frozen_frames, self.n_frames_total) object.__setattr__(self, "frames", tuple(_normalize_frame_selector_values(frozen_frames))) object.__setattr__(self, "n_frames_selected", n_frames_selected) def _validate_slice_mode(self) -> None: """Validate ``start``/``stop``/``step`` selectors. Raises ------ ValueError Raised when the slice arguments are invalid or select no known frames. """ if self.step is not None and self.step < 1: raise ValueError(f"step must be >= 1, got {self.step}") if self.start is not None and self.start < 0: raise ValueError(f"start must be >= 0, got {self.start}") if self.stop is not None and self.start is not None and self.stop <= self.start: raise ValueError(f"stop={self.stop} must be greater than start={self.start}") if self.n_frames_total is not None: if self.start is not None and self.start >= self.n_frames_total: raise ValueError( f"start={self.start} is outside the trajectory range [0, {self.n_frames_total})" ) if self.stop is not None and self.stop > self.n_frames_total: raise ValueError( f"stop={self.stop} exceeds trajectory length {self.n_frames_total}" ) n_frames_selected = _selected_count_from_slice( start=self.start, stop=self.stop, step=self.step, n_frames_total=self.n_frames_total, ) if n_frames_selected == 0: raise ValueError("Frame selection must select at least one frame") object.__setattr__(self, "n_frames_selected", n_frames_selected)