9
9
10
10
import numpy as np
11
11
12
- from pandas ._config import get_option
12
+ from pandas ._config import (
13
+ get_option ,
14
+ using_string_dtype ,
15
+ )
13
16
14
17
from pandas ._libs import (
15
18
lib ,
@@ -81,8 +84,10 @@ class StringDtype(StorageExtensionDtype):
81
84
82
85
Parameters
83
86
----------
84
- storage : {"python", "pyarrow", "pyarrow_numpy" }, optional
87
+ storage : {"python", "pyarrow"}, optional
85
88
If not given, the value of ``pd.options.mode.string_storage``.
89
+ na_value : {np.nan, pd.NA}, default pd.NA
90
+ Whether the dtype follows NaN or NA missing value semantics.
86
91
87
92
Attributes
88
93
----------
@@ -113,30 +118,67 @@ class StringDtype(StorageExtensionDtype):
113
118
# follows NumPy semantics, which uses nan.
114
119
@property
115
120
def na_value (self ) -> libmissing .NAType | float : # type: ignore[override]
116
- if self .storage == "pyarrow_numpy" :
117
- return np .nan
118
- else :
119
- return libmissing .NA
121
+ return self ._na_value
120
122
121
- _metadata = ("storage" ,)
123
+ _metadata = ("storage" , "_na_value" ) # type: ignore[assignment]
122
124
123
- def __init__ (self , storage = None ) -> None :
125
+ def __init__ (
126
+ self ,
127
+ storage : str | None = None ,
128
+ na_value : libmissing .NAType | float = libmissing .NA ,
129
+ ) -> None :
130
+ # infer defaults
124
131
if storage is None :
125
- infer_string = get_option ("future.infer_string" )
126
- if infer_string :
127
- storage = "pyarrow_numpy"
132
+ if using_string_dtype ():
133
+ storage = "pyarrow"
128
134
else :
129
135
storage = get_option ("mode.string_storage" )
130
- if storage not in {"python" , "pyarrow" , "pyarrow_numpy" }:
136
+
137
+ if storage == "pyarrow_numpy" :
138
+ # TODO raise a deprecation warning
139
+ storage = "pyarrow"
140
+ na_value = np .nan
141
+
142
+ # validate options
143
+ if storage not in {"python" , "pyarrow" }:
131
144
raise ValueError (
132
- f"Storage must be 'python', 'pyarrow' or 'pyarrow_numpy'. "
133
- f"Got { storage } instead."
145
+ f"Storage must be 'python' or 'pyarrow'. Got { storage } instead."
134
146
)
135
- if storage in ( "pyarrow" , "pyarrow_numpy" ) and pa_version_under10p1 :
147
+ if storage == "pyarrow" and pa_version_under10p1 :
136
148
raise ImportError (
137
149
"pyarrow>=10.0.1 is required for PyArrow backed StringArray."
138
150
)
151
+
152
+ if isinstance (na_value , float ) and np .isnan (na_value ):
153
+ # when passed a NaN value, always set to np.nan to ensure we use
154
+ # a consistent NaN value (and we can use `dtype.na_value is np.nan`)
155
+ na_value = np .nan
156
+ elif na_value is not libmissing .NA :
157
+ raise ValueError ("'na_value' must be np.nan or pd.NA, got {na_value}" )
158
+
139
159
self .storage = storage
160
+ self ._na_value = na_value
161
+
162
+ def __eq__ (self , other : object ) -> bool :
163
+ # we need to override the base class __eq__ because na_value (NA or NaN)
164
+ # cannot be checked with normal `==`
165
+ if isinstance (other , str ):
166
+ if other == self .name :
167
+ return True
168
+ try :
169
+ other = self .construct_from_string (other )
170
+ except TypeError :
171
+ return False
172
+ if isinstance (other , type (self )):
173
+ return self .storage == other .storage and self .na_value is other .na_value
174
+ return False
175
+
176
+ def __hash__ (self ) -> int :
177
+ # need to override __hash__ as well because of overriding __eq__
178
+ return super ().__hash__ ()
179
+
180
+ def __reduce__ (self ):
181
+ return StringDtype , (self .storage , self .na_value )
140
182
141
183
@property
142
184
def type (self ) -> type [str ]:
@@ -181,6 +223,7 @@ def construct_from_string(cls, string) -> Self:
181
223
elif string == "string[pyarrow]" :
182
224
return cls (storage = "pyarrow" )
183
225
elif string == "string[pyarrow_numpy]" :
226
+ # TODO deprecate
184
227
return cls (storage = "pyarrow_numpy" )
185
228
else :
186
229
raise TypeError (f"Cannot construct a '{ cls .__name__ } ' from '{ string } '" )
@@ -205,7 +248,7 @@ def construct_array_type( # type: ignore[override]
205
248
206
249
if self .storage == "python" :
207
250
return StringArray
208
- elif self .storage == "pyarrow" :
251
+ elif self .storage == "pyarrow" and self . _na_value is libmissing . NA :
209
252
return ArrowStringArray
210
253
else :
211
254
return ArrowStringArrayNumpySemantics
@@ -217,13 +260,17 @@ def __from_arrow__(
217
260
Construct StringArray from pyarrow Array/ChunkedArray.
218
261
"""
219
262
if self .storage == "pyarrow" :
220
- from pandas .core .arrays .string_arrow import ArrowStringArray
263
+ if self ._na_value is libmissing .NA :
264
+ from pandas .core .arrays .string_arrow import ArrowStringArray
265
+
266
+ return ArrowStringArray (array )
267
+ else :
268
+ from pandas .core .arrays .string_arrow import (
269
+ ArrowStringArrayNumpySemantics ,
270
+ )
221
271
222
- return ArrowStringArray (array )
223
- elif self .storage == "pyarrow_numpy" :
224
- from pandas .core .arrays .string_arrow import ArrowStringArrayNumpySemantics
272
+ return ArrowStringArrayNumpySemantics (array )
225
273
226
- return ArrowStringArrayNumpySemantics (array )
227
274
else :
228
275
import pyarrow
229
276
0 commit comments