Extending the Plotter Framework
This guide shows developers how to create custom plotters using PolyzyMD’s registry-based framework. The framework follows the Open-Closed Principle (open for extension, closed for modification) and provides automatic discovery of plotter implementations.
Note
This guide covers the plotter subsystem for generating comparison visualizations. For statistical comparisons, see Extending the Comparison Framework.
Overview
The plotter framework provides:
BasePlotter: Abstract base class with shared utilities (
_save_figure,_get_output_path)PlotterRegistry: Registry for auto-discovery of plotter types via
@registerdecoratorPlotSettings: Configuration model for plot appearance (DPI, colors, formats)
Automatic discovery:
plot_all()finds all registered plotters for an analysis type
Architecture
compare/
├── plotter.py # BasePlotter, PlotterRegistry, ComparisonPlotter
├── plotters/
│ ├── triad.py # TriadKDEPanelPlotter, TriadThresholdBarsPlotter
│ ├── contacts.py # BindingPreferenceHeatmapPlotter, BindingPreferenceBarPlotter
│ ├── rmsf.py # RMSFBarPlotter, RMSFLinePlotter
│ └── __init__.py
└── config.py # PlotSettings, PlotSettingsTriad, PlotSettingsContacts
Quick Start: Minimal Plotter
Here’s the minimal code to create a new plotter:
from pathlib import Path
from typing import Any, Sequence
from polyzymd.compare.plotter import BasePlotter, PlotterRegistry
@PlotterRegistry.register("my_custom_plot")
class MyCustomPlotter(BasePlotter):
"""Generate custom visualization for my analysis type."""
@classmethod
def plot_type(cls) -> str:
return "my_custom_plot"
def can_plot(self, comparison_config, analysis_type: str) -> bool:
"""Return True if this plotter handles the analysis type."""
return analysis_type == "my_analysis"
def plot(
self,
data: dict[str, Any],
labels: Sequence[str],
output_dir: Path,
**kwargs,
) -> list[Path]:
"""Generate and save plot(s)."""
import matplotlib.pyplot as plt
# Load your data from each condition's analysis_dir
results = self._load_results(data, labels)
if not results:
return []
# Create your visualization
fig, ax = plt.subplots(figsize=(10, 6), dpi=self.settings.dpi)
# ... plotting logic ...
# Save using helper method
output_path = self._get_output_path(output_dir, "my_custom_plot")
return [self._save_figure(fig, output_path)]
def _load_results(self, data, labels):
"""Load analysis results from filesystem."""
from my_module.results import MyResult
results = {}
for label in labels:
cond_data = data.get(label)
if cond_data is None:
continue
analysis_dir = Path(cond_data["analysis_dir"])
result_file = analysis_dir / "my_result.json"
if result_file.exists():
results[label] = MyResult.load(result_file)
return results
Critical: The Data Contract
Warning
The most common mistake when implementing plotters is expecting data to be
passed via kwargs. This does not work! The orchestrator only provides
filesystem paths—plotters must load their own data.
What plot() Receives
The ComparisonPlotter.plot_analysis() method calls your plotter with:
plotter.plot(
data=data, # Dict of condition metadata (see structure below)
labels=labels, # List of condition labels in display order
output_dir=output_dir, # Where to save plots
# **kwargs is reserved for future use—DO NOT rely on it
)
The data Dictionary Structure
The data dict contains metadata for each condition, not analysis results:
data = {
"SBMA_75_25": {
"condition": ConditionConfig(...), # Condition metadata
"sim_config": SimulationConfig(...), # Full simulation config
"analysis_dir": Path("path/to/analysis/contacts/"), # CRITICAL!
"aggregated_dir": Path("path/to/analysis/contacts/aggregated/"),
"replicates": [1, 2, 3], # Replicate numbers
},
"EGMA_75_25": {
# Same structure...
},
}
Correct Pattern: Load From Filesystem
def plot(self, data, labels, output_dir, **kwargs):
# Load data from filesystem paths
for label in labels:
analysis_dir = Path(data[label]["analysis_dir"])
result_file = analysis_dir / "my_result_aggregated.json"
result = MyAggregatedResult.load(result_file)
# ... use result for plotting ...
def plot(self, data, labels, output_dir, **kwargs):
# WRONG: Expecting pre-loaded result in kwargs
comparison_result = kwargs.get("comparison_result")
if comparison_result is None:
return [] # Always returns empty!
Step-by-Step Guide
Step 1: Choose Your Analysis Type
Each plotter handles one analysis type (e.g., "contacts", "catalytic_triad", "rmsf").
The can_plot() method determines which analysis types your plotter supports:
def can_plot(self, comparison_config: "ComparisonConfig", analysis_type: str) -> bool:
"""Check if this plotter can handle the analysis type."""
if analysis_type != "contacts":
return False
# Optionally check settings to enable/disable
return self.settings.contacts.generate_my_plot
Step 2: Register Your Plotter
Use the @PlotterRegistry.register() decorator with a unique key:
@PlotterRegistry.register("binding_preference_heatmap")
class BindingPreferenceHeatmapPlotter(BasePlotter):
...
The registry key should be:
Lowercase with underscores
Descriptive of the plot type
Unique across all plotters
Step 3: Implement Required Methods
plot_type() (classmethod)
Return the registry key:
@classmethod
def plot_type(cls) -> str:
return "binding_preference_heatmap"
can_plot()
Check if this plotter should be called:
def can_plot(self, comparison_config: "ComparisonConfig", analysis_type: str) -> bool:
# Only handle contacts analysis
if analysis_type != "contacts":
return False
# Check if this specific plot type is enabled in settings
return self.settings.contacts.generate_enrichment_heatmap
plot()
Generate and save the visualization. Key responsibilities:
Load data from filesystem (not kwargs!)
Create matplotlib figure(s)
Save using
_save_figure()Return list of output paths
def plot(
self,
data: dict[str, Any],
labels: Sequence[str],
output_dir: Path,
**kwargs,
) -> list[Path]:
import matplotlib.pyplot as plt
# 1. Load data from each condition
results = {}
for label in labels:
analysis_dir = Path(data[label]["analysis_dir"])
result_file = analysis_dir / "binding_preference_aggregated_reps1-3.json"
if result_file.exists():
results[label] = AggregatedBindingPreferenceResult.load(result_file)
if not results:
return []
# 2. Create figure
fig, ax = plt.subplots(figsize=self.settings.contacts.figsize_enrichment_heatmap)
# ... plotting logic ...
# 3. Save and return
output_path = self._get_output_path(output_dir, "my_plot_name")
return [self._save_figure(fig, output_path)]
Step 4: Access Plot Settings
The self.settings attribute provides access to PlotSettings:
# Global settings
self.settings.dpi # int, default 150
self.settings.format # str, "png" or "svg"
self.settings.color_palette # str, default "Set2"
# Analysis-specific settings
self.settings.contacts.figsize_enrichment_heatmap # tuple
self.settings.contacts.enrichment_colormap # str
self.settings.contacts.show_enrichment_error # bool
self.settings.triad.figsize_kde_panel # tuple
self.settings.triad.kde_fill_alpha # float
To add new settings, extend the settings models in compare/config.py.
Step 5: Handle Multiple Plot Files
If your plotter generates multiple files (e.g., one per polymer type), return all paths:
def plot(self, data, labels, output_dir, **kwargs) -> list[Path]:
output_paths = []
for polymer_type in polymer_types:
fig, ax = plt.subplots(...)
# ... plot for this polymer type ...
output_path = self._get_output_path(
output_dir, f"binding_bars_{polymer_type.lower()}"
)
output_paths.append(self._save_figure(fig, output_path))
return output_paths
Complete Example: Bar Chart Plotter
Here’s a complete example of a plotter that generates grouped bar charts:
"""Example plotter for binding preference bar charts."""
from __future__ import annotations
import logging
from pathlib import Path
from typing import TYPE_CHECKING, Any, Sequence
import numpy as np
from polyzymd.compare.plotter import BasePlotter, PlotterRegistry
if TYPE_CHECKING:
from polyzymd.analysis.contacts.binding_preference import (
AggregatedBindingPreferenceResult,
)
from polyzymd.compare.config import ComparisonConfig
logger = logging.getLogger(__name__)
@PlotterRegistry.register("binding_preference_bars")
class BindingPreferenceBarPlotter(BasePlotter):
"""Generate grouped bar chart of binding preference enrichment.
Creates a figure showing enrichment ratios as grouped bars with:
- Groups: Protein groups (e.g., aromatic, polar, charged)
- Bars within group: One per condition
- Error bars: SEM across replicates
- Reference line at 1.0 (neutral enrichment)
One plot is generated per polymer type.
"""
@classmethod
def plot_type(cls) -> str:
return "binding_preference_bars"
def can_plot(self, comparison_config: "ComparisonConfig", analysis_type: str) -> bool:
"""Check if this plotter can handle the analysis type."""
if analysis_type != "contacts":
return False
return self.settings.contacts.generate_enrichment_bars
def plot(
self,
data: dict[str, Any],
labels: Sequence[str],
output_dir: Path,
**kwargs,
) -> list[Path]:
"""Generate enrichment bar chart from filesystem data."""
import matplotlib.pyplot as plt
# Load binding preference results from each condition
binding_results = self._load_binding_preference_results(data, labels)
if not binding_results:
logger.info("No binding preference data found - skipping bar plots")
return []
# Collect all polymer types and protein groups
all_polymer_types: set[str] = set()
all_protein_groups: set[str] = set()
for result in binding_results.values():
all_polymer_types.update(result.polymer_types())
all_protein_groups.update(result.protein_groups())
polymer_types = sorted(all_polymer_types)
protein_groups = sorted(all_protein_groups)
if not polymer_types or not protein_groups:
return []
# Generate one plot per polymer type
output_paths: list[Path] = []
valid_labels = [label for label in labels if label in binding_results]
for poly_type in polymer_types:
fig, ax = plt.subplots(
figsize=self.settings.contacts.figsize_enrichment_bars,
dpi=self.settings.dpi,
)
n_groups = len(protein_groups)
n_conditions = len(valid_labels)
bar_width = 0.8 / n_conditions
x = np.arange(n_groups)
for i, cond_label in enumerate(valid_labels):
result = binding_results[cond_label]
means = []
sems = []
for prot_group in protein_groups:
entry = result.get_entry(poly_type, prot_group)
if entry and entry.mean_enrichment is not None:
means.append(entry.mean_enrichment)
sems.append(entry.sem_enrichment or 0.0)
else:
means.append(0.0)
sems.append(0.0)
offset = (i - n_conditions / 2 + 0.5) * bar_width
ax.bar(x + offset, means, bar_width, yerr=sems, label=cond_label)
ax.axhline(y=1.0, color="black", linestyle="--", label="Neutral")
ax.set_xlabel("Protein Group")
ax.set_ylabel("Enrichment Ratio")
ax.set_title(f"Binding Preference: {poly_type}")
ax.set_xticks(x)
ax.set_xticklabels(protein_groups, rotation=45, ha="right")
ax.legend()
plt.tight_layout()
output_path = self._get_output_path(
output_dir, f"binding_preference_bars_{poly_type.lower()}"
)
output_paths.append(self._save_figure(fig, output_path))
return output_paths
def _load_binding_preference_results(
self,
data: dict[str, Any],
labels: Sequence[str],
) -> dict[str, "AggregatedBindingPreferenceResult"]:
"""Load aggregated binding preference results for each condition."""
from polyzymd.analysis.contacts.binding_preference import (
AggregatedBindingPreferenceResult,
)
results: dict[str, AggregatedBindingPreferenceResult] = {}
for label in labels:
cond_data = data.get(label)
if cond_data is None:
continue
analysis_dir = Path(cond_data.get("analysis_dir", ""))
if not analysis_dir:
continue
# Find aggregated file
agg_files = list(analysis_dir.glob("binding_preference_aggregated*.json"))
if not agg_files:
continue
result_file = sorted(agg_files)[-1] # Most recent
try:
results[label] = AggregatedBindingPreferenceResult.load(result_file)
except Exception as e:
logger.warning(f"Failed to load {result_file}: {e}")
return results
Testing Your Plotter
Unit Test
import pytest
from pathlib import Path
from unittest.mock import MagicMock, patch
class TestMyCustomPlotter:
"""Tests for MyCustomPlotter."""
def test_plot_type_returns_registry_key(self):
plotter = MyCustomPlotter(settings=MagicMock())
assert plotter.plot_type() == "my_custom_plot"
def test_can_plot_returns_true_for_correct_type(self):
settings = MagicMock()
settings.my_analysis.generate_custom_plot = True
plotter = MyCustomPlotter(settings=settings)
assert plotter.can_plot(MagicMock(), "my_analysis") is True
assert plotter.can_plot(MagicMock(), "other_type") is False
def test_plot_returns_empty_when_no_data(self):
plotter = MyCustomPlotter(settings=MagicMock())
result = plotter.plot({}, [], Path("/tmp"))
assert result == []
Integration Test
# Run plot-all to test discovery and execution
pixi run -e build polyzymd compare plot-all -f comparison.yaml
Common Patterns
Pattern: Load Aggregated Results
Most plotters load aggregated (replicate-averaged) results:
def _load_aggregated(self, data, labels):
from my_module import MyAggregatedResult
results = {}
for label in labels:
agg_dir = Path(data[label].get("aggregated_dir", ""))
result_file = agg_dir / "my_aggregated.json"
if result_file.exists():
results[label] = MyAggregatedResult.load(result_file)
return results
Pattern: Pool Per-Replicate Data
For distribution plots (KDEs, histograms), pool raw data across replicates:
def _pool_replicate_data(self, data, labels):
pooled = {}
for label in labels:
analysis_dir = Path(data[label]["analysis_dir"])
replicates = data[label]["replicates"]
all_values = []
for rep in replicates:
rep_file = analysis_dir / f"run_{rep}" / "result.json"
if rep_file.exists():
result = MyResult.load(rep_file)
all_values.extend(result.values)
if all_values:
pooled[label] = np.array(all_values)
return pooled
Pattern: Conditional Plot Generation
Only generate plots when data is available:
def plot(self, data, labels, output_dir, **kwargs):
results = self._load_results(data, labels)
if not results:
logger.info("No data found - skipping plot")
return [] # Return empty list, not raise exception
# Continue with plotting...
Adding Plot Settings
To add configurable settings for your plotter:
1. Extend Settings Model
In compare/config.py:
class PlotSettingsMyAnalysis(BaseModel):
"""Plot settings for my_analysis type."""
generate_custom_plot: bool = True
figsize_custom: tuple[int, int] = (10, 6)
custom_colormap: str = "viridis"
class PlotSettings(BaseModel):
"""Root plot settings."""
# Existing fields...
my_analysis: PlotSettingsMyAnalysis = Field(default_factory=PlotSettingsMyAnalysis)
2. Access in Plotter
def plot(self, data, labels, output_dir, **kwargs):
fig, ax = plt.subplots(figsize=self.settings.my_analysis.figsize_custom)
ax.imshow(matrix, cmap=self.settings.my_analysis.custom_colormap)
Troubleshooting
“No data found” but data exists
Check that your file pattern matches:
# Verify file exists
agg_files = list(analysis_dir.glob("binding_preference_aggregated*.json"))
print(f"Found files: {agg_files}") # Debug
Plot not appearing in plot-all
Verify registration:
from polyzymd.compare.plotter import PlotterRegistry print(PlotterRegistry.list()) # Should include your key
Check
can_plot()returnsTrue:assert plotter.can_plot(config, "my_analysis") is True
Verify settings enable the plot:
plot_settings: my_analysis: generate_custom_plot: true
Type errors in matplotlib
Use tuples for add_axes and tight_layout:
# Correct
cbar_ax = fig.add_axes((0.92, 0.15, 0.02, 0.7)) # tuple
plt.tight_layout(rect=(0, 0, 0.9, 0.95)) # tuple
# Incorrect (causes type errors)
cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7]) # list
See Also
Extending the Comparison Framework — Creating custom statistical comparators
Binding Preference Analysis — Binding preference analysis details
Architecture — Overall system architecture