Source code for polyzymd.analyses.shared.binding_preference._aggregate

"""Binding preference aggregation helpers."""

from __future__ import annotations

import logging

import numpy as np

from ._models import (
    AggregatedBindingPreferenceEntry,
    AggregatedBindingPreferenceResult,
    AggregatedPartitionBindingEntry,
    AggregatedPartitionBindingResult,
    AggregatedPartitionCoverageEntry,
    AggregatedPartitionCoverageResult,
    AggregatedPolymerBindingPreferenceResult,
    AggregatedSystemCoverageResult,
    BindingPreferenceResult,
    PartitionBindingResult,
    PartitionCoverageResult,
    PolymerBindingPreferenceResult,
    SystemCoverageResult,
)

logger = logging.getLogger(__name__)


[docs] def aggregate_binding_preference( results: list[BindingPreferenceResult], ) -> "AggregatedBindingPreferenceResult": """Aggregate binding preference across replicates. Computes mean ± SEM for both residue-based and atom-based enrichment ratios across multiple replicates. Parameters ---------- results : list[BindingPreferenceResult] Binding preference results from multiple replicates Returns ------- AggregatedBindingPreferenceResult Aggregated results with mean and SEM for both normalization methods """ if not results: raise ValueError("No results to aggregate") # Helper function for computing mean and SEM def _compute_stats(values: list[float]) -> tuple[float | None, float | None]: """Compute mean and SEM from a list of values.""" n = len(values) if n == 0: return None, None mean_val = float(np.mean(values)) sem_val = float(np.std(values, ddof=1) / np.sqrt(n)) if n > 1 else 0.0 return mean_val, sem_val # Collect all (polymer_type, protein_group) pairs all_pairs: set[tuple[str, str]] = set() for r in results: for e in r.entries: all_pairs.add((e.polymer_type, e.protein_group)) entries = [] for poly_type, prot_group in sorted(all_pairs): # Collect values from each replicate enrichments = [] contact_fractions = [] contact_shares = [] for r in results: entry = r.get_entry(poly_type, prot_group) if entry is not None: if entry.enrichment is not None: enrichments.append(entry.enrichment) contact_fractions.append(entry.mean_contact_fraction) contact_shares.append(entry.contact_share) # Compute statistics mean_enrichment, sem_enrichment = _compute_stats(enrichments) mean_contact_fraction, sem_contact_fraction = _compute_stats(contact_fractions) mean_contact_share = float(np.mean(contact_shares)) if contact_shares else 0.0 # Get group metadata from first result first_entry = results[0].get_entry(poly_type, prot_group) n_exposed = first_entry.n_exposed_in_group if first_entry else 0 n_total = first_entry.n_residues_in_group if first_entry else 0 expected_share = first_entry.expected_share if first_entry else 0.0 entries.append( AggregatedBindingPreferenceEntry( polymer_type=poly_type, protein_group=prot_group, # Enrichment (surface-normalized) mean_enrichment=mean_enrichment, sem_enrichment=sem_enrichment, per_replicate_enrichments=enrichments, # Contact metrics mean_contact_fraction=mean_contact_fraction if mean_contact_fraction else 0.0, sem_contact_fraction=sem_contact_fraction if sem_contact_fraction else 0.0, mean_contact_share=mean_contact_share, # Expected share (from protein surface) expected_share=expected_share, # Group metadata n_exposed_in_group=n_exposed, n_residues_in_group=n_total, n_replicates=len(enrichments), ) ) # Aggregate system coverage if present in all results aggregated_system_coverage = None system_coverages = [r.system_coverage for r in results if r.system_coverage is not None] if len(system_coverages) == len(results) and len(system_coverages) > 0: aggregated_system_coverage = aggregate_system_coverage(system_coverages) logger.debug( f"Aggregated system coverage: " f"{len(aggregated_system_coverage.aa_class_coverage.entries)} AA classes, " f"{len(aggregated_system_coverage.custom_group_coverages)} custom groups " f"from {len(system_coverages)} replicates" ) # Aggregate partition-based binding preference if present in all results aggregated_binding_preference = None binding_preferences = [ r.binding_preference for r in results if r.binding_preference is not None ] if len(binding_preferences) == len(results) and len(binding_preferences) > 0: aggregated_binding_preference = aggregate_polymer_binding_preference(binding_preferences) logger.debug( f"Aggregated binding preference: " f"{len(aggregated_binding_preference.aa_class_binding)} polymer types, " f"{len(aggregated_binding_preference.user_defined_partitions)} user partitions " f"from {len(binding_preferences)} replicates" ) return AggregatedBindingPreferenceResult( entries=entries, # DEPRECATED: kept for backward compat n_replicates=len(results), total_exposed_residues=results[0].total_exposed_residues if results else 0, surface_exposure_threshold=results[0].surface_exposure_threshold if results else None, protein_groups_used=results[0].protein_groups_used if results else {}, polymer_types_used=results[0].polymer_types_used if results else {}, polymer_composition=results[0].polymer_composition if results else None, system_coverage=aggregated_system_coverage, binding_preference=aggregated_binding_preference, # NEW: partition-based per-polymer )
def _aggregate_partition_binding( partitions: list[PartitionBindingResult], ) -> AggregatedPartitionBindingResult: """Aggregate partition binding results across replicates. Parameters ---------- partitions : list[PartitionBindingResult] Partition binding results from multiple replicates (same polymer type) Returns ------- AggregatedPartitionBindingResult Aggregated result with mean and SEM """ if not partitions: raise ValueError("No partitions to aggregate") first = partitions[0] polymer_type = first.polymer_type partition_name = first.partition_name partition_type = first.partition_type # Helper function for computing mean and SEM def _compute_stats(values: list[float]) -> tuple[float | None, float | None]: n = len(values) if n == 0: return None, None mean_val = float(np.mean(values)) sem_val = float(np.std(values, ddof=1) / np.sqrt(n)) if n > 1 else 0.0 return mean_val, sem_val # Collect all element names all_elements: set[str] = set() for p in partitions: all_elements.update(p.element_names()) entries = [] for element_name in sorted(all_elements): # Collect values from each replicate contact_shares = [] enrichments = [] for p in partitions: entry = p.get_entry(element_name) if entry: contact_shares.append(entry.contact_share) if entry.enrichment is not None: enrichments.append(entry.enrichment) # Compute statistics mean_contact_share, sem_contact_share = _compute_stats(contact_shares) mean_enrichment, sem_enrichment = _compute_stats(enrichments) # Get metadata from first replicate first_entry = first.get_entry(element_name) n_exposed = first_entry.n_exposed_in_element if first_entry else 0 n_total = first_entry.n_residues_in_element if first_entry else 0 expected_share = first_entry.expected_share if first_entry else 0.0 entries.append( AggregatedPartitionBindingEntry( partition_element=element_name, polymer_type=polymer_type, mean_contact_share=mean_contact_share if mean_contact_share else 0.0, sem_contact_share=sem_contact_share if sem_contact_share else 0.0, mean_enrichment=mean_enrichment, sem_enrichment=sem_enrichment, per_replicate_enrichments=enrichments, expected_share=expected_share, n_exposed_in_element=n_exposed, n_residues_in_element=n_total, n_replicates=len(enrichments), ) ) # Compute mean total contact share (validation) total_shares = [p.total_contact_share for p in partitions] mean_total_share = float(np.mean(total_shares)) if total_shares else 1.0 return AggregatedPartitionBindingResult( partition_name=partition_name, partition_type=partition_type, polymer_type=polymer_type, entries=entries, mean_total_contact_share=mean_total_share, n_replicates=len(partitions), )
[docs] def aggregate_polymer_binding_preference( results: list[PolymerBindingPreferenceResult], ) -> AggregatedPolymerBindingPreferenceResult: """Aggregate per-polymer binding preference across replicates. Parameters ---------- results : list[PolymerBindingPreferenceResult] Per-polymer binding preference results from multiple replicates Returns ------- AggregatedPolymerBindingPreferenceResult Aggregated results with mean and SEM for all partitions """ if not results: raise ValueError("No results to aggregate") # Collect all polymer types all_polymer_types: set[str] = set() for r in results: all_polymer_types.update(r.polymer_types) # Aggregate AA class binding for each polymer type aggregated_aa_class: dict[str, AggregatedPartitionBindingResult] = {} for poly_type in sorted(all_polymer_types): poly_partitions = [ r.aa_class_binding[poly_type] for r in results if poly_type in r.aa_class_binding ] if poly_partitions: aggregated_aa_class[poly_type] = _aggregate_partition_binding(poly_partitions) # Aggregate user-defined partitions all_partition_names: set[str] = set() for r in results: all_partition_names.update(r.user_defined_partitions.keys()) aggregated_user_partitions: dict[str, dict[str, AggregatedPartitionBindingResult]] = {} for partition_name in sorted(all_partition_names): aggregated_user_partitions[partition_name] = {} for poly_type in sorted(all_polymer_types): poly_partitions = [ r.user_defined_partitions[partition_name][poly_type] for r in results if partition_name in r.user_defined_partitions and poly_type in r.user_defined_partitions[partition_name] ] if poly_partitions: aggregated_user_partitions[partition_name][poly_type] = ( _aggregate_partition_binding(poly_partitions) ) return AggregatedPolymerBindingPreferenceResult( aa_class_binding=aggregated_aa_class, user_defined_partitions=aggregated_user_partitions, n_replicates=len(results), total_exposed_residues=results[0].total_exposed_residues if results else 0, surface_exposure_threshold=results[0].surface_exposure_threshold if results else None, polymer_types=sorted(all_polymer_types), )
def _aggregate_partition_coverage( partitions: list[PartitionCoverageResult], ) -> AggregatedPartitionCoverageResult: """Aggregate a partition's coverage across replicates. Parameters ---------- partitions : list[PartitionCoverageResult] Same partition from multiple replicates Returns ------- AggregatedPartitionCoverageResult Aggregated partition coverage """ if not partitions: raise ValueError("No partitions to aggregate") # Helper function for computing mean and SEM def _compute_stats(values: list[float]) -> tuple[float | None, float | None]: """Compute mean and SEM from a list of values.""" n = len(values) if n == 0: return None, None mean_val = float(np.mean(values)) sem_val = float(np.std(values, ddof=1) / np.sqrt(n)) if n > 1 else 0.0 return mean_val, sem_val # Get partition metadata from first result first_partition = partitions[0] # Collect all element names all_elements: set[str] = set() for p in partitions: for e in p.entries: all_elements.add(e.partition_element) # Collect all polymer types all_polymer_types: set[str] = set() for p in partitions: for e in p.entries: all_polymer_types.update(e.polymer_contributions.keys()) entries = [] for element_name in sorted(all_elements): # Collect values from each replicate coverage_shares = [] enrichments = [] polymer_contributions_all: dict[str, list[float]] = {pt: [] for pt in all_polymer_types} for p in partitions: entry = p.get_entry(element_name) if entry is not None: coverage_shares.append(entry.coverage_share) if entry.coverage_enrichment is not None: enrichments.append(entry.coverage_enrichment) for poly_type in all_polymer_types: contrib = entry.polymer_contributions.get(poly_type, 0.0) polymer_contributions_all[poly_type].append(contrib) # Compute statistics mean_coverage_share, sem_coverage_share = _compute_stats(coverage_shares) mean_enrichment, sem_enrichment = _compute_stats(enrichments) # Compute mean polymer contributions mean_polymer_contributions = {} for poly_type, contribs in polymer_contributions_all.items(): if contribs: mean_polymer_contributions[poly_type] = float(np.mean(contribs)) else: mean_polymer_contributions[poly_type] = 0.0 # Get metadata from first partition first_entry = first_partition.get_entry(element_name) n_exposed = first_entry.n_exposed_in_element if first_entry else 0 n_total = first_entry.n_residues_in_element if first_entry else 0 expected_share = first_entry.expected_share if first_entry else 0.0 entries.append( AggregatedPartitionCoverageEntry( partition_element=element_name, mean_coverage_share=mean_coverage_share if mean_coverage_share else 0.0, sem_coverage_share=sem_coverage_share if sem_coverage_share else 0.0, mean_coverage_enrichment=mean_enrichment, sem_coverage_enrichment=sem_enrichment, per_replicate_enrichments=enrichments, expected_share=expected_share, n_exposed_in_element=n_exposed, n_residues_in_element=n_total, n_replicates=len(enrichments), mean_polymer_contributions=mean_polymer_contributions, ) ) return AggregatedPartitionCoverageResult( partition_name=first_partition.partition_name, partition_type=first_partition.partition_type, entries=entries, n_replicates=len(partitions), )
[docs] def aggregate_system_coverage( results: list[SystemCoverageResult], ) -> AggregatedSystemCoverageResult: """Aggregate system coverage across replicates. Computes mean ± SEM for coverage metrics across multiple replicates for all partitions (AA class, custom groups, combined). Parameters ---------- results : list[SystemCoverageResult] System coverage results from multiple replicates Returns ------- AggregatedSystemCoverageResult Aggregated results with mean and SEM for all partitions """ if not results: raise ValueError("No results to aggregate") # Aggregate AA class partition aa_class_partitions = [r.aa_class_coverage for r in results] aggregated_aa_class = _aggregate_partition_coverage(aa_class_partitions) # Aggregate custom group partitions # First, find all custom group names across all results all_custom_groups: set[str] = set() for r in results: all_custom_groups.update(r.custom_group_coverages.keys()) aggregated_custom_groups: dict[str, AggregatedPartitionCoverageResult] = {} for group_name in sorted(all_custom_groups): # Collect this partition from all results that have it group_partitions = [ r.custom_group_coverages[group_name] for r in results if group_name in r.custom_group_coverages ] if group_partitions: aggregated_custom_groups[group_name] = _aggregate_partition_coverage(group_partitions) # Aggregate combined custom partition (if all results have it) aggregated_combined: AggregatedPartitionCoverageResult | None = None combined_partitions = [ r.combined_custom_coverage for r in results if r.combined_custom_coverage is not None ] if combined_partitions and len(combined_partitions) == len(results): aggregated_combined = _aggregate_partition_coverage(combined_partitions) # Aggregate user-defined partitions # Find all user-defined partition names across all results all_user_partitions: set[str] = set() for r in results: all_user_partitions.update(r.user_defined_partitions.keys()) aggregated_user_partitions: dict[str, AggregatedPartitionCoverageResult] = {} for partition_name in sorted(all_user_partitions): # Collect this partition from all results that have it partition_results = [ r.user_defined_partitions[partition_name] for r in results if partition_name in r.user_defined_partitions ] if partition_results: aggregated_user_partitions[partition_name] = _aggregate_partition_coverage( partition_results ) # Collect polymer types all_polymer_types: set[str] = set() for r in results: all_polymer_types.update(r.polymer_types_included) # Check for overlaps has_overlaps = any(r.has_overlapping_custom_groups for r in results) return AggregatedSystemCoverageResult( aa_class_coverage=aggregated_aa_class, custom_group_coverages=aggregated_custom_groups, combined_custom_coverage=aggregated_combined, user_defined_partitions=aggregated_user_partitions, n_replicates=len(results), total_exposed_residues=results[0].total_exposed_residues if results else 0, surface_exposure_threshold=results[0].surface_exposure_threshold if results else None, custom_group_selections=results[0].custom_group_selections if results else {}, polymer_types_included=sorted(all_polymer_types), has_overlapping_custom_groups=has_overlaps, )