Source code for polyzymd.compare.plotters.polymer_affinity

"""Polymer affinity score plotters for comparison workflow.

This module provides registered plotters for the polymer affinity score:

- AffinityStackedBarPlotter: Total affinity score per condition, with
  stacked segments showing each polymer type's contribution.
- AffinityGroupBarPlotter: Per-group breakdown comparing conditions,
  one figure per polymer type.

Both plotters load a ``PolymerAffinityScoreResult`` JSON saved by the
``polyzymd compare polymer-affinity`` command (in ``results/`` adjacent to
``comparison.yaml``).

Physics interpretation
----------------------
Score < 0  →  net favorable polymer-protein affinity
Score > 0  →  net unfavorable (avoidance dominates)
Score = 0  →  contacts match the surface-availability reference

Units are always kT (dimensionless, in units of k_bT).

Sign convention
---------------
More negative = stronger polymer-protein interaction. Diverging colormap
is not used here (unlike BFE heatmaps) because the primary display is
bar charts where sign is visually obvious.
"""

from __future__ import annotations

import logging
from pathlib import Path
from typing import TYPE_CHECKING, Any, Sequence

import numpy as np

from polyzymd.compare.plotter import BasePlotter, PlotterRegistry

if TYPE_CHECKING:
    from polyzymd.compare.config import ComparisonConfig
    from polyzymd.compare.results.polymer_affinity import PolymerAffinityScoreResult

logger = logging.getLogger(__name__)


def _find_affinity_result(
    data: dict[str, Any], labels: Sequence[str]
) -> "PolymerAffinityScoreResult | None":
    """Find and load PolymerAffinityScoreResult from the results/ directory.

    Searches for JSON files matching the naming conventions produced by the
    ``polymer-affinity`` CLI subcommand or the generic ``run`` command.

    Parameters
    ----------
    data : dict
        Mapping of condition_label -> condition data dict, plus an optional
        ``"__meta__"`` key with ``results_dir``.
    labels : sequence of str
        Condition labels in display order.

    Returns
    -------
    PolymerAffinityScoreResult or None
        Loaded result, or None if not found.
    """
    from polyzymd.compare.results.polymer_affinity import PolymerAffinityScoreResult

    _GLOBS = [
        "polymer_affinity_comparison_*.json",
        "affinity_comparison_*.json",
    ]

    def _try_load_from_dir(results_dir: Path) -> "PolymerAffinityScoreResult | None":
        if not results_dir.is_dir():
            return None
        files: list[Path] = []
        for pattern in _GLOBS:
            files.extend(results_dir.glob(pattern))
        if not files:
            return None
        result_file = max(files, key=lambda p: p.stat().st_mtime)
        try:
            result = PolymerAffinityScoreResult.load(result_file)
            logger.debug(f"Loaded affinity result from {result_file}")
            return result
        except Exception as e:
            logger.warning(f"Failed to load affinity result {result_file}: {e}")
            return None

    # Primary: __meta__.results_dir
    meta = data.get("__meta__")
    if meta is not None:
        results_dir = meta.get("results_dir")
        if results_dir is not None:
            result = _try_load_from_dir(Path(results_dir))
            if result is not None:
                return result
            logger.debug(f"No affinity result JSON in {results_dir} — falling back to heuristic")

    # Fallback: navigate from condition config paths
    candidate_dirs: list[Path] = []
    for label in labels:
        cond_data = data.get(label)
        if cond_data is None:
            continue
        condition = cond_data.get("condition")
        if condition is None:
            continue
        config_path = getattr(condition, "config", None)
        if config_path is None:
            continue
        config_path = Path(config_path)
        for candidate in [config_path.parent, config_path.parent.parent]:
            results_dir = candidate / "results"
            if results_dir.is_dir() and results_dir not in candidate_dirs:
                candidate_dirs.append(results_dir)

    for results_dir in candidate_dirs:
        result = _try_load_from_dir(results_dir)
        if result is not None:
            return result

    logger.info("No polymer affinity result JSON found — skipping affinity plots")
    return None


# ---------------------------------------------------------------------------
# Stacked bar plotter — total score per condition
# ---------------------------------------------------------------------------


[docs] @PlotterRegistry.register("affinity_stacked_bars") class AffinityStackedBarPlotter(BasePlotter): """Stacked bar chart of total affinity score per condition. Each bar represents one condition's total affinity score, with segments colored by polymer type contribution. This gives a quick overview of which polymer types contribute most to the total interaction strength. Loads ``PolymerAffinityScoreResult`` from ``results/`` adjacent to ``comparison.yaml``. """
[docs] @classmethod def plot_type(cls) -> str: return "affinity_stacked_bars"
[docs] def can_plot(self, comparison_config: "ComparisonConfig", analysis_type: str) -> bool: if analysis_type != "polymer_affinity": return False return self.settings.polymer_affinity.generate_stacked_bars
[docs] def plot( self, data: dict[str, Any], labels: Sequence[str], output_dir: Path, **kwargs: Any, ) -> list[Path]: """Generate stacked bar chart of affinity scores by condition. Parameters ---------- data : dict Condition data dict from orchestrator. labels : sequence of str Condition labels in display order. output_dir : Path Directory to save plot files. Returns ------- list[Path] Paths to generated plot files. """ import matplotlib.pyplot as plt t = self.theme result = _find_affinity_result(data, labels) if result is None: return [] affinity_settings = self.settings.polymer_affinity # Order conditions by labels cond_labels = [c.label for c in result.conditions] display_labels = [lbl for lbl in labels if lbl in cond_labels] if not display_labels: display_labels = cond_labels # Collect polymer types across all conditions all_polymer_types = result.polymer_types if not all_polymer_types: logger.info("No polymer types in affinity result — skipping stacked bars") return [] n_conds = len(display_labels) n_poly = len(all_polymer_types) colors = self._get_colors(n_poly) fig, ax = plt.subplots(figsize=affinity_settings.figsize_stacked, dpi=self.settings.dpi) x = np.arange(n_conds) bottoms_neg = np.zeros(n_conds) bottoms_pos = np.zeros(n_conds) for poly_idx, poly_type in enumerate(all_polymer_types): values = [] for cond_label in display_labels: cond = result.get_condition(cond_label) if cond is None: values.append(0.0) continue # Find this polymer type's score pts = [s for s in cond.polymer_type_scores if s.polymer_type == poly_type] if pts: values.append(pts[0].total_score) else: values.append(0.0) vals = np.array(values) # Stack negative bars downward, positive upward neg_vals = np.where(vals < 0, vals, 0) pos_vals = np.where(vals >= 0, vals, 0) if np.any(neg_vals != 0): ax.bar( x, neg_vals, bottom=bottoms_neg, color=colors[poly_idx], label=poly_type, alpha=t.bar_alpha, edgecolor="white", linewidth=t.bar_linewidth, ) bottoms_neg += neg_vals if np.any(pos_vals != 0): ax.bar( x, pos_vals, bottom=bottoms_pos, color=colors[poly_idx], label=poly_type if np.all(neg_vals == 0) else None, alpha=t.bar_alpha, edgecolor="white", linewidth=t.bar_linewidth, ) bottoms_pos += pos_vals # Add total score markers with uncertainty if affinity_settings.show_error_bars: totals = [] errors = [] for cond_label in display_labels: cond = result.get_condition(cond_label) if cond is not None: totals.append(cond.total_score) errors.append( cond.total_score_uncertainty if cond.total_score_uncertainty else 0.0 ) else: totals.append(0.0) errors.append(0.0) ax.errorbar( x, totals, yerr=errors, fmt="k_", capsize=t.bar_capsize, capthick=1.5, linewidth=0, elinewidth=1.5, label="Total ± SEM", zorder=10, ) ax.axhline(y=0, color="black", linewidth=1.0, linestyle="-") ax.set_xticks(x) ax.set_xticklabels(display_labels, rotation=35, ha="right", fontsize=t.tick_fontsize) # Temperature string temp_str = "" if result.conditions: temps = {c.temperature_K for c in result.conditions} if len(temps) == 1: temp_str = f" ({next(iter(temps)):.0f} K)" self._apply_axis_style( ax, title=f"Polymer Affinity Score by Condition{temp_str}", ylabel=r"Affinity Score ($k_\mathrm{b}T$)", ) # De-duplicate legend entries handles, legend_labels = ax.get_legend_handles_labels() seen: dict[str, Any] = {} unique_handles = [] unique_labels = [] for handle, lbl in zip(handles, legend_labels): if lbl not in seen: seen[lbl] = True unique_handles.append(handle) unique_labels.append(lbl) self._apply_legend( ax, fontsize=t.small_fontsize, handles=unique_handles, labels=unique_labels, framealpha=0.7, ) plt.tight_layout() output_path = self._get_output_path(output_dir, "affinity_stacked_bars") return [ self._save_figure( fig, output_path, experimental_features=("polymer_affinity",), ) ]
# --------------------------------------------------------------------------- # Per-group bar plotter — breakdown by AA group # ---------------------------------------------------------------------------
[docs] @PlotterRegistry.register("affinity_group_bars") class AffinityGroupBarPlotter(BasePlotter): """Grouped bar chart of per-group affinity score contributions. Creates one figure per polymer type with: - Groups on x-axis: protein groups (AA classes) - Bars within each group: one per condition - Error bars: SEM on per-group affinity score - Reference line at score = 0 Loads ``PolymerAffinityScoreResult`` from ``results/``. """
[docs] @classmethod def plot_type(cls) -> str: return "affinity_group_bars"
[docs] def can_plot(self, comparison_config: "ComparisonConfig", analysis_type: str) -> bool: if analysis_type != "polymer_affinity": return False return self.settings.polymer_affinity.generate_group_bars
[docs] def plot( self, data: dict[str, Any], labels: Sequence[str], output_dir: Path, **kwargs: Any, ) -> list[Path]: """Generate grouped bar charts of per-group affinity scores. Parameters ---------- data : dict Condition data dict from orchestrator. labels : sequence of str Condition labels in display order. output_dir : Path Directory to save plot files. Returns ------- list[Path] Paths to generated plot files. """ import matplotlib.pyplot as plt from polyzymd.analysis.common.aa_classification import CANONICAL_AA_CLASS_ORDER t = self.theme result = _find_affinity_result(data, labels) if result is None: return [] affinity_settings = self.settings.polymer_affinity cond_labels = [c.label for c in result.conditions] display_labels = [lbl for lbl in labels if lbl in cond_labels] if not display_labels: display_labels = cond_labels # Filter to conditions with data valid_labels = [ lbl for lbl in display_labels if lbl in cond_labels and any(e.affinity_score is not None for e in result.get_condition(lbl).entries) ] if not valid_labels: logger.info("No conditions with affinity score data — skipping group bars") return [] polymer_types = result.polymer_types protein_groups = result.protein_groups if not polymer_types or not protein_groups: return [] # Sort protein groups canonically ordered_groups = [g for g in CANONICAL_AA_CLASS_ORDER if g in protein_groups] for g in sorted(protein_groups): if g not in ordered_groups: ordered_groups.append(g) n_conds = len(valid_labels) n_groups = len(ordered_groups) colors = self._get_colors(n_conds) n_poly = len(polymer_types) # Temperature string temp_str = "" if result.conditions: temps = {c.temperature_K for c in result.conditions} if len(temps) == 1: temp_str = f" ({next(iter(temps)):.0f} K)" output_paths: list[Path] = [] for poly_type in polymer_types: fig, ax = plt.subplots( figsize=affinity_settings.figsize_group_bars, dpi=self.settings.dpi ) x = np.arange(n_groups) series: list[tuple[str, list[float], list[float]]] = [] for cond_label in valid_labels: cond = result.get_condition(cond_label) means: list[float] = [] sems: list[float] = [] for group in ordered_groups: # Find matching entry entry = None if cond is not None: for e in cond.entries: if e.polymer_type == poly_type and e.protein_group == group: entry = e break if entry is not None and entry.affinity_score is not None: means.append(entry.affinity_score) # Prefer replicate-based SEM per_rep = entry.affinity_score_per_replicate if len(per_rep) >= 2: sem = float(np.std(per_rep, ddof=1) / np.sqrt(len(per_rep))) elif entry.affinity_score_uncertainty is not None: sem = entry.affinity_score_uncertainty else: sem = 0.0 sems.append(sem) else: means.append(0.0) sems.append(0.0) series.append((cond_label, means, sems)) self._grouped_bars( ax, x, series, colors, show_error=affinity_settings.show_error_bars, reference_label="Score = 0 (neutral)", bar_edgecolor="none", ) poly_label = f": {poly_type}" if n_poly > 1 else "" self._apply_axis_style( ax, title=f"Per-Group Affinity Score{poly_label}{temp_str}", xlabel="Amino Acid Group", ylabel=r"Affinity Score ($k_\mathrm{b}T$)", ) ax.set_xticks(x) ax.set_xticklabels(ordered_groups, rotation=35, ha="right", fontsize=t.tick_fontsize) self._apply_legend( ax, fontsize=t.small_fontsize, framealpha=0.7, ) # Guide lines at ±1 kT ax.axhline(y=1.0, color="gray", linestyle=":", linewidth=1.0, alpha=0.6) ax.axhline(y=-1.0, color="gray", linestyle=":", linewidth=1.0, alpha=0.6) plt.tight_layout() stem = ( f"affinity_group_bars_{poly_type.lower()}" if n_poly > 1 else "affinity_group_bars" ) output_path = self._get_output_path(output_dir, stem) output_paths.append( self._save_figure( fig, output_path, experimental_features=("polymer_affinity",), ) ) return output_paths