Source code for polyzymd.analyses._framework.aggregate_validation

"""Aggregate result validation helpers for analysis plugins."""

from __future__ import annotations

from pathlib import Path
from typing import TYPE_CHECKING, Any, Sequence

from pydantic import BaseModel

from polyzymd.analyses._framework.cache_identity import settings_fingerprint

if TYPE_CHECKING:
    from polyzymd.analyses._framework.contexts import Condition


_UNKNOWN_VALUES = {None, "", "unknown"}


[docs] class AggregateValidationError(ValueError): """Raised when an aggregated result is stale for the active context."""
def aggregate_settings_fingerprint(settings: BaseModel | None) -> str | None: """Return the default settings fingerprint for aggregate validation. Parameters ---------- settings : BaseModel or None Analysis settings for the active run. Returns ------- str or None Short deterministic settings fingerprint, or ``None`` when no settings model is available. """ if settings is None: return None return settings_fingerprint(settings) def validate_aggregated_result( analysis: Any, result: Any, *, condition: Condition | None, settings: BaseModel | None, equilibration: str, source: str | Path | None = None, expected_replicates: Sequence[int] | None = None, allow_replicate_subset: bool = False, ) -> Any: """Validate and coerce an aggregated result for the active context. Parameters ---------- analysis : Any Analysis plugin instance that owns result model configuration. result : Any Loaded or newly computed aggregate result. condition : Condition or None Condition providing simulation configuration and label context. settings : BaseModel or None Active analysis settings. equilibration : str Active equilibration window. source : str or Path or None, optional File path or description used in diagnostics. expected_replicates : sequence of int or None, optional Replicate IDs expected in the aggregate. allow_replicate_subset : bool, optional Accept a non-empty subset of ``expected_replicates`` when ``True``. Returns ------- Any The validated result, coerced through ``AggregatedResultClass`` when applicable. Raises ------ AggregateValidationError If the aggregate cannot prove compatibility with the active context. """ coerced = _coerce_aggregate_result(analysis, result, source=source) _validate_config_hash(coerced, condition=condition, analysis_name=analysis.name, source=source) _validate_equilibration( coerced, equilibration=equilibration, analysis_name=analysis.name, source=source, ) _validate_settings_fingerprint( analysis, coerced, settings=settings, source=source, ) _validate_replicates( analysis, coerced, expected_replicates=expected_replicates, allow_replicate_subset=allow_replicate_subset, source=source, ) return coerced def _coerce_aggregate_result(analysis: Any, result: Any, *, source: str | Path | None) -> Any: """Coerce dict aggregate payloads through a plugin result model.""" result_cls = type(analysis).AggregatedResultClass if result_cls is None or not isinstance(result, dict): return result try: if hasattr(result_cls, "model_validate"): return result_cls.model_validate(result) return result_cls(**result) except Exception as exc: raise _validation_error( analysis.name, source, f"could not parse aggregate as {result_cls.__name__}: {exc}", ) from exc def _validate_config_hash( result: Any, *, condition: Condition | None, analysis_name: str, source: str | Path | None, ) -> None: """Validate stored config hash when the aggregate provides one.""" stored_hash = _field_value(result, "config_hash") if stored_hash in _UNKNOWN_VALUES or condition is None: return from polyzymd.analyses._framework.cache_identity import validate_config_hash if not validate_config_hash(str(stored_hash), condition.sim_config): raise _validation_error( analysis_name, source, "config hash does not match the active condition config", ) def _validate_equilibration( result: Any, *, equilibration: str, analysis_name: str, source: str | Path | None, ) -> None: """Validate stored equilibration metadata when present.""" stored_time = _field_value(result, "equilibration_time") stored_unit = _field_value(result, "equilibration_unit") if stored_time is None and stored_unit is None: return if stored_time is None or stored_unit is None: raise _validation_error( analysis_name, source, "aggregate has incomplete equilibration metadata", ) from polyzymd.analyses.shared.loader import convert_time, parse_time_string try: expected_value, expected_unit = parse_time_string(equilibration) stored_ps = convert_time(float(stored_time), str(stored_unit), "ps") expected_ps = convert_time(float(expected_value), str(expected_unit), "ps") except (TypeError, ValueError) as exc: raise _validation_error( analysis_name, source, f"aggregate has invalid equilibration metadata: {exc}", ) from exc if abs(stored_ps - expected_ps) > 1.0e-9: raise _validation_error( analysis_name, source, f"equilibration mismatch: stored {float(stored_time):g}{stored_unit}, " f"requested {equilibration}", ) def _validate_settings_fingerprint( analysis: Any, result: Any, *, settings: BaseModel | None, source: str | Path | None, ) -> None: """Validate aggregate settings identity.""" expected_fingerprint = analysis.aggregate_settings_fingerprint(settings) if expected_fingerprint is None: return stored_fingerprint = _settings_fingerprint_from_result(result) if stored_fingerprint is None: raise _validation_error( analysis.name, source, "aggregate is missing settings fingerprint metadata", ) if str(stored_fingerprint) != str(expected_fingerprint): raise _validation_error( analysis.name, source, f"settings fingerprint mismatch: stored {stored_fingerprint}, " f"current {expected_fingerprint}", ) def _validate_replicates( analysis: Any, result: Any, *, expected_replicates: Sequence[int] | None, allow_replicate_subset: bool, source: str | Path | None, ) -> None: """Validate aggregate replicate identity and replicate count.""" if expected_replicates is None: return expected = tuple(sorted(int(rep) for rep in expected_replicates)) expected_set = set(expected) stored = _replicates_from_result(result) stored_count = _field_value(result, "n_replicates") if stored_count is None: stored_count = _metadata_value(result, "n_replicates") if stored is None and stored_count is None: raise _validation_error( analysis.name, source, "aggregate is missing replicate identity metadata", ) if stored is not None: _validate_stored_replicate_ids( analysis, stored, expected=expected, expected_set=expected_set, allow_replicate_subset=allow_replicate_subset, source=source, ) if stored_count is not None: _validate_replicate_count( analysis, stored_count, expected_count=len(stored), source=source, count_context="stored replicate IDs", ) return if allow_replicate_subset: try: declared_count = int(stored_count) except (TypeError, ValueError) as exc: raise _validation_error( analysis.name, source, f"aggregate has invalid n_replicates={stored_count!r}", ) from exc if declared_count < analysis.min_replicates or declared_count > len(expected): raise _validation_error( analysis.name, source, f"n_replicates={declared_count} is not an allowed subset count for " f"expected replicates {list(expected)}", ) return _validate_replicate_count( analysis, stored_count, expected_count=len(expected), source=source, count_context="expected replicates", ) def _validate_stored_replicate_ids( analysis: Any, stored_replicates: Sequence[int], *, expected: tuple[int, ...], expected_set: set[int], allow_replicate_subset: bool, source: str | Path | None, ) -> None: """Validate stored replicate IDs against expected IDs.""" stored = tuple(sorted(int(rep) for rep in stored_replicates)) if not stored: raise _validation_error(analysis.name, source, "aggregate has empty replicate identity") if len(set(stored)) != len(stored): raise _validation_error(analysis.name, source, "aggregate has duplicate replicate IDs") is_allowed_subset = allow_replicate_subset and set(stored).issubset(expected_set) if stored != expected and not is_allowed_subset: raise _validation_error( analysis.name, source, f"replicate mismatch: stored {list(stored)}, expected {list(expected)}", ) if allow_replicate_subset and len(stored) < analysis.min_replicates: raise _validation_error( analysis.name, source, f"stored subset has {len(stored)} replicates below minimum {analysis.min_replicates}", ) def _validate_replicate_count( analysis: Any, stored_count: Any, *, expected_count: int, source: str | Path | None, count_context: str, ) -> None: """Validate a declared aggregate replicate count.""" try: declared_count = int(stored_count) except (TypeError, ValueError) as exc: raise _validation_error( analysis.name, source, f"aggregate has invalid n_replicates={stored_count!r}", ) from exc if declared_count != expected_count: raise _validation_error( analysis.name, source, f"n_replicates={declared_count} does not match {count_context} count {expected_count}", ) def _settings_fingerprint_from_result(result: Any) -> str | None: """Return the settings fingerprint stored on an aggregate result.""" for field_name in ("settings_fingerprint", "settings_fp"): value = _field_value(result, field_name) if value is not None: return str(value) for field_name in ("settings_fingerprint", "settings_fp"): value = _metadata_value(result, field_name) if value is not None: return str(value) return None def _replicates_from_result(result: Any) -> tuple[int, ...] | None: """Return aggregate replicate IDs when the result declares them.""" for field_name in ("replicates", "replicate_ids"): value = _field_value(result, field_name) if value is not None: try: return tuple(int(rep) for rep in value) except (TypeError, ValueError) as exc: raise AggregateValidationError( "Aggregate result has invalid replicate identity. Recompute the analysis " "or clear stale analysis cache files." ) from exc for field_name in ("replicates", "replicate_ids"): value = _metadata_value(result, field_name) if value is not None: try: return tuple(int(rep) for rep in value) except (TypeError, ValueError) as exc: raise AggregateValidationError( "Aggregate result metadata has invalid replicate identity. Recompute the " "analysis or clear stale analysis cache files." ) from exc return None def _field_value(result: Any, field_name: str) -> Any: """Return a top-level result field from models or mappings.""" if isinstance(result, dict): return result.get(field_name) return getattr(result, field_name, None) def _metadata_value(result: Any, field_name: str) -> Any: """Return a metadata field from models or mappings.""" metadata = _field_value(result, "metadata") if isinstance(metadata, dict): return metadata.get(field_name) return None def _validation_error( analysis_name: str, source: str | Path | None, detail: str, ) -> AggregateValidationError: """Build a user-actionable aggregate validation error.""" source_text = f" at {source}" if source is not None else "" return AggregateValidationError( f"{analysis_name}: stale or incompatible aggregated result{source_text}: {detail}. " "Recompute the analysis with --recompute or clear stale analysis cache files." )