Skip to content

Commit 8f983f1

Browse files
Switch to T_DataArray and T_Dataset in concat (#6784)
* Switch to T_DataArray in concat * Switch tp T_Dataset in concat * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update concat.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * cast types * Update concat.py * Update concat.py * Update concat.py Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent f045401 commit 8f983f1

File tree

1 file changed

+20
-19
lines changed

1 file changed

+20
-19
lines changed

xarray/core/concat.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING, Any, Hashable, Iterable, overload
3+
from typing import TYPE_CHECKING, Any, Hashable, Iterable, cast, overload
44

55
import pandas as pd
66

@@ -14,42 +14,41 @@
1414
merge_attrs,
1515
merge_collected,
1616
)
17+
from .types import T_DataArray, T_Dataset
1718
from .variable import Variable
1819
from .variable import concat as concat_vars
1920

2021
if TYPE_CHECKING:
21-
from .dataarray import DataArray
22-
from .dataset import Dataset
2322
from .types import CombineAttrsOptions, CompatOptions, ConcatOptions, JoinOptions
2423

2524

2625
@overload
2726
def concat(
28-
objs: Iterable[Dataset],
29-
dim: Hashable | DataArray | pd.Index,
27+
objs: Iterable[T_Dataset],
28+
dim: Hashable | T_DataArray | pd.Index,
3029
data_vars: ConcatOptions | list[Hashable] = "all",
3130
coords: ConcatOptions | list[Hashable] = "different",
3231
compat: CompatOptions = "equals",
3332
positions: Iterable[Iterable[int]] | None = None,
3433
fill_value: object = dtypes.NA,
3534
join: JoinOptions = "outer",
3635
combine_attrs: CombineAttrsOptions = "override",
37-
) -> Dataset:
36+
) -> T_Dataset:
3837
...
3938

4039

4140
@overload
4241
def concat(
43-
objs: Iterable[DataArray],
44-
dim: Hashable | DataArray | pd.Index,
42+
objs: Iterable[T_DataArray],
43+
dim: Hashable | T_DataArray | pd.Index,
4544
data_vars: ConcatOptions | list[Hashable] = "all",
4645
coords: ConcatOptions | list[Hashable] = "different",
4746
compat: CompatOptions = "equals",
4847
positions: Iterable[Iterable[int]] | None = None,
4948
fill_value: object = dtypes.NA,
5049
join: JoinOptions = "outer",
5150
combine_attrs: CombineAttrsOptions = "override",
52-
) -> DataArray:
51+
) -> T_DataArray:
5352
...
5453

5554

@@ -402,7 +401,7 @@ def process_subset_opt(opt, subset):
402401

403402
# determine dimensional coordinate names and a dict mapping name to DataArray
404403
def _parse_datasets(
405-
datasets: Iterable[Dataset],
404+
datasets: Iterable[T_Dataset],
406405
) -> tuple[dict[Hashable, Variable], dict[Hashable, int], set[Hashable], set[Hashable]]:
407406

408407
dims: set[Hashable] = set()
@@ -429,16 +428,16 @@ def _parse_datasets(
429428

430429

431430
def _dataset_concat(
432-
datasets: list[Dataset],
433-
dim: str | DataArray | pd.Index,
431+
datasets: list[T_Dataset],
432+
dim: str | T_DataArray | pd.Index,
434433
data_vars: str | list[str],
435434
coords: str | list[str],
436435
compat: CompatOptions,
437436
positions: Iterable[Iterable[int]] | None,
438437
fill_value: object = dtypes.NA,
439438
join: JoinOptions = "outer",
440439
combine_attrs: CombineAttrsOptions = "override",
441-
) -> Dataset:
440+
) -> T_Dataset:
442441
"""
443442
Concatenate a sequence of datasets along a new or existing dimension
444443
"""
@@ -482,7 +481,8 @@ def _dataset_concat(
482481

483482
# case where concat dimension is a coordinate or data_var but not a dimension
484483
if (dim in coord_names or dim in data_names) and dim not in dim_names:
485-
datasets = [ds.expand_dims(dim) for ds in datasets]
484+
# TODO: Overriding type because .expand_dims has incorrect typing:
485+
datasets = [cast(T_Dataset, ds.expand_dims(dim)) for ds in datasets]
486486

487487
# determine which variables to concatenate
488488
concat_over, equals, concat_dim_lengths = _calc_concat_over(
@@ -590,7 +590,7 @@ def get_indexes(name):
590590
# preserves original variable order
591591
result_vars[name] = result_vars.pop(name)
592592

593-
result = Dataset(result_vars, attrs=result_attrs)
593+
result = type(datasets[0])(result_vars, attrs=result_attrs)
594594

595595
absent_coord_names = coord_names - set(result.variables)
596596
if absent_coord_names:
@@ -618,16 +618,16 @@ def get_indexes(name):
618618

619619

620620
def _dataarray_concat(
621-
arrays: Iterable[DataArray],
622-
dim: str | DataArray | pd.Index,
621+
arrays: Iterable[T_DataArray],
622+
dim: str | T_DataArray | pd.Index,
623623
data_vars: str | list[str],
624624
coords: str | list[str],
625625
compat: CompatOptions,
626626
positions: Iterable[Iterable[int]] | None,
627627
fill_value: object = dtypes.NA,
628628
join: JoinOptions = "outer",
629629
combine_attrs: CombineAttrsOptions = "override",
630-
) -> DataArray:
630+
) -> T_DataArray:
631631
from .dataarray import DataArray
632632

633633
arrays = list(arrays)
@@ -650,7 +650,8 @@ def _dataarray_concat(
650650
if compat == "identical":
651651
raise ValueError("array names not identical")
652652
else:
653-
arr = arr.rename(name)
653+
# TODO: Overriding type because .rename has incorrect typing:
654+
arr = cast(T_DataArray, arr.rename(name))
654655
datasets.append(arr._to_temp_dataset())
655656

656657
ds = _dataset_concat(

0 commit comments

Comments
 (0)