diff --git a/CrocoDash/case.py b/CrocoDash/case.py index a92c551d..a757ba7e 100644 --- a/CrocoDash/case.py +++ b/CrocoDash/case.py @@ -93,6 +93,27 @@ def __init__( Must be in the form hh:mm:ss. If None, defaults to the CESM defaults """ + # Capture scalar init args for state serialization before any local vars are added. + # Excludes: objects (stored as paths), args resolved to derived values, and ephemeral flags. + _locals = locals() + _SERIALIZABLE_EXCLUDE = frozenset( + { + "self", + "ocn_grid", + "ocn_topo", + "ocn_vgrid", + "compset", + "machine", + "cesmroot", + "caseroot", + "inputdir", + "override", + } + ) + self._init_args = { + k: v for k, v in _locals.items() if k not in _SERIALIZABLE_EXCLUDE + } + # Initialize visualCaseGen system and get the CIME interface self.cime = initialize_visualCaseGen(cesmroot) @@ -125,11 +146,13 @@ def __init__( ) # Set instance attributes + self.cesmroot = Path(cesmroot) self.caseroot = Path(caseroot) self.inputdir = Path(inputdir) self.ocn_grid = ocn_grid self.ocn_topo = ocn_topo self.ocn_vgrid = ocn_vgrid + self.atm_grid_name = atm_grid_name self.ninst = ninst self.override = override self.ProductRegistry = ProductRegistry @@ -139,6 +162,10 @@ def __init__( self.compset_lname = compset_lname self.machine = machine or self.cime.machine self.project = project + self.rof_grid_name = rof_grid_name + self.ntasks_ocn = ntasks_ocn + self.job_queue = job_queue + self.job_wallclock_time = job_wallclock_time # Using visualCaseGen's configuration system, set the configuration variables for the case # based on the provided arguments. This includes setting the compset, grid, and launch variables. @@ -165,6 +192,8 @@ def __init__( self._apply_final_xmlchanges(ntasks_ocn, job_queue, job_wallclock_time) + self._write_state() + required_configurators = ForcingConfigRegistry.find_required_configurators( self.compset_lname ) @@ -872,6 +901,25 @@ def _configure_launch(self): # Variables that are not included in a stage: cvars["NINST"].value = self.ninst + def _write_state(self): + """Write case creation parameters to crocodash_state.json in caseroot.""" + state = { + # Derived / resolved fields that can't come from init args directly + "inputdir": str(self.inputdir), + "cesmroot": str(self.cesmroot), + "supergrid_path": self.supergrid_path, + "topo_path": self.topo_path, + "vgrid_path": self.vgrid_path, + "grid_name": self.ocn_grid.name, + "session_id": cvars["MB_ATTEMPT_ID"].value, + "compset_lname": self.compset_lname, + "machine": self.machine, + # Scalar init args captured at construction time + **self._init_args, + } + with open(self.caseroot / "crocodash_state.json", "w") as f: + json.dump(state, f, indent=2) + def _apply_final_xmlchanges( self, ntasks_ocn=None, job_queue=None, job_wallclock_time=None ): diff --git a/CrocoDash/cli.py b/CrocoDash/cli.py index 798657a8..a7505d81 100644 --- a/CrocoDash/cli.py +++ b/CrocoDash/cli.py @@ -1,5 +1,21 @@ import argparse import json +import sys + + +def _create(args): + from CrocoDash.recipe import load_config, create_case_from_yaml + + config = load_config(args.config) + create_case_from_yaml(config, override=args.override) + + +def _dump(args): + from CrocoDash.recipe import case_to_yaml + import yaml + + config = case_to_yaml(args.caseroot) + yaml.dump(config, sys.stdout, default_flow_style=False, sort_keys=False) def _bundle(args): @@ -30,20 +46,9 @@ def _duplicate_case(args): def _fork(args): - from CrocoDash.shareable.fork import ForkCrocoDashBundle plan = json.loads(args.plan) if args.plan else None - extra_configs = ( - [x.strip() for x in args.extra_configs.split(",") if x.strip()] - if args.extra_configs - else None - ) - remove_configs = ( - [x.strip() for x in args.remove_configs.split(",") if x.strip()] - if args.remove_configs - else None - ) forker = ForkCrocoDashBundle(args.bundle) forker.fork( @@ -53,10 +58,6 @@ def _fork(args): new_caseroot=args.caseroot, new_inputdir=args.inputdir, plan=plan, - compset=args.compset, - extra_configs=extra_configs, - remove_configs=remove_configs, - extra_forcing_args_path=args.extra_forcing_args, ) @@ -64,6 +65,32 @@ def main(): parser = argparse.ArgumentParser(prog="crocodash") subparsers = parser.add_subparsers(dest="command", required=True) + # --- create --- + create_parser = subparsers.add_parser( + "create", + help="Create a new CrocoDash case from a YAML config file.", + ) + create_parser.add_argument( + "--config", required=True, help="Path to the YAML case config file." + ) + create_parser.add_argument( + "--override", + action="store_true", + default=False, + help="Overwrite existing caseroot and inputdir if they exist.", + ) + create_parser.set_defaults(func=_create) + + # --- dump --- + dump_parser = subparsers.add_parser( + "dump", + help="Print a YAML representation of an existing CrocoDash case to stdout.", + ) + dump_parser.add_argument( + "--caseroot", required=True, help="Path to the existing CESM caseroot." + ) + dump_parser.set_defaults(func=_dump) + # --- bundle --- bundle_parser = subparsers.add_parser( "bundle", @@ -136,32 +163,10 @@ def main(): "--machine", required=True, help="Machine name (e.g. derecho)." ) fork_parser.add_argument("--project", required=True, help="Project/account number.") - # optional bypass flags - fork_parser.add_argument( - "--compset", default=None, help="Override the compset from the bundle." - ) fork_parser.add_argument( "--plan", default=None, - help='JSON object controlling what to copy, e.g. \'{"xml_files": true, "user_nl": true, "source_mods": false, "xmlchanges": true}\'.', - ) - fork_parser.add_argument( - "--extra-configs", - default=None, - dest="extra_configs", - help="Comma-separated forcing configs to add.", - ) - fork_parser.add_argument( - "--remove-configs", - default=None, - dest="remove_configs", - help="Comma-separated forcing configs to drop.", - ) - fork_parser.add_argument( - "--extra-forcing-args", - default=None, - dest="extra_forcing_args", - help="Path to JSON file with extra forcing arguments.", + help='JSON object controlling what non-standard CESM state to copy, e.g. \'{"xml_files": true, "user_nl": true, "source_mods": false, "xmlchanges": true}\'.', ) fork_parser.set_defaults(func=_fork) diff --git a/CrocoDash/recipe.py b/CrocoDash/recipe.py new file mode 100644 index 00000000..700d535f --- /dev/null +++ b/CrocoDash/recipe.py @@ -0,0 +1,257 @@ +import json +from datetime import datetime +from pathlib import Path + +import xarray as xr +import yaml + +from CrocoDash.case import Case +from CrocoDash.forcing_configurations.base import ForcingConfigRegistry +from CrocoDash.grid import Grid +from CrocoDash.topo import Topo +from CrocoDash.vgrid import VGrid +from CrocoDash.logging import setup_logger + +logger = setup_logger(__name__) + +_TOPO_SOURCE_TYPES = {"flat", "dataset", "from_file"} +_VGRID_TYPES = {"uniform", "hyperbolic", "from_file"} + +# State keys that are derived/resolved at init time and cannot be passed straight back +# to Case.__init__ — handled explicitly in case_to_yaml's "case" section. +_STATE_DERIVED_KEYS = frozenset( + { + "inputdir", + "cesmroot", + "supergrid_path", + "topo_path", + "vgrid_path", + "grid_name", + "session_id", + "compset_lname", + "machine", + } +) + + +def load_config(path): + """Read a YAML case config file, validate its structure, and return the config dict.""" + with open(path) as f: + config = yaml.safe_load(f) + validate_config_structure(config) + return config + + +def validate_config_structure(config): + """Fast pre-flight structural checks on a config dict before any expensive work.""" + required_top = {"grid", "topo", "vgrid", "case", "forcings"} + missing = required_top - set(config.keys()) + if missing: + raise ValueError(f"Config missing required top-level sections: {missing}") + + case_cfg = config["case"] + for key in ("cesmroot", "caseroot", "inputdir", "compset", "machine"): + if key not in case_cfg: + raise ValueError(f"case.{key} is required") + + topo_cfg = config.get("topo", {}) + source_cfg = topo_cfg.get("source", {}) + if "source" in topo_cfg and source_cfg.get("type") not in _TOPO_SOURCE_TYPES: + raise ValueError(f"topo.source.type must be one of {_TOPO_SOURCE_TYPES}") + + vgrid_cfg = config.get("vgrid", {}) + vgrid_type = vgrid_cfg.get("type") + if vgrid_type is not None and vgrid_type not in _VGRID_TYPES: + raise ValueError(f"vgrid.type must be one of {_VGRID_TYPES}") + + forcings_cfg = config["forcings"] + if "date_range" not in forcings_cfg: + raise ValueError("forcings.date_range is required") + dr = forcings_cfg["date_range"] + if not (isinstance(dr, list) and len(dr) == 2): + raise ValueError("forcings.date_range must be a list of exactly 2 date strings") + if "boundaries" in forcings_cfg: + valid_boundaries = {"north", "south", "east", "west"} + bad = set(forcings_cfg["boundaries"]) - valid_boundaries + if bad: + raise ValueError(f"Invalid boundary values: {bad}") + if "tidal_constituents" in forcings_cfg: + for tide_key in ("tpxo_elevation_filepath", "tpxo_velocity_filepath"): + if tide_key not in forcings_cfg: + raise ValueError( + f"forcings.{tide_key} is required when tidal_constituents is set" + ) + + +def build_grid(grid_cfg): + """Build a Grid from a config dict. Uses supergrid_path for file-based grids.""" + if "supergrid_path" in grid_cfg: + grid = Grid.from_supergrid(grid_cfg["supergrid_path"]) + if grid_cfg.get("name"): + grid.name = grid_cfg["name"] + return grid + return Grid(**grid_cfg) + + +def build_topo(topo_cfg, grid): + """Build a Topo from a config dict. Dispatches on topo.source.type.""" + min_depth = topo_cfg["min_depth"] + source = topo_cfg.get("source", {}) + source_type = source.get("type", "flat") + + if source_type == "from_file": + return Topo.from_topo_file(grid, source["topo_file_path"], min_depth=min_depth) + + topo = Topo(grid, min_depth) + + if source_type == "flat": + topo.set_flat(source["depth"]) + elif source_type == "dataset": + topo.set_from_dataset(**{k: v for k, v in source.items() if k != "type"}) + else: + raise ValueError(f"Unknown topo.source.type: '{source_type}'") + + return topo + + +def build_vgrid(vgrid_cfg, topo): + """Build a VGrid from a config dict. If depth is omitted, uses topo.max_depth.""" + vgrid_type = vgrid_cfg.get("type", "uniform") + + if vgrid_type == "from_file": + return VGrid.from_file(**{k: v for k, v in vgrid_cfg.items() if k != "type"}) + + depth = vgrid_cfg.get("depth") or topo.max_depth + kwargs = {k: v for k, v in vgrid_cfg.items() if k not in ("type", "depth")} + kwargs["depth"] = depth + + if vgrid_type == "uniform": + return VGrid.uniform(**kwargs) + elif vgrid_type == "hyperbolic": + return VGrid.hyperbolic(**kwargs) + else: + raise ValueError(f"Unknown vgrid.type: '{vgrid_type}'") + + +def create_case_from_yaml(config, override=False): + """ + Run the full case creation workflow from a config dict. + + Builds Grid, Topo, and VGrid objects, creates the CESM case, then calls + configure_forcings and process_forcings. A forcings section is required. + Returns the Case. + """ + grid = build_grid(config["grid"]) + topo = build_topo(config["topo"], grid) + vgrid = build_vgrid(config["vgrid"], topo) + + case = Case( + ocn_grid=grid, + ocn_topo=topo, + ocn_vgrid=vgrid, + override=override, + **config["case"], + ) + + case.configure_forcings(**config["forcings"]) + case.process_forcings() + + return case + + +def generate_configure_forcing_args(forcing_config, remove_configs=None): + """Convert a config.json forcing_config dict into configure_forcings kwargs.""" + if remove_configs is None: + remove_configs = [] + logger.info("Setup configuration arguments...") + + start_str = forcing_config["basic"]["dates"]["start"] + end_str = forcing_config["basic"]["dates"]["end"] + date_format = forcing_config["basic"]["dates"]["format"] + start_dt = datetime.strptime(start_str, date_format) + end_dt = datetime.strptime(end_str, date_format) + + date_range = [ + start_dt.strftime("%Y-%m-%d %H:%M:%S"), + end_dt.strftime("%Y-%m-%d %H:%M:%S"), + ] + + configure_forcing_args = { + "date_range": date_range, + "boundaries": list( + forcing_config["basic"]["general"]["boundary_number_conversion"].keys() + ), + "product_name": forcing_config["basic"]["forcing"]["product_name"], + "function_name": forcing_config["basic"]["forcing"]["function_name"], + } + for key in forcing_config: + if key == "basic" or key in remove_configs: + continue + user_args = ForcingConfigRegistry.get_user_args( + ForcingConfigRegistry.get_configurator_from_name(key) + ) + for arg in user_args: + if not arg.startswith("case_"): + configure_forcing_args[arg] = forcing_config[key]["inputs"][arg] + return configure_forcing_args + + +def case_to_yaml(caseroot): + """ + Reconstruct a YAML config dict from an existing case's state files. + + Reads crocodash_state.json (written by Case.__init__) and, if present, + extract_forcings/config.json (written by Case.configure_forcings). + Returns a dict suitable for passing to create_case_from_yaml or writing + to a YAML file with yaml.dump(). + """ + caseroot = Path(caseroot) + state_path = caseroot / "crocodash_state.json" + if not state_path.exists(): + raise FileNotFoundError( + f"No crocodash_state.json found in {caseroot}. " + "This case may not have been created with a recent version of CrocoDash." + ) + with open(state_path) as f: + state = json.load(f) + + topo_ds = xr.open_dataset(state["topo_path"]) + min_depth = float(topo_ds.attrs.get("min_depth", 0.0)) + topo_ds.close() + + config = { + "grid": { + "supergrid_path": state["supergrid_path"], + "name": state["grid_name"], + }, + "topo": { + "min_depth": min_depth, + "source": { + "type": "from_file", + "topo_file_path": state["topo_path"], + }, + }, + "vgrid": { + "type": "from_file", + "filename": state["vgrid_path"], + }, + "case": { + # Derived/resolved fields — require explicit mapping from state keys + "cesmroot": state["cesmroot"], + "caseroot": str(caseroot), + "inputdir": state["inputdir"], + "compset": state["compset_lname"], + "machine": state["machine"], + # Scalar init args stored verbatim by Case._init_args — pull dynamically + # so new Case.__init__ params flow through without touching this function. + **{k: v for k, v in state.items() if k not in _STATE_DERIVED_KEYS}, + }, + } + + forcing_config_path = Path(state["inputdir"]) / "extract_forcings" / "config.json" + if forcing_config_path.exists(): + with open(forcing_config_path) as f: + forcing_config = json.load(f) + config["forcings"] = generate_configure_forcing_args(forcing_config) + + return config diff --git a/CrocoDash/shareable/bundle.py b/CrocoDash/shareable/bundle.py index f82bedea..b7bad1f6 100644 --- a/CrocoDash/shareable/bundle.py +++ b/CrocoDash/shareable/bundle.py @@ -1,32 +1,40 @@ -""" -Bundle is inordinately hard-coded, and probably can't be changed. Robust testing is needed to ensure we are picking up the correct information -""" - -from pathlib import Path import dataclasses +import importlib import json +import logging import os +import shutil +import subprocess +import sys import tempfile +from contextlib import redirect_stdout, redirect_stderr +from pathlib import Path +from uuid import uuid4 + +import yaml +from CrocoDash.forcing_configurations.base import * from CrocoDash.grid import * -from CrocoDash.topo import * -from CrocoDash.vgrid import * +from CrocoDash.logging import setup_logger +from CrocoDash.shareable.apply import ( + INPUTDIR_FILE_PREFIXES, + apply_xmlchanges_to_case, + copy_source_mods_from_case, + copy_user_nl_params_from_case, + copy_xml_files_from_case, +) from CrocoDash.shareable.fork import ( + BundleDifferences, + BundleManifest, + ForkCrocoDashBundle, create_case, +) +from CrocoDash.topo import * +from CrocoDash.vgrid import * +from CrocoDash.recipe import ( + case_to_yaml, + create_case_from_yaml, generate_configure_forcing_args, - ForkCrocoDashBundle, - BundleManifest, - BundleDifferences, ) -from CrocoDash.shareable.apply import INPUTDIR_FILE_PREFIXES -from uuid import uuid4 -import subprocess -from CrocoDash.logging import setup_logger -from contextlib import redirect_stdout, redirect_stderr -import logging -from CrocoDash.forcing_configurations.base import * -import importlib -import sys -import shutil logger = setup_logger(__name__) @@ -45,16 +53,14 @@ def __init__(self, caseroot): self._get_case_machine() self._get_case_project() self._read_user_nls() - self._identify_CrocoDashCase_init_args() - self._identify_CrocoDashCase_forcing_config_args() + self._load_state_from_crocodash() self._read_xmlchanges() self._read_xmlfiles() self._read_sourcemods() def reread(self): self._read_user_nls() - self._identify_CrocoDashCase_init_args() - self._identify_CrocoDashCase_forcing_config_args() + self._load_state_from_crocodash() self._read_xmlchanges() self._read_xmlfiles() self._read_sourcemods() @@ -65,19 +71,6 @@ def case(self): self._case = get_case_obj(self.caseroot) return self._case - def generate_manifest(self) -> BundleManifest: - return BundleManifest( - paths={ - "casefiles": str(self.caseroot), - "inputfiles": self.init_args["inputdir_ocnice"], - }, - user_nl_info=self.user_nl_objs, - init_args=self.init_args, - forcing_config=self.forcing_config, - sourcemods=[str(f) for f in self.sourcemods], - xmlchanges=self.xmlchanges, - ) - def _read_xmlchanges(self): replay_path = self.caseroot / "replay.sh" self.xmlchanges = {} @@ -126,43 +119,42 @@ def _read_xmlfiles(self): def _read_sourcemods(self): self.sourcemods = { - f.relative_to(self.caseroot / "SourceMods") + str(f.relative_to(self.caseroot / "SourceMods")) for f in (self.caseroot / "SourceMods").rglob("*") if f.is_file() } - def _identify_CrocoDashCase_init_args(self): + def _load_state_from_crocodash(self): + """Load case parameters from crocodash_state.json and extract_forcings/config.json.""" + logger.info(f"Loading CrocoDash state from {self.caseroot}") + self.case_yaml = case_to_yaml(self.caseroot) - logger.info(f"Finding initialization arguments from {self.caseroot}") - - inputdir_ocnice = self.get_user_nl_value("mom", "INPUTDIR") + # Populate init_args in the legacy format for identify_non_standard / fork compatibility + state_path = self.caseroot / "crocodash_state.json" + with open(state_path) as f: + state = json.load(f) + inputdir_ocnice = str(Path(state["inputdir"]) / "ocnice") esmf_file = next(Path(inputdir_ocnice).glob("ESMF_mesh_*.nc"), None) self.init_args = { "inputdir_ocnice": inputdir_ocnice, - "supergrid_path": self.get_user_nl_value("mom", "GRID_FILE"), - "vgrid_path": self.get_user_nl_value("mom", "ALE_COORDINATE_CONFIG"), - "topo_path": self.get_user_nl_value("mom", "TOPO_FILE"), + "supergrid_path": Path(state["supergrid_path"]).name, + "vgrid_path": Path(state["vgrid_path"]).name, + "topo_path": Path(state["topo_path"]).name, "esmf_mesh_path": esmf_file.name if esmf_file else None, - "compset": self.case.get_value("COMPSET"), - "atm_grid_name": self.case.get_value("ATM_GRID"), + "compset": state["compset_lname"], + "atm_grid_name": state.get("atm_grid_name", "TL319"), } - return self.init_args - - def _identify_CrocoDashCase_forcing_config_args(self): - - logger.info(f"Loading forcing configuration from {self.caseroot}") - # The input directory is where the forcing config is. - - # Find the input directory - inputdir = self.get_user_nl_value("mom", "INPUTDIR") - - # Read in forcing config file - forcing_config_path = Path(inputdir).parent / "extract_forcings" / "config.json" + forcing_config_path = ( + Path(state["inputdir"]) / "extract_forcings" / "config.json" + ) + if forcing_config_path.exists(): + with open(forcing_config_path) as f: + self.forcing_config = json.load(f) + else: + self.forcing_config = {} - with open(forcing_config_path, "r") as f: - self.forcing_config = json.load(f) - return self.forcing_config + return self.init_args def get_user_nl_value(self, component, param): return ( @@ -173,7 +165,7 @@ def _read_user_nl_lines_as_obj(self, user_nl_comp="mom"): if not hasattr(self, "user_nl_reader"): # Import the CESM MOM_interface user_nl_mom reader - mod_path = ( + mod_path = str( self.cesmroot / "components" / "mom" @@ -186,7 +178,7 @@ def _read_user_nl_lines_as_obj(self, user_nl_comp="mom"): spec.loader.exec_module(self.user_nl_reader) return self.user_nl_reader.FType_MOM_params.from_MOM_input( - self.caseroot / f"user_nl_{user_nl_comp}" + str(self.caseroot / f"user_nl_{user_nl_comp}") )._data def _get_cesmroot(self): @@ -223,9 +215,7 @@ def diff(self, other_case): return BundleDifferences( xml_files_missing_in_new=sorted(list(self.xmlfiles - other_case.xmlfiles)), - source_mods_missing_files=sorted( - [str(f) for f in self.sourcemods - other_case.sourcemods] - ), + source_mods_missing_files=sorted(self.sourcemods - other_case.sourcemods), xmlchanges_missing=sorted( k for k in self.xmlchanges if k not in other_case.xmlchanges ), @@ -327,12 +317,10 @@ def bundle(self, output_folder_location, machine=None, project=None): logger.info(f"Copying grid file: {src}") shutil.copy(src, ocnice_target / src.name) - # Write out manifest - logger.info(f"Writing out BundleCrocoDashCase manifest...") - with open(case_subfolder / "manifest.json", "w") as f: - json.dump( - dataclasses.asdict(self.generate_manifest()), f, indent=2, default=str - ) + # Write YAML (replaces manifest.json — init_args + forcing_config in human-readable form) + logger.info("Writing out crocodash_case.yaml...") + with open(case_subfolder / "crocodash_case.yaml", "w") as f: + yaml.dump(self.case_yaml, f, default_flow_style=False, sort_keys=False) # Write out differences logger.info(f"Writing out non standard CrocoDash information...") @@ -374,7 +362,7 @@ def duplicate_case(self, new_caseroot, new_inputdir, bundle_dir=None): def duplicate_case(caseroot, new_caseroot, new_inputdir, bundle_dir=None): """ Duplicate a CrocoDash case to a new location. Machine, project, and cesmroot - are read automatically from the original caseroot. + are read automatically from the original caseroot's crocodash_state.json. Parameters ---------- @@ -385,37 +373,48 @@ def duplicate_case(caseroot, new_caseroot, new_inputdir, bundle_dir=None): new_inputdir : str or Path Path for the new input directory. bundle_dir : str or Path, optional - Where to write the intermediate bundle. Defaults to inside - new_caseroot and is cleaned up automatically. + Where to copy the bundle for reference. If None, no bundle is saved. """ rcc = BundleCrocoDashCase(caseroot) + rcc.identify_non_standard_CrocoDash_case_information( + rcc.cesmroot, rcc.case_machine, rcc.case_project + ) + + # Patch paths in the YAML for the new location + config = rcc.case_yaml.copy() + config["case"] = config["case"].copy() + config["case"]["caseroot"] = str(new_caseroot) + config["case"]["inputdir"] = str(new_inputdir) - plan = { - "xml_files": True, - "user_nl": True, - "source_mods": True, - "xmlchanges": True, - } - - with tempfile.TemporaryDirectory() as tmp: - loc = rcc.bundle(tmp) - fcb = ForkCrocoDashBundle(loc) - result = fcb.fork( - rcc.cesmroot, - rcc.case_machine, - rcc.case_project, - new_caseroot, - new_inputdir, - plan=plan, - compset=rcc.init_args["compset"], - extra_configs=[], - remove_configs=[], + result = create_case_from_yaml(config, override=True) + + # Copy all non-standard CESM state (full plan) + if rcc.non_standard_case_info.xml_files_missing_in_new: + copy_xml_files_from_case( + rcc.caseroot, + result.caseroot, + rcc.non_standard_case_info.xml_files_missing_in_new, ) - dest = Path(new_caseroot) / loc.name - if bundle_dir is None: - shutil.copytree(loc, dest) - else: - shutil.copytree(loc, Path(bundle_dir) / loc.name) + if rcc.non_standard_case_info.user_nl_missing_params and any( + rcc.non_standard_case_info.user_nl_missing_params.values() + ): + copy_user_nl_params_from_case( + rcc.caseroot, rcc.non_standard_case_info.user_nl_missing_params + ) + if rcc.non_standard_case_info.source_mods_missing_files: + copy_source_mods_from_case( + rcc.caseroot, + result.caseroot, + rcc.non_standard_case_info.source_mods_missing_files, + ) + if rcc.non_standard_case_info.xmlchanges_missing: + apply_xmlchanges_to_case( + rcc.caseroot, rcc.non_standard_case_info.xmlchanges_missing + ) + + # Optionally save a bundle alongside the new case + if bundle_dir is not None: + rcc.bundle(bundle_dir) return result diff --git a/CrocoDash/shareable/fork.py b/CrocoDash/shareable/fork.py index 90062fe4..d390a39a 100644 --- a/CrocoDash/shareable/fork.py +++ b/CrocoDash/shareable/fork.py @@ -1,16 +1,23 @@ -from pathlib import Path -from CrocoDash.forcing_configurations.base import * -from CrocoDash.shareable.apply import * +import copy import json +import os import shutil -from datetime import datetime +import subprocess +import tempfile from dataclasses import dataclass, field +from datetime import datetime +from pathlib import Path + +import xarray as xr +import yaml from CrocoDash.case import Case +from CrocoDash.forcing_configurations.base import * from CrocoDash.grid import Grid -from CrocoDash.vgrid import VGrid -from CrocoDash.topo import Topo -import xarray as xr from CrocoDash.logging import setup_logger +from CrocoDash.shareable.apply import * +from CrocoDash.topo import Topo +from CrocoDash.vgrid import VGrid +from CrocoDash.recipe import create_case_from_yaml, generate_configure_forcing_args logger = setup_logger(__name__) @@ -43,13 +50,30 @@ class ForkCrocoDashBundle: def __init__(self, bundle_location): self.bundle_location = Path(bundle_location) - json_file = self.bundle_location / "manifest.json" - assert json_file.exists() - with open(json_file) as f: - self.manifest = BundleManifest(**json.load(f)) + yaml_file = self.bundle_location / "crocodash_case.yaml" + assert yaml_file.exists(), f"Bundle is missing crocodash_case.yaml: {yaml_file}" + with open(yaml_file) as f: + self.bundle_yaml = yaml.safe_load(f) + + # Populate a minimal manifest for backwards-compatible apply_copy_plan usage + state = self.bundle_yaml + case_cfg = state.get("case", {}) + inputdir_ocnice = str( + Path(state.get("grid", {}).get("supergrid_path", "")).parent + ) + self.manifest = BundleManifest( + forcing_config={}, + init_args={ + "inputdir_ocnice": inputdir_ocnice, + "compset": case_cfg.get("compset", ""), + "atm_grid_name": case_cfg.get("atm_grid_name", "TL319"), + }, + ) json_file = self.bundle_location / "non_standard_case_info.json" - assert json_file.exists() + assert ( + json_file.exists() + ), f"Bundle is missing non_standard_case_info.json: {json_file}" with open(json_file) as f: self.differences = BundleDifferences(**json.load(f)) @@ -86,94 +110,146 @@ def fork( new_caseroot, new_inputdir, plan=None, - compset=None, - extra_configs=None, - remove_configs=None, - extra_forcing_args_path=None, ): """ - Share a CESM case by inspecting an existing bundle, optionally copying - non-standard components, resolving forcing configurations, and creating - a new case with equivalent forcings. + Create a new case from a bundle, guiding the user through YAML modifications. + + Prompts the user to update destination paths and machine settings, then + optionally opens $EDITOR for deeper edits. After confirmation, creates the + case and copies non-standard CESM state per the plan. Parameters ---------- cesmroot : str or Path - Path to the CESM root. + Path to the CESM root for the new case. machine : str - Machine name. + Machine name for the new case. project_number : str - Project/account number. + Project/account number for the new case. new_caseroot : str or Path Path for the new case root. new_inputdir : str or Path - Path for input data. + Path for the new input directory. plan : dict, optional Which non-standard items to copy, e.g. ``{"xml_files": True, "user_nl": False, "source_mods": True, "xmlchanges": True}``. When omitted the user is asked interactively. - compset : str, optional - Override the compset from the bundle. When omitted the user is asked interactively. - extra_configs : list, optional - Additional forcing configuration names to add beyond the bundle. - remove_configs : list, optional - Forcing configuration names from the bundle to drop. - extra_forcing_args_path : str or Path, optional - Path to a JSON file supplying arguments for any new forcing configs. """ - - # Phase 1: gather all decisions (prompting interactively where params are None) - self._gather_inputs( - plan, compset, extra_configs, remove_configs, extra_forcing_args_path + # Phase 1: build patched YAML with new destination values + config = self._configure_yaml_for_forked_case_args( + cesmroot, machine, project_number, new_caseroot, new_inputdir ) - # Phase 2: pure execution — no prompts below this point - logger.info("Creating new case...") - self.manifest.init_args["inputdir_ocnice"] = str( - self.bundle_location / "ocnice" - ) - self.case = create_case( - self.manifest.init_args, - new_caseroot, - new_inputdir, - compset=self.compset, - machine=machine, - project_number=project_number, - cesmroot=cesmroot, - ) + # Phase 2: guided YAML review — prompt for each key field, offer editor + config = self._guide_yaml_review(config) + + # Phase 3: resolve which non-standard CESM items to copy + self._resolve_copy_plan(plan) - logger.info("Copying exact grid files from bundle...") + # Phase 4: create the case + logger.info("Creating new case from YAML...") + self.case = create_case_from_yaml(config, override=True) + + # Phase 5: copy bundle ocnice files then apply non-standard CESM state + logger.info("Copying forcing files from bundle...") bundle_ocnice = self.bundle_location / "ocnice" - for key in ("supergrid_path", "topo_path", "vgrid_path", "esmf_mesh_path"): - src_name = self.manifest.init_args.get(key) - dst = getattr(self.case, key, None) - if src_name and dst is not None: - src = bundle_ocnice / src_name - if src.exists(): - shutil.copy(src, dst) - - logger.info("Building configuration args") - self.case.configure_forcings(**self.configure_forcing_args) - - logger.info("Copying items to new case based on user input") + for src in bundle_ocnice.iterdir(): + dst = Path(self.case.inputdir) / "ocnice" / src.name + if not dst.exists(): + shutil.copy(src, dst) + + logger.info("Applying non-standard CESM state per plan...") self.apply_copy_plan() self.case.validate_case() print( - "\nYou're ready! If you requested any additional forcings, remember to " - "run them with your extract_forcings driver script." + "\nYou're ready! Remember to run the extract_forcings driver to " + "regenerate any forcing files for the new domain." ) return self.case - def _gather_inputs( - self, plan, compset, extra_configs, remove_configs, extra_forcing_args_path + def _configure_yaml_for_forked_case_args( + self, cesmroot, machine, project_number, new_caseroot, new_inputdir ): - """Gather all decisions before execution, prompting interactively where params are None.""" - self._resolve_copy_plan(plan) - self._resolve_compset(compset) - self._resolve_forcing_configurations(extra_configs, remove_configs) - self._resolve_forcing_args(extra_forcing_args_path) + """Return a copy of bundle_yaml with destination fields configured for the forked case.""" + config = copy.deepcopy(self.bundle_yaml) + config["case"]["cesmroot"] = str(cesmroot) + config["case"]["machine"] = machine + config["case"]["project"] = project_number + config["case"]["caseroot"] = str(new_caseroot) + config["case"]["inputdir"] = str(new_inputdir) + # Point grid/topo/vgrid at bundle ocnice copies + bundle_ocnice = str(self.bundle_location / "ocnice") + if "supergrid_path" in config.get("grid", {}): + config["grid"]["supergrid_path"] = str( + self.bundle_location + / "ocnice" + / Path(config["grid"]["supergrid_path"]).name + ) + if config.get("topo", {}).get("source", {}).get("type") == "from_file": + config["topo"]["source"]["topo_file_path"] = str( + self.bundle_location + / "ocnice" + / Path(config["topo"]["source"]["topo_file_path"]).name + ) + if config.get("vgrid", {}).get("type") == "from_file": + config["vgrid"]["filename"] = str( + self.bundle_location / "ocnice" / Path(config["vgrid"]["filename"]).name + ) + return config + + def _guide_yaml_review(self, config): + """Walk the user through key YAML fields and offer $EDITOR for deeper edits.""" + print("\n=== Fork: Review Case Configuration ===") + print( + "The following fields have been pre-filled. Press Enter to keep each value.\n" + ) + + fields = [ + ("case.caseroot", ["case", "caseroot"]), + ("case.inputdir", ["case", "inputdir"]), + ("case.cesmroot", ["case", "cesmroot"]), + ("case.machine", ["case", "machine"]), + ("case.project", ["case", "project"]), + ("case.compset", ["case", "compset"]), + ] + if "forcings" in config: + fields += [ + ("forcings.date_range", ["forcings", "date_range"]), + ("forcings.boundaries", ["forcings", "boundaries"]), + ] + + for label, keys in fields: + obj = config + for k in keys[:-1]: + obj = obj[k] + current = obj[keys[-1]] + response = ask_string(f" {label} [{current}]: ", default=str(current)) + if response != str(current): + if keys[-1] in ("date_range", "boundaries"): + obj[keys[-1]] = yaml.safe_load(response) + else: + obj[keys[-1]] = response + + editor = os.environ.get("EDITOR", "") + if editor and ask_yes_no("\nOpen $EDITOR for full YAML review?", default=False): + with tempfile.NamedTemporaryFile( + mode="w", suffix=".yaml", delete=False + ) as tmp: + yaml.dump(config, tmp, default_flow_style=False, sort_keys=False) + tmp_path = tmp.name + subprocess.call([editor, tmp_path]) + with open(tmp_path) as f: + config = yaml.safe_load(f) + Path(tmp_path).unlink() + + print("\nFinal configuration:") + print(yaml.dump(config, default_flow_style=False, sort_keys=False)) + if not ask_yes_no("Proceed with this configuration?", default=True): + raise RuntimeError("Fork cancelled by user.") + + return config def _resolve_copy_plan(self, plan): if plan is not None: @@ -207,103 +283,6 @@ def _resolve_copy_plan(self, plan): f"{self.differences.xmlchanges_missing}\nApply them?" ) - def _resolve_compset(self, compset): - self.compset = self.manifest.init_args["compset"] - if compset is not None and compset != self.compset: - self.compset = compset - print( - "Warning: Changing compset may have unintended consequences and " - "may require additional data." - ) - - def _resolve_forcing_configurations(self, extra_configs, remove_configs): - self.requested_configs = [] - - required = ForcingConfigRegistry.find_required_configurators(self.compset) - for cfg in required: - if cfg.name.lower() not in self.manifest.forcing_config: - print("Missing required configurator:", cfg) - self.requested_configs.append(cfg.name.lower()) - - valid = ForcingConfigRegistry.find_valid_configurators(self.compset) - already_ran = [] - - for cfg in self.manifest.forcing_config: - if cfg == "basic": - continue - config_class = ForcingConfigRegistry.get_configurator_from_name(cfg) - if config_class not in valid: - print(f"Forcing config '{cfg}' is no longer valid for this compset") - else: - already_ran.append(config_class) - valid.remove(config_class) - - if extra_configs is not None: - extra = set(extra_configs) - self.resolved_remove = ( - set(remove_configs) if remove_configs is not None else set() - ) - else: - extra_str = ask_string( - f"Enter any other configurations you want " - f"(comma-separated) from: {[obj.name for obj in valid]}", - default="[]", - ) - remove_str = ask_string( - f"Enter any configs you don't want " - f"(comma-separated) from: {[obj.name for obj in already_ran]}", - default="[]", - ) - extra = {x.strip() for x in extra_str.split(",") if x.strip()} - self.resolved_remove = { - x.strip() for x in remove_str.split(",") if x.strip() - } - - for thing in ForcingConfigRegistry.registered_types: - if thing.name in extra: - self.requested_configs.append(thing.name) - - def _resolve_forcing_args(self, extra_forcing_args_path): - self.configure_forcing_args = generate_configure_forcing_args( - self.manifest.forcing_config, self.resolved_remove - ) - if not self.requested_configs: - return - - print( - "\nYou requested or are required to add the following configurations:", - self.requested_configs, - ) - required_args = [ - user_arg - for config in self.requested_configs - for user_arg in ForcingConfigRegistry.get_user_args( - ForcingConfigRegistry.get_configurator_from_name(config) - ) - if not user_arg.startswith("case_") - and user_arg not in self.configure_forcing_args - ] - if extra_forcing_args_path is None: - print(f"Provide the following arguments in a JSON file: {required_args}") - extra_forcing_args_path = ask_string( - "Enter path to JSON file with the required arguments: " - ) - with open(extra_forcing_args_path) as f: - new_args = json.load(f) - - for config in self.requested_configs: - for user_arg in ForcingConfigRegistry.get_user_args( - ForcingConfigRegistry.get_configurator_from_name(config) - ): - if ( - not user_arg.startswith("case_") - and user_arg not in self.configure_forcing_args - and user_arg not in new_args - ): - raise ValueError(f"Missing arg: '{user_arg}' for {config}") - - self.configure_forcing_args.update(new_args) - def apply_copy_plan(self): if self.plan.get("xml_files"): copy_xml_files_from_case( @@ -331,9 +310,7 @@ def apply_copy_plan(self): self.differences.xmlchanges_missing, ) - copy_configurations_to_case( - self.manifest.forcing_config, self.case, self.bundle_location / "ocnice" - ) + # Forcing files are copied in fork() before apply_copy_plan is called. def ask_string(prompt: str, default="") -> str: @@ -420,39 +397,3 @@ def create_case( atm_grid_name=init_args["atm_grid_name"], ) return case - - -def generate_configure_forcing_args(forcing_config, remove_configs=None): - if remove_configs is None: - remove_configs = [] - logger.info("Setup configuration arguments...") - - start_str = forcing_config["basic"]["dates"]["start"] - end_str = forcing_config["basic"]["dates"]["end"] - date_format = forcing_config["basic"]["dates"]["format"] - start_dt = datetime.strptime(start_str, date_format) - end_dt = datetime.strptime(end_str, date_format) - - date_range = [ - start_dt.strftime("%Y-%m-%d %H:%M:%S"), - end_dt.strftime("%Y-%m-%d %H:%M:%S"), - ] - - configure_forcing_args = { - "date_range": date_range, - "boundaries": list( - forcing_config["basic"]["general"]["boundary_number_conversion"].keys() - ), - "product_name": forcing_config["basic"]["forcing"]["product_name"], - "function_name": forcing_config["basic"]["forcing"]["function_name"], - } - for key in forcing_config: - if key == "basic" or key in remove_configs: - continue - user_args = ForcingConfigRegistry.get_user_args( - ForcingConfigRegistry.get_configurator_from_name(key) - ) - for arg in user_args: - if not arg.startswith("case_"): - configure_forcing_args[arg] = forcing_config[key]["inputs"][arg] - return configure_forcing_args diff --git a/docs/source/for_users/case_information.md b/docs/source/for_users/case_information.md index 3c31e4d5..9672965c 100644 --- a/docs/source/for_users/case_information.md +++ b/docs/source/for_users/case_information.md @@ -1,5 +1,26 @@ # Extra CESM Case Information +## Case State File + +At the end of `Case.__init__`, CrocoDash writes a `crocodash_state.json` file into the caseroot. It records the construction parameters needed to reconstruct or inspect the case later: + +```json +{ + "inputdir": "/path/to/croc_input/mycase", + "cesmroot": "/path/to/CROCESM", + "supergrid_path": "/path/to/.../ocean_hgrid_mygrid_abc123.nc", + "topo_path": "/path/to/.../ocean_topog_mygrid_abc123.nc", + "vgrid_path": "/path/to/.../ocean_vgrid_mygrid_abc123.nc", + "grid_name": "mygrid", + "session_id": "abc123", + "compset_lname": "1850_DATM%JRA_SLND_SICE_MOM6%REGIONAL_SROF_SGLC_SWAV", + "machine": "derecho", + "project": "NCGD0011", + "atm_grid_name": "TL319" +} +``` + +This file is read by `crocodash dump` to reconstruct a YAML config, and by `bundle`/`fork`/`duplicate` to avoid re-parsing CIME. It is the paired counterpart to `inputdir/extract_forcings/config.json`, which records forcing configuration written by `configure_forcings`. ## Available Compset Aliases diff --git a/docs/source/for_users/cli.md b/docs/source/for_users/cli.md new file mode 100644 index 00000000..4f478966 --- /dev/null +++ b/docs/source/for_users/cli.md @@ -0,0 +1,118 @@ +# Command Line Interface + +CrocoDash ships a `crocodash` command (installed automatically with `pip install -e .`) that lets you run the full case setup workflow from a YAML config file and inspect or share existing cases — no Python scripting required. + +## Quick reference + +``` +crocodash create --config mycase.yaml [--override] +crocodash dump --caseroot /path/to/case +crocodash bundle --caseroot /path/to/case --output-dir /path/to/bundle_dir ... +crocodash fork --bundle /path/to/bundle --caseroot ... --inputdir ... --cesmroot ... --machine ... --project ... +crocodash duplicate --source /path/to/case --case /path/to/new_case --inputdir /path/to/new_inputdir +``` + +--- + +## `crocodash create` + +Creates a new CrocoDash case end-to-end from a YAML config file. Equivalent to calling `recipe.create_case_from_yaml()`. + +```bash +crocodash create --config mycase.yaml +crocodash create --config mycase.yaml --override # overwrite existing caseroot/inputdir +``` + +### YAML config schema + +```yaml +# --- Horizontal grid --- +grid: + lenx: 10.0 # domain width in degrees + leny: 10.0 # domain height in degrees + xstart: -60.0 # western edge longitude + ystart: 30.0 # southern edge latitude + resolution: 1.0 # degrees per cell (or use nx/ny instead) + name: "mygrid" + +# --- Bathymetry --- +topo: + min_depth: 10.0 # columns shallower than this are masked + source: + type: "flat" # flat | dataset | from_file + depth: 1000.0 # for type: flat — constant depth in metres + + # type: dataset — interpolate from a real bathymetry file (e.g. GEBCO) + # bathymetry_path: "/path/to/gebco.nc" + # longitude_coordinate_name: "lon" + # latitude_coordinate_name: "lat" + # vertical_coordinate_name: "elevation" + # is_input_positive_below_msl: false + # fill_channels: false + + # type: from_file — reuse an existing topog.nc + # topo_file_path: "/path/to/ocean_topog.nc" + +# --- Vertical grid --- +vgrid: + type: "uniform" # uniform | hyperbolic | from_file + nk: 10 # number of layers + # depth omitted → uses topo.max_depth automatically + name: "myvgrid" + + # type: hyperbolic — surface-intensified levels + # nk: 75 + # depth: 5000.0 + # ratio: 20.0 + + # type: from_file — reuse an existing vgrid.nc + # filename: "/path/to/ocean_vgrid.nc" + +# --- CESM case --- +case: + cesmroot: "/path/to/CROCESM" + caseroot: "/path/to/cases/mycase" + inputdir: "/path/to/croc_input/mycase" + compset: "CR_JRA" # alias or full long name + machine: "derecho" + project: "NCGD0011" + atm_grid_name: "TL319" # optional, default TL319 + +# --- Forcings (required) --- +forcings: + date_range: ["2020-01-01 00:00:00", "2020-02-01 00:00:00"] + boundaries: ["south", "east", "west"] + product_name: "GLORYS" + function_name: "get_glorys_data_from_rda" + + # Any extra kwargs are forwarded directly to configure_forcings: + # tidal_constituents: ["M2", "S2"] + # tpxo_elevation_filepath: "/path/to/TPXO_elevation.nc" + # tpxo_velocity_filepath: "/path/to/TPXO_velocity.nc" +``` + +After `create` completes the caseroot contains a `crocodash_state.json` recording all construction parameters, and `inputdir/extract_forcings/config.json` recording the forcing setup. These files are the source of truth for `dump`, `bundle`, and `fork`. + +--- + +## `crocodash dump` + +Prints a YAML representation of an existing case to stdout. The output can be saved, edited, and passed back to `create` — making `dump` the exact inverse of `create`. + +```bash +# View the config for an existing case +crocodash dump --caseroot ~/croc_cases/mycase + +# Save it to a file and edit before re-creating +crocodash dump --caseroot ~/croc_cases/mycase > mycase_copy.yaml +# ... edit paths, dates, machine, etc. ... +crocodash create --config mycase_copy.yaml --override +``` + +The dumped YAML uses `supergrid_path`/`from_file` references pointing at the existing grid/topo/vgrid files. To create a fully independent copy, either update those paths or re-generate the grid from parameters. + +--- + +## `crocodash bundle`, `fork`, `duplicate` + +For sharing cases with others, see [Shareable Configuration](shareable.md). diff --git a/docs/source/for_users/index.md b/docs/source/for_users/index.md index c0441480..c44cc027 100644 --- a/docs/source/for_users/index.md +++ b/docs/source/for_users/index.md @@ -21,6 +21,7 @@ datasets forcing_configurations extract_forcings case_information +cli shareable additional_resources ``` diff --git a/docs/source/for_users/shareable.md b/docs/source/for_users/shareable.md index ce779b64..420fe7c1 100644 --- a/docs/source/for_users/shareable.md +++ b/docs/source/for_users/shareable.md @@ -5,9 +5,11 @@ Ever wanted to share your regional MOM6 setup? Get a summary of your unique chan Importable through `CrocoDash.shareable`, the module lets you: 1. **Bundle** - Inspect an existing CESM case, identify what makes it unique, and package it into a portable folder -2. **Fork** - Recreate a case from a bundle, with optional modifications +2. **Fork** - Recreate a case from a bundle, guided through any changes via an interactive YAML review 3. **Duplicate** - One-step shortcut to copy a case to a new location, reading machine/project/cesmroot automatically from the original +The shareable workflow is built on top of the [`create`/`dump` primitives](cli.md): `bundle` uses `dump` internally to write the case config as `crocodash_case.yaml`, and `fork` uses `create` internally to build the new case from (a modified copy of) that YAML. + --- ## Workflow @@ -23,15 +25,9 @@ case = BundleCrocoDashCase("/path/to/caseroot") bundle_path = case.bundle("/path/to/output_dir") ``` -If you need to override the machine or project used for the diff (e.g. generating a bundle on a different machine than the original), pass them explicitly: - -```python -bundle_path = case.bundle("/path/to/output_dir", machine="derecho", project="PROJ123") -``` - The bundle folder contains: -- `manifest.json` — grid paths, forcing config, all case metadata -- `non_standard_case_info.json` — diff against a standard case +- `crocodash_case.yaml` — complete case config (grid, topo, vgrid, case, forcings) +- `non_standard_case_info.json` — diff against a standard case (user_nl, xmlchanges, xml_files, SourceMods) - `ocnice/` — ocean/ice input files plus grid files - `user_nl_*` files, `replay.sh` - `xml_files/` and `SourceMods/` — any non-standard modifications @@ -52,11 +48,14 @@ case = forker.fork( ) ``` -By default `fork()` is interactive — it will ask you which non-standard items to copy over (XML files, user_nl params, SourceMods, xmlchanges) and whether you want to change the compset. +`fork()` guides you through the key fields interactively: -#### Non-interactive fork +1. **Path and machine review** — prompts for `caseroot`, `inputdir`, `cesmroot`, `machine`, `project`, `compset`, and (if forcings were configured) `date_range` and `boundaries`. Press Enter to keep the pre-filled value. +2. **EDITOR** — if `$EDITOR` is set, offers to open the full YAML for deeper modifications (changing forcing kwargs, compset modifiers, etc.). +3. **Confirmation** — shows the final config and asks to proceed. +4. **Plan** — asks interactively which non-standard CESM state to copy (XML files, user_nl params, SourceMods, xmlchanges). Pass `plan=` to skip the prompts. -All prompts can be bypassed by passing arguments directly: +#### Non-interactive fork ```python case = forker.fork( @@ -66,19 +65,13 @@ case = forker.fork( new_caseroot="/path/to/new_case", new_inputdir="/path/to/new_inputdir", plan={"xml_files": True, "user_nl": True, "source_mods": False, "xmlchanges": True}, - compset="GOMOM6", # omit to keep the bundle's compset - extra_configs=["tides"], # additional forcing configs to add - remove_configs=["bgc"], # forcing configs to drop - extra_forcing_args_path="/path/to/args.json", # only needed if adding new configs ) ``` -Any argument left as `None` (the default) will still prompt interactively, so you can pre-supply only some of them. +To change forcings, compset, or other parameters: run `crocodash dump` on the bundle's YAML, edit it, and pass it to `crocodash create` directly — no need to use fork for that. ### Duplicate (one-step shortcut) -If you just want an exact copy of an existing case without any modifications, use `duplicate_case`. It reads machine, project, and cesmroot directly from the original caseroot — no extra arguments needed. - ```python from CrocoDash.shareable.bundle import duplicate_case @@ -89,7 +82,7 @@ new_case = duplicate_case( ) ``` -The bundle is written into `new_caseroot` and kept there after duplicating. You can also specify a custom location: +Reads machine, project, and cesmroot from `crocodash_state.json` in the original case. Pass `bundle_dir=` to save the bundle for reference: ```python new_case = duplicate_case( @@ -104,8 +97,6 @@ new_case = duplicate_case( ## Command Line -After installing CrocoDash (`pip install -e .`), a `crocodash` command is available. - ### Bundle ```bash @@ -120,7 +111,7 @@ crocodash bundle \ ### Fork ```bash -# Interactive +# Interactive (guided YAML review) crocodash fork \ --bundle /path/to/bundle \ --caseroot /path/to/new_case \ @@ -129,7 +120,7 @@ crocodash fork \ --machine derecho \ --project PROJ123 -# Non-interactive +# Non-interactive (skip CESM-state copy prompts) crocodash fork \ --bundle /path/to/bundle \ --caseroot /path/to/new_case \ @@ -137,15 +128,9 @@ crocodash fork \ --cesmroot /path/to/cesm \ --machine derecho \ --project PROJ123 \ - --plan '{"xml_files": true, "user_nl": true, "source_mods": false, "xmlchanges": true}' \ - --compset GOMOM6 \ - --extra-configs tides,bgc \ - --remove-configs runoff \ - --extra-forcing-args /path/to/args.json + --plan '{"xml_files": true, "user_nl": true, "source_mods": false, "xmlchanges": true}' ``` -All `fork` flags beyond the six required ones are optional and only needed to bypass the interactive prompts. - ### Duplicate ```bash @@ -155,16 +140,6 @@ crocodash duplicate \ --inputdir /path/to/new_inputdir ``` -Machine, project, and cesmroot are read automatically from the original case. Optionally specify where to keep the bundle: - -```bash -crocodash duplicate \ - --source /path/to/existing_case \ - --case /path/to/new_case \ - --inputdir /path/to/new_inputdir \ - --bundle-dir /path/to/bundle -``` - --- ## What gets diffed? diff --git a/tests/extract_forcings/test_case_integration.py b/tests/extract_forcings/test_case_integration.py index 08bbe6e8..e5bd9f7b 100644 --- a/tests/extract_forcings/test_case_integration.py +++ b/tests/extract_forcings/test_case_integration.py @@ -2,8 +2,8 @@ import json -def test_case_integration_driver(get_CrocoDash_case, skip_if_not_glade): - case = get_CrocoDash_case +def test_case_integration_driver(CrocoDash_case_factory, tmp_path, skip_if_not_glade): + case = CrocoDash_case_factory(tmp_path) case.configure_forcings( date_range=["2020-01-01 00:00:00", "2020-01-02 00:00:00"], boundaries=["north", "south", "east"], @@ -22,8 +22,8 @@ def test_case_integration_driver(get_CrocoDash_case, skip_if_not_glade): return -def test_case_integration_config(get_CrocoDash_case): - case = get_CrocoDash_case +def test_case_integration_config(CrocoDash_case_factory, tmp_path): + case = CrocoDash_case_factory(tmp_path) case.configure_forcings( date_range=["2020-01-01 00:00:00", "2020-02-01 00:00:00"], boundaries=["north", "south", "east"], @@ -45,11 +45,11 @@ def test_case_integration_config(get_CrocoDash_case): } -def test_driver_works(get_CrocoDash_case, tmp_path): +def test_driver_works(CrocoDash_case_factory, tmp_path): """ Test that the setup for the forcings works """ - case = get_CrocoDash_case + case = CrocoDash_case_factory(tmp_path / "case") case.configure_forcings( date_range=["2020-01-01 00:00:00", "2020-02-01 00:00:00"], tidal_constituents=["M2"], diff --git a/tests/fixtures/objects.py b/tests/fixtures/objects.py index fee60a68..a8423e8e 100644 --- a/tests/fixtures/objects.py +++ b/tests/fixtures/objects.py @@ -26,7 +26,7 @@ def setup_sample_rm6_expt(tmp_path): return expt -@pytest.fixture(scope="module") +@pytest.fixture(scope="session") def get_case_with_cf(CrocoDash_case_factory, tmp_path_factory): case = CrocoDash_case_factory(tmp_path_factory.mktemp(f"case-{uuid4().hex}")) case.configure_forcings( @@ -36,7 +36,19 @@ def get_case_with_cf(CrocoDash_case_factory, tmp_path_factory): return case -@pytest.fixture(scope="module") +@pytest.fixture(scope="session") +def get_shareable_CrocoDash_case(CrocoDash_case_factory, tmp_path_factory): + case = CrocoDash_case_factory(tmp_path_factory.mktemp(f"case-{uuid4().hex}")) + case.configure_forcings( + date_range=["2020-01-01 00:00:00", "2020-01-09 00:00:00"], + tidal_constituents=["M2"], + tpxo_elevation_filepath="s3://crocodile-cesm/CrocoDash/data/tpxo/h_tpxo9.v1.zarr/", + tpxo_velocity_filepath="s3://crocodile-cesm/CrocoDash/data/tpxo/u_tpxo9.v1.zarr/", + ) + return case + + +@pytest.fixture(scope="session") def get_CrocoDash_case(CrocoDash_case_factory, tmp_path_factory): # Set some defaults diff --git a/tests/shareable/test_bundle.py b/tests/shareable/test_bundle.py index c7d8fdb6..746d724a 100644 --- a/tests/shareable/test_bundle.py +++ b/tests/shareable/test_bundle.py @@ -1,6 +1,7 @@ from CrocoDash.shareable.bundle import * import pytest import subprocess +import yaml from pathlib import Path @@ -72,27 +73,20 @@ def test_diff_CESM_cases_alldiff(two_cesm_cases): assert output.xmlchanges_missing == ["JOB_PRIORITY"] -def test_identify_CrocoDashCase_init_args(get_case_with_cf, fake_RCC_empty_case): +def test_load_state_from_crocodash_init_args(get_case_with_cf): case = get_case_with_cf - rcc = fake_RCC_empty_case - rcc.caseroot = case.caseroot - rcc._get_cesmroot() - rcc._read_user_nls() - init_args = rcc._identify_CrocoDashCase_init_args() - print(init_args) + rcc = BundleCrocoDashCase(case.caseroot) + init_args = rcc.init_args assert str(case.inputdir / "ocnice") == str(init_args["inputdir_ocnice"]) - assert str(init_args["supergrid_path"]).startswith(str("ocean_hgrid_pana")) + assert str(init_args["supergrid_path"]).startswith("ocean_hgrid_pana") + assert str(init_args["topo_path"]).startswith("ocean_topog_pana") + assert str(init_args["vgrid_path"]).startswith("ocean_vgrid_pana") + assert "compset" in init_args - assert str(init_args["topo_path"]).startswith(str("ocean_topog_pana")) - assert str(init_args["vgrid_path"]).startswith(str("ocean_vgrid_pana")) - - assert init_args["compset"] == "1850_DATM%JRA_SLND_SICE_MOM6_SROF_SGLC_SWAV_SESP" - - -def test_identify_CrocoDashCase_forcing_config_args( - CrocoDash_case_factory, tmp_path_factory, fake_RCC_empty_case +def test_load_state_from_crocodash_forcing_config( + CrocoDash_case_factory, tmp_path_factory ): case1 = CrocoDash_case_factory(tmp_path_factory.mktemp("forcing_config_args")) case1.configure_forcings( @@ -101,24 +95,13 @@ def test_identify_CrocoDashCase_forcing_config_args( tpxo_elevation_filepath="s3://crocodile-cesm/CrocoDash/data/tpxo/h_tpxo9.v1.zarr/", tpxo_velocity_filepath="s3://crocodile-cesm/CrocoDash/data/tpxo/u_tpxo9.v1.zarr/", ) - rcc = fake_RCC_empty_case - rcc.caseroot = case1.caseroot - rcc._get_cesmroot() - rcc._read_user_nls() - forcing_config = rcc._identify_CrocoDashCase_forcing_config_args() - # Since this just reads the forcing_config json file in input directory, I'll only check one thing in it - assert "tides" in forcing_config + rcc = BundleCrocoDashCase(case1.caseroot) + assert "tides" in rcc.forcing_config -def test_identify_non_standard_case_information(get_CrocoDash_case): +def test_identify_non_standard_case_information(get_shareable_CrocoDash_case): - case1 = get_CrocoDash_case - case1.configure_forcings( - date_range=["2020-01-01 00:00:00", "2020-01-09 00:00:00"], - tidal_constituents=["M2"], - tpxo_elevation_filepath="s3://crocodile-cesm/CrocoDash/data/tpxo/h_tpxo9.v1.zarr/", - tpxo_velocity_filepath="s3://crocodile-cesm/CrocoDash/data/tpxo/u_tpxo9.v1.zarr/", - ) + case1 = get_shareable_CrocoDash_case xml_file = Path(case1.caseroot) / "test.xml" xml_file.write_text("data") @@ -234,16 +217,17 @@ def test_bundle_with_modifications(CrocoDash_case_factory, tmp_path_factory, tmp replay_sh_path = case_bundle / "replay.sh" assert replay_sh_path.exists() - # Check that manifest.json was written - json_file = case_bundle / "manifest.json" - assert json_file.exists() - with open(json_file) as f: - saved_output = json.load(f) + # Check that crocodash_case.yaml was written (replaces the old manifest.json) + yaml_file = case_bundle / "crocodash_case.yaml" + assert yaml_file.exists() + with open(yaml_file) as f: + saved_yaml = yaml.safe_load(f) json_file = case_bundle / "non_standard_case_info.json" with open(json_file) as f: differences = json.load(f) - assert "init_args" in saved_output - assert "forcing_config" in saved_output + assert "case" in saved_yaml + assert "grid" in saved_yaml + assert "forcings" in saved_yaml assert differences["xml_files_missing_in_new"] == ["custom_settings.xml"] assert differences["source_mods_missing_files"] == ["src.mom/custom_module.F90"] @@ -306,7 +290,7 @@ def test_read_sourcemods(fake_RCC_empty_case, tmp_path): case._read_sourcemods() # Expected relative paths - expected = {Path("src.drv/file1.txt"), Path("src.mom/file2.txt")} + expected = {"src.drv/file1.txt", "src.mom/file2.txt"} # Assert assert case.sourcemods == expected diff --git a/tests/shareable/test_cli.py b/tests/shareable/test_cli.py index 826d6543..60f73acf 100644 --- a/tests/shareable/test_cli.py +++ b/tests/shareable/test_cli.py @@ -50,7 +50,6 @@ def test_fork_cli(tmp_path): "source_mods": True, "xmlchanges": True, } - args_file = tmp_path / "args.json" with patch( "CrocoDash.shareable.fork.ForkCrocoDashBundle", return_value=mock_forker @@ -72,14 +71,6 @@ def test_fork_cli(tmp_path): "PROJ123", "--plan", json.dumps(plan), - "--compset", - "GOMOM6", - "--extra-configs", - "tides,bgc", - "--remove-configs", - "runoff", - "--extra-forcing-args", - str(args_file), ] ) @@ -90,10 +81,6 @@ def test_fork_cli(tmp_path): new_caseroot=str(tmp_path / "new_case"), new_inputdir=str(tmp_path / "inputdir"), plan=plan, - compset="GOMOM6", - extra_configs=["tides", "bgc"], - remove_configs=["runoff"], - extra_forcing_args_path=str(args_file), ) diff --git a/tests/shareable/test_fork.py b/tests/shareable/test_fork.py index bb1fda20..bd991da1 100644 --- a/tests/shareable/test_fork.py +++ b/tests/shareable/test_fork.py @@ -1,7 +1,7 @@ from CrocoDash.shareable.fork import * import json import pytest -from types import SimpleNamespace +from pathlib import Path from unittest.mock import patch from uuid import uuid4 @@ -84,23 +84,51 @@ def test_resolve_copy_plan_with_provided_plan(fake_fcb_empty_case): assert fcb.plan is provided -def test_resolve_compset(fake_fcb_empty_case): - """Test _resolve_compset sets compset on self.""" +def test_configure_yaml_for_forked_case_args(fake_fcb_empty_case, tmp_path): + """Test _configure_yaml_for_forked_case_args correctly patches destination fields.""" fcb = fake_fcb_empty_case - bundle_compset = "1850_DATM%JRA_SLND_SICE_MOM6_SROF_SGLC_SWAV" - fcb.manifest = BundleManifest( - forcing_config={}, - init_args={"compset": bundle_compset}, + fcb.bundle_location = tmp_path / "bundle" + (fcb.bundle_location / "ocnice").mkdir(parents=True) + + bundle_yaml = { + "case": { + "cesmroot": "/old/cesm", + "machine": "old_machine", + "project": "OLD123", + "caseroot": "/old/case", + "inputdir": "/old/inputdir", + "compset": "1850_DATM%JRA_SLND_SICE_MOM6_SROF_SGLC_SWAV", + }, + "grid": {"supergrid_path": "/old/ocnice/ocean_hgrid.nc"}, + "topo": { + "source": { + "type": "from_file", + "topo_file_path": "/old/ocnice/ocean_topog.nc", + } + }, + "vgrid": {"type": "from_file", "filename": "/old/ocnice/ocean_vgrid.nc"}, + } + fcb.bundle_yaml = bundle_yaml + + config = fcb._configure_yaml_for_forked_case_args( + cesmroot="/new/cesm", + machine="new_machine", + project_number="NEW123", + new_caseroot="/new/case", + new_inputdir="/new/inputdir", ) - fcb._resolve_compset(None) - - assert fcb.compset == bundle_compset - - new_compset = "2000_DATM%JRA_SLND_SICE_MOM6_SROF_SGLC_SWAV" - fcb._resolve_compset(new_compset) - - assert fcb.compset == new_compset + assert config["case"]["cesmroot"] == "/new/cesm" + assert config["case"]["machine"] == "new_machine" + assert config["case"]["project"] == "NEW123" + assert config["case"]["caseroot"] == "/new/case" + assert config["case"]["inputdir"] == "/new/inputdir" + # Original must be unchanged + assert bundle_yaml["case"]["cesmroot"] == "/old/cesm" + # Grid/topo/vgrid paths are redirected to bundle ocnice + assert "ocean_hgrid.nc" in config["grid"]["supergrid_path"] + assert "ocean_topog.nc" in config["topo"]["source"]["topo_file_path"] + assert "ocean_vgrid.nc" in config["vgrid"]["filename"] def test_build_general_configure_forcing_args(sample_forcing_config): @@ -126,114 +154,6 @@ def test_build_general_configure_forcing_args(sample_forcing_config): assert "marbl_ic_filepath" in args -def test_resolve_forcing_args_no_configs(fake_fcb_empty_case, sample_forcing_config): - """Test _resolve_forcing_args sets configure_forcing_args unchanged when no configs requested.""" - fcb = fake_fcb_empty_case - fcb.manifest = BundleManifest(forcing_config=sample_forcing_config, init_args={}) - fcb.resolved_remove = {} - fcb.requested_configs = [] - - fcb._resolve_forcing_args(None) - - assert fcb.configure_forcing_args == { - "date_range": ["2020-01-01 00:00:00", "2020-01-09 00:00:00"], - "boundaries": ["north"], - "product_name": "GLORYS", - "function_name": "get_glorys_data_script_for_cli", - "tpxo_elevation_filepath": "ASd", - "tpxo_velocity_filepath": "ASd", - "tidal_constituents": ["M2", "K1"], - "marbl_ic_filepath": "qwreqwre", - } - - -def test_resolve_forcing_args_with_json_file( - fake_fcb_empty_case, sample_forcing_config, tmp_path -): - """Test that _resolve_forcing_args loads extra args from a JSON file path.""" - fcb = fake_fcb_empty_case - fcb.manifest = BundleManifest(forcing_config=sample_forcing_config, init_args={}) - fcb.resolved_remove = {} - fcb.requested_configs = ["tides"] - - args_file = tmp_path / "forcing_args.json" - args_file.write_text( - json.dumps( - { - "tidal_constituents": ["M2", "K1"], - "tpxo_elevation_filepath": "elev.nc", - "tpxo_velocity_filepath": "vel.nc", - "boundaries": ["north"], - } - ) - ) - - fcb._resolve_forcing_args(str(args_file)) - - assert fcb.configure_forcing_args["tidal_constituents"] == ["M2", "K1"] - - -def test_resolve_forcing_args_missing_required_arg( - fake_fcb_empty_case, sample_forcing_config, tmp_path -): - """Test that _resolve_forcing_args raises ValueError when required args are missing.""" - fcb = fake_fcb_empty_case - fcb.manifest = BundleManifest(forcing_config=sample_forcing_config, init_args={}) - fcb.resolved_remove = {"tides"} # remove tides so its args aren't pre-populated - fcb.requested_configs = ["tides"] - - args_file = tmp_path / "incomplete_args.json" - args_file.write_text(json.dumps({"tidal_constituents": ["M2"]})) - - with pytest.raises(ValueError, match="Missing arg"): - fcb._resolve_forcing_args(str(args_file)) - - -def test_resolve_forcing_configurations(fake_fcb_empty_case, sample_forcing_config): - """Test _resolve_forcing_configurations sets requested and removed configs on self.""" - fcb = fake_fcb_empty_case - fcb.manifest = BundleManifest(forcing_config=sample_forcing_config, init_args={}) - fcb.compset = "2000_DATM%JRA_SLND_SICE_MOM6_SROF_SGLC_SWAV" - - with patch( - "CrocoDash.shareable.fork.ForcingConfigRegistry.find_required_configurators", - return_value=[], - ): - with patch( - "CrocoDash.shareable.fork.ForcingConfigRegistry.find_valid_configurators", - return_value=[], - ): - with patch("CrocoDash.shareable.fork.ask_string", side_effect=["", "bgc"]): - fcb._resolve_forcing_configurations(None, None) - - assert isinstance(fcb.requested_configs, list) - assert isinstance(fcb.resolved_remove, set) - assert "bgc" in fcb.resolved_remove - - -def test_resolve_forcing_configurations_required_missing( - fake_fcb_empty_case, sample_forcing_config -): - """Test that a required configurator absent from the manifest is added to requested_configs.""" - fcb = fake_fcb_empty_case - # manifest has no "bgc" entry - fcb.manifest = BundleManifest(forcing_config={"basic": {}}, init_args={}) - fcb.compset = "2000_DATM%JRA_SLND_SICE_MOM6_SROF_SGLC_SWAV" - - mock_required = SimpleNamespace(name="BGC") - - with patch( - "CrocoDash.shareable.fork.ForcingConfigRegistry.find_required_configurators", - return_value=[mock_required], - ), patch( - "CrocoDash.shareable.fork.ForcingConfigRegistry.find_valid_configurators", - return_value=[], - ): - fcb._resolve_forcing_configurations(extra_configs=[], remove_configs=[]) - - assert "bgc" in fcb.requested_configs - - def test_ask_input_response(): """Test ask_yes_no returns True for yes/y response.""" with patch("builtins.input", return_value="yes"): diff --git a/tests/test_case.py b/tests/test_case.py index 765d8aee..216709c8 100644 --- a/tests/test_case.py +++ b/tests/test_case.py @@ -78,24 +78,8 @@ def test_create_grid_input(get_CrocoDash_case): assert len(files) > 0 -def test_case_expt_smoke(get_CrocoDash_case, tmp_path): - case = get_CrocoDash_case - case.configure_forcings( - date_range=["2020-01-01 00:00:00", "2020-02-01 00:00:00"], - tidal_constituents=["M2"], - tpxo_elevation_filepath=tmp_path, - tpxo_velocity_filepath=tmp_path, - chl_processed_filepath=tmp_path, - boundaries=["north", "south", "east"], - ) - assert case.expt is not None - - -def test_configure_forcings(get_CrocoDash_case, tmp_path): - """ - Test that the setup for the forcings works - """ - case = get_CrocoDash_case +def test_configure_forcings(CrocoDash_case_factory, tmp_path_factory, tmp_path): + case = CrocoDash_case_factory(tmp_path_factory.mktemp(f"case-{uuid4().hex}")) case.configure_forcings( date_range=["2020-01-01 00:00:00", "2020-02-01 00:00:00"], tidal_constituents=["M2"], @@ -105,16 +89,14 @@ def test_configure_forcings(get_CrocoDash_case, tmp_path): boundaries=["north", "south", "east"], ) + assert case.expt is not None assert case.date_range[0].year == 2020 assert case.fcr["tides"].tidal_constituents == ["M2"] assert case.boundaries == ["north", "south", "east"] -def test_process_forcing(get_CrocoDash_case, tmp_path): - """ - Test that the setup for the forcings works - """ - case = get_CrocoDash_case +def test_process_forcing(CrocoDash_case_factory, tmp_path_factory, tmp_path): + case = CrocoDash_case_factory(tmp_path_factory.mktemp(f"case-{uuid4().hex}")) case.configure_forcings( date_range=["2020-01-01 00:00:00", "2020-02-01 00:00:00"], tidal_constituents=["M2"], @@ -140,8 +122,8 @@ def test_process_forcing(get_CrocoDash_case, tmp_path): ) -def test_update_forcing_variables(get_CrocoDash_case): - case = get_CrocoDash_case +def test_update_forcing_variables(CrocoDash_case_factory, tmp_path_factory): + case = CrocoDash_case_factory(tmp_path_factory.mktemp(f"case-{uuid4().hex}")) search_string = "OBC_NUMBER_OF_SEGMENTS" found_user_nl_mom_adjusted_var = False diff --git a/tests/test_recipe_functions.py b/tests/test_recipe_functions.py new file mode 100644 index 00000000..46bd3dd7 --- /dev/null +++ b/tests/test_recipe_functions.py @@ -0,0 +1,391 @@ +""" +Unit tests for CrocoDash/recipe.py. + +Covers: +- validate_config_structure: valid/invalid variants +- load_config: file I/O + validation round-trip +- build_grid / build_topo / build_vgrid: each source type +- case_to_yaml: reads state files written by Case.__init__ + configure_forcings +- Round-trip: case_to_yaml output is a valid input for create_case_from_yaml +""" + +import json +import pytest +import yaml +from pathlib import Path + +from CrocoDash.recipe import ( + build_grid, + build_topo, + build_vgrid, + case_to_yaml, + load_config, + validate_config_structure, +) +from CrocoDash.grid import Grid +from CrocoDash.topo import Topo +from CrocoDash.vgrid import VGrid + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +MINIMAL_VALID_CONFIG = { + "grid": { + "lenx": 4.0, + "leny": 3.0, + "resolution": 0.1, + "xstart": 278.0, + "ystart": 7.0, + }, + "topo": { + "min_depth": 9.5, + "source": {"type": "flat", "depth": 100.0}, + }, + "vgrid": {"type": "uniform", "nk": 10, "depth": 100.0}, + "case": { + "cesmroot": "/cesm", + "caseroot": "/case", + "inputdir": "/inputdir", + "compset": "CR_JRA", + "machine": "derecho", + }, + "forcings": { + "date_range": ["2020-01-01 00:00:00", "2020-12-31 00:00:00"], + "boundaries": ["north", "south", "east", "west"], + "product_name": "GLORYS", + "function_name": "get_glorys_data_from_rda", + }, +} + + +# --------------------------------------------------------------------------- +# validate_config_structure +# --------------------------------------------------------------------------- + + +def test_validate_valid_config(): + validate_config_structure(MINIMAL_VALID_CONFIG) + + +def test_validate_missing_top_level_sections(): + bad = {"grid": {}, "topo": {}} # no vgrid, no case + with pytest.raises(ValueError, match="missing required top-level"): + validate_config_structure(bad) + + +@pytest.mark.parametrize( + "missing_key", ["cesmroot", "caseroot", "inputdir", "compset", "machine"] +) +def test_validate_missing_case_key(missing_key): + config = { + "grid": {}, + "topo": {}, + "vgrid": {"type": "uniform"}, + "case": { + k: "x" + for k in ("cesmroot", "caseroot", "inputdir", "compset", "machine") + if k != missing_key + }, + "forcings": { + "date_range": ["2020-01-01", "2020-02-01"], + "boundaries": ["north"], + "product_name": "GLORYS", + "function_name": "get_glorys", + }, + } + with pytest.raises(ValueError, match=f"case\\.{missing_key}"): + validate_config_structure(config) + + +def test_validate_invalid_topo_type(): + config = {**MINIMAL_VALID_CONFIG, "topo": {"source": {"type": "bogus"}}} + with pytest.raises(ValueError, match="topo\\.source\\.type"): + validate_config_structure(config) + + +def test_validate_invalid_vgrid_type(): + config = {**MINIMAL_VALID_CONFIG, "vgrid": {"type": "bogus"}} + with pytest.raises(ValueError, match="vgrid\\.type"): + validate_config_structure(config) + + +def test_validate_forcings_missing_date_range(): + forcings = { + "boundaries": ["north"], + "product_name": "GLORYS", + "function_name": "get_glorys", + } + config = {**MINIMAL_VALID_CONFIG, "forcings": forcings} + with pytest.raises(ValueError, match="forcings\\.date_range"): + validate_config_structure(config) + + +def test_validate_forcings_bad_date_range_not_list(): + config = { + **MINIMAL_VALID_CONFIG, + "forcings": { + "date_range": "2020-01-01", + "boundaries": ["north"], + "product_name": "GLORYS", + "function_name": "get_glorys", + }, + } + with pytest.raises(ValueError, match="date_range must be a list"): + validate_config_structure(config) + + +def test_validate_forcings_bad_date_range_wrong_length(): + config = { + **MINIMAL_VALID_CONFIG, + "forcings": { + "date_range": ["2020-01-01"], + "boundaries": ["north"], + "product_name": "GLORYS", + "function_name": "get_glorys", + }, + } + with pytest.raises(ValueError, match="date_range must be a list"): + validate_config_structure(config) + + +def test_validate_forcings_invalid_boundary(): + config = { + **MINIMAL_VALID_CONFIG, + "forcings": { + "date_range": ["2020-01-01", "2020-02-01"], + "boundaries": ["northwest"], + "product_name": "GLORYS", + "function_name": "get_glorys", + }, + } + with pytest.raises(ValueError, match="Invalid boundary"): + validate_config_structure(config) + + +# --------------------------------------------------------------------------- +# load_config +# --------------------------------------------------------------------------- + + +def test_load_config_valid_file(tmp_path): + config_file = tmp_path / "case.yaml" + config_file.write_text(yaml.dump(MINIMAL_VALID_CONFIG)) + loaded = load_config(config_file) + assert loaded["vgrid"]["type"] == "uniform" + assert loaded["case"]["machine"] == "derecho" + + +def test_load_config_invalid_file_raises(tmp_path): + bad = {"grid": {}, "topo": {}} + config_file = tmp_path / "bad.yaml" + config_file.write_text(yaml.dump(bad)) + with pytest.raises(ValueError): + load_config(config_file) + + +# --------------------------------------------------------------------------- +# build_grid +# --------------------------------------------------------------------------- + + +def test_build_grid_from_params(): + cfg = { + "lenx": 4.0, + "leny": 3.0, + "resolution": 0.5, + "xstart": 278.0, + "ystart": 7.0, + "name": "testgrid", + } + grid = build_grid(cfg) + assert isinstance(grid, Grid) + assert grid.name == "testgrid" + assert grid.lenx == pytest.approx(4.0, rel=0.01) + assert grid.leny == pytest.approx(3.0, rel=0.01) + + +def test_build_grid_from_supergrid_file(gen_grid_topo_vgrid, tmp_path): + orig_grid, _, _ = gen_grid_topo_vgrid + supergrid_path = tmp_path / "ocean_hgrid.nc" + orig_grid.write_supergrid(supergrid_path) + + cfg = {"supergrid_path": str(supergrid_path), "name": "reloaded"} + grid = build_grid(cfg) + assert isinstance(grid, Grid) + assert grid.name == "reloaded" + assert grid.nx == orig_grid.nx + assert grid.ny == orig_grid.ny + + +def test_build_grid_from_supergrid_preserves_shape(gen_grid_topo_vgrid, tmp_path): + orig_grid, _, _ = gen_grid_topo_vgrid + supergrid_path = tmp_path / "ocean_hgrid.nc" + orig_grid.write_supergrid(supergrid_path) + + grid = build_grid({"supergrid_path": str(supergrid_path)}) + assert grid.nx == orig_grid.nx + assert grid.ny == orig_grid.ny + + +# --------------------------------------------------------------------------- +# build_topo +# --------------------------------------------------------------------------- + + +def test_build_topo_flat(get_rect_grid): + cfg = {"min_depth": 9.5, "source": {"type": "flat", "depth": 500.0}} + topo = build_topo(cfg, get_rect_grid) + assert isinstance(topo, Topo) + assert topo.max_depth == pytest.approx(500.0, rel=0.01) + assert topo.min_depth == pytest.approx(9.5, rel=0.01) + + +def test_build_topo_from_file(gen_grid_topo_vgrid, tmp_path): + grid, orig_topo, _ = gen_grid_topo_vgrid + topo_path = tmp_path / "ocean_topog.nc" + orig_topo.write_topo(topo_path) + + cfg = { + "min_depth": orig_topo.min_depth, + "source": {"type": "from_file", "topo_file_path": str(topo_path)}, + } + topo = build_topo(cfg, grid) + assert isinstance(topo, Topo) + assert topo.max_depth == pytest.approx(orig_topo.max_depth, rel=0.01) + + +def test_build_topo_unknown_type_raises(get_rect_grid): + cfg = {"min_depth": 9.5, "source": {"type": "unknown"}} + with pytest.raises(ValueError, match="Unknown topo\\.source\\.type"): + build_topo(cfg, get_rect_grid) + + +# --------------------------------------------------------------------------- +# build_vgrid +# --------------------------------------------------------------------------- + + +def test_build_vgrid_uniform(get_rect_grid_and_topo): + _, topo = get_rect_grid_and_topo + cfg = {"type": "uniform", "nk": 10, "depth": 200.0} + vgrid = build_vgrid(cfg, topo) + assert isinstance(vgrid, VGrid) + assert vgrid.nk == 10 + assert vgrid.depth == pytest.approx(200.0, rel=0.01) + + +def test_build_vgrid_hyperbolic(get_rect_grid_and_topo): + _, topo = get_rect_grid_and_topo + cfg = {"type": "hyperbolic", "nk": 20, "depth": 1000.0, "ratio": 10.0} + vgrid = build_vgrid(cfg, topo) + assert isinstance(vgrid, VGrid) + assert vgrid.nk == 20 + + +def test_build_vgrid_depth_defaults_to_topo_max_depth(get_rect_grid_and_topo): + _, topo = get_rect_grid_and_topo + cfg = {"type": "uniform", "nk": 5} # no depth key + vgrid = build_vgrid(cfg, topo) + assert vgrid.depth == pytest.approx(topo.max_depth, rel=0.01) + + +def test_build_vgrid_from_file(get_vgrid, tmp_path): + vgrid_path = tmp_path / "vgrid.nc" + get_vgrid.write(vgrid_path) + + cfg = {"type": "from_file", "filename": str(vgrid_path)} + vgrid = build_vgrid(cfg, topo=None) + assert isinstance(vgrid, VGrid) + assert vgrid.nk == get_vgrid.nk + assert vgrid.depth == pytest.approx(get_vgrid.depth, rel=0.01) + + +def test_build_vgrid_unknown_type_raises(get_rect_grid_and_topo): + _, topo = get_rect_grid_and_topo + with pytest.raises(ValueError, match="Unknown vgrid\\.type"): + build_vgrid({"type": "bogus", "nk": 5}, topo) + + +# --------------------------------------------------------------------------- +# case_to_yaml +# --------------------------------------------------------------------------- + + +def test_case_to_yaml_missing_state_file(tmp_path): + with pytest.raises(FileNotFoundError, match="crocodash_state\\.json"): + case_to_yaml(tmp_path / "no_case_here") + + +def test_case_to_yaml_structure(get_CrocoDash_case): + case = get_CrocoDash_case + config = case_to_yaml(case.caseroot) + + assert set(config.keys()) >= {"grid", "topo", "vgrid", "case"} + assert "supergrid_path" in config["grid"] + assert "min_depth" in config["topo"] + assert config["topo"]["source"]["type"] == "from_file" + assert config["vgrid"]["type"] == "from_file" + assert "cesmroot" in config["case"] + assert "caseroot" in config["case"] + assert "inputdir" in config["case"] + assert "compset" in config["case"] + assert "machine" in config["case"] + + +def test_case_to_yaml_values_match_case(get_CrocoDash_case): + case = get_CrocoDash_case + config = case_to_yaml(case.caseroot) + + assert config["case"]["compset"] == case.compset_lname + assert config["case"]["machine"] == case.machine + assert config["case"]["caseroot"] == str(case.caseroot) + assert config["case"]["inputdir"] == str(case.inputdir) + assert config["grid"]["supergrid_path"] == case.supergrid_path + + +def test_case_to_yaml_with_forcings(get_case_with_cf): + case = get_case_with_cf + config = case_to_yaml(case.caseroot) + + assert "forcings" in config + assert "date_range" in config["forcings"] + assert "boundaries" in config["forcings"] + assert "product_name" in config["forcings"] + assert "function_name" in config["forcings"] + assert isinstance(config["forcings"]["date_range"], list) + assert len(config["forcings"]["date_range"]) == 2 + + +# --------------------------------------------------------------------------- +# Round-trip: case_to_yaml output is valid input for create_case_from_yaml +# --------------------------------------------------------------------------- + + +def test_case_to_yaml_round_trip_is_valid_config(get_case_with_cf): + """case_to_yaml output must pass validate_config_structure without error.""" + case = get_case_with_cf + config = case_to_yaml(case.caseroot) + validate_config_structure(config) + + +def test_case_to_yaml_round_trip(get_case_with_cf, tmp_path): + """case_to_yaml output can be written to YAML and reloaded identically, including forcings.""" + case = get_case_with_cf + config = case_to_yaml(case.caseroot) + + yaml_path = tmp_path / "round_trip.yaml" + yaml_path.write_text(yaml.dump(config, default_flow_style=False, sort_keys=False)) + reloaded = yaml.safe_load(yaml_path.read_text()) + + assert reloaded["case"]["compset"] == config["case"]["compset"] + assert reloaded["case"]["machine"] == config["case"]["machine"] + assert reloaded["grid"]["supergrid_path"] == config["grid"]["supergrid_path"] + assert ( + reloaded["topo"]["source"]["topo_file_path"] + == config["topo"]["source"]["topo_file_path"] + ) + assert reloaded["vgrid"]["filename"] == config["vgrid"]["filename"] + assert reloaded["forcings"]["date_range"] == config["forcings"]["date_range"] + assert reloaded["forcings"]["boundaries"] == config["forcings"]["boundaries"] + assert reloaded["forcings"]["product_name"] == config["forcings"]["product_name"]