diff --git a/src/ai_models/model.py b/src/ai_models/model.py index e643b5a..01eda32 100644 --- a/src/ai_models/model.py +++ b/src/ai_models/model.py @@ -43,8 +43,7 @@ def __exit__(self, *args): class ArchiveCollector: - UNIQUE = {"date", "hdate", "time", - "referenceDate", "type", "stream", "expver"} + UNIQUE = {"date", "hdate", "time", "referenceDate", "type", "stream", "expver"} def __init__(self) -> None: self.expect = 0 @@ -56,8 +55,7 @@ def add(self, field): self.request[k].add(str(v)) if k in self.UNIQUE: if len(self.request[k]) > 1: - raise ValueError( - f"Field {field} has different values for {k}: {self.request[k]}") + raise ValueError(f"Field {field} has different values for {k}: {self.request[k]}") class Model: @@ -162,8 +160,7 @@ def json_default(obj): raise TypeError print( - json.dumps(json_requests, separators=( - ",", ":"), default=json_default, sort_keys=True), + json.dumps(json_requests, separators=(",", ":"), default=json_default, sort_keys=True), file=f, ) @@ -173,8 +170,7 @@ def download_assets(self, **kwargs): if not os.path.exists(asset): os.makedirs(os.path.dirname(asset), exist_ok=True) LOG.info("Downloading %s", asset) - download(self.download_url.format( - file=file), asset + ".download") + download(self.download_url.format(file=file), asset + ".download") os.rename(asset + ".download", asset) @property @@ -447,8 +443,7 @@ def _requests(self): def filter_constant(request): # We check for 'sfc' because param 'z' can be ambiguous if request.get("levtype") == "sfc": - param = set(self.constant_fields) & set( - request.get("param", [])) + param = set(self.constant_fields) & set(request.get("param", [])) if param: request["param"] = list(param) return True @@ -459,8 +454,7 @@ def filter_prognostic(request): # TODO: We assume here that prognostic fields are # the ones that are not constant. This may not always be true if request.get("levtype") == "sfc": - param = set(request.get("param", [])) - \ - set(self.constant_fields) + param = set(request.get("param", [])) - set(self.constant_fields) if param: request["param"] = list(param) return True @@ -502,8 +496,7 @@ def peek_into_checkpoint(self, path): def parse_model_args(self, args): if args: - raise NotImplementedError( - f"This model does not accept arguments {args}") + raise NotImplementedError(f"This model does not accept arguments {args}") def provenance(self): from .provenance import gather_provenance_info @@ -597,8 +590,7 @@ def write_input_fields( """ template = base64.b64decode(template) - accumulations_template = ekd.from_source( - "memory", template)[0] + accumulations_template = ekd.from_source("memory", template)[0] for param in accumulations: self.write(