Source code for polyzymd.analyses.shared.autocorrelation

"""Autocorrelation analysis for independent sampling.

MD trajectories are highly correlated in time - consecutive frames are not
independent samples. This module provides tools to:

1. Compute the autocorrelation function (ACF) of an observable
2. Estimate the correlation time (τ) from the ACF
3. Compute statistical inefficiency (g) for proper uncertainty quantification
4. Select independent frames based on τ for proper statistics

Key Concepts
------------
- **Autocorrelation function (ACF)**: Measures how correlated a signal is with
  itself at different time lags. ACF(0) = 1, and ACF decays toward 0.

- **Correlation time (τ)**: Characteristic time for decorrelation. Frames
  separated by > 2τ are approximately independent.

- **Statistical inefficiency (g)**: Factor by which variance is inflated due
  to correlation. g = 1 + 2*Σ C(t)*(1-t/N). N_eff = N/g.

- **Independent samples**: For proper SEM calculation, we need N_eff independent
  samples, not N_frames correlated observations.

Methods for τ estimation
------------------------
- **First zero crossing**: τ is lag where ACF first crosses zero
- **Exponential fit**: Fit ACF = exp(-t/τ) and extract τ
- **Integration**: τ = ∫ACF(t)dt from 0 to first zero (or cutoff)

Statistical Validity
--------------------
The number of effective independent samples (N_eff) is computed as:
    N_eff = N / g = N / (1 + 2*Σ C(t)*(1-t/N))

This matches the algorithm from Chodera et al. (2007) with the finite-size
correction factor (1-t/N). When N_eff < 10, statistical estimates (mean, SEM)
may be unreliable, and users should be warned per LiveCoMS best practices
(Grossfield et al., 2018).

For multiple timeseries of different lengths (e.g., replicates), use
`statistical_inefficiency_multiple()` which correctly handles the averaging.

References
----------
- Flyvbjerg & Petersen (1989) J. Chem. Phys. 91:461 (block averaging)
- Chodera et al. (2007) J. Chem. Theory Comput. 3:26 (statistical inefficiency)
- Grossfield et al. (2018) LiveCoMS 1:5067 (uncertainty quantification)
"""

from __future__ import annotations

import logging
from dataclasses import dataclass
from enum import Enum
from typing import Literal

import numpy as np
from numpy.typing import ArrayLike, NDArray

logger = logging.getLogger(__name__)

# Minimum recommended independent samples for reliable statistics
MIN_RECOMMENDED_N_INDEPENDENT = 10


[docs] class CorrelationTimeMethod(str, Enum): """Method for estimating correlation time from ACF.""" FIRST_ZERO = "first_zero" EXPONENTIAL_FIT = "exponential_fit" INTEGRATION = "integration"
[docs] @dataclass class ACFResult: """Result of autocorrelation function computation. Attributes ---------- lags : NDArray[np.float64] Time lags in the same units as timestep acf : NDArray[np.float64] Autocorrelation values (normalized, ACF[0] = 1) timestep : float Time between frames timestep_unit : str Unit of timestep (e.g., "ps", "ns") n_samples : int Number of samples in the original timeseries """ lags: NDArray[np.float64] acf: NDArray[np.float64] timestep: float timestep_unit: str n_samples: int def __len__(self) -> int: return len(self.lags)
[docs] def to_dict(self) -> dict: """Convert to dictionary for serialization.""" return { "lags": self.lags.tolist(), "acf": self.acf.tolist(), "timestep": self.timestep, "timestep_unit": self.timestep_unit, "n_samples": self.n_samples, }
[docs] @dataclass class CorrelationTimeResult: """Result of correlation time estimation. Attributes ---------- tau : float Estimated correlation time tau_unit : str Unit of tau (same as timestep unit) method : str Method used for estimation n_independent : int Estimated number of independent samples in trajectory statistical_inefficiency : float g = 1 + 2*tau/dt, factor by which variance is inflated warning : str | None Warning message if statistics may be unreliable (e.g., N_ind < 10) """ tau: float tau_unit: str method: str n_independent: int statistical_inefficiency: float warning: str | None = None @property def is_reliable(self) -> bool: """Return True if statistics are likely reliable (N_ind >= 10).""" return self.n_independent >= MIN_RECOMMENDED_N_INDEPENDENT
[docs] def to_dict(self) -> dict: """Convert to dictionary for serialization.""" return { "tau": self.tau, "tau_unit": self.tau_unit, "method": self.method, "n_independent": self.n_independent, "statistical_inefficiency": self.statistical_inefficiency, "warning": self.warning, "is_reliable": self.is_reliable, }
[docs] def compute_acf( timeseries: ArrayLike, max_lag: int | None = None, timestep: float = 1.0, timestep_unit: str = "frames", ) -> ACFResult: """Compute autocorrelation function of a 1D timeseries. Uses FFT-based computation for efficiency. Parameters ---------- timeseries : array_like 1D array of values (e.g., RMSD over time, distance over time) max_lag : int, optional Maximum lag to compute (in frames). Default is N//4 where N is the length of the timeseries. timestep : float, optional Time between frames. Default is 1.0. timestep_unit : str, optional Unit of timestep. Default is "frames". Returns ------- ACFResult Container with lags, acf values, and metadata Examples -------- >>> # Compute ACF of RMSD timeseries >>> rmsd = np.array([1.2, 1.3, 1.25, 1.4, ...]) # from MDAnalysis >>> acf_result = compute_acf(rmsd, timestep=10.0, timestep_unit="ps") >>> print(f"ACF at lag 100ps: {acf_result.acf[10]:.3f}") Notes ----- The ACF is normalized so that ACF[0] = 1. For a stationary process: ACF(τ) = <(x(t) - μ)(x(t+τ) - μ)> / σ² For constant or near-constant timeseries (variance below a small epsilon), this function returns a defined degenerate ACF with ACF[0] = 1 and all positive lags set to 0. """ x = np.asarray(timeseries, dtype=np.float64) n = len(x) if n < 10: raise ValueError(f"Timeseries too short ({n} points). Need at least 10.") if max_lag is None: max_lag = n // 4 # Reasonable default max_lag = min(max_lag, n - 1) # Remove mean x_centered = x - np.mean(x) variance = float(np.var(x_centered)) # Define a stable degenerate ACF for near-constant timeseries if variance < 1e-12: acf = np.zeros(max_lag + 1, dtype=np.float64) acf[0] = 1.0 lags = np.arange(max_lag + 1, dtype=np.float64) * timestep return ACFResult( lags=lags, acf=acf, timestep=timestep, timestep_unit=timestep_unit, n_samples=n, ) # FFT-based autocorrelation (much faster than direct computation) # Pad to next power of 2 for FFT efficiency n_fft = 2 ** int(np.ceil(np.log2(2 * n - 1))) fft_x = np.fft.fft(x_centered, n_fft) acf_full = np.fft.ifft(fft_x * np.conj(fft_x)).real[:n] # Normalize by decreasing sample size and variance acf_full = acf_full / (np.arange(n, 0, -1) * variance) # Take only up to max_lag acf = acf_full[: max_lag + 1] lags = np.arange(max_lag + 1) * timestep return ACFResult( lags=lags, acf=acf, timestep=timestep, timestep_unit=timestep_unit, n_samples=n, )
[docs] def estimate_correlation_time( acf_or_timeseries: ACFResult | ArrayLike, timestep: float = 1.0, timestep_unit: str = "frames", method: Literal["first_zero", "exponential_fit", "integration"] = "integration", n_frames: int | None = None, ) -> CorrelationTimeResult: """Estimate correlation time from ACF or raw timeseries. Parameters ---------- acf_or_timeseries : ACFResult or array_like Either an ACFResult from compute_acf(), or a raw timeseries timestep : float, optional Time between frames (only used if passing raw timeseries) timestep_unit : str, optional Unit of timestep (only used if passing raw timeseries) method : {"first_zero", "exponential_fit", "integration"} Method for estimating τ: - "first_zero": Lag where ACF first crosses zero - "exponential_fit": Fit ACF = exp(-t/τ) - "integration": τ = ∫ACF(t)dt (recommended, most robust) n_frames : int, optional Total number of frames (for computing n_independent). Only needed if passing ACFResult. Returns ------- CorrelationTimeResult Contains tau, method used, n_independent, statistical_inefficiency Examples -------- >>> acf_result = compute_acf(rmsd, timestep=10.0, timestep_unit="ps") >>> tau_result = estimate_correlation_time(acf_result, method="integration") >>> print(f"Correlation time: {tau_result.tau:.1f} {tau_result.tau_unit}") >>> print(f"Independent samples: {tau_result.n_independent}") Notes ----- The "integration" method is most robust for noisy ACFs. It computes: τ = ∫₀^∞ ACF(t) dt ≈ Σ ACF[i] * dt Integration stops at first zero crossing to avoid noise contribution. """ # Handle input type if isinstance(acf_or_timeseries, ACFResult): acf_result = acf_or_timeseries dt = acf_result.timestep unit = acf_result.timestep_unit acf = acf_result.acf lags = acf_result.lags if n_frames is None: n_frames = acf_result.n_samples else: # Compute ACF from raw timeseries acf_result = compute_acf( acf_or_timeseries, timestep=timestep, timestep_unit=timestep_unit, ) acf = acf_result.acf lags = acf_result.lags dt = timestep unit = timestep_unit n_frames = len(np.asarray(acf_or_timeseries)) # Handle degenerate constant-series ACF explicitly if len(acf) > 1 and np.isclose(acf[0], 1.0) and np.allclose(acf[1:], 0.0): g = 1.0 n_independent = max(1, int(n_frames / g)) return CorrelationTimeResult( tau=0.0, tau_unit=unit, method=method, n_independent=n_independent, statistical_inefficiency=g, warning=None, ) # Find first zero crossing zero_crossing_idx = _find_first_zero_crossing(acf) if method == "first_zero": if zero_crossing_idx is None: # ACF never crosses zero, use full length tau = lags[-1] else: tau = float(lags[zero_crossing_idx]) elif method == "exponential_fit": tau = _fit_exponential_acf(lags, acf, zero_crossing_idx) elif method == "integration": tau = _integrate_acf(lags, acf, zero_crossing_idx, dt, n_frames=n_frames) else: raise ValueError(f"Unknown method: {method}") # Compute statistical inefficiency: g = 1 + 2*τ/dt # This is the factor by which variance is inflated due to correlation g = 1.0 + 2.0 * tau / dt # Number of independent samples n_independent = max(1, int(n_frames / g)) # Generate warning if statistics may be unreliable warning = None if n_independent < MIN_RECOMMENDED_N_INDEPENDENT: warning = ( f"Low statistical reliability: only {n_independent} independent samples " f"(recommended >= {MIN_RECOMMENDED_N_INDEPENDENT}). " f"Correlation time τ = {tau:.1f} {unit} is comparable to or longer than " f"the trajectory sampling window. Consider: (1) extending simulation time, " f"(2) using multiple independent trajectories, or (3) interpreting results " f"with caution. See Grossfield et al. (2018) LiveCoMS 1:5067." ) logger.warning(warning) return CorrelationTimeResult( tau=tau, tau_unit=unit, method=method, n_independent=n_independent, statistical_inefficiency=g, warning=warning, )
[docs] def get_independent_indices( n_frames: int, correlation_time: float, timestep: float = 1.0, start_frame: int = 0, ) -> NDArray[np.int64]: """Get frame indices for independent samples. Selects frames separated by at least 2*τ (correlation time) to ensure approximate independence for statistical analysis. Parameters ---------- n_frames : int Total number of frames in trajectory correlation_time : float Correlation time τ (in same units as timestep) timestep : float, optional Time between frames. Default is 1.0. start_frame : int, optional First frame to consider (after equilibration). Default is 0. Note: Frame indices are 0-indexed internally, but user-facing documentation uses 1-indexed (PyMOL convention). Returns ------- NDArray[np.int64] Array of frame indices (0-indexed) that are approximately independent Examples -------- >>> # Get independent frames for RMSF calculation >>> tau_result = estimate_correlation_time(rmsd, timestep=10.0) >>> indices = get_independent_indices( ... n_frames=10000, ... correlation_time=tau_result.tau, ... timestep=10.0, ... start_frame=1000, # Skip first 1000 frames for equilibration ... ) >>> print(f"Using {len(indices)} independent frames") Notes ----- Frame indices returned are 0-indexed (for direct use with MDAnalysis). When displaying to users, add 1 for PyMOL convention. The spacing is set to 2*τ/timestep, which gives frames with negligible correlation (ACF < 0.05 for exponential decay). """ if n_frames <= start_frame: raise ValueError(f"start_frame ({start_frame}) >= n_frames ({n_frames})") # Convert correlation time to frame spacing # Use 2*τ for good independence (ACF ≈ exp(-2) ≈ 0.14) frame_spacing = max(1, int(np.ceil(2.0 * correlation_time / timestep))) # Generate indices indices = np.arange(start_frame, n_frames, frame_spacing, dtype=np.int64) return indices
def _find_first_zero_crossing(acf: NDArray[np.float64]) -> int | None: """Find index of first zero crossing in ACF.""" nonpositive = np.where(acf <= 0.0)[0] if len(nonpositive) > 0: return int(nonpositive[0]) return None def _fit_exponential_acf( lags: NDArray[np.float64], acf: NDArray[np.float64], zero_crossing_idx: int | None, ) -> float: """Fit exponential decay to ACF and extract τ.""" # Only fit up to first zero crossing (or halfway) if zero_crossing_idx is not None: fit_end = zero_crossing_idx else: fit_end = len(acf) // 2 fit_end = max(3, fit_end) # Need at least 3 points # ACF = exp(-t/τ) => log(ACF) = -t/τ # Linear fit: y = -x/τ where y = log(ACF), x = t acf_positive = np.maximum(acf[:fit_end], 1e-10) # Avoid log(0) log_acf = np.log(acf_positive) # Linear regression try: slope, _ = np.polyfit(lags[:fit_end], log_acf, 1) tau = -1.0 / slope if slope < 0 else lags[fit_end] except (np.linalg.LinAlgError, ValueError): # Fallback to integration if fit fails tau = lags[fit_end] return float(max(tau, lags[1])) # At least one timestep def _integrate_acf( lags: NDArray[np.float64], acf: NDArray[np.float64], zero_crossing_idx: int | None, dt: float, n_frames: int | None = None, use_finite_size_correction: bool = True, ) -> float: """Estimate τ by integrating ACF. Parameters ---------- lags : NDArray[np.float64] Time lags acf : NDArray[np.float64] Autocorrelation values zero_crossing_idx : int | None Index of first zero crossing dt : float Timestep n_frames : int | None Total number of frames (for finite-size correction) use_finite_size_correction : bool If True, apply (1-t/N) weighting per Chodera et al. 2007 Returns ------- float Estimated correlation time τ """ # Integrate up to first zero crossing if zero_crossing_idx is not None: int_end = zero_crossing_idx + 1 else: # Find where ACF drops below threshold below_threshold = np.where(acf < 0.05)[0] if len(below_threshold) > 0: int_end = below_threshold[0] + 1 else: int_end = len(acf) # Apply finite-size correction if requested if use_finite_size_correction and n_frames is not None and n_frames > 0: # Weight ACF by (1 - t/N) per Chodera et al. 2007 # This accounts for reduced sample size at longer lags lag_indices = np.arange(int_end) weights = 1.0 - lag_indices / n_frames weighted_acf = acf[:int_end] * weights else: weighted_acf = acf[:int_end] # Trapezoidal integration (trapezoid in numpy 2.0+, trapz in older versions) trapz_func = getattr(np, "trapezoid", np.trapz) tau = float(trapz_func(weighted_acf, lags[:int_end])) return max(tau, dt) # At least one timestep # ============================================================================= # Statistical Inefficiency Functions # =============================================================================
[docs] def statistical_inefficiency( timeseries: ArrayLike, mintime: int = 3, fft: bool = True, ) -> float: """Compute statistical inefficiency g directly from a timeseries. The statistical inefficiency g is the factor by which the variance of the sample mean is increased due to correlation: Var(mean) = Var(x) * g / N This is computed as: g = 1 + 2 * Σ C(t) * (1 - t/N) where C(t) is the normalized autocorrelation function and the sum includes the finite-size correction factor (1 - t/N) per Chodera et al. (2007). Parameters ---------- timeseries : array_like 1D array of values (e.g., contact binary array, RMSD over time) mintime : int Minimum number of lags to compute before checking for zero crossing. Prevents early termination from noise. Default is 3. fft : bool If True, use FFT-based ACF computation (faster). Default is True. Returns ------- float Statistical inefficiency g (>= 1.0). The number of effective independent samples is N_eff = N / g. Examples -------- >>> # Binary contact timeseries >>> contacts = np.array([0, 1, 1, 1, 0, 0, 1, 1, ...]) >>> g = statistical_inefficiency(contacts) >>> n_eff = len(contacts) / g >>> print(f"Effective samples: {n_eff:.1f}") >>> # Continuous observable >>> rmsd = np.array([1.2, 1.3, 1.25, 1.4, ...]) >>> g = statistical_inefficiency(rmsd) Notes ----- This implementation follows the algorithm from Chodera et al. (2007) J. Chem. Theory Comput. 3:26, with the finite-size correction. For binary (0/1) data, the algorithm works correctly as the variance of a Bernoulli random variable is p(1-p). References ---------- Chodera et al. (2007) J. Chem. Theory Comput. 3:26 """ x = np.asarray(timeseries, dtype=np.float64) n = len(x) if n < 3: logger.warning(f"Timeseries too short ({n} points). Returning g=1.0") return 1.0 # Compute variance mu = np.mean(x) var = np.var(x) if var < 1e-10: # Constant timeseries - no correlation return 1.0 # Compute normalized fluctuations delta_x = x - mu # Compute ACF using FFT for efficiency if fft: n_fft = 2 ** int(np.ceil(np.log2(2 * n - 1))) fft_x = np.fft.fft(delta_x, n_fft) acf_unnorm = np.fft.ifft(fft_x * np.conj(fft_x)).real[:n] # Normalize by decreasing sample size acf = acf_unnorm / (np.arange(n, 0, -1) * var) else: # Direct computation (slower but clearer) acf = np.zeros(n) for t in range(n): acf[t] = np.mean(delta_x[: n - t] * delta_x[t:]) / var # Compute g = 1 + 2 * sum(C(t) * (1 - t/N)) # Start with g = 1 (for lag 0, C(0) = 1, but we don't count it in the sum) g = 1.0 # Sum over positive lags with finite-size correction for t in range(1, n): # Finite-size correction factor weight = 1.0 - float(t) / n # Check for zero crossing (after mintime) if t >= mintime and acf[t] <= 0: break g += 2.0 * acf[t] * weight # Ensure g >= 1 g = max(1.0, g) return float(g)
[docs] def statistical_inefficiency_multiple( timeseries_list: list[ArrayLike], mintime: int = 3, ) -> float: """Compute statistical inefficiency from multiple timeseries of different lengths. This is critical for aggregating replicates with different frame counts. The algorithm computes a global mean μ across all timeseries, then averages the ACF numerator and denominator separately before computing g. Parameters ---------- timeseries_list : list[ArrayLike] List of 1D timeseries arrays (can have different lengths) mintime : int Minimum number of lags before checking for zero crossing. Default is 3. Returns ------- float Statistical inefficiency g (>= 1.0) Examples -------- >>> # Three replicates with different lengths >>> ts1 = np.array([0, 1, 1, 0, 0, 1]) # 6 frames >>> ts2 = np.array([1, 1, 0, 0, 0]) # 5 frames >>> ts3 = np.array([0, 0, 1, 1, 1, 0, 1]) # 7 frames >>> g = statistical_inefficiency_multiple([ts1, ts2, ts3]) Notes ----- This implementation follows the algorithm from PyMBAR's `statistical_inefficiency_multiple()`, adapted without the PyMBAR dependency. The algorithm: 1. Compute global mean μ across all timeseries 2. For each lag t: - Compute sum of (x - μ) products across all timeseries where t < N_k - Compute sum of sample counts across all timeseries where t < N_k - Average to get C(t) 3. Sum with finite-size correction References ---------- Chodera et al. (2007) J. Chem. Theory Comput. 3:26 """ if not timeseries_list: return 1.0 # Convert to numpy arrays arrays = [np.asarray(ts, dtype=np.float64) for ts in timeseries_list] lengths = np.array([len(a) for a in arrays]) n_total = int(np.sum(lengths)) max_length = int(np.max(lengths)) if n_total < 3: logger.warning(f"Total samples too few ({n_total}). Returning g=1.0") return 1.0 # Compute global mean total_sum = sum(np.sum(a) for a in arrays) mu = total_sum / n_total # Compute global variance total_var_sum = sum(np.sum((a - mu) ** 2) for a in arrays) var = total_var_sum / n_total if var < 1e-10: return 1.0 # Compute fluctuations deltas = [a - mu for a in arrays] # Compute g using averaged ACF g = 1.0 for t in range(1, max_length): # Sum ACF contributions from all timeseries where t < N_k acf_numerator = 0.0 acf_denominator = 0.0 for k, (delta, n_k) in enumerate(zip(deltas, lengths)): if t < n_k: # This timeseries contributes at lag t # Number of pairs at lag t n_pairs = n_k - t # Sum of products product_sum = np.sum(delta[: n_k - t] * delta[t:]) acf_numerator += product_sum acf_denominator += n_pairs if acf_denominator < 1: # No timeseries has this lag break # Normalized ACF at lag t c_t = acf_numerator / (acf_denominator * var) # Check for zero crossing (after mintime) if t >= mintime and c_t <= 0: break # Finite-size correction: use average N across contributing timeseries # For simplicity, use the mean length of timeseries that contribute contributing = lengths[lengths > t] if len(contributing) == 0: break mean_n = np.mean(contributing) weight = 1.0 - float(t) / mean_n g += 2.0 * c_t * weight # Ensure g >= 1 g = max(1.0, g) return float(g)
[docs] def n_effective(n_samples: int, g: float) -> float: """Compute number of effective independent samples. Parameters ---------- n_samples : int Total number of samples g : float Statistical inefficiency Returns ------- float Effective number of independent samples (N_eff = N / g) """ if g <= 0: return float(n_samples) return n_samples / g
[docs] def check_statistical_reliability( n_eff: float, threshold: int = MIN_RECOMMENDED_N_INDEPENDENT, ) -> tuple[bool, str | None]: """Check if statistics are reliable based on effective sample count. Parameters ---------- n_eff : float Number of effective independent samples threshold : int Minimum recommended independent samples. Default is 10. Returns ------- is_reliable : bool True if n_eff >= threshold warning : str | None Warning message if not reliable, None otherwise Examples -------- >>> g = statistical_inefficiency(contacts) >>> n_eff = n_effective(len(contacts), g) >>> is_ok, warning = check_statistical_reliability(n_eff) >>> if not is_ok: ... print(warning) """ if n_eff >= threshold: return True, None warning = ( f"Low statistical reliability: only {n_eff:.1f} effective independent samples " f"(recommended >= {threshold}). Consider: (1) extending simulation time, " f"(2) using more independent replicates, or (3) interpreting results " f"with caution. See Grossfield et al. (2018) LiveCoMS 1:5067." ) logger.warning(warning) return False, warning