diff --git a/cordex/__init__.py b/cordex/__init__.py index 639a83c..903ba10 100644 --- a/cordex/__init__.py +++ b/cordex/__init__.py @@ -2,7 +2,14 @@ from . import regions, tables, tutorial from .accessor import CordexDataArrayAccessor, CordexDatasetAccessor # noqa -from .domain import cordex_domain, create_dataset, domain, domain_info, vertices, rewrite_coords +from .domain import ( + cordex_domain, + create_dataset, + domain, + domain_info, + vertices, + rewrite_coords, +) from .tables import domains, ecmwf from .transform import ( map_crs, diff --git a/cordex/domain.py b/cordex/domain.py index 78a7f2d..1fc388f 100644 --- a/cordex/domain.py +++ b/cordex/domain.py @@ -594,8 +594,11 @@ def rewrite_coords(ds, coords="xy", domain_id=None, mip_era="CMIP5", method="nea ds : xr.Dataset The dataset with rewritten coordinates. """ - if domain_id is None and ds.cf["grid_mapping"].grid_mapping_name == "rotated_latitude_longitude": - domain_id = ds.cx.domain_id + if ( + domain_id is None + and ds.cf["grid_mapping"].grid_mapping_name == "rotated_latitude_longitude" + ): + domain_id = ds.cx.domain_id if domain_id: grid_info = domain_info(domain_id) dx = grid_info["dlon"] @@ -619,25 +622,28 @@ def rewrite_coords(ds, coords="xy", domain_id=None, mip_era="CMIP5", method="nea if coords == "xy" or coords == "all": ds = ds.cf.reindex(X=xn, Y=yn, method=method) - + if coords == "lonlat" or coords == "all": try: - trg_dims=(ds.cf["longitude"].name, ds.cf["latitude"].name) - overwrite=True + trg_dims = (ds.cf["longitude"].name, ds.cf["latitude"].name) + overwrite = True except KeyError: - trg_dims=("lon", "lat") - overwrite=False + trg_dims = ("lon", "lat") + overwrite = False dst = transform_coords(ds, trg_dims=trg_dims) if overwrite is False: - ds = ds.assign_coords({trg_dims[0]: dst[trg_dims[0]], trg_dims[1]: dst[trg_dims[1]]}) + ds = ds.assign_coords( + {trg_dims[0]: dst[trg_dims[0]], trg_dims[1]: dst[trg_dims[1]]} + ) ds[trg_dims[0]].attrs = cf.vocabulary[mip_era]["coords"][trg_dims[0]] ds[trg_dims[1]].attrs = cf.vocabulary[mip_era]["coords"][trg_dims[1]] else: ds[trg_dims[0]][:] = dst[trg_dims[0]] ds[trg_dims[1]][:] = dst[trg_dims[1]] - + return ds + def _crop_to_domain(ds, domain_id, drop=True): domain = cordex_domain(domain_id) x_mask = ds.cf["X"].round(8).isin(domain.cf["X"])