"""Shared SASA computation utilities.
This module provides reusable helpers for SASA analyses that need
selection validation, atom-level and residue-level SASA traces, and
raw artifact persistence.
"""
from __future__ import annotations
import json
import logging
import tempfile
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import numpy as np
from numpy.typing import NDArray
from polyzymd.analyses.shared.diagnostics import warn_if_multi_chain_selection
LOGGER = logging.getLogger(__name__)
NM2_TO_A2 = 100.0
[docs]
@dataclass(frozen=True)
class SASAComputationResult:
"""Container for raw SASA computation outputs.
Parameters
----------
atom_sasa_a2 : NDArray[np.float64]
Per-atom SASA trace in Ų, shape ``(n_frames, n_target_atoms)``.
residue_sasa_a2 : NDArray[np.float64]
Per-residue SASA trace in Ų, shape ``(n_frames, n_target_residues)``.
total_sasa_a2 : NDArray[np.float64]
Total target SASA in Ų, shape ``(n_frames,)``.
frames : NDArray[np.int64]
0-indexed frame indices used for analysis.
time_ns : NDArray[np.float64]
Time axis in ns corresponding to ``frames``.
target_atom_indices : NDArray[np.int64]
Universe-global atom indices for target atoms.
context_atom_indices : NDArray[np.int64]
Universe-global atom indices for context atoms.
residue_keys : list[str]
Residue identity keys in ``chainID:resid:resname`` format.
residue_chainids : list[str]
Chain IDs for each residue key.
residue_resids : list[int]
Residue IDs for each residue key.
residue_resnames : list[str]
Residue names for each residue key.
"""
atom_sasa_a2: NDArray[np.float64]
residue_sasa_a2: NDArray[np.float64]
total_sasa_a2: NDArray[np.float64]
frames: NDArray[np.int64]
time_ns: NDArray[np.float64]
target_atom_indices: NDArray[np.int64]
context_atom_indices: NDArray[np.int64]
residue_keys: list[str]
residue_chainids: list[str]
residue_resids: list[int]
residue_resnames: list[str]
[docs]
def resolve_selection_indices(
universe: Any,
selection: str,
*,
role: str,
run_label: str,
) -> tuple[Any, NDArray[np.int64]]:
"""Resolve an MDAnalysis selection to atom indices.
Parameters
----------
universe : Any
MDAnalysis universe.
selection : str
MDAnalysis selection string.
role : str
Selection role used in warning context.
run_label : str
Human-readable run label.
Returns
-------
tuple[Any, NDArray[np.int64]]
Selected atom group and its universe-global atom indices.
"""
atoms = universe.select_atoms(selection)
context = f"for SASA {role} selection in run '{run_label}'"
if role == "context":
context += " (multi-chain context selections are often intentional for SASA)"
warn_if_multi_chain_selection(atoms, selection, context=context)
return atoms, np.asarray(atoms.indices, dtype=np.int64)
[docs]
def validate_target_subset(
target_indices: NDArray[np.int64],
context_indices: NDArray[np.int64],
*,
run_label: str,
target_selection: str,
context_selection: str,
) -> None:
"""Validate target-selection atoms are a subset of context-selection atoms.
Raises
------
ValueError
If any target atom is absent from the context selection.
"""
if target_indices.size == 0:
return
context_set = set(context_indices.tolist())
missing = [idx for idx in target_indices.tolist() if idx not in context_set]
if not missing:
return
raise ValueError(
"SASA run '{label}' requires target_selection atoms to be a subset of "
"context_selection atoms. Missing {count} target atoms in context. "
"target_selection={target!r}, context_selection={context!r}".format(
label=run_label,
count=len(missing),
target=target_selection,
context=context_selection,
)
)
[docs]
def compute_sasa(
universe: Any,
*,
run_label: str,
target_selection: str,
context_selection: str,
probe_radius_nm: float,
n_sphere_points: int,
start_frame: int,
stop_frame: int,
timestep_ps: float,
chunk_size: int = 100,
stride: int = 1,
) -> SASAComputationResult:
"""Compute target SASA over a selected context.
Parameters
----------
universe : Any
MDAnalysis universe.
run_label : str
Human-readable run label.
target_selection : str
Selection of atoms whose SASA is reported.
context_selection : str
Selection of atoms considered during SASA computation.
probe_radius_nm : float
Probe radius in nm.
n_sphere_points : int
Number of sphere points for Shrake-Rupley.
start_frame : int
First frame index (inclusive).
stop_frame : int
Last frame index (exclusive).
timestep_ps : float
Timestep in ps.
chunk_size : int, optional
Number of analyzed frames processed per Shrake-Rupley chunk,
by default 100.
stride : int, optional
Frame stride applied before chunking (1 = analyze every frame),
by default 1.
Returns
-------
SASAComputationResult
Raw atom-level, residue-level, and total SASA traces in Ų.
Raises
------
ValueError
If frame bounds or stride/chunk parameters are invalid.
"""
import mdtraj as md
if chunk_size <= 0:
raise ValueError("chunk_size must be >= 1")
if stride <= 0:
raise ValueError("stride must be >= 1")
trajectory_length = len(universe.trajectory)
if start_frame < 0:
raise ValueError(f"start_frame must be >= 0, got {start_frame}")
if stop_frame > trajectory_length:
raise ValueError(
f"stop_frame must be <= trajectory length ({trajectory_length}), got {stop_frame}"
)
if start_frame >= stop_frame:
raise ValueError(
"start_frame must be < stop_frame, got "
f"start_frame={start_frame}, stop_frame={stop_frame}"
)
target_atoms, target_indices = resolve_selection_indices(
universe,
target_selection,
role="target",
run_label=run_label,
)
context_atoms, context_indices = resolve_selection_indices(
universe,
context_selection,
role="context",
run_label=run_label,
)
frames = np.arange(start_frame, stop_frame, dtype=np.int64)
analyzed_frames = frames[::stride]
n_frames = int(analyzed_frames.size)
if context_indices.size == 0:
LOGGER.warning(
"Run '%s' context selection matched zero atoms (%r); returning NaN SASA metrics",
run_label,
context_selection,
)
if target_indices.size == 0:
LOGGER.warning(
"Run '%s' target selection matched zero atoms (%r); returning NaN SASA metrics",
run_label,
target_selection,
)
if target_indices.size == 0 or context_indices.size == 0 or n_frames == 0:
time_ns = (analyzed_frames.astype(np.float64) * timestep_ps) / 1000.0
return SASAComputationResult(
atom_sasa_a2=np.empty((n_frames, target_indices.size), dtype=np.float64),
residue_sasa_a2=np.empty((n_frames, 0), dtype=np.float64),
total_sasa_a2=np.full(n_frames, np.nan, dtype=np.float64),
frames=analyzed_frames,
time_ns=time_ns,
target_atom_indices=target_indices,
context_atom_indices=context_indices,
residue_keys=[],
residue_chainids=[],
residue_resids=[],
residue_resnames=[],
)
validate_target_subset(
target_indices,
context_indices,
run_label=run_label,
target_selection=target_selection,
context_selection=context_selection,
)
context_index_to_local = {int(idx): i for i, idx in enumerate(context_indices.tolist())}
target_local_indices = np.asarray(
[context_index_to_local[int(idx)] for idx in target_indices.tolist()],
dtype=np.int64,
)
residue_to_indices: dict[tuple[str, int, str], list[int]] = {}
for atom_local, atom in zip(target_local_indices.tolist(), target_atoms):
key = (str(atom.chainID), int(atom.resid), str(atom.resname))
residue_to_indices.setdefault(key, []).append(int(atom_local))
residue_items = list(residue_to_indices.items())
residue_keys = [f"{chain}:{resid}:{resname}" for (chain, resid, resname), _ in residue_items]
residue_chainids = [chain for (chain, _, _), _ in residue_items]
residue_resids = [resid for (_, resid, _), _ in residue_items]
residue_resnames = [resname for (_, _, resname), _ in residue_items]
with tempfile.NamedTemporaryFile(suffix=".pdb", delete=True) as tmp_pdb:
context_atoms.write(tmp_pdb.name)
template = md.load(tmp_pdb.name)
atom_sasa_target_a2 = np.empty((n_frames, target_local_indices.size), dtype=np.float64)
residue_sasa_a2 = np.empty((n_frames, len(residue_items)), dtype=np.float64)
total_sasa_a2 = np.empty(n_frames, dtype=np.float64)
n_chunks = (n_frames + chunk_size - 1) // chunk_size
for chunk_idx, chunk_start in enumerate(range(0, n_frames, chunk_size)):
chunk_end = min(chunk_start + chunk_size, n_frames)
LOGGER.info(
"Computing SASA chunk %d/%d (frames %d-%d)...",
chunk_idx + 1,
n_chunks,
int(analyzed_frames[chunk_start]),
int(analyzed_frames[chunk_end - 1]),
)
chunk_frames = analyzed_frames[chunk_start:chunk_end]
xyz_nm = np.empty((chunk_frames.size, len(context_atoms), 3), dtype=np.float32)
for out_idx, frame_idx in enumerate(chunk_frames.tolist()):
universe.trajectory[frame_idx]
xyz_nm[out_idx] = context_atoms.positions.astype(np.float32) / 10.0
mdtraj_traj = md.Trajectory(xyz=xyz_nm, topology=template.topology)
atom_sasa_nm2 = np.asarray(
md.shrake_rupley(
mdtraj_traj,
mode="atom",
probe_radius=probe_radius_nm,
n_sphere_points=n_sphere_points,
),
dtype=np.float64,
)
atom_sasa_target_chunk = atom_sasa_nm2[:, target_local_indices] * NM2_TO_A2
atom_sasa_target_a2[chunk_start:chunk_end, :] = atom_sasa_target_chunk
for idx, (_, atom_locals) in enumerate(residue_items):
residue_sasa_a2[chunk_start:chunk_end, idx] = (
np.sum(atom_sasa_nm2[:, atom_locals], axis=1) * NM2_TO_A2
)
total_sasa_a2[chunk_start:chunk_end] = np.sum(atom_sasa_target_chunk, axis=1)
time_ns = (analyzed_frames.astype(np.float64) * timestep_ps) / 1000.0
return SASAComputationResult(
atom_sasa_a2=atom_sasa_target_a2,
residue_sasa_a2=residue_sasa_a2,
total_sasa_a2=total_sasa_a2,
frames=analyzed_frames,
time_ns=time_ns,
target_atom_indices=target_indices,
context_atom_indices=context_indices,
residue_keys=residue_keys,
residue_chainids=residue_chainids,
residue_resids=residue_resids,
residue_resnames=residue_resnames,
)
[docs]
def save_sasa_artifacts(
npz_path: Path,
metadata_path: Path,
result: SASAComputationResult,
*,
run_label: str,
target_selection: str,
context_selection: str,
probe_radius_nm: float,
n_sphere_points: int,
equilibration: str,
) -> None:
"""Save raw SASA arrays plus JSON sidecar metadata.
Parameters
----------
npz_path : Path
Output path for compressed NumPy archive.
metadata_path : Path
Output path for JSON metadata sidecar.
result : SASAComputationResult
Raw SASA arrays to persist.
run_label : str
Run label.
target_selection : str
Target selection string.
context_selection : str
Context selection string.
probe_radius_nm : float
Probe radius in nm.
n_sphere_points : int
Number of sphere points.
equilibration : str
Equilibration string.
"""
npz_path.parent.mkdir(parents=True, exist_ok=True)
metadata_path.parent.mkdir(parents=True, exist_ok=True)
np.savez_compressed(
npz_path,
atom_sasa_a2=result.atom_sasa_a2,
residue_sasa_a2=result.residue_sasa_a2,
total_sasa_a2=result.total_sasa_a2,
frames=result.frames,
time_ns=result.time_ns,
target_atom_indices=result.target_atom_indices,
context_atom_indices=result.context_atom_indices,
residue_keys=np.asarray(result.residue_keys, dtype=str),
residue_chainids=np.asarray(result.residue_chainids, dtype=str),
residue_resids=np.asarray(result.residue_resids, dtype=np.int64),
residue_resnames=np.asarray(result.residue_resnames, dtype=str),
)
metadata = build_sasa_artifact_metadata(
result,
run_label=run_label,
target_selection=target_selection,
context_selection=context_selection,
probe_radius_nm=probe_radius_nm,
n_sphere_points=n_sphere_points,
equilibration=equilibration,
)
metadata_path.write_text(json.dumps(metadata, indent=2), encoding="utf-8")
[docs]
def load_sasa_artifacts(
npz_path: Path,
metadata_path: Path,
) -> tuple[SASAComputationResult, dict[str, Any]]:
"""Load SASA raw arrays and metadata from disk.
Parameters
----------
npz_path : Path
Path to compressed NumPy archive.
metadata_path : Path
Path to JSON metadata sidecar.
Returns
-------
tuple[SASAComputationResult, dict[str, Any]]
Reconstructed raw result and parsed metadata dictionary.
"""
with np.load(npz_path) as payload:
result = SASAComputationResult(
atom_sasa_a2=np.asarray(payload["atom_sasa_a2"], dtype=np.float64),
residue_sasa_a2=np.asarray(payload["residue_sasa_a2"], dtype=np.float64),
total_sasa_a2=np.asarray(payload["total_sasa_a2"], dtype=np.float64),
frames=np.asarray(payload["frames"], dtype=np.int64),
time_ns=np.asarray(payload["time_ns"], dtype=np.float64),
target_atom_indices=(
np.asarray(payload["target_atom_indices"], dtype=np.int64)
if "target_atom_indices" in payload
else np.empty((0,), dtype=np.int64)
),
context_atom_indices=(
np.asarray(payload["context_atom_indices"], dtype=np.int64)
if "context_atom_indices" in payload
else np.empty((0,), dtype=np.int64)
),
residue_keys=[str(v) for v in payload["residue_keys"].tolist()],
residue_chainids=[str(v) for v in payload["residue_chainids"].tolist()],
residue_resids=[int(v) for v in payload["residue_resids"].tolist()],
residue_resnames=[str(v) for v in payload["residue_resnames"].tolist()],
)
metadata = json.loads(metadata_path.read_text(encoding="utf-8"))
return result, metadata
SASA_ARTIFACT_SCHEMA_NAME = "polyzymd.sasa_artifact"
SASA_ARTIFACT_SCHEMA_VERSION = 1
SASA_ARTIFACT_COMPATIBILITY_VERSION = 1
A2_TO_NM2 = 0.01
SASA_COMPAT_PROBE_RADIUS_ABS_TOL = 1e-6
[docs]
@dataclass(frozen=True)
class SASAArtifactContract:
"""Canonical contract metadata for a persisted SASA artifact.
Parameters
----------
schema_name : str
Artifact schema identifier.
schema_version : int
Artifact schema version.
compatibility_version : int
Compatibility hash payload version.
producer : str
Producer module name.
engine : str
SASA engine identifier.
mode : str
SASA mode. Canonical value is ``"atom"``.
units : str
SASA units. Canonical value is ``"A^2"``.
run_label : str
Human-readable run label.
target_selection : str
Target atom selection used for SASA reporting.
context_selection : str
Context atom selection used during SASA computation.
probe_radius_nm : float
Probe radius in nm.
n_sphere_points : int
Number of Shrake-Rupley sphere points.
equilibration : str
Equilibration descriptor.
compatibility_hash : str
Deterministic compatibility hash.
"""
schema_name: str
schema_version: int
compatibility_version: int
producer: str
engine: str
mode: str
units: str
run_label: str
target_selection: str
context_selection: str
probe_radius_nm: float
n_sphere_points: int
equilibration: str
compatibility_hash: str
[docs]
@dataclass(frozen=True)
class SASAArtifactCompatibilityQuery:
"""Compatibility query metadata for sibling artifact lookup.
Parameters
----------
probe_radius_nm : float
Probe radius expected by the consumer.
n_sphere_points : int
Sphere point count expected by the consumer.
equilibration : str
Equilibration label expected by the consumer.
selection : str | None, optional
Target selection string for advisory hash comparison.
context_selection : str | None, optional
Context selection string for advisory hash comparison.
"""
probe_radius_nm: float
n_sphere_points: int
equilibration: str
selection: str | None = None
context_selection: str | None = None
[docs]
@dataclass(frozen=True)
class SASAArtifactCompatibility:
"""Compatibility decision for a single SASA artifact metadata payload.
Parameters
----------
is_compatible : bool
True when metadata-level required fields are compatible.
is_legacy : bool
True when schema version fields are absent.
schema_version : int | None
Parsed schema version if present.
selection_hash_matches : bool | None
Advisory selection hash result when query selections are supplied.
mismatched_fields : tuple[str, ...]
Names of fields that made the artifact incompatible.
"""
is_compatible: bool
is_legacy: bool
schema_version: int | None
selection_hash_matches: bool | None
mismatched_fields: tuple[str, ...]
[docs]
@dataclass(frozen=True)
class SASASiblingArtifactMatch:
"""A sibling SASA artifact candidate and its compatibility outcome.
Parameters
----------
sibling_analysis_dir : Path
Directory containing sibling SASA artifacts for the replicate.
npz_path : Path
Path to the SASA NPZ payload.
metadata_path : Path
Path to the JSON metadata sidecar.
metadata : dict[str, Any]
Parsed metadata dictionary.
compatibility : SASAArtifactCompatibility
Compatibility decision for this artifact.
"""
sibling_analysis_dir: Path
npz_path: Path
metadata_path: Path
metadata: dict[str, Any]
compatibility: SASAArtifactCompatibility
[docs]
def compute_sasa_artifact_compatibility_hash(
*,
probe_radius_nm: float,
n_sphere_points: int,
selection: str,
equilibration: str,
context_selection: str | None = None,
) -> str:
"""Compute deterministic compatibility hash for SASA artifacts.
Parameters
----------
probe_radius_nm : float
Probe radius in nm.
n_sphere_points : int
Number of Shrake-Rupley sphere points.
selection : str
Target selection string.
equilibration : str
Equilibration label.
context_selection : str | None, optional
Context selection string. Defaults to ``selection`` when omitted.
Returns
-------
str
First 16 characters of the SHA-256 compatibility digest.
"""
import hashlib
normalized_selection = selection.strip()
normalized_context = (context_selection or selection).strip()
payload = {
"compatibility_version": SASA_ARTIFACT_COMPATIBILITY_VERSION,
"mode": "atom",
"units": "A^2",
"selection": normalized_selection,
"context_selection": normalized_context,
"probe_radius_nm": round(float(probe_radius_nm), 6),
"n_sphere_points": int(n_sphere_points),
"equilibration": equilibration.strip(),
}
serialized = json.dumps(payload, sort_keys=True, separators=(",", ":"))
return hashlib.sha256(serialized.encode("utf-8")).hexdigest()[:16]
[docs]
def build_sasa_artifact_contract(
*,
run_label: str,
target_selection: str,
context_selection: str,
probe_radius_nm: float,
n_sphere_points: int,
equilibration: str,
) -> SASAArtifactContract:
"""Build canonical SASA artifact contract metadata.
Parameters
----------
run_label : str
Human-readable run label.
target_selection : str
Target atom selection used for SASA reporting.
context_selection : str
Context atom selection used during SASA computation.
probe_radius_nm : float
Probe radius in nm.
n_sphere_points : int
Number of Shrake-Rupley sphere points.
equilibration : str
Equilibration descriptor.
Returns
-------
SASAArtifactContract
Fully populated canonical metadata contract.
"""
compatibility_hash = compute_sasa_artifact_compatibility_hash(
probe_radius_nm=probe_radius_nm,
n_sphere_points=n_sphere_points,
selection=target_selection,
context_selection=context_selection,
equilibration=equilibration,
)
return SASAArtifactContract(
schema_name=SASA_ARTIFACT_SCHEMA_NAME,
schema_version=SASA_ARTIFACT_SCHEMA_VERSION,
compatibility_version=SASA_ARTIFACT_COMPATIBILITY_VERSION,
producer="polyzymd.analyses.shared.sasa",
engine="mdtraj.shrake_rupley",
mode="atom",
units="A^2",
run_label=run_label,
target_selection=target_selection,
context_selection=context_selection,
probe_radius_nm=probe_radius_nm,
n_sphere_points=n_sphere_points,
equilibration=equilibration,
compatibility_hash=compatibility_hash,
)
[docs]
def check_sasa_artifact_compatibility(
metadata: dict[str, Any],
query: SASAArtifactCompatibilityQuery,
) -> SASAArtifactCompatibility:
"""Evaluate metadata-level compatibility for reusable SASA artifacts.
Parameters
----------
metadata : dict[str, Any]
Parsed SASA metadata sidecar.
query : SASAArtifactCompatibilityQuery
Consumer compatibility query.
Returns
-------
SASAArtifactCompatibility
Compatibility decision containing definitive and advisory signals.
"""
mismatched_fields: list[str] = []
schema_name_raw = metadata.get("artifact_schema")
if schema_name_raw is not None and str(schema_name_raw) != SASA_ARTIFACT_SCHEMA_NAME:
mismatched_fields.append("artifact_schema")
schema_raw = metadata.get("artifact_schema_version")
schema_version: int | None
if schema_raw is None:
schema_version = None
else:
try:
schema_version = int(schema_raw)
except (TypeError, ValueError):
schema_version = None
mismatched_fields.append("artifact_schema_version")
is_legacy = schema_raw is None
if schema_version is not None and schema_version > SASA_ARTIFACT_SCHEMA_VERSION:
mismatched_fields.append("artifact_schema_version")
compat_version_raw = metadata.get("artifact_compatibility_version")
if compat_version_raw is not None:
try:
compatibility_version = int(compat_version_raw)
except (TypeError, ValueError):
compatibility_version = None
mismatched_fields.append("artifact_compatibility_version")
if (
compatibility_version is not None
and compatibility_version > SASA_ARTIFACT_COMPATIBILITY_VERSION
):
mismatched_fields.append("artifact_compatibility_version")
probe_value = metadata.get("probe_radius_nm")
try:
probe_float = float(probe_value)
except (TypeError, ValueError):
probe_float = float("nan")
if not np.isfinite(probe_float) or (
abs(probe_float - float(query.probe_radius_nm)) > SASA_COMPAT_PROBE_RADIUS_ABS_TOL
):
mismatched_fields.append("probe_radius_nm")
sphere_value = metadata.get("n_sphere_points")
try:
sphere_points = int(sphere_value)
except (TypeError, ValueError):
sphere_points = -1
if sphere_points != int(query.n_sphere_points):
mismatched_fields.append("n_sphere_points")
equilibration_value = str(metadata.get("equilibration", "")).strip()
if equilibration_value != query.equilibration.strip():
mismatched_fields.append("equilibration")
units_value = metadata.get("units", "A^2")
if str(units_value) != "A^2":
mismatched_fields.append("units")
mode_value = metadata.get("sasa_mode", "atom")
if str(mode_value) != "atom":
mismatched_fields.append("sasa_mode")
selection_hash_matches: bool | None = None
if query.selection is not None or query.context_selection is not None:
selection_value = (
query.selection if query.selection is not None else (query.context_selection or "")
)
context_value = (
query.context_selection if query.context_selection is not None else selection_value
)
query_hash = compute_sasa_artifact_compatibility_hash(
probe_radius_nm=query.probe_radius_nm,
n_sphere_points=query.n_sphere_points,
selection=selection_value,
context_selection=context_value,
equilibration=query.equilibration,
)
metadata_hash = metadata.get("compatibility_hash")
selection_hash_matches = str(metadata_hash) == query_hash
return SASAArtifactCompatibility(
is_compatible=len(mismatched_fields) == 0,
is_legacy=is_legacy,
schema_version=schema_version,
selection_hash_matches=selection_hash_matches,
mismatched_fields=tuple(mismatched_fields),
)
[docs]
def adapt_canonical_sasa_to_exposure(
result: SASAComputationResult,
*,
exposure_threshold: float,
) -> Any:
"""Adapt canonical shared SASA result into exposure SASA trajectory format.
Parameters
----------
result : SASAComputationResult
Canonical shared SASA computation result in Ų.
exposure_threshold : float
Relative SASA threshold used by exposure consumers.
Returns
-------
Any
``SASATrajectoryResult`` instance for exposure analysis.
"""
from polyzymd.analyses.exposure._sasa_trajectory import SASATrajectoryResult
from polyzymd.analyses.shared.aa_classification import MAX_ASA_TABLE, get_aa_class
sasa_per_frame = result.residue_sasa_a2.astype(np.float32) * A2_TO_NM2
resids = np.asarray(result.residue_resids, dtype=np.int32)
resnames = [str(x).upper() for x in result.residue_resnames]
aa_classes = [get_aa_class(name) for name in resnames]
max_sasa_nm2 = np.asarray(
[MAX_ASA_TABLE.get(name, 200.0) * A2_TO_NM2 for name in resnames],
dtype=np.float32,
)
safe_max = np.where(max_sasa_nm2 > 0.0, max_sasa_nm2, 1.0).astype(np.float32)
relative_sasa_per_frame = sasa_per_frame / safe_max[np.newaxis, :]
n_frames = int(sasa_per_frame.shape[0])
n_residues = int(sasa_per_frame.shape[1]) if sasa_per_frame.ndim == 2 else 0
return SASATrajectoryResult(
sasa_per_frame=sasa_per_frame,
relative_sasa_per_frame=relative_sasa_per_frame.astype(np.float32),
resids=resids,
resnames=resnames,
aa_classes=aa_classes,
max_sasa_nm2=max_sasa_nm2,
n_frames=n_frames,
n_residues=n_residues,
exposure_threshold=exposure_threshold,
)
[docs]
def find_sibling_sasa_artifacts(
replicate_analysis_dir: Path,
query: SASAArtifactCompatibilityQuery,
) -> list[SASASiblingArtifactMatch]:
"""Find compatible sibling SASA artifacts for a replicate analysis directory.
Parameters
----------
replicate_analysis_dir : Path
Replicate analysis directory, for example ``analysis/<cond>/exposure/run_1``.
query : SASAArtifactCompatibilityQuery
Compatibility query used to pre-filter candidates.
Returns
-------
list[SASASiblingArtifactMatch]
Compatible sibling artifacts sorted by preference.
"""
sibling_dir = replicate_analysis_dir.parent.parent / "sasa" / replicate_analysis_dir.name
if not sibling_dir.exists():
return []
matches: list[SASASiblingArtifactMatch] = []
for npz_path in sibling_dir.glob("sasa_*.npz"):
metadata_path = npz_path.with_suffix(".json")
if not metadata_path.exists():
continue
try:
metadata = json.loads(metadata_path.read_text(encoding="utf-8"))
except (OSError, json.JSONDecodeError):
continue
compatibility = check_sasa_artifact_compatibility(metadata, query)
if not compatibility.is_compatible:
continue
matches.append(
SASASiblingArtifactMatch(
sibling_analysis_dir=sibling_dir,
npz_path=npz_path,
metadata_path=metadata_path,
metadata=metadata,
compatibility=compatibility,
)
)
def _selection_hash_rank(value: bool | None) -> int:
if value is True:
return 0
if value is None:
return 1
return 2
matches.sort(
key=lambda item: (
int(item.compatibility.is_legacy),
_selection_hash_rank(item.compatibility.selection_hash_matches),
item.npz_path.name,
)
)
return matches