Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix #58421: Index[timestamp[pyarrow]].union with itself return object type #61219

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/source/whatsnew/v3.0.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -720,8 +720,10 @@ MultiIndex
- :func:`MultiIndex.get_level_values` accessing a :class:`DatetimeIndex` does not carry the frequency attribute along (:issue:`58327`, :issue:`57949`)
- Bug in :class:`DataFrame` arithmetic operations in case of unaligned MultiIndex columns (:issue:`60498`)
- Bug in :class:`DataFrame` arithmetic operations with :class:`Series` in case of unaligned MultiIndex (:issue:`61009`)
- Bug in :class:`MultiIndex.concat` where extension dtypes such as ``timestamp[pyarrow]`` were silently coerced to ``object`` instead of preserving their original dtype (:issue:`58421`)
- Bug in :meth:`MultiIndex.from_tuples` causing wrong output with input of type tuples having NaN values (:issue:`60695`, :issue:`60988`)


I/O
^^^
- Bug in :class:`DataFrame` and :class:`Series` ``repr`` of :py:class:`collections.abc.Mapping`` elements. (:issue:`57915`)
Expand Down
43 changes: 38 additions & 5 deletions pandas/core/reshape/concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from pandas.core.dtypes.common import (
is_bool,
is_extension_array_dtype,
is_scalar,
)
from pandas.core.dtypes.concat import concat_compat
Expand All @@ -36,6 +37,7 @@
factorize_from_iterables,
)
import pandas.core.common as com
from pandas.core.construction import array as pd_array
from pandas.core.indexes.api import (
Index,
MultiIndex,
Expand Down Expand Up @@ -824,7 +826,20 @@ def _get_sample_object(


def _concat_indexes(indexes) -> Index:
return indexes[0].append(indexes[1:])
# try to preserve extension types such as timestamp[pyarrow]
values = []
for idx in indexes:
values.extend(idx._values if hasattr(idx, "_values") else idx)

# use the first index as a sample to infer the desired dtype
sample = indexes[0]
try:
# this helps preserve extension types like timestamp[pyarrow]
arr = pd_array(values, dtype=sample.dtype)
except Exception:
arr = pd_array(values) # fallback

return Index(arr)


def validate_unique_levels(levels: list[Index]) -> None:
Expand Down Expand Up @@ -881,14 +896,32 @@ def _make_concat_multiindex(indexes, keys, levels=None, names=None) -> MultiInde

concat_index = _concat_indexes(indexes)

# these go at the end
if isinstance(concat_index, MultiIndex):
levels.extend(concat_index.levels)
codes_list.extend(concat_index.codes)
else:
codes, categories = factorize_from_iterable(concat_index)
levels.append(categories)
codes_list.append(codes)
# handle the case where the resulting index is a flat Index
# but contains tuples (i.e., a collapsed MultiIndex)
if isinstance(concat_index[0], tuple):
# retrieve the original dtypes
original_dtypes = [lvl.dtype for lvl in indexes[0].levels]

unzipped = list(zip(*concat_index))
for i, level_values in enumerate(unzipped):
# reconstruct each level using original dtype
arr = pd_array(level_values, dtype=original_dtypes[i])
level_codes, _ = factorize_from_iterable(arr)
levels.append(ensure_index(arr))
codes_list.append(level_codes)
else:
# simple indexes factorize directly
codes, categories = factorize_from_iterable(concat_index)
values = getattr(concat_index, "_values", concat_index)
if is_extension_array_dtype(values):
levels.append(values)
else:
levels.append(categories)
codes_list.append(codes)

if len(names) == len(levels):
names = list(names)
Expand Down
51 changes: 51 additions & 0 deletions pandas/tests/frame/methods/test_concat_arrow_index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import pytest

import pandas as pd

schema = {
"id": "int64[pyarrow]",
"time": "timestamp[s][pyarrow]",
"value": "float[pyarrow]",
}


@pytest.mark.parametrize("dtype", ["timestamp[s][pyarrow]"])
def test_concat_preserves_pyarrow_timestamp(dtype):
dfA = (
pd.DataFrame(
[
(0, "2021-01-01 00:00:00", 5.3),
(1, "2021-01-01 00:01:00", 5.4),
(2, "2021-01-01 00:01:00", 5.4),
(3, "2021-01-01 00:02:00", 5.5),
],
columns=schema,
)
.astype(schema)
.set_index(["id", "time"])
)

dfB = (
pd.DataFrame(
[
(1, "2022-01-01 08:00:00", 6.3),
(2, "2022-01-01 08:01:00", 6.4),
(3, "2022-01-01 08:02:00", 6.5),
],
columns=schema,
)
.astype(schema)
.set_index(["id", "time"])
)

df = pd.concat([dfA, dfB], keys=[0, 1], names=["run"])

# check whether df.index is multiIndex
assert isinstance(df.index, pd.MultiIndex), (
f"Expected MultiIndex, but received {type(df.index)}"
)

# Verifying special dtype timestamp[s][pyarrow] stays intact after concat
assert df.index.levels[2].dtype == dtype, (
f"Expected {dtype}, but received {df.index.levels[2].dtype}"
)
Loading