diff --git a/pandas/core/dtypes/concat.py b/pandas/core/dtypes/concat.py index dcf8cb5c78536..28d7444f0c7dd 100644 --- a/pandas/core/dtypes/concat.py +++ b/pandas/core/dtypes/concat.py @@ -69,6 +69,7 @@ def concat_compat( ------- a single array, preserving the combined dtypes """ + if len(to_concat) and lib.dtypes_all_equal([obj.dtype for obj in to_concat]): # fastpath! obj = to_concat[0] @@ -92,6 +93,27 @@ def concat_compat( to_concat_eas, axis=axis, # type: ignore[call-arg] ) + # Special handling for categorical arrays solves #51362 + if ( + len(to_concat) + and all(isinstance(arr.dtype, CategoricalDtype) for arr in to_concat) + and axis == 0 + ): + # Filter out empty arrays before union, similar to non_empties logic + non_empty_categoricals = [x for x in to_concat if _is_nonempty(x, axis)] + + if len(non_empty_categoricals) == 0: + # All arrays are empty, return the first one (they're all categorical) + return to_concat[0] + elif len(non_empty_categoricals) == 1: + # Only one non-empty array, return it directly + return non_empty_categoricals[0] + else: + # Multiple non-empty arrays, use union_categoricals + return union_categoricals( + non_empty_categoricals, sort_categories=True + ) # Performance cost, but necessary to keep tests passing. + # see pandas/tests/reshape/concat/test_append_common.py:498 # If all arrays are empty, there's nothing to convert, just short-cut to # the concatenation, #3121. diff --git a/pandas/tests/dtypes/test_concat.py b/pandas/tests/dtypes/test_concat.py index 571e12d0c3303..672c536cd9845 100644 --- a/pandas/tests/dtypes/test_concat.py +++ b/pandas/tests/dtypes/test_concat.py @@ -3,7 +3,10 @@ import pandas.core.dtypes.concat as _concat import pandas as pd -from pandas import Series +from pandas import ( + DataFrame, + Series, +) import pandas._testing as tm @@ -14,12 +17,12 @@ def test_concat_mismatched_categoricals_with_empty(): result = _concat.concat_compat([ser1._values, ser2._values]) expected = pd.concat([ser1, ser2])._values - tm.assert_numpy_array_equal(result, expected) + tm.assert_categorical_equal(result, expected) def test_concat_single_dataframe_tz_aware(): # https://github.com/pandas-dev/pandas/issues/25257 - df = pd.DataFrame( + df = DataFrame( {"timestamp": [pd.Timestamp("2020-04-08 09:00:00.709949+0000", tz="UTC")]} ) expected = df.copy() @@ -53,7 +56,7 @@ def test_concat_series_between_empty_and_tzaware_series(using_infer_string): ser2 = Series(dtype=float) result = pd.concat([ser1, ser2], axis=1) - expected = pd.DataFrame( + expected = DataFrame( data=[ (0.0, None), ], @@ -64,3 +67,21 @@ def test_concat_series_between_empty_and_tzaware_series(using_infer_string): dtype=float, ) tm.assert_frame_equal(result, expected) + + +def test_concat_categorical_dataframes(): + df = DataFrame({"a": [0, 1]}, dtype="category") + df2 = DataFrame({"a": [2, 3]}, dtype="category") + + result = pd.concat([df, df2], axis=0) + + assert result["a"].dtype.name == "category" + + +def test_concat_categorical_series(): + ser = Series([0, 1], dtype="category") + ser2 = Series([2, 3], dtype="category") + + result = pd.concat([ser, ser2], axis=0) + + assert result.dtype.name == "category"