"""Replicate-level SLURM orchestration for analysis comparisons.
This module provides a shared DAG submission layer for analysis plugins:
- one replicate worker per (condition, replicate)
- one aggregate worker per condition
- one finalizer worker per analysis comparison
The DAG parallelizes at the per-replicate compute-stage boundary, with one
SLURM job per (condition, replicate) pair. This per-replicate worker is the
analysis lifecycle's atomic unit. Sub-replicate parallelism (for example,
per-run work inside SASA-style calculations) is intentionally handled inside
each plugin's compute path. Plugins can use internal threading/multiprocessing
for that finer-grained work when needed.
"""
from __future__ import annotations
import json
import logging
import os
import re
import shutil
import stat
import subprocess
import time
from datetime import datetime, timezone
from hashlib import sha256
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import Any, Literal, Sequence, cast
from pydantic import BaseModel, Field, model_validator
from polyzymd.analyses.base import Analysis
from polyzymd.analyses.orchestrator import prepare_comparison_run
from polyzymd.analyses.shared.paths import sanitize_label
from polyzymd.config.comparison import ComparisonConfig
from polyzymd.utils.templates import render_package_template
LOGGER = logging.getLogger(__name__)
_MEM_PATTERN = re.compile(r"^\d+(?:[KMGT]B?)$", flags=re.IGNORECASE)
_TIME_PATTERN = re.compile(r"^(?:\d+-)?\d{1,3}:\d{2}:\d{2}$")
_SLURM_JOB_ID_PATTERN = re.compile(r"^\d+(?:_\d+)?$")
_SBATCH_PARSE_FAILURE_MARKER = "SBATCH_PARSE_FAILURE"
_WORKFLOW_TEMPLATE_PACKAGE = "polyzymd.workflow"
_ANALYSIS_REPLICATE_TEMPLATE = "analysis_replicate_worker.sh.jinja"
_ANALYSIS_AGGREGATE_TEMPLATE = "analysis_aggregate_worker.sh.jinja"
_ANALYSIS_FINALIZE_TEMPLATE = "analysis_finalize_worker.sh.jinja"
_ANALYSIS_ARRAY_TEMPLATE = "analysis_array_worker.sh.jinja"
[docs]
class AnalysisSlurmResources(BaseModel):
"""SLURM resource settings for analysis workers."""
pixi_path: str = Field(default="pixi")
partition: str | None = None
qos: str | None = None
account: str | None = None
ntasks: int = Field(default=1, ge=1, le=256)
cpus_per_task: int = Field(default=1, ge=1, le=256)
mem: str = "4G"
time: str = "01:00:00"
max_retries: int = Field(default=3, ge=1)
mail_user: str | None = None
mail_type: str = "FAIL"
@model_validator(mode="after")
def _validate_slurm_values(self) -> AnalysisSlurmResources:
"""Validate user-provided values used in SLURM scripts.
Returns
-------
AnalysisSlurmResources
Validated resource object.
Raises
------
ValueError
If a field contains unsafe shell characters or invalid format.
"""
self.pixi_path = _sanitize_slurm_value(self.pixi_path, "pixi_path")
if self.partition is not None:
self.partition = _sanitize_slurm_value(self.partition, "partition")
self.mem = _sanitize_slurm_value(self.mem, "mem")
self.time = _sanitize_slurm_value(self.time, "time")
self.mail_type = _sanitize_slurm_value(self.mail_type, "mail_type")
if self.qos is not None:
self.qos = _sanitize_slurm_value(self.qos, "qos")
if self.account is not None:
self.account = _sanitize_slurm_value(self.account, "account")
if self.mail_user is not None:
self.mail_user = _sanitize_slurm_value(self.mail_user, "mail_user")
if not _MEM_PATTERN.match(self.mem):
raise ValueError("Invalid mem format. Expected values like '4G', '8000M', or '16GB'.")
if not _TIME_PATTERN.match(self.time):
raise ValueError("Invalid time format. Expected 'HH:MM:SS' or 'D-HH:MM:SS'.")
return self
[docs]
class ReplicateTaskSpec(BaseModel):
"""Task spec for one replicate job."""
condition_index: int
replicate: int
condition_label: str
condition_slug: str
[docs]
class ConditionTaskSpec(BaseModel):
"""Task spec for one condition aggregate job."""
condition_index: int
condition_label: str
condition_slug: str
replicate_specs: list[ReplicateTaskSpec]
[docs]
class AnalysisJobManifest(BaseModel):
"""Snapshot of inputs needed to run analysis workers."""
analysis_name: str
comparison_yaml: str
condition_specs: list[ConditionTaskSpec]
settings_snapshot: dict[str, Any]
snapshot_hash: str = Field(min_length=1)
pipeline_mode: Literal["full", "finalize_only"]
partial_policy: Literal["strict", "allow_partial"]
equilibration: str
recompute: bool
resources: AnalysisSlurmResources
created_at: str
[docs]
def save(self, path: Path) -> Path:
"""Save manifest as JSON."""
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(self.model_dump_json(indent=2))
return path
[docs]
@classmethod
def load(cls, path: Path) -> AnalysisJobManifest:
"""Load manifest from JSON."""
return cls.model_validate_json(path.read_text())
[docs]
class SubmittedJobGraph(BaseModel):
"""Submitted SLURM job IDs for analysis DAG nodes."""
replicate_jobs: dict[tuple[int, int], str]
array_jobs: dict[str, str] | None = None
aggregator_jobs: dict[int, str]
finalizer_job_id: str
[docs]
def save(self, path: Path) -> Path:
"""Save graph as JSON with portable keys."""
payload = {
"replicate_jobs": {
f"{cond_idx}:{rep}": job_id
for (cond_idx, rep), job_id in self.replicate_jobs.items()
},
"array_jobs": self.array_jobs,
"aggregator_jobs": {
str(cond_idx): job_id for cond_idx, job_id in self.aggregator_jobs.items()
},
"finalizer_job_id": self.finalizer_job_id,
}
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(json.dumps(payload, indent=2))
return path
[docs]
@classmethod
def load(cls, path: Path) -> SubmittedJobGraph:
"""Load graph from JSON with tuple/int key reconstruction."""
payload = json.loads(path.read_text())
raw_replicate_jobs = {
tuple(map(int, key.split(":"))): value
for key, value in payload.get("replicate_jobs", {}).items()
}
replicate_jobs = cast(dict[tuple[int, int], str], raw_replicate_jobs)
aggregator_jobs = {
int(key): value for key, value in payload.get("aggregator_jobs", {}).items()
}
return cls(
replicate_jobs=replicate_jobs,
array_jobs=payload.get("array_jobs"),
aggregator_jobs=aggregator_jobs,
finalizer_job_id=payload["finalizer_job_id"],
)
[docs]
class TaskStatus(BaseModel):
"""Task status persisted by worker wrappers."""
state: Literal["pending", "running", "succeeded", "failed", "retrying"]
attempt_count: int = Field(default=0, ge=0)
error_message: str | None = None
last_updated: str
slurm_job_id: str | None = None
def _utc_now() -> str:
return datetime.now(timezone.utc).isoformat()
def _sanitize_slurm_value(value: str, field_name: str) -> str:
"""Validate a SLURM value before writing it into scripts.
Parameters
----------
value : str
User-provided value.
field_name : str
Name of the field being validated.
Returns
-------
str
Sanitized value.
Raises
------
ValueError
If value contains unsafe characters or is too long.
"""
if len(value) > 128:
raise ValueError(f"Unsafe SLURM value for '{field_name}': exceeds 128 characters")
if any(ch.isspace() for ch in value):
raise ValueError(f"Unsafe SLURM value for '{field_name}': whitespace is not allowed")
if any(
token in value
for token in (
"\n",
"\r",
";",
"`",
"$(",
"${",
"$",
"\\",
"'",
'"',
"|",
"<",
">",
)
):
raise ValueError(
f"Unsafe SLURM value for '{field_name}': contains disallowed shell characters"
)
return value
def _sanitize_path_for_script(path: Path) -> str:
"""Validate a filesystem path before script interpolation.
Parameters
----------
path : Path
Filesystem path that will be interpolated into generated script content.
Returns
-------
str
Resolved path string when no unsafe tokens are present.
Raises
------
ValueError
If the resolved path includes shell-unsafe or quote-breaking characters.
"""
raw_path_str = str(path)
if any(0 <= ord(char) <= 0x1F or ord(char) == 0x7F for char in raw_path_str):
raise ValueError(
"Unsafe path for generated scripts: found ASCII control character in "
f"{raw_path_str!r}. Please rename the path to remove control characters"
)
path_str = str(path.resolve())
if any(0 <= ord(char) <= 0x1F or ord(char) == 0x7F for char in path_str):
raise ValueError(
"Unsafe path for generated scripts: found ASCII control character in "
f"{path_str!r}. Please rename the path to remove control characters"
)
disallowed_tokens = ("'", '"', "`", "$(", "${", "$", "\\", "|", "<", ">", ";")
for token in disallowed_tokens:
if token in path_str:
display_token = token.encode("unicode_escape").decode("ascii")
raise ValueError(
"Unsafe path for generated scripts: "
f"found disallowed token '{display_token}' in {path_str!r}. "
"Please rename the path to remove this character sequence"
)
return path_str
def _pixi_run_prefix(resources: AnalysisSlurmResources) -> str:
"""Return the pixi command prefix used in generated scripts."""
pixi = resources.pixi_path
if pixi == "pixi":
detected = shutil.which("pixi")
if detected is not None:
pixi = detected
return f'"{pixi}" run -e build'
def _render_analysis_template(template_name: str, context: dict[str, Any]) -> str:
"""Render an analysis SLURM worker script template.
Parameters
----------
template_name : str
Package-resource template filename.
context : dict[str, Any]
Pre-sanitized template variables.
Returns
-------
str
Rendered SLURM worker script.
"""
return render_package_template(_WORKFLOW_TEMPLATE_PACKAGE, template_name, context)
def _condition_specs_from_conditions(conditions: list[Any]) -> list[ConditionTaskSpec]:
"""Build manifest condition task specs from prepared conditions."""
condition_specs: list[ConditionTaskSpec] = []
for cond_idx, condition in enumerate(conditions):
slug = sanitize_label(condition.label)
reps = [
ReplicateTaskSpec(
condition_index=cond_idx,
replicate=rep,
condition_label=condition.label,
condition_slug=slug,
)
for rep in condition.replicates
]
condition_specs.append(
ConditionTaskSpec(
condition_index=cond_idx,
condition_label=condition.label,
condition_slug=slug,
replicate_specs=reps,
)
)
return condition_specs
[docs]
def compute_manifest_snapshot_hash(
analysis_name: str,
settings_snapshot: dict[str, Any],
condition_specs: list[ConditionTaskSpec],
equilibration: str,
) -> str:
"""Compute deterministic hash for manifest-sensitive comparison inputs."""
payload = {
"analysis_name": analysis_name,
"settings_snapshot": settings_snapshot,
"condition_specs": [spec.model_dump(mode="json") for spec in condition_specs],
"equilibration": equilibration,
}
encoded = json.dumps(payload, sort_keys=True, separators=(",", ":")).encode("utf-8")
return sha256(encoded).hexdigest()
[docs]
def validate_manifest_snapshot(
manifest: AnalysisJobManifest,
analysis: Analysis,
config: ComparisonConfig,
) -> tuple[list[Any], str, Path]:
"""Validate that live comparison inputs match the manifest snapshot.
Returns
-------
tuple[list[Any], str, Path]
Prepared conditions, resolved equilibration, and analysis root.
Raises
------
RuntimeError
If current config/plugin settings drift from the submitted manifest.
"""
prepared = prepare_comparison_run(
analysis,
config,
manifest.equilibration,
)
valid_conditions = prepared["valid_conditions"]
live_settings = prepared["settings"]
resolved_equilibration = prepared["equilibration"]
analysis_root = prepared["analysis_root"]
live_settings_snapshot = (
live_settings.model_dump(mode="json") if hasattr(live_settings, "model_dump") else {}
)
live_condition_specs = _condition_specs_from_conditions(valid_conditions)
live_hash = compute_manifest_snapshot_hash(
analysis_name=manifest.analysis_name,
settings_snapshot=live_settings_snapshot,
condition_specs=live_condition_specs,
equilibration=resolved_equilibration,
)
if live_hash != manifest.snapshot_hash:
raise RuntimeError(
"Manifest/config drift detected. The current comparison.yaml or plugin settings no "
"longer match the submitted manifest snapshot. Re-submit the analysis jobs."
)
return valid_conditions, resolved_equilibration, analysis_root
def _manifest_path(hpc_dir: Path) -> Path:
return hpc_dir / "manifest.json"
def _graph_path(hpc_dir: Path) -> Path:
return hpc_dir / "job_graph.json"
def _submission_error_path(hpc_dir: Path) -> Path:
return hpc_dir / "submission_error.json"
def _scripts_dir(hpc_dir: Path) -> Path:
return hpc_dir / "scripts"
def _logs_dir(hpc_dir: Path) -> Path:
return hpc_dir / "logs"
def _rep_status_path(hpc_dir: Path, condition_slug: str, replicate: int) -> Path:
return hpc_dir / "status" / "replicates" / condition_slug / f"rep_{replicate}.json"
def _cond_status_path(hpc_dir: Path, condition_slug: str) -> Path:
return hpc_dir / "status" / "conditions" / f"{condition_slug}.json"
def _final_status_path(hpc_dir: Path) -> Path:
return hpc_dir / "status" / "finalize.json"
[docs]
def update_task_status(
status_path: Path,
state: Literal["pending", "running", "succeeded", "failed", "retrying"],
attempt_count: int,
error_message: str | None = None,
) -> None:
"""Atomically write a task status JSON file."""
status_path.parent.mkdir(parents=True, exist_ok=True)
payload = TaskStatus(
state=state,
attempt_count=attempt_count,
error_message=error_message,
last_updated=_utc_now(),
slurm_job_id=os.getenv("SLURM_JOB_ID"),
)
with NamedTemporaryFile("w", encoding="utf-8", dir=status_path.parent, delete=False) as tmp:
tmp.write(payload.model_dump_json(indent=2))
tmp_path = Path(tmp.name)
os.replace(tmp_path, status_path)
def _status_update_python(
status_path: Path,
state: str,
resources: AnalysisSlurmResources,
error_expr: str = "None",
) -> str:
status_path_str = _sanitize_path_for_script(status_path)
return (
f'{_pixi_run_prefix(resources)} python -c "from pathlib import Path; '
"from polyzymd.workflow.analysis_slurm import update_task_status; "
f"update_task_status(Path(r'{status_path_str}'), '{state}', int($ATTEMPT), {error_expr})\""
)
def _status_attempt_python(status_path: Path, resources: AnalysisSlurmResources) -> str:
status_path_str = _sanitize_path_for_script(status_path)
return (
f'{_pixi_run_prefix(resources)} python -c "import json; '
f"d=json.load(open(r'{status_path_str}')); "
"print(d.get('attempt_count', 0))\""
)
def _slurm_header(resources: AnalysisSlurmResources, job_name: str, log_path: Path) -> str:
log_path_str = _sanitize_path_for_script(log_path)
lines = [
"#!/bin/bash",
"#SBATCH --requeue",
f"#SBATCH --job-name={job_name}",
f'#SBATCH --output="{log_path_str}"',
f"#SBATCH --ntasks={resources.ntasks}",
f"#SBATCH --cpus-per-task={resources.cpus_per_task}",
f"#SBATCH --mem={resources.mem}",
f"#SBATCH --time={resources.time}",
]
if resources.partition:
lines.insert(2, f"#SBATCH --partition={resources.partition}")
if resources.qos:
lines.append(f"#SBATCH --qos={resources.qos}")
if resources.account:
lines.append(f"#SBATCH --account={resources.account}")
if resources.mail_user:
lines.append(f"#SBATCH --mail-type={resources.mail_type}")
lines.append(f"#SBATCH --mail-user={resources.mail_user}")
return "\n".join(lines)
def _ensure_layout(hpc_dir: Path, manifest: AnalysisJobManifest) -> None:
_scripts_dir(hpc_dir).mkdir(parents=True, exist_ok=True)
_logs_dir(hpc_dir).mkdir(parents=True, exist_ok=True)
if manifest.pipeline_mode == "full":
(hpc_dir / "status" / "replicates").mkdir(parents=True, exist_ok=True)
(hpc_dir / "status" / "conditions").mkdir(parents=True, exist_ok=True)
for cond_spec in manifest.condition_specs:
for rep_spec in cond_spec.replicate_specs:
status_path = _rep_status_path(
hpc_dir, cond_spec.condition_slug, rep_spec.replicate
)
if not status_path.exists():
update_task_status(status_path, "pending", 0)
cond_status = _cond_status_path(hpc_dir, cond_spec.condition_slug)
if not cond_status.exists():
update_task_status(cond_status, "pending", 0)
fin = _final_status_path(hpc_dir)
if not fin.exists():
update_task_status(fin, "pending", 0)
[docs]
def build_manifest(
analysis: Analysis,
config: ComparisonConfig,
resources: AnalysisSlurmResources,
recompute: bool,
equilibration: str | None,
allow_partial: bool = False,
) -> AnalysisJobManifest:
"""Build submission manifest from comparison config and plugin settings."""
prepared = prepare_comparison_run(
analysis,
config,
equilibration,
)
valid_conditions = prepared["valid_conditions"]
settings = prepared["settings"]
resolved_equilibration = prepared["equilibration"]
condition_specs = _condition_specs_from_conditions(valid_conditions)
pipeline_mode: Literal["full", "finalize_only"] = (
"finalize_only" if not analysis.has_compute_stage else "full"
)
settings_snapshot = settings.model_dump(mode="json") if hasattr(settings, "model_dump") else {}
snapshot_hash = compute_manifest_snapshot_hash(
analysis_name=analysis.name,
settings_snapshot=settings_snapshot,
condition_specs=condition_specs,
equilibration=resolved_equilibration,
)
return AnalysisJobManifest(
analysis_name=analysis.name,
comparison_yaml=str(Path(config.source_path).resolve()) if config.source_path else "",
condition_specs=condition_specs,
settings_snapshot=settings_snapshot,
snapshot_hash=snapshot_hash,
pipeline_mode=pipeline_mode,
partial_policy="allow_partial" if allow_partial else "strict",
equilibration=resolved_equilibration,
recompute=recompute,
resources=resources,
created_at=_utc_now(),
)
[docs]
def generate_replicate_script(
manifest: AnalysisJobManifest,
task_spec: ReplicateTaskSpec,
resources: AnalysisSlurmResources,
hpc_dir: Path,
) -> Path:
"""Generate a replicate worker script with automatic retries."""
script_path = (
_scripts_dir(hpc_dir) / f"replicate__{task_spec.condition_slug}__r{task_spec.replicate}.sh"
)
log_path = (
_logs_dir(hpc_dir) / f"replicate__{task_spec.condition_slug}__r{task_spec.replicate}.%j.out"
)
status_path = _rep_status_path(hpc_dir, task_spec.condition_slug, task_spec.replicate)
manifest_path = _manifest_path(hpc_dir)
status_path_str = _sanitize_path_for_script(status_path)
manifest_path_str = _sanitize_path_for_script(manifest_path)
_sanitize_path_for_script(log_path)
header = _slurm_header(
resources,
f"pzmd_r_{task_spec.condition_slug}_{task_spec.replicate}",
log_path,
)
worker_cmd = (
f"{_pixi_run_prefix(resources)} polyzymd compare worker-replicate "
f'--manifest "{manifest_path_str}" '
f"--condition-index {task_spec.condition_index} "
f"--replicate {task_spec.replicate}"
)
error_expr = "'worker exit code ' + str($EXIT_CODE)"
script = _render_analysis_template(
_ANALYSIS_REPLICATE_TEMPLATE,
{
"header": header,
"status_file": status_path_str,
"manifest": manifest_path_str,
"max_retries": resources.max_retries,
"status_attempt_cmd": _status_attempt_python(status_path, resources),
"status_running_cmd": _status_update_python(status_path, "running", resources),
"worker_cmd": worker_cmd,
"status_retrying_cmd": _status_update_python(
status_path,
"retrying",
resources,
error_expr,
),
"status_failed_cmd": _status_update_python(
status_path,
"failed",
resources,
error_expr,
),
"status_succeeded_cmd": _status_update_python(status_path, "succeeded", resources),
},
)
script_path.parent.mkdir(parents=True, exist_ok=True)
script_path.write_text(script)
script_path.chmod(script_path.stat().st_mode | stat.S_IXUSR)
return script_path
[docs]
def generate_aggregate_script(
manifest: AnalysisJobManifest,
cond_spec: ConditionTaskSpec,
resources: AnalysisSlurmResources,
hpc_dir: Path,
) -> Path:
"""Generate an aggregate worker script with retries."""
script_path = _scripts_dir(hpc_dir) / f"aggregate__{cond_spec.condition_slug}.sh"
log_path = _logs_dir(hpc_dir) / f"aggregate__{cond_spec.condition_slug}.%j.out"
status_path = _cond_status_path(hpc_dir, cond_spec.condition_slug)
manifest_path = _manifest_path(hpc_dir)
status_path_str = _sanitize_path_for_script(status_path)
manifest_path_str = _sanitize_path_for_script(manifest_path)
_sanitize_path_for_script(log_path)
header = _slurm_header(resources, f"pzmd_a_{cond_spec.condition_slug}", log_path)
worker_cmd = (
f"{_pixi_run_prefix(resources)} polyzymd compare worker-aggregate "
f'--manifest "{manifest_path_str}" --condition-index {cond_spec.condition_index}'
)
error_expr = "'worker exit code ' + str($EXIT_CODE)"
script = _render_analysis_template(
_ANALYSIS_AGGREGATE_TEMPLATE,
{
"header": header,
"status_file": status_path_str,
"max_retries": resources.max_retries,
"status_attempt_cmd": _status_attempt_python(status_path, resources),
"status_running_cmd": _status_update_python(status_path, "running", resources),
"worker_cmd": worker_cmd,
"status_retrying_cmd": _status_update_python(
status_path,
"retrying",
resources,
error_expr,
),
"status_failed_cmd": _status_update_python(
status_path,
"failed",
resources,
error_expr,
),
"status_succeeded_cmd": _status_update_python(status_path, "succeeded", resources),
},
)
script_path.parent.mkdir(parents=True, exist_ok=True)
script_path.write_text(script)
script_path.chmod(script_path.stat().st_mode | stat.S_IXUSR)
return script_path
def _array_spec(replicates: list[int]) -> str:
"""Build a SLURM array selector specification.
Parameters
----------
replicates : list[int]
Replicate IDs to include in one array job.
Returns
-------
str
Comma-separated array spec such as ``"1,2,3"``.
Raises
------
ValueError
If no replicate IDs are provided.
"""
ordered = sorted(set(replicates))
if not ordered:
raise ValueError("At least one replicate is required for an array submission")
return ",".join(str(rep) for rep in ordered)
[docs]
def generate_array_script(
cond_spec: ConditionTaskSpec,
manifest: AnalysisJobManifest,
resources: AnalysisSlurmResources,
replicates: list[int],
hpc_dir: Path,
) -> Path:
"""Generate one array worker script for all replicates of a condition.
Parameters
----------
cond_spec : ConditionTaskSpec
Condition task specification from the manifest.
manifest : AnalysisJobManifest
Submission manifest used by worker commands.
resources : AnalysisSlurmResources
SLURM resource settings used in script header.
replicates : list[int]
Replicate IDs included in this array job.
hpc_dir : Path
Root directory where scripts and logs are written.
Returns
-------
Path
Generated executable script path.
"""
script_path = _scripts_dir(hpc_dir) / f"replicate_array__{cond_spec.condition_slug}.sh"
log_path = _logs_dir(hpc_dir) / f"replicate_array__{cond_spec.condition_slug}.%a.log"
manifest_path = _manifest_path(hpc_dir)
manifest_path_str = _sanitize_path_for_script(manifest_path)
_sanitize_path_for_script(log_path)
header = _slurm_header(resources, f"pzmd_ra_{cond_spec.condition_slug}", log_path)
array_spec = _array_spec(replicates)
case_branches: list[dict[str, str | int]] = []
for replicate in sorted(set(replicates)):
status_path = _rep_status_path(hpc_dir, cond_spec.condition_slug, replicate)
error_expr = "'worker exit code ' + str($EXIT_CODE)"
worker_cmd = (
f"{_pixi_run_prefix(resources)} polyzymd compare worker-replicate "
f'--manifest "{manifest_path_str}" '
f"--condition-index {cond_spec.condition_index} "
"--replicate $REP"
)
case_branches.append(
{
"replicate": replicate,
"status_attempt_cmd": _status_attempt_python(status_path, resources),
"status_running_cmd": _status_update_python(status_path, "running", resources),
"worker_cmd": worker_cmd,
"status_retrying_cmd": _status_update_python(
status_path,
"retrying",
resources,
error_expr,
),
"status_failed_cmd": _status_update_python(
status_path,
"failed",
resources,
error_expr,
),
"status_succeeded_cmd": _status_update_python(
status_path,
"succeeded",
resources,
),
}
)
script = _render_analysis_template(
_ANALYSIS_ARRAY_TEMPLATE,
{
"header": header,
"array_spec": array_spec,
"max_retries": resources.max_retries,
"case_branches": case_branches,
},
)
script_path.parent.mkdir(parents=True, exist_ok=True)
script_path.write_text(script)
script_path.chmod(script_path.stat().st_mode | stat.S_IXUSR)
return script_path
[docs]
def generate_finalize_script(
manifest: AnalysisJobManifest,
resources: AnalysisSlurmResources,
hpc_dir: Path,
) -> Path:
"""Generate the final compare+plot worker script with retries."""
script_path = _scripts_dir(hpc_dir) / "finalize.sh"
log_path = _logs_dir(hpc_dir) / "finalize.%j.out"
status_path = _final_status_path(hpc_dir)
manifest_path = _manifest_path(hpc_dir)
status_path_str = _sanitize_path_for_script(status_path)
manifest_path_str = _sanitize_path_for_script(manifest_path)
_sanitize_path_for_script(log_path)
header = _slurm_header(resources, f"pzmd_f_{manifest.analysis_name}", log_path)
worker_cmd = (
f"{_pixi_run_prefix(resources)} polyzymd compare worker-finalize "
f'--manifest "{manifest_path_str}"'
)
error_expr = "'worker exit code ' + str($EXIT_CODE)"
script = _render_analysis_template(
_ANALYSIS_FINALIZE_TEMPLATE,
{
"header": header,
"status_file": status_path_str,
"max_retries": resources.max_retries,
"status_attempt_cmd": _status_attempt_python(status_path, resources),
"status_running_cmd": _status_update_python(status_path, "running", resources),
"worker_cmd": worker_cmd,
"status_retrying_cmd": _status_update_python(
status_path,
"retrying",
resources,
error_expr,
),
"status_failed_cmd": _status_update_python(
status_path,
"failed",
resources,
error_expr,
),
"status_succeeded_cmd": _status_update_python(status_path, "succeeded", resources),
},
)
script_path.parent.mkdir(parents=True, exist_ok=True)
script_path.write_text(script)
script_path.chmod(script_path.stat().st_mode | stat.S_IXUSR)
return script_path
def _submit_sbatch(script_path: Path, dependency: str | None = None) -> str:
cmd = ["sbatch"]
if dependency:
cmd.extend(["--dependency", dependency])
cmd.append(str(script_path))
try:
completed = subprocess.run(cmd, capture_output=True, text=True, check=True)
except FileNotFoundError as exc:
raise RuntimeError(
"SLURM is not available: 'sbatch' not found on PATH. The HPC submission "
"commands require a SLURM cluster. Run analysis locally with "
"'polyzymd compare run' instead."
) from exc
except subprocess.CalledProcessError as exc:
stderr = (exc.stderr or "").strip()
details = stderr if stderr else "No stderr output from sbatch"
raise RuntimeError(f"SLURM submission failed for {script_path}: {details}") from exc
output = completed.stdout.strip()
match = re.search(r"Submitted batch job\s+(\d+)", output)
if match is None:
raise RuntimeError(
f"{_SBATCH_PARSE_FAILURE_MARKER}: Could not parse job id from sbatch output. "
f"Raw stdout: {output!r}"
)
return match.group(1)
def _cancel_jobs(job_ids: list[str]) -> dict[str, dict[str, Any]]:
"""Attempt to cancel submitted SLURM jobs.
Parameters
----------
job_ids : list[str]
Job IDs to cancel.
Returns
-------
dict[str, dict[str, Any]]
Per-job cancellation outcomes keyed by job ID with retry metadata.
Notes
-----
Cancellation is best-effort. Failures are logged and suppressed.
"""
results: dict[str, dict[str, Any]] = {}
if not job_ids:
return results
for job_id in job_ids:
if not _SLURM_JOB_ID_PATTERN.fullmatch(job_id):
LOGGER.warning("Skipping invalid SLURM job id for cancellation: %r", job_id)
results[job_id] = {
"attempted": False,
"cancelled": False,
"attempts": 0,
"error": "invalid_job_id",
}
continue
attempts = 0
cancelled = False
error: str | None = None
for attempt in (1, 2):
attempts = attempt
LOGGER.info("Cancelling SLURM job %s (attempt %d/2)", job_id, attempt)
try:
subprocess.run(["scancel", job_id], capture_output=True, text=True, check=True)
except FileNotFoundError as exc:
error = f"scancel_not_found: {exc}"
LOGGER.warning("Could not cancel submitted jobs: 'scancel' not found (%s)", exc)
break
except subprocess.CalledProcessError as exc:
stderr = (exc.stderr or "").strip()
error = stderr if stderr else "No stderr output from scancel"
LOGGER.warning(
"Failed to cancel submitted job %s on attempt %d/2: %s",
job_id,
attempt,
error,
)
if attempt == 1:
time.sleep(2)
continue
else:
cancelled = True
error = None
break
break
results[job_id] = {
"attempted": True,
"cancelled": cancelled,
"attempts": attempts,
"error": error,
}
return results
def _query_sacct(job_ids: list[str]) -> dict[str, str]:
"""Query sacct for base SLURM job states.
Parameters
----------
job_ids : list[str]
SLURM job IDs to query.
Returns
-------
dict[str, str]
Mapping of base job ID to SLURM state string.
Notes
-----
This parser ignores sub-step rows such as ``12345.batch`` and only keeps
top-level job records.
If ``sacct`` emits duplicate rows for the same top-level job ID, conflict
resolution is deterministic:
- terminal states are preferred over non-terminal states
- among terminal states, the most recent terminal row is preferred
"""
if not job_ids:
return {}
try:
completed = subprocess.run(
[
"sacct",
"-j",
",".join(job_ids),
"--format=JobIDRaw,State",
"--noheader",
"--parsable2",
],
capture_output=True,
text=True,
check=True,
)
except FileNotFoundError as exc:
LOGGER.warning("Could not reconcile status with sacct: command not found (%s)", exc)
return {}
except subprocess.CalledProcessError as exc:
stderr = (exc.stderr or "").strip()
details = stderr if stderr else "No stderr output from sacct"
LOGGER.warning("Could not reconcile status with sacct: %s", details)
return {}
states: dict[str, str] = {}
for raw_line in completed.stdout.splitlines():
line = raw_line.strip()
if not line:
continue
parts = line.split("|")
if len(parts) != 2:
continue
job_id, state = parts[0].strip(), parts[1].strip()
if not job_id or not state:
continue
if "." in job_id:
continue
existing = states.get(job_id)
if existing is None:
states[job_id] = state
continue
existing_terminal = _is_terminal_slurm_state(existing)
current_terminal = _is_terminal_slurm_state(state)
if current_terminal and not existing_terminal:
states[job_id] = state
continue
if current_terminal and existing_terminal:
states[job_id] = state
continue
if not current_terminal and not existing_terminal:
states[job_id] = state
return states
def _is_terminal_slurm_state(slurm_state: str) -> bool:
"""Return whether a SLURM state string is terminal.
Parameters
----------
slurm_state : str
Raw SLURM state string from ``sacct``.
Returns
-------
bool
``True`` when the state is terminal, otherwise ``False``.
"""
normalized = slurm_state.strip().upper().split(maxsplit=1)[0]
if normalized.startswith("CANCELLED"):
return True
return normalized in {
"COMPLETED",
"FAILED",
"TIMEOUT",
"OUT_OF_MEMORY",
"NODE_FAIL",
"BOOT_FAIL",
"DEADLINE",
"REVOKED",
"PREEMPTED",
}
def _map_slurm_state(slurm_state: str) -> str | None:
"""Map SLURM states to local task states.
Parameters
----------
slurm_state : str
Raw SLURM state as returned by ``sacct``.
Returns
-------
str | None
New local state if a terminal transition is needed, otherwise ``None``.
"""
normalized = slurm_state.strip().upper()
normalized = normalized.split(maxsplit=1)[0]
if normalized == "COMPLETED":
return "succeeded"
if normalized in {
"FAILED",
"OUT_OF_MEMORY",
"TIMEOUT",
"NODE_FAIL",
"BOOT_FAIL",
"DEADLINE",
"REVOKED",
}:
return "failed"
if normalized.startswith("CANCELLED") or normalized == "PREEMPTED":
return "failed"
if normalized in {"RUNNING", "PENDING", "SUSPENDED", "REQUEUED"}:
return None
LOGGER.debug("Unknown SLURM state %r during reconciliation", slurm_state)
return None
[docs]
def reconcile_status_with_slurm(hpc_dir: Path) -> dict[str, Any]:
"""Reconcile local status files with live SLURM accounting state.
Parameters
----------
hpc_dir : Path
Root HPC artifact directory for one analysis submission.
Returns
-------
dict[str, Any]
Summary with checked file count, update count, and per-file changes.
"""
rep_root = hpc_dir / "status" / "replicates"
cond_root = hpc_dir / "status" / "conditions"
status_paths: list[Path] = []
if rep_root.exists():
for cond_dir in sorted(p for p in rep_root.iterdir() if p.is_dir()):
status_paths.extend(sorted(cond_dir.glob("rep_*.json")))
if cond_root.exists():
status_paths.extend(sorted(cond_root.glob("*.json")))
finalize_path = _final_status_path(hpc_dir)
if finalize_path.exists():
status_paths.append(finalize_path)
actionable: list[tuple[Path, dict[str, Any], str]] = []
for status_path in status_paths:
try:
payload = json.loads(status_path.read_text())
if not isinstance(payload, dict):
raise ValueError("Status payload must be a JSON object")
local_state_raw = payload["state"]
job_id_raw = payload["slurm_job_id"]
if not isinstance(local_state_raw, str):
raise ValueError("Status payload field 'state' must be a string")
if job_id_raw is not None and not isinstance(job_id_raw, str):
raise ValueError("Status payload field 'slurm_job_id' must be a string or null")
except (
json.JSONDecodeError,
KeyError,
ValueError,
FileNotFoundError,
PermissionError,
) as exc:
LOGGER.warning(
"Skipping unreadable status file during reconciliation: %s (%s)", status_path, exc
)
continue
local_state = local_state_raw.strip().lower()
job_id = job_id_raw
if (
local_state in {"running", "pending", "retrying"}
and isinstance(job_id, str)
and job_id.strip()
):
actionable.append((status_path, payload, job_id.strip()))
if not actionable:
return {"checked": 0, "updated": 0, "changes": []}
unique_job_ids = sorted({job_id for _, _, job_id in actionable})
sacct_states = _query_sacct(unique_job_ids)
changes: list[dict[str, str]] = []
for status_path, payload, job_id in actionable:
slurm_state = sacct_states.get(job_id)
if slurm_state is None:
continue
mapped_state = _map_slurm_state(slurm_state)
if mapped_state is None:
continue
old_state = str(payload.get("state", "")).strip().lower()
if mapped_state == old_state:
continue
try:
latest_payload = json.loads(status_path.read_text())
if not isinstance(latest_payload, dict):
raise ValueError("Status payload must be a JSON object")
latest_state_raw = latest_payload["state"]
if not isinstance(latest_state_raw, str):
raise ValueError("Status payload field 'state' must be a string")
latest_state = latest_state_raw.strip().lower()
except (
json.JSONDecodeError,
KeyError,
ValueError,
FileNotFoundError,
PermissionError,
) as exc:
LOGGER.warning(
"Skipping unreadable status file during reconciliation: %s (%s)", status_path, exc
)
continue
if latest_state != old_state:
LOGGER.debug(
"Skipping status reconciliation for %s because state changed from %s to %s",
status_path,
old_state,
latest_state,
)
continue
latest_job_id = latest_payload.get("slurm_job_id")
if latest_job_id != job_id:
LOGGER.debug(
"Skipping status reconciliation for %s because slurm_job_id changed from %s to %s",
status_path,
job_id,
latest_job_id,
)
continue
latest_payload["state"] = mapped_state
latest_payload["reconciled_from"] = slurm_state
latest_payload["reconciled_at"] = _utc_now()
if mapped_state == "failed" and (
slurm_state.upper().startswith("CANCELLED")
or slurm_state.upper().startswith("PREEMPTED")
):
latest_payload["error_message"] = f"Reconciled from SLURM state {slurm_state}"
try:
with NamedTemporaryFile(
"w", encoding="utf-8", dir=status_path.parent, delete=False
) as tmp:
tmp.write(json.dumps(latest_payload, indent=2))
tmp_path = Path(tmp.name)
os.replace(tmp_path, status_path)
except OSError as write_exc:
LOGGER.warning("Failed to write reconciled status for %s: %s", status_path, write_exc)
if tmp_path.exists():
tmp_path.unlink(missing_ok=True)
continue
changes.append(
{
"job_id": job_id,
"path": str(status_path),
"old_state": old_state,
"new_state": mapped_state,
"slurm_state": slurm_state,
}
)
return {
"checked": len(actionable),
"updated": len(changes),
"changes": changes,
}
[docs]
def submit_analysis_graph(
manifest: AnalysisJobManifest,
resources: AnalysisSlurmResources,
hpc_dir: Path,
root_dependencies: Sequence[str] = (),
) -> SubmittedJobGraph:
"""Submit replicate, aggregate, and finalizer jobs with dependencies.
Parameters
----------
manifest : AnalysisJobManifest
Submission manifest describing all condition and replicate tasks.
resources : AnalysisSlurmResources
SLURM resource settings used for generated scripts.
hpc_dir : Path
Root directory where scripts, logs, and submission metadata are stored.
Returns
-------
SubmittedJobGraph
Graph of submitted replicate, aggregate, and finalizer job IDs.
Raises
------
RuntimeError
Propagated if any ``sbatch`` submission fails.
"""
_ensure_layout(hpc_dir, manifest)
submission_error_path = _submission_error_path(hpc_dir)
if submission_error_path.exists():
submission_error_path.unlink()
manifest.save(_manifest_path(hpc_dir))
replicate_jobs: dict[tuple[int, int], str] = {}
aggregator_jobs: dict[int, str] = {}
submitted_job_ids: list[str] = []
try:
if manifest.pipeline_mode == "finalize_only":
finalize_script = generate_finalize_script(manifest, resources, hpc_dir)
dependency = None
if root_dependencies:
dependency = "afterok:" + ":".join(root_dependencies)
finalizer_job_id = _submit_sbatch(finalize_script, dependency=dependency)
submitted_job_ids.append(finalizer_job_id)
else:
for cond_spec in manifest.condition_specs:
for rep_spec in cond_spec.replicate_specs:
script = generate_replicate_script(manifest, rep_spec, resources, hpc_dir)
job_id = _submit_sbatch(script)
replicate_jobs[(rep_spec.condition_index, rep_spec.replicate)] = job_id
submitted_job_ids.append(job_id)
for cond_spec in manifest.condition_specs:
replicate_ids = [
replicate_jobs[(cond_spec.condition_index, rep_spec.replicate)]
for rep_spec in cond_spec.replicate_specs
]
dependency = "afterany:" + ":".join(replicate_ids)
script = generate_aggregate_script(manifest, cond_spec, resources, hpc_dir)
job_id = _submit_sbatch(script, dependency=dependency)
aggregator_jobs[cond_spec.condition_index] = job_id
submitted_job_ids.append(job_id)
aggregate_dependency = "afterany:" + ":".join(
aggregator_jobs[idx] for idx in sorted(aggregator_jobs.keys())
)
finalize_script = generate_finalize_script(manifest, resources, hpc_dir)
finalizer_job_id = _submit_sbatch(finalize_script, dependency=aggregate_dependency)
submitted_job_ids.append(finalizer_job_id)
except (RuntimeError, subprocess.SubprocessError, OSError) as exc:
cancel_results: dict[str, dict[str, Any]] = {}
if submitted_job_ids:
cancel_results = _cancel_jobs(submitted_job_ids)
cancelled_count = sum(
1 for result in cancel_results.values() if result.get("cancelled") is True
)
LOGGER.warning(
"Analysis DAG submission failed and cancellation succeeded for %d/%d tracked jobs",
cancelled_count,
len(submitted_job_ids),
)
if _SBATCH_PARSE_FAILURE_MARKER in str(exc):
LOGGER.warning(
"Submission failed while parsing sbatch output; one or more jobs may not have "
"been tracked for rollback"
)
raw_sbatch_stdout: str | None = None
if _SBATCH_PARSE_FAILURE_MARKER in str(exc):
marker = "Raw stdout:"
if marker in str(exc):
raw_sbatch_stdout = str(exc).split(marker, maxsplit=1)[1].strip()
error_payload = {
"error": str(exc),
"cancelled_job_ids": submitted_job_ids,
"cancellation_results": cancel_results,
"raw_sbatch_stdout": raw_sbatch_stdout,
"timestamp": _utc_now(),
}
try:
submission_error_path.write_text(json.dumps(error_payload, indent=2))
except (OSError, TypeError, ValueError) as sidecar_exc:
LOGGER.warning(
"Failed to write submission error sidecar at %s: %s",
submission_error_path,
sidecar_exc,
)
raise
graph = SubmittedJobGraph(
replicate_jobs=replicate_jobs,
aggregator_jobs=aggregator_jobs,
finalizer_job_id=finalizer_job_id,
)
graph.save(_graph_path(hpc_dir))
return graph
[docs]
def submit_analysis_graph_with_arrays(
manifest: AnalysisJobManifest,
resources: AnalysisSlurmResources,
hpc_dir: Path,
) -> SubmittedJobGraph:
"""Submit one array per condition plus aggregate and finalizer jobs.
Parameters
----------
manifest : AnalysisJobManifest
Submission manifest describing condition and replicate tasks.
resources : AnalysisSlurmResources
SLURM resource settings used for generated scripts.
hpc_dir : Path
Root directory where scripts, logs, and submission metadata are stored.
Returns
-------
SubmittedJobGraph
Graph of submitted array, aggregate, and finalizer job IDs.
Raises
------
RuntimeError
Propagated if any ``sbatch`` submission fails.
"""
_ensure_layout(hpc_dir, manifest)
submission_error_path = _submission_error_path(hpc_dir)
if submission_error_path.exists():
submission_error_path.unlink()
manifest.save(_manifest_path(hpc_dir))
array_jobs: dict[str, str] = {}
aggregator_jobs: dict[int, str] = {}
submitted_job_ids: list[str] = []
try:
for cond_spec in manifest.condition_specs:
replicates = [rep_spec.replicate for rep_spec in cond_spec.replicate_specs]
script = generate_array_script(cond_spec, manifest, resources, replicates, hpc_dir)
job_id = _submit_sbatch(script)
array_jobs[cond_spec.condition_slug] = job_id
submitted_job_ids.append(job_id)
for cond_spec in manifest.condition_specs:
parent_array_job_id = array_jobs[cond_spec.condition_slug]
dependency = f"afterany:{parent_array_job_id}"
script = generate_aggregate_script(manifest, cond_spec, resources, hpc_dir)
job_id = _submit_sbatch(script, dependency=dependency)
aggregator_jobs[cond_spec.condition_index] = job_id
submitted_job_ids.append(job_id)
aggregate_dependency = "afterany:" + ":".join(
aggregator_jobs[idx] for idx in sorted(aggregator_jobs.keys())
)
finalize_script = generate_finalize_script(manifest, resources, hpc_dir)
finalizer_job_id = _submit_sbatch(finalize_script, dependency=aggregate_dependency)
submitted_job_ids.append(finalizer_job_id)
except (RuntimeError, subprocess.SubprocessError, OSError) as exc:
cancel_results: dict[str, dict[str, Any]] = {}
if submitted_job_ids:
cancel_results = _cancel_jobs(submitted_job_ids)
cancelled_count = sum(
1 for result in cancel_results.values() if result.get("cancelled") is True
)
LOGGER.warning(
"Analysis DAG array submission failed and cancellation succeeded for %d/%d tracked jobs",
cancelled_count,
len(submitted_job_ids),
)
if _SBATCH_PARSE_FAILURE_MARKER in str(exc):
LOGGER.warning(
"Submission failed while parsing sbatch output; one or more jobs may not have "
"been tracked for rollback"
)
raw_sbatch_stdout: str | None = None
if _SBATCH_PARSE_FAILURE_MARKER in str(exc):
marker = "Raw stdout:"
if marker in str(exc):
raw_sbatch_stdout = str(exc).split(marker, maxsplit=1)[1].strip()
error_payload = {
"error": str(exc),
"cancelled_job_ids": submitted_job_ids,
"cancellation_results": cancel_results,
"raw_sbatch_stdout": raw_sbatch_stdout,
"timestamp": _utc_now(),
}
try:
submission_error_path.write_text(json.dumps(error_payload, indent=2))
except (OSError, TypeError, ValueError) as sidecar_exc:
LOGGER.warning(
"Failed to write submission error sidecar at %s: %s",
submission_error_path,
sidecar_exc,
)
raise
graph = SubmittedJobGraph(
replicate_jobs={},
array_jobs=array_jobs,
aggregator_jobs=aggregator_jobs,
finalizer_job_id=finalizer_job_id,
)
graph.save(_graph_path(hpc_dir))
return graph
[docs]
def read_analysis_status(hpc_dir: Path) -> dict[str, Any]:
"""Read all status files for one analysis HPC run."""
status_root = hpc_dir / "status"
rep_root = status_root / "replicates"
cond_root = status_root / "conditions"
replicate_summary: dict[str, dict[str, dict[str, Any]]] = {}
warnings: list[str] = []
if rep_root.exists():
for cond_dir in sorted(p for p in rep_root.iterdir() if p.is_dir()):
per_rep: dict[str, dict[str, Any]] = {}
for status_file in sorted(cond_dir.glob("rep_*.json")):
try:
status = TaskStatus.model_validate_json(status_file.read_text())
per_rep[status_file.stem] = status.model_dump()
except (
json.JSONDecodeError,
KeyError,
ValueError,
FileNotFoundError,
PermissionError,
) as exc:
warning = f"Corrupted status file: {status_file} ({type(exc).__name__}: {exc})"
LOGGER.warning(warning)
warnings.append(warning)
per_rep[status_file.stem] = {
"state": "unknown",
"attempt_count": 0,
"error_message": warning,
"last_updated": None,
"slurm_job_id": None,
}
replicate_summary[cond_dir.name] = per_rep
condition_summary: dict[str, dict[str, Any]] = {}
if cond_root.exists():
for status_file in sorted(cond_root.glob("*.json")):
try:
status = TaskStatus.model_validate_json(status_file.read_text())
condition_summary[status_file.stem] = status.model_dump()
except (
json.JSONDecodeError,
KeyError,
ValueError,
FileNotFoundError,
PermissionError,
) as exc:
warning = f"Corrupted status file: {status_file} ({type(exc).__name__}: {exc})"
LOGGER.warning(warning)
warnings.append(warning)
condition_summary[status_file.stem] = {
"state": "unknown",
"attempt_count": 0,
"error_message": warning,
"last_updated": None,
"slurm_job_id": None,
}
finalize = None
finalize_path = _final_status_path(hpc_dir)
if finalize_path.exists():
try:
finalize = TaskStatus.model_validate_json(finalize_path.read_text()).model_dump()
except (
json.JSONDecodeError,
KeyError,
ValueError,
FileNotFoundError,
PermissionError,
) as exc:
warning = f"Corrupted status file: {finalize_path} ({type(exc).__name__}: {exc})"
LOGGER.warning(warning)
warnings.append(warning)
finalize = {
"state": "unknown",
"attempt_count": 0,
"error_message": warning,
"last_updated": None,
"slurm_job_id": None,
}
states: list[str] = []
for cond_map in replicate_summary.values():
states.extend(item["state"] for item in cond_map.values())
states.extend(item["state"] for item in condition_summary.values())
if finalize is not None:
states.append(finalize["state"])
counts = {
state: states.count(state)
for state in [
"pending",
"running",
"retrying",
"succeeded",
"failed",
"unknown",
]
}
return {
"replicates": replicate_summary,
"conditions": condition_summary,
"finalize": finalize,
"counts": counts,
"warnings": warnings,
}