1
1
from __future__ import annotations
2
2
3
- from typing import TYPE_CHECKING , Any , Hashable , Iterable , overload
3
+ from typing import TYPE_CHECKING , Any , Hashable , Iterable , cast , overload
4
4
5
5
import pandas as pd
6
6
14
14
merge_attrs ,
15
15
merge_collected ,
16
16
)
17
+ from .types import T_DataArray , T_Dataset
17
18
from .variable import Variable
18
19
from .variable import concat as concat_vars
19
20
20
21
if TYPE_CHECKING :
21
- from .dataarray import DataArray
22
- from .dataset import Dataset
23
22
from .types import CombineAttrsOptions , CompatOptions , ConcatOptions , JoinOptions
24
23
25
24
26
25
@overload
27
26
def concat (
28
- objs : Iterable [Dataset ],
29
- dim : Hashable | DataArray | pd .Index ,
27
+ objs : Iterable [T_Dataset ],
28
+ dim : Hashable | T_DataArray | pd .Index ,
30
29
data_vars : ConcatOptions | list [Hashable ] = "all" ,
31
30
coords : ConcatOptions | list [Hashable ] = "different" ,
32
31
compat : CompatOptions = "equals" ,
33
32
positions : Iterable [Iterable [int ]] | None = None ,
34
33
fill_value : object = dtypes .NA ,
35
34
join : JoinOptions = "outer" ,
36
35
combine_attrs : CombineAttrsOptions = "override" ,
37
- ) -> Dataset :
36
+ ) -> T_Dataset :
38
37
...
39
38
40
39
41
40
@overload
42
41
def concat (
43
- objs : Iterable [DataArray ],
44
- dim : Hashable | DataArray | pd .Index ,
42
+ objs : Iterable [T_DataArray ],
43
+ dim : Hashable | T_DataArray | pd .Index ,
45
44
data_vars : ConcatOptions | list [Hashable ] = "all" ,
46
45
coords : ConcatOptions | list [Hashable ] = "different" ,
47
46
compat : CompatOptions = "equals" ,
48
47
positions : Iterable [Iterable [int ]] | None = None ,
49
48
fill_value : object = dtypes .NA ,
50
49
join : JoinOptions = "outer" ,
51
50
combine_attrs : CombineAttrsOptions = "override" ,
52
- ) -> DataArray :
51
+ ) -> T_DataArray :
53
52
...
54
53
55
54
@@ -402,7 +401,7 @@ def process_subset_opt(opt, subset):
402
401
403
402
# determine dimensional coordinate names and a dict mapping name to DataArray
404
403
def _parse_datasets (
405
- datasets : Iterable [Dataset ],
404
+ datasets : Iterable [T_Dataset ],
406
405
) -> tuple [dict [Hashable , Variable ], dict [Hashable , int ], set [Hashable ], set [Hashable ]]:
407
406
408
407
dims : set [Hashable ] = set ()
@@ -429,16 +428,16 @@ def _parse_datasets(
429
428
430
429
431
430
def _dataset_concat (
432
- datasets : list [Dataset ],
433
- dim : str | DataArray | pd .Index ,
431
+ datasets : list [T_Dataset ],
432
+ dim : str | T_DataArray | pd .Index ,
434
433
data_vars : str | list [str ],
435
434
coords : str | list [str ],
436
435
compat : CompatOptions ,
437
436
positions : Iterable [Iterable [int ]] | None ,
438
437
fill_value : object = dtypes .NA ,
439
438
join : JoinOptions = "outer" ,
440
439
combine_attrs : CombineAttrsOptions = "override" ,
441
- ) -> Dataset :
440
+ ) -> T_Dataset :
442
441
"""
443
442
Concatenate a sequence of datasets along a new or existing dimension
444
443
"""
@@ -482,7 +481,8 @@ def _dataset_concat(
482
481
483
482
# case where concat dimension is a coordinate or data_var but not a dimension
484
483
if (dim in coord_names or dim in data_names ) and dim not in dim_names :
485
- datasets = [ds .expand_dims (dim ) for ds in datasets ]
484
+ # TODO: Overriding type because .expand_dims has incorrect typing:
485
+ datasets = [cast (T_Dataset , ds .expand_dims (dim )) for ds in datasets ]
486
486
487
487
# determine which variables to concatenate
488
488
concat_over , equals , concat_dim_lengths = _calc_concat_over (
@@ -590,7 +590,7 @@ def get_indexes(name):
590
590
# preserves original variable order
591
591
result_vars [name ] = result_vars .pop (name )
592
592
593
- result = Dataset (result_vars , attrs = result_attrs )
593
+ result = type ( datasets [ 0 ]) (result_vars , attrs = result_attrs )
594
594
595
595
absent_coord_names = coord_names - set (result .variables )
596
596
if absent_coord_names :
@@ -618,16 +618,16 @@ def get_indexes(name):
618
618
619
619
620
620
def _dataarray_concat (
621
- arrays : Iterable [DataArray ],
622
- dim : str | DataArray | pd .Index ,
621
+ arrays : Iterable [T_DataArray ],
622
+ dim : str | T_DataArray | pd .Index ,
623
623
data_vars : str | list [str ],
624
624
coords : str | list [str ],
625
625
compat : CompatOptions ,
626
626
positions : Iterable [Iterable [int ]] | None ,
627
627
fill_value : object = dtypes .NA ,
628
628
join : JoinOptions = "outer" ,
629
629
combine_attrs : CombineAttrsOptions = "override" ,
630
- ) -> DataArray :
630
+ ) -> T_DataArray :
631
631
from .dataarray import DataArray
632
632
633
633
arrays = list (arrays )
@@ -650,7 +650,8 @@ def _dataarray_concat(
650
650
if compat == "identical" :
651
651
raise ValueError ("array names not identical" )
652
652
else :
653
- arr = arr .rename (name )
653
+ # TODO: Overriding type because .rename has incorrect typing:
654
+ arr = cast (T_DataArray , arr .rename (name ))
654
655
datasets .append (arr ._to_temp_dataset ())
655
656
656
657
ds = _dataset_concat (
0 commit comments