88
99import numpy as np
1010
11- from pandas ._config import get_option
11+ from pandas ._config import (
12+ get_option ,
13+ using_string_dtype ,
14+ )
1215
1316from pandas ._libs import (
1417 lib ,
@@ -80,8 +83,10 @@ class StringDtype(StorageExtensionDtype):
8083
8184 Parameters
8285 ----------
83- storage : {"python", "pyarrow", "pyarrow_numpy" }, optional
86+ storage : {"python", "pyarrow"}, optional
8487 If not given, the value of ``pd.options.mode.string_storage``.
88+ na_value : {np.nan, pd.NA}, default pd.NA
89+ Whether the dtype follows NaN or NA missing value semantics.
8590
8691 Attributes
8792 ----------
@@ -108,30 +113,67 @@ class StringDtype(StorageExtensionDtype):
108113 # follows NumPy semantics, which uses nan.
109114 @property
110115 def na_value (self ) -> libmissing .NAType | float : # type: ignore[override]
111- if self .storage == "pyarrow_numpy" :
112- return np .nan
113- else :
114- return libmissing .NA
116+ return self ._na_value
115117
116- _metadata = ("storage" ,)
118+ _metadata = ("storage" , "_na_value" ) # type: ignore[assignment]
117119
118- def __init__ (self , storage = None ) -> None :
120+ def __init__ (
121+ self ,
122+ storage : str | None = None ,
123+ na_value : libmissing .NAType | float = libmissing .NA ,
124+ ) -> None :
125+ # infer defaults
119126 if storage is None :
120- infer_string = get_option ("future.infer_string" )
121- if infer_string :
122- storage = "pyarrow_numpy"
127+ if using_string_dtype ():
128+ storage = "pyarrow"
123129 else :
124130 storage = get_option ("mode.string_storage" )
125- if storage not in {"python" , "pyarrow" , "pyarrow_numpy" }:
131+
132+ if storage == "pyarrow_numpy" :
133+ # TODO raise a deprecation warning
134+ storage = "pyarrow"
135+ na_value = np .nan
136+
137+ # validate options
138+ if storage not in {"python" , "pyarrow" }:
126139 raise ValueError (
127- f"Storage must be 'python', 'pyarrow' or 'pyarrow_numpy'. "
128- f"Got { storage } instead."
140+ f"Storage must be 'python' or 'pyarrow'. Got { storage } instead."
129141 )
130- if storage in ( "pyarrow" , "pyarrow_numpy" ) and pa_version_under10p1 :
142+ if storage == "pyarrow" and pa_version_under10p1 :
131143 raise ImportError (
132144 "pyarrow>=10.0.1 is required for PyArrow backed StringArray."
133145 )
146+
147+ if isinstance (na_value , float ) and np .isnan (na_value ):
148+ # when passed a NaN value, always set to np.nan to ensure we use
149+ # a consistent NaN value (and we can use `dtype.na_value is np.nan`)
150+ na_value = np .nan
151+ elif na_value is not libmissing .NA :
152+ raise ValueError ("'na_value' must be np.nan or pd.NA, got {na_value}" )
153+
134154 self .storage = storage
155+ self ._na_value = na_value
156+
157+ def __eq__ (self , other : object ) -> bool :
158+ # we need to override the base class __eq__ because na_value (NA or NaN)
159+ # cannot be checked with normal `==`
160+ if isinstance (other , str ):
161+ if other == self .name :
162+ return True
163+ try :
164+ other = self .construct_from_string (other )
165+ except TypeError :
166+ return False
167+ if isinstance (other , type (self )):
168+ return self .storage == other .storage and self .na_value is other .na_value
169+ return False
170+
171+ def __hash__ (self ) -> int :
172+ # need to override __hash__ as well because of overriding __eq__
173+ return super ().__hash__ ()
174+
175+ def __reduce__ (self ):
176+ return StringDtype , (self .storage , self .na_value )
135177
136178 @property
137179 def type (self ) -> type [str ]:
@@ -176,6 +218,7 @@ def construct_from_string(cls, string) -> Self:
176218 elif string == "string[pyarrow]" :
177219 return cls (storage = "pyarrow" )
178220 elif string == "string[pyarrow_numpy]" :
221+ # TODO deprecate
179222 return cls (storage = "pyarrow_numpy" )
180223 else :
181224 raise TypeError (f"Cannot construct a '{ cls .__name__ } ' from '{ string } '" )
@@ -200,7 +243,7 @@ def construct_array_type( # type: ignore[override]
200243
201244 if self .storage == "python" :
202245 return StringArray
203- elif self .storage == "pyarrow" :
246+ elif self .storage == "pyarrow" and self . _na_value is libmissing . NA :
204247 return ArrowStringArray
205248 else :
206249 return ArrowStringArrayNumpySemantics
@@ -212,13 +255,17 @@ def __from_arrow__(
212255 Construct StringArray from pyarrow Array/ChunkedArray.
213256 """
214257 if self .storage == "pyarrow" :
215- from pandas .core .arrays .string_arrow import ArrowStringArray
258+ if self ._na_value is libmissing .NA :
259+ from pandas .core .arrays .string_arrow import ArrowStringArray
260+
261+ return ArrowStringArray (array )
262+ else :
263+ from pandas .core .arrays .string_arrow import (
264+ ArrowStringArrayNumpySemantics ,
265+ )
216266
217- return ArrowStringArray (array )
218- elif self .storage == "pyarrow_numpy" :
219- from pandas .core .arrays .string_arrow import ArrowStringArrayNumpySemantics
267+ return ArrowStringArrayNumpySemantics (array )
220268
221- return ArrowStringArrayNumpySemantics (array )
222269 else :
223270 import pyarrow
224271
0 commit comments