Skip to content

Commit 87a30ce

Browse files
authored
Do not add scalar coords from the target grid to the regridding output (#418)
1 parent fceddc8 commit 87a30ce

6 files changed

+32
-16
lines changed

.pre-commit-config.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ repos:
2626
rev: 7.1.1
2727
hooks:
2828
- id: flake8
29+
2930
- repo: https://github.com/PyCQA/isort
3031
rev: 6.0.0
3132
hooks:

CHANGES.rst

+1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ What's new
44
0.8.9 (unreleased)
55
------------------
66
* Destroy grids explicitly once weights are computed. Do not store them in `grid_in` and `grid_out` attributes. This fixes segmentation faults introduced by the memory fix of last version. By `Pascal Bourgault <https://github.com/aulemahal>`_.
7+
* Do not add scalar coordinates of the target grid to the regridded output (:issue:`417`, :pull:`418`). `xe.Regridder.out_coords` is now a dataset instead of a dictionary. By `Pascal Bourgault <https://github.com/aulemahal>`_.
78

89
0.8.8 (2024-11-01)
910
------------------

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ tag_regex = "^(?P<prefix>v)?(?P<version>[^\\+]+)(?P<suffix>.*)?$"
6161
[tool.black]
6262
line-length = 100
6363
target-version = [
64-
'py310',
64+
'py311',
6565
]
6666
skip-string-normalization = true
6767

setup.cfg

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,4 @@ ignore =
77
max-line-length = 100
88
max-complexity = 18
99
select = B,C,E,F,W,T4,B9
10-
extend-ignore = E203,E501,E402,W605
10+
extend-ignore = E203,E501,E402,W503,W605

xesmf/frontend.py

+21-12
Original file line numberDiff line numberDiff line change
@@ -962,18 +962,23 @@ def __init__(
962962
self.out_horiz_dims = (lat_out.dims[0], lon_out.dims[0])
963963

964964
if isinstance(ds_out, Dataset):
965-
self.out_coords = {
966-
name: crd
967-
for name, crd in ds_out.coords.items()
968-
if set(self.out_horiz_dims).issuperset(crd.dims)
969-
}
965+
out_coords = ds_out.coords.to_dataset()
970966
grid_mapping = {
971967
var.attrs['grid_mapping']
972968
for var in ds_out.data_vars.values()
973969
if 'grid_mapping' in var.attrs
974970
}
975-
if grid_mapping:
976-
self.out_coords.update({gm: ds_out[gm] for gm in grid_mapping if gm in ds_out})
971+
# to keep : grid_mappings and non-scalar coords that have the spatial dims
972+
self.out_coords = out_coords.drop_vars(
973+
[
974+
name
975+
for name, crd in out_coords.coords.items()
976+
if not (
977+
(name in grid_mapping)
978+
or (len(crd.dims) > 0 and set(self.out_horiz_dims).issuperset(crd.dims))
979+
)
980+
]
981+
)
977982
else:
978983
self.out_coords = {lat_out.name: lat_out, lon_out.name: lon_out}
979984

@@ -1055,10 +1060,14 @@ def _init_para_regrid(self, ds_in, ds_out, kwargs):
10551060
chunks = out_chunks | in_chunks
10561061

10571062
# Rename coords to avoid issues in xr.map_blocks
1058-
for coord in list(self.out_coords.keys()):
1059-
# If coords and dims are the same, renaming has already been done.
1060-
if coord not in self.out_horiz_dims:
1061-
ds_out = ds_out.rename({coord: coord + '_out'})
1063+
# If coords and dims are the same, renaming has already been done.
1064+
ds_out = ds_out.rename(
1065+
{
1066+
coord: coord + '_out'
1067+
for coord in self.out_coords.coords.keys()
1068+
if coord not in self.out_horiz_dims
1069+
}
1070+
)
10621071

10631072
weights_dims = ('y_out', 'x_out', 'y_in', 'x_in')
10641073
templ = sps.zeros((self.shape_out + self.shape_in))
@@ -1102,7 +1111,7 @@ def _format_xroutput(self, out, new_dims=None):
11021111
# rename dimension name to match output grid
11031112
out = out.rename({nd: od for nd, od in zip(new_dims, self.out_horiz_dims)})
11041113

1105-
out = out.assign_coords(**self.out_coords)
1114+
out = out.assign_coords(self.out_coords.coords)
11061115
out.attrs['regrid_method'] = self.method
11071116

11081117
if self.sequence_out:

xesmf/tests/test_frontend.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -732,17 +732,22 @@ def test_regrid_dataset_extracoords():
732732
x=np.arange(24),
733733
y=np.arange(20), # coords to be transfered
734734
latitude_longitude=xr.DataArray(), # grid_mapping
735-
bogus=ds_out.lev * ds_out.lon, # coord not to be transfered
735+
bogus=ds_out.lev * ds_out.lon, # coords not to be transfered
736+
scalar1=1, #
737+
scalar2=1, #
736738
)
737739
ds_out2['data_ref'].attrs['grid_mapping'] = 'latitude_longitude'
738740
ds_out2['data4D_ref'].attrs['grid_mapping'] = 'latitude_longitude'
739741

742+
ds_in2 = ds_in.assign_coords(scalar2=5)
740743
regridder = xe.Regridder(ds_in, ds_out2, 'conservative')
741-
ds_result = regridder(ds_in)
744+
ds_result = regridder(ds_in2)
742745

743746
assert 'x' in ds_result.coords
744747
assert 'y' in ds_result.coords
745748
assert 'bogus' not in ds_result.coords
749+
assert 'scalar1' not in ds_result.coords
750+
assert ds_result.scalar2 == 5
746751
assert 'latitude_longitude' in ds_result.coords
747752

748753

0 commit comments

Comments
 (0)