Source code for polyzymd.analyses.mda.comparison

"""Comparison engine for MDAnalysis condition artifacts."""

from __future__ import annotations

import math
from collections.abc import Mapping, Sequence
from dataclasses import dataclass
from typing import Any

from polyzymd.analyses._framework.comparison_models import ComparisonResult, MetricValue
from polyzymd.analyses.mda.aggregation import AggregatedMetric
from polyzymd.analyses.mda.artifacts import ComparisonArtifact, ConditionArtifact
from polyzymd.analyses.mda.base import MDAnalysisExtensionError


[docs] class MDAComparisonError(MDAnalysisExtensionError): """Error raised when condition artifacts cannot be compared."""
[docs] @dataclass(frozen=True) class MDAComparisonContext: """Identity and statistical controls for condition-artifact comparison.""" analysis_name: str project_name: str expected_condition_labels: Sequence[str] | None = None expected_replicates_by_condition: Mapping[str, Sequence[int]] | None = None control_label: str | None = None effective_control: str | None = None equilibration: str = "0ns" settings_fingerprint: str | None = None min_replicates: int = 1 fdr_alpha: float = 0.05 ttest_method: str = "student" posthoc_method: str = "ttest_bh"
[docs] def __post_init__(self) -> None: """Normalize expected identity inputs and reject ambiguous values.""" if self.min_replicates < 1: raise MDAComparisonError("min_replicates must be at least 1") if self.expected_condition_labels is not None: labels = tuple(self.expected_condition_labels) duplicates = sorted({label for label in labels if labels.count(label) > 1}) if duplicates: raise MDAComparisonError( f"Expected condition labels must be unique; duplicates: {', '.join(duplicates)}" ) object.__setattr__(self, "expected_condition_labels", labels) if self.expected_replicates_by_condition is not None: normalized: dict[str, tuple[int, ...]] = {} for label, replicates in self.expected_replicates_by_condition.items(): if not isinstance(label, str): raise MDAComparisonError( f"Expected replicate mapping keys must be condition labels, got {label!r}" ) replicate_tuple = tuple(int(replicate) for replicate in replicates) duplicate_replicates = sorted( { replicate for replicate in replicate_tuple if replicate_tuple.count(replicate) > 1 } ) if duplicate_replicates: raise MDAComparisonError( f"Expected replicates for condition {label!r} must be unique; " f"duplicates: {duplicate_replicates}" ) normalized[label] = replicate_tuple object.__setattr__(self, "expected_replicates_by_condition", normalized)
@dataclass(frozen=True) class _DerivedMetricStatistics: """Statistics derived directly from replicate-level metric values.""" values: list[float] mean: float std: float sem: float
[docs] def compare_condition_artifacts( artifacts: Sequence[ConditionArtifact], ctx: MDAComparisonContext, ) -> ComparisonArtifact: """Compare aggregate condition artifacts with replicate-level statistics. Parameters ---------- artifacts : sequence of ConditionArtifact Condition artifacts produced by MDAnalysis extension-layer aggregation. ctx : MDAComparisonContext Comparison identity, expected condition labels, and statistical controls. Returns ------- ComparisonArtifact Stable comparison artifact containing scalar statistics and provenance. """ normalized = _validate_condition_artifacts(artifacts, ctx) metrics_by_condition, metric_metadata = _build_metric_values(normalized, ctx) from polyzymd.analyses.stats import default_scalar_comparison comparison = default_scalar_comparison( analysis_name=ctx.analysis_name, project_name=ctx.project_name, metrics_by_condition=metrics_by_condition, control_label=ctx.effective_control, equilibration=ctx.equilibration, fdr_alpha=ctx.fdr_alpha, ttest_method=ctx.ttest_method, posthoc_method=ctx.posthoc_method, ) return _comparison_result_to_artifact(comparison, normalized, ctx, metric_metadata)
def _validate_condition_artifacts( artifacts: Sequence[ConditionArtifact], ctx: MDAComparisonContext, ) -> list[ConditionArtifact]: """Validate condition-artifact identity before statistical comparison. Parameters ---------- artifacts : sequence of ConditionArtifact Candidate artifacts. ctx : MDAComparisonContext Expected analysis and condition identity. Returns ------- list of ConditionArtifact Artifacts in input order when validation passes. """ if not artifacts: raise MDAComparisonError(f"{ctx.analysis_name}: no condition artifacts to compare") normalized = list(artifacts) labels = [artifact.condition_label for artifact in normalized] duplicates = sorted({label for label in labels if labels.count(label) > 1}) if duplicates: raise MDAComparisonError( f"{ctx.analysis_name}: duplicate condition artifact labels: {', '.join(duplicates)}" ) if ctx.expected_condition_labels is not None: expected = list(ctx.expected_condition_labels) missing = sorted(set(expected) - set(labels)) unexpected = sorted(set(labels) - set(expected)) if missing or unexpected: raise MDAComparisonError( f"{ctx.analysis_name}: condition artifact labels do not match expected labels; " f"missing={missing}, unexpected={unexpected}. Recompute the analysis or clear " "stale aggregate result.json files." ) for artifact in normalized: _validate_artifact_identity(artifact, ctx) return normalized def _validate_artifact_identity(artifact: ConditionArtifact, ctx: MDAComparisonContext) -> None: """Validate one condition artifact against the comparison context. Parameters ---------- artifact : ConditionArtifact Candidate condition artifact. ctx : MDAComparisonContext Expected analysis identity. """ if artifact.analysis_name != ctx.analysis_name: raise MDAComparisonError( f"{ctx.analysis_name}: condition {artifact.condition_label!r} analysis mismatch; " f"artifact has {artifact.analysis_name!r}" ) if ctx.settings_fingerprint is not None: stored = artifact.metadata.get("settings_fingerprint") if stored != ctx.settings_fingerprint: raise MDAComparisonError( f"{ctx.analysis_name}: condition {artifact.condition_label!r} settings " f"fingerprint mismatch; expected {ctx.settings_fingerprint!r}, got {stored!r}. " "Recompute the analysis or clear stale aggregate result.json files." ) if len(set(artifact.replicates)) != len(artifact.replicates): raise MDAComparisonError( f"{ctx.analysis_name}: condition {artifact.condition_label!r} has duplicate replicates" ) _validate_artifact_replicates(artifact, ctx) def _validate_artifact_replicates( artifact: ConditionArtifact, ctx: MDAComparisonContext, ) -> None: """Validate condition artifact replicate identity against active inputs. Parameters ---------- artifact : ConditionArtifact Candidate condition artifact. ctx : MDAComparisonContext Expected replicate identity and minimum replicate policy. """ replicate_ids = tuple(int(replicate) for replicate in artifact.replicates) if not replicate_ids: raise MDAComparisonError( f"{ctx.analysis_name}: condition {artifact.condition_label!r} has no replicates" ) if len(replicate_ids) < ctx.min_replicates: raise MDAComparisonError( f"{ctx.analysis_name}: condition {artifact.condition_label!r} has " f"{len(replicate_ids)} replicate(s), below required minimum {ctx.min_replicates}. " "Recompute the analysis or clear stale aggregate result.json files." ) if ctx.expected_replicates_by_condition is None: return expected = ctx.expected_replicates_by_condition.get(artifact.condition_label) if expected is None: raise MDAComparisonError( f"{ctx.analysis_name}: condition {artifact.condition_label!r} has no active " "replicate request in the comparison context" ) expected_set = set(expected) unexpected = sorted(set(replicate_ids) - expected_set) if unexpected: raise MDAComparisonError( f"{ctx.analysis_name}: condition {artifact.condition_label!r} contains unexpected " f"replicate IDs {unexpected}; active replicates are {list(expected)}. Recompute the " "analysis or clear stale aggregate result.json files." ) def _build_metric_values( artifacts: Sequence[ConditionArtifact], ctx: MDAComparisonContext, ) -> tuple[dict[str, dict[str, MetricValue]], dict[str, dict[str, Any]]]: """Build ``MetricValue`` inputs from condition artifacts. Parameters ---------- artifacts : sequence of ConditionArtifact Validated condition artifacts. ctx : MDAComparisonContext Comparison context used for diagnostics. Returns ------- tuple of dict Metrics keyed by condition label and metric metadata keyed by metric name. """ metrics_by_condition: dict[str, dict[str, MetricValue]] = {} metric_metadata: dict[str, dict[str, Any]] = {} expected_metric_keys: set[str] | None = None baseline_label: str | None = None key_mismatches: list[str] = [] for artifact in artifacts: metric_payload = artifact.payload.get("metrics") if not isinstance(metric_payload, Mapping): raise MDAComparisonError( f"{ctx.analysis_name}: condition {artifact.condition_label!r} is missing " "payload['metrics'] mapping" ) condition_metrics: dict[str, MetricValue] = {} metric_keys: set[str] = set() for metric_key, raw_metric in metric_payload.items(): if not isinstance(metric_key, str): raise MDAComparisonError( f"{ctx.analysis_name}: condition {artifact.condition_label!r} has non-string " f"metric key {metric_key!r}" ) metric_keys.add(metric_key) summary = _validated_aggregated_metric(raw_metric, artifact, ctx, metric_key) metadata = _metric_metadata_for(artifact, metric_key, raw_metric) _merge_metric_metadata(metric_metadata, metric_key, metadata, artifact, ctx) condition_metrics[metric_key] = MetricValue( name=metric_key, mean=summary.mean, sem=summary.sem, replicate_values=summary.values, higher_is_better=metadata.get("higher_is_better", True), direction_labels=metadata.get( "direction_labels", ("decreased", "unchanged", "increased") ), ) if not condition_metrics: raise MDAComparisonError( f"{ctx.analysis_name}: condition {artifact.condition_label!r} declares no metrics" ) if expected_metric_keys is None: expected_metric_keys = metric_keys baseline_label = artifact.condition_label elif metric_keys != expected_metric_keys: missing = sorted(expected_metric_keys - metric_keys) extra = sorted(metric_keys - expected_metric_keys) key_mismatches.append(f"{artifact.condition_label}: missing={missing}, extra={extra}") metrics_by_condition[artifact.condition_label] = condition_metrics if key_mismatches: raise MDAComparisonError( f"{ctx.analysis_name}: inconsistent metric keys across condition artifacts; " f"baseline {baseline_label!r} keys={sorted(expected_metric_keys or [])}; " f"differences: {'; '.join(key_mismatches)}" ) return metrics_by_condition, metric_metadata def _validated_aggregated_metric( raw_metric: Any, artifact: ConditionArtifact, ctx: MDAComparisonContext, metric_key: str, ) -> _DerivedMetricStatistics: """Validate one raw metric payload with the aggregation schema. Parameters ---------- raw_metric : Any Candidate ``AggregatedMetric`` payload. artifact : ConditionArtifact Source condition artifact. ctx : MDAComparisonContext Comparison context used for diagnostics. metric_key : str Metric key within ``payload['metrics']``. Returns ------- _DerivedMetricStatistics Metric statistics derived from validated replicate values. """ if not isinstance(raw_metric, Mapping): raise MDAComparisonError( f"{ctx.analysis_name}: condition {artifact.condition_label!r} metric {metric_key!r} " f"must be a mapping, got {type(raw_metric).__name__}" ) raw_values = raw_metric.get("values") if not isinstance(raw_values, Sequence) or isinstance(raw_values, (str, bytes, bytearray)): raise MDAComparisonError( f"{ctx.analysis_name}: condition {artifact.condition_label!r} metric {metric_key!r} " "must provide a sequence of replicate values" ) for index, value in enumerate(raw_values): _validate_finite_scalar( value, ctx=ctx, condition_label=artifact.condition_label, metric_key=metric_key, value_name=f"values[{index}]", ) try: summary = AggregatedMetric.model_validate(raw_metric) except ValueError as exc: raise MDAComparisonError( f"{ctx.analysis_name}: condition {artifact.condition_label!r} metric {metric_key!r} " f"does not match AggregatedMetric schema: {exc}" ) from exc for value_name in ("mean", "sem", "std"): _validate_finite_scalar( getattr(summary, value_name), ctx=ctx, condition_label=artifact.condition_label, metric_key=metric_key, value_name=value_name, ) if summary.name != metric_key: raise MDAComparisonError( f"{ctx.analysis_name}: condition {artifact.condition_label!r} metric key {metric_key!r} " f"does not match AggregatedMetric.name {summary.name!r}" ) if summary.n != len(summary.values) or summary.n != len(artifact.replicates): raise MDAComparisonError( f"{ctx.analysis_name}: condition {artifact.condition_label!r} metric {metric_key!r} " "has inconsistent replicate counts; expected n == len(values) == len(replicates), " f"got n={summary.n}, len(values)={len(summary.values)}, " f"len(replicates)={len(artifact.replicates)}" ) derived = _derive_metric_statistics(summary.values) _validate_stored_statistic_matches( summary.mean, derived.mean, statistic_name="mean", artifact=artifact, ctx=ctx, metric_key=metric_key, ) _validate_stored_statistic_matches( summary.std, derived.std, statistic_name="std", artifact=artifact, ctx=ctx, metric_key=metric_key, ) _validate_stored_statistic_matches( summary.sem, derived.sem, statistic_name="sem", artifact=artifact, ctx=ctx, metric_key=metric_key, ) return derived def _derive_metric_statistics(values: Sequence[float]) -> _DerivedMetricStatistics: """Derive comparison statistics from replicate-level values. Parameters ---------- values : sequence of float Replicate-level metric values from an ``AggregatedMetric`` payload. Returns ------- _DerivedMetricStatistics Mean, sample standard deviation, and SEM derived from ``values``. """ derived_values = [float(value) for value in values] mean = sum(derived_values) / len(derived_values) if len(derived_values) == 1: std = 0.0 else: std = math.sqrt( sum((value - mean) ** 2 for value in derived_values) / (len(derived_values) - 1) ) sem = std / math.sqrt(len(derived_values)) return _DerivedMetricStatistics(values=derived_values, mean=mean, std=std, sem=sem) def _validate_stored_statistic_matches( stored: float, derived: float, *, statistic_name: str, artifact: ConditionArtifact, ctx: MDAComparisonContext, metric_key: str, ) -> None: """Reject stale aggregate statistics that disagree with replicate values. Parameters ---------- stored : float Statistic stored in the condition artifact. derived : float Statistic recalculated from ``AggregatedMetric.values``. statistic_name : str Name used in diagnostics. artifact : ConditionArtifact Source condition artifact. ctx : MDAComparisonContext Comparison context used for diagnostics. metric_key : str Metric key. """ if math.isclose(stored, derived, rel_tol=1e-9, abs_tol=1e-12): return raise MDAComparisonError( f"{ctx.analysis_name}: condition {artifact.condition_label!r} metric {metric_key!r} " f"stored {statistic_name}={stored!r} does not match value-derived " f"{statistic_name}={derived!r}; recompute the analysis or clear stale aggregate " "result.json files" ) def _validate_finite_scalar( value: Any, *, ctx: MDAComparisonContext, condition_label: str, metric_key: str, value_name: str, ) -> None: """Validate that a metric payload value is a finite scalar. Parameters ---------- value : Any Candidate scalar value. ctx : MDAComparisonContext Comparison context used for diagnostics. condition_label : str Source condition label. metric_key : str Metric key. value_name : str Field name used in diagnostics. """ if isinstance(value, bool) or not isinstance(value, (int, float)) or not math.isfinite(value): raise MDAComparisonError( f"{ctx.analysis_name}: condition {condition_label!r} metric {metric_key!r} " f"has non-finite scalar {value_name}={value!r}" ) def _metric_metadata_for( artifact: ConditionArtifact, metric_key: str, raw_metric: Any, ) -> dict[str, Any]: """Extract optional comparison metadata for one metric. Parameters ---------- artifact : ConditionArtifact Source condition artifact. metric_key : str Metric key. raw_metric : Any Raw metric payload mapping. Returns ------- dict[str, Any] Normalized metadata for ``MetricValue`` construction and artifact output. """ metadata: dict[str, Any] = {} payload_metadata = artifact.payload.get("metric_metadata") if isinstance(payload_metadata, Mapping): metric_metadata = payload_metadata.get(metric_key) if isinstance(metric_metadata, Mapping): metadata.update(metric_metadata) if isinstance(raw_metric, Mapping): for key in ("higher_is_better", "direction_labels", "label", "unit"): if key in raw_metric: metadata[key] = raw_metric[key] if "higher_is_better" in metadata and not isinstance( metadata["higher_is_better"], (bool, type(None)) ): raise MDAComparisonError( f"{artifact.analysis_name}: metric {metric_key!r} higher_is_better must be bool or None" ) if "direction_labels" in metadata: labels = metadata["direction_labels"] if ( not isinstance(labels, Sequence) or isinstance(labels, (str, bytes, bytearray)) or len(labels) != 3 or any(not isinstance(label, str) for label in labels) ): raise MDAComparisonError( f"{artifact.analysis_name}: metric {metric_key!r} direction_labels must be " "three strings" ) metadata["direction_labels"] = tuple(labels) metadata.setdefault("higher_is_better", True) metadata.setdefault("direction_labels", ("decreased", "unchanged", "increased")) return metadata def _merge_metric_metadata( metric_metadata: dict[str, dict[str, Any]], metric_key: str, metadata: Mapping[str, Any], artifact: ConditionArtifact, ctx: MDAComparisonContext, ) -> None: """Record metric metadata and reject cross-condition disagreements. Parameters ---------- metric_metadata : dict[str, dict[str, Any]] Accumulated metadata by metric key. metric_key : str Metric key being merged. metadata : mapping Metadata for the current condition. artifact : ConditionArtifact Source condition artifact. ctx : MDAComparisonContext Comparison context used for diagnostics. """ normalized = dict(metadata) if metric_key not in metric_metadata: metric_metadata[metric_key] = normalized return previous = metric_metadata[metric_key] for key in ("higher_is_better", "direction_labels"): if previous.get(key) != normalized.get(key): raise MDAComparisonError( f"{ctx.analysis_name}: condition {artifact.condition_label!r} metric {metric_key!r} " f"metadata mismatch for {key}; expected {previous.get(key)!r}, " f"got {normalized.get(key)!r}" ) def _comparison_result_to_artifact( comparison: ComparisonResult, artifacts: Sequence[ConditionArtifact], ctx: MDAComparisonContext, metric_metadata: Mapping[str, Mapping[str, Any]], ) -> ComparisonArtifact: """Convert the scalar comparison model into an MDA comparison artifact. Parameters ---------- comparison : ComparisonResult Result from the shared scalar statistics pipeline. artifacts : sequence of ConditionArtifact Source condition artifacts. ctx : MDAComparisonContext Comparison context. metric_metadata : mapping Metadata keyed by metric name. Returns ------- ComparisonArtifact Serialized artifact envelope for cross-condition comparison. """ condition_labels = [artifact.condition_label for artifact in artifacts] payload = { "condition_summaries": [condition.model_dump() for condition in comparison.conditions], "pairwise_comparisons": [ pairwise.model_dump() for pairwise in comparison.pairwise_comparisons ], "anova": [result.model_dump() for result in comparison.anova or []], "ranking": comparison.ranking, "rankings_by_metric": comparison.rankings_by_metric, "metric_metadata": _json_metric_metadata(metric_metadata), "statistical_parameters": { "fdr_alpha": ctx.fdr_alpha, "ttest_method": ctx.ttest_method, "posthoc_method": ctx.posthoc_method, "control_label": ctx.control_label, "effective_control": ctx.effective_control, "equilibration": ctx.equilibration, }, } metadata: dict[str, Any] = { "n_conditions": len(condition_labels), "metrics": list(metric_metadata.keys()), "comparison_result_model": type(comparison).__name__, } if ctx.settings_fingerprint is not None: metadata["settings_fingerprint"] = ctx.settings_fingerprint return ComparisonArtifact( analysis_name=ctx.analysis_name, conditions=condition_labels, control_label=ctx.control_label, effective_control=ctx.effective_control, payload=payload, provenance={ "source": "mda_condition_artifact_comparison", "source_condition_labels": condition_labels, "source_replicates": { artifact.condition_label: list(artifact.replicates) for artifact in artifacts }, "source_condition_artifacts": [ { "condition_label": artifact.condition_label, "replicates": list(artifact.replicates), "source_replicates": list(artifact.source_replicates), "skipped_replicates": list(artifact.skipped_replicates), } for artifact in artifacts ], "skipped_replicates": { artifact.condition_label: list(artifact.skipped_replicates) for artifact in artifacts }, }, metadata=metadata, warnings=_combined_warnings(artifacts), ) def _json_metric_metadata( metric_metadata: Mapping[str, Mapping[str, Any]], ) -> dict[str, dict[str, Any]]: """Convert metric metadata to JSON-friendly dictionaries. Parameters ---------- metric_metadata : mapping Metadata keyed by metric name. Returns ------- dict[str, dict[str, Any]] JSON-friendly metadata. """ converted: dict[str, dict[str, Any]] = {} for metric_key, metadata in metric_metadata.items(): converted[metric_key] = { key: list(value) if isinstance(value, tuple) else value for key, value in metadata.items() } return converted def _combined_warnings(artifacts: Sequence[ConditionArtifact]) -> list[str]: """Return de-duplicated warnings from source condition artifacts. Parameters ---------- artifacts : sequence of ConditionArtifact Source artifacts. Returns ------- list[str] Warnings in first-seen order. """ warnings: list[str] = [] seen: set[str] = set() for artifact in artifacts: for warning in artifact.warnings: if warning in seen: continue warnings.append(warning) seen.add(warning) return warnings