Source code for polyzymd.analyses.shared.binding_preference._compute

"""Binding preference compute kernels and partition logic."""

from __future__ import annotations

import logging
from typing import TYPE_CHECKING, Any, Literal

from polyzymd.analyses.shared.aa_classification import DEFAULT_AA_CLASS_SELECTIONS

from ._models import (
    BindingPreferenceEntry,
    BindingPreferenceResult,
    PartitionBindingEntry,
    PartitionBindingResult,
    PartitionCoverageEntry,
    PartitionCoverageResult,
    PolymerBindingPreferenceResult,
    PolymerComposition,
    SystemCoverageResult,
)

if TYPE_CHECKING:
    from polyzymd.analyses.shared.surface_exposure import SurfaceExposureResult

logger = logging.getLogger(__name__)


def _detect_overlapping_groups(groups: dict[str, set[int]]) -> list[tuple[str, str]]:
    """Detect pairs of groups that share residue IDs.

    Parameters
    ----------
    groups : dict[str, set[int]]
        Mapping of group name to residue IDs

    Returns
    -------
    list[tuple[str, str]]
        List of (group_a, group_b) pairs that have overlapping residue IDs
    """
    group_names = list(groups.keys())
    overlaps = []

    for i, g1 in enumerate(group_names):
        for g2 in group_names[i + 1 :]:
            if groups[g1] & groups[g2]:  # Non-empty intersection
                overlaps.append((g1, g2))

    return overlaps


def _compute_enrichment(coverage_share: float, expected_share: float) -> float | None:
    """Calculate zero-centered enrichment: (observed / expected) - 1.

    Parameters
    ----------
    coverage_share : float
        Fraction of contacts to this element
    expected_share : float
        Expected fraction based on surface availability

    Returns
    -------
    float | None
        Enrichment value, or None if expected_share is 0
    """
    if expected_share > 0 and coverage_share > 0:
        return (coverage_share / expected_share) - 1
    elif expected_share > 0 and coverage_share == 0:
        return -1.0  # Complete avoidance
    else:
        return None  # Cannot compute (no expected share)


def _compute_partition_binding(
    partition_name: str,
    partition_type: Literal["aa_class", "user_defined"],
    partition_groups: dict[str, set[int]],
    exposed_partition: dict[str, set[int]],
    polymer_type: str,
    contact_data_for_polymer: dict[str, dict[str, Any]],
    total_exposed: int,
    total_contact_frames_for_polymer: int,
) -> "PartitionBindingResult":
    """Compute binding preference for a single partition and single polymer type.

    This is the per-polymer version of _compute_partition_coverage.
    It computes contact_share and enrichment for one polymer type across
    all elements in a partition.

    Parameters
    ----------
    partition_name : str
        Name of the partition (e.g., "aa_class", "lid_helices")
    partition_type : str
        One of: "aa_class", "user_defined"
    partition_groups : dict[str, set[int]]
        Mapping of element name to ALL residue IDs in that element
    exposed_partition : dict[str, set[int]]
        Mapping of element name to EXPOSED residue IDs only
    polymer_type : str
        Polymer type this binding is for (e.g., "SBM")
    contact_data_for_polymer : dict[str, dict[str, Any]]
        Contact data for this polymer: {element_name: {"total_frames": int, "residues_contacted": set}}
    total_exposed : int
        Total number of exposed residues across all elements in the partition
    total_contact_frames_for_polymer : int
        Total contact frames for this polymer type (for contact_share denominator)

    Returns
    -------
    PartitionBindingResult
        Binding result for this polymer type on this partition
    """
    # The classes below are defined earlier in this same module -- no import needed.
    # (Previously used a self-import to avoid forward references.)

    # Build partition entries
    binding_entries = []
    total_contact_share = 0.0
    total_expected_share = 0.0

    for element_name in sorted(partition_groups.keys()):
        n_total = len(partition_groups.get(element_name, set()))
        n_exposed = len(exposed_partition.get(element_name, set()))

        # Calculate expected share based on surface availability
        expected_share = n_exposed / total_exposed if total_exposed > 0 else 0.0
        total_expected_share += expected_share

        # Get contact data for this element
        edata = contact_data_for_polymer.get(
            element_name, {"total_frames": 0, "residues_contacted": set()}
        )
        total_frames = edata.get("total_frames", 0)
        residues_contacted = edata.get("residues_contacted", set())
        n_residues_contacted = len(residues_contacted)

        # Calculate contact share (fraction of this polymer's contacts to this element)
        if total_contact_frames_for_polymer > 0:
            contact_share = total_frames / total_contact_frames_for_polymer
        else:
            contact_share = 0.0
        total_contact_share += contact_share

        # Calculate enrichment
        enrichment = _compute_enrichment(contact_share, expected_share)

        binding_entries.append(
            PartitionBindingEntry(
                partition_element=element_name,
                polymer_type=polymer_type,
                total_contact_frames=total_frames,
                contact_share=contact_share,
                expected_share=expected_share,
                enrichment=enrichment,
                n_exposed_in_element=n_exposed,
                n_residues_in_element=n_total,
                n_residues_contacted=n_residues_contacted,
            )
        )

    return PartitionBindingResult(
        partition_name=partition_name,
        partition_type=partition_type,
        polymer_type=polymer_type,
        entries=binding_entries,
        total_contact_share=total_contact_share,
        total_expected_share=total_expected_share,
        total_contact_frames=total_contact_frames_for_polymer,
    )


def _compute_polymer_binding_preference(
    contact_data: dict[str, dict[str, dict[str, Any]]],
    total_contacts_by_polymer: dict[str, int],
    protein_groups: dict[str, set[int]],
    exposed_groups: dict[str, set[int]],
    protein_partitions: dict[str, list[str]] | None,
    total_exposed: int,
    n_frames: int,
    surface_exposure_threshold: float | None,
    polymer_composition: "PolymerComposition | None",
    protein_groups_used: dict[str, str] | None,
) -> "PolymerBindingPreferenceResult":
    """Compute per-polymer binding preference using partition-based analysis.

    Parameters
    ----------
    contact_data : dict[str, dict[str, dict[str, Any]]]
        Contact data: {polymer_type: {group_name: {"total_frames": int, "residues_contacted": set}}}
    total_contacts_by_polymer : dict[str, int]
        Total contact frames per polymer type
    protein_groups : dict[str, set[int]]
        All protein groups (AA classes + custom)
    exposed_groups : dict[str, set[int]]
        Exposed residues per group
    protein_partitions : dict[str, list[str]] | None
        User-defined partitions
    total_exposed : int
        Total exposed residues
    n_frames : int
        Number of frames analyzed
    surface_exposure_threshold : float | None
        SASA threshold used
    polymer_composition : PolymerComposition | None
        Polymer composition metadata
    protein_groups_used : dict[str, str] | None
        Selection strings used (for metadata)

    Returns
    -------
    PolymerBindingPreferenceResult
        Per-polymer partition-based binding preference
    """
    # PolymerBindingPreferenceResult is defined later in this same module.
    # (Previously used a self-import to avoid forward references.)

    all_polymer_types = sorted(contact_data.keys())

    # Separate AA class groups from custom groups
    aa_class_names = set(DEFAULT_AA_CLASS_SELECTIONS.keys())
    aa_class_groups: dict[str, set[int]] = {}
    aa_class_exposed: dict[str, set[int]] = {}

    for group_name, resids in protein_groups.items():
        if group_name in aa_class_names:
            aa_class_groups[group_name] = resids
            aa_class_exposed[group_name] = exposed_groups.get(group_name, set())

    # Get all exposed residue IDs (for computing "rest_of_protein")
    all_exposed_resids: set[int] = set()
    for resids in exposed_groups.values():
        all_exposed_resids.update(resids)

    # Compute total exposed for AA class partition
    aa_class_total_exposed = sum(len(resids) for resids in aa_class_exposed.values())

    # ---------------------------------------------------------------------
    # 1. Compute AA Class Partition for each polymer type
    # ---------------------------------------------------------------------
    aa_class_binding: dict[str, "PartitionBindingResult"] = {}

    for poly_type in all_polymer_types:
        poly_contact_data = contact_data.get(poly_type, {})
        total_frames = total_contacts_by_polymer.get(poly_type, 0)

        # Filter contact data to AA class groups only
        aa_contact_data: dict[str, dict[str, Any]] = {}
        for group_name in aa_class_groups:
            if group_name in poly_contact_data:
                aa_contact_data[group_name] = poly_contact_data[group_name]
            else:
                aa_contact_data[group_name] = {"total_frames": 0, "residues_contacted": set()}

        # Compute total contact frames for AA class groups only
        aa_total_frames = sum(d.get("total_frames", 0) for d in aa_contact_data.values())

        aa_class_binding[poly_type] = _compute_partition_binding(
            partition_name="aa_class",
            partition_type="aa_class",
            partition_groups=aa_class_groups,
            exposed_partition=aa_class_exposed,
            polymer_type=poly_type,
            contact_data_for_polymer=aa_contact_data,
            total_exposed=aa_class_total_exposed,
            total_contact_frames_for_polymer=aa_total_frames,
        )

    # ---------------------------------------------------------------------
    # 2. Compute User-Defined Partitions for each polymer type
    # ---------------------------------------------------------------------
    user_defined_partitions: dict[str, dict[str, "PartitionBindingResult"]] = {}

    if protein_partitions:
        for partition_name, group_names in protein_partitions.items():
            # Build partition groups from referenced group names
            partition_groups_map: dict[str, set[int]] = {}
            partition_exposed_map: dict[str, set[int]] = {}
            all_partition_exposed: set[int] = set()

            for group_name in group_names:
                if group_name not in protein_groups:
                    logger.warning(
                        f"Partition '{partition_name}' references undefined group "
                        f"'{group_name}' - skipping this group"
                    )
                    continue

                partition_groups_map[group_name] = protein_groups[group_name]
                partition_exposed_map[group_name] = exposed_groups.get(group_name, set())
                all_partition_exposed.update(partition_exposed_map[group_name])

            if not partition_groups_map:
                logger.warning(f"Partition '{partition_name}' has no valid groups - skipping")
                continue

            # Check if we need to add rest_of_protein
            rest_exposed = all_exposed_resids - all_partition_exposed
            if rest_exposed:
                # Partition doesn't cover all residues - add rest_of_protein
                rest_all = set().union(*exposed_groups.values()) - all_partition_exposed
                partition_groups_map["rest_of_protein"] = rest_all
                partition_exposed_map["rest_of_protein"] = rest_exposed
                logger.debug(
                    f"Partition '{partition_name}': auto-added 'rest_of_protein' "
                    f"with {len(rest_exposed)} exposed residues"
                )

            # Compute total exposed for this partition
            user_partition_total_exposed = sum(len(r) for r in partition_exposed_map.values())

            # Compute for each polymer type
            user_defined_partitions[partition_name] = {}
            for poly_type in all_polymer_types:
                poly_contact_data = contact_data.get(poly_type, {})

                # Build contact data for this partition
                partition_contact_data: dict[str, dict[str, Any]] = {}
                for group_name in partition_groups_map:
                    if group_name == "rest_of_protein":
                        # Aggregate contacts from groups NOT in this partition
                        rest_frames = 0
                        rest_residues: set[int] = set()
                        for gname, gdata in poly_contact_data.items():
                            if gname not in group_names:
                                rest_frames += gdata.get("total_frames", 0)
                                rest_residues.update(gdata.get("residues_contacted", set()))
                        partition_contact_data["rest_of_protein"] = {
                            "total_frames": rest_frames,
                            "residues_contacted": rest_residues,
                        }
                    elif group_name in poly_contact_data:
                        partition_contact_data[group_name] = poly_contact_data[group_name]
                    else:
                        partition_contact_data[group_name] = {
                            "total_frames": 0,
                            "residues_contacted": set(),
                        }

                # Compute total frames for this partition
                partition_total_frames = sum(
                    d.get("total_frames", 0) for d in partition_contact_data.values()
                )

                user_defined_partitions[partition_name][poly_type] = _compute_partition_binding(
                    partition_name=partition_name,
                    partition_type="user_defined",
                    partition_groups=partition_groups_map,
                    exposed_partition=partition_exposed_map,
                    polymer_type=poly_type,
                    contact_data_for_polymer=partition_contact_data,
                    total_exposed=user_partition_total_exposed,
                    total_contact_frames_for_polymer=partition_total_frames,
                )

            logger.info(
                f"Computed user-defined partition '{partition_name}' binding for "
                f"{len(all_polymer_types)} polymer types"
            )

    return PolymerBindingPreferenceResult(
        aa_class_binding=aa_class_binding,
        user_defined_partitions=user_defined_partitions,
        n_frames=n_frames,
        total_exposed_residues=total_exposed,
        surface_exposure_threshold=surface_exposure_threshold,
        polymer_types=all_polymer_types,
        polymer_composition=polymer_composition,
        protein_groups_used=protein_groups_used or {},
    )


def _compute_partition_coverage(
    partition_name: str,
    partition_type: Literal["aa_class", "binary_custom", "combined_custom", "user_defined"],
    partition_groups: dict[str, set[int]],
    exposed_partition: dict[str, set[int]],
    entries: list["BindingPreferenceEntry"],
    total_exposed: int,
    all_polymer_types: list[str],
) -> PartitionCoverageResult:
    """Compute coverage for a single partition.

    A partition divides the protein surface into mutually exclusive elements.
    This function validates the partition and computes coverage metrics.

    Parameters
    ----------
    partition_name : str
        Name of the partition (e.g., "aa_class", "lid_helix_5_vs_rest")
    partition_type : str
        One of: "aa_class", "binary_custom", "combined_custom", "user_defined"
    partition_groups : dict[str, set[int]]
        Mapping of element name to ALL residue IDs in that element
    exposed_partition : dict[str, set[int]]
        Mapping of element name to EXPOSED residue IDs only
    entries : list[BindingPreferenceEntry]
        Binding preference entries to aggregate
    total_exposed : int
        Total number of exposed residues across all elements
    all_polymer_types : list[str]
        List of all polymer types

    Returns
    -------
    PartitionCoverageResult
        Coverage result for this partition
    """
    # Build a mapping: resid -> partition_element
    resid_to_element: dict[int, str] = {}
    for element_name, resids in partition_groups.items():
        for resid in resids:
            resid_to_element[resid] = element_name

    # Aggregate contact frames by partition element
    # Structure: {element: {"total_frames": int, "by_polymer": {poly: frames}}}
    element_totals: dict[str, dict[str, Any]] = {
        element_name: {"total_frames": 0, "by_polymer": dict.fromkeys(all_polymer_types, 0)}
        for element_name in partition_groups.keys()
    }

    for entry in entries:
        # Determine which element this entry's protein_group maps to
        element_name = entry.protein_group
        if element_name in element_totals:
            element_totals[element_name]["total_frames"] += entry.total_contact_frames
            element_totals[element_name]["by_polymer"][entry.polymer_type] = (
                element_totals[element_name]["by_polymer"].get(entry.polymer_type, 0)
                + entry.total_contact_frames
            )

    # Calculate grand total of contacts across all elements
    grand_total = sum(et["total_frames"] for et in element_totals.values())

    # Build partition entries
    coverage_entries = []
    total_coverage_share = 0.0
    total_expected_share = 0.0

    for element_name in sorted(partition_groups.keys()):
        n_total = len(partition_groups.get(element_name, set()))
        n_exposed = len(exposed_partition.get(element_name, set()))

        # Calculate expected share
        expected_share = n_exposed / total_exposed if total_exposed > 0 else 0.0
        total_expected_share += expected_share

        # Get contact data
        edata = element_totals.get(element_name, {"total_frames": 0, "by_polymer": {}})
        total_frames = edata["total_frames"]
        by_polymer = edata.get("by_polymer", {})

        # Calculate coverage share
        coverage_share = total_frames / grand_total if grand_total > 0 else 0.0
        total_coverage_share += coverage_share

        # Calculate enrichment
        enrichment = _compute_enrichment(coverage_share, expected_share)

        # Calculate polymer contributions
        polymer_contributions: dict[str, float] = {}
        if total_frames > 0:
            for pt, pf in by_polymer.items():
                polymer_contributions[pt] = pf / total_frames
        else:
            for pt in all_polymer_types:
                polymer_contributions[pt] = 0.0

        coverage_entries.append(
            PartitionCoverageEntry(
                partition_element=element_name,
                total_contact_frames=total_frames,
                coverage_share=coverage_share,
                expected_share=expected_share,
                coverage_enrichment=enrichment,
                n_exposed_in_element=n_exposed,
                n_residues_in_element=n_total,
                polymer_contributions=polymer_contributions,
            )
        )

    return PartitionCoverageResult(
        partition_name=partition_name,
        partition_type=partition_type,
        entries=coverage_entries,
        total_coverage_share=total_coverage_share,
        total_expected_share=total_expected_share,
    )


def _compute_system_coverage(
    entries: list[BindingPreferenceEntry],
    protein_groups: dict[str, set[int]],
    exposed_groups: dict[str, set[int]],
    total_exposed: int,
    n_frames: int,
    surface_exposure_threshold: float | None,
    protein_group_selections: dict[str, str] | None,
    protein_partitions: dict[str, list[str]] | None = None,
) -> SystemCoverageResult:
    """Compute system-level coverage using partition-based analysis.

    This function computes coverage enrichments using proper partitions to
    avoid the overlap bug where custom groups and AA classes can inflate
    the expected_share denominator.

    Partition Strategy
    ------------------
    1. **AA Class Partition**: 5-way partition by amino acid class.
       Always computed, uses only the 5 default AA classes.

    2. **Binary Custom Partitions**: For each custom group, compute a
       binary partition (group vs rest_of_protein).

    3. **Combined Custom Partition**: If custom groups don't overlap,
       combine them all + rest_of_protein into a single partition.

    4. **User-Defined Partitions**: Custom partitions from config.
       Each references groups from protein_groups and must be mutually
       exclusive. 'rest_of_protein' is auto-added if needed.

    Parameters
    ----------
    entries : list[BindingPreferenceEntry]
        Binding preference entries (per polymer type × protein group)
    protein_groups : dict[str, set[int]]
        Mapping of group name to ALL residue IDs in that group
    exposed_groups : dict[str, set[int]]
        Mapping of group name to EXPOSED residue IDs only
    total_exposed : int
        Total number of exposed residues
    n_frames : int
        Number of trajectory frames analyzed
    surface_exposure_threshold : float | None
        SASA threshold used for surface filtering
    protein_group_selections : dict[str, str] | None
        Original MDAnalysis selections (for metadata)
    protein_partitions : dict[str, list[str]] | None
        User-defined partitions: {partition_name: [group1, group2, ...]}
        Groups must exist in protein_groups.

    Returns
    -------
    SystemCoverageResult
        Partition-based coverage metrics (schema v2)
    """
    from polyzymd.analyses.shared.aa_classification import DEFAULT_AA_CLASS_SELECTIONS

    # Collect all polymer types
    all_polymer_types = sorted({e.polymer_type for e in entries})

    # Separate AA class groups from custom groups
    aa_class_names = set(DEFAULT_AA_CLASS_SELECTIONS.keys())
    aa_class_groups: dict[str, set[int]] = {}
    aa_class_exposed: dict[str, set[int]] = {}
    custom_groups: dict[str, set[int]] = {}
    custom_exposed: dict[str, set[int]] = {}
    custom_selections: dict[str, str] = {}

    for group_name, resids in protein_groups.items():
        if group_name in aa_class_names:
            aa_class_groups[group_name] = resids
            aa_class_exposed[group_name] = exposed_groups.get(group_name, set())
        else:
            custom_groups[group_name] = resids
            custom_exposed[group_name] = exposed_groups.get(group_name, set())
            if protein_group_selections and group_name in protein_group_selections:
                custom_selections[group_name] = protein_group_selections[group_name]

    # Get all exposed residue IDs (for computing "rest_of_protein")
    all_exposed_resids: set[int] = set()
    for resids in exposed_groups.values():
        all_exposed_resids.update(resids)

    # Filter entries by group type for proper partition computation
    aa_class_entries = [e for e in entries if e.protein_group in aa_class_names]
    custom_entries = [e for e in entries if e.protein_group not in aa_class_names]

    # ---------------------------------------------------------------------
    # 1. Compute AA Class Partition (always)
    # ---------------------------------------------------------------------
    # Compute total exposed for AA class partition
    aa_class_total_exposed = sum(len(resids) for resids in aa_class_exposed.values())

    aa_class_coverage = _compute_partition_coverage(
        partition_name="aa_class",
        partition_type="aa_class",
        partition_groups=aa_class_groups,
        exposed_partition=aa_class_exposed,
        entries=aa_class_entries,
        total_exposed=aa_class_total_exposed,
        all_polymer_types=all_polymer_types,
    )

    # ---------------------------------------------------------------------
    # 2. Compute Binary Custom Partitions (one per custom group)
    # ---------------------------------------------------------------------
    custom_group_coverages: dict[str, PartitionCoverageResult] = {}

    for group_name, group_resids in custom_groups.items():
        group_exposed = custom_exposed.get(group_name, set())

        # Compute "rest_of_protein" as all exposed residues NOT in this group
        rest_exposed = all_exposed_resids - group_exposed
        rest_all = set()
        for gname, gresids in protein_groups.items():
            if gname != group_name:
                rest_all.update(gresids)
        # Actually, rest_all should be all residues NOT in this custom group
        # We need the full protein residue set - but we only have groups
        # For now, use the exposed residues as the proxy

        # Build binary partition
        binary_partition_groups = {
            group_name: group_resids,
            "rest_of_protein": rest_all - group_resids,
        }
        binary_partition_exposed = {
            group_name: group_exposed,
            "rest_of_protein": rest_exposed,
        }

        # Create synthetic entries for "rest_of_protein" by aggregating all other entries
        # We need to compute the contact frames to "rest_of_protein"
        rest_contact_frames: dict[str, int] = dict.fromkeys(all_polymer_types, 0)
        group_contact_frames: dict[str, int] = dict.fromkeys(all_polymer_types, 0)

        for entry in entries:
            if entry.protein_group == group_name:
                group_contact_frames[entry.polymer_type] = entry.total_contact_frames
            elif entry.protein_group not in custom_groups:
                # It's an AA class group - add to rest
                rest_contact_frames[entry.polymer_type] = (
                    rest_contact_frames.get(entry.polymer_type, 0) + entry.total_contact_frames
                )
            # Note: Other custom groups are NOT added to rest - they'll have their own partition

        # Build synthetic entries for the binary partition
        binary_entries: list[BindingPreferenceEntry] = []

        # Add entries for the custom group
        for entry in entries:
            if entry.protein_group == group_name:
                binary_entries.append(entry)

        # Create synthetic entries for rest_of_protein
        for pt, frames in rest_contact_frames.items():
            binary_entries.append(
                BindingPreferenceEntry(
                    polymer_type=pt,
                    protein_group="rest_of_protein",
                    total_contact_frames=frames,
                    mean_contact_fraction=0.0,  # Not used for partition coverage
                    n_residues_contacted=0,  # Not used for partition coverage
                    contact_share=0.0,
                    expected_share=0.0,
                    enrichment=None,
                    n_exposed_in_group=len(rest_exposed),
                    n_residues_in_group=len(binary_partition_groups["rest_of_protein"]),
                )
            )

        binary_total_exposed = len(group_exposed) + len(rest_exposed)

        binary_coverage = _compute_partition_coverage(
            partition_name=f"{group_name}_vs_rest",
            partition_type="binary_custom",
            partition_groups=binary_partition_groups,
            exposed_partition=binary_partition_exposed,
            entries=binary_entries,
            total_exposed=binary_total_exposed,
            all_polymer_types=all_polymer_types,
        )

        custom_group_coverages[group_name] = binary_coverage

    # ---------------------------------------------------------------------
    # 3. Check for overlaps among custom groups
    # ---------------------------------------------------------------------
    overlapping_pairs = _detect_overlapping_groups(custom_exposed)
    has_overlaps = len(overlapping_pairs) > 0

    if has_overlaps:
        logger.warning(
            f"Custom protein groups have overlapping residues: {overlapping_pairs}. "
            f"Combined custom partition will not be computed."
        )

    # ---------------------------------------------------------------------
    # 4. Compute Combined Custom Partition (if no overlaps)
    # ---------------------------------------------------------------------
    combined_custom_coverage: PartitionCoverageResult | None = None

    if custom_groups and not has_overlaps:
        # Build combined partition: all custom groups + rest_of_protein
        combined_partition_groups: dict[str, set[int]] = dict(custom_groups)
        combined_partition_exposed: dict[str, set[int]] = dict(custom_exposed)

        # Compute rest_of_protein for combined partition
        all_custom_exposed: set[int] = set()
        all_custom_resids: set[int] = set()
        for resids in custom_exposed.values():
            all_custom_exposed.update(resids)
        for resids in custom_groups.values():
            all_custom_resids.update(resids)

        rest_exposed_combined = all_exposed_resids - all_custom_exposed
        rest_all_combined: set[int] = set()
        for gname, gresids in protein_groups.items():
            if gname not in custom_groups:
                rest_all_combined.update(gresids)
        rest_all_combined = rest_all_combined - all_custom_resids

        combined_partition_groups["rest_of_protein"] = rest_all_combined
        combined_partition_exposed["rest_of_protein"] = rest_exposed_combined

        # Build entries for combined partition
        combined_entries: list[BindingPreferenceEntry] = list(custom_entries)

        # Add synthetic entries for rest_of_protein
        rest_contact_frames_combined: dict[str, int] = dict.fromkeys(all_polymer_types, 0)
        for entry in aa_class_entries:
            rest_contact_frames_combined[entry.polymer_type] = (
                rest_contact_frames_combined.get(entry.polymer_type, 0) + entry.total_contact_frames
            )

        for pt, frames in rest_contact_frames_combined.items():
            combined_entries.append(
                BindingPreferenceEntry(
                    polymer_type=pt,
                    protein_group="rest_of_protein",
                    total_contact_frames=frames,
                    mean_contact_fraction=0.0,  # Not used for partition coverage
                    n_residues_contacted=0,  # Not used for partition coverage
                    contact_share=0.0,
                    expected_share=0.0,
                    enrichment=None,
                    n_exposed_in_group=len(rest_exposed_combined),
                    n_residues_in_group=len(rest_all_combined),
                )
            )

        combined_total_exposed = sum(len(r) for r in combined_partition_exposed.values())

        combined_custom_coverage = _compute_partition_coverage(
            partition_name="combined_custom",
            partition_type="combined_custom",
            partition_groups=combined_partition_groups,
            exposed_partition=combined_partition_exposed,
            entries=combined_entries,
            total_exposed=combined_total_exposed,
            all_polymer_types=all_polymer_types,
        )

    # ---------------------------------------------------------------------
    # 5. Compute User-Defined Partitions (from protein_partitions config)
    # ---------------------------------------------------------------------
    user_defined_partitions: dict[str, PartitionCoverageResult] = {}

    if protein_partitions:
        for partition_name, group_names in protein_partitions.items():
            # Build partition groups from referenced group names
            partition_groups_map: dict[str, set[int]] = {}
            partition_exposed_map: dict[str, set[int]] = {}
            partition_entries_list: list[BindingPreferenceEntry] = []

            # Collect residues from all specified groups
            all_partition_exposed: set[int] = set()

            for group_name in group_names:
                if group_name not in protein_groups:
                    logger.warning(
                        f"Partition '{partition_name}' references undefined group "
                        f"'{group_name}' - skipping this group"
                    )
                    continue

                partition_groups_map[group_name] = protein_groups[group_name]
                partition_exposed_map[group_name] = exposed_groups.get(group_name, set())
                all_partition_exposed.update(partition_exposed_map[group_name])

                # Add entries for this group
                for entry in entries:
                    if entry.protein_group == group_name:
                        partition_entries_list.append(entry)

            if not partition_groups_map:
                logger.warning(f"Partition '{partition_name}' has no valid groups - skipping")
                continue

            # Check if we need to add rest_of_protein
            # (if partition doesn't cover all exposed residues)
            rest_exposed = all_exposed_resids - all_partition_exposed
            if rest_exposed:
                # Partition doesn't cover all residues - add rest_of_protein
                partition_groups_map["rest_of_protein"] = (
                    set().union(*exposed_groups.values()) - all_partition_exposed
                )
                partition_exposed_map["rest_of_protein"] = rest_exposed

                # Create synthetic entries for rest_of_protein
                rest_contact_frames_user: dict[str, int] = dict.fromkeys(all_polymer_types, 0)
                for entry in entries:
                    if entry.protein_group not in group_names:
                        rest_contact_frames_user[entry.polymer_type] = (
                            rest_contact_frames_user.get(entry.polymer_type, 0)
                            + entry.total_contact_frames
                        )

                for pt, frames in rest_contact_frames_user.items():
                    partition_entries_list.append(
                        BindingPreferenceEntry(
                            polymer_type=pt,
                            protein_group="rest_of_protein",
                            total_contact_frames=frames,
                            mean_contact_fraction=0.0,
                            n_residues_contacted=0,
                            contact_share=0.0,
                            expected_share=0.0,
                            enrichment=None,
                            n_exposed_in_group=len(rest_exposed),
                            n_residues_in_group=len(partition_groups_map["rest_of_protein"]),
                        )
                    )

                logger.debug(
                    f"Partition '{partition_name}': auto-added 'rest_of_protein' "
                    f"with {len(rest_exposed)} exposed residues"
                )
            else:
                logger.debug(
                    f"Partition '{partition_name}': covers all exposed residues, "
                    f"no 'rest_of_protein' needed"
                )

            # Compute total exposed for this partition
            user_partition_total_exposed = sum(len(r) for r in partition_exposed_map.values())

            # Compute the partition coverage
            user_partition_coverage = _compute_partition_coverage(
                partition_name=partition_name,
                partition_type="user_defined",
                partition_groups=partition_groups_map,
                exposed_partition=partition_exposed_map,
                entries=partition_entries_list,
                total_exposed=user_partition_total_exposed,
                all_polymer_types=all_polymer_types,
            )

            user_defined_partitions[partition_name] = user_partition_coverage
            logger.info(
                f"Computed user-defined partition '{partition_name}' with "
                f"{len(partition_groups_map)} groups"
            )

    # ---------------------------------------------------------------------
    # 6. Compute total contact frames
    # ---------------------------------------------------------------------
    total_contact_frames = sum(e.total_contact_frames for e in entries)

    return SystemCoverageResult(
        aa_class_coverage=aa_class_coverage,
        custom_group_coverages=custom_group_coverages,
        combined_custom_coverage=combined_custom_coverage,
        user_defined_partitions=user_defined_partitions,
        n_frames=n_frames,
        total_contact_frames=total_contact_frames,
        total_exposed_residues=total_exposed,
        surface_exposure_threshold=surface_exposure_threshold,
        custom_group_selections=custom_selections,
        polymer_types_included=all_polymer_types,
        has_overlapping_custom_groups=has_overlaps,
        overlapping_group_pairs=overlapping_pairs,
    )


[docs] def compute_binding_preference( contact_result: Any, surface_exposure: "SurfaceExposureResult", protein_groups: dict[str, set[int]], polymer_composition: PolymerComposition, polymer_types: list[str] | None = None, protein_group_selections: dict[str, str] | None = None, polymer_type_selections: dict[str, str] | None = None, protein_partitions: dict[str, list[str]] | None = None, ) -> BindingPreferenceResult: """Compute binding preference from contact results. This function computes enrichment ratios for each (polymer_type, protein_group) combination, answering: "Does this polymer type preferentially bind this protein group compared to random chance?" The enrichment calculation accounts for: 1. Surface exposure (only exposed residues are considered) 2. Contact duration (contact frames are summed, not binary counts) 3. Polymer composition (normalization by residue count or heavy atom count) Parameters ---------- contact_result : Any Raw contact analysis results from trajectory surface_exposure : SurfaceExposureResult Surface exposure data for filtering buried residues protein_groups : dict[str, set[int]] Mapping of group name to residue IDs in that group. Only surface-exposed residues in each group are counted. Example: {"aromatic": {12, 45, 67}, "charged": {23, 34}} polymer_composition : PolymerComposition Polymer composition data (residue and heavy atom counts per type). Used for dual normalization of enrichment ratios. polymer_types : list[str], optional If provided, only compute for these polymer residue types. If None, all polymer types found in contacts are used. protein_group_selections : dict[str, str], optional Original MDAnalysis selections (for metadata/reproducibility) polymer_type_selections : dict[str, str], optional Original MDAnalysis selections (for metadata/reproducibility) protein_partitions : dict[str, list[str]], optional User-defined partitions for system coverage plots. Each partition maps a name to a list of group names from protein_groups. Groups within a partition must be mutually exclusive. Example: {"lid_helices": ["lid_helix_5", "lid_helix_10"]} Returns ------- BindingPreferenceResult Binding preference metrics with dual enrichment normalization Notes ----- Enrichment calculation is centered at zero For each (polymer_type, protein_group) pair: contact_share = polymer_contacts_to_group / polymer_total_contacts Two normalization methods are computed: 1. **Residue-based** (matches experimental concentration ratios): expected_by_residue = polymer_residue_count / total_polymer_residues enrichment_by_residue = (contact_share / expected_by_residue) - 1 2. **Atom-based** (accounts for monomer size via heavy atoms): expected_by_atoms = polymer_heavy_atoms / total_polymer_heavy_atoms enrichment_by_atoms = (contact_share / expected_by_atoms) - 1 Interpretation (both methods): - enrichment > 0: Preferential binding (more contacts than expected) - enrichment = 0: Neutral (matches random chance) - enrichment < 0: Avoidance (fewer contacts than expected) - enrichment = -1: Complete avoidance (no contacts at all) When to use which metric: - enrichment_by_residue: Direct comparison to experimental concentration ratios - enrichment_by_atoms: Reveals true chemical affinity vs. geometric/steric effects """ exposed_resids = surface_exposure.exposed_resids n_frames = contact_result.n_frames total_exposed = len(exposed_resids) logger.info( f"Computing binding preference: {total_exposed} exposed residues, " f"{len(protein_groups)} protein groups" ) logger.info( f"Polymer composition: {polymer_composition.residue_counts} residues, " f"{polymer_composition.heavy_atom_counts} heavy atoms" ) # Filter protein_groups to only exposed residues exposed_groups: dict[str, set[int]] = {} for group_name, resids in protein_groups.items(): exposed_groups[group_name] = resids & exposed_resids n_orig = len(resids) n_exposed = len(exposed_groups[group_name]) logger.debug(f" {group_name}: {n_exposed}/{n_orig} residues are exposed") # Collect contact frame counts per (polymer_type, protein_group) # Structure: {polymer_type: {protein_group: {"frames": int, "residues": set}}} contact_data: dict[str, dict[str, dict[str, Any]]] = {} # Also track total contacts per polymer type (for contact_share denominator) total_contacts_by_polymer: dict[str, int] = {} for rc in contact_result.residue_contacts: resid = rc.protein_resid if resid not in exposed_resids: continue # Skip buried residues # Determine which protein groups this residue belongs to residue_groups = [gname for gname, gresids in exposed_groups.items() if resid in gresids] if not residue_groups: # Residue is exposed but not in any defined group continue # Get contacts by polymer type for this residue contacts_by_type = rc.contacts_by_polymer_type(n_frames) for poly_type, frac in contacts_by_type.items(): if polymer_types and poly_type not in polymer_types: continue # Convert fraction to frame count contact_frames = int(round(frac * n_frames)) if contact_frames == 0: continue # Initialize polymer type if needed if poly_type not in contact_data: contact_data[poly_type] = {} total_contacts_by_polymer[poly_type] = 0 # Add to total contacts for this polymer total_contacts_by_polymer[poly_type] += contact_frames # Add to each group this residue belongs to for gname in residue_groups: if gname not in contact_data[poly_type]: contact_data[poly_type][gname] = { "total_frames": 0, "residues_contacted": set(), } contact_data[poly_type][gname]["total_frames"] += contact_frames contact_data[poly_type][gname]["residues_contacted"].add(resid) # Get polymer composition totals (for metadata/secondary analysis) total_poly_residues = polymer_composition.total_residues total_poly_atoms = polymer_composition.total_heavy_atoms # Build result entries with enrichment calculations entries = [] for poly_type in sorted(contact_data.keys()): total_poly_contacts = total_contacts_by_polymer.get(poly_type, 0) # Get polymer composition for this type (stored as metadata) poly_res_count = polymer_composition.residue_counts.get(poly_type, 0) poly_atom_count = polymer_composition.heavy_atom_counts.get(poly_type, 0) for group_name in sorted(protein_groups.keys()): n_total_in_group = len(protein_groups.get(group_name, set())) n_exposed_in_group = len(exposed_groups.get(group_name, set())) # Calculate expected share based on PROTEIN SURFACE AVAILABILITY # This is the correct normalization: how much of the exposed surface # is this protein group? if total_exposed > 0: expected_share = n_exposed_in_group / total_exposed else: expected_share = 0.0 # Get contact data for this (polymer, group) pair gdata = contact_data[poly_type].get(group_name, {}) contact_frames = gdata.get("total_frames", 0) residues_contacted = gdata.get("residues_contacted", set()) n_residues_contacted = len(residues_contacted) # Calculate mean contact fraction if n_exposed_in_group > 0 and n_frames > 0: # Per-residue average: total_frames / (n_frames * n_exposed) mean_frac = contact_frames / (n_frames * n_exposed_in_group) else: mean_frac = 0.0 # Calculate contact share (what fraction of this polymer's contacts # went to this protein group?) if total_poly_contacts > 0: contact_share = contact_frames / total_poly_contacts else: contact_share = 0.0 # Calculate enrichment (centered at zero) # enrichment = (contact_share / expected_share) - 1 # Positive = prefers this group, Negative = avoids this group enrichment = _compute_enrichment(contact_share, expected_share) entries.append( BindingPreferenceEntry( polymer_type=poly_type, protein_group=group_name, total_contact_frames=contact_frames, mean_contact_fraction=mean_frac, n_residues_in_group=n_total_in_group, n_exposed_in_group=n_exposed_in_group, n_residues_contacted=n_residues_contacted, contact_share=contact_share, expected_share=expected_share, enrichment=enrichment, # Polymer composition metadata (for secondary analysis) polymer_residue_count=poly_res_count, total_polymer_residues=total_poly_residues, polymer_heavy_atom_count=poly_atom_count, total_polymer_heavy_atoms=total_poly_atoms, ) ) # Compute system-level coverage (collapsed across polymer types) system_coverage = _compute_system_coverage( entries=entries, protein_groups=protein_groups, exposed_groups=exposed_groups, total_exposed=total_exposed, n_frames=n_frames, surface_exposure_threshold=surface_exposure.threshold, protein_group_selections=protein_group_selections, protein_partitions=protein_partitions, ) # Compute per-polymer partition-based binding preference (NEW in v5) # This is the primary output - contact_share sums to 1.0 within each partition binding_preference = _compute_polymer_binding_preference( contact_data=contact_data, total_contacts_by_polymer=total_contacts_by_polymer, protein_groups=protein_groups, exposed_groups=exposed_groups, protein_partitions=protein_partitions, total_exposed=total_exposed, n_frames=n_frames, surface_exposure_threshold=surface_exposure.threshold, polymer_composition=polymer_composition, protein_groups_used=protein_group_selections, ) result = BindingPreferenceResult( entries=entries, # DEPRECATED: kept for backward compat n_frames=n_frames, total_exposed_residues=total_exposed, surface_exposure_threshold=surface_exposure.threshold, protein_groups_used=protein_group_selections or {}, polymer_types_used=polymer_type_selections or {}, polymer_composition=polymer_composition, system_coverage=system_coverage, binding_preference=binding_preference, # NEW: partition-based per-polymer ) # Log summary polymer_types_found = result.polymer_types() logger.info( f"Binding preference computed: {len(entries)} entries for " f"{len(polymer_types_found)} polymer types × {len(protein_groups)} groups" ) n_aa_classes = len(system_coverage.aa_class_coverage.entries) n_custom_groups = len(system_coverage.custom_group_coverages) logger.info( f"System coverage computed: {n_aa_classes} AA classes, " f"{n_custom_groups} custom groups, " f"{system_coverage.total_contact_frames} total contact frames" ) # Log partition-based binding preference summary if binding_preference: n_user_partitions = len(binding_preference.user_defined_partitions) logger.info( f"Partition-based binding preference computed: " f"{len(binding_preference.polymer_types)} polymer types, " f"AA class partition + {n_user_partitions} user partitions" ) # Validate contact_share sums to ~1.0 for poly_type, aa_result in binding_preference.aa_class_binding.items(): total_share = aa_result.total_contact_share if abs(total_share - 1.0) > 0.01: logger.warning( f"AA class partition for {poly_type}: contact_share sums to " f"{total_share:.4f} (expected ~1.0)" ) return result