"""Shared plotting utilities for analysis plugins.
This module provides reusable plotting helper functions extracted from the
plotter infrastructure. Analysis plugins import these functions to apply
consistent styling, save figures with watermarks, and render common chart
elements (grouped bars, heatmap annotations, etc.) without inheriting from
a base class.
All functions accept a ``PlotSettings`` object (from
``polyzymd.config.comparison``) so that user-configured themes, palettes,
and DPI settings are respected automatically.
Examples
--------
>>> from polyzymd.analyses.shared.plotting import (
... apply_axis_style, apply_legend, get_palette_colors, save_figure,
... )
>>>
>>> fig, ax = plt.subplots()
>>> colors = get_palette_colors(3, plot_settings)
>>> ax.bar(x, y, color=colors[0])
>>> apply_axis_style(ax, plot_settings, title="My Plot", ylabel="Value (Å)")
>>> apply_legend(ax, plot_settings)
>>> save_figure(fig, output_dir / "my_plot.png", plot_settings)
"""
from __future__ import annotations
import logging
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Any, Sequence
from polyzymd.core.branding import PLOT_WATERMARK
if TYPE_CHECKING:
import numpy as np
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from polyzymd.config.comparison import PlotSettings, PlotTheme
logger = logging.getLogger(__name__)
_UNSET = object() # sentinel for apply_legend defaults
[docs]
@dataclass(frozen=True)
class ArtifactPlotData:
"""Canonical artifacts loaded for plot-time data access."""
analysis_dir: Path
condition_artifact: Any | None
replicate_artifacts: dict[int, Any]
aggregated_dir: Path
run_dirs: dict[int, Path]
[docs]
def load_canonical_plot_artifacts(
analysis_dir: Path,
replicates: Sequence[int],
*,
require_condition: bool = False,
require_replicates: bool = True,
) -> ArtifactPlotData:
"""Load plot inputs from canonical MDAnalysis artifacts only.
The loader reads ``aggregated/result.json`` and the configured
``run_N/result.json`` files through :class:`ArtifactStore`. It never scans
directories, opens non-canonical JSON files, or imports trajectory packages.
Parameters
----------
analysis_dir : Path
Condition-level analysis directory containing ``aggregated`` and
``run_N`` subdirectories.
replicates : sequence of int
Configured replicate IDs to load. Extra run directories are ignored.
require_condition : bool, optional
Raise when ``aggregated/result.json`` is absent, by default False.
require_replicates : bool, optional
Raise when any configured ``run_N/result.json`` is absent, by default
True.
Returns
-------
ArtifactPlotData
Loaded canonical condition and replicate artifacts.
"""
from polyzymd.analyses.mda import ArtifactStore, ArtifactStoreError
root = Path(analysis_dir)
aggregated_dir = root / "aggregated"
condition_artifact = None
condition_path = aggregated_dir / "result.json"
if condition_path.exists():
condition_artifact = ArtifactStore(aggregated_dir).read_condition_result("result.json")
elif require_condition:
raise ArtifactStoreError(f"Missing canonical condition artifact: {condition_path}")
replicate_artifacts: dict[int, Any] = {}
run_dirs: dict[int, Path] = {}
for replicate in replicates:
replicate_id = int(replicate)
run_dir = root / f"run_{replicate_id}"
run_dirs[replicate_id] = run_dir
replicate_path = run_dir / "result.json"
if replicate_path.exists():
replicate_artifacts[replicate_id] = ArtifactStore(run_dir).read_replicate_result(
"result.json"
)
elif require_replicates:
raise ArtifactStoreError(f"Missing canonical replicate artifact: {replicate_path}")
return ArtifactPlotData(
analysis_dir=root,
condition_artifact=condition_artifact,
replicate_artifacts=replicate_artifacts,
aggregated_dir=aggregated_dir,
run_dirs=run_dirs,
)
# ---------------------------------------------------------------------------
# Theme access
# ---------------------------------------------------------------------------
[docs]
def get_theme(plot_settings: "PlotSettings") -> "PlotTheme":
"""Return the resolved ``PlotTheme`` from *plot_settings*.
Parameters
----------
plot_settings : PlotSettings
Global plot settings (carries a ``.theme`` property).
Returns
-------
PlotTheme
"""
return plot_settings.theme
# ---------------------------------------------------------------------------
# Axis styling
# ---------------------------------------------------------------------------
[docs]
def apply_axis_style(
ax: "Axes",
plot_settings: "PlotSettings",
*,
title: str | None = None,
xlabel: str | None = None,
ylabel: str | None = None,
) -> None:
"""Apply standard axis chrome from the theme.
Hides spines according to theme settings, sizes tick labels, and
optionally sets title / xlabel / ylabel with themed font sizes.
Parameters
----------
ax : matplotlib Axes
Target axes to style.
plot_settings : PlotSettings
Global plot settings.
title : str, optional
Axes title.
xlabel : str, optional
X-axis label.
ylabel : str, optional
Y-axis label.
"""
t = plot_settings.theme
if t.hide_top_spine:
ax.spines["top"].set_visible(False)
if t.hide_right_spine:
ax.spines["right"].set_visible(False)
ax.tick_params(axis="both", labelsize=t.tick_fontsize)
if title is not None:
ax.set_title(title, fontsize=t.title_fontsize, fontweight=t.title_fontweight)
if xlabel is not None:
ax.set_xlabel(xlabel, fontsize=t.label_fontsize)
if ylabel is not None:
ax.set_ylabel(ylabel, fontsize=t.label_fontsize)
[docs]
def apply_legend(
ax: "Axes",
plot_settings: "PlotSettings",
*,
loc: str | None = None,
bbox_to_anchor: tuple[float, float] | None | object = _UNSET,
fontsize: int | None = None,
**kwargs: Any,
) -> None:
"""Apply legend with themed defaults.
Uses ``theme.legend_loc`` and ``theme.legend_bbox`` unless overridden
by the caller. Extra *kwargs* are forwarded to ``ax.legend()``.
Parameters
----------
ax : matplotlib Axes
Target axes.
plot_settings : PlotSettings
Global plot settings.
loc : str, optional
Override ``theme.legend_loc``.
bbox_to_anchor : tuple of float or None, optional
Override ``theme.legend_bbox``. Pass ``None`` explicitly to
suppress the bbox (e.g. for inside-axes placement).
fontsize : int, optional
Override ``theme.legend_fontsize``.
**kwargs
Forwarded to ``ax.legend()``.
"""
t = plot_settings.theme
resolved_loc = loc or t.legend_loc
resolved_fs = fontsize or t.legend_fontsize
legend_kwargs: dict[str, Any] = {"loc": resolved_loc, "fontsize": resolved_fs}
# Resolve bbox_to_anchor: _UNSET → theme default, None → omit
if bbox_to_anchor is _UNSET:
legend_kwargs["bbox_to_anchor"] = t.legend_bbox
elif bbox_to_anchor is not None:
legend_kwargs["bbox_to_anchor"] = bbox_to_anchor
legend_kwargs.update(kwargs)
ax.legend(**legend_kwargs)
# ---------------------------------------------------------------------------
# Colors
# ---------------------------------------------------------------------------
[docs]
def get_palette_colors(n: int, plot_settings: "PlotSettings") -> list:
"""Get *n* distinct colors from the configured palette.
Tries seaborn first (richer palette support), falls back to a
matplotlib colormap sampled at evenly-spaced intervals.
Parameters
----------
n : int
Number of colors needed.
plot_settings : PlotSettings
Global plot settings (carries ``color_palette``).
Returns
-------
list
List of color values (RGB tuples or matplotlib color specs).
"""
palette = plot_settings.color_palette
try:
import seaborn as sns
try:
return list(sns.color_palette(palette, n))
except ValueError:
logger.warning(
"Color palette %r is not a valid seaborn palette. Falling back to 'tab10'.",
palette,
)
return _matplotlib_colormap_colors("tab10", n)
except ImportError:
pass
try:
return _matplotlib_colormap_colors(palette, n)
except (KeyError, ValueError):
logger.warning(
"Color palette %r is not a valid matplotlib colormap "
"(seaborn is not installed). Falling back to 'tab10'.",
palette,
)
return _matplotlib_colormap_colors("tab10", n)
def _matplotlib_colormap_colors(colormap_name: str, n: int) -> list:
"""Sample colors from a matplotlib colormap.
Parameters
----------
colormap_name : str
Name of the matplotlib colormap to sample.
n : int
Number of colors to return.
Returns
-------
list
RGBA colors sampled across the colormap range.
"""
import matplotlib.pyplot as plt
cmap = plt.colormaps[colormap_name]
return [cmap(i / max(1, n - 1)) for i in range(n)]
[docs]
def order_condition_labels(labels: Sequence[str], plot_settings: "PlotSettings") -> list[str]:
"""Return condition labels in semantic plot order when enabled.
Ordering only affects plot display order. It does not alter comparison
statistics, rankings, or condition result files.
Parameters
----------
labels : sequence of str
Condition labels in their original order.
plot_settings : PlotSettings
Global plot settings carrying optional semantic color settings.
Returns
-------
list of str
Ordered labels for plotting.
"""
label_list = list(labels)
semantic = getattr(plot_settings, "semantic_colors", None)
if semantic is None or not semantic.enabled:
return label_list
remaining = list(label_list)
ordered: list[str] = []
for label in semantic.order:
if label in remaining:
ordered.append(label)
remaining.remove(label)
indexed_remaining = list(enumerate(remaining))
with_order: list[tuple[int, str, int]] = []
without_order: list[tuple[int, str]] = []
for relative_index, label in indexed_remaining:
condition = semantic.conditions.get(label)
if condition is not None and condition.order is not None:
with_order.append((condition.order, label, relative_index))
else:
without_order.append((relative_index, label))
ordered.extend(label for _, label, _ in sorted(with_order, key=lambda item: (item[0], item[2])))
ordered.extend(label for _, label in without_order)
return ordered
[docs]
def get_condition_colors(
labels: Sequence[str],
plot_settings: "PlotSettings",
*,
control_label: str | None = None,
) -> list:
"""Return colors for condition labels using semantic settings if enabled.
Parameters
----------
labels : sequence of str
Condition labels in plot order.
plot_settings : PlotSettings
Global plot settings carrying optional semantic color settings.
control_label : str, optional
Label that should use the configured semantic control color.
Returns
-------
list
Color values aligned to ``labels``.
"""
color_map = get_condition_color_map(labels, plot_settings, control_label=control_label)
return [color_map[label] for label in labels]
[docs]
def get_condition_color_map(
labels: Sequence[str],
plot_settings: "PlotSettings",
*,
control_label: str | None = None,
) -> dict[str, Any]:
"""Return a label-to-color map using semantic condition color rules.
Resolution precedence is manual color, condition color, control color,
family/value color, missing metadata fallback, then existing palette fallback.
Invalid color or colormap values warn and continue to a safe fallback.
Parameters
----------
labels : sequence of str
Condition labels in their original palette-alignment order.
plot_settings : PlotSettings
Global plot settings carrying optional semantic color settings.
control_label : str, optional
Label that should use the configured semantic control color.
Returns
-------
dict of str to Any
Mapping from each label to its resolved matplotlib-compatible color.
"""
label_list = list(labels)
palette_colors = get_palette_colors(len(label_list), plot_settings)
palette_by_label = dict(zip(label_list, palette_colors))
semantic = getattr(plot_settings, "semantic_colors", None)
if semantic is None or not semantic.enabled:
return dict(palette_by_label)
observed_values = _collect_family_values(label_list, semantic.conditions)
color_map: dict[str, Any] = {}
for label in label_list:
color_map[label] = _resolve_condition_color(
label,
semantic,
palette_by_label,
observed_values,
control_label=control_label,
)
return color_map
def _resolve_condition_color(
label: str,
semantic: Any,
palette_by_label: dict[str, Any],
observed_values: dict[str, list[Any]],
*,
control_label: str | None,
) -> Any:
"""Resolve one condition color using semantic precedence."""
manual_color = _validated_color(
semantic.manual_colors.get(label), f"manual color for {label!r}"
)
if manual_color is not None:
return manual_color
condition = semantic.conditions.get(label)
if condition is None:
default_color = _validated_color(semantic.default_color, "semantic default_color")
return default_color if default_color is not None else palette_by_label[label]
condition_color = _validated_color(condition.color, f"condition color for {label!r}")
if condition_color is not None:
return condition_color
is_control = control_label is not None and label == control_label
is_control = is_control or condition.role == "control"
if is_control:
control_color = _validated_color(semantic.control_color, "semantic control_color")
if control_color is not None:
return control_color
family_color = _resolve_family_color(condition, semantic.families, observed_values, label)
if family_color is not None:
return family_color
missing_color = _validated_color(semantic.missing_color, "semantic missing_color")
return missing_color if missing_color is not None else palette_by_label[label]
def _collect_family_values(
labels: Sequence[str], conditions: dict[str, Any]
) -> dict[str, list[Any]]:
"""Collect observed semantic values by family in label order."""
observed_values: dict[str, list[Any]] = {}
for label in labels:
condition = conditions.get(label)
if condition is None or condition.family is None or condition.value is None:
continue
values = observed_values.setdefault(condition.family, [])
if condition.value not in values:
values.append(condition.value)
return observed_values
def _validated_color(color: Any, context: str) -> Any | None:
"""Return a color when matplotlib accepts it, otherwise warn."""
if color is None:
return None
from matplotlib.colors import is_color_like
if is_color_like(color):
return color
logger.warning("Invalid %s %r. Falling back to the next available color rule.", context, color)
return None
def _resolve_family_color(
condition: Any,
families: dict[str, Any],
observed_values: dict[str, list[Any]],
label: str,
) -> Any | None:
"""Resolve a condition color from its semantic family and value."""
if condition.family is None or condition.value is None:
return None
family = families.get(condition.family)
if family is None:
logger.warning(
"Condition %r references unknown semantic color family %r.",
label,
condition.family,
)
return None
value_color = _resolve_explicit_value_color(family.value_colors, condition.value, label)
if value_color is not None:
return value_color
cmap = _get_colormap(family.colormap, label)
if cmap is None:
return None
if family.scale == "ordinal":
fraction = _ordinal_fraction(
condition.value, family, observed_values.get(condition.family, [])
)
else:
fraction = _linear_fraction(
condition.value, family, observed_values.get(condition.family, [])
)
if fraction is None:
return None
if family.reverse:
fraction = 1.0 - fraction
low, high = family.colormap_range
return cmap(low + ((high - low) * fraction))
def _resolve_explicit_value_color(
value_colors: dict[str, Any], value: Any, label: str
) -> Any | None:
"""Resolve an exact semantic value color if configured."""
if value in value_colors:
color = value_colors[value]
else:
color = value_colors.get(str(value))
return _validated_color(color, f"value color for {label!r}")
def _get_colormap(colormap_name: str, label: str) -> Any | None:
"""Return a matplotlib colormap or warn and return ``None``."""
import matplotlib.pyplot as plt
try:
return plt.colormaps[colormap_name]
except (KeyError, ValueError):
logger.warning(
"Semantic color colormap %r for condition %r is invalid. Falling back.",
colormap_name,
label,
)
return None
def _ordinal_fraction(value: Any, family: Any, observed_values: list[Any]) -> float | None:
"""Return an ordinal colormap fraction for a semantic value."""
value_order = family.value_order or observed_values
if value not in value_order:
logger.warning("Semantic ordinal value %r is not present in value_order.", value)
return None
if len(value_order) <= 1:
return 0.5
return value_order.index(value) / (len(value_order) - 1)
def _linear_fraction(value: Any, family: Any, observed_values: list[Any]) -> float | None:
"""Return a linear colormap fraction for a semantic numeric value."""
try:
numeric_value = float(value)
except (TypeError, ValueError):
logger.warning("Semantic linear value %r is not numeric.", value)
return None
numeric_values: list[float] = []
for observed_value in observed_values:
try:
numeric_values.append(float(observed_value))
except (TypeError, ValueError):
continue
vmin = family.vmin if family.vmin is not None else min(numeric_values, default=numeric_value)
vmax = family.vmax if family.vmax is not None else max(numeric_values, default=numeric_value)
if vmax == vmin:
return 0.5
return min(1.0, max(0.0, (numeric_value - vmin) / (vmax - vmin)))
# ---------------------------------------------------------------------------
# Figure saving
# ---------------------------------------------------------------------------
[docs]
def get_output_path(output_dir: Path, name: str, plot_settings: "PlotSettings") -> Path:
"""Generate output file path with correct format extension.
Parameters
----------
output_dir : Path
Output directory.
name : str
Base filename (without extension).
plot_settings : PlotSettings
Global plot settings (carries ``format``).
Returns
-------
Path
Full output path with extension.
"""
return output_dir / f"{name}.{plot_settings.format}"
# ---------------------------------------------------------------------------
# Grouped bar chart
# ---------------------------------------------------------------------------
[docs]
def finite_numeric_values(values: Any) -> "np.ndarray":
"""Return finite numeric values as a one-dimensional float array.
Parameters
----------
values : Any
Candidate scalar or sequence of replicate values. ``None`` and
non-numeric inputs are treated as missing data.
Returns
-------
numpy.ndarray
One-dimensional array containing only finite floats. The array is
empty when no finite numeric values are available.
"""
import numpy as np
if values is None:
return np.array([], dtype=float)
try:
value_array = np.asarray(values, dtype=float)
except (TypeError, ValueError):
if isinstance(values, str | bytes):
return np.array([], dtype=float)
try:
iterator = iter(values)
except TypeError:
try:
scalar_value = float(values)
except (TypeError, ValueError):
return np.array([], dtype=float)
if np.isfinite(scalar_value):
return np.array([scalar_value], dtype=float)
return np.array([], dtype=float)
finite_values: list[float] = []
for value in iterator:
try:
numeric_value = float(value)
except (TypeError, ValueError):
continue
if np.isfinite(numeric_value):
finite_values.append(numeric_value)
return np.array(finite_values, dtype=float)
if value_array.ndim == 0:
value_array = value_array.reshape(1)
value_array = value_array.ravel()
return value_array[np.isfinite(value_array)]
[docs]
def replicate_jitter_offsets(n_values: int, bar_width: float) -> "np.ndarray":
"""Return deterministic offsets for replicate dot overlays.
Offsets are centred on the corresponding bar position so overlays are
reproducible across runs and independent of random-number state.
Parameters
----------
n_values : int
Number of replicate values to display for one bar.
bar_width : float
Width or height of the corresponding bar, depending on orientation.
Returns
-------
numpy.ndarray
Jitter offsets centred around zero.
"""
import numpy as np
if n_values <= 0:
return np.array([], dtype=float)
if n_values == 1:
return np.array([0.0], dtype=float)
max_jitter = bar_width * 0.25
return np.linspace(-max_jitter, max_jitter, n_values)
[docs]
def has_replicate_uncertainty(
replicate_values: Any = None,
*,
n_replicates: int | None = None,
) -> bool:
"""Return whether replicate-level uncertainty can be displayed.
Parameters
----------
replicate_values : Any, optional
Per-condition or per-bar replicate values. Finite numeric entries are
counted after coercion.
n_replicates : int or None, optional
Explicit replicate count when the raw replicate values are not
available.
Returns
-------
bool
True when at least two finite independent replicate values are present.
"""
if n_replicates is not None:
return n_replicates >= 2
return finite_numeric_values(replicate_values).size >= 2
[docs]
def suppress_singleton_errors(
errors: Sequence[float],
replicate_values: Sequence[Any] | None,
) -> list[float] | None:
"""Return errors with singleton replicate uncertainties suppressed.
Parameters
----------
errors : sequence of float
SEM or uncertainty values aligned to ``replicate_values``.
replicate_values : sequence or None
Per-bar replicate values used to decide whether an error bar is
statistically displayable.
Returns
-------
list of float or None
Sanitized error values. Returns ``None`` when no bar has replicate
uncertainty, allowing callers to omit error bars entirely.
"""
if replicate_values is None:
return list(errors)
sanitized: list[float] = []
has_any_uncertainty = False
for error, values in zip(errors, replicate_values, strict=False):
if has_replicate_uncertainty(values):
sanitized.append(float(error))
has_any_uncertainty = True
else:
sanitized.append(0.0)
return sanitized if has_any_uncertainty else None
[docs]
def scatter_replicate_values(
ax: "Axes",
bar_positions: "Sequence[float] | np.ndarray",
replicate_values: "Sequence[Any]",
plot_settings: "PlotSettings",
*,
orientation: str = "vertical",
bar_width: float = 0.8,
dot_color: Any | None = None,
dot_size: float | None = None,
dot_alpha: float | None = None,
zorder: float = 5,
) -> int:
"""Overlay jittered per-replicate values on bars.
For vertical bars, bar positions are x-coordinates, jitter is applied in
x, and replicate values are plotted on y. For horizontal bars, bar
positions are y-coordinates, jitter is applied in y, and replicate values
are plotted on x.
Parameters
----------
ax : matplotlib.axes.Axes
Axes containing the bar chart.
bar_positions : sequence of float or numpy.ndarray
Bar centre positions aligned to ``replicate_values``.
replicate_values : sequence of Any
Per-bar replicate values. Each item may be a scalar or sequence; only
finite numeric values are plotted.
plot_settings : PlotSettings
Global plot settings whose theme provides default dot styling.
orientation : {"vertical", "horizontal"}, optional
Bar orientation, by default ``"vertical"``.
bar_width : float, optional
Width or height of the bars, by default ``0.8``.
dot_color : Any, optional
Override for theme dot colour.
dot_size : float, optional
Override for theme dot size. Dots are skipped when non-positive.
dot_alpha : float, optional
Override for theme dot alpha. Dots are skipped when non-positive.
zorder : float, optional
Matplotlib z-order for dot overlays, by default ``5``.
Returns
-------
int
Number of scatter calls emitted.
Raises
------
ValueError
If ``orientation`` is not ``"vertical"`` or ``"horizontal"``, or if
``bar_positions`` and ``replicate_values`` are not the same length.
"""
import numpy as np
if orientation not in {"vertical", "horizontal"}:
raise ValueError("orientation must be 'vertical' or 'horizontal'")
theme = plot_settings.theme
resolved_size = theme.dot_size if dot_size is None else dot_size
resolved_alpha = theme.dot_alpha if dot_alpha is None else dot_alpha
if resolved_size <= 0 or resolved_alpha <= 0:
return 0
resolved_color = theme.dot_color if dot_color is None else dot_color
positions = np.asarray(bar_positions, dtype=float)
if len(replicate_values) != len(positions):
raise ValueError(
"replicate_values length must match bar_positions length "
f"({len(replicate_values)} != {len(positions)})"
)
n_scattered = 0
for idx, values in enumerate(replicate_values):
rep_arr = finite_numeric_values(values)
if rep_arr.size == 0:
continue
jitter = replicate_jitter_offsets(rep_arr.size, bar_width)
position_arr = np.full(rep_arr.shape, float(positions[idx]), dtype=float) + jitter
if orientation == "vertical":
x_values = position_arr
y_values = rep_arr
else:
x_values = rep_arr
y_values = position_arr
ax.scatter(
x_values,
y_values,
color=resolved_color,
s=resolved_size,
zorder=zorder,
alpha=resolved_alpha,
edgecolors="none",
)
n_scattered += 1
return n_scattered
[docs]
def scatter_stacked_segment_replicates(
ax: "Axes",
x_position: float,
bottom_value: float,
replicate_values: Sequence[Any],
plot_settings: "PlotSettings",
*,
replicate_base_values: Sequence[Any] | None = None,
positive_base_values: Sequence[Any] | None = None,
negative_base_values: Sequence[Any] | None = None,
bar_width: float = 0.8,
dot_color: Any | None = None,
dot_size: float | None = None,
dot_alpha: float | None = None,
placement: str = "center",
zorder: float = 5,
) -> int:
"""Overlay replicate dots on stacked segments.
The per-component replicate value is a segment height, not an absolute
stacked coordinate. Plotting at ``base + replicate / 2`` places each dot at
the center of the component-specific replicate segment. Callers should pass
replicate-specific bases when earlier stacked components vary by replicate.
Signed stacks may pass separate positive and negative bases so each dot is
placed on the same sign stack as its own replicate value.
Parameters
----------
ax : matplotlib.axes.Axes
Axes containing the stacked bar chart.
x_position : float
Center x-coordinate of the condition bar.
bottom_value : float
Aggregate stack baseline for the current segment.
replicate_values : sequence of Any
Component-specific per-replicate segment heights.
plot_settings : PlotSettings
Plot configuration used for dot styling.
replicate_base_values : sequence of Any, optional
Per-replicate cumulative stack bases for unsigned stacks. When omitted,
``bottom_value`` is used for every replicate for backward compatibility.
positive_base_values : sequence of Any, optional
Per-replicate cumulative positive stack bases for signed stacks.
negative_base_values : sequence of Any, optional
Per-replicate cumulative negative stack bases for signed stacks.
bar_width : float, optional
Width used for deterministic jitter, by default ``0.8``.
dot_color : Any, optional
Override for theme dot colour.
dot_size : float, optional
Override for theme dot size.
dot_alpha : float, optional
Override for theme dot alpha.
placement : {"center", "end"}, optional
Dot placement within each replicate segment. ``"center"`` uses
``base + replicate / 2`` and ``"end"`` uses ``base + replicate``.
zorder : float, optional
Matplotlib z-order for dot overlays, by default ``5``.
Returns
-------
int
Number of scatter calls emitted.
Raises
------
ValueError
If replicate base arrays do not align with ``replicate_values``.
"""
import math
import numpy as np
if placement not in {"center", "end"}:
raise ValueError("placement must be 'center' or 'end'")
raw_values = list(replicate_values)
if positive_base_values is not None or negative_base_values is not None:
if positive_base_values is None or negative_base_values is None:
raise ValueError(
"positive_base_values and negative_base_values must be provided together"
)
positive_bases = list(positive_base_values)
negative_bases = list(negative_base_values)
if len(positive_bases) != len(raw_values) or len(negative_bases) != len(raw_values):
raise ValueError("signed replicate base lengths must match replicate_values length")
base_values: list[float] = []
segment_values_list: list[float] = []
for value, positive_base, negative_base in zip(raw_values, positive_bases, negative_bases):
try:
segment_value = float(value)
base_value = float(positive_base if segment_value >= 0.0 else negative_base)
except (TypeError, ValueError):
continue
if math.isfinite(segment_value) and math.isfinite(base_value):
segment_values_list.append(segment_value)
base_values.append(base_value)
segment_values = np.asarray(segment_values_list, dtype=float)
bases = np.asarray(base_values, dtype=float)
elif replicate_base_values is not None:
raw_bases = list(replicate_base_values)
if len(raw_bases) != len(raw_values):
raise ValueError("replicate_base_values length must match replicate_values length")
base_values = []
segment_values_list = []
for value, base in zip(raw_values, raw_bases):
try:
segment_value = float(value)
base_value = float(base)
except (TypeError, ValueError):
continue
if math.isfinite(segment_value) and math.isfinite(base_value):
segment_values_list.append(segment_value)
base_values.append(base_value)
segment_values = np.asarray(segment_values_list, dtype=float)
bases = np.asarray(base_values, dtype=float)
else:
segment_values = finite_numeric_values(raw_values)
bases = np.full(segment_values.shape, float(bottom_value), dtype=float)
if segment_values.size == 0:
return 0
divisor = 2.0 if placement == "center" else 1.0
segment_positions = [
float(base) + float(value) / divisor for base, value in zip(bases, segment_values)
]
return scatter_replicate_values(
ax,
[x_position],
[segment_positions],
plot_settings,
orientation="vertical",
bar_width=bar_width,
dot_color=dot_color,
dot_size=dot_size,
dot_alpha=dot_alpha,
zorder=zorder,
)
[docs]
def grouped_bars(
ax: "Axes",
x: "np.ndarray",
series: "Sequence[tuple[str, Sequence[float], Sequence[float]]]",
colors: "Sequence",
plot_settings: "PlotSettings",
*,
bar_width: float | None = None,
show_error: bool = True,
reference_line: float | None = 0.0,
reference_label: str = "Neutral (0)",
replicate_values: "Sequence[Sequence[Sequence[float]]] | None" = None,
**style_overrides,
) -> None:
"""Render grouped bars with optional error bars and reference line.
Style values (alpha, capsize, edgecolor, linewidth, dot_size, etc.)
are read from ``plot_settings.theme``. Callers can override any of
them via ``**style_overrides`` using the theme field names as keys.
Parameters
----------
ax : matplotlib Axes
Target axes.
x : np.ndarray
1-D array of group centre positions (e.g. ``np.arange(n_groups)``).
series : sequence of (label, means, errors)
One tuple per condition. *means* and *errors* must have the same
length as *x*.
colors : sequence
One colour per condition (same length as *series*).
plot_settings : PlotSettings
Global plot settings.
bar_width : float | None, optional
Width of each individual bar. When ``None`` (default) the width
is computed as ``0.8 / len(series)``.
show_error : bool, optional
If ``False``, error bars are suppressed, by default ``True``.
reference_line : float | None, optional
Y-value for a horizontal reference line. Set to ``None`` to
skip, by default ``0.0``.
reference_label : str, optional
Legend label for the reference line, by default ``"Neutral (0)"``.
replicate_values : sequence or None, optional
Per-replicate values for jittered dot overlay. Indexed as
``replicate_values[condition_idx][group_idx]`` -> sequence of
floats (one per replicate). When ``None`` (default), no dots
are drawn.
**style_overrides
Override any theme field for this call only. Accepted keys:
``bar_alpha``, ``bar_capsize``, ``bar_edgecolor``,
``bar_linewidth``, ``dot_size``, ``dot_alpha``, ``dot_color``,
``reference_line_color``, ``reference_line_style``,
``reference_line_width``.
"""
import numpy as np
t = plot_settings.theme
alpha = style_overrides.get("bar_alpha", t.bar_alpha)
capsize = style_overrides.get("bar_capsize", t.bar_capsize)
edgecolor = style_overrides.get("bar_edgecolor", t.bar_edgecolor)
linewidth = style_overrides.get("bar_linewidth", t.bar_linewidth)
dot_s = style_overrides.get("dot_size", t.dot_size)
dot_a = style_overrides.get("dot_alpha", t.dot_alpha)
dot_c = style_overrides.get("dot_color", t.dot_color)
ref_color = style_overrides.get("reference_line_color", t.reference_line_color)
ref_style = style_overrides.get("reference_line_style", t.reference_line_style)
ref_width = style_overrides.get("reference_line_width", t.reference_line_width)
n = len(series)
w = bar_width if bar_width is not None else 0.8 / max(n, 1)
if replicate_values is not None:
if len(replicate_values) != n:
raise ValueError(
f"replicate_values length must match series length ({len(replicate_values)} != {n})"
)
for idx, series_replicates in enumerate(replicate_values):
if len(series_replicates) != len(x):
raise ValueError(
"replicate_values entries must match x length "
f"for series {idx} ({len(series_replicates)} != {len(x)})"
)
for i, (label, means, errors) in enumerate(series):
offset = (i - n / 2 + 0.5) * w
bar_kwargs: dict = {
"width": w,
"label": label,
"color": colors[i],
"alpha": alpha,
"capsize": capsize,
"edgecolor": edgecolor,
"linewidth": linewidth,
}
if show_error:
errors_for_plot = errors
if replicate_values is not None:
errors_for_plot = suppress_singleton_errors(errors, replicate_values[i])
if errors_for_plot is not None:
bar_kwargs["yerr"] = errors_for_plot
bar_positions = np.asarray(x) + offset
ax.bar(bar_positions, means, **bar_kwargs)
# Overlay jittered replicate dots
if replicate_values is not None:
scatter_replicate_values(
ax,
bar_positions,
replicate_values[i],
plot_settings,
orientation="vertical",
bar_width=w,
dot_color=dot_c,
dot_size=dot_s,
dot_alpha=dot_a,
)
if reference_line is not None:
ax.axhline(
y=reference_line,
color=ref_color,
linestyle=ref_style,
linewidth=ref_width,
label=reference_label,
)
# ---------------------------------------------------------------------------
# Heatmap cell annotation
# ---------------------------------------------------------------------------
[docs]
def annotate_cells(
ax: "Axes",
matrix: "np.ndarray",
plot_settings: "PlotSettings",
*,
fmt: str = ".2f",
fontsize: int | None = None,
threshold: float = 0.3,
sem_matrix: "np.ndarray | None" = None,
show_sign: bool = True,
linespacing: float | None = None,
) -> None:
"""Annotate heatmap cells with formatted values.
Iterates over every element of *matrix* and places a text label at
the corresponding (col, row) position on *ax*. NaN cells are
skipped. Text colour flips between black and white depending on
the background intensity (controlled by *threshold*).
Parameters
----------
ax : matplotlib Axes
The axes containing the heatmap image.
matrix : np.ndarray
2-D array of values (rows x cols) matching the heatmap.
plot_settings : PlotSettings
Global plot settings.
fmt : str, optional
Format spec for the value, by default ``".2f"``.
fontsize : int | None, optional
Annotation font size. When ``None`` (default), uses
``plot_settings.theme.annotation_fontsize``.
threshold : float, optional
Absolute-value threshold above which text turns white.
sem_matrix : np.ndarray | None, optional
If provided, a second line ``±{sem}`` is appended when the
SEM value is finite.
show_sign : bool, optional
Prefix positive values with ``"+"`` , by default ``True``.
linespacing : float | None, optional
Passed to ``ax.text(linespacing=...)`` when SEM is shown.
"""
import numpy as np
fs = fontsize if fontsize is not None else plot_settings.theme.annotation_fontsize
n_rows, n_cols = matrix.shape
for i in range(n_rows):
for j in range(n_cols):
val = matrix[i, j]
if not np.isfinite(val):
continue
text_color = "white" if abs(val) > threshold else "black"
sign = "+" if show_sign and val > 0 else ""
label_str = f"{sign}{val:{fmt}}"
if sem_matrix is not None:
sem = sem_matrix[i, j]
if not np.isnan(sem):
label_str = f"{label_str}\n\u00b1{sem:{fmt}}"
kwargs: dict = {
"ha": "center",
"va": "center",
"color": text_color,
"fontsize": fs,
}
if linespacing is not None:
kwargs["linespacing"] = linespacing
ax.text(j, i, label_str, **kwargs)
# ---------------------------------------------------------------------------
# Symmetric colour-limit helper
# ---------------------------------------------------------------------------
[docs]
def symmetric_clim(
values: "Sequence[float] | np.ndarray",
pad: float = 0.1,
) -> tuple[float, float]:
"""Compute symmetric colour limits centred on zero.
Parameters
----------
values : sequence of float or ndarray
Finite data values to derive limits from.
pad : float, optional
Extra padding added to both sides, by default ``0.1``.
Returns
-------
tuple[float, float]
``(vmin, vmax)`` with ``vmin == -vmax`` (before padding).
"""
import numpy as np
arr = np.asarray(values, dtype=float)
arr = arr[np.isfinite(arr)]
if len(arr) == 0:
return (-pad, pad)
max_abs = float(max(abs(arr.min()), abs(arr.max())))
return (-(max_abs + pad), max_abs + pad)