"""Base classes for comparison analysis.
This module provides abstract base classes that consolidate common patterns
across all comparator types, following the Template Method design pattern.
Classes
-------
BaseConditionSummary
Abstract base for condition-level summary statistics.
BaseComparisonResult
Abstract base for complete comparison results with save/load.
PairwiseComparison
Shared model for statistical comparison between two conditions.
ANOVASummary
Shared model for ANOVA results.
BaseComparator
Abstract base implementing the Template Method pattern for comparisons.
Design Principles
-----------------
1. Open-Closed Principle: New comparators extend base classes without modifying them.
2. Template Method: `compare()` defines the algorithm skeleton; subclasses fill in specifics.
3. DRY: Statistical tests, pairwise logic, and serialization are implemented once.
"""
from __future__ import annotations
import json
import logging
from abc import ABC, abstractmethod
from datetime import datetime
from pathlib import Path
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Self, TypeVar
from pydantic import BaseModel, Field
from polyzymd import __version__
from polyzymd.analysis.core.metric_type import MetricType
from polyzymd.compare.statistics import (
cohens_d,
independent_ttest,
one_way_anova,
percent_change,
)
if TYPE_CHECKING:
from polyzymd.compare.config import ComparisonConfig, ConditionConfig
logger = logging.getLogger("polyzymd.compare")
# ============================================================================
# Shared Result Models (DRY - used by all comparators)
# ============================================================================
[docs]
class PairwiseComparison(BaseModel):
"""Statistical comparison between two conditions.
This is the standard pairwise comparison result used across all
comparator types. For comparators that need additional fields
(e.g., multiple metrics), subclass this model.
Attributes
----------
condition_a : str
Label of first condition (typically control/reference).
condition_b : str
Label of second condition (typically treatment).
metric : str
Name of the metric being compared.
t_statistic : float
T-test statistic.
p_value : float
Two-tailed p-value.
cohens_d : float
Effect size (Cohen's d).
effect_size_interpretation : str
"negligible", "small", "medium", or "large".
direction : str
Interpretation of change (e.g., "stabilizing", "improving").
significant : bool
Whether p < 0.05.
percent_change : float
Percent change from condition_a to condition_b.
"""
condition_a: str
condition_b: str
metric: str = "default"
t_statistic: float
p_value: float
cohens_d: float
effect_size_interpretation: str
direction: str
significant: bool
percent_change: float
[docs]
class ANOVASummary(BaseModel):
"""One-way ANOVA result summary.
Attributes
----------
metric : str
Name of the metric tested (e.g., "rmsf", "coverage").
f_statistic : float
F-statistic from ANOVA.
p_value : float
P-value for the test.
significant : bool
Whether p < 0.05.
"""
metric: str = "default"
f_statistic: float
p_value: float
significant: bool
# ============================================================================
# Abstract Base Classes for Results
# ============================================================================
[docs]
class BaseConditionSummary(BaseModel, ABC):
"""Abstract base class for condition-level summary statistics.
All condition summaries share these common fields. Subclasses add
analysis-specific fields (e.g., mean_rmsf, coverage_mean).
Attributes
----------
label : str
Display name for this condition.
config_path : str
Path to the simulation config file.
n_replicates : int
Number of replicates included.
replicate_values : list[float]
Per-replicate values of the primary metric (for statistical tests).
"""
label: str
config_path: str
n_replicates: int
replicate_values: list[float]
@property
@abstractmethod
def primary_metric_value(self) -> float:
"""Return the primary metric value for ranking/comparison.
This is used by BaseComparator for sorting and statistical tests.
"""
...
@property
@abstractmethod
def primary_metric_sem(self) -> float:
"""Return the SEM of the primary metric."""
...
# Type variable for condition summary subtypes
TConditionSummary = TypeVar("TConditionSummary", bound=BaseConditionSummary)
TPairwiseComparison = TypeVar("TPairwiseComparison", bound=PairwiseComparison)
[docs]
class BaseComparisonResult(BaseModel, ABC, Generic[TConditionSummary, TPairwiseComparison]):
"""Abstract base class for comparison results.
Provides common serialization (save/load) and accessor methods.
Subclasses define analysis-specific fields.
Attributes
----------
metric : str
The primary metric being compared (e.g., "rmsf", "simultaneous_contact_fraction").
name : str
Name of the comparison project.
control_label : str, optional
Label of the control condition.
conditions : list[TConditionSummary]
Summary for each condition.
pairwise_comparisons : list[TPairwiseComparison]
Statistical comparisons (all vs control, or all pairs).
anova : ANOVASummary, optional
ANOVA result if 3+ conditions.
ranking : list[str]
Labels sorted by primary metric.
equilibration_time : str
Equilibration time used.
created_at : datetime
When the analysis was run.
polyzymd_version : str
Version of polyzymd used.
"""
# Class variable - subclasses should override
comparison_type: ClassVar[str] = "base"
metric: str
name: str
control_label: str | None = None
conditions: list[Any] # Will be overridden in subclasses with specific type
pairwise_comparisons: list[Any] # Will be overridden in subclasses
anova: ANOVASummary | list[ANOVASummary] | None = None
ranking: list[str]
equilibration_time: str
created_at: datetime
polyzymd_version: str = __version__
[docs]
def save(self, path: Path | str) -> Path:
"""Save result to JSON file.
Parameters
----------
path : Path or str
Output path.
Returns
-------
Path
Path to saved file.
"""
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(self.model_dump_json(indent=2))
return path
[docs]
@classmethod
def load(cls, path: Path | str) -> Self:
"""Load result from JSON file.
Parameters
----------
path : Path or str
Path to JSON file.
Returns
-------
Self
Loaded result.
"""
path = Path(path)
return cls.model_validate_json(path.read_text())
[docs]
def get_condition(self, label: str) -> Any:
"""Get a condition by label.
Parameters
----------
label : str
Condition label.
Returns
-------
BaseConditionSummary
The matching condition.
Raises
------
KeyError
If condition not found.
"""
for cond in self.conditions:
if cond.label == label:
return cond
raise KeyError(f"Condition '{label}' not found")
[docs]
def get_comparison(self, label: str) -> Any | None:
"""Get pairwise comparison for a condition vs control.
Parameters
----------
label : str
Treatment condition label.
Returns
-------
PairwiseComparison or None
The comparison, or None if not found.
"""
for comp in self.pairwise_comparisons:
if comp.condition_b == label:
return comp
return None
# ============================================================================
# Base Comparator (Template Method Pattern)
# ============================================================================
# Type variables for generic comparator
TAnalysisSettings = TypeVar("TAnalysisSettings")
TComparisonSettings = TypeVar("TComparisonSettings")
TConditionData = TypeVar("TConditionData")
TResult = TypeVar("TResult", bound=BaseComparisonResult)
[docs]
class BaseComparator(ABC, Generic[TAnalysisSettings, TConditionData, TConditionSummary, TResult]):
"""Abstract base class for all comparators using Template Method pattern.
The `compare()` method defines the comparison algorithm skeleton:
1. Load/compute analysis for each condition
2. Build condition summaries
3. Compute pairwise statistical comparisons
4. Compute ANOVA (if 3+ conditions)
5. Rank conditions
6. Build and return result
Subclasses implement the abstract methods to customize each step.
Parameters
----------
config : ComparisonConfig
Comparison configuration defining conditions.
analysis_settings : TAnalysisSettings
Analysis-specific settings.
equilibration : str, optional
Equilibration time override.
Type Parameters
---------------
TAnalysisSettings
Type of analysis settings (e.g., RMSFAnalysisSettings).
TConditionData
Type of raw data loaded for each condition.
TConditionSummary
Type of condition summary (e.g., RMSFConditionSummary).
TResult
Type of comparison result (e.g., RMSFComparisonResult).
"""
# Class variable - subclasses should override
comparison_type: ClassVar[str] = "base"
[docs]
def __init__(
self,
config: "ComparisonConfig",
analysis_settings: TAnalysisSettings,
equilibration: str | None = None,
):
self.config = config
self.analysis_settings = analysis_settings
self.equilibration = equilibration or config.defaults.equilibration_time
[docs]
@classmethod
@abstractmethod
def comparison_type_name(cls) -> str:
"""Return the comparison type identifier (e.g., "rmsf", "contacts").
Returns
-------
str
Type identifier used in registry and CLI.
"""
...
@property
@abstractmethod
def metric_type(self) -> MetricType:
"""Declare whether this comparator's metric is mean or variance-based.
This determines how autocorrelation is handled in the underlying analysis:
- **MEAN_BASED**: Use all frames for computation, correct uncertainty
using N_eff (effective sample size). Examples: average distance,
contact fraction, catalytic triad proximity.
- **VARIANCE_BASED**: Subsample to independent frames separated by 2τ
(correlation time) to avoid bias in variance estimates. Examples:
RMSF, fluctuation metrics.
Contributors implementing new comparators MUST declare the appropriate
metric type to ensure correct statistical treatment per LiveCoMS
best practices (Grossfield et al., 2018).
Returns
-------
MetricType
The metric type for this comparator.
References
----------
- Grossfield et al. (2018) LiveCoMS 1:5067 (Best Practices for Uncertainty)
- GitHub: dmzuckerman/Sampling-Uncertainty
"""
...
[docs]
def compare(self, recompute: bool = False) -> TResult:
"""Run comparison across all conditions (Template Method).
This method defines the algorithm skeleton. Subclasses customize
behavior by implementing the abstract hook methods.
Parameters
----------
recompute : bool, optional
If True, force recompute even if cached results exist.
Returns
-------
TResult
Complete comparison results with statistics and rankings.
"""
logger.info(f"Starting {self.comparison_type_name()} comparison: {self.config.name}")
logger.info(f"Conditions: {len(self.config.conditions)}")
logger.info(f"Equilibration: {self.equilibration}")
# Step 1: Filter conditions (optional hook - default returns all)
valid_conditions, excluded_conditions = self._filter_conditions()
if excluded_conditions:
logger.warning(
f"Excluding {len(excluded_conditions)} condition(s): "
f"{[c.label for c in excluded_conditions]}"
)
# Step 2: Load or compute analysis for each condition
condition_data: list[tuple["ConditionConfig", TConditionData]] = []
for cond in valid_conditions:
data = self._load_or_compute(cond, recompute)
condition_data.append((cond, data))
# Step 3: Build condition summaries
summaries: list[TConditionSummary] = []
for cond, data in condition_data:
summary = self._build_condition_summary(cond, data)
summaries.append(summary)
# Step 4: Determine effective control
effective_control = self._get_effective_control(excluded_conditions)
# Step 5: Compute pairwise comparisons
comparisons = self._compute_pairwise_comparisons(summaries, effective_control)
# Step 6: ANOVA if 3+ conditions
anova = None
if len(summaries) >= 3:
anova = self._compute_anova(summaries)
# Step 7: Rank conditions
ranking = self._compute_ranking(summaries)
# Step 8: Build result
return self._build_result(
summaries=summaries,
comparisons=comparisons,
anova=anova,
ranking=ranking,
effective_control=effective_control,
excluded_conditions=excluded_conditions,
)
# ========================================================================
# Abstract Methods (must be implemented by subclasses)
# ========================================================================
@abstractmethod
def _load_or_compute(
self,
cond: "ConditionConfig",
recompute: bool,
) -> TConditionData:
"""Load existing results or compute analysis for a condition.
Parameters
----------
cond : ConditionConfig
Condition to analyze.
recompute : bool
Force recompute even if cached.
Returns
-------
TConditionData
Raw analysis data for this condition.
"""
...
@abstractmethod
def _build_condition_summary(
self,
cond: "ConditionConfig",
data: TConditionData,
) -> TConditionSummary:
"""Build a condition summary from raw data.
Parameters
----------
cond : ConditionConfig
Condition configuration.
data : TConditionData
Raw analysis data.
Returns
-------
TConditionSummary
Structured condition summary.
"""
...
def _build_result(
self,
summaries: list[TConditionSummary],
comparisons: list[Any],
anova: ANOVASummary | list[ANOVASummary] | None,
ranking: list[str],
effective_control: str | None,
excluded_conditions: list["ConditionConfig"],
) -> TResult:
"""Build the final comparison result.
Subclasses that use the base ``compare()`` template must implement
this method. Subclasses that override ``compare()`` entirely do
not need to.
Parameters
----------
summaries : list[TConditionSummary]
Condition summaries.
comparisons : list
Pairwise comparison results.
anova : ANOVASummary or list or None
ANOVA result(s).
ranking : list[str]
Ranked condition labels.
effective_control : str or None
Effective control label.
excluded_conditions : list[ConditionConfig]
Conditions that were excluded.
Returns
-------
TResult
Complete comparison result.
"""
raise NotImplementedError(
f"{type(self).__name__} must implement _build_result() or override compare() entirely."
)
def _get_replicate_values(self, summary: TConditionSummary) -> list[float]:
"""Extract per-replicate values for statistical tests.
Subclasses that use the base ``_compute_pairwise_comparisons()``
must implement this. Subclasses with custom pairwise logic do
not need to.
Parameters
----------
summary : TConditionSummary
Condition summary.
Returns
-------
list[float]
Per-replicate values of the primary metric.
"""
raise NotImplementedError(
f"{type(self).__name__} must implement _get_replicate_values() "
"or override _compute_pairwise_comparisons()."
)
def _get_mean_value(self, summary: TConditionSummary) -> float:
"""Get the mean value of the primary metric.
Subclasses that use the base ``_compute_pairwise_comparisons()``
must implement this. Subclasses with custom pairwise logic do
not need to.
Parameters
----------
summary : TConditionSummary
Condition summary.
Returns
-------
float
Mean value.
"""
raise NotImplementedError(
f"{type(self).__name__} must implement _get_mean_value() "
"or override _compute_pairwise_comparisons()."
)
@property
def _direction_labels(self) -> tuple[str, str, str]:
"""Labels for (negative_change, unchanged, positive_change).
Override this property for simple 3-way direction labeling where
any negative percent change maps to one label and any positive maps
to another. The default implementation raises NotImplementedError;
subclasses must define this OR override ``_interpret_direction()``
directly for custom semantics (e.g., threshold-based dead zones).
Returns
-------
tuple[str, str, str]
``(negative_label, unchanged_label, positive_label)``
"""
raise NotImplementedError(
f"{type(self).__name__} must define _direction_labels or "
"override _interpret_direction()"
)
def _interpret_direction(self, percent_change: float) -> str:
"""Interpret the direction of change for this metric.
The default implementation uses ``_direction_labels`` with a zero
threshold. Override this method for custom semantics such as
threshold-based dead zones or inverted polarity logic.
Parameters
----------
percent_change : float
Percent change from control to treatment.
Returns
-------
str
Direction interpretation (e.g., "stabilizing", "improving").
"""
neg, unchanged, pos = self._direction_labels
if percent_change < 0:
return neg
elif percent_change > 0:
return pos
return unchanged
def _rank_summaries(self, summaries: list[TConditionSummary]) -> list[TConditionSummary]:
"""Sort summaries by the primary metric.
Subclasses that use the base ``compare()`` template must implement
this. Subclasses that override ``compare()`` entirely and compute
their own rankings do not need to.
Parameters
----------
summaries : list[TConditionSummary]
Condition summaries to rank.
Returns
-------
list[TConditionSummary]
Sorted summaries (best first).
"""
raise NotImplementedError(
f"{type(self).__name__} must implement _rank_summaries() "
"or override compare() entirely."
)
# ========================================================================
# Hook Methods (can be overridden by subclasses)
# ========================================================================
def _filter_conditions(
self,
) -> tuple[list["ConditionConfig"], list["ConditionConfig"]]:
"""Filter conditions before analysis.
Override this to exclude certain conditions (e.g., no-polymer conditions
for contacts analysis).
Returns
-------
tuple[list[ConditionConfig], list[ConditionConfig]]
(valid_conditions, excluded_conditions)
"""
return self.config.conditions, []
def _get_effective_control(
self,
excluded_conditions: list["ConditionConfig"],
) -> str | None:
"""Determine the effective control label.
If the configured control was excluded, returns None.
Parameters
----------
excluded_conditions : list[ConditionConfig]
Conditions that were excluded.
Returns
-------
str or None
Effective control label.
"""
if not self.config.control:
return None
excluded_labels = {c.label for c in excluded_conditions}
if self.config.control in excluded_labels:
logger.warning(
f"Control '{self.config.control}' was excluded. Comparisons will be pairwise."
)
return None
return self.config.control
def _use_rmsf_mode_for_cohens_d(self) -> bool:
"""Whether to use RMSF-specific Cohen's d interpretation.
Override in RMSF comparator to return True.
Returns
-------
bool
True if negative d should be "stabilizing".
"""
return False
def _resolve_condition_output_dir(self, label: str, analysis_subdir: str) -> Path | None:
"""Resolve a condition-specific output directory for analysis results.
When running in comparison mode (config loaded from a YAML file with
``source_path`` set), returns a condition-specific path under the
comparison project directory. This prevents cache collisions when
multiple conditions share the same ``projects_directory``.
When ``source_path`` is ``None`` (standalone / programmatic usage),
returns ``None``, which tells downstream calculators to use their
default output directory (``projects_directory``).
Parameters
----------
label : str
Condition label (e.g. ``"SBMA-EGMA 25%"``).
analysis_subdir : str
Analysis-type subdirectory name (e.g. ``"rmsf"``,
``"catalytic_triad"``, ``"contacts"``, ``"distances"``).
Returns
-------
Path or None
Condition-specific output directory, or ``None`` for default
behaviour.
"""
if self.config.source_path is None:
return None
from polyzymd.compare.comparators._utils import sanitize_label
comparison_dir = self.config.source_path.parent
return comparison_dir / "analysis" / sanitize_label(label) / analysis_subdir
# ========================================================================
# Shared Implementation Methods (DRY)
# ========================================================================
def _compute_pairwise_comparisons(
self,
summaries: list[TConditionSummary],
effective_control: str | None,
) -> list[PairwiseComparison]:
"""Compute pairwise statistical comparisons.
If a control is specified, compares all conditions vs control.
Otherwise, compares all pairs.
Parameters
----------
summaries : list[TConditionSummary]
Condition summaries.
effective_control : str or None
Control condition label.
Returns
-------
list[PairwiseComparison]
Pairwise comparison results.
"""
comparisons = []
if effective_control:
# Compare all vs control
control = next(s for s in summaries if s.label == effective_control)
treatments = [s for s in summaries if s.label != effective_control]
for treatment in treatments:
comp = self._compare_pair(control, treatment)
comparisons.append(comp)
else:
# Compare all pairs
for i, cond_a in enumerate(summaries):
for cond_b in summaries[i + 1 :]:
comp = self._compare_pair(cond_a, cond_b)
comparisons.append(comp)
return comparisons
def _compare_pair(
self,
cond_a: TConditionSummary,
cond_b: TConditionSummary,
) -> PairwiseComparison:
"""Compare two conditions statistically.
Override this method for comparators that need custom comparison
logic (e.g., multiple metrics like contacts).
Parameters
----------
cond_a : TConditionSummary
First condition (typically control).
cond_b : TConditionSummary
Second condition (typically treatment).
Returns
-------
PairwiseComparison
Statistical comparison result.
"""
values_a = self._get_replicate_values(cond_a)
values_b = self._get_replicate_values(cond_b)
mean_a = self._get_mean_value(cond_a)
mean_b = self._get_mean_value(cond_b)
# T-test
ttest = independent_ttest(values_a, values_b)
# Effect size
effect = cohens_d(values_a, values_b, rmsf_mode=self._use_rmsf_mode_for_cohens_d())
# Percent change
pct = percent_change(mean_a, mean_b)
# Direction interpretation
direction = self._interpret_direction(pct)
return PairwiseComparison(
condition_a=cond_a.label,
condition_b=cond_b.label,
metric=self.comparison_type_name(),
t_statistic=ttest.t_statistic,
p_value=ttest.p_value,
cohens_d=effect.cohens_d,
effect_size_interpretation=effect.interpretation,
direction=direction,
significant=ttest.significant,
percent_change=pct,
)
def _compute_anova(
self,
summaries: list[TConditionSummary],
) -> ANOVASummary:
"""Compute one-way ANOVA across all conditions.
Override this for comparators that test multiple metrics.
Parameters
----------
summaries : list[TConditionSummary]
Condition summaries.
Returns
-------
ANOVASummary
ANOVA result.
"""
groups = [self._get_replicate_values(s) for s in summaries]
result = one_way_anova(*groups)
return ANOVASummary(
metric=self.comparison_type_name(),
f_statistic=result.f_statistic,
p_value=result.p_value,
significant=result.significant,
)
def _compute_ranking(self, summaries: list[TConditionSummary]) -> list[str]:
"""Compute ranking of conditions.
Parameters
----------
summaries : list[TConditionSummary]
Condition summaries.
Returns
-------
list[str]
Labels in ranked order (best first).
"""
ranked = self._rank_summaries(summaries)
return [s.label for s in ranked]