diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 0000000..38a4596 --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,29 @@ +name: lint + +on: + push: + branches: [master, main, v2] + pull_request: + branches: [master, main, v2] + +jobs: + ruff: + name: Lint (ruff) + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v6.0.2 + + - name: Set up Python + uses: actions/setup-python@v6.2.0 + with: + python-version: "3.12" + + - name: Install ruff + run: pip install ruff + + - name: Run ruff check + run: ruff check batch/ tests/ + + - name: Run ruff format check + run: ruff format --check batch/ tests/ diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml new file mode 100644 index 0000000..6c1e3b1 --- /dev/null +++ b/.github/workflows/tests.yml @@ -0,0 +1,38 @@ +name: tests + +on: + push: + branches: [master, main, v2] + pull_request: + branches: [master, main, v2] + +jobs: + unit-tests: + name: Unit tests (Python ${{ matrix.python-version }}) + runs-on: ubuntu-latest + + strategy: + fail-fast: false + matrix: + python-version: ["3.10", "3.11", "3.12"] + + steps: + - uses: actions/checkout@v6.0.2 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v6.2.0 + with: + python-version: ${{ matrix.python-version }} + + - name: Install package and dev dependencies + run: pip install -e ".[dev]" + + - name: Run unit tests + run: pytest tests/ -v --tb=short + + - name: Upload test results on failure + if: failure() + uses: actions/upload-artifact@v7 + with: + name: test-results-py${{ matrix.python-version }} + path: .pytest_cache/ diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..4166074 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,95 @@ +# Changelog + +All notable changes to this project will be documented in this file. +Format follows [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). + +## [Unreleased] + +### Added +- `Job.post_run(result)`: hook called after successful job completion. + No-op default; override for per-job plotting or data conversion. + Fires concurrently with remaining jobs in ParallelExecutor. + Exceptions in post_run are logged and do not abort the batch. +- `batch.plot.plot_job`: runs plotclaw as a subprocess, capturing all + output (including C-level I/O) to the job log file. Callable setplot + falls back to in-process with a logged warning. + +## [2.0.0] — breaking API change + +Tagged as v2. The v1 API (original `batch.py`, `stampede.py`) is preserved +on the `v1.0.0` tag. + +### Added +- `JobPaths` dataclass: typed, named filesystem layout replacing the `dict` + returned by the old `run()`. +- `JobResult` dataclass: carries `job`, `paths`, `returncode`, and `job_id`. +- `ClobberPolicy` enum: `OVERWRITE` (default), `ERROR`, `SKIP`. + `SKIP` gives free batch resumability — re-run the same script and only + unfinished jobs are submitted. +- `BatchController.experiment` attribute: replaces the `job.type` / `job.name` + two-level grouping with a single experiment subdirectory set once on the + controller. +- `BatchController.setup()`: writes `.data` files without running the solver. + Replaces the `run(only_write_data=True)` flag. +- `Job.build(paths)`: hook for per-job compilation before submission. + No-op default; override for jobs that compile Fortran source. +- `Executor` protocol: defines `submit()` and `wait_all()`. New schedulers + are added by implementing this protocol rather than subclassing + `BatchController`. +- `SerialExecutor`: sequential local runner. +- `ParallelExecutor`: concurrent local runner replacing the hand-rolled process + queue. Fixes the modify-list-while-iterating bug and propagates + `returncode`. +- `SLURMExecutor`: submits via `sbatch --parsable`, captures job ID, + polls `squeue` in `wait_all`. Replaces `StampedeBatchController`. +- `SLURMResources` dataclass: typed SLURM resource request. Per-job override + by attaching `job.slurm_resources`. +- `render_slurm_script()`: pure function for SLURM script generation — + independently testable without a cluster. +- `batch.sweep.product_sweep()`, `zip_sweep()`: build job lists from parameter + grids. +- `pyproject.toml` (PEP 517/518, hatchling backend). +- pytest test suite covering all public components without requiring a Clawpack + installation or a running scheduler. + +### Changed +- `Job.write_data_objects()` now accepts an explicit `path: Path` argument and + calls `rundata.write(out_dir=path)`. The `os.chdir` pattern is eliminated. +- `Job.restart` is now a first-class attribute on `Job`, not accessed through + `job.rundata.clawdata.restart` in the controller. +- `BatchController.run()` now defaults to `wait=True` (blocking). The old + default of `wait=False` silently killed background subprocesses when the + calling script exited. +- `max_processes` no longer defaults from `$OMP_NUM_THREADS`. Use + `$BATCH_MAX_JOBS` or pass `max_workers` explicitly to `ParallelExecutor`. +- Flattened directory layout: data files, solver output, and log all share one + directory (`OUTPUT_PATH/experiment/prefix/`). Only plots get a subdirectory. +- `OUTPUT_PATH` is the only environment-variable default for output location. + `DATA_PATH` is no longer used. +- All `subprocess` calls use explicit argument lists (`shell=False`). + +### Removed +- `Job.type`, `Job.name` (replaced by `BatchController.experiment`). +- `Job.output_path`, `Job.data_path`, `Job.log_path` (dead attributes). +- `BatchController.parallel`, `BatchController.terminal_output`, + `BatchController.runclaw_cmd`, `BatchController.plotclaw_cmd`, + `BatchController.max_processes`, `BatchController.poll_interval` — all + moved into the executor or removed. +- `StampedeBatchController` and `StampedeJob` — superseded by `SLURMExecutor` + and `SLURMResources`. +- `from __future__ import` statements. +- Python 2 `super(ClassName, self)` style. + +### Fixed +- Modify-list-while-iterating in the parallel process drain loop caused every + other completed process to be silently skipped. +- Log file handle was never closed in parallel mode. +- `OMP_NUM_THREADS` was incorrectly used as the number of parallel jobs. +- `#SBATCH -t` in `StampedeBatchController` was hardcoded to `9:00:00`, + ignoring `job.time`. +- Missing `\n` in the Stampede MIC environment export line. + +## [1.0.0] + +Original implementation. Tagged for historical reference. +See `batch.py` and `stampede.py` on the `v1.0.0` tag. diff --git a/README.md b/README.md index 545ef40..c626536 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,304 @@ -batch -===== +# batch + +Utilities for running [Clawpack](https://www.clawpack.org) / GeoClaw batch jobs. + +`batch` manages the directory layout, data file generation, and job submission +for parameter sweeps and ensemble simulations. Execution backends are +pluggable: the same job definition runs locally (serial or parallel) or on +a SLURM cluster without changing any application code. + +> **v2 breaking change.** The v1 API (`batch.Job`, `batch.BatchController` +> with scheduler-specific subclasses) is preserved on the `v1.0.0` tag. +> See [CHANGELOG](CHANGELOG.md) for the full migration guide. + +--- + +## Installation + +```bash +pip install -e . +``` + +Requires Python ≥ 3.10. Clawpack must be importable at runtime but is not +listed as a hard dependency (it is assumed to be present in the environment). + +--- + +## Core concepts + +| Class / Function | Role | +|---|---| +| `Job` | Describes one simulation: `prefix`, `executable`, `rundata`, optional `build()` / `post_run()` overrides | +| `BatchController` | Orchestrates directory setup, data writing, and dispatch | +| `Executor` | Protocol implemented by `SerialExecutor`, `ParallelExecutor`, `SLURMExecutor` | +| `JobPaths` | Typed paths for one job's directory, plots, and log | +| `JobResult` | Return value from `run()`: job, paths, returncode, scheduler job ID | +| `ClobberPolicy` | Controls what happens when output already exists: `OVERWRITE`, `ERROR`, `SKIP` | +| `plot_job` | Calls plotclaw in-process after job completion; handles missing visclaw gracefully | + +--- + +## Quick start + +### 1. Define a job + +Subclass `Job`, populate `rundata`, and set `prefix`: + +```python +from pathlib import Path +import importlib.util +from batch import Job + +import clawpack.clawutil.util as clawutil + +class MyGeoClawJob(Job): + def __init__(self, manning: float) -> None: + super().__init__() + self.prefix = f"n{manning:.3f}" + self.executable = "xgeoclaw" + self.manning = manning + + # Load base configuration from a local setrun.py - requires clawpack + setrun = clawutil.fullpath_import(setrun_path) + self.rundata = setrun.setrun() + + # Apply parameter override + self.rundata.geo_data.manning_coefficient = manning +``` + +### 2. Run locally + +```python +from batch import BatchController, ParallelExecutor + +jobs = [MyGeoClawJob(manning=n) for n in [0.020, 0.025, 0.030]] + +ctrl = BatchController( + jobs=jobs, + executor=ParallelExecutor(max_workers=3), + experiment="manning_sensitivity", +) +results = ctrl.run() + +for r in results: + status = "ok" if r.success else f"FAILED (rc={r.returncode})" + print(f" {r.job.prefix} {status} -> {r.paths.job}") +``` + +Output layout: + +``` +OUTPUT_PATH/ + manning_sensitivity/ + n0.020/ + n0.020_log.txt + *.data + fort.* + plots/ + n0.025/ + ... +``` + +### 3. Run on SLURM + +```python +from batch import BatchController, SLURMExecutor, SLURMResources + +executor = SLURMExecutor( + default_resources=SLURMResources( + partition="main", + nodes=1, + cpus_per_task=8, + time="06:00:00", + account="MY_ALLOCATION", + env_vars={"OMP_NUM_THREADS": "8"}, + modules=["ncarenv/23.09", "python/3.11.4"], + ), +) + +ctrl = BatchController( + jobs=jobs, + executor=executor, + experiment="manning_sensitivity", +) +# Returns immediately after sbatch submission; job IDs in results +results = ctrl.run(wait=False) +for r in results: + print(f" {r.job.prefix} -> SLURM job {r.job_id}") +``` + +Per-job resource overrides are supported without subclassing — attach a +`SLURMResources` instance directly to the job: + +```python +job.slurm_resources = SLURMResources(partition="gpu", time="12:00:00") +``` + +### 4. Inspect scripts without submitting + +```python +executor = SLURMExecutor(default_resources=resources, dry_run=True) +ctrl = BatchController(jobs=jobs, executor=executor, experiment="test") +ctrl.run(wait=False) +# Scripts written to each job directory; sbatch not called +``` + +--- + +## Parameter sweeps + +### Cartesian product + +```python +from batch.sweep import product_sweep + +jobs = product_sweep( + factory=lambda manning, level: MyGeoClawJob(manning, max_level=level), + namer=lambda p: f"n{p['manning']:.3f}_l{p['level']}", + manning=[0.020, 0.025, 0.030], + level=[4, 5], +) +# 6 jobs: 3 Manning values x 2 refinement levels +``` + +### Paired sweep + +```python +from batch.sweep import zip_sweep + +jobs = zip_sweep( + factory=lambda storm_id, intensity: StormJob(storm_id, intensity), + namer=lambda p: f"{p['storm_id']}_{p['intensity']}", + storm_id=["katrina", "ike", "harvey"], + intensity=["low", "mid", "high"], +) +# 3 jobs: one per (storm, intensity) pair +``` + +--- + +## Resuming a partial batch + +Use `ClobberPolicy.SKIP` to skip jobs whose output directory already exists. +Re-run the same script after a walltime kill and only unfinished jobs are +submitted: + +```python +from batch import ClobberPolicy + +ctrl = BatchController( + jobs=jobs, + executor=executor, + experiment="my_ensemble", + clobber=ClobberPolicy.SKIP, +) +``` + +--- + +## Per-job compilation + +Override `build()` when a job requires compiling the executable before +submission. The no-op default is used when all jobs share a pre-built binary. + +```python +import shutil +import subprocess +from batch import Job, JobPaths + +class CompiledJob(Job): + def __init__(self, source_path): + super().__init__() + self.source_path = source_path + + def build(self, paths: JobPaths) -> None: + subprocess.run(["make", ".exe"], cwd=self.source_path, check=True) + shutil.move(self.source_path / self.executable, paths.job) + self.executable = paths.job / self.executable +``` + +The controller calls `job.build(paths)` after writing data files and before +calling `executor.submit()`. For SLURM this means compilation happens on +the login node, which is the correct behavior. + +--- + +## Per-job postprocessing + +Override `post_run(result)` to run plotting, data conversion, or any other +work immediately after a job completes successfully. The default is a no-op. +`post_run` receives a `JobResult` giving access to `result.paths` and +`result.returncode`. For `ParallelExecutor` it fires as each job is +harvested in `_drain`, so postprocessing for a finished job runs concurrently +with jobs still in flight. For `SLURMExecutor` it fires as each job leaves +the queue in `wait_all`. Exceptions raised inside `post_run` are logged and +swallowed — a failing postprocessing step never aborts the batch loop. + +Use `plot_job` for the common case of running plotclaw after a job completes. +It runs plotclaw as a subprocess so all output — including C-level output from +matplotlib — is captured to the job's log file rather than the terminal; a +`--- plotclaw ---` separator is written to the log between solver and plotting +output. It resolves relative setplot paths against the job directory and +returns `False` gracefully when visclaw is not installed rather than raising. + +```python +from pathlib import Path +from batch import Job +from batch import plot_job + +class MyJob(Job): + def post_run(self, result) -> None: + plot_job(result, setplot=Path(__file__).parent / "setplot.py") +``` + +For cross-run analysis, use the `results` list returned by `ctrl.run()`. +Each element is a `JobResult` with `.success`, `.paths`, and `.job`; filter to +`r.success` and iterate to load output files, compute statistics, or produce +comparison plots spanning the full ensemble. This is the right place for +anything that needs data from more than one job at once. + +```python +results = ctrl.run(wait=True) +successful = [r for r in results if r.success] +# load fort.gauge, compute metrics, write ensemble_comparison.png … +``` + +--- + +## Environment variables + +| Variable | Effect | +|---|---| +| `OUTPUT_PATH` | Base directory for all job output (default: cwd) | +| `BATCH_MAX_JOBS` | Default `max_workers` for `ParallelExecutor` (default: 4) | +| `OMP_NUM_THREADS` | Number of threads to allow OpenMP (default: environment variable or 1). | + +**Note:** The value `BATCH_MAX_JOBS` x `OMP_NUM_THREADS` should not exceed the physical core count or you may run into swapping/contention problems. For example, a 16-core machine one could do `BATCH_MAX_JOBS = 2` and `OMP_NUM_THREADS=8`. For the `SLURMExecutor` or any other HPC environment this will not be an issue and the maximum number of cores available should be used. +--- + +## Examples + +- [`examples/local_ensemble/`](examples/local_ensemble/) — Manning's n + sensitivity sweep run locally with `ParallelExecutor`. +- [`examples/storm_surge/`](examples/storm_surge/) — 100-member storm + ensemble submitted to SLURM. + +--- + +## Running the tests + +```bash +pytest tests/ -v +``` + +The test suite has no dependency on an installed Clawpack or a running +scheduler. All executor and scheduler behavior is tested via mocks. + +Integration tests that exercise the actual solver are marked +`@pytest.mark.integration` and are skipped by default. + +--- + +## License + +MIT — see [LICENSE](LICENSE). diff --git a/batch/__init__.py b/batch/__init__.py index 53d37d7..49e4d8c 100644 --- a/batch/__init__.py +++ b/batch/__init__.py @@ -1,5 +1,32 @@ -from __future__ import print_function -from __future__ import absolute_import +"""batch — utilities for running Clawpack/GeoClaw batch jobs. -from .batch import Job -from .batch import BatchController \ No newline at end of file +Public API +---------- +The most commonly used names are re-exported here for convenience:: + + from batch import Job, BatchController, ClobberPolicy + from batch import SerialExecutor, ParallelExecutor + from batch import SLURMExecutor, SLURMResources + from batch.sweep import product_sweep, zip_sweep +""" + +from batch.controller import BatchController +from batch.executors.local import ParallelExecutor, SerialExecutor +from batch.executors.slurm import SLURMExecutor, SLURMResources +from batch.job import ClobberPolicy, Job, JobPaths, JobResult +from batch.plot import plot_job + +__version__ = "2.0.0" + +__all__ = [ + "Job", + "JobPaths", + "JobResult", + "ClobberPolicy", + "BatchController", + "SerialExecutor", + "ParallelExecutor", + "SLURMExecutor", + "SLURMResources", + "plot_job", +] diff --git a/batch/analysis.py b/batch/analysis.py new file mode 100644 index 0000000..428ca82 --- /dev/null +++ b/batch/analysis.py @@ -0,0 +1,3 @@ +"""Analysis Tools for Batch""" + +raise NotImplementedError("This module is not yet implemented.") diff --git a/batch/batch.py b/batch/batch.py deleted file mode 100644 index 26f8a49..0000000 --- a/batch/batch.py +++ /dev/null @@ -1,386 +0,0 @@ -r"""Simple controller for runs with GeoClaw - -Includes support for multiple runs at the same time - -""" -# ============================================================================ -# Copyright (C) 2013 Kyle Mandli -# -# Distributed under the terms of the MIT license -# http://www.opensource.org/licenses/ -# ============================================================================ - -from __future__ import print_function -from __future__ import absolute_import - -import subprocess -import os -import time -import glob - - -class Job(object): - r"""Base object for all jobs - - The ``type``, ``name``, and ``prefix`` attributes are used by the - :class:`BatchController` to create the path to the output files along with - the name of the output directories and log file. The pattern is - ``base_path/type/name/prefix*``. See :class:`BatchController` for more - information on how these are created. - - .. attribute:: type - - (string) - The top most directory that the batch output will be - located in. ``default = ""``. - - .. attribute:: name - - (string) - The second top most directory that the batch output will - be located in. ``default = ""``. - - .. attribute:: prefix - - (string) - The prefix applied to the data directory, the output - directory, and the log file. ``default = None``. - - .. attribute:: executable - - (string) - Name of the binary executable. ``default = "xclaw"``. - - .. attribute:: setplot - - (string) - Name of the module containing the `setplot` - function. ``default = "setplot"``. - - .. attribute:: rundata - - (clawpack.clawutil.data.ClawRunData) - The data object containing all - data objects. By default all data objects inside of this object will - be written out. This attribute must be instantiated by any subclass - and if not will raise a ValueError exception when asked to write out. - - :Initialization: - - Output: - - (:class:`Job`) - Initialized Job object. - """ - - def __init__(self): - r""" - Initialize a Job object - - See :class:`Job` for full documentation - """ - - super(Job, self).__init__() - - # Base job traits - self.type = "" - self.name = "" - self.setplot = "setplot" - self.prefix = None - self.executable = 'xclaw' - - self.rundata = None - - def __str__(self): - output = "Job %s: %s\n" % (self.name, self.prefix) - output += " Setplot: %s\n" % self.setplot - return output - - def write_data_objects(self): - r""" - Write out data objects contained in *rundata* - - Raises ValueError if *rundata* has not been set. - - """ - - if self.rundata is None: - raise ValueError("Must set rundata to a ClawRunData object.") - self.rundata.write() - - -class BatchController(object): - r"""Controller for Clawpack batch runs. - - Controller object that will run the set of jobs provided with the - parameters set in the object including plotting, path creation, and - simple process parallelism. - - .. attribute:: jobs - - (list) - List of :class:`Job` objects that will be run. - - .. attribute:: plot - - (bool) - If True each job will be plotted after it has run. - ``default = True`` - - .. attribute:: tar - - (bool) - If True will tar and gzip the plots directory. - ``default = False`` - - .. attribute:: verbose - - (bool) - If True will print to stdout the remaining jobs - waiting to be run and how many are currently in the process queue. - ``default = False``. - - .. attribute:: terminal_output - - (bool) - If ``paralllel`` is False, this controls where - the output is sent. If True then it will simply use stdout, if False - then the usual log file will be used. ``default = False``. - - .. attribute:: base_path - - (path) - The base path to put all output. If the - environment variable ``DATA_PATH`` is set than the ``base_path`` will - be set to that. Otherwise the current working directory (returned by - ``os.getcwd()``) will be used. - - .. attribute:: parallel - - (bool) - If True, jobs will be run in parallel. This means - that jobs will be run concurrently with other jobs up to a maximum at - one time of ``max_processes``. Once a process completes a new one is - started. ``default = True``. - - .. attribute:: wait - - (bool) - If True, the method waits to return until the last job - has completed. If False then the method returns immediately once the - last job has been added to the process queue. Default is `False`. - - .. attribute:: poll_interval - - (float) - Interval to poll for the status of each - process. Default is `5.0` seconds. - - .. attribute:: max_processes - - (int) - The maximum number of processes that can be run - at one time. If the environment variable `OMP_NUM_THREADS` is set then - this defaults to that number. Otherwise `4` is used. - - .. attribute:: runclaw_cmd - - (string) - The string that stores the base command for the - run command. - ``default = "python $CLAW/clawutil/src/python/clawutil/runclaw.py"``. - - .. attribute:: plotclaw_cmd - - (string) - The string that stores the base command for the - plotting command. - ``default = "python $CLAW/visclaw/src/visclaw/plotclaw.py"``. - - :Initialization: - - Input: - - *jobs* - (list) List of :class:`Job` objects to be run. - - Output: - - (:class:`BatchController`) Initialized BatchController object - - """ - - def __init__(self, jobs=[]): - r"""Initialize a BatchController object. - - See :class:`BatchController` for full documentation - - """ - - super(BatchController, self).__init__() - - # Establish controller default parameters - # Execution controls - self.plot = True - self.tar = False - self.verbose = False - self.terminal_output = False - - # Path controls - if 'DATA_PATH' in os.environ.keys(): - self.base_path = os.environ['DATA_PATH'] - else: - self.base_path = os.getcwd() - self.base_path = os.path.expanduser(self.base_path) - - # Parallel run controls - self.parallel = True - self.wait = False - self.poll_interval = 5.0 - if 'OMP_NUM_THREADS' in os.environ.keys(): - self.max_processes = int(os.environ['OMP_NUM_THREADS']) - else: - self.max_processes = 4 - self._process_queue = [] - - # Default commands for running and plotting - self.runclaw_cmd = "python $CLAW/clawutil/src/python/clawutil/runclaw.py" - self.plotclaw_cmd = "python $CLAW/visclaw/src/python/visclaw/plotclaw.py" - - # Add the initial jobs to the jobs list - if not isinstance(jobs, list) and not isinstance(jobs, tuple): - raise ValueError("Jobs must be a list or tuple.") - self.jobs = [] - for job in jobs: - if isinstance(job, Job): - self.jobs.append(job) - else: - raise ValueError("Elements of jobs must be a Job.") - - def __str__(self): - output = "" - for (i, job) in enumerate(self.jobs): - output += "====== Job #%s ============================\n" % (i) - output += str(job) + "\n" - return output - - def run(self, only_write_data=False): - r"""Run jobs from controller's *jobs* list. - - For each :class:`Job` object in *jobs* create a set of paths, directory - structures, and log files in preperation for running the commands - constructed. If *parallel* is True then jobs are started and added - to a queue with a maximum of *maximum_processes*. If *parallel* is - False each job is run to completion before continuing. The *wait* - parameter controls whether the function waits for the last job to run - before returning. - - Output: - - *paths* - (list) List of dictionaries containing paths to the data - constructed for each job. The dictionary has keys 'job', 'data', - 'output', 'plots', and 'log' which respectively stores the base - directory of the job, the data, output, and plot directories, and - the log file. - - """ - - # Run jobs - paths = [] - for (i, job) in enumerate(self.jobs): - # Create output directory - data_dirname = ''.join((job.prefix, '_data')) - output_dirname = ''.join((job.prefix, "_output")) - plots_dirname = ''.join((job.prefix, "_plots")) - log_name = ''.join((job.prefix, "_log.txt")) - - if len(job.type) > 0: - job_path = os.path.join(self.base_path, job.type, job.name) - else: - job_path = os.path.join(self.base_path, job.name) - job_path = os.path.abspath(job_path) - data_path = os.path.join(job_path, data_dirname) - output_path = os.path.join(job_path, output_dirname) - plots_path = os.path.join(job_path, plots_dirname) - log_path = os.path.join(job_path, log_name) - paths.append({'job': job_path, 'data': data_path, - 'output': output_path, 'plots': plots_path, - 'log': log_path}) - - # Create job directory if not present - if not os.path.exists(job_path): - os.makedirs(job_path) - - # Clobber old data directory - if os.path.exists(data_path): - if not job.rundata.clawdata.restart: - data_files = glob.glob(os.path.join(data_path, "*.data")) - for data_file in data_files: - os.remove(data_file) - else: - os.mkdir(data_path) - - # Open and start log file - log_file = open(log_path, 'w') - tm = time.localtime() - year = str(tm[0]).zfill(4) - month = str(tm[1]).zfill(2) - day = str(tm[2]).zfill(2) - hour = str(tm[3]).zfill(2) - minute = str(tm[4]).zfill(2) - second = str(tm[5]).zfill(2) - date = 'Started %s/%s/%s-%s:%s.%s' % (year, month, day, hour, - minute, second) - log_file.write(date) - - # Write out data - temp_path = os.getcwd() - os.chdir(data_path) - job.write_data_objects() - os.chdir(temp_path) - - if only_write_data: - continue - - # Handle restart requests - if job.rundata.clawdata.restart: - restart = "T" - overwrite = "F" - else: - restart = "F" - overwrite = "T" - - # Construct string commands - run_cmd = "%s %s %s %s %s %s True" % (self.runclaw_cmd, - job.executable, - output_path, - overwrite, - restart, - data_path) - if self.plot: - plot_cmd = "%s %s %s %s" % (self.plotclaw_cmd, - output_path, - plots_path, - job.setplot) - tar_cmd = "tar -cvzf %s.tgz -C %s/.. %s" % (plots_path, - plots_path, - os.path.basename( - plots_path)) - cmd = run_cmd - if self.plot: - cmd = ";".join((cmd, plot_cmd)) - if self.tar: - cmd = ";".join((cmd, tar_cmd)) - - # Run jobs - if self.parallel: - while len(self._process_queue) == self.max_processes: - if self.verbose: - print("Number of processes currently:", - len(self._process_queue)) - for process in self._process_queue: - if process.poll() is not None: - self._process_queue.remove(process) - time.sleep(self.poll_interval) - self._process_queue.append(subprocess.Popen(cmd, shell=True, - stdout=log_file, - stderr=log_file)) - - else: - if self.terminal_output: - log_file.write("Outputting to terminal...") - subprocess.Popen(cmd, shell=True).wait() - log_file.write("Command completed.") - else: - subprocess.Popen(cmd, shell=True, stdout=log_file, - stderr=log_file).wait() - - # -- All jobs have been started -- - - # Wait to exit while processes are still going - if self.wait: - while len(self._process_queue) > 0: - time.sleep(self.poll_interval) - for process in self._process_queue: - if process.poll() is not None: - self._process_queue.remove(process) - print("Number of processes currently: %s" % - len(self._process_queue)) - - return paths diff --git a/batch/controller.py b/batch/controller.py new file mode 100644 index 0000000..9cf809a --- /dev/null +++ b/batch/controller.py @@ -0,0 +1,204 @@ +"""BatchController: orchestrates path setup, data writing, and job dispatch.""" + +from __future__ import annotations + +import logging +import os +from datetime import datetime +from pathlib import Path + +from batch.executors import Executor +from batch.job import ClobberPolicy, Job, JobPaths, JobResult + +logger = logging.getLogger(__name__) + + +class BatchController: + """Orchestrate setup and submission for a list of jobs. + + The controller is responsible for: + + 1. Computing the canonical directory layout for each job. + 2. Creating directories and applying the clobber policy. + 3. Writing a log-file header. + 4. Calling ``job.write_data_objects()`` to produce ``.data`` files. + 5. Calling ``job.build()`` to compile the executable (no-op by default). + 6. Dispatching to the executor. + + None of this is scheduler-specific; all scheduler differences live in the + :mod:`batch.executors` implementations. + + Parameters + ---------- + jobs: + Jobs to run. May also be added later by appending to ``self.jobs``. + executor: + Backend that actually runs or queues jobs. Defaults to + ``ParallelExecutor(max_workers=4)``. + base_path: + Root output directory. Falls back to the ``OUTPUT_PATH`` environment + variable, then the current working directory. + experiment: + Subdirectory under *base_path* grouping all jobs in this batch. + Typically the name of the experiment or storm (e.g. ``"hurricane_ike"``). + Leave empty to write directly under *base_path*. + clobber: + Policy for pre-existing job directories. See :class:`ClobberPolicy`. + """ + + def __init__( + self, + jobs: list[Job] | None = None, + executor: Executor | None = None, + base_path: Path | str | None = None, + experiment: str = "", + clobber: ClobberPolicy = ClobberPolicy.OVERWRITE, + ) -> None: + self.jobs: list[Job] = list(jobs) if jobs else [] + + if executor is None: + from batch.executors.local import ParallelExecutor + + executor = ParallelExecutor() + self.executor: Executor = executor + + if base_path is None: + base_path = os.environ.get("OUTPUT_PATH", os.getcwd()) + self.base_path = Path(base_path).expanduser().resolve() + self.experiment = experiment + self.clobber = clobber + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _root(self) -> Path: + """Base path with optional experiment subdirectory.""" + if self.experiment: + return self.base_path / self.experiment + return self.base_path + + def _make_paths(self, job: Job) -> JobPaths: + """Compute the canonical directory layout for *job*.""" + if not job.prefix: + raise ValueError( + f"Job {job!r} has no prefix set. " + "Assign job.prefix before adding it to the controller." + ) + job_dir = self._root() / job.prefix + return JobPaths( + job=job_dir, + plots=job_dir / "plots", + log=job_dir / f"{job.prefix}_log.txt", + ) + + def _setup_job_dir(self, job: Job, paths: JobPaths) -> bool: + """Create the job directory, applying the clobber policy. + + Returns + ------- + bool + True if the job should proceed; False if it should be skipped + (``ClobberPolicy.SKIP`` with an existing directory). + + Raises + ------ + FileExistsError + When ``ClobberPolicy.ERROR`` and the directory already exists. + """ + if paths.job.exists(): + if self.clobber is ClobberPolicy.ERROR: + raise FileExistsError( + f"Job directory already exists: {paths.job}\n" + "Use ClobberPolicy.OVERWRITE to allow re-running or " + "ClobberPolicy.SKIP to resume a partial batch." + ) + if self.clobber is ClobberPolicy.SKIP: + logger.info("Skipping job %s (directory exists)", job.prefix) + return False + # OVERWRITE: remove stale .data files unless restarting + if not job.restart: + for f in paths.job.glob("*.data"): + f.unlink() + logger.debug("Removed stale data file: %s", f) + else: + paths.job.mkdir(parents=True, exist_ok=True) + return True + + @staticmethod + def _write_log_header(paths: JobPaths) -> None: + with open(paths.log, "w") as fh: + fh.write(f"Started {datetime.now().isoformat()}\n") + fh.write("-" * 60 + "\n") + + # ------------------------------------------------------------------ + # Public interface + # ------------------------------------------------------------------ + + def setup(self) -> list[JobPaths]: + """Write ``.data`` files for all jobs without running them. + + Useful for staging a batch before submission, or for inspecting + what would be written. + + Returns + ------- + list[JobPaths] + Paths for every job that was set up (skipped jobs are omitted). + """ + all_paths: list[JobPaths] = [] + for job in self.jobs: + paths = self._make_paths(job) + if not self._setup_job_dir(job, paths): + continue + self._write_log_header(paths) + job.write_data_objects(paths.job) + job.paths = paths + all_paths.append(paths) + logger.info("Setup complete for job %s → %s", job.prefix, paths.job) + return all_paths + + def run(self, wait: bool = True) -> list[JobResult]: + """Set up, optionally build, and submit all jobs. + + Parameters + ---------- + wait: + If True (default), block until all jobs complete. Set False for + SLURM/PBS backends if you want to return immediately after + submission and check results later. + + Returns + ------- + list[JobResult] + One result per submitted job. Skipped jobs are omitted. + ``result.returncode`` is None for scheduler-submitted jobs when + ``wait=False``. + """ + results: list[JobResult] = [] + for job in self.jobs: + paths = self._make_paths(job) + if not self._setup_job_dir(job, paths): + continue + + self._write_log_header(paths) + job.write_data_objects(paths.job) + job.build(paths) + job.paths = paths + + result = self.executor.submit(job, paths) + results.append(result) + + if wait: + self.executor.wait_all(results) + + failures = [ + r for r in results if r.returncode is not None and r.returncode != 0 + ] + if failures: + logger.warning( + "%d job(s) failed: %s", + len(failures), + [r.job.prefix for r in failures], + ) + return results diff --git a/batch/executors/__init__.py b/batch/executors/__init__.py new file mode 100644 index 0000000..98c9d00 --- /dev/null +++ b/batch/executors/__init__.py @@ -0,0 +1,25 @@ +"""Executor protocol and re-exports for batch execution backends.""" + +from __future__ import annotations + +from typing import Protocol + +from batch.job import Job, JobPaths, JobResult + + +class Executor(Protocol): + """Interface that all execution backends must satisfy. + + Both :class:`~batch.executors.local.SerialExecutor` and + :class:`~batch.executors.local.ParallelExecutor` and + :class:`~batch.executors.slurm.SLURMExecutor` implement this protocol, + as does any custom executor the caller provides. + """ + + def submit(self, job: Job, paths: JobPaths) -> JobResult: + """Start or queue one job and return its result object.""" + ... + + def wait_all(self, results: list[JobResult]) -> list[JobResult]: + """Block until every result in *results* has a final returncode.""" + ... diff --git a/batch/executors/local.py b/batch/executors/local.py new file mode 100644 index 0000000..a84e842 --- /dev/null +++ b/batch/executors/local.py @@ -0,0 +1,185 @@ +"""Local executors: serial and parallel subprocess-based runners.""" + +from __future__ import annotations + +import logging +import os +import subprocess +import sys +import time + +from batch.job import Job, JobPaths, JobResult + +logger = logging.getLogger(__name__) + + +def _build_run_args(job: Job, paths: JobPaths) -> list[str]: + """Build the argument list for invoking runclaw. + + Uses ``python -m clawpack.clawutil.runclaw`` so the invocation works + wherever clawpack is importable without requiring ``$CLAW`` to be set. + The runclaw positional interface is:: + + runclaw.py + + With the flattened directory layout, ``outdir`` and ``rundir`` are both + ``paths.job``. + """ + return [ + sys.executable, + "-m", + "clawpack.clawutil.runclaw", + str(job.executable), + str(paths.job), # outdir + "F" if job.restart else "T", # overwrite + "T" if job.restart else "F", # restart + str(paths.job), # rundir (same directory) + "True", # verbose + ] + + +class SerialExecutor: + """Run jobs one at a time, blocking until each finishes. + + This is the simplest executor and the right choice for interactive or + debugging runs. The calling process blocks until every job completes, + so ``wait_all`` is a no-op. + + Parameters + ---------- + extra_args: + Additional arguments appended to the runclaw invocation. Rarely + needed but provided as an escape hatch. + env: + Additional environment variables to set for each job. Useful for + example to set ``OMP_NUM_THREADS`` for OpenMP-based executables. + """ + + def __init__( + self, extra_args: list[str] | None = None, env: dict[str, str] | None = None + ) -> None: + self.extra_args = extra_args or [] + self.env = env or {} + + def submit(self, job: Job, paths: JobPaths) -> JobResult: + """Run the job synchronously and return its result.""" + args = _build_run_args(job, paths) + self.extra_args + run_env = os.environ.copy() + run_env.update(self.env) + logger.info("Running job %s: %s", job.prefix, " ".join(args)) + with open(paths.log, "a") as log: + proc = subprocess.run(args, stdout=log, stderr=log, env=run_env) + if proc.returncode != 0: + logger.error("Job %s failed (returncode=%d)", job.prefix, proc.returncode) + result = JobResult(job=job, paths=paths, returncode=proc.returncode) + if result.returncode == 0: + try: + result.job.post_run(result) + except Exception: + logger.exception("post_run failed for job %s", job.prefix) + return result + + def wait_all(self, results: list[JobResult]) -> list[JobResult]: + """No-op — all jobs already completed in ``submit``.""" + return results + + +class ParallelExecutor: + """Run up to *max_workers* jobs concurrently as subprocesses. + + Jobs are submitted as soon as a slot is free. ``wait_all`` drains the + remaining queue before returning. + + Parameters + ---------- + max_workers: + Maximum number of simultaneous subprocesses. Defaults to the value + of the ``BATCH_MAX_JOBS`` environment variable, or 4 if that is not + set. Set this to the number of independent jobs you want in flight + at once, not to the number of OpenMP threads per job. + poll_interval: + Seconds between queue drain checks. Default 5.0. + extra_args: + Additional arguments appended to every runclaw invocation. + env: + Additional environment variables to set for each job. Useful for + example to set ``OMP_NUM_THREADS`` for OpenMP-based executables. + """ + + def __init__( + self, + max_workers: int = int(os.environ.get("BATCH_MAX_JOBS", 4)), + poll_interval: float = 5.0, + extra_args: list[str] | None = None, + env: dict[str, str] | None = None, + ) -> None: + self.max_workers = max_workers + self.poll_interval = poll_interval + self.extra_args = extra_args or [] + self.env = env or {} + # Each entry: (Popen, JobResult, open log file handle) + self._active: list[tuple[subprocess.Popen, JobResult, object]] = [] + + def submit(self, job: Job, paths: JobPaths) -> JobResult: + """Start the job, blocking only if the worker pool is full.""" + self._drain() + while len(self._active) >= self.max_workers: + time.sleep(self.poll_interval) + self._drain() + + args = _build_run_args(job, paths) + self.extra_args + run_env = os.environ.copy() + run_env.update(self.env) + log_fh = open(paths.log, "a") + proc = subprocess.Popen(args, stdout=log_fh, stderr=log_fh, env=run_env) + result = JobResult(job=job, paths=paths, returncode=None) + self._active.append((proc, result, log_fh)) + logger.info("Started job %s (pid=%d)", job.prefix, proc.pid) + return result + + def _drain(self) -> None: + """Harvest completed processes. + + Rebuilds ``_active`` via list comprehension to avoid the + modify-while-iterating pitfall of the original implementation. + """ + still_running = [] + for proc, result, log_fh in self._active: + rc = proc.poll() + if rc is not None: + result.returncode = rc + log_fh.close() + if rc != 0: + logger.error( + "Job %s failed (rc=%d) — last 10 lines of %s:", + result.job.prefix, + rc, + result.paths.log, + ) + # Emit the tail of the log so failures are visible + try: + lines = result.paths.log.read_text().splitlines() + for line in lines[-10:]: + logger.error(" %s", line) + except OSError: + pass + else: + logger.info("Job %s complete", result.job.prefix) + try: + result.job.post_run(result) + except Exception: + logger.exception( + "post_run failed for job %s", result.job.prefix + ) + else: + still_running.append((proc, result, log_fh)) + self._active = still_running + + def wait_all(self, results: list[JobResult]) -> list[JobResult]: + """Block until all in-flight jobs finish.""" + while self._active: + time.sleep(self.poll_interval) + self._drain() + if self._active: + logger.info("%d job(s) still running", len(self._active)) + return results diff --git a/batch/executors/slurm.py b/batch/executors/slurm.py new file mode 100644 index 0000000..84d3dd0 --- /dev/null +++ b/batch/executors/slurm.py @@ -0,0 +1,231 @@ +"""SLURM executor: submit jobs to a SLURM scheduler via sbatch.""" + +from __future__ import annotations + +import logging +import subprocess +import time +from dataclasses import dataclass, field + +from batch.executors.local import _build_run_args +from batch.job import Job, JobPaths, JobResult + +logger = logging.getLogger(__name__) + + +@dataclass +class SLURMResources: + """SLURM resource request for a single job. + + Maps directly to ``#SBATCH`` directives. Attach an instance to + ``job.slurm_resources`` to override the executor's defaults on a per-job + basis. + + Parameters + ---------- + partition: + SLURM partition (queue) name. + nodes: + Number of nodes to request. + ntasks_per_node: + MPI tasks per node. For pure-OpenMP GeoClaw runs this should be 1. + cpus_per_task: + CPUs (hardware threads) per task. Set this to ``OMP_NUM_THREADS``. + time: + Walltime limit in ``HH:MM:SS`` format. + memory: + Memory per node, e.g. ``"4G"``. Empty string uses the partition + default. + account: + Allocation account (``-A``). Required on most HPC allocations. + constraint: + Node feature constraint, e.g. ``"cpu"`` on Derecho or ``"knl"`` on + older Stampede partitions. + modules: + List of module names to load (``module load ``). + env_vars: + Environment variables to export in the job script. The canonical use + is ``{"OMP_NUM_THREADS": "8"}``. + email: + Email address for job notifications. Empty string disables mail. + mail_type: + Comma-separated SLURM mail event types. Default ``"END,FAIL"``. + extra_directives: + Raw ``#SBATCH`` lines appended after the standard directives. Use for + anything not covered above (GRES, licenses, heterogeneous jobs, etc.). + """ + + partition: str = "main" + nodes: int = 1 + ntasks_per_node: int = 1 + cpus_per_task: int = 1 + time: str = "01:00:00" + memory: str = "" + account: str = "" + constraint: str = "" + modules: list[str] = field(default_factory=list) + env_vars: dict[str, str] = field(default_factory=dict) + email: str = "" + mail_type: str = "END,FAIL" + extra_directives: list[str] = field(default_factory=list) + + +def render_slurm_script( + job: Job, + paths: JobPaths, + resources: SLURMResources, +) -> str: + """Generate a self-contained sbatch script for one job. + + This is a pure function — it does not touch the filesystem or call any + external processes, which makes it straightforward to unit-test without + a cluster. + + Parameters + ---------- + job: + The job being submitted. + paths: + Pre-computed filesystem layout. + resources: + SLURM resource requests. + + Returns + ------- + str + Complete bash script text, ready to write to a ``.sh`` file. + """ + run_cmd = " ".join(str(a) for a in _build_run_args(job, paths)) + + # Standard directives — always present + directives = [ + f"#SBATCH -J {job.prefix}", + f"#SBATCH -o {paths.log}", + f"#SBATCH -e {paths.log}", + f"#SBATCH -p {resources.partition}", + f"#SBATCH -N {resources.nodes}", + f"#SBATCH --ntasks-per-node={resources.ntasks_per_node}", + f"#SBATCH --cpus-per-task={resources.cpus_per_task}", + f"#SBATCH -t {resources.time}", + ] + + # Optional directives + if resources.memory: + directives.append(f"#SBATCH --mem={resources.memory}") + if resources.account: + directives.append(f"#SBATCH -A {resources.account}") + if resources.constraint: + directives.append(f"#SBATCH --constraint={resources.constraint}") + if resources.email: + directives.append(f"#SBATCH --mail-user={resources.email}") + directives.append(f"#SBATCH --mail-type={resources.mail_type}") + directives.extend(resources.extra_directives) + + lines: list[str] = ["#!/bin/bash"] + directives + [""] + + if resources.modules: + lines.extend(f"module load {m}" for m in resources.modules) + lines.append("") + + if resources.env_vars: + lines.extend(f"export {k}={v}" for k, v in resources.env_vars.items()) + lines.append("") + + lines.append(run_cmd) + lines.append("") # ensure trailing newline + + return "\n".join(lines) + + +class SLURMExecutor: + """Submit jobs to SLURM via ``sbatch``. + + ``submit`` returns immediately after queuing; ``wait_all`` polls + ``squeue`` until all submitted jobs leave the queue. + + Per-job resource overrides are supported by attaching a + :class:`SLURMResources` instance as ``job.slurm_resources``. Jobs + without that attribute use ``default_resources``. + + Parameters + ---------- + default_resources: + Resource defaults applied to every job that does not carry its own + ``slurm_resources`` attribute. + dry_run: + If True, write the submission script but do not call ``sbatch``. + Useful for inspecting what would be submitted. + poll_interval: + Seconds between ``squeue`` polls in ``wait_all``. Default 30.0. + """ + + def __init__( + self, + default_resources: SLURMResources | None = None, + dry_run: bool = False, + poll_interval: float = 30.0, + ) -> None: + self.default_resources = default_resources or SLURMResources() + self.dry_run = dry_run + self.poll_interval = poll_interval + + def submit(self, job: Job, paths: JobPaths) -> JobResult: + resources: SLURMResources = getattr( + job, "slurm_resources", self.default_resources + ) + script = render_slurm_script(job, paths, resources) + + script_path = paths.job / f"{job.prefix}_run.sh" + script_path.write_text(script) + logger.debug("Wrote submission script: %s", script_path) + + if self.dry_run: + logger.info("[dry-run] Would submit: %s", script_path) + return JobResult(job=job, paths=paths, returncode=None, job_id="dry-run") + + proc = subprocess.run( + ["sbatch", "--parsable", str(script_path)], + capture_output=True, + text=True, + check=True, + ) + # --parsable output: "" or ";" + job_id = proc.stdout.strip().split(";")[0] + logger.info("Submitted job %s → SLURM job ID %s", job.prefix, job_id) + return JobResult(job=job, paths=paths, returncode=None, job_id=job_id) + + def wait_all(self, results: list[JobResult]) -> list[JobResult]: + """Poll squeue until all submitted jobs leave the queue.""" + pending = {r.job_id: r for r in results if r.job_id and r.job_id != "dry-run"} + while pending: + time.sleep(self.poll_interval) + completed = [] + for job_id in list(pending): + proc = subprocess.run( + ["squeue", "--job", job_id, "--noheader"], + capture_output=True, + text=True, + ) + if not proc.stdout.strip(): + # Job no longer in queue — finished (success or failure). + # squeue exit code is non-zero for unknown job IDs on some + # clusters so we key on empty stdout rather than returncode. + pending[job_id].returncode = 0 + try: + pending[job_id].job.post_run(pending[job_id]) + except Exception: + logger.exception( + "post_run failed for job %s", + pending[job_id].job.prefix, + ) + logger.info( + "Job %s (SLURM %s) left the queue", + pending[job_id].job.prefix, + job_id, + ) + completed.append(job_id) + for job_id in completed: + del pending[job_id] + if pending: + logger.info("%d job(s) still in queue", len(pending)) + return results diff --git a/batch/gauge_error_analysis.py b/batch/gauge_error_analysis.py deleted file mode 100644 index 5fbe1a5..0000000 --- a/batch/gauge_error_analysis.py +++ /dev/null @@ -1,410 +0,0 @@ -""" -Script to read gauge data and calculate errors relative to the baseline run for the batch runs -""" -# ============================================================================ -# Copyright (C) 2013 Kyle Mandli -# -# Distributed under the terms of the MIT license -# http://www.opensource.org/licenses/ -# -# (post process scripts by Akshay Sriapda , 2017) -# ============================================================================ - - -from collections import defaultdict -import numpy -import matplotlib.pyplot as plt -import os -import glob - - - -def extract_level_data(data,var,level): - # get gauge data at particular levels - lev_data = [] - for i in data: - if i[0] == level: - lev_data.append([i[1],i[var]]) - return numpy.array(lev_data) - -def plot_data(test,base,gauge,level,dir): - # plots all the h,hu,hv, Eta vs time - directory = dir+'/Plots/' - if not os.path.exists(directory): - os.makedirs(directory) - colors = ['or','ob','oy','om','oc','ok','or','ob','-g'] - labels = ['Baseline','Level '] - marker_size = 7 - width = 3 - plt.ioff() - fig, ((ax1, ax2,ax3, ax4)) = plt.subplots(4) - - ax1.plot(base[:,1],base[:,2],'-k',linewidth = width,markerfacecolor = 'w',markersize = marker_size) - for k in range(1,level+1): - lev_data = extract_level_data(test,2,k) - if len(lev_data) > 0: - if k == level: - c = colors[-1] - else: - c = colors[k-1] - ax1.plot(lev_data[:,0],lev_data[:,1],c,linewidth = width,markersize = marker_size) - ax1.set_title('h') - ax1.set_xlabel("Time") - ax1.set_ylabel("Distance") - - ax2.plot(base[:,1],base[:,3],'-k',linewidth = width,markerfacecolor = 'w',markersize = marker_size) - for k in range(1,level+1): - lev_data = extract_level_data(test,3,k) - if len(lev_data) > 0: - if k == level: - c = colors[-1] - else: - c = colors[k-1] - ax2.plot(lev_data[:,0],lev_data[:,1],c,linewidth = width,markersize = marker_size) - ax2.set_title('hu') - ax2.set_xlabel("Time") - ax2.set_ylabel("Momentum") - - ax3.plot(base[:,1],base[:,4],'-k',linewidth = width,markerfacecolor = 'w',markersize = marker_size) - for k in range(1,level+1): - lev_data = extract_level_data(test,4,k) - if len(lev_data) > 0: - if k == level: - c = colors[-1] - else: - c = colors[k-1] - ax3.plot(lev_data[:,0],lev_data[:,1],c,linewidth = width,markersize = marker_size) - ax3.set_title('hv') - ax3.set_xlabel("Time") - ax3.set_ylabel("Momentum") - - ax4.plot(base[:,1],base[:,5],'-k',linewidth = width,markerfacecolor = 'w',markersize = marker_size,label= labels[0]) - for k in range(1,level+1): - lev_data = extract_level_data(test,5,k) - if len(lev_data) > 0: - if k == level: - c = colors[-1] - else: - c = colors[k-1] - ax4.plot(lev_data[:,0],lev_data[:,1],c,linewidth = width,markersize = marker_size,label= labels[1]+str(k)) - ax4.set_title('Eta') - ax4.set_xlabel("Time") - ax4.set_ylabel("Distance") - - plt.legend(bbox_to_anchor=(0, -0.2, 1., -0.2), loc=2,ncol=4, mode="expand", borderaxespad=0.) - - plt.tight_layout() - plt.savefig(directory+str(gauge)+'_Data.png') - -def plot_error(gauge,error,level,dir): - # plots error relative to basseline for the four unkowns - directory = dir+'/Plots/' - if not os.path.exists(directory): - os.makedirs(directory) - colors = ['or','ob','oy','om','oc','ok','or','ob','-g'] - labels = ['Baseline','Level '] - marker_size = 7 - width = 3 - plt.ioff() - fig, ((ax1, ax2,ax3, ax4)) = plt.subplots(4) - - for k in range(1,level+1): - lev_data = extract_level_data(error,2,k) - if len(lev_data) > 0: - if k == level: - c = colors[-1] - else: - c = colors[k-1] - ax1.plot(lev_data[:,0],lev_data[:,1],c,linewidth = width,markersize = marker_size) - ax1.set_title('h') - ax1.set_xlabel("Time") - ax1.set_ylabel("Error") - - for k in range(1,level+1): - lev_data = extract_level_data(error,3,k) - if len(lev_data) > 0: - if k == level: - c = colors[-1] - else: - c = colors[k-1] - ax2.plot(lev_data[:,0],lev_data[:,1],c,linewidth = width,markersize = marker_size) - ax2.set_title('hu') - ax2.set_xlabel("Time") - ax2.set_ylabel("Error") - - for k in range(1,level+1): - lev_data = extract_level_data(error,4,k) - if len(lev_data) > 0: - if k == level: - c = colors[-1] - else: - c = colors[k-1] - ax3.plot(lev_data[:,0],lev_data[:,1],c,linewidth = width,markersize = marker_size) - ax3.set_title('hv') - ax3.set_xlabel("Time") - ax3.set_ylabel("Error") - - for k in range(1,level+1): - lev_data = extract_level_data(error,5,k) - if len(lev_data) > 0: - if k == level: - c = colors[-1] - else: - c = colors[k-1] - ax4.plot(lev_data[:,0],lev_data[:,1],c,linewidth = width,markersize = marker_size,label= labels[1]+str(k)) - ax4.set_title('Eta') - ax4.set_xlabel("Time") - ax4.set_ylabel("Error") - - plt.legend(bbox_to_anchor=(0, -0.2, 1., -0.2), loc=2,ncol=4, mode="expand", borderaxespad=0.) - - plt.tight_layout() - plt.savefig(directory+str(gauge)+'_Error.png') - - -def norm_calc(error,level,max_quants,debug_file): - # calculates error norm value which is normalized by the largest value in the dataset - level_error = numpy.empty([4,3]) - - if level == 0: - for i in range(2,6): - level_error[i-2,0] = numpy.linalg.norm(error[:,i]/max_quants[i-2], ord=1) - level_error[i-2,1] = numpy.linalg.norm(error[:,i]/max_quants[i-2], ord=2) - level_error[i-2,2] = numpy.linalg.norm(error[:,i]/max_quants[i-2], ord=numpy.inf) - - else: - - for i in range(2,6): - lev_error = extract_level_data(error,i,level) - - if len(lev_error) == 0: - #print 'No data found for level '+str(level)+ ' ' - level_error[i-2,0] = 0.0 - level_error[i-2,1] = 0.0 - level_error[i-2,2] = 0.0 - else: - level_error[i-2,0] = numpy.linalg.norm(lev_error[:,1]/max_quants[i-2], ord=1) - level_error[i-2,1] = numpy.linalg.norm(lev_error[:,1]/max_quants[i-2], ord=2) - level_error[i-2,2] = numpy.linalg.norm(lev_error[:,1]/max_quants[i-2], ord=numpy.inf) - - return [level_error,debug_file] - -def interpolate(test,base_0,base_1): - # interpolates data points in the baseline for time steps not recorded - x = test[1] - x_0 = base_0[1] - x_1 = base_1[1] - - x_ratio = (x - x_0)/(x_1 - x_0) - base = [0,0,0,0] - for i in range(2,6): - base[i-2] = (x_ratio*(base_1[i] - base_0[i])) + base_0[i] - - return base - -def error_calc(test,base,gauge,number_of_levels,dir,summary_file,output_file,debug_file): - # Actual error calculations done for each timestep the data is recorded - for i in range(0,len(test)): - if test[i,1] > base[-1,1]: - cutoff_location = i-1 - break - else: - cutoff_location = i - - error = numpy.empty(test.shape) - interpolate_base = numpy.empty([cutoff_location,6]) - - for i in range(0,cutoff_location): - for j in range(0,len(base)): - if test[i,1] == base[j,1]: - interpolate_base[i] = base[j] - break - elif base[j,1] > test[i,1]: - interpolate_base[i,0:2] = test[i,0:2] - interpolate_base[i,2:6] = interpolate(test[i,:],base[j-1,:],base[j,:]) - break - - error = numpy.empty(interpolate_base.shape) - error[:,0:2] = interpolate_base[:,0:2] - error[:,2:6] = numpy.abs(test[:cutoff_location,2:6] - interpolate_base[:,2:6]) - - - max_quants = [] - for k in range(2,6): - max_quants.append(numpy.max(base[:,k])) - - for i in range(0,int(number_of_levels)+1): - - [norm_error,debug_file] = norm_calc(error,i,max_quants,debug_file) - - if i == 0: - output_file.write('All levels L1 Norm error for gauge '+str(gauge)+' is '+str(norm_error[:,0])+ '\n') - output_file.write('All levels L2 Norm error for gauge '+str(gauge)+' is '+str(norm_error[:,1])+ '\n') - output_file.write('All levels Infinity Norm error for gauge '+str(gauge)+' is '+str(norm_error[:,2]) + '\n') - summary_file.write(str(norm_error[3,0])+' '+str(norm_error[3,1])+' '+str(norm_error[3,2])) - else: - output_file.write('Level '+str(i)+' L1 Norm error for gauge '+str(gauge)+' is '+str(norm_error[:,0])+ '\n') - output_file.write('Level '+str(i)+' L2 Norm error for gauge '+str(gauge)+' is '+str(norm_error[:,1])+ '\n') - output_file.write('Level '+str(i)+' Infinity Norm error for gauge '+str(gauge)+' is '+str(norm_error[:,2]) + '\n') - - return [output_file,summary_file,debug_file] - -def get_num_of_levels(root_path,sweep_data): - # Uses code from Clawpack file plot_num_grids.py to calculate the number of cells used at each output time step - for k in range(0,len(sweep_data)): - - output_path = root_path+'sweep_'+str(k)+'_output/fort.q*' - num_levels = sweep_data[k][2] - file_list = glob.glob(output_path) - time = numpy.empty(len(file_list), dtype=float) - num_grids = numpy.zeros((time.shape[0], num_levels), dtype=int) - num_cells = numpy.zeros((time.shape[0], num_levels), dtype=int) - - for (n,path) in enumerate(file_list): - t_path = path[:-5] + "t" + path[-4:] - t_file = open(t_path, 'r') - time[n] = float(t_file.readline().split()[0]) - t_file.readline() - t_file_num_grids = int(t_file.readline().split()[0]) - t_file.close() - q_file = open(path, 'r') - line = "\n" - while line != "": - line = q_file.readline() - if "grid_number" in line: - level = int(q_file.readline().split()[0]) - num_grids[n,level-1] += 1 - mx = int(q_file.readline().split()[0]) - my = int(q_file.readline().split()[0]) - num_cells[n,level-1 ] += mx * my - q_file.close() - if numpy.sum(num_grids[n,:]) != t_file_num_grids: - raise Exception("Number of grids in fort.t* file and fort.q* file do not match.") - - numpy.savetxt('post-process-data/num_cells_run_'+str(k)+'.txt',num_cells) - numpy.savetxt('post-process-data/output-times.txt',time) - -if __name__ == "__main__": - # point output_path to the output folder from the scenarios - output_path = '../../scratch/Tohoku-hawaii/tohoku-hawaii-scenario/all-scenarios-1/' - - # point run_data_path to where the run details text file is located. - # If there is no text file with the details, create an array with the deatils - run_data_path = '../../scratch/Tohoku-hawaii/' - sweep_data = numpy.loadtxt(run_data_path+'run-data.txt') - - # point log_path to where all the run log files are located - log_path = '../../scratch/Tohoku-hawaii/' - - # All the post processed data will be stored in a new folder names post-process-data - if not os.path.exists('post-process-data'): - os.makedirs('post-process-data') - - # summary file contains error norm and run time data, used for visualizing - summary_file = open('post-process-data/summary-data.txt','w') - # output file logs all the erros for each gauge for each sweep - output_file = open('post-process-data/output-data.txt','w') - # debug file can be used to log the parameters passed through the various functions - debug_file = open('post-process-data/debug-data.txt','w') - - gauges_data = open(output_path+'sweep_0_output/gauges.data') - get_num_of_levels(output_path,sweep_data) - gauge_list = [] - count = 0 - - for i in gauges_data: - l = i.strip().split(' ') - count += 1 - if count >= 8 and l == ['']: - break - elif count>= 8: - gauge_list.append(int(l[0])) - - gauge_list.sort() - baseline_gauges = defaultdict(list) - - for i in gauge_list: - - l = len(str(i)) - if l < 5: - t = '0'*(5-l)+str(i) - else: - t = str(i) - baseline_gauges[i] = numpy.loadtxt(output_path+'sweep_0_output/gauge' + t +'.txt') - baseline_gauges[i] = numpy.array(baseline_gauges[i]) - - log_file = open(output_path + 'sweep_0_log.txt') - run_number = log_file.readline().strip().split('.')[0] - base_log_data = open(log_path+'sweep_0.o'+run_number) - - number_of_tests = len(sweep_data) - 1 - - for line in base_log_data: - l = line.strip().split() - if not l: - continue - elif l[0] == 'total': - base_total_cells = l[3] - output_file.write('Total cells for baseline run '+base_total_cells+ '\n') - elif l[0] == 'Total' and l[1] == 'time:': - base_total_time = l[2] - output_file.write('Total wall time for baseline run '+base_total_time+ '\n') - elif l[0] == 'Regridding': - base_regridding_time = l[1] - output_file.write('Regridding time for baseline run '+base_regridding_time+ '\n') - - summary_file.write(str(number_of_tests)+' '+str(len(gauge_list))+' '+ str(base_total_cells) +'\n') - - for i in range(0,number_of_tests): - dir = 'post-process-data/test-no-'+str(i+1) - if not os.path.exists(dir): - os.makedirs(dir) - - max_level = sweep_data[i+1][2] - - output_file.write('The Sweep number '+str(i+1)+' had the following data: '+ '\n') - output_file.write('Grid s` ize: '+str(sweep_data[i+1][0])+' by '+ str(sweep_data[i+1][1])+ '\n') - output_file.write('The max AMR Level is '+ str(max_level)+ '\n') - output_file.write('The AMR Levels used are '+ str(sweep_data[i+1][3:]) +'\n') - - log_file = open(output_path + 'sweep_'+str(i+1)+'_log.txt') - run_number = log_file.readline().strip().split('.')[0] - test_log_data = open(log_path+'sweep_'+str(i+1)+'.o'+run_number) - - for line in test_log_data: - l = line.strip().split() - if not l: - continue - elif l[0] == 'total': - test_total_cells = l[3] - output_file.write('Total cells for test run '+base_total_cells+ '\n') - elif l[0] == 'Total' and l[1] == 'time:': - test_total_time = l[2] - summary_file.write('0 '+str(test_total_time)+' '+str(test_regridding_time)+' '+str(test_total_cells)+' '+str(i+1)+'\n') - output_file.write('Total wall time for test run '+test_total_time+ '\n') - - elif l[0] == 'Regridding': - test_regridding_time = l[1] - output_file.write('Regridding time for test run '+test_regridding_time +'\n') - - test_gauges = defaultdict(list) - - for j in gauge_list: - test_path = output_path+'sweep_'+str(i+1)+'_output/' - l = len(str(j)) - if l < 5: - t = '0'*(5-l)+str(j) - else: - t = str(j) - - test_gauges[j] = numpy.loadtxt(test_path+'gauge' + t +'.txt') - - - for j in gauge_list: - test_gauges[j] = numpy.array(test_gauges[j]) - summary_file.write('g ') - [output_file,summary_file,debug_file] = error_calc(test_gauges[j],baseline_gauges[j],j,max_level,dir,summary_file,output_file,debug_file) - summary_file.write('\n') - - summary_file.write('b '+str(base_regridding_time)+' '+str(base_total_time)+'\n\n') - print 'Error Analysis done' diff --git a/batch/habanero.py b/batch/habanero.py deleted file mode 100644 index a54840d..0000000 --- a/batch/habanero.py +++ /dev/null @@ -1,185 +0,0 @@ -r"""Batch sub-classes for runs on the Columbia Habanero machine (SLURM)""" - -# ============================================================================ -# Copyright (C) 2018 Kyle Mandli -# -# Distributed under the terms of the MIT license -# http://www.opensource.org/licenses/ -# ============================================================================ - -from __future__ import print_function -from __future__ import absolute_import - -import os -import glob -import subprocess - -import batch - - -class HabaneroJob(batch.Job): - r""" - Modifications to the basic :class:`batch.Job` class for Habanero runs - - """ - - def __init__(self): - r""" - Initialize Habanero job - - See :class:`HabaneroJob` for full documentation - """ - - super(HabaneroJob, self).__init__() - - # Add extra job parameters - self.omp_num_threads = 1 - # self.mic_omp_num_threads = 1 - # self.mic_affinity = "none" - self.time = "12:00:00" - # TODO: check to see what the queue should be - self.queue = None - - -class HabaneroBatchController(batch.BatchController): - r""" - Modifications to the basic batch controller for Habanero runs - - - :Ignored Attributes: - - Due to the system setup, the following controller attributes are ignored: - - *plot*, *terminal_output*, *wait*, *poll_interval*, *plotclaw_cmd* - """ - - def __init__(self, jobs=[]): - r""" - Initialize Habanero batch controller - - See :class:`HabaneroBatchController` for full documentation - """ - - super(HabaneroBatchController, self).__init__(jobs) - - # Habanero specific execution controls - self.email = None - - def run(self): - r"""Run Habanero jobs from controller's *jobs* list. - - This run function is modified to run jobs through the slurm queue - system and provides controls for running serial jobs (OpenMP only). - - Unless otherwise noted, the behavior of this function is identical to - the base class :class:`BatchController`'s function. - """ - - # Run jobs - paths = [] - for (i, job) in enumerate(self.jobs): - # Create output directory - data_dirname = ''.join((job.prefix, '_data')) - output_dirname = ''.join((job.prefix, "_output")) - plots_dirname = ''.join((job.prefix, "_plots")) - run_script_name = ''.join((job.prefix, "_run.sh")) - log_name = ''.join((job.prefix, "_log.txt")) - - if len(job.type) > 0: - job_path = os.path.join(self.base_path, job.type, job.name) - else: - job_path = os.path.join(self.base_path, job.name) - job_path = os.path.abspath(job_path) - data_path = os.path.join(job_path, data_dirname) - output_path = os.path.join(job_path, output_dirname) - plots_path = os.path.join(job_path, plots_dirname) - log_path = os.path.join(job_path, log_name) - run_script_path = os.path.join(job_path, run_script_name) - paths.append({'job': job_path, 'data': data_path, - 'output': output_path, 'plots': plots_path, - 'log': log_path}) - - # Create job directory if not present - if not os.path.exists(job_path): - os.makedirs(job_path) - - # Clobber old data directory - if os.path.exists(data_path): - if not job.rundata.clawdata.restart: - data_files = glob.glob(os.path.join(data_path, '*.data')) - for data_file in data_files: - os.remove(data_file) - else: - os.mkdir(data_path) - - # Write out data - temp_path = os.getcwd() - os.chdir(data_path) - job.write_data_objects() - os.chdir(temp_path) - - # Handle restart requests - if job.rundata.clawdata.restart: - restart = "T" - overwrite = "F" - else: - restart = "F" - overwrite = "T" - - # Construct string commands - run_cmd = "%s %s %s %s %s %s True\n" % (self.runclaw_cmd, - job.executable, - output_path, - overwrite, - restart, - data_path) - - if self.plot: - plot_cmd = "%s %s %s %s" % (self.plotclaw_cmd, output_path, - plots_path, job.setplot) - - cmd = run_cmd - if self.plot: - cmd = ";".join((cmd, plot_cmd)) - if self.tar: - cmd = ";".join((cmd, tar_cmd)) - - # Write slurm run script - run_script = open(run_script_path, 'w') - - run_script.write("#!/bin/sh\n") - run_script.write("#SBATCH --account=%s" % job.account) - run_script.write("#SBATCH -J %s # Job name\n" % job.prefix) - run_script.write("#SBATCH -o %s # Job name\n" % log_path) - run_script.write("#SBATCH -n 1 # Total number of MPI " - "tasks requested\n") - run_script.write("#SBATCH -N 1 # Total number of MPI " - "tasks requested\n") - if job.queue is not None: - run_script.write("#SBATCH -p %s # queue\n" - % job.queue) - run_script.write("#SBATCH -t %s # run time " - "(hh:mm:ss)\n" % job.time) - if self.email is not None: - run_script.write("#SBATCH --mail-user=%s" % self.email) - run_script.write("#SBATCH --mail-type=begin # email me" - " when the job starts\n") - run_script.write("#SBATCH --mail-type=end # email me" - " when the job finishes\n") - run_script.write("\n") - run_script.write("# OpenMP controls\n") - run_script.write("export OMP_NUM_THREADS=%s\n" - % job.omp_num_threads) - run_script.write("\n") - run_script.write("# Run command\n") - run_script.write(cmd) - - run_script.close() - - # Submit job to queue - subprocess.Popen("sbatch %s > %s" % (run_script_path, log_path), - shell=True).wait() - - # -- All jobs have been started -- - - return paths diff --git a/batch/job.py b/batch/job.py new file mode 100644 index 0000000..aeabdac --- /dev/null +++ b/batch/job.py @@ -0,0 +1,182 @@ +"""Core data types for batch job description and results.""" + +from __future__ import annotations + +import enum +import logging +from dataclasses import dataclass +from pathlib import Path + +logger = logging.getLogger(__name__) + + +class ClobberPolicy(enum.Enum): + """Controls behavior when a job's output directory already exists. + + OVERWRITE + Remove stale ``.data`` files and re-run. Existing output (``fort.*``) + is left in place and will be overwritten by the solver. This is the + default and matches the original batch behavior. + ERROR + Raise ``FileExistsError`` immediately. Use this when you want a hard + guarantee that you are not accidentally stomping a previous run. + SKIP + Skip any job whose output directory already exists. Together with a + sentinel file produced by the solver this gives free resumability: run + the same batch script again after a walltime kill and only the jobs + that did not finish will be resubmitted. + """ + + OVERWRITE = "overwrite" + ERROR = "error" + SKIP = "skip" + + +@dataclass +class JobPaths: + """Filesystem layout for one job. + + All data files, solver output (``fort.*``), and the run log share the same + root directory ``job``. Plots are kept in a subdirectory so they can be + tarred or discarded independently. + """ + + job: Path # root directory — data files and fort.* output go here + plots: Path # plots subdirectory + log: Path # per-job log file + + +@dataclass +class JobResult: + """Outcome record for a single submitted job. + + Attributes + ---------- + job: + The job that was submitted. + paths: + Filesystem layout computed by the controller for this job. + returncode: + Process exit code. ``None`` until the job completes; this is the + normal state for scheduler-submitted jobs when ``wait=False``. + job_id: + Scheduler-assigned job identifier (e.g. a SLURM job number). + ``None`` for local executors. + """ + + job: Job + paths: JobPaths + returncode: int | None + job_id: str | None = None + + @property + def success(self) -> bool: + """True only when returncode is known and zero.""" + return self.returncode == 0 + + @property + def pending(self) -> bool: + """True for scheduler-submitted jobs whose result is not yet known.""" + return self.returncode is None + + +class Job: + """Base class for all Clawpack batch jobs. + + Subclass this to define a concrete simulation. At minimum you must: + + - Set ``self.prefix`` — a unique string used to name the job directory. + - Populate ``self.rundata`` with a ``ClawRunData`` object (e.g. from + ``setrun.setrun()``). + + Optionally override: + + - ``write_data_objects(path)`` — if you need to write auxiliary files + beyond what ``rundata.write()`` produces. + - ``build(paths)`` — to compile the executable before submission. + + Attributes + ---------- + prefix : str | None + Unique identifier for this job. Becomes the job directory name. + Must be set before the job is submitted. + executable : str | Path + Name or path of the compiled binary. A bare name (``"xgeoclaw"``) is + resolved against the job directory after ``build()`` runs; an absolute + path is used as-is. + setplot : str + Module name passed to ``plotclaw`` if plotting is requested. + restart : bool + If True, the controller will not clobber existing ``.data`` files and + will pass the restart flag to ``runclaw``. + paths : JobPaths | None + Populated by ``BatchController`` before the job is submitted. + Available for use in postprocessing. + rundata : ClawRunData | None + Clawpack run-data object. Must be set by the subclass. + """ + + def __init__(self) -> None: + self.prefix: str | None = None + self.executable: str | Path = "xgeoclaw" + self.setplot: str = "setplot" + self.restart: bool = False + self.paths: JobPaths | None = None + self.rundata = None + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(prefix={self.prefix!r})" + + def write_data_objects(self, path: Path) -> None: + """Write Clawpack ``.data`` files into *path*. + + The default implementation calls ``self.rundata.write(out_dir=path)``. + Override to write additional auxiliary files, calling ``super()`` first. + + Parameters + ---------- + path: + Destination directory. Always ``paths.job`` — the same directory + that will receive solver output. + """ + if self.rundata is None: + raise ValueError( + f"Job {self.prefix!r}: rundata is not set. " + "Assign a ClawRunData object before running." + ) + self.rundata.write(out_dir=path) + + def build(self, paths: JobPaths) -> None: + """Compile the executable before job submission. + + The default is a no-op. Override when each job requires a fresh + build — for example, a parameter that is compiled into the Fortran + source rather than read from a data file. + + The compiled executable should be placed at ``paths.job / self.executable`` + (or ``self.executable`` should be updated to an absolute path) so that + the executor can locate it. + + Parameters + ---------- + paths: + Paths object for this job, provided by the controller. + """ + pass + + def post_run(self, result: "JobResult") -> None: + """Called after a job completes successfully. No-op default. + + Override to run plotting, data conversion, or any per-job + postprocessing. For ParallelExecutor this fires as each job is + harvested in _drain, concurrent with remaining running jobs. + For SLURM it fires as each job leaves the queue in wait_all. + Only called when result.success is True. + + Parameters + ---------- + result: + The completed job's result object, giving access to paths + and returncode. + """ + pass diff --git a/batch/plot.py b/batch/plot.py new file mode 100644 index 0000000..4433779 --- /dev/null +++ b/batch/plot.py @@ -0,0 +1,114 @@ +"""Plotting utilities for post-run analysis.""" + +from __future__ import annotations + +import logging +import subprocess +import sys +from pathlib import Path + +from batch.job import JobResult + +logger = logging.getLogger(__name__) + + +def _plot_inprocess(result: JobResult, setplot, format: str) -> bool: + """In-process fallback used only when setplot is callable.""" + try: + from clawpack.visclaw.plotclaw import plotclaw + except ImportError: + logger.warning( + "clawpack.visclaw not importable; skipping plot for %s", + result.job.prefix, + ) + return False + + try: + plotclaw( + outdir=str(result.paths.job), + plotdir=str(result.paths.plots), + setplot=setplot, + format=format, + ) + logger.info("Plots written to %s", result.paths.plots) + return True + except Exception: + logger.exception("Plotting failed for job %s", result.job.prefix) + return False + + +def plot_job( + result: JobResult, + setplot: str | Path = "setplot.py", + format: str = "ascii", + verbose: bool = False, +) -> bool: + """Run plotclaw on a completed job's output. + + Runs plotclaw as a subprocess, capturing all output (including C-level + output from matplotlib) to the job's log file. A ``--- plotclaw ---`` + separator is written to the log before the subprocess call so solver + and plotting output are visually distinct. + + Parameters + ---------- + result: + Completed job result. Uses result.paths.job as outdir and + result.paths.plots as plotdir. + setplot: + File path (str or Path) or callable. A relative string is resolved + against result.paths.job if that file exists; a Path is resolved to + an absolute path. A callable cannot cross the subprocess boundary + and triggers an in-process fallback with a logged warning. + format: + Clawpack output format, passed to plotclaw. Default 'ascii'. + verbose: + When True, log the full args list at INFO level before running. + + Returns + ------- + bool + True on success, False on failure. + """ + if callable(setplot): + logger.warning( + "setplot is callable; falling back to in-process plotting for %s " + "(output will not be captured to log)", + result.job.prefix, + ) + return _plot_inprocess(result, setplot, format) + + if isinstance(setplot, Path): + setplot_arg = str(setplot.resolve()) + else: + candidate = result.paths.job / setplot + setplot_arg = str(candidate) if candidate.exists() else str(setplot) + + args = [ + sys.executable, + "-m", + "clawpack.visclaw.plotclaw", + str(result.paths.job), + str(result.paths.plots), + setplot_arg, + ] + + if verbose: + logger.info("plotclaw args: %s", args) + + with open(result.paths.log, "a") as log_fh: + log_fh.write("\n--- plotclaw ---\n") + log_fh.flush() + proc = subprocess.run(args, stdout=log_fh, stderr=log_fh) + + if proc.returncode != 0: + logger.warning( + "plotclaw exited with returncode %d for job %s; see %s", + proc.returncode, + result.job.prefix, + result.paths.log, + ) + return False + + logger.info("Plots written to %s", result.paths.plots) + return True diff --git a/batch/post_process_data_analysis.py b/batch/post_process_data_analysis.py deleted file mode 100644 index 4273650..0000000 --- a/batch/post_process_data_analysis.py +++ /dev/null @@ -1,624 +0,0 @@ -""" -Script to analyze and visualze gauge errors for the batch runs -""" -# ============================================================================ -# Copyright (C) 2013 Kyle Mandli -# -# Distributed under the terms of the MIT license -# http://www.opensource.org/licenses/ -# -# (post process scripts by Akshay Sriapda , 2017) -# ============================================================================ - -from collections import defaultdict -import numpy -import matplotlib.pyplot as plt -from matplotlib import cm -import os -import math -from mpl_toolkits.mplot3d import Axes3D - -def plot_summary(time,L1_error,L2_error,Inf_error,no_of_gauges,no_of_sweeps,regrid_time): - # plots all the three norm errors for all gauges in three subplots. - # Colors represnt different sweeps - # Markers represent different gauges - #Marker size is relative to regridding time - sweep_labels = ['r','b','g','c','m','y','k'] - gauge_labesls = ['o','s','d','^','v','>','+','p','h','o','s'] - - regrid_marker_fac = 5 - plot_count = 0 - fig, (ax1,ax2,ax3) = plt.subplots(3) - #fig, (ax1) = plt.subplots(1) - - for i in range(0,no_of_gauges): - for j in range(0,no_of_sweeps): - if (i==0): - ax1.plot(time[j],L1_error[j,i],sweep_labels[j%len(sweep_labels)]+gauge_labesls[i%len(gauge_labesls)],markersize=5+(regrid_time[j]*regrid_marker_fac)) - ax2.plot(time[j],L2_error[j,i],sweep_labels[j%len(sweep_labels)]+gauge_labesls[i%len(gauge_labesls)],markersize=5+(regrid_time[j]*regrid_marker_fac)) - ax3.plot(time[j],Inf_error[j,i],sweep_labels[j%len(sweep_labels)]+gauge_labesls[i%len(gauge_labesls)],label='Scenario '+str(j+1),markersize=5+(regrid_time[j]*regrid_marker_fac)) - - elif (j==0): - ax1.plot(time[j],L1_error[j,i],sweep_labels[j%len(sweep_labels)]+gauge_labesls[i%len(gauge_labesls)],markersize=5+(regrid_time[j]*regrid_marker_fac)) - ax2.plot(time[j],L2_error[j,i],sweep_labels[j%len(sweep_labels)]+gauge_labesls[i%len(gauge_labesls)],markersize=5+(regrid_time[j]*regrid_marker_fac)) - ax3.plot(time[j],Inf_error[j,i],sweep_labels[j%len(sweep_labels)]+gauge_labesls[i%len(gauge_labesls)],label='Gauge '+str(i+1),markersize=5+(regrid_time[j]*regrid_marker_fac)) - else: - ax1.plot(time[j],L1_error[j,i],sweep_labels[j%len(sweep_labels)]+gauge_labesls[i%len(gauge_labesls)],markersize=5+(regrid_time[j]*regrid_marker_fac)) - ax2.plot(time[j],L2_error[j,i],sweep_labels[j%len(sweep_labels)]+gauge_labesls[i%len(gauge_labesls)],markersize=5+(regrid_time[j]*regrid_marker_fac)) - ax3.plot(time[j],Inf_error[j,i],sweep_labels[j%len(sweep_labels)]+gauge_labesls[i%len(gauge_labesls)],markersize=5+(regrid_time[j]*regrid_marker_fac)) - - ax1.set_title('L1 Norm Error') - ax1.set_xlabel("Total Wall Time (%, Baseline = 100%)") - ax1.set_ylabel("L1 Norm Error") - ax2.set_title('L2 Norm Error') - ax2.set_xlabel("Total Wall Time (%, Baseline = 100%)") - ax2.set_ylabel("L2 Norm Error") - ax3.set_title('Inf Norm Error') - ax3.set_xlabel("Total Wall Time (%, Baseline = 100%)") - ax3.set_ylabel("Inf Norm Error") - - plt.legend(bbox_to_anchor=(0, -0.2, 1., -0.2), loc=2,ncol=3, mode="expand", borderaxespad=0.) - plt.tight_layout() - plt.savefig(plot_path+'tohoku-error.png',bbox_inches='tight') - -def get_avg_bound(L1_error,L2_error,Inf_error,no_of_sweeps): - # calcualtes the max and min errors for each sweep and which gauge the max and min occurs at. - # Also the average error over all the gauges for each sweep - L1_info = numpy.empty([6,no_of_sweeps]) - L2_info = numpy.empty([6,no_of_sweeps]) - Inf_info = numpy.empty([6,no_of_sweeps]) - - for j in range(0,no_of_sweeps): - L1_info[0,j] = numpy.mean(L1_error[j,:]) - L1_info[1,j] = numpy.std(L1_error[j,:]) - L1_info[2,j] = numpy.amax(L1_error[j,:]) - L1_info[3,j] = L1_error[j,:].tolist().index(numpy.amax(L1_error[j,:])) - L1_info[4,j] = numpy.amin(L1_error[j,:]) - L1_info[5,j] = L1_error[j,:].tolist().index(numpy.amin(L1_error[j,:])) - - L2_info[0,j] = numpy.mean(L2_error[j,:]) - L2_info[1,j] = numpy.std(L2_error[j,:]) - L2_info[2,j] = numpy.amax(L2_error[j,:]) - L2_info[3,j] = L2_error[j,:].tolist().index(numpy.amax(L2_error[j,:])) - L2_info[4,j] = numpy.amin(L2_error[j,:]) - L2_info[5,j] = L2_error[j,:].tolist().index(numpy.amin(L2_error[j,:])) - - Inf_info[0,j] = numpy.mean(Inf_error[j,:]) - Inf_info[1,j] = numpy.std(Inf_error[j,:]) - Inf_info[2,j] = numpy.amax(Inf_error[j,:]) - Inf_info[3,j] = Inf_error[j,:].tolist().index(numpy.amax(Inf_error[j,:])) - Inf_info[4,j] = numpy.amin(Inf_error[j,:]) - Inf_info[5,j] = Inf_error[j,:].tolist().index(numpy.amin(Inf_error[j,:])) - - highest_error_sweep = L1_info[0,:].tolist().index(numpy.amax(L1_info[0,:])) - lowest_error_sweep = L1_info[0,:].tolist().index(numpy.amin(L1_info[0,:])) - return L1_info,L2_info,Inf_info,lowest_error_sweep,highest_error_sweep - -def plot_error_bar(time,l1_data,l2_data,inf_data,no_of_sweeps,error_bars): - # plots the max and min errors for each sweep as a error bar plot with gauge nos - #fig, (ax1,ax2,ax3) = plt.subplots(3) - fig, (ax1) = plt.subplots(1) - - if(error_bars): - ax1.errorbar(time,l1_data[0,:],yerr=[numpy.subtract(l1_data[0,:],l1_data[4,:]),numpy.subtract(l1_data[2,:],l1_data[0,:])],ls='',marker='o',color='r',markersize=5) - #ax2.errorbar(time,l2_data[0,:],yerr=[numpy.subtract(l2_data[0,:],l2_data[4,:]),numpy.subtract(l2_data[2,:],l2_data[0,:])],ls='',marker='o',color='r',markersize=5) - #ax3.errorbar(time,inf_data[0,:],yerr=[numpy.subtract(inf_data[0,:],inf_data[4,:]),numpy.subtract(inf_data[2,:],inf_data[0,:])],ls='',marker='o',color='r',markersize=5) - else: - ax1.plot(time,l1_data[0,:],'w',markersize=3) - #ax2.plot(time,l2_data[0,:],'w',markersize=3) - #ax3.plot(time,inf_data[0,:],'w',markersize=3) - - for i in range(0,no_of_sweeps): - ax1.text(time[i]+0.2,l1_data[0,i],str(i+1),fontsize = 10) - ax1.text(time[i],l1_data[2,i]+0.01,str(l1_data[3,i]+1),fontsize = 10) - ax1.text(time[i],l1_data[4,i]-0.01,str(l1_data[5,i]+1),fontsize = 10) - - #ax2.text(time[i]+0.2,l2_data[0,i],str(i+1),fontsize = 10) - #ax2.text(time[i],l2_data[2,i]+0.1,str(l2_data[3,i]+1),fontsize = 10) - #ax2.text(time[i],l2_data[4,i]-0.1,str(l2_data[5,i]+1),fontsize = 10) - - #ax3.text(time[i]+0.2,inf_data[0,i],str(i+1),fontsize = 10) - #ax3.text(time[i],inf_data[2,i]+0.1,str(inf_data[3,i]+1),fontsize = 10) - #ax3.text(time[i],inf_data[4,i]-0.1,str(inf_data[5,i]+1),fontsize = 10) - - ax1.set_title('Error bars with gauge IDs') - ax1.set_xlabel("Total Wall Time (%, Baseline = 100%)") - ax1.set_ylabel("L1 Norm Error") - ax1.set_ylim([-0.1,1]) - #ax2.set_title('L2 Norm Error') - #ax2.set_xlabel("Total Wall Time (%, Baseline = 100%)") - #ax2.set_ylabel("L2 Norm Error") - #ax2.set_ylim([-0.1,1]) - #ax3.set_title('Inf Norm Error') - #ax3.set_xlabel("Total Wall Time (%, Baseline = 100%)") - #ax3.set_ylabel("Inf Norm Error") - #ax3.set_ylim([-0.1,1]) - - plt.tight_layout() - if(error_bars): - plt.savefig(plot_path+'gauge_details_with_error_bar.png',bbox_inches='tight') - else: - plt.savefig(plot_path+'gauge_details_no_error_bar.png',bbox_inches='tight') - -def plot_error_average(time,data,no_of_sweeps,max_level,error_type): - # plots the error average of all gauges for each sweep. error_type can handle different norms - colour = ['r','b','g','c','m','y'] - count = defaultdict(list) - for i in range(number_of_sweeps): - count[max_level[0,i]] = 0 - - counter = 0 - - fig, (ax1) = plt.subplots(1) - - for i in range(no_of_sweeps): - if (count[max_level[0,i]] == 0): - ax1.plot(time[i],data[i],colour[int(max_level[0,i]-numpy.amin(max_level))]+'o',label='Level '+str(max_level[0,i])) - count[max_level[0,i]] += 1 - else: - ax1.plot(time[i],data[i],colour[int(max_level[0,i]-numpy.amin(max_level))]+'o') - counter = counter + 1 - - if(counter%2 == 0 ): - ax1.text(time[i],data[i]+0.05,str(i+1),fontsize = 10) - else: - ax1.text(time[i],data[i]-0.075,str(i+1),fontsize = 10) - - ax1.set_title(error_type+" Norm Error Averages") - ax1.set_xlabel("Total Wall Time (%, Baseline = 100%)") - ax1.set_ylabel(error_type+" Norm Error") - ax1.set_ylim([-0.1,1.1]) - plt.tight_layout() - ax1.legend(loc=1) - plt.savefig(plot_path+error_type+'-average.png',bbox_inches='tight') - -def plot_error_std(time,average,std,lower,upper,no_of_sweeps,error_type): - # plots the standard deviation of errors of all gauges for each sweep. error_type can handle different norms - fig, (ax1) = plt.subplots(1) - ax1.errorbar(time,average,yerr=[numpy.subtract(average,lower),numpy.subtract(upper,average)],ls='',marker='o',color='r',markersize=5) - - for i in range(0,no_of_sweeps): - ax1.text(time[i],average[i],str(round(std[i],3)),fontsize=10) - - ax1.set_title(error_type+" Norm error bars") - ax1.set_xlabel("Total Wall Time (%, Baseline = 100%)") - ax1.set_ylabel(error_type+" Norm Error") - ax1.set_ylim([-0.1,1.1]) - - plt.tight_layout() - plt.savefig(plot_path+error_type+'-standard-deviation.png',bbox_inches='tight') - - -def plot_number_cells(time,data,no_of_sweeps): - # plots the total number of cells used for each sweep - fig, (ax1) = plt.subplots(1) - - scale = len(str(int(max(data)))) - 1 - count = 0 - for i in range(0,no_of_sweeps): - ax1.plot(time[i],data[i],'ow',markersize = 0.5) - count = count + 1 - if(count%2 == 0 ): - ax1.text(time[i],data[i]+0.005*count*10**scale,str(i+1),fontsize = 10) - else: - ax1.text(time[i],data[i]-0.005*count*10**scale,str(i+1),fontsize = 10) - - ax1.set_title('Number of Cells') - ax1.set_xlabel("Total Wall Time (%, Baseline = 100%)") - ax1.set_ylabel("Number of Cells") - ax1.set_ylim([min(-0.1,min(data)-0.1),max(1.1,max(data)+0.1)]) - plt.tight_layout() - plt.savefig(plot_path+'Number-of-Cells.png',bbox_inches='tight') - - -def plot_regrid_time(time,data,no_of_sweeps,error_type): - #plots the regridding time used for each sweep - fig, (ax1) = plt.subplots(1) - - count = 0 - for i in range(0,no_of_sweeps): - ax1.plot(time[i],data[i],'ow',markersize = 0.5) - count = count + 1 - if(count%2 == 0 ): - ax1.text(time[i],data[i],str(i+1),fontsize = 10) - else: - ax1.text(time[i],data[i],str(i+1),fontsize = 10) - - ax1.set_title(error_type+' Norm Error vs Regridding time') - ax1.set_xlabel("Regridding time") - ax1.set_ylabel(error_type+" Norm Error") - #ax1.set_ylim([0.1,1.1]) - - plt.tight_layout() - plt.savefig(plot_path+error_type+'-regrid-time.png',bbox_inches='tight') - -def plot_features(time,l1_data,l2_data,inf_data,no_of_gauges,no_of_sweeps,regrid_time,number_of_total_cell_updates,max_level): - - #All the error types with gauge numbers for the largest and smallest error for each sweep - # error_bars = True, plots with error bars, if false no error bars. - plot_error_bar(time,l1_data,l2_data,inf_data,no_of_sweeps,error_bars = True) - - # Plotting the error avergaes with sweep numbers - plot_error_average(time,l1_data[0,:],no_of_sweeps,max_level,'L1') - plot_error_average(time,l2_data[0,:],no_of_sweeps,max_level,'L2') - plot_error_average(time,inf_data[0,:],no_of_sweeps,max_level,'Inf') - - # Plotting error avergaes with standard deviation and error bars - plot_error_std(time,l1_data[0,:],l1_data[1,:],l1_data[2,:],l1_data[4,:],no_of_sweeps,'L1') - plot_error_std(time,l2_data[0,:],l2_data[1,:],l2_data[2,:],l2_data[4,:],no_of_sweeps,'L2') - plot_error_std(time,inf_data[0,:],inf_data[1,:],inf_data[2,:],inf_data[4,:],no_of_sweeps,'Inf') - - # Plotting errors vs the regirding times for each sweep - plot_regrid_time(regrid_time,l1_data[0,:],no_of_sweeps,'L1') - plot_regrid_time(regrid_time,l2_data[0,:],no_of_sweeps,'L2') - plot_regrid_time(regrid_time,inf_data[0,:],no_of_sweeps,'Inf') - - # Plotting the number of cells used for each sweep - plot_number_cells(time,number_of_total_cell_updates,no_of_sweeps) - - -def plot_cell_ratios(output_time,number_of_sweeps,number_of_cells,refinemnt_ratios,max_level,grid_size,lowest_error_sweep,highest_error_sweep,highest_cost_sweep,lowest_cost_sweep,data,cost,time): - # calculates the cell ratios and plots them vs various other quantities - cell_number_ratio_sum = numpy.zeros([number_of_sweeps,len(output_time)]) - cell_number_ratio_avg = numpy.zeros([number_of_sweeps,1]) - fig, (ax1) = plt.subplots(1) - colour = ['r','b','g','c','m','y'] - count = defaultdict(list) - for i in range(number_of_sweeps): - count[max_level[0,i]] = 0 - - for i in range(0,number_of_sweeps): - for j in range(0,len(output_time)): - ratio = 0 - for k in range(0,int(max_level[0,i])-1): - - if (number_of_cells[i+1][j,k+1] == 0 and number_of_cells[i+1][j,k] == 0): - ratio += 0.0 - else: - ratio += (number_of_cells[i+1][j,k+1]) / (number_of_cells[i+1][j,k] * float(refinemnt_ratios[i,k])**2) - cell_number_ratio_sum[i,j] = ratio - - if (count[max_level[0,i]] == 0): - ax1.plot(output_time/3600,cell_number_ratio_sum[i,:],'-'+colour[int(max_level[0,i]-numpy.amin(max_level))],linewidth=2,label='Level '+str(max_level[0,i])) - count[max_level[0,i]] += 1 - else: - ax1.plot(output_time/3600,cell_number_ratio_sum[i,:],'-'+colour[int(max_level[0,i]-numpy.amin(max_level))],linewidth=2) - - cell_number_ratio_avg[i,0] = numpy.mean(cell_number_ratio_sum[i,:]) - numpy.savetxt(plot_path+'cell_uti_ratios.txt',cell_number_ratio_sum) - ax1.plot(output_time/3600,cell_number_ratio_sum[lowest_error_sweep,:],'wo',markersize=5,linewidth=3,label='lowest error sweep '+str(lowest_error_sweep+1)) - ax1.plot(output_time/3600,cell_number_ratio_sum[highest_error_sweep,:],'ko',markersize=5,linewidth=3,label='highest error sweep '+str(highest_error_sweep+1)) - ax1.plot(output_time/3600,cell_number_ratio_sum[lowest_cost_sweep,:],'ws',markersize=5,linewidth=3,label='lowest cost sweep '+str(lowest_cost_sweep+1)) - ax1.plot(output_time/3600,cell_number_ratio_sum[highest_cost_sweep,:],'ks',markersize=5,linewidth=3,label='highest cost sweep '+str(highest_cost_sweep+1)) - - ax1.set_title('Cell Utilization Ratios over Scenarios') - ax1.set_xlabel("Output Times (hrs)") - ax1.set_ylabel("Cell Utilization Ratios") - #ax1.legend(loc=1) - ax1.legend(bbox_to_anchor=(0, -0.05, 1., -0), loc=2,ncol=3, mode="expand", borderaxespad=0.) - plt.tight_layout() - plt.savefig(plot_path+'cell-ratios.png',bbox_inches='tight') - - for i in range(number_of_sweeps): - count[max_level[0,i]] = 0 - fig, (ax1) = plt.subplots(1) - colour = ['r','b','g','c','m','y'] - for i in range(0,number_of_sweeps): - if (count[max_level[0,i]] == 0): - ax1.plot(time[i],numpy.mean(cell_number_ratio_sum[i,:]),colour[int(max_level[0,i]-numpy.amin(max_level))]+'o',label='Level '+str(max_level[0,i])) - count[max_level[0,i]] += 1 - else: - ax1.plot(time[i],numpy.mean(cell_number_ratio_sum[i,:]),colour[int(max_level[0,i]-numpy.amin(max_level))]+'o') - ax1.text(time[i],numpy.mean(cell_number_ratio_sum[i,:])+0.2,str(i+1)) - ax1.set_title('Average Cell Utilization Ratios vs Wall Time') - ax1.set_xlabel("Total Wall Time (%, Baseline = 100%)") - ax1.set_ylabel("Cell Utilization Ratio") - #ax1.legend(loc=1) - ax1.legend(loc=1) - plt.tight_layout() - plt.savefig(plot_path+'time-vs-cellratios.png',bbox_inches='tight') - - for i in range(number_of_sweeps): - count[max_level[0,i]] = 0 - fig, (ax1) = plt.subplots(1) - colour = ['r','b','g','c','m','y'] - for i in range(0,number_of_sweeps): - if (count[max_level[0,i]] == 0): - ax1.plot(data[0,i],cell_number_ratio_avg[i,0],colour[int(max_level[0,i]-numpy.amin(max_level))]+'o',label='Level '+str(max_level[0,i])) - count[max_level[0,i]] += 1 - else: - ax1.plot(data[0,i],cell_number_ratio_avg[i,0],colour[int(max_level[0,i]-numpy.amin(max_level))]+'o') - - ax1.text(data[0,i],numpy.mean(cell_number_ratio_sum[i,:])+0.2,str(i+1)) - ax1.set_title('Cell Utilization Ratios vs L1 Error') - ax1.set_xlabel("L1 Error") - ax1.set_ylabel("Cell Utilization Ratios") - #ax1.legend(loc=1) - ax1.legend(loc=1) - plt.tight_layout() - plt.savefig(plot_path+'error-vs-cellratios.png',bbox_inches='tight') - - fig = plt.figure() - ax = fig.add_subplot(111, projection='3d') - for i in range(0,number_of_sweeps): - ax.scatter(time[i], data[0,i], cell_number_ratio_avg[i,0],c=colour[int(max_level[0,i]-numpy.amin(max_level))]) - ax.set_xlabel('Time') - ax.set_ylabel(' Error') - ax.set_zlabel('Cell Ratio') - plt.savefig(plot_path+'cell-ratios-3d.pdf',bbox_inches='tight') - - for i in range(number_of_sweeps): - count[max_level[0,i]] = 0 - fig, (ax1) = plt.subplots(1) - colour = ['r','b','g','c','m','y'] - for i in range(0,number_of_sweeps): - if (count[max_level[0,i]] == 0): - ax1.plot(cost[i],cell_number_ratio_avg[i,0],colour[int(max_level[0,i]-numpy.amin(max_level))]+'o',label='Level '+str(max_level[0,i])) - count[max_level[0,i]] += 1 - else: - ax1.plot(cost[i],cell_number_ratio_avg[i,0],colour[int(max_level[0,i]-numpy.amin(max_level))]+'o') - - ax1.text(cost[i],numpy.mean(cell_number_ratio_sum[i,:])+0.2,str(i+1)) - ax1.set_title('Cost vs Cell Utilization Ratios') - ax1.set_xlabel("Cost") - ax1.set_ylabel("Cell Utilization Ratios") - #ax1.legend(loc=1) - ax1.legend(loc=1) - plt.tight_layout() - plt.savefig(plot_path+'cost-vs-cellratios.png',bbox_inches='tight') - - - - for i in range(number_of_sweeps): - count[max_level[0,i]] = 0 - fig, (ax1) = plt.subplots(1) - colour = ['r','b','g','c','m','y'] - for i in range(0,number_of_sweeps): - if (count[max_level[0,i]] == 0): - ax1.loglog(time[i],cell_number_ratio_avg[i,0],colour[int(max_level[0,i]-numpy.amin(max_level))]+'o',label='Level '+str(max_level[0,i])) - count[max_level[0,i]] += 1 - else: - ax1.loglog(time[i],cell_number_ratio_avg[i,0],colour[int(max_level[0,i]-numpy.amin(max_level))]+'o') - - #ax1.text(cost[i],numpy.mean(cell_number_ratio_sum[i,:])+0.2,str(i+1)) - ax1.set_title('Log Log Plot for Cell Utilization Ratios vs Wall Time') - ax1.set_xlabel("Total Wall Time (%, Baseline = 100%)") - ax1.set_ylabel("Cell Utilization Ratios") - #ax1.legend(loc=1) - ax1.legend(loc=1) - plt.tight_layout() - plt.savefig(plot_path+'time-vs-cellratios-loglog.png',bbox_inches='tight') - - p = numpy.polyfit(numpy.log(time),numpy.log(cell_number_ratio_avg[:,0].tolist()),1) - print p - line_fit = lambda t: p[0]*t + p[1] - for i in range(number_of_sweeps): - count[max_level[0,i]] = 0 - fig, (ax1) = plt.subplots(1) - colour = ['r','b','g','c','m','y'] - time_range = numpy.linspace(1.0,max(numpy.log(time)+1.0),30) - ax1.plot(time_range,line_fit(time_range),'-k',label = str(round(p[0],2))+'t + '+str(round(p[1],2))) - - for i in range(0,number_of_sweeps): - if (count[max_level[0,i]] == 0): - ax1.plot(numpy.log(time[i]),numpy.log(cell_number_ratio_avg[i,0]),colour[int(max_level[0,i]-numpy.amin(max_level))]+'o',label='Level '+str(max_level[0,i])) - count[max_level[0,i]] += 1 - else: - ax1.plot(numpy.log(time[i]),numpy.log(cell_number_ratio_avg[i,0]),colour[int(max_level[0,i]-numpy.amin(max_level))]+'o') - #ax1.text(cost[i],numpy.mean(cell_number_ratio_sum[i,:])+0.2,str(i+1)) - ax1.set_title('Line fit plot') - ax1.set_xlabel("Total Wall Time (%, Baseline = 100%)") - ax1.set_ylabel("Cell Utilization Ratios") - #ax1.legend(loc=1) - ax1.legend(loc=1) - plt.tight_layout() - plt.savefig(plot_path+'time-vs-cellratios-log.png',bbox_inches='tight') - - -def plot_level_group(time,data,distinct_max_level,sweep_indices,error_type,refinemnt_ratios,pair_indices): - # plots the sweeps using the same number of levels in different subplot - count = 0 - - nrow = int(math.ceil(len(distinct_max_level)/2.0)) ; ncol = int(math.floor(len(distinct_max_level)/2.0)); - fig, axs = plt.subplots(nrows=nrow, ncols=ncol) - axs = numpy.array(axs) - count = 0 - for ax in axs.reshape(-1): - for i in sweep_indices[distinct_max_level[count]]: - #ax.errorbar(time[i],data[0,i],yerr=[numpy.subtract(data[0,i],data[4,i]),numpy.subtract(data[2,i],data[0,i])],ls='',marker='o',color='r',markersize=5) - - ax.text(time[i],data[0,i],str(i+1),fontsize = 10) - #ax.text(time[i],data[2,i],str(data[3,i]),fontsize = 8) - #ax.text(time[i],data[4,i],str(data[5,i]),fontsize = 8) - #ax.text(time[i],data[0,i],str(round(data[1,i],3)),fontsize = 8) - - - ax.set_title("Level "+str(int(distinct_max_level[count]))+" Groupings",fontsize = 8) - ax.set_xlabel("Total Wall Time (%, Baseline = 100%)",fontsize = 8) - ax.set_ylabel(error_type+" Error",fontsize = 8) - ax.set_ylim([-0.1,1.1]) - ax.set_xlim([0,max(time)+5]) - count = count + 1 - if(count >= len(distinct_max_level)): - break - plt.tight_layout() - #plt.suptitle("Level "+str(distinct_max_level[count])+" Groupings") - plt.savefig(plot_path+error_type+'-grouping.png',bbox_inches='tight') - - # plots grouping with levels represnted by colors and also plots an arrow for the pairs of sweeps. - # Pairs are the sweeps with same refinement ratios but different arrangements - fig, ax1 = plt.subplots(1) - colour = ['r','b','g','c','m','y'] - - for i in range(0,len(distinct_max_level)): - count = 0 - pair = [] - for j in sweep_indices[distinct_max_level[i]]: - - if len(sweep_indices[distinct_max_level[i]]) >1: - if(count == 0): - count += 1 - - if pair_indices[j+1]: - if refinemnt_ratios[j,0] > refinemnt_ratios[pair_indices[j+1][0]-1,0]: - ax1.plot(time[j],data[0,j],colour[i]+'s',label='Level '+str(int(distinct_max_level[i]))) - ax1.text( time[pair_indices[j+1][0]-1],data[0,pair_indices[j+1][0]-1]+0.01,str(refinemnt_ratios[j,0]/refinemnt_ratios[j,1])) - ax1.arrow(time[j],data[0,j], time[pair_indices[j+1][0]-1]-time[j],data[0,pair_indices[j+1][0]-1]-data[0,j], head_width=0.03, head_length=0.25, fc='k', ec='k') - else: - ax1.plot(time[pair_indices[j+1][0]-1],data[0,pair_indices[j+1][0]-1],colour[i]+'s',label='Level '+str(int(distinct_max_level[i]))) - ax1.arrow(time[pair_indices[j+1][0]-1],data[0,pair_indices[j+1][0]-1],time[j]-time[pair_indices[j+1][0]-1],data[0,j]-data[0,pair_indices[j+1][0]-1], head_width=0.03, head_length=0.25, fc='k', ec='k') - ax1.text( time[j],data[0,j]+0.01,str(refinemnt_ratios[j,0]/refinemnt_ratios[j,1])) - - pair.append(pair_indices[j+1][0] - 1) - - elif(j not in pair): - if pair_indices[j+1]: - if refinemnt_ratios[j,0] > refinemnt_ratios[pair_indices[j+1][0]-1,0]: - ax1.plot(time[j],data[0,j],colour[i]+'s') - ax1.text( time[pair_indices[j+1][0]-1],data[0,pair_indices[j+1][0]-1]+0.01,str(refinemnt_ratios[j,0]/refinemnt_ratios[j,1])) - ax1.arrow(time[j],data[0,j],time[pair_indices[j+1][0]-1]-time[j],data[0,pair_indices[j+1][0]-1]-data[0,j], head_width=0.03, head_length=0.25, fc='k', ec='k') - else: - ax1.plot(time[pair_indices[j+1][0]-1],data[0,pair_indices[j+1][0]-1],colour[i]+'s',label='Level '+str(int(distinct_max_level[i]))) - ax1.arrow(time[pair_indices[j+1][0]-1],data[0,pair_indices[j+1][0]-1],time[j]-time[pair_indices[j+1][0]-1],data[0,j]-data[0,pair_indices[j+1][0]-1], head_width=0.03, head_length=0.25, fc='k', ec='k') - ax1.text( time[j],data[0,j]+0.01,str(refinemnt_ratios[pair_indices[j+1][0]-1,0]/refinemnt_ratios[pair_indices[j+1][0],1])) - - pair.append(pair_indices[j+1][0] - 1) - else: - ax1.plot(time[j],data[0,j],colour[i]+'^') - else: - ax1.plot(time[j],data[0,j],colour[i]+'s',label='Level '+str(int(distinct_max_level[i]))) -#break - ax1.set_title("Grouping by level") - ax1.set_xlabel("Total Wall Time (%, Baseline = 100%)",fontsize = 12) - ax1.set_ylabel(error_type+" Error",fontsize = 12) - ax1.legend(loc=2) - ax1.set_ylim([-0.1,1.1]) - ax1.set_xlim([0,max(time)+5]) - plt.tight_layout() - plt.savefig(plot_path+error_type+'-grouping-1.png',bbox_inches='tight') - - -def plot_cost_objective(time,data,error_weight,time_weight,error_type): - # calculates and pltos a cost plot. The cost calculations depend on the weights sent to the functions - cost = numpy.zeros(numpy.shape(time)) - fig, ax1 = plt.subplots(1) - for i in range(0,len(time)): - cost[i] = error_weight*data[0,i] + time_weight*time[i]/100.0 - ax1.plot(i+1,cost[i],'ro',markersize=8) - - ax1.set_title("Cost Plot with error weight "+str(error_weight)+" and time weight "+str(time_weight)) - ax1.set_xlabel("Sweep Number") - ax1.set_ylabel("Cost normalized to 1") - plt.tight_layout() - plt.savefig(plot_path+error_type+'-cost-plot'+str(error_weight)+','+str(time_weight)+'.png',bbox_inches='tight') - - fig = plt.figure() - ax = fig.add_subplot(111, projection='3d') - ax.scatter(time/100, data[0,:], cost) - - ax.set_xlabel('Time') - ax.set_ylabel(error_type+' Error') - ax.set_zlabel('Cost') - plt.savefig(plot_path+error_type+'-cost-plot-3d.pdf',bbox_inches='tight') - #plt.show() - - return cost,cost.tolist().index(max(cost)),cost.tolist().index(min(cost)) - -if __name__ == "__main__": - # point run_data_path to where the run details text file is located. - # If there is no text file with the details, create an array with the deatils - run_data_path = '../../scratch/Tohoku-hawaii/' - run_data = numpy.loadtxt(run_data_path+'run-data.txt') - - path = 'post-process-data/' - data = open(path+'summary-data.txt') - output_time = numpy.loadtxt(path+'output-times.txt') - if not os.path.exists(path+'plots'): - os.makedirs(path+'plots') - plot_path = path+'plots/' - - first_line = data.readline().strip().split() - number_of_gauges = int(first_line[1]) - number_of_sweeps = int(first_line[0]) - time_data = numpy.empty([number_of_sweeps,2]) - number_of_total_cell_updates = numpy.empty([number_of_sweeps,1]) - L1_error_data = numpy.empty([number_of_sweeps,number_of_gauges]) - L2_error_data = numpy.empty([number_of_sweeps,number_of_gauges]) - Inf_error_data = numpy.empty([number_of_sweeps,number_of_gauges]) - - number_of_cells = {} - for i in range(number_of_sweeps): - number_of_cells[i+1] = numpy.loadtxt(path+'num_cells_run_'+str(i+1)+'.txt') - - sweep_count = -1 - gauge_count = 0 - - for l in data: - line = l.strip().split() - if len(line) == 0: - break - - elif line[0] == '0': - sweep_count+=1 - time_data[sweep_count,0] = float(line[1]) - time_data[sweep_count,1] = float(line[2]) - number_of_total_cell_updates[sweep_count,0] = float(line[3]) - - elif line[0] == 'g': - L1_error_data[sweep_count,gauge_count] = float(line[1]) - L2_error_data[sweep_count,gauge_count] = float(line[2]) - Inf_error_data[sweep_count,gauge_count] = float(line[3]) - gauge_count+=1 - if gauge_count >= number_of_gauges: - gauge_count = 0 - - elif line[0] == 'b': - basline_total_time = float(line[2]) - basline_regridding_time = float(line[1]) - - # Normalizing the data - time = (time_data[:,0]*100)/basline_total_time - regridding_time = (time_data[:,1])/numpy.amax(time_data[:,1]) - L1_error_data = L1_error_data/numpy.amax(L1_error_data) - L2_error_data = L2_error_data/numpy.amax(L2_error_data) - Inf_error_data = Inf_error_data/numpy.amax(Inf_error_data) - - grid_size = numpy.empty([1,number_of_sweeps]) - max_level = numpy.empty([1,number_of_sweeps]) - refinemnt_ratios = [] - count = 0 - for i in range(1,number_of_sweeps+1): - - grid_size[0,count] = run_data[i,0] - max_level[0,count] = run_data[i,2] - refinemnt_ratios.append(run_data[i,3:]) - - count += 1 - - refinemnt_ratios = numpy.array(refinemnt_ratios) - - l1_data,l2_data,inf_data,lowest_error_sweep,highest_error_sweep = get_avg_bound(L1_error_data,L2_error_data,Inf_error_data,number_of_sweeps) - plot_summary(time,L1_error_data,L2_error_data,Inf_error_data,number_of_gauges,number_of_sweeps,regridding_time) - plot_features(time,l1_data,l2_data,inf_data,number_of_gauges,number_of_sweeps,regridding_time,number_of_total_cell_updates,max_level) - cost,highest_cost_sweep,lowest_cost_sweep = plot_cost_objective(time,l1_data,0.5,0.5,'L1') - - plot_cell_ratios(output_time,number_of_sweeps,number_of_cells,refinemnt_ratios,max_level,grid_size,lowest_error_sweep,highest_error_sweep,highest_cost_sweep,lowest_cost_sweep,l1_data,cost,time) - - distinct_max_level = [] - sweep_indices = defaultdict(list) - pair_indices = defaultdict(list) - for i in range(0,number_of_sweeps): - sweep_indices[max_level[0,i]].append(i) - if max_level[0,i] not in distinct_max_level: - distinct_max_level.append(max_level[0,i]) - - for i in range(0,number_of_sweeps): - for j in range(i+1,number_of_sweeps): - - if(numpy.array_equal(refinemnt_ratios[i,:max_level[0,i]-1],refinemnt_ratios[j,:max_level[0,j]-1][::-1])): - pair_indices[i+1].append(j+1) - plot_level_group(time,l1_data,distinct_max_level,sweep_indices,'L1',refinemnt_ratios,pair_indices) - - - plot_cost_objective(time,l1_data,0.75,0.25,'L1') - plot_cost_objective(time,l1_data,0.25,0.75,'L1') \ No newline at end of file diff --git a/batch/stampede.py b/batch/stampede.py deleted file mode 100644 index 91dcdd2..0000000 --- a/batch/stampede.py +++ /dev/null @@ -1,163 +0,0 @@ -r"""Batch sub-classes for runs on the TACC Stampede machine""" - -# ============================================================================ -# Copyright (C) 2013 Kyle Mandli -# -# Distributed under the terms of the MIT license -# http://www.opensource.org/licenses/ -# ============================================================================ - -from __future__ import print_function -from __future__ import absolute_import - -import os -import glob -import subprocess - -import batch - -class StampedeJob(batch.Job): - r""" - Modifications to the basic :class:`batch.Job` class for Stampede runs - - """ - - def __init__(self): - r""" - Initialize Stampede job - - See :class:`StampedeJob` for full documentation - """ - - super(StampedeJob, self).__init__() - - # Add extra job parameters - self.omp_num_threads = 1 - self.mic_omp_num_threads = 1 - self.mic_affinity = "none" - self.time = "12:00:00" - self.queue = "serial" - - -class StampedeBatchController(batch.BatchController): - r""" - Modifications to the basic batch controller for Stampede runs - - - :Ignored Attributes: - - Due to the system setup, the following controller attributes are ignored: - - *plot*, *terminal_output*, *wait*, *poll_interval*, *plotclaw_cmd* - """ - - def __init__(self, jobs=[]): - r""" - Initialize Stampede batch controller - - See :class:`StampedeBatchController` for full documentation - """ - - super(StampedeBatchController, self).__init__(jobs) - - # Stampede specific execution controls - self.email = None - - def run(self): - r"""Run Stampede jobs from controller's *jobs* list. - - This run function is modified to run jobs through the slurm queue system - and provides controls for running serial jobs (OpenMP only). - - Unless otherwise noted, the behavior of this function is identical to - the base class :class:`BatchController`'s function. - """ - - # Run jobs - paths = [] - for (i,job) in enumerate(self.jobs): - # Create output directory - data_dirname = ''.join((job.prefix,'_data')) - output_dirname = ''.join((job.prefix,"_output")) - plots_dirname = ''.join((job.prefix,"_plots")) - run_script_name = ''.join((job.prefix,"_run.sh")) - log_name = ''.join((job.prefix,"_log.txt")) - - - if len(job.type) > 0: - job_path = os.path.join(self.base_path,job.type,job.name) - else: - job_path = os.path.join(self.base_path,job.name) - job_path = os.path.abspath(job_path) - data_path = os.path.join(job_path,data_dirname) - output_path = os.path.join(job_path,output_dirname) - plots_path = os.path.join(job_path,plots_dirname) - log_path = os.path.join(job_path,log_name) - run_script_path = os.path.join(job_path,run_script_name) - paths.append({'job':job_path, 'data':data_path, - 'output':output_path, 'plots':plots_path, - 'log':log_path}) - - # Create job directory if not present - if not os.path.exists(job_path): - os.makedirs(job_path) - - # Clobber old data directory - if os.path.exists(data_path): - if not job.rundata.clawdata.restart: - data_files = glob.glob(os.path.join(data_path,'*.data')) - for data_file in data_files: - os.remove(data_file) - else: - os.mkdir(data_path) - - # Write out data - temp_path = os.getcwd() - os.chdir(data_path) - job.write_data_objects() - os.chdir(temp_path) - - # Handle restart requests - if job.rundata.clawdata.restart: - restart = "T" - overwrite = "F" - else: - restart = "F" - overwrite = "T" - - # Construct string commands - run_cmd = "%s %s %s %s %s %s True\n" % (self.runclaw_cmd, job.executable, output_path, - overwrite, restart, data_path) - - # Write slurm run script - run_script = open(run_script_path, 'w') - - run_script.write("#!/bin/sh\n") - run_script.write("#SBATCH -J %s # Job name\n" % job.prefix) - run_script.write("#SBATCH -o %s # Job name\n" % log_path) - run_script.write("#SBATCH -n 1 # Total number of MPI tasks requested\n") - run_script.write("#SBATCH -N 1 # Total number of MPI tasks requested\n") - run_script.write("#SBATCH -p %s # queue\n" % job.queue) - run_script.write("#SBATCH -t 9:00:00 # run time (hh:mm:ss)\n") - if self.email is not None: - run_script.write("#SBATCH --mail-user=%s" % self.email) - run_script.write("#SBATCH --mail-type=begin # email me when the job starts\n") - run_script.write("#SBATCH --mail-type=end # email me when the job finishes\n") - run_script.write("\n") - run_script.write("# OpenMP controls\n") - run_script.write("export OMP_NUM_THREADS=%s\n" % job.omp_num_threads) - run_script.write("export MIC_ENV_PREFIX=MIC") - run_script.write("export MIC_OMP_NUM_THREADS=%s\n" % job.mic_omp_num_threads) - run_script.write("export MIC_KMP_AFFINITY=%s\n" % job.mic_affinity) - run_script.write("\n") - run_script.write("# Run command\n") - run_script.write(run_cmd) - - run_script.close() - - # Submit job to queue - subprocess.Popen("sbatch %s > %s" % (run_script_path,log_path), shell=True).wait() - - # -- All jobs have been started -- - - return paths diff --git a/batch/storm.py b/batch/storm.py deleted file mode 100644 index b613113..0000000 --- a/batch/storm.py +++ /dev/null @@ -1,57 +0,0 @@ - -from __future__ import print_function -from __future__ import absolute_import - -import os -import numpy -import datetime - -import storm - -import batch.batch - -days2seconds = lambda days: days * 60.0**2 * 24.0 - -class StormJob(batch.batch.job): - r"""""" - - def __init__(self, storm_num, base_path='./', storms_path='./'): - - super(StormJob, self).__init__() - - self.type = "" - self.name = "" - self.prefix = str(storm_num).zfill(5) - self.storm_num = storm_num - self.executable = "xgeoclaw" - - # Create base data object - import setrun - self.rundata = setrun.setrun() - - # Storm specific data - self.storm_file_path = os.path.abspath(os.path.join(storms_path, - "%s.storm" % storm_num)) - - # Set storm file - self.rundata.storm_data.storm_file = self.storm_file_path - - # Change time frame of simulation... - # self.rundata.clawdata.t0 = days2seconds() - # self.rundata.clawdata.tfinal = days2seconds() - - - def __str__(self): - output = super(StormJob, self).__str__() - output += "\n Storm Number: %s" % self.storm_num - return output - - - def write_data_objects(self): - r"""""" - - # Write out all data files - super(StormJob, self).write_data_objects() - - # If any additional information per storm is needed do it here - # ... diff --git a/batch/sweep.py b/batch/sweep.py new file mode 100644 index 0000000..1b1f71d --- /dev/null +++ b/batch/sweep.py @@ -0,0 +1,110 @@ +"""Parameter sweep helpers for building job lists from parameter grids.""" + +from __future__ import annotations + +import itertools +from typing import Any, Callable + +from batch.job import Job + + +def product_sweep( + factory: Callable[..., Job], + namer: Callable[[dict[str, Any]], str], + **param_grid: list[Any], +) -> list[Job]: + """Build jobs from the Cartesian product of parameter lists. + + Parameters + ---------- + factory: + Callable that accepts keyword arguments drawn from *param_grid* and + returns a configured :class:`~batch.job.Job`. The ``prefix`` is set + by *namer* after construction, so the factory does not need to set it. + namer: + Callable mapping a parameter dict to a prefix string. + **param_grid: + Keyword arguments where each value is a list of options. All + combinations are enumerated. + + Returns + ------- + list[Job] + One job per combination in the Cartesian product, in row-major order. + + Examples + -------- + >>> jobs = product_sweep( + ... factory=lambda manning, level: MyJob(manning=manning, max_level=level), + ... namer=lambda p: f"n{p['manning']:.3f}_l{p['level']}", + ... manning=[0.020, 0.025, 0.030], + ... level=[4, 5], + ... ) + >>> len(jobs) # 3 × 2 + 6 + """ + keys = list(param_grid.keys()) + jobs: list[Job] = [] + for combo in itertools.product(*param_grid.values()): + params = dict(zip(keys, combo)) + job = factory(**params) + job.prefix = namer(params) + jobs.append(job) + return jobs + + +def zip_sweep( + factory: Callable[..., Job], + namer: Callable[[dict[str, Any]], str], + **param_grid: list[Any], +) -> list[Job]: + """Build jobs by pairing parameter lists element-wise (like ``zip``). + + All parameter lists must have the same length. This is useful when + parameters are not independent — for example, paired storm tracks and + intensities. + + Parameters + ---------- + factory: + Callable that accepts keyword arguments drawn from *param_grid*. + namer: + Callable mapping a parameter dict to a prefix string. + **param_grid: + Keyword arguments where each value is a list of options. Lists must + all have the same length. + + Returns + ------- + list[Job] + One job per index position. + + Raises + ------ + ValueError + If the parameter lists have different lengths. + + Examples + -------- + >>> jobs = zip_sweep( + ... factory=lambda storm_id, intensity: StormJob(storm_id, intensity), + ... namer=lambda p: f"{p['storm_id']}_{p['intensity']}", + ... storm_id=["katrina", "ike", "harvey"], + ... intensity=["low", "mid", "high"], + ... ) + >>> len(jobs) + 3 + """ + lengths = {k: len(v) for k, v in param_grid.items()} + if len(set(lengths.values())) > 1: + raise ValueError( + f"All parameter lists must have the same length. Got: {lengths}" + ) + keys = list(param_grid.keys()) + jobs: list[Job] = [] + for combo in zip(*param_grid.values()): + params = dict(zip(keys, combo)) + job = factory(**params) + job.prefix = namer(params) + jobs.append(job) + return jobs diff --git a/batch/tests.py b/batch/tests.py deleted file mode 100644 index 95078f7..0000000 --- a/batch/tests.py +++ /dev/null @@ -1,34 +0,0 @@ -#!/usr/bin/env python - -from __future__ import print_function -from __future__ import absolute_import - -import unittest - -import batch - -class TestJob(batch.job): - - def __init__(self, test_param): - - super(TestJob, self).__init__() - - self.rundata = None - self.type = None - self.name = None - self.prefix = None - self.executable = None - - def write_data_objects(self): - pass - -def test_failed_job(): - job = TestJob(1) - - job_controller = batch.BatchController([job]) - - job_controller.run() - - -if __name__ == "__main__": - unittest.main() diff --git a/batch/yeti.py b/batch/yeti.py deleted file mode 100644 index 17187a6..0000000 --- a/batch/yeti.py +++ /dev/null @@ -1,186 +0,0 @@ -r"""Batch sub-classes for runs on Columbia's Yeti cluster""" - -# ============================================================================ -# Copyright (C) 2013 Kyle Mandli -# -# Distributed under the terms of the MIT license -# http://www.opensource.org/licenses/ -# -# (adapted for yeti by Andrew Kaluzny , 2015) -# ============================================================================ - -from __future__ import print_function -from __future__ import absolute_import - -import os -import glob -import subprocess - -import batch - -class Job(batch.Job): - r""" - Modifications to the basic :class:`batch.Job` class for Yeti runs - - """ - - def __init__(self, use_v2=True, memory=4096, time=60, nodes=1, - omp_num_threads=1): - r""" - Initialize Yeti job - - See :class:`YetiJob` for full documentation - """ - - super(Job, self).__init__() - - # Add extra job parameters - self.omp_num_threads = omp_num_threads - ##self.mic_omp_num_threads = 1 - ##self.mic_affinity = "none" - self.memory = memory - self.time = time - self.nodes = nodes - self.use_v2 = use_v2 - - self.group = None # needs to be set at some point - - -class BatchController(batch.BatchController): - r""" - Modifications to the basic batch controller for Yeti runs - - - :Ignored Attributes: - - Due to the system setup, the following controller attributes are ignored: - - *plot*, *terminal_output*, *wait*, *poll_interval*, *plotclaw_cmd* - """ - - def __init__(self, jobs=[], email=None, email_behavior="abe", output=None): - r""" - Initialize Yeti batch controller - - See :class:`YetiBatchController` for full documentation - """ - - super(BatchController, self).__init__(jobs) - - # Yeti specific execution controls - self.email = email - self.email_behavior = email_behavior - self.output = output - self.queue = None - - def run(self): - r"""Run Yeti jobs from controller's *jobs* list. - - This run function is modified to run jobs through the slurm queue system - and provides controls for running serial jobs (OpenMP only). - - Unless otherwise noted, the behavior of this function is identical to - the base class :class:`BatchController`'s function. - """ - - # Run jobs - paths = [] - for (i,job) in enumerate(self.jobs): - # Create output directory - data_dirname = ''.join((job.prefix,'_data')) - output_dirname = ''.join((job.prefix,"_output")) - plots_dirname = ''.join((job.prefix,"_plots")) - run_script_name = ''.join((job.prefix,"_run.sh")) - log_name = ''.join((job.prefix,"_log.txt")) - - - if len(job.type) > 0: - job_path = os.path.join(self.base_path,job.type,job.name) - else: - job_path = os.path.join(self.base_path,job.name) - job_path = os.path.abspath(job_path) - data_path = os.path.join(job_path,data_dirname) - output_path = os.path.join(job_path,output_dirname) - plots_path = os.path.join(job_path,plots_dirname) - log_path = os.path.join(job_path,log_name) - run_script_path = os.path.join(job_path,run_script_name) - paths.append({'job':job_path, 'data':data_path, - 'output':output_path, 'plots':plots_path, - 'log':log_path}) - - # Create job directory if not present - if not os.path.exists(job_path): - os.makedirs(job_path) - - # Clobber old data directory - if os.path.exists(data_path): - if not job.rundata.clawdata.restart: - data_files = glob.glob(os.path.join(data_path,'*.data')) - for data_file in data_files: - os.remove(data_file) - else: - os.mkdir(data_path) - - # Write out data - temp_path = os.getcwd() - os.chdir(data_path) - job.write_data_objects() - os.chdir(temp_path) - - # Handle restart requests - if job.rundata.clawdata.restart: - restart = "T" - overwrite = "F" - else: - restart = "F" - overwrite = "T" - - # Construct string commands - run_cmd = "%s %s %s %s %s %s True\n" % (self.runclaw_cmd, job.executable, output_path, - overwrite, restart, data_path) - - # Write slurm run script - run_script = open(run_script_path, 'w') - - hours = int(job.time / 60) - minutes = job.time % 60 - - run_script.write("#!/bin/sh\n") - run_script.write("#PBS -N %s # Job name\n" % job.prefix) - run_script.write("#PBS -W group_list=%s # Group\n" % job.group) - run_script.write("#PBS -l mem=%smb # Memory\n" % job.memory) - run_script.write("#PBS -l walltime=00:%s:%s:00 # Walltime\n" % (hours, minutes)) - if job.use_v2: - run_script.write("#PBS -l nodes=%s:ppn=%s:v2 # Nodes and processers per node\n" % (job.nodes, job.omp_num_threads)) - else: - run_script.write("#PBS -l nodes=%s:ppn=%s # Nodes and processers per node\n" % (job.nodes, job.omp_num_threads)) - run_script.write("#PBS -V # export env. variables to the job\n") - if self.queue is not None: - run_script.write("#PBS -q %s # Requested queue\n" % self.queue) - if self.email is not None: - run_script.write("#PBS -M %s\n" % self.email) - run_script.write("#PBS -m %s # email for abort, begin, end\n" % self.email_behavior) - if self.output is not None: - run_script.write("#PBS -o localhost:%s # stdout\n" % self.output) - run_script.write("#PBS -e localhost:%s # stderr\n" % self.output) - run_script.write("\n") - run_script.write("# OpenMP controls\n") - run_script.write("export OMP_NUM_THREADS=%s\n" % job.omp_num_threads) - ## run_script.write("export MIC_ENV_PREFIX=MIC") - ## run_script.write("export MIC_OMP_NUM_THREADS=%s\n" % job.mic_omp_num_threads) - ## run_script.write("export MIC_KMP_AFFINITY=%s\n" % job.mic_affinity) - run_script.write("\n") - run_script.write("# Location of output\n") - run_script.write(self.output) - run_script.write("\n") - run_script.write("# Run command\n") - run_script.write(run_cmd) - - run_script.close() - - # Submit job to queue - subprocess.Popen("qsub %s > %s" % (run_script_path,log_path), shell=True).wait() - - # -- All jobs have been started -- - - return paths diff --git a/examples/local_ensemble/manning_job.py b/examples/local_ensemble/manning_job.py new file mode 100644 index 0000000..44d78c1 --- /dev/null +++ b/examples/local_ensemble/manning_job.py @@ -0,0 +1,73 @@ +"""Example Job subclass for a Manning's n sensitivity study. + +This module demonstrates the minimal pattern for defining a batch job: +subclass Job, populate rundata in __init__, override write_data_objects +if additional files are needed. + +This example is designed to be self-contained: it does not require an actual +Clawpack installation to import, though running the batch will. +""" + +from __future__ import annotations + +import importlib.util +from pathlib import Path + +from batch import Job +from batch.plot import plot_job + + +class ManningJob(Job): + """One GeoClaw run with a specific uniform Manning's n coefficient. + + Parameters + ---------- + manning: + Manning's roughness coefficient to apply uniformly over the domain. + max_level: + Maximum AMR refinement level. Allows coarsening for quick sweeps or + full-resolution runs without separate setrun files. + setrun_path: + Path to the ``setrun.py`` file that defines the base configuration. + Defaults to the ``setrun.py`` in the same directory as this file. + """ + + def __init__( + self, + manning: float, + max_level: int = 5, + setrun_path: Path | None = None, + ) -> None: + super().__init__() + + self.manning = manning + self.max_level = max_level + + # Prefix encodes the swept parameters for easy identification + self.prefix = f"n{manning:.3f}_l{max_level}" + self.executable = "xgeoclaw" + + # Load base configuration and apply parameter overrides + if setrun_path is None: + setrun_path = Path(__file__).parent / "setrun.py" + + spec = importlib.util.spec_from_file_location("setrun", setrun_path) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + + self.rundata = mod.setrun() + + # Override the Manning coefficient + self.rundata.geo_data.manning_coefficient = manning + + # Override the maximum refinement level + self.rundata.amrdata.amr_levels_max = max_level + + def post_run(self, result) -> None: + plot_job(result, setplot=Path(__file__).parent / "setplot.py") + + def __repr__(self) -> str: + return ( + f"ManningJob(prefix={self.prefix!r}, " + f"manning={self.manning}, max_level={self.max_level})" + ) diff --git a/examples/local_ensemble/run_batch.py b/examples/local_ensemble/run_batch.py new file mode 100644 index 0000000..746d3ab --- /dev/null +++ b/examples/local_ensemble/run_batch.py @@ -0,0 +1,155 @@ +"""Manning's n sensitivity ensemble — local parallel run. + +Demonstrates: +- product_sweep to generate a Cartesian parameter grid +- ParallelExecutor for local multi-process execution +- ClobberPolicy.SKIP for free resumability + +Usage +----- +From the example directory:: + + python run_batch.py + +or with a custom output path:: + + OUTPUT_PATH=/scratch/myproject python run_batch.py + +To do a dry run that only writes .data files:: + + python run_batch.py --setup-only + +To resume a partially-completed batch:: + + python run_batch.py --resume +""" + +from __future__ import annotations + +import argparse +import logging +import os +from pathlib import Path + +# Use non-interactive backend so plotting works in batch without a display +import matplotlib + +matplotlib.use("Agg") + +from manning_job import ManningJob + +from batch import BatchController, ClobberPolicy, ParallelExecutor +from batch.sweep import product_sweep + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(levelname)-8s %(name)s %(message)s", + datefmt="%H:%M:%S", +) + + +def make_jobs() -> list[ManningJob]: + """Define the parameter grid and return the job list.""" + return product_sweep( + factory=lambda manning, max_level: ManningJob( + manning=manning, + max_level=max_level, + setrun_path=Path(__file__).parent / "setrun.py", + ), + namer=lambda p: f"n{p['manning']:.3f}_l{p['max_level']}", + manning=[0.020, 0.025, 0.030, 0.035], + max_level=[4, 5], + ) + + +def plot_ensemble(results: list) -> None: + """Plot surface elevation vs time for all successful jobs on one figure. + + Reads ``fort.gauge`` from each job's output directory. Jobs without that + file are skipped with a warning. The figure is written next to the job + directories as ``ensemble_comparison.png``. + """ + import matplotlib.pyplot as plt + import numpy as np + + logger = logging.getLogger(__name__) + successful = [r for r in results if r.success] + if not successful: + logger.warning("No successful jobs to plot in plot_ensemble.") + return + + fig, ax = plt.subplots() + for r in successful: + gauge_file = r.paths.job / "fort.gauge" + if not gauge_file.exists(): + logger.warning("fort.gauge not found for %s, skipping.", r.job.prefix) + continue + try: + data = np.loadtxt(gauge_file) + # fort.gauge columns: gauge_num, level, time, q[0], q[1], q[2], eta + ax.plot(data[:, 2], data[:, 6], label=r.job.prefix) + except Exception as exc: + logger.warning("Failed to load gauge data for %s: %s", r.job.prefix, exc) + + ax.set_xlabel("Time (s)") + ax.set_ylabel("Surface elevation (m)") + ax.legend() + + out_path = successful[0].paths.job.parent / "ensemble_comparison.png" + fig.savefig(out_path) + plt.close(fig) + logger.info("Ensemble comparison written to %s", out_path) + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--setup-only", + action="store_true", + help="Write .data files only; do not run the solver.", + ) + parser.add_argument( + "--resume", + action="store_true", + help="Skip jobs whose output directory already exists.", + ) + parser.add_argument( + "--max-workers", + type=int, + default=int(os.environ.get("BATCH_MAX_JOBS", 4)), + help="Maximum concurrent jobs (default: $BATCH_MAX_JOBS or 4).", + ) + args = parser.parse_args() + + jobs = make_jobs() + + clobber = ClobberPolicy.SKIP if args.resume else ClobberPolicy.OVERWRITE + + ctrl = BatchController( + jobs=jobs, + executor=ParallelExecutor(max_workers=args.max_workers), + experiment="manning_sensitivity", + clobber=clobber, + ) + + if args.setup_only: + paths = ctrl.setup() + print(f"Setup complete for {len(paths)} job(s).") + return + + results = ctrl.run(wait=True) + + n_ok = sum(1 for r in results if r.success) + n_fail = sum(1 for r in results if not r.success and r.returncode is not None) + print(f"\nCompleted: {n_ok}/{len(results)} successful, {n_fail} failed.") + + if n_fail: + for r in results: + if r.returncode is not None and r.returncode != 0: + print(f" FAILED: {r.job.prefix} (see {r.paths.log})") + + plot_ensemble(results) + + +if __name__ == "__main__": + main() diff --git a/examples/storm_surge/storm_batch.py b/examples/storm_surge/storm_batch.py new file mode 100644 index 0000000..a02f609 --- /dev/null +++ b/examples/storm_surge/storm_batch.py @@ -0,0 +1,198 @@ +"""Storm surge ensemble — SLURM submission example. + +Demonstrates: +- Subclassing Job for storm-file-driven GeoClaw runs +- Per-job SLURMResources override +- SLURMExecutor with dry_run for script inspection +- zip_sweep for paired (storm_id, intensity) runs + +Directory layout produced:: + + OUTPUT_PATH/ + storm_ensemble/ + 00001/ ← one directory per storm + 00001_log.txt + 00001_run.sh + *.data + fort.* + plots/ + 00002/ + ... + +Usage +----- +Dry run (inspect generated scripts without submitting):: + + python run_batch.py --dry-run + +Submit to SLURM:: + + python run_batch.py + +Resume after partial completion:: + + python run_batch.py --resume +""" + +from __future__ import annotations + +import argparse +import importlib.util +import logging +from pathlib import Path + +from batch import ( + BatchController, + ClobberPolicy, + Job, + SLURMExecutor, + SLURMResources, +) + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(levelname)-8s %(name)s %(message)s", + datefmt="%H:%M:%S", +) + + +# --------------------------------------------------------------------------- +# Job definition +# --------------------------------------------------------------------------- + + +class StormJob(Job): + """One GeoClaw storm-surge simulation driven by a single ``.storm`` file. + + Parameters + ---------- + storm_num: + Integer storm identifier. Becomes the zero-padded prefix and is used + to locate the storm file. + storms_path: + Directory containing ``.storm`` files named ``.storm``. + setrun_path: + Path to the base ``setrun.py``. + cpus: + Number of OpenMP threads for this job; controls both the SLURM + ``--cpus-per-task`` request and the ``OMP_NUM_THREADS`` export. + """ + + def __init__( + self, + storm_num: int, + storms_path: Path = Path("."), + setrun_path: Path | None = None, + cpus: int = 8, + ) -> None: + super().__init__() + + self.storm_num = storm_num + self.prefix = str(storm_num).zfill(5) + self.executable = "xgeoclaw" + + if setrun_path is None: + setrun_path = Path(__file__).parent / "setrun.py" + + # Use clawutil's fullpath_import if available: + # mod = clawutil.fullpath_import(setrun_path) + spec = importlib.util.spec_from_file_location("setrun", setrun_path) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + self.rundata = mod.setrun() + + storm_file = (Path(storms_path) / f"{storm_num}.storm").resolve() + self.rundata.surge_data.storm_file = str(storm_file) + + # SLURM resource request — attached directly to the job so the + # executor applies it without needing a subclass. + self.slurm_resources = SLURMResources( + partition="main", + nodes=1, + ntasks_per_node=1, + cpus_per_task=cpus, + time="06:00:00", + account="", # fill in your allocation + env_vars={"OMP_NUM_THREADS": str(cpus)}, + modules=["ncarenv/23.09", "python/3.11.4"], + ) + + def __repr__(self) -> str: + return f"StormJob(storm_num={self.storm_num}, prefix={self.prefix!r})" + + +# --------------------------------------------------------------------------- +# Run script +# --------------------------------------------------------------------------- + + +def make_jobs(storms_path: Path, setrun_path: Path) -> list[StormJob]: + """Build jobs for storms 1–100.""" + return [ + StormJob( + storm_num=n, + storms_path=storms_path, + setrun_path=setrun_path, + ) + for n in range(1, 101) + ] + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--dry-run", + action="store_true", + help="Write submission scripts but do not call sbatch.", + ) + parser.add_argument( + "--resume", + action="store_true", + help="Skip jobs whose output directory already exists.", + ) + parser.add_argument( + "--storms-path", + type=Path, + default=Path(__file__).parent / "storms", + help="Directory containing .storm files.", + ) + parser.add_argument( + "--setrun", + type=Path, + default=Path(__file__).parent / "setrun.py", + help="Path to setrun.py.", + ) + args = parser.parse_args() + + jobs = make_jobs(storms_path=args.storms_path, setrun_path=args.setrun) + + executor = SLURMExecutor( + default_resources=SLURMResources( + partition="main", + nodes=1, + cpus_per_task=8, + time="06:00:00", + ), + dry_run=args.dry_run, + ) + + ctrl = BatchController( + jobs=jobs, + executor=executor, + experiment="storm_ensemble", + clobber=ClobberPolicy.SKIP if args.resume else ClobberPolicy.OVERWRITE, + ) + + # For SLURM we do not wait — sbatch returns immediately + results = ctrl.run(wait=False) + + if args.dry_run: + print(f"Dry run: {len(results)} script(s) written, none submitted.") + else: + print(f"Submitted {len(results)} job(s) to SLURM.") + for r in results: + print(f" {r.job.prefix} → SLURM job {r.job_id}") + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..8bb4f50 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,43 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "clawpack-batch" +version = "2.0.0" +description = "Utilities for running Clawpack/GeoClaw batch jobs" +readme = "README.md" +license = { text = "MIT" } +requires-python = ">=3.10" +# No hard dependencies — clawpack is assumed to be present in the environment. +dependencies = [] + +[project.optional-dependencies] +dev = [ + "pytest>=7", + "pytest-cov", + "ruff", +] +lint = [ + "ruff", +] + +[tool.ruff] +target-version = "py310" + +[tool.ruff.lint] +select = ["E", "F", "W", "I"] +# from __future__ import annotations is intentional for forward reference support; +# silence the pyupgrade rule that would remove it on Python >= 3.10. +ignore = ["UP015"] + +[tool.hatch.build.targets.wheel] +packages = ["batch"] + +[tool.pytest.ini_options] +testpaths = ["tests"] +# Tests that actually run the solver require a Clawpack install; mark them +# so they can be skipped in CI that only has a Python environment. +markers = [ + "integration: requires a compiled Clawpack executable", +] diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..c6f56ce --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,56 @@ +"""Shared pytest fixtures for the batch test suite. + +All fixtures avoid any dependency on an installed Clawpack or a running +scheduler. A ``MockJob`` with a mock ``rundata`` is the primary test double. +""" + +from __future__ import annotations + +from pathlib import Path +from unittest.mock import MagicMock + +import pytest + +from batch.job import Job, JobPaths + + +class MockJob(Job): + """Minimal concrete Job for testing. + + ``write_data_objects`` writes a dummy ``.data`` file so directory-setup + logic in the controller has real filesystem state to work with. + ``rundata`` is a MagicMock so attribute access never raises. + """ + + def __init__(self, prefix: str = "job_001") -> None: + super().__init__() + self.prefix = prefix + self.rundata = MagicMock() + # Track calls for assertion in tests + self._write_calls: list[Path] = [] + + def write_data_objects(self, path: Path) -> None: + path.mkdir(parents=True, exist_ok=True) + (path / "claw.data").write_text("mock claw data\n") + self._write_calls.append(path) + + +@pytest.fixture +def mock_job() -> MockJob: + return MockJob(prefix="job_001") + + +@pytest.fixture +def three_jobs() -> list[MockJob]: + return [MockJob(prefix=f"job_{i:03d}") for i in range(3)] + + +@pytest.fixture +def job_paths(tmp_path: Path) -> JobPaths: + job_dir = tmp_path / "job_001" + job_dir.mkdir() + return JobPaths( + job=job_dir, + plots=job_dir / "plots", + log=job_dir / "job_001_log.txt", + ) diff --git a/tests/test_controller.py b/tests/test_controller.py new file mode 100644 index 0000000..da9e306 --- /dev/null +++ b/tests/test_controller.py @@ -0,0 +1,280 @@ +"""Tests for BatchController: path layout, clobber policies, setup, run.""" + +from __future__ import annotations + +from unittest.mock import MagicMock + +import pytest + +from batch.controller import BatchController +from batch.job import ClobberPolicy, JobResult +from tests.conftest import MockJob + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def make_stub_executor(returncode: int = 0): + """Return a mock Executor whose submit() returns a successful JobResult.""" + executor = MagicMock() + + def _submit(job, paths): + result = JobResult(job=job, paths=paths, returncode=returncode) + job.paths = paths + return result + + executor.submit.side_effect = _submit + executor.wait_all.side_effect = lambda results: results + return executor + + +# --------------------------------------------------------------------------- +# _make_paths +# --------------------------------------------------------------------------- + + +class TestMakePaths: + def test_paths_under_base(self, tmp_path): + ctrl = BatchController(base_path=tmp_path) + job = MockJob(prefix="run_001") + paths = ctrl._make_paths(job) + assert paths.job == tmp_path / "run_001" + assert paths.plots == tmp_path / "run_001" / "plots" + assert paths.log == tmp_path / "run_001" / "run_001_log.txt" + + def test_paths_with_experiment(self, tmp_path): + ctrl = BatchController(base_path=tmp_path, experiment="hurricane_ike") + job = MockJob(prefix="run_001") + paths = ctrl._make_paths(job) + assert paths.job == tmp_path / "hurricane_ike" / "run_001" + + def test_raises_without_prefix(self, tmp_path): + ctrl = BatchController(base_path=tmp_path) + job = MockJob() + job.prefix = None + with pytest.raises(ValueError, match="no prefix"): + ctrl._make_paths(job) + + +# --------------------------------------------------------------------------- +# _setup_job_dir — clobber policies +# --------------------------------------------------------------------------- + + +class TestSetupJobDir: + def test_overwrite_creates_directory(self, tmp_path): + ctrl = BatchController(base_path=tmp_path, clobber=ClobberPolicy.OVERWRITE) + job = MockJob(prefix="job_001") + paths = ctrl._make_paths(job) + result = ctrl._setup_job_dir(job, paths) + assert result is True + assert paths.job.is_dir() + + def test_overwrite_removes_data_files(self, tmp_path): + job = MockJob(prefix="job_001") + ctrl = BatchController(base_path=tmp_path, clobber=ClobberPolicy.OVERWRITE) + paths = ctrl._make_paths(job) + paths.job.mkdir(parents=True) + stale = paths.job / "claw.data" + stale.write_text("old data\n") + keeper = paths.job / "fort.q0001" + keeper.write_text("output\n") + + ctrl._setup_job_dir(job, paths) + + assert not stale.exists(), "stale .data file should have been removed" + assert keeper.exists(), "fort.* output should be untouched" + + def test_overwrite_keeps_data_files_on_restart(self, tmp_path): + job = MockJob(prefix="job_001") + job.restart = True + ctrl = BatchController(base_path=tmp_path, clobber=ClobberPolicy.OVERWRITE) + paths = ctrl._make_paths(job) + paths.job.mkdir(parents=True) + data_file = paths.job / "claw.data" + data_file.write_text("restart data\n") + + ctrl._setup_job_dir(job, paths) + + assert data_file.exists(), ".data file must be preserved for restart" + + def test_error_policy_raises_on_existing_dir(self, tmp_path): + job = MockJob(prefix="job_001") + ctrl = BatchController(base_path=tmp_path, clobber=ClobberPolicy.ERROR) + paths = ctrl._make_paths(job) + paths.job.mkdir(parents=True) + + with pytest.raises(FileExistsError, match="already exists"): + ctrl._setup_job_dir(job, paths) + + def test_error_policy_passes_for_new_dir(self, tmp_path): + job = MockJob(prefix="job_001") + ctrl = BatchController(base_path=tmp_path, clobber=ClobberPolicy.ERROR) + paths = ctrl._make_paths(job) + result = ctrl._setup_job_dir(job, paths) + assert result is True + + def test_skip_policy_returns_false_on_existing_dir(self, tmp_path): + job = MockJob(prefix="job_001") + ctrl = BatchController(base_path=tmp_path, clobber=ClobberPolicy.SKIP) + paths = ctrl._make_paths(job) + paths.job.mkdir(parents=True) + + result = ctrl._setup_job_dir(job, paths) + assert result is False + + def test_skip_policy_returns_true_for_new_dir(self, tmp_path): + job = MockJob(prefix="job_001") + ctrl = BatchController(base_path=tmp_path, clobber=ClobberPolicy.SKIP) + paths = ctrl._make_paths(job) + result = ctrl._setup_job_dir(job, paths) + assert result is True + + +# --------------------------------------------------------------------------- +# setup() +# --------------------------------------------------------------------------- + + +class TestSetup: + def test_setup_creates_directories_and_data_files(self, tmp_path, three_jobs): + ctrl = BatchController( + jobs=three_jobs, + base_path=tmp_path, + executor=make_stub_executor(), + ) + paths_list = ctrl.setup() + assert len(paths_list) == 3 + for paths in paths_list: + assert paths.job.is_dir() + assert (paths.job / "claw.data").exists() + + def test_setup_assigns_paths_to_job(self, tmp_path, mock_job): + ctrl = BatchController( + jobs=[mock_job], + base_path=tmp_path, + executor=make_stub_executor(), + ) + ctrl.setup() + assert mock_job.paths is not None + assert mock_job.paths.job == tmp_path / "job_001" + + def test_setup_writes_log_header(self, tmp_path, mock_job): + ctrl = BatchController( + jobs=[mock_job], + base_path=tmp_path, + executor=make_stub_executor(), + ) + ctrl.setup() + log = mock_job.paths.log.read_text() + assert "Started" in log + + def test_setup_skips_existing_dirs_under_skip_policy(self, tmp_path): + job_a = MockJob(prefix="job_000") + job_b = MockJob(prefix="job_001") + ctrl = BatchController( + jobs=[job_a, job_b], + base_path=tmp_path, + clobber=ClobberPolicy.SKIP, + ) + # Pre-create job_a directory + (tmp_path / "job_000").mkdir(parents=True) + paths_list = ctrl.setup() + # Only job_b should have been set up + assert len(paths_list) == 1 + assert paths_list[0].job == tmp_path / "job_001" + + +# --------------------------------------------------------------------------- +# run() +# --------------------------------------------------------------------------- + + +class TestRun: + def test_run_calls_submit_for_each_job(self, tmp_path, three_jobs): + executor = make_stub_executor() + ctrl = BatchController( + jobs=three_jobs, + base_path=tmp_path, + executor=executor, + ) + results = ctrl.run() + assert executor.submit.call_count == 3 + assert len(results) == 3 + + def test_run_calls_wait_all_when_wait_true(self, tmp_path, mock_job): + executor = make_stub_executor() + ctrl = BatchController( + jobs=[mock_job], + base_path=tmp_path, + executor=executor, + ) + ctrl.run(wait=True) + executor.wait_all.assert_called_once() + + def test_run_skips_wait_all_when_wait_false(self, tmp_path, mock_job): + executor = make_stub_executor() + ctrl = BatchController( + jobs=[mock_job], + base_path=tmp_path, + executor=executor, + ) + ctrl.run(wait=False) + executor.wait_all.assert_not_called() + + def test_run_calls_build_on_each_job(self, tmp_path): + job = MockJob(prefix="job_001") + job.build = MagicMock() + ctrl = BatchController( + jobs=[job], + base_path=tmp_path, + executor=make_stub_executor(), + ) + ctrl.run() + job.build.assert_called_once() + + def test_run_returns_results_with_correct_prefix(self, tmp_path, three_jobs): + executor = make_stub_executor() + ctrl = BatchController( + jobs=three_jobs, + base_path=tmp_path, + executor=executor, + ) + results = ctrl.run() + prefixes = {r.job.prefix for r in results} + assert prefixes == {j.prefix for j in three_jobs} + + def test_run_skips_jobs_under_skip_policy(self, tmp_path): + job_a = MockJob(prefix="job_000") + job_b = MockJob(prefix="job_001") + executor = make_stub_executor() + ctrl = BatchController( + jobs=[job_a, job_b], + base_path=tmp_path, + clobber=ClobberPolicy.SKIP, + executor=executor, + ) + (tmp_path / "job_000").mkdir(parents=True) + results = ctrl.run() + # Only job_b submitted + assert len(results) == 1 + assert results[0].job.prefix == "job_001" + + def test_run_logs_warning_on_failures(self, tmp_path, mock_job, caplog): + import logging + + executor = make_stub_executor(returncode=1) + ctrl = BatchController( + jobs=[mock_job], + base_path=tmp_path, + executor=executor, + ) + with caplog.at_level(logging.WARNING, logger="batch.controller"): + ctrl.run() + assert any("failed" in rec.message for rec in caplog.records) + + def test_base_path_falls_back_to_env_var(self, tmp_path, mock_job, monkeypatch): + monkeypatch.setenv("OUTPUT_PATH", str(tmp_path)) + ctrl = BatchController(jobs=[mock_job], executor=make_stub_executor()) + assert ctrl.base_path == tmp_path diff --git a/tests/test_executors_local.py b/tests/test_executors_local.py new file mode 100644 index 0000000..7ea4d12 --- /dev/null +++ b/tests/test_executors_local.py @@ -0,0 +1,211 @@ +"""Tests for SerialExecutor and ParallelExecutor. + +No Clawpack installation is required. subprocess.run / subprocess.Popen are +patched to avoid actually launching xgeoclaw. +""" + +from __future__ import annotations + +import subprocess +import sys +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + +from batch.executors.local import ( + ParallelExecutor, + SerialExecutor, + _build_run_args, +) +from batch.job import JobPaths, JobResult +from tests.conftest import MockJob + + +@pytest.fixture +def paths(tmp_path: Path) -> JobPaths: + job_dir = tmp_path / "job_001" + job_dir.mkdir() + return JobPaths( + job=job_dir, + plots=job_dir / "plots", + log=job_dir / "job_001_log.txt", + ) + + +# --------------------------------------------------------------------------- +# _build_run_args +# --------------------------------------------------------------------------- + + +class TestBuildRunArgs: + def test_uses_sys_executable(self, paths): + job = MockJob(prefix="job_001") + args = _build_run_args(job, paths) + assert args[0] == sys.executable + + def test_invokes_runclaw_module(self, paths): + job = MockJob(prefix="job_001") + args = _build_run_args(job, paths) + assert args[1:3] == ["-m", "clawpack.clawutil.runclaw"] + + def test_overwrite_flag_when_not_restarting(self, paths): + # args: [python, -m, runclaw, exe, outdir, overwrite, restart, rundir, verbose] + # 0 1 2 3 4 5 6 7 8 + job = MockJob(prefix="job_001") + job.restart = False + args = _build_run_args(job, paths) + assert args[5] == "T" # overwrite + assert args[6] == "F" # restart + + def test_restart_flags_when_restarting(self, paths): + job = MockJob(prefix="job_001") + job.restart = True + args = _build_run_args(job, paths) + assert args[5] == "F" # overwrite + assert args[6] == "T" # restart + + def test_outdir_and_rundir_are_same(self, paths): + job = MockJob(prefix="job_001") + args = _build_run_args(job, paths) + outdir = args[4] + rundir = args[7] + assert outdir == rundir == str(paths.job) + + +# --------------------------------------------------------------------------- +# SerialExecutor +# --------------------------------------------------------------------------- + + +class TestSerialExecutor: + def test_submit_returns_success_result(self, paths): + job = MockJob(prefix="job_001") + executor = SerialExecutor() + mock_result = MagicMock() + mock_result.returncode = 0 + with patch("batch.executors.local.subprocess.run", return_value=mock_result): + result = executor.submit(job, paths) + assert result.returncode == 0 + assert result.job is job + + def test_submit_returns_failure_returncode(self, paths): + job = MockJob(prefix="job_001") + executor = SerialExecutor() + mock_result = MagicMock() + mock_result.returncode = 1 + with patch("batch.executors.local.subprocess.run", return_value=mock_result): + result = executor.submit(job, paths) + assert result.returncode == 1 + + def test_submit_writes_to_log(self, paths): + job = MockJob(prefix="job_001") + executor = SerialExecutor() + mock_result = MagicMock() + mock_result.returncode = 0 + with patch("batch.executors.local.subprocess.run", return_value=mock_result): + executor.submit(job, paths) + assert paths.log.exists() + + def test_wait_all_is_identity(self, paths): + job = MockJob(prefix="job_001") + result = JobResult(job=job, paths=paths, returncode=0) + executor = SerialExecutor() + returned = executor.wait_all([result]) + assert returned == [result] or returned == [result] + + def test_extra_args_appended(self, paths): + job = MockJob(prefix="job_001") + executor = SerialExecutor(extra_args=["--extra", "flag"]) + captured_args = [] + mock_result = MagicMock() + mock_result.returncode = 0 + + def fake_run(args, **kwargs): + captured_args.extend(args) + return mock_result + + with patch("batch.executors.local.subprocess.run", side_effect=fake_run): + executor.submit(job, paths) + assert "--extra" in captured_args + assert "flag" in captured_args + + +# --------------------------------------------------------------------------- +# ParallelExecutor._drain — the core correctness test +# --------------------------------------------------------------------------- + + +class TestParallelExecutorDrain: + """Test _drain without actually spawning processes.""" + + def _make_mock_proc(self, poll_return): + proc = MagicMock(spec=subprocess.Popen) + proc.poll.return_value = poll_return + return proc + + def test_drain_removes_completed_processes(self, paths): + executor = ParallelExecutor(max_workers=4) + job = MockJob(prefix="job_001") + + done_proc = self._make_mock_proc(poll_return=0) + running_proc = self._make_mock_proc(poll_return=None) + done_result = JobResult(job=job, paths=paths, returncode=None) + running_result = JobResult(job=job, paths=paths, returncode=None) + log_fh = MagicMock() + + executor._active = [ + (done_proc, done_result, log_fh), + (running_proc, running_result, log_fh), + ] + executor._drain() + + assert len(executor._active) == 1 + assert executor._active[0][1] is running_result + + def test_drain_sets_returncode_on_completed(self, paths): + executor = ParallelExecutor(max_workers=4) + job = MockJob(prefix="job_001") + + proc = self._make_mock_proc(poll_return=0) + result = JobResult(job=job, paths=paths, returncode=None) + log_fh = MagicMock() + executor._active = [(proc, result, log_fh)] + + executor._drain() + + assert result.returncode == 0 + + def test_drain_closes_log_handle_on_completion(self, paths): + executor = ParallelExecutor(max_workers=4) + job = MockJob(prefix="job_001") + + proc = self._make_mock_proc(poll_return=0) + result = JobResult(job=job, paths=paths, returncode=None) + log_fh = MagicMock() + executor._active = [(proc, result, log_fh)] + + executor._drain() + + log_fh.close.assert_called_once() + + def test_drain_does_not_skip_consecutive_completed_processes(self, paths): + """Regression test: the original list-modify-while-iterating bug + caused every other completed process to be silently skipped.""" + executor = ParallelExecutor(max_workers=8) + job = MockJob(prefix="job_001") + results = [] + log_fh = MagicMock() + + for _ in range(6): + proc = self._make_mock_proc(poll_return=0) + result = JobResult(job=job, paths=paths, returncode=None) + results.append(result) + executor._active.append((proc, result, log_fh)) + + executor._drain() + + assert executor._active == [], "all completed processes should be drained" + assert all(r.returncode == 0 for r in results), ( + "all results should have their returncode set" + ) diff --git a/tests/test_job.py b/tests/test_job.py new file mode 100644 index 0000000..28521a9 --- /dev/null +++ b/tests/test_job.py @@ -0,0 +1,165 @@ +"""Tests for batch.job — Job, JobPaths, JobResult, ClobberPolicy.""" + +from __future__ import annotations + +import subprocess +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + +from batch.executors.local import ParallelExecutor, SerialExecutor +from batch.job import ClobberPolicy, Job, JobPaths, JobResult +from tests.conftest import MockJob + + +class TestJob: + def test_default_attributes(self): + job = Job() + assert job.prefix is None + assert job.executable == "xgeoclaw" + assert job.setplot == "setplot" + assert job.restart is False + assert job.paths is None + assert job.rundata is None + + def test_repr(self): + job = MockJob(prefix="test_001") + assert "MockJob" in repr(job) + assert "test_001" in repr(job) + + def test_write_data_objects_calls_rundata_write(self, tmp_path): + job = Job() + job.prefix = "job_001" + job.rundata = MagicMock() + job.write_data_objects(tmp_path) + job.rundata.write.assert_called_once_with(out_dir=tmp_path) + + def test_write_data_objects_raises_without_rundata(self, tmp_path): + job = Job() + job.prefix = "job_001" + with pytest.raises(ValueError, match="rundata is not set"): + job.write_data_objects(tmp_path) + + def test_build_is_noop_by_default(self, job_paths): + """build() must not raise or produce side effects by default.""" + job = MockJob() + job.build(job_paths) # should not raise + + def test_mock_job_write_creates_data_file(self, tmp_path): + job = MockJob(prefix="abc") + job.write_data_objects(tmp_path) + assert (tmp_path / "claw.data").exists() + assert len(job._write_calls) == 1 + assert job._write_calls[0] == tmp_path + + +class TestJobPaths: + def test_fields_are_paths(self, tmp_path): + paths = JobPaths( + job=tmp_path / "job", + plots=tmp_path / "job" / "plots", + log=tmp_path / "job" / "job_log.txt", + ) + assert isinstance(paths.job, Path) + assert isinstance(paths.plots, Path) + assert isinstance(paths.log, Path) + + +class TestJobResult: + def test_success_true_when_returncode_zero(self, job_paths): + job = MockJob() + result = JobResult(job=job, paths=job_paths, returncode=0) + assert result.success is True + + def test_success_false_when_nonzero(self, job_paths): + job = MockJob() + result = JobResult(job=job, paths=job_paths, returncode=1) + assert result.success is False + + def test_success_false_when_none(self, job_paths): + job = MockJob() + result = JobResult(job=job, paths=job_paths, returncode=None) + assert result.success is False + + def test_pending_true_when_returncode_none(self, job_paths): + job = MockJob() + result = JobResult(job=job, paths=job_paths, returncode=None) + assert result.pending is True + + def test_pending_false_when_returncode_set(self, job_paths): + job = MockJob() + result = JobResult(job=job, paths=job_paths, returncode=0) + assert result.pending is False + + +class TestClobberPolicy: + def test_all_values_present(self): + assert ClobberPolicy.OVERWRITE.value == "overwrite" + assert ClobberPolicy.ERROR.value == "error" + assert ClobberPolicy.SKIP.value == "skip" + + +class TestPostRun: + def test_post_run_is_noop_by_default(self, job_paths): + job = Job() + job.prefix = "job_001" + result = JobResult(job=job, paths=job_paths, returncode=0) + job.post_run(result) # must not raise + + def test_post_run_called_on_success_serial(self, job_paths): + job = MockJob(prefix="job_001") + executor = SerialExecutor() + mock_proc = MagicMock() + mock_proc.returncode = 0 + with patch("batch.executors.local.subprocess.run", return_value=mock_proc): + with patch.object(job, "post_run") as mock_post_run: + result = executor.submit(job, job_paths) + mock_post_run.assert_called_once_with(result) + + def test_post_run_not_called_on_failure_serial(self, job_paths): + job = MockJob(prefix="job_001") + executor = SerialExecutor() + mock_proc = MagicMock() + mock_proc.returncode = 1 + with patch("batch.executors.local.subprocess.run", return_value=mock_proc): + with patch.object(job, "post_run") as mock_post_run: + executor.submit(job, job_paths) + mock_post_run.assert_not_called() + + def test_post_run_called_on_success_parallel(self, job_paths): + executor = ParallelExecutor(max_workers=4) + job = MockJob(prefix="job_001") + proc = MagicMock(spec=subprocess.Popen) + proc.poll.return_value = 0 + result = JobResult(job=job, paths=job_paths, returncode=None) + log_fh = MagicMock() + executor._active = [(proc, result, log_fh)] + with patch.object(job, "post_run") as mock_post_run: + executor._drain() + mock_post_run.assert_called_once_with(result) + + def test_post_run_not_called_on_failure_parallel(self, job_paths): + executor = ParallelExecutor(max_workers=4) + job = MockJob(prefix="job_001") + proc = MagicMock(spec=subprocess.Popen) + proc.poll.return_value = 1 + result = JobResult(job=job, paths=job_paths, returncode=None) + log_fh = MagicMock() + executor._active = [(proc, result, log_fh)] + with patch.object(job, "post_run") as mock_post_run: + executor._drain() + mock_post_run.assert_not_called() + + def test_post_run_exception_does_not_propagate(self, job_paths): + class ExplodingJob(MockJob): + def post_run(self, result): + raise RuntimeError("boom") + + job = ExplodingJob(prefix="job_001") + executor = SerialExecutor() + mock_proc = MagicMock() + mock_proc.returncode = 0 + with patch("batch.executors.local.subprocess.run", return_value=mock_proc): + result = executor.submit(job, job_paths) # must not raise + assert result.returncode == 0 diff --git a/tests/test_plot.py b/tests/test_plot.py new file mode 100644 index 0000000..30de0d0 --- /dev/null +++ b/tests/test_plot.py @@ -0,0 +1,123 @@ +"""Tests for batch.plot.plot_job and batch.plot._plot_inprocess.""" + +from __future__ import annotations + +import sys +from unittest.mock import MagicMock, patch + +from batch.job import JobResult +from batch.plot import _plot_inprocess, plot_job +from tests.conftest import MockJob + + +class TestPlotJob: + def test_plot_job_returns_true_on_success(self, job_paths): + job = MockJob(prefix="job_001") + result = JobResult(job=job, paths=job_paths, returncode=0) + mock_proc = MagicMock(returncode=0) + with patch("batch.plot.subprocess.run", return_value=mock_proc): + assert plot_job(result) is True + assert "--- plotclaw ---" in job_paths.log.read_text() + + def test_plot_job_returns_false_on_nonzero_returncode(self, job_paths): + job = MockJob(prefix="job_001") + result = JobResult(job=job, paths=job_paths, returncode=0) + mock_proc = MagicMock(returncode=1) + with patch("batch.plot.subprocess.run", return_value=mock_proc): + assert plot_job(result) is False + + def test_plot_job_resolves_relative_setplot_against_job_dir(self, job_paths): + job = MockJob(prefix="job_001") + result = JobResult(job=job, paths=job_paths, returncode=0) + (job_paths.job / "setplot.py").write_text("# dummy\n") + captured = {} + + def capture_run(args, **kwargs): + captured["args"] = args + return MagicMock(returncode=0) + + with patch("batch.plot.subprocess.run", side_effect=capture_run): + plot_job(result, setplot="setplot.py") + + assert captured["args"][-1] == str(job_paths.job / "setplot.py") + + def test_plot_job_passes_absolute_setplot_unchanged(self, job_paths, tmp_path): + job = MockJob(prefix="job_001") + result = JobResult(job=job, paths=job_paths, returncode=0) + abs_path = tmp_path / "custom_setplot.py" + abs_path.touch() + captured = {} + + def capture_run(args, **kwargs): + captured["args"] = args + return MagicMock(returncode=0) + + with patch("batch.plot.subprocess.run", side_effect=capture_run): + plot_job(result, setplot=abs_path) + + assert captured["args"][-1] == str(abs_path.resolve()) + + def test_plot_job_output_appended_to_log(self, job_paths): + job = MockJob(prefix="job_001") + result = JobResult(job=job, paths=job_paths, returncode=0) + with patch("batch.plot.subprocess.run", return_value=MagicMock(returncode=0)): + plot_job(result) + assert "--- plotclaw ---" in job_paths.log.read_text() + + def test_plot_job_callable_setplot_uses_inprocess_fallback(self, job_paths): + job = MockJob(prefix="job_001") + result = JobResult(job=job, paths=job_paths, returncode=0) + + def setplot_fn(): + pass + + with patch("batch.plot._plot_inprocess", return_value=True) as mock_inproc: + with patch("batch.plot.subprocess.run") as mock_run: + plot_job(result, setplot=setplot_fn) + mock_inproc.assert_called_once() + mock_run.assert_not_called() + + +class TestPlotInprocess: + def test_plot_inprocess_returns_false_when_visclaw_not_importable(self, job_paths): + job = MockJob(prefix="job_001") + result = JobResult(job=job, paths=job_paths, returncode=0) + with patch.dict( + sys.modules, + { + "clawpack": MagicMock(), + "clawpack.visclaw": MagicMock(), + "clawpack.visclaw.plotclaw": None, + }, + ): + assert _plot_inprocess(result, "setplot.py", "ascii") is False + + def test_plot_inprocess_returns_false_on_exception(self, job_paths): + job = MockJob(prefix="job_001") + result = JobResult(job=job, paths=job_paths, returncode=0) + mock_module = MagicMock() + mock_module.plotclaw = MagicMock(side_effect=RuntimeError("boom")) + with patch.dict( + sys.modules, + { + "clawpack": MagicMock(), + "clawpack.visclaw": MagicMock(), + "clawpack.visclaw.plotclaw": mock_module, + }, + ): + assert _plot_inprocess(result, "setplot.py", "ascii") is False + + def test_plot_inprocess_returns_true_on_success(self, job_paths): + job = MockJob(prefix="job_001") + result = JobResult(job=job, paths=job_paths, returncode=0) + mock_module = MagicMock() + mock_module.plotclaw = MagicMock() + with patch.dict( + sys.modules, + { + "clawpack": MagicMock(), + "clawpack.visclaw": MagicMock(), + "clawpack.visclaw.plotclaw": mock_module, + }, + ): + assert _plot_inprocess(result, "setplot.py", "ascii") is True diff --git a/tests/test_slurm.py b/tests/test_slurm.py new file mode 100644 index 0000000..34a3bbd --- /dev/null +++ b/tests/test_slurm.py @@ -0,0 +1,201 @@ +"""Tests for SLURMResources and render_slurm_script. + +render_slurm_script is a pure function so all tests run without a cluster. +SLURMExecutor submission is tested via dry_run=True. +""" + +from __future__ import annotations + +from pathlib import Path + +import pytest + +from batch.executors.slurm import SLURMExecutor, SLURMResources, render_slurm_script +from batch.job import JobPaths +from tests.conftest import MockJob + + +@pytest.fixture +def paths(tmp_path: Path) -> JobPaths: + job_dir = tmp_path / "job_001" + job_dir.mkdir() + return JobPaths( + job=job_dir, + plots=job_dir / "plots", + log=job_dir / "job_001_log.txt", + ) + + +@pytest.fixture +def minimal_resources() -> SLURMResources: + return SLURMResources(partition="main", nodes=1, time="02:00:00") + + +# --------------------------------------------------------------------------- +# render_slurm_script — directive correctness +# --------------------------------------------------------------------------- + + +class TestRenderSlurmScript: + def test_starts_with_shebang(self, paths, minimal_resources): + job = MockJob(prefix="job_001") + script = render_slurm_script(job, paths, minimal_resources) + assert script.startswith("#!/bin/bash") + + def test_contains_job_name_directive(self, paths, minimal_resources): + job = MockJob(prefix="job_001") + script = render_slurm_script(job, paths, minimal_resources) + assert "#SBATCH -J job_001" in script + + def test_log_path_in_directives(self, paths, minimal_resources): + job = MockJob(prefix="job_001") + script = render_slurm_script(job, paths, minimal_resources) + assert str(paths.log) in script + + def test_partition_in_directives(self, paths): + job = MockJob(prefix="job_001") + resources = SLURMResources(partition="preempt") + script = render_slurm_script(job, paths, resources) + assert "#SBATCH -p preempt" in script + + def test_walltime_in_directives(self, paths): + job = MockJob(prefix="job_001") + resources = SLURMResources(time="12:30:00") + script = render_slurm_script(job, paths, resources) + assert "#SBATCH -t 12:30:00" in script + + def test_cpus_per_task_in_directives(self, paths): + job = MockJob(prefix="job_001") + resources = SLURMResources(cpus_per_task=16) + script = render_slurm_script(job, paths, resources) + assert "#SBATCH --cpus-per-task=16" in script + + def test_memory_absent_when_empty(self, paths, minimal_resources): + job = MockJob(prefix="job_001") + minimal_resources.memory = "" + script = render_slurm_script(job, paths, minimal_resources) + assert "--mem=" not in script + + def test_memory_present_when_set(self, paths): + job = MockJob(prefix="job_001") + resources = SLURMResources(memory="8G") + script = render_slurm_script(job, paths, resources) + assert "#SBATCH --mem=8G" in script + + def test_account_present_when_set(self, paths): + job = MockJob(prefix="job_001") + resources = SLURMResources(account="NCAR0001") + script = render_slurm_script(job, paths, resources) + assert "#SBATCH -A NCAR0001" in script + + def test_account_absent_when_empty(self, paths, minimal_resources): + job = MockJob(prefix="job_001") + script = render_slurm_script(job, paths, minimal_resources) + assert "#SBATCH -A" not in script + + def test_constraint_present_when_set(self, paths): + job = MockJob(prefix="job_001") + resources = SLURMResources(constraint="cpu") + script = render_slurm_script(job, paths, resources) + assert "#SBATCH --constraint=cpu" in script + + def test_email_directives_when_set(self, paths): + job = MockJob(prefix="job_001") + resources = SLURMResources(email="user@example.com", mail_type="END,FAIL") + script = render_slurm_script(job, paths, resources) + assert "#SBATCH --mail-user=user@example.com" in script + assert "#SBATCH --mail-type=END,FAIL" in script + + def test_email_directives_absent_when_empty(self, paths, minimal_resources): + job = MockJob(prefix="job_001") + script = render_slurm_script(job, paths, minimal_resources) + assert "--mail-user" not in script + + def test_module_load_lines_present(self, paths): + job = MockJob(prefix="job_001") + resources = SLURMResources(modules=["ncarenv/23.09", "python/3.11.4"]) + script = render_slurm_script(job, paths, resources) + assert "module load ncarenv/23.09" in script + assert "module load python/3.11.4" in script + + def test_env_vars_exported(self, paths): + job = MockJob(prefix="job_001") + resources = SLURMResources(env_vars={"OMP_NUM_THREADS": "8"}) + script = render_slurm_script(job, paths, resources) + assert "export OMP_NUM_THREADS=8" in script + + def test_extra_directives_appended(self, paths): + job = MockJob(prefix="job_001") + resources = SLURMResources( + extra_directives=["#SBATCH --gres=gpu:1", "#SBATCH --licenses=scratch:1"] + ) + script = render_slurm_script(job, paths, resources) + assert "#SBATCH --gres=gpu:1" in script + assert "#SBATCH --licenses=scratch:1" in script + + def test_script_ends_with_newline(self, paths, minimal_resources): + job = MockJob(prefix="job_001") + script = render_slurm_script(job, paths, minimal_resources) + assert script.endswith("\n") + + def test_run_command_is_last_non_empty_line(self, paths, minimal_resources): + job = MockJob(prefix="job_001") + script = render_slurm_script(job, paths, minimal_resources) + non_empty = [ln for ln in script.splitlines() if ln.strip()] + last_line = non_empty[-1] + # Should invoke runclaw + assert "clawpack.clawutil.runclaw" in last_line + + def test_per_job_resource_override(self, paths): + """slurm_resources on the job should override executor defaults.""" + job = MockJob(prefix="job_001") + job.slurm_resources = SLURMResources(partition="gpu", time="04:00:00") + + executor = SLURMExecutor( + default_resources=SLURMResources(partition="main", time="01:00:00"), + dry_run=True, + ) + executor.submit(job, paths) + script = (paths.job / "job_001_run.sh").read_text() + assert "#SBATCH -p gpu" in script + assert "#SBATCH -t 04:00:00" in script + + +# --------------------------------------------------------------------------- +# SLURMExecutor dry_run +# --------------------------------------------------------------------------- + + +class TestSLURMExecutorDryRun: + def test_dry_run_writes_script_file(self, paths, minimal_resources): + job = MockJob(prefix="job_001") + executor = SLURMExecutor(default_resources=minimal_resources, dry_run=True) + executor.submit(job, paths) + assert (paths.job / "job_001_run.sh").exists() + + def test_dry_run_returns_dry_run_job_id(self, paths, minimal_resources): + job = MockJob(prefix="job_001") + executor = SLURMExecutor(default_resources=minimal_resources, dry_run=True) + result = executor.submit(job, paths) + assert result.job_id == "dry-run" + assert result.returncode is None + + def test_dry_run_does_not_call_sbatch(self, paths, minimal_resources): + from unittest.mock import patch + + job = MockJob(prefix="job_001") + executor = SLURMExecutor(default_resources=minimal_resources, dry_run=True) + with patch("batch.executors.slurm.subprocess.run") as mock_run: + executor.submit(job, paths) + mock_run.assert_not_called() + + def test_wait_all_skips_dry_run_jobs(self, paths, minimal_resources): + """wait_all should not poll squeue for dry-run job IDs.""" + from unittest.mock import patch + + job = MockJob(prefix="job_001") + executor = SLURMExecutor(default_resources=minimal_resources, dry_run=True) + result = executor.submit(job, paths) + with patch("batch.executors.slurm.subprocess.run") as mock_run: + executor.wait_all([result]) + mock_run.assert_not_called() diff --git a/tests/test_sweep.py b/tests/test_sweep.py new file mode 100644 index 0000000..28561ed --- /dev/null +++ b/tests/test_sweep.py @@ -0,0 +1,143 @@ +"""Tests for batch.sweep — product_sweep and zip_sweep.""" + +from __future__ import annotations + +import pytest + +from batch.job import Job +from batch.sweep import product_sweep, zip_sweep +from tests.conftest import MockJob + + +def simple_factory(**params) -> MockJob: + """Factory that records the params it received.""" + job = MockJob(prefix="to_be_set") + job._params = params + return job + + +def simple_namer(params: dict) -> str: + return "_".join(f"{k}{v}" for k, v in sorted(params.items())) + + +# --------------------------------------------------------------------------- +# product_sweep +# --------------------------------------------------------------------------- + + +class TestProductSweep: + def test_cartesian_product_count(self): + jobs = product_sweep( + factory=simple_factory, + namer=simple_namer, + manning=[0.020, 0.025, 0.030], + level=[4, 5], + ) + assert len(jobs) == 6 # 3 × 2 + + def test_single_parameter_list(self): + jobs = product_sweep( + factory=simple_factory, + namer=simple_namer, + manning=[0.020, 0.025], + ) + assert len(jobs) == 2 + + def test_prefix_set_by_namer(self): + jobs = product_sweep( + factory=simple_factory, + namer=lambda p: f"n{p['manning']:.3f}", + manning=[0.020, 0.025], + ) + assert jobs[0].prefix == "n0.020" + assert jobs[1].prefix == "n0.025" + + def test_all_combinations_present(self): + jobs = product_sweep( + factory=simple_factory, + namer=simple_namer, + a=[1, 2], + b=["x", "y"], + ) + prefixes = {j.prefix for j in jobs} + assert prefixes == {"a1_bx", "a1_by", "a2_bx", "a2_by"} + + def test_all_returned_objects_are_jobs(self): + jobs = product_sweep( + factory=simple_factory, + namer=simple_namer, + x=[1, 2, 3], + ) + assert all(isinstance(j, Job) for j in jobs) + + def test_empty_param_grid_returns_one_job(self): + # product of zero iterators is one empty combination + jobs = product_sweep( + factory=simple_factory, + namer=lambda p: "only", + ) + assert len(jobs) == 1 + assert jobs[0].prefix == "only" + + def test_factory_receives_correct_params(self): + jobs = product_sweep( + factory=simple_factory, + namer=simple_namer, + manning=[0.020], + level=[4], + ) + assert jobs[0]._params == {"manning": 0.020, "level": 4} + + +# --------------------------------------------------------------------------- +# zip_sweep +# --------------------------------------------------------------------------- + + +class TestZipSweep: + def test_paired_count(self): + jobs = zip_sweep( + factory=simple_factory, + namer=simple_namer, + storm_id=[1, 2, 3], + intensity=["low", "mid", "high"], + ) + assert len(jobs) == 3 + + def test_prefix_set_by_namer(self): + jobs = zip_sweep( + factory=simple_factory, + namer=lambda p: f"{p['storm_id']}_{p['intensity']}", + storm_id=[1, 2], + intensity=["low", "high"], + ) + assert jobs[0].prefix == "1_low" + assert jobs[1].prefix == "2_high" + + def test_raises_on_mismatched_lengths(self): + with pytest.raises(ValueError, match="same length"): + zip_sweep( + factory=simple_factory, + namer=simple_namer, + a=[1, 2, 3], + b=["x", "y"], # length mismatch + ) + + def test_single_parameter_list(self): + jobs = zip_sweep( + factory=simple_factory, + namer=lambda p: str(p["x"]), + x=[10, 20, 30], + ) + assert len(jobs) == 3 + assert [j.prefix for j in jobs] == ["10", "20", "30"] + + def test_factory_receives_paired_params(self): + jobs = zip_sweep( + factory=simple_factory, + namer=simple_namer, + a=[1, 2], + b=["x", "y"], + ) + assert jobs[0]._params == {"a": 1, "b": "x"} + assert jobs[1]._params == {"a": 2, "b": "y"}