Skip to content

Commit f25a09e

Browse files
String dtype: rename the storage options and add na_value keyword in StringDtype() (#59330)
* rename storage option and add na_value keyword * update init * fix propagating na_value to Array class + fix some tests * fix more tests * disallow pyarrow_numpy as option + fix more cases of checking storage to be pyarrow_numpy * restore pyarrow_numpy as option for now * linting * try fix typing * try fix typing * fix dtype equality to take into account the NaN vs NA * fix pickling of dtype * fix test_convert_dtypes * update expected result for dtype='string' * suppress typing error with _metadata attribute
1 parent 56ea76a commit f25a09e

File tree

20 files changed

+176
-110
lines changed

20 files changed

+176
-110
lines changed

pandas/_libs/lib.pyx

+1-1
Original file line numberDiff line numberDiff line change
@@ -2702,7 +2702,7 @@ def maybe_convert_objects(ndarray[object] objects,
27022702
if using_string_dtype() and is_string_array(objects, skipna=True):
27032703
from pandas.core.arrays.string_ import StringDtype
27042704

2705-
dtype = StringDtype(storage="pyarrow_numpy")
2705+
dtype = StringDtype(storage="pyarrow", na_value=np.nan)
27062706
return dtype.construct_array_type()._from_sequence(objects, dtype=dtype)
27072707

27082708
elif convert_to_nullable_dtype and is_string_array(objects, skipna=True):

pandas/_testing/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -509,14 +509,14 @@ def shares_memory(left, right) -> bool:
509509
if (
510510
isinstance(left, ExtensionArray)
511511
and is_string_dtype(left.dtype)
512-
and left.dtype.storage in ("pyarrow", "pyarrow_numpy") # type: ignore[attr-defined]
512+
and left.dtype.storage == "pyarrow" # type: ignore[attr-defined]
513513
):
514514
# https://github.com/pandas-dev/pandas/pull/43930#discussion_r736862669
515515
left = cast("ArrowExtensionArray", left)
516516
if (
517517
isinstance(right, ExtensionArray)
518518
and is_string_dtype(right.dtype)
519-
and right.dtype.storage in ("pyarrow", "pyarrow_numpy") # type: ignore[attr-defined]
519+
and right.dtype.storage == "pyarrow" # type: ignore[attr-defined]
520520
):
521521
right = cast("ArrowExtensionArray", right)
522522
left_pa_data = left._pa_array

pandas/core/arrays/arrow/array.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -575,10 +575,8 @@ def __getitem__(self, item: PositionalIndexer):
575575
if isinstance(item, np.ndarray):
576576
if not len(item):
577577
# Removable once we migrate StringDtype[pyarrow] to ArrowDtype[string]
578-
if self._dtype.name == "string" and self._dtype.storage in (
579-
"pyarrow",
580-
"pyarrow_numpy",
581-
):
578+
if self._dtype.name == "string" and self._dtype.storage == "pyarrow":
579+
# TODO(infer_string) should this be large_string?
582580
pa_dtype = pa.string()
583581
else:
584582
pa_dtype = self._dtype.pyarrow_dtype

pandas/core/arrays/string_.py

+68-21
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,10 @@
99

1010
import numpy as np
1111

12-
from pandas._config import get_option
12+
from pandas._config import (
13+
get_option,
14+
using_string_dtype,
15+
)
1316

1417
from pandas._libs import (
1518
lib,
@@ -81,8 +84,10 @@ class StringDtype(StorageExtensionDtype):
8184
8285
Parameters
8386
----------
84-
storage : {"python", "pyarrow", "pyarrow_numpy"}, optional
87+
storage : {"python", "pyarrow"}, optional
8588
If not given, the value of ``pd.options.mode.string_storage``.
89+
na_value : {np.nan, pd.NA}, default pd.NA
90+
Whether the dtype follows NaN or NA missing value semantics.
8691
8792
Attributes
8893
----------
@@ -113,30 +118,67 @@ class StringDtype(StorageExtensionDtype):
113118
# follows NumPy semantics, which uses nan.
114119
@property
115120
def na_value(self) -> libmissing.NAType | float: # type: ignore[override]
116-
if self.storage == "pyarrow_numpy":
117-
return np.nan
118-
else:
119-
return libmissing.NA
121+
return self._na_value
120122

121-
_metadata = ("storage",)
123+
_metadata = ("storage", "_na_value") # type: ignore[assignment]
122124

123-
def __init__(self, storage=None) -> None:
125+
def __init__(
126+
self,
127+
storage: str | None = None,
128+
na_value: libmissing.NAType | float = libmissing.NA,
129+
) -> None:
130+
# infer defaults
124131
if storage is None:
125-
infer_string = get_option("future.infer_string")
126-
if infer_string:
127-
storage = "pyarrow_numpy"
132+
if using_string_dtype():
133+
storage = "pyarrow"
128134
else:
129135
storage = get_option("mode.string_storage")
130-
if storage not in {"python", "pyarrow", "pyarrow_numpy"}:
136+
137+
if storage == "pyarrow_numpy":
138+
# TODO raise a deprecation warning
139+
storage = "pyarrow"
140+
na_value = np.nan
141+
142+
# validate options
143+
if storage not in {"python", "pyarrow"}:
131144
raise ValueError(
132-
f"Storage must be 'python', 'pyarrow' or 'pyarrow_numpy'. "
133-
f"Got {storage} instead."
145+
f"Storage must be 'python' or 'pyarrow'. Got {storage} instead."
134146
)
135-
if storage in ("pyarrow", "pyarrow_numpy") and pa_version_under10p1:
147+
if storage == "pyarrow" and pa_version_under10p1:
136148
raise ImportError(
137149
"pyarrow>=10.0.1 is required for PyArrow backed StringArray."
138150
)
151+
152+
if isinstance(na_value, float) and np.isnan(na_value):
153+
# when passed a NaN value, always set to np.nan to ensure we use
154+
# a consistent NaN value (and we can use `dtype.na_value is np.nan`)
155+
na_value = np.nan
156+
elif na_value is not libmissing.NA:
157+
raise ValueError("'na_value' must be np.nan or pd.NA, got {na_value}")
158+
139159
self.storage = storage
160+
self._na_value = na_value
161+
162+
def __eq__(self, other: object) -> bool:
163+
# we need to override the base class __eq__ because na_value (NA or NaN)
164+
# cannot be checked with normal `==`
165+
if isinstance(other, str):
166+
if other == self.name:
167+
return True
168+
try:
169+
other = self.construct_from_string(other)
170+
except TypeError:
171+
return False
172+
if isinstance(other, type(self)):
173+
return self.storage == other.storage and self.na_value is other.na_value
174+
return False
175+
176+
def __hash__(self) -> int:
177+
# need to override __hash__ as well because of overriding __eq__
178+
return super().__hash__()
179+
180+
def __reduce__(self):
181+
return StringDtype, (self.storage, self.na_value)
140182

141183
@property
142184
def type(self) -> type[str]:
@@ -181,6 +223,7 @@ def construct_from_string(cls, string) -> Self:
181223
elif string == "string[pyarrow]":
182224
return cls(storage="pyarrow")
183225
elif string == "string[pyarrow_numpy]":
226+
# TODO deprecate
184227
return cls(storage="pyarrow_numpy")
185228
else:
186229
raise TypeError(f"Cannot construct a '{cls.__name__}' from '{string}'")
@@ -205,7 +248,7 @@ def construct_array_type( # type: ignore[override]
205248

206249
if self.storage == "python":
207250
return StringArray
208-
elif self.storage == "pyarrow":
251+
elif self.storage == "pyarrow" and self._na_value is libmissing.NA:
209252
return ArrowStringArray
210253
else:
211254
return ArrowStringArrayNumpySemantics
@@ -217,13 +260,17 @@ def __from_arrow__(
217260
Construct StringArray from pyarrow Array/ChunkedArray.
218261
"""
219262
if self.storage == "pyarrow":
220-
from pandas.core.arrays.string_arrow import ArrowStringArray
263+
if self._na_value is libmissing.NA:
264+
from pandas.core.arrays.string_arrow import ArrowStringArray
265+
266+
return ArrowStringArray(array)
267+
else:
268+
from pandas.core.arrays.string_arrow import (
269+
ArrowStringArrayNumpySemantics,
270+
)
221271

222-
return ArrowStringArray(array)
223-
elif self.storage == "pyarrow_numpy":
224-
from pandas.core.arrays.string_arrow import ArrowStringArrayNumpySemantics
272+
return ArrowStringArrayNumpySemantics(array)
225273

226-
return ArrowStringArrayNumpySemantics(array)
227274
else:
228275
import pyarrow
229276

pandas/core/arrays/string_arrow.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ class ArrowStringArray(ObjectStringArrayMixin, ArrowExtensionArray, BaseStringAr
131131
# base class "ArrowExtensionArray" defined the type as "ArrowDtype")
132132
_dtype: StringDtype # type: ignore[assignment]
133133
_storage = "pyarrow"
134+
_na_value: libmissing.NAType | float = libmissing.NA
134135

135136
def __init__(self, values) -> None:
136137
_chk_pyarrow_available()
@@ -140,7 +141,7 @@ def __init__(self, values) -> None:
140141
values = pc.cast(values, pa.large_string())
141142

142143
super().__init__(values)
143-
self._dtype = StringDtype(storage=self._storage)
144+
self._dtype = StringDtype(storage=self._storage, na_value=self._na_value)
144145

145146
if not pa.types.is_large_string(self._pa_array.type) and not (
146147
pa.types.is_dictionary(self._pa_array.type)
@@ -187,10 +188,7 @@ def _from_sequence(
187188

188189
if dtype and not (isinstance(dtype, str) and dtype == "string"):
189190
dtype = pandas_dtype(dtype)
190-
assert isinstance(dtype, StringDtype) and dtype.storage in (
191-
"pyarrow",
192-
"pyarrow_numpy",
193-
)
191+
assert isinstance(dtype, StringDtype) and dtype.storage == "pyarrow"
194192

195193
if isinstance(scalars, BaseMaskedArray):
196194
# avoid costly conversion to object dtype in ensure_string_array and
@@ -597,7 +595,8 @@ def _rank(
597595

598596

599597
class ArrowStringArrayNumpySemantics(ArrowStringArray):
600-
_storage = "pyarrow_numpy"
598+
_storage = "pyarrow"
599+
_na_value = np.nan
601600

602601
@classmethod
603602
def _result_converter(cls, values, na=None):

pandas/core/construction.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -574,7 +574,7 @@ def sanitize_array(
574574
if isinstance(data, str) and using_string_dtype() and original_dtype is None:
575575
from pandas.core.arrays.string_ import StringDtype
576576

577-
dtype = StringDtype("pyarrow_numpy")
577+
dtype = StringDtype("pyarrow", na_value=np.nan)
578578
data = construct_1d_arraylike_from_scalar(data, len(index), dtype)
579579

580580
return data
@@ -608,7 +608,7 @@ def sanitize_array(
608608
elif data.dtype.kind == "U" and using_string_dtype():
609609
from pandas.core.arrays.string_ import StringDtype
610610

611-
dtype = StringDtype(storage="pyarrow_numpy")
611+
dtype = StringDtype(storage="pyarrow", na_value=np.nan)
612612
subarr = dtype.construct_array_type()._from_sequence(data, dtype=dtype)
613613

614614
if subarr is data and copy:

pandas/core/dtypes/cast.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -801,7 +801,7 @@ def infer_dtype_from_scalar(val) -> tuple[DtypeObj, Any]:
801801
if using_string_dtype():
802802
from pandas.core.arrays.string_ import StringDtype
803803

804-
dtype = StringDtype(storage="pyarrow_numpy")
804+
dtype = StringDtype(storage="pyarrow", na_value=np.nan)
805805

806806
elif isinstance(val, (np.datetime64, dt.datetime)):
807807
try:

pandas/core/indexes/base.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -5453,9 +5453,10 @@ def equals(self, other: Any) -> bool:
54535453

54545454
if (
54555455
isinstance(self.dtype, StringDtype)
5456-
and self.dtype.storage == "pyarrow_numpy"
5456+
and self.dtype.na_value is np.nan
54575457
and other.dtype != self.dtype
54585458
):
5459+
# TODO(infer_string) can we avoid this special case?
54595460
# special case for object behavior
54605461
return other.equals(self.astype(object))
54615462

pandas/core/internals/construction.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,7 @@ def ndarray_to_mgr(
302302
nb = new_block_2d(values, placement=bp, refs=refs)
303303
block_values = [nb]
304304
elif dtype is None and values.dtype.kind == "U" and using_string_dtype():
305-
dtype = StringDtype(storage="pyarrow_numpy")
305+
dtype = StringDtype(storage="pyarrow", na_value=np.nan)
306306

307307
obj_columns = list(values)
308308
block_values = [

pandas/core/reshape/encoding.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import numpy as np
1212

13+
from pandas._libs import missing as libmissing
1314
from pandas._libs.sparse import IntIndex
1415

1516
from pandas.core.dtypes.common import (
@@ -256,7 +257,7 @@ def _get_dummies_1d(
256257
dtype = ArrowDtype(pa.bool_()) # type: ignore[assignment]
257258
elif (
258259
isinstance(input_dtype, StringDtype)
259-
and input_dtype.storage != "pyarrow_numpy"
260+
and input_dtype.na_value is libmissing.NA
260261
):
261262
dtype = pandas_dtype("boolean") # type: ignore[assignment]
262263
else:

pandas/core/reshape/merge.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -2677,8 +2677,7 @@ def _factorize_keys(
26772677

26782678
elif isinstance(lk, ExtensionArray) and lk.dtype == rk.dtype:
26792679
if (isinstance(lk.dtype, ArrowDtype) and is_string_dtype(lk.dtype)) or (
2680-
isinstance(lk.dtype, StringDtype)
2681-
and lk.dtype.storage in ["pyarrow", "pyarrow_numpy"]
2680+
isinstance(lk.dtype, StringDtype) and lk.dtype.storage == "pyarrow"
26822681
):
26832682
import pyarrow as pa
26842683
import pyarrow.compute as pc

pandas/core/tools/numeric.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,10 @@
77

88
import numpy as np
99

10-
from pandas._libs import lib
10+
from pandas._libs import (
11+
lib,
12+
missing as libmissing,
13+
)
1114
from pandas.util._validators import check_dtype_backend
1215

1316
from pandas.core.dtypes.cast import maybe_downcast_numeric
@@ -218,7 +221,7 @@ def to_numeric(
218221
coerce_numeric=coerce_numeric,
219222
convert_to_masked_nullable=dtype_backend is not lib.no_default
220223
or isinstance(values_dtype, StringDtype)
221-
and not values_dtype.storage == "pyarrow_numpy",
224+
and values_dtype.na_value is libmissing.NA,
222225
)
223226

224227
if new_mask is not None:
@@ -229,7 +232,7 @@ def to_numeric(
229232
dtype_backend is not lib.no_default
230233
and new_mask is None
231234
or isinstance(values_dtype, StringDtype)
232-
and not values_dtype.storage == "pyarrow_numpy"
235+
and values_dtype.na_value is libmissing.NA
233236
):
234237
new_mask = np.zeros(values.shape, dtype=np.bool_)
235238

pandas/io/_util.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
from typing import TYPE_CHECKING
44

5+
import numpy as np
6+
57
from pandas.compat._optional import import_optional_dependency
68

79
import pandas as pd
@@ -32,6 +34,6 @@ def arrow_string_types_mapper() -> Callable:
3234
pa = import_optional_dependency("pyarrow")
3335

3436
return {
35-
pa.string(): pd.StringDtype(storage="pyarrow_numpy"),
36-
pa.large_string(): pd.StringDtype(storage="pyarrow_numpy"),
37+
pa.string(): pd.StringDtype(storage="pyarrow", na_value=np.nan),
38+
pa.large_string(): pd.StringDtype(storage="pyarrow", na_value=np.nan),
3739
}.get

0 commit comments

Comments
 (0)