Source code for polyzymd.compare.plotters.binding_free_energy

"""Binding free energy plotters for comparison workflow.

This module provides registered plotters for ΔG_sel (selectivity free energy)
analysis:

- BFEHeatmapPlotter: ΔG_sel heatmap with rows = AA groups, columns = conditions
- BFEBarPlotter: Grouped bar chart of ΔG_sel by AA residue class

Both plotters load a ``BindingFreeEnergyResult`` JSON saved by the
``polyzymd compare binding-free-energy`` command (in ``results/`` adjacent to
``comparison.yaml``) rather than per-condition analysis directories.

**Partition-aware plotting**

Each ``FreeEnergyEntry`` carries a ``partition_name`` field (e.g., "aa_class",
"lid_helices", "whole_lid_domain") that identifies which residue grouping
scheme produced that entry.  Different partitions use different denominators
(each partition's total exposed surface area), so mixing groups from different
partitions on the same figure is scientifically misleading.

Both plotters therefore produce one figure per (partition, polymer_type)
combination.  When only a single partition is present (the common case for
datasets that only use default AA-class grouping), filenames and titles omit
the partition name to preserve backward compatibility.

**Physics interpretation**

| ``ΔG_sel < 0`` →  preferential contact (polymer contacts this group more than
  expected from surface availability alone)
| ``ΔG_sel > 0`` →  contact avoidance (polymer contacts this group less than expected)
| ``ΔG_sel = 0`` →  contacts match surface-availability reference exactly

Diverging colormap (RdBu_r by default) is centered at 0.0:

- Blue (negative)  → preference
- White (zero)     → neutral
- Red  (positive)  → avoidance

Units are whatever was specified in analysis_settings.binding_free_energy.units
(kT by default — dimensionless, in units of k_bT).
"""

from __future__ import annotations

import logging
from pathlib import Path
from typing import TYPE_CHECKING, Any, Sequence

import numpy as np

from polyzymd.analysis.common.aa_classification import CANONICAL_AA_CLASS_ORDER
from polyzymd.compare.plotter import BasePlotter, PlotterRegistry

if TYPE_CHECKING:
    from polyzymd.compare.config import ComparisonConfig
    from polyzymd.compare.results.binding_free_energy import BindingFreeEnergyResult

logger = logging.getLogger(__name__)


def _unit_label_mpl(units: str) -> str:
    """Return matplotlib-compatible unit label with subscript for kT.

    Parameters
    ----------
    units : str
        Energy unit string ("kT", "kcal/mol", or "kJ/mol").

    Returns
    -------
    str
        Label suitable for matplotlib axes/colorbars. For "kT" this returns
        the mathtext ``$k_\\mathrm{b}T$`` to render a subscript "b".
    """
    if units == "kT":
        return r"$k_\mathrm{b}T$"
    return units


def _find_bfe_result(
    data: dict[str, Any], labels: Sequence[str]
) -> "BindingFreeEnergyResult | None":
    """Find and load BindingFreeEnergyResult from the results/ directory.

    The BFE result JSON lives adjacent to ``comparison.yaml`` under
    ``results/``.  Two naming conventions exist:

    - ``binding_free_energy_comparison_{name}.json`` (generic ``run_comparison``)
    - ``bfe_comparison_{name}.json`` (dedicated ``compare binding-free-energy``)

    Both are searched.  The most recently modified match wins.

    The orchestrator provides a ``__meta__`` entry in *data* with the
    ``results_dir`` path (derived from ``comparison.yaml``'s location).
    This is the primary lookup.  If ``__meta__`` is absent we fall back
    to heuristic navigation from condition config paths.

    Parameters
    ----------
    data : dict
        Mapping of condition_label -> condition data dict, plus an optional
        ``"__meta__"`` key with ``results_dir``.
    labels : sequence of str
        Condition labels in display order.

    Returns
    -------
    BindingFreeEnergyResult or None
        Loaded result, or None if not found.
    """
    from polyzymd.compare.results.binding_free_energy import BindingFreeEnergyResult

    # Both naming conventions that may exist on disk
    _BFE_GLOBS = [
        "binding_free_energy_comparison_*.json",
        "bfe_comparison_*.json",
    ]

    def _try_load_from_dir(results_dir: Path) -> "BindingFreeEnergyResult | None":
        """Try loading the most recent BFE result JSON from a directory."""
        if not results_dir.is_dir():
            return None
        bfe_files: list[Path] = []
        for pattern in _BFE_GLOBS:
            bfe_files.extend(results_dir.glob(pattern))
        if not bfe_files:
            return None
        bfe_file = max(bfe_files, key=lambda p: p.stat().st_mtime)
        try:
            result = BindingFreeEnergyResult.load(bfe_file)
            logger.debug(f"Loaded BFE result from {bfe_file}")
            return result
        except Exception as e:
            logger.warning(f"Failed to load BFE result {bfe_file}: {e}")
            return None

    # --- Primary path: use __meta__.results_dir from the orchestrator ---
    meta = data.get("__meta__")
    if meta is not None:
        results_dir = meta.get("results_dir")
        if results_dir is not None:
            result = _try_load_from_dir(Path(results_dir))
            if result is not None:
                return result
            logger.debug(f"No BFE result JSON in {results_dir} — falling back to heuristic")

    # --- Fallback: navigate from condition config paths ---
    candidate_dirs: list[Path] = []

    for label in labels:
        cond_data = data.get(label)
        if cond_data is None:
            continue
        condition = cond_data.get("condition")
        if condition is None:
            continue
        config_path = getattr(condition, "config", None)
        if config_path is None:
            continue
        config_path = Path(config_path)
        # condition config lives in {project_root}/{condition_name}/...
        # Try parent (condition dir) and grandparent (project root)
        for candidate in [config_path.parent, config_path.parent.parent]:
            results_dir = candidate / "results"
            if results_dir.is_dir() and results_dir not in candidate_dirs:
                candidate_dirs.append(results_dir)

    for results_dir in candidate_dirs:
        result = _try_load_from_dir(results_dir)
        if result is not None:
            return result

    logger.info("No BFE result JSON found in any results/ directory - skipping BFE plots")
    return None


def _sorted_groups(groups: list[str]) -> list[str]:
    """Sort AA groups in canonical order, with non-canonical groups appended."""
    ordered = [g for g in CANONICAL_AA_CLASS_ORDER if g in groups]
    for g in sorted(groups):
        if g not in ordered:
            ordered.append(g)
    return ordered


def _get_partitions(result: "BindingFreeEnergyResult") -> dict[str, list[str]]:
    """Build a mapping of partition_name -> sorted list of protein groups.

    Scans all ``FreeEnergyEntry`` objects across every condition to discover
    which protein groups belong to each partition.  This reconstructs the
    partition→groups structure that is lost by the flat ``protein_groups``
    list on ``BindingFreeEnergyResult``.

    Parameters
    ----------
    result : BindingFreeEnergyResult
        Loaded BFE comparison result.

    Returns
    -------
    dict[str, list[str]]
        Mapping of partition name to its sorted group list.  The sort order
        uses ``_sorted_groups`` (canonical AA-class ordering first, then
        alphabetical for non-canonical groups).
    """
    partition_groups: dict[str, set[str]] = {}
    for cond in result.conditions:
        for entry in cond.entries:
            partition_groups.setdefault(entry.partition_name, set()).add(entry.protein_group)
    # Stable ordering: aa_class first (most common), then alphabetical
    ordered_partitions: dict[str, list[str]] = {}
    partition_names = sorted(partition_groups.keys())
    if "aa_class" in partition_names:
        partition_names.remove("aa_class")
        partition_names.insert(0, "aa_class")
    for pname in partition_names:
        ordered_partitions[pname] = _sorted_groups(list(partition_groups[pname]))
    return ordered_partitions


def _partition_display_name(partition_name: str) -> str:
    """Convert a partition name to a human-readable display string.

    Examples: "aa_class" → "AA Class", "lid_helices" → "Lid Helices".
    """
    return partition_name.replace("_", " ").title()


# ---------------------------------------------------------------------------
# Heatmap plotter
# ---------------------------------------------------------------------------


[docs] @PlotterRegistry.register("bfe_heatmap") class BFEHeatmapPlotter(BasePlotter): """Generate ΔG_sel heatmap comparing binding free energy across conditions. Creates one figure per (partition, polymer_type) combination: - Rows: protein groups belonging to that partition - Columns: Conditions (e.g., 0% SBMA, 25% SBMA, …) - Color: ΔG_sel value with diverging colormap centered at 0 When only a single partition exists (e.g., just "aa_class"), filenames and titles match the previous single-partition behavior for backward compatibility. Loads ``BindingFreeEnergyResult`` from ``results/`` adjacent to ``comparison.yaml`` (accepts both ``binding_free_energy_comparison_*.json`` and ``bfe_comparison_*.json`` naming conventions). Sign convention --------------- Blue (negative ΔG_sel) = preferential contact Red (positive ΔG_sel) = contact avoidance """
[docs] @classmethod def plot_type(cls) -> str: return "bfe_heatmap"
[docs] def can_plot(self, comparison_config: "ComparisonConfig", analysis_type: str) -> bool: """Return True for 'binding_free_energy' when heatmap is enabled.""" if analysis_type != "binding_free_energy": return False return self.settings.binding_free_energy.generate_heatmap
[docs] def plot( self, data: dict[str, Any], labels: Sequence[str], output_dir: Path, **kwargs: Any, ) -> list[Path]: """Generate ΔG_sel heatmaps, one per (partition, polymer_type). Parameters ---------- data : dict Mapping of condition_label -> condition data dict from ``ComparisonPlotter._load_analysis_data()``. labels : sequence of str Condition labels in desired display order. output_dir : Path Directory to save plot files. **kwargs Unused; for interface compatibility. Returns ------- list[Path] Paths to generated plot files, or empty list. """ import matplotlib.pyplot as plt result = _find_bfe_result(data, labels) if result is None: return [] t = self.theme bfe_settings = self.settings.binding_free_energy units = result.units # Determine display labels cond_labels = [c.label for c in result.conditions] display_labels = [lbl for lbl in labels if lbl in cond_labels] if not display_labels: display_labels = cond_labels polymer_types = result.polymer_types partitions = _get_partitions(result) if not polymer_types or not partitions: logger.warning("BFE result has no polymer types or protein groups - skipping heatmap") return [] n_conds = len(display_labels) n_poly = len(polymer_types) n_partitions = len(partitions) multi_partition = n_partitions > 1 # Temperature string (shared across all figures) temp_str = "" if result.conditions: temps = {c.temperature_K for c in result.conditions} if len(temps) == 1: temp_str = f" at {next(iter(temps)):.0f} K" output_paths: list[Path] = [] for partition_name, protein_groups in partitions.items(): n_groups = len(protein_groups) # Compute per-partition color range from entries in this partition only partition_vals: list[float] = [] for cond_summary in result.conditions: for entry in cond_summary.entries: if entry.partition_name == partition_name and entry.delta_G is not None: partition_vals.append(entry.delta_G) if not partition_vals: logger.debug(f"No ΔG_sel values for partition '{partition_name}' - skipping") continue vmin, vmax = self._symmetric_clim(partition_vals, pad=0.05) max_abs = vmax - 0.05 # needed for annotation threshold below for poly_type in polymer_types: # Auto-size if bfe_settings.figsize_heatmap is not None: figsize = bfe_settings.figsize_heatmap else: figsize = ( max(6, 1.5 * n_conds + 1.5), max(4, 0.9 * n_groups + 1.5), ) fig, ax = plt.subplots(figsize=figsize, dpi=self.settings.dpi) # Build matrix: rows = protein groups, columns = conditions matrix = np.full((n_groups, n_conds), np.nan) sem_matrix = np.full((n_groups, n_conds), np.nan) for col_idx, cond_label in enumerate(display_labels): try: cond_summary = result.get_condition(cond_label) except KeyError: continue for row_idx, group in enumerate(protein_groups): entry = cond_summary.get_entry( poly_type, group, partition_name=partition_name ) if entry is not None and entry.delta_G is not None: matrix[row_idx, col_idx] = entry.delta_G if entry.delta_G_uncertainty is not None: sem_matrix[row_idx, col_idx] = entry.delta_G_uncertainty valid = matrix[~np.isnan(matrix)] if len(valid) == 0: logger.debug( f"No ΔG_sel data for partition '{partition_name}', " f"polymer '{poly_type}' - skipping" ) plt.close(fig) continue im = ax.imshow( matrix, cmap=bfe_settings.colormap, vmin=vmin, vmax=vmax, aspect="auto", ) # Annotate cells with ΔG_sel ± σ if bfe_settings.annotate_heatmap: self._annotate_cells( ax, matrix, fontsize=t.small_fontsize, threshold=0.35 * max_abs, sem_matrix=sem_matrix, linespacing=1.2, ) ax.set_xticks(range(n_conds)) ax.set_xticklabels(display_labels, rotation=35, ha="right") ax.set_yticks(range(n_groups)) ax.set_yticklabels(protein_groups) # Y-axis label includes partition name when multiple partitions if multi_partition: ylabel = f"Protein Group ({_partition_display_name(partition_name)})" else: ylabel = "Amino Acid Group" # Title: include partition and polymer info as needed poly_label = poly_type if n_poly > 1 else "" if multi_partition: part_label = _partition_display_name(partition_name) title_parts = [r"$\Delta G_{\mathrm{sel}}$", part_label] if poly_label: title_parts.append(poly_label) if temp_str: title_parts.append(temp_str.strip()) title = " — ".join(title_parts[:2]) if poly_label: title += f" ({poly_label})" if temp_str: title += temp_str else: parts = [r"$\Delta G_{\mathrm{sel}}$"] if poly_label: parts.append(poly_label) if temp_str: parts.append(temp_str.strip()) title = ( " ".join(parts) if len(parts) > 1 else r"$\Delta G_{\mathrm{sel}}$ Binding Selectivity" ) self._apply_axis_style(ax, title=title, xlabel="Condition", ylabel=ylabel) cbar = fig.colorbar(im, ax=ax, shrink=0.85) unit_lbl = _unit_label_mpl(units) cbar.set_label( r"$\Delta G_{\mathrm{sel}}$" + f" ({unit_lbl})", rotation=270, labelpad=14, fontsize=t.legend_fontsize, ) cbar.ax.axhline( y=0.0, color=t.reference_line_color, linewidth=t.reference_line_width, linestyle=t.reference_line_style, ) plt.tight_layout() # Filename: include partition when multiple, polymer when multiple stem = self._build_stem( "bfe_heatmap", partition_name, poly_type, multi_partition, n_poly > 1, ) output_path = self._get_output_path(output_dir, stem) output_paths.append( self._save_figure( fig, output_path, experimental_features=("binding_free_energy",), ) ) return output_paths
@staticmethod def _build_stem( prefix: str, partition_name: str, poly_type: str, multi_partition: bool, multi_poly: bool, ) -> str: """Build output filename stem from partition and polymer type. Single partition + single polymer → ``prefix`` Single partition + multi polymer → ``prefix_{poly}`` Multi partition + single polymer → ``prefix_{partition}`` Multi partition + multi polymer → ``prefix_{partition}_{poly}`` """ parts = [prefix] if multi_partition: parts.append(partition_name.lower()) if multi_poly: parts.append(poly_type.lower()) return "_".join(parts)
# --------------------------------------------------------------------------- # Bar chart plotter # ---------------------------------------------------------------------------
[docs] @PlotterRegistry.register("bfe_bars") class BFEBarPlotter(BasePlotter): """Generate ΔG_sel grouped bar charts comparing binding free energy across conditions. Creates one figure per (partition, polymer_type) combination with: - Groups on x-axis: protein groups from that partition - Bars within each group: one per condition - Error bars: between-replicate SEM on ΔG_sel (delta-method fallback) - Reference line at ΔG_sel = 0 When only a single partition exists, filenames and titles match the previous single-partition behavior for backward compatibility. Loads ``BindingFreeEnergyResult`` from ``results/`` adjacent to ``comparison.yaml`` (accepts both ``binding_free_energy_comparison_*.json`` and ``bfe_comparison_*.json`` naming conventions). """
[docs] @classmethod def plot_type(cls) -> str: return "bfe_bars"
[docs] def can_plot(self, comparison_config: "ComparisonConfig", analysis_type: str) -> bool: """Return True for 'binding_free_energy' when bar charts are enabled.""" if analysis_type != "binding_free_energy": return False return self.settings.binding_free_energy.generate_bars
[docs] def plot( self, data: dict[str, Any], labels: Sequence[str], output_dir: Path, **kwargs: Any, ) -> list[Path]: """Generate ΔG_sel grouped bar charts, one per (partition, polymer_type). Parameters ---------- data : dict Mapping of condition_label -> condition data dict from ``ComparisonPlotter._load_analysis_data()``. labels : sequence of str Condition labels in desired display order. output_dir : Path Directory to save plot files. **kwargs Unused; for interface compatibility. Returns ------- list[Path] Paths to generated plot files, or empty list. """ import matplotlib.pyplot as plt result = _find_bfe_result(data, labels) if result is None: return [] t = self.theme bfe_settings = self.settings.binding_free_energy units = result.units cond_labels = [c.label for c in result.conditions] display_labels = [lbl for lbl in labels if lbl in cond_labels] if not display_labels: display_labels = cond_labels # Filter to conditions that have data valid_labels = [ lbl for lbl in display_labels if any(e.delta_G is not None for e in result.get_condition(lbl).entries) if lbl in cond_labels ] if not valid_labels: logger.info("No conditions with ΔG_sel values - skipping bar charts") return [] polymer_types = result.polymer_types partitions = _get_partitions(result) if not polymer_types or not partitions: return [] n_conds = len(valid_labels) colors = self._get_colors(n_conds) n_poly = len(polymer_types) n_partitions = len(partitions) multi_partition = n_partitions > 1 # Temperature string (shared) temp_str = "" if result.conditions: temps = {c.temperature_K for c in result.conditions} if len(temps) == 1: temp_str = f" ({next(iter(temps)):.0f} K)" # kT guide lines (shared across all figures) if units == "kT": kt: float | None = 1.0 # Already in k_bT units; guide lines at ±1.0 else: temps_list = [c.temperature_K for c in result.conditions] kt = None if temps_list: t_med = float(np.median(temps_list)) from polyzymd.compare.settings import BindingFreeEnergyAnalysisSettings tmp_settings = BindingFreeEnergyAnalysisSettings(units=units) kt = tmp_settings.k_b() * t_med output_paths: list[Path] = [] for partition_name, protein_groups in partitions.items(): n_groups = len(protein_groups) for poly_type in polymer_types: figsize = bfe_settings.figsize_bars fig, ax = plt.subplots(figsize=figsize, dpi=self.settings.dpi) x = np.arange(n_groups) series: list[tuple[str, list[float], list[float]]] = [] for cond_label in valid_labels: cond_summary = result.get_condition(cond_label) means: list[float] = [] sems: list[float] = [] for group in protein_groups: entry = cond_summary.get_entry( poly_type, group, partition_name=partition_name ) if entry is not None and entry.delta_G is not None: means.append(entry.delta_G) # Prefer between-replicate SEM, fall back to delta-method per_rep = entry.delta_G_per_replicate if len(per_rep) >= 2: sem = float(np.std(per_rep, ddof=1) / np.sqrt(len(per_rep))) elif entry.delta_G_uncertainty is not None: sem = entry.delta_G_uncertainty else: sem = 0.0 sems.append(sem) else: means.append(0.0) sems.append(0.0) series.append((cond_label, means, sems)) self._grouped_bars( ax, x, series, colors, show_error=bfe_settings.show_error_bars, reference_label=r"$\Delta G_{\mathrm{sel}}$ = 0 (neutral)", bar_edgecolor="none", ) # Title: include partition and polymer info as needed poly_label = f": {poly_type}" if n_poly > 1 else "" if multi_partition: part_label = _partition_display_name(partition_name) title = r"$\Delta G_{\mathrm{sel}}$" + f" — {part_label}{poly_label}{temp_str}" else: title = r"$\Delta G_{\mathrm{sel}}$" + f"{poly_label}{temp_str}" # X-axis label if multi_partition: xlabel = f"Protein Group ({_partition_display_name(partition_name)})" else: xlabel = "Amino Acid Group" unit_lbl = _unit_label_mpl(units) ylabel = r"$\Delta G_{\mathrm{sel}}$" + f" ({unit_lbl})" self._apply_axis_style(ax, title=title, xlabel=xlabel, ylabel=ylabel) ax.set_xticks(x) ax.set_xticklabels(protein_groups, rotation=35, ha="right") self._apply_legend( ax, fontsize=t.small_fontsize, framealpha=0.7, ) # Horizontal guide lines at ±kT if kt is not None: ax.axhline(y=kt, color="gray", linestyle=":", linewidth=1.0, alpha=0.6) ax.axhline(y=-kt, color="gray", linestyle=":", linewidth=1.0, alpha=0.6) kt_label = r"$k_\mathrm{b}T$" ax.text( n_groups - 0.5, kt, f"+{kt_label}", color="gray", fontsize=t.tiny_fontsize, va="bottom", ha="right", ) ax.text( n_groups - 0.5, -kt, f"\u2212{kt_label}", color="gray", fontsize=t.tiny_fontsize, va="top", ha="right", ) plt.tight_layout() # Filename: reuse _build_stem from heatmap plotter stem = BFEHeatmapPlotter._build_stem( "bfe_bars", partition_name, poly_type, multi_partition, n_poly > 1, ) output_path = self._get_output_path(output_dir, stem) output_paths.append( self._save_figure( fig, output_path, experimental_features=("binding_free_energy",), ) ) return output_paths