Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Preserve Complex Data Types for to_csv #61157

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -141,3 +141,5 @@ doc/source/savefig/
# Pyodide/WASM related files #
##############################
/.pyodide-xbuildenv-*

pandas-env/
7 changes: 7 additions & 0 deletions pandas/core/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3771,6 +3771,7 @@ def to_csv(
decimal: str = ".",
errors: OpenFileErrors = "strict",
storage_options: StorageOptions | None = None,
preserve_complex: bool = False,
) -> str | None:
r"""
Write object to a comma-separated values (csv) file.
Expand Down Expand Up @@ -3858,6 +3859,11 @@ def to_csv(

{storage_options}

preserve_complex : bool, default False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As commented in the issue, you can use the dtype argument in read_csv to read complex values already. I'm negative on this approach.

If True, arrays (e.g. NumPy arrays) or complex data are serialized and
reconstructed in a custom manner. If False (default), standard CSV
behavior is used.

Returns
-------
None or str
Expand Down Expand Up @@ -3938,6 +3944,7 @@ def to_csv(
doublequote=doublequote,
escapechar=escapechar,
storage_options=storage_options,
preserve_complex=preserve_complex,
)

# ----------------------------------------------------------------------
Expand Down
16 changes: 16 additions & 0 deletions pandas/io/formats/csvs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
Sequence,
)
import csv as csvlib
import json
import os
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -75,6 +76,7 @@ def __init__(
doublequote: bool = True,
escapechar: str | None = None,
storage_options: StorageOptions | None = None,
preserve_complex: bool = False,
) -> None:
self.fmt = formatter

Expand All @@ -85,6 +87,7 @@ def __init__(
self.compression: CompressionOptions = compression
self.mode = mode
self.storage_options = storage_options
self.preserve_complex = preserve_complex

self.sep = sep
self.index_label = self._initialize_index_label(index_label)
Expand All @@ -98,6 +101,19 @@ def __init__(
self.cols = self._initialize_columns(cols)
self.chunksize = self._initialize_chunksize(chunksize)

if self.preserve_complex:
for col in self.obj.columns:
if self.obj[col].dtype == "O":
first_val = self.obj[col].iloc[0]
if isinstance(first_val, (np.ndarray, list)):
self.obj[col] = self.obj[col].apply(
lambda x: json.dumps(x.tolist())
if isinstance(x, np.ndarray)
else json.dumps(x)
if isinstance(x, list)
else x
)

@property
def na_rep(self) -> str:
return self.fmt.na_rep
Expand Down
2 changes: 2 additions & 0 deletions pandas/io/formats/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -970,6 +970,7 @@ def to_csv(
escapechar: str | None = None,
errors: str = "strict",
storage_options: StorageOptions | None = None,
preserve_complex: bool = False,
) -> str | None:
"""
Render dataframe as comma-separated file.
Expand Down Expand Up @@ -999,6 +1000,7 @@ def to_csv(
doublequote=doublequote,
escapechar=escapechar,
storage_options=storage_options,
preserve_complex=preserve_complex,
formatter=self.fmt,
)
csv_formatter.save()
Expand Down
39 changes: 38 additions & 1 deletion pandas/io/parsers/readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
defaultdict,
)
import csv
import json
import sys
from textwrap import fill
from typing import (
Expand Down Expand Up @@ -47,6 +48,7 @@
pandas_dtype,
)

import pandas as pd
from pandas import Series
from pandas.core.frame import DataFrame
from pandas.core.indexes.api import RangeIndex
Expand Down Expand Up @@ -453,6 +455,11 @@ class _read_shared(TypedDict, Generic[HashableT], total=False):

{storage_options}

preserve_complex : bool, default False
If True, arrays (e.g. NumPy arrays) or complex data are serialized and
reconstructed in a custom manner. If False (default), standard CSV
behavior is used.

dtype_backend : {{'numpy_nullable', 'pyarrow'}}
Back-end data type applied to the resultant :class:`DataFrame`
(still experimental). If not specified, the default behavior
Expand Down Expand Up @@ -831,6 +838,7 @@ def read_csv(
memory_map: bool = False,
float_precision: Literal["high", "legacy", "round_trip"] | None = None,
storage_options: StorageOptions | None = None,
preserve_complex: bool = False,
dtype_backend: DtypeBackend | lib.NoDefault = lib.no_default,
) -> DataFrame | TextFileReader:
# locals() should never be modified
Expand All @@ -850,7 +858,35 @@ def read_csv(
)
kwds.update(kwds_defaults)

return _read(filepath_or_buffer, kwds)
df_or_reader = _read(filepath_or_buffer, kwds)
# If DataFrame, parse columns containing JSON arrays if preserve_complex=True
if preserve_complex and isinstance(df_or_reader, DataFrame):
_restore_complex_arrays(df_or_reader)

return df_or_reader


def _restore_complex_arrays(df: DataFrame) -> None:
"""
Converted bracketed JSON strings in df back to NumPy arrays.
eg. "[0.1, 0.2, 0.3]" --> parse into NumPy array.
"""

def looks_like_json_array(x: str) -> bool:
return x.startswith("[") and x.endswith("]")

for col in df.columns:
if df[col].dtype == "object":
nonnull = df[col].dropna()
if (
len(nonnull) > 0
and nonnull.apply(
lambda x: isinstance(x, str) and looks_like_json_array(x)
).all()
):
df[col] = df[col].apply(
lambda x: np.array(json.loads(x)) if pd.notnull(x) else x
)


@overload
Expand Down Expand Up @@ -967,6 +1003,7 @@ def read_table(
memory_map: bool = False,
float_precision: Literal["high", "legacy", "round_trip"] | None = None,
storage_options: StorageOptions | None = None,
preserve_complex: bool = False,
dtype_backend: DtypeBackend | lib.NoDefault = lib.no_default,
) -> DataFrame | TextFileReader:
# locals() should never be modified
Expand Down
85 changes: 85 additions & 0 deletions scripts/tests/test_csv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import tempfile

import numpy as np

import pandas as pd


def test_preserve_numpy_arrays_in_csv():
print("\nRunning: test_preserve_numpy_arrays_in_csv")
df = pd.DataFrame({
"id": [1, 2],
"embedding": [
np.array([0.1, 0.2, 0.3]),
np.array([0.4, 0.5, 0.6]),
],
})

with tempfile.NamedTemporaryFile(suffix=".csv") as tmp:
path = tmp.name
df.to_csv(path, index=False, preserve_complex=True)
df_loaded = pd.read_csv(path, preserve_complex=True)

assert isinstance(
df_loaded["embedding"][0], np.ndarray
), "Test Failed: The CSV did not preserve embeddings as NumPy arrays!"

print("PASS: test_preserve_numpy_arrays_in_csv")


def test_preserve_numpy_arrays_in_csv_empty_dataframe():
print("\nRunning: test_preserve_numpy_arrays_in_csv_empty_dataframe")
df = pd.DataFrame({"embedding": []})
expected = "embedding\n"

with tempfile.NamedTemporaryFile(suffix=".csv") as tmp:
path = tmp.name
df.to_csv(path, index=False, preserve_complex=True)
with open(path, encoding="utf-8") as f:
result = f.read()

msg = (
f"CSV output mismatch for empty DataFrame.\n"
f"Got:\n{result}\nExpected:\n{expected}"
)
assert result == expected, msg
print("PASS: test_preserve_numpy_arrays_in_csv_empty_dataframe")


def test_preserve_numpy_arrays_in_csv_mixed_dtypes():
print("\nRunning: test_preserve_numpy_arrays_in_csv_mixed_dtypes")
df = pd.DataFrame({
"id": [101, 102],
"name": ["alice", "bob"],
"scores": [
np.array([95.5, 88.0]),
np.array([76.0, 90.5]),
],
"age": [25, 30],
})

with tempfile.NamedTemporaryFile(suffix=".csv") as tmp:
path = tmp.name
df.to_csv(path, index=False, preserve_complex=True)
df_loaded = pd.read_csv(path, preserve_complex=True)

err_scores = "Failed: 'scores' column not deserialized as np.ndarray."
assert isinstance(df_loaded["scores"][0], np.ndarray), err_scores
assert df_loaded["id"].dtype == np.int64, (
"Failed: 'id' should still be int."
)
assert df_loaded["name"].dtype == object, (
"Failed: 'name' should still be object/string."
)
assert df_loaded["age"].dtype == np.int64, (
"Failed: 'age' should still be int."
)

print("PASS: test_preserve_numpy_arrays_in_csv_mixed_dtypes")


if __name__ == "__main__":
test_preserve_numpy_arrays_in_csv()
test_preserve_numpy_arrays_in_csv_empty_dataframe()
test_preserve_numpy_arrays_in_csv_mixed_dtypes()
print("\nDone.")
3 changes: 3 additions & 0 deletions scripts/tests/test_numpy_array.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
id,embedding
1,[0.1 0.2 0.3]
2,[0.4 0.5 0.6]
3 changes: 3 additions & 0 deletions test_numpy_array.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
id,embedding
1,"[0.1, 0.2, 0.3]"
2,"[0.4, 0.5, 0.6]"
Loading