-
-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Avoid coercing to numpy in as_shared_dtypes
#8714
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 3 commits
c6f4e3a
1467c4c
d9931ef
c067f7d
5092aaa
a884ba8
86e6bf8
630629c
45808d8
e625c67
6833d66
bcd02bf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -217,23 +217,27 @@ def asarray(data, xp=np): | |||||||||||||||||||||||||||
return data if is_duck_array(data) else xp.asarray(data) | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
def as_shared_dtype(scalars_or_arrays, xp=np): | ||||||||||||||||||||||||||||
"""Cast a arrays to a shared dtype using xarray's type promotion rules.""" | ||||||||||||||||||||||||||||
array_type_cupy = array_type("cupy") | ||||||||||||||||||||||||||||
if array_type_cupy and any( | ||||||||||||||||||||||||||||
isinstance(x, array_type_cupy) for x in scalars_or_arrays | ||||||||||||||||||||||||||||
): | ||||||||||||||||||||||||||||
import cupy as cp | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
arrays = [asarray(x, xp=cp) for x in scalars_or_arrays] | ||||||||||||||||||||||||||||
def as_duck_array(data, xp=np): | ||||||||||||||||||||||||||||
if is_duck_array(data): | ||||||||||||||||||||||||||||
return data | ||||||||||||||||||||||||||||
elif hasattr(data, "get_duck_array"): | ||||||||||||||||||||||||||||
# must be a lazy indexing class wrapping a duck array | ||||||||||||||||||||||||||||
return data.get_duck_array() | ||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will this idea always work? What if it steps down through a lazy decoder class that changes the dtype... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Those should be going through xarray/xarray/coding/variables.py Lines 52 to 64 in c9ba2be
so you should be fine. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think I'm getting confused as to how this all works now... Don't I want to be computing There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As of now, So that means we need to read from disk, which you do with It will get more complicated when we do lazy concatenation in Xarray, then we'd need to lazily infer dtypes and apply a lazy astype. |
||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||
arrays = [asarray(x, xp=xp) for x in scalars_or_arrays] | ||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Previously this |
||||||||||||||||||||||||||||
# Pass arrays directly instead of dtypes to result_type so scalars | ||||||||||||||||||||||||||||
# get handled properly. | ||||||||||||||||||||||||||||
# Note that result_type() safely gets the dtype from dask arrays without | ||||||||||||||||||||||||||||
# evaluating them. | ||||||||||||||||||||||||||||
out_type = dtypes.result_type(*arrays) | ||||||||||||||||||||||||||||
return [astype(x, out_type, copy=False) for x in arrays] | ||||||||||||||||||||||||||||
array_type_cupy = array_type("cupy") | ||||||||||||||||||||||||||||
if array_type_cupy and any(isinstance(data, array_type_cupy)): | ||||||||||||||||||||||||||||
import cupy as cp | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
return asarray(data, xp=cp) | ||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||
return asarray(data, xp=xp) | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
def as_shared_dtype(scalars_or_arrays, xp=np): | ||||||||||||||||||||||||||||
"""Cast arrays to a shared dtype using xarray's type promotion rules.""" | ||||||||||||||||||||||||||||
duckarrays = [as_duck_array(obj, xp=xp) for obj in scalars_or_arrays] | ||||||||||||||||||||||||||||
out_type = dtypes.result_type(*duckarrays) | ||||||||||||||||||||||||||||
return [astype(x, out_type, copy=False) for x in duckarrays] | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
def broadcast_to(array, shape): | ||||||||||||||||||||||||||||
|
Uh oh!
There was an error while loading. Please reload this page.