diff --git a/.gitignore b/.gitignore index d951f3fb9cbad..f2b86e857cf66 100644 --- a/.gitignore +++ b/.gitignore @@ -141,3 +141,5 @@ doc/source/savefig/ # Pyodide/WASM related files # ############################## /.pyodide-xbuildenv-* + +pandas-env/ diff --git a/pandas/core/generic.py b/pandas/core/generic.py index 0c3f535df9ce2..ea62abb1eea01 100644 --- a/pandas/core/generic.py +++ b/pandas/core/generic.py @@ -3771,6 +3771,7 @@ def to_csv( decimal: str = ".", errors: OpenFileErrors = "strict", storage_options: StorageOptions | None = None, + preserve_complex: bool = False, ) -> str | None: r""" Write object to a comma-separated values (csv) file. @@ -3858,6 +3859,11 @@ def to_csv( {storage_options} + preserve_complex : bool, default False + If True, arrays (e.g. NumPy arrays) or complex data are serialized and + reconstructed in a custom manner. If False (default), standard CSV + behavior is used. + Returns ------- None or str @@ -3938,6 +3944,7 @@ def to_csv( doublequote=doublequote, escapechar=escapechar, storage_options=storage_options, + preserve_complex=preserve_complex, ) # ---------------------------------------------------------------------- diff --git a/pandas/io/formats/csvs.py b/pandas/io/formats/csvs.py index 75bcb51ef4be2..fd7a8ddce700d 100644 --- a/pandas/io/formats/csvs.py +++ b/pandas/io/formats/csvs.py @@ -11,6 +11,7 @@ Sequence, ) import csv as csvlib +import json import os from typing import ( TYPE_CHECKING, @@ -75,6 +76,7 @@ def __init__( doublequote: bool = True, escapechar: str | None = None, storage_options: StorageOptions | None = None, + preserve_complex: bool = False, ) -> None: self.fmt = formatter @@ -85,6 +87,7 @@ def __init__( self.compression: CompressionOptions = compression self.mode = mode self.storage_options = storage_options + self.preserve_complex = preserve_complex self.sep = sep self.index_label = self._initialize_index_label(index_label) @@ -98,6 +101,19 @@ def __init__( self.cols = self._initialize_columns(cols) self.chunksize = self._initialize_chunksize(chunksize) + if self.preserve_complex: + for col in self.obj.columns: + if self.obj[col].dtype == "O": + first_val = self.obj[col].iloc[0] + if isinstance(first_val, (np.ndarray, list)): + self.obj[col] = self.obj[col].apply( + lambda x: json.dumps(x.tolist()) + if isinstance(x, np.ndarray) + else json.dumps(x) + if isinstance(x, list) + else x + ) + @property def na_rep(self) -> str: return self.fmt.na_rep diff --git a/pandas/io/formats/format.py b/pandas/io/formats/format.py index b7fbc4e5e22b7..b82320c5ac3ac 100644 --- a/pandas/io/formats/format.py +++ b/pandas/io/formats/format.py @@ -970,6 +970,7 @@ def to_csv( escapechar: str | None = None, errors: str = "strict", storage_options: StorageOptions | None = None, + preserve_complex: bool = False, ) -> str | None: """ Render dataframe as comma-separated file. @@ -999,6 +1000,7 @@ def to_csv( doublequote=doublequote, escapechar=escapechar, storage_options=storage_options, + preserve_complex=preserve_complex, formatter=self.fmt, ) csv_formatter.save() diff --git a/pandas/io/parsers/readers.py b/pandas/io/parsers/readers.py index 67193f930b4dc..f68ffc9f029b3 100644 --- a/pandas/io/parsers/readers.py +++ b/pandas/io/parsers/readers.py @@ -11,6 +11,7 @@ defaultdict, ) import csv +import json import sys from textwrap import fill from typing import ( @@ -47,6 +48,7 @@ pandas_dtype, ) +import pandas as pd from pandas import Series from pandas.core.frame import DataFrame from pandas.core.indexes.api import RangeIndex @@ -453,6 +455,11 @@ class _read_shared(TypedDict, Generic[HashableT], total=False): {storage_options} +preserve_complex : bool, default False + If True, arrays (e.g. NumPy arrays) or complex data are serialized and + reconstructed in a custom manner. If False (default), standard CSV + behavior is used. + dtype_backend : {{'numpy_nullable', 'pyarrow'}} Back-end data type applied to the resultant :class:`DataFrame` (still experimental). If not specified, the default behavior @@ -831,6 +838,7 @@ def read_csv( memory_map: bool = False, float_precision: Literal["high", "legacy", "round_trip"] | None = None, storage_options: StorageOptions | None = None, + preserve_complex: bool = False, dtype_backend: DtypeBackend | lib.NoDefault = lib.no_default, ) -> DataFrame | TextFileReader: # locals() should never be modified @@ -850,7 +858,35 @@ def read_csv( ) kwds.update(kwds_defaults) - return _read(filepath_or_buffer, kwds) + df_or_reader = _read(filepath_or_buffer, kwds) + # If DataFrame, parse columns containing JSON arrays if preserve_complex=True + if preserve_complex and isinstance(df_or_reader, DataFrame): + _restore_complex_arrays(df_or_reader) + + return df_or_reader + + +def _restore_complex_arrays(df: DataFrame) -> None: + """ + Converted bracketed JSON strings in df back to NumPy arrays. + eg. "[0.1, 0.2, 0.3]" --> parse into NumPy array. + """ + + def looks_like_json_array(x: str) -> bool: + return x.startswith("[") and x.endswith("]") + + for col in df.columns: + if df[col].dtype == "object": + nonnull = df[col].dropna() + if ( + len(nonnull) > 0 + and nonnull.apply( + lambda x: isinstance(x, str) and looks_like_json_array(x) + ).all() + ): + df[col] = df[col].apply( + lambda x: np.array(json.loads(x)) if pd.notnull(x) else x + ) @overload @@ -967,6 +1003,7 @@ def read_table( memory_map: bool = False, float_precision: Literal["high", "legacy", "round_trip"] | None = None, storage_options: StorageOptions | None = None, + preserve_complex: bool = False, dtype_backend: DtypeBackend | lib.NoDefault = lib.no_default, ) -> DataFrame | TextFileReader: # locals() should never be modified diff --git a/scripts/tests/test_csv.py b/scripts/tests/test_csv.py new file mode 100644 index 0000000000000..aa8cfa44334cc --- /dev/null +++ b/scripts/tests/test_csv.py @@ -0,0 +1,85 @@ +import tempfile + +import numpy as np + +import pandas as pd + + +def test_preserve_numpy_arrays_in_csv(): + print("\nRunning: test_preserve_numpy_arrays_in_csv") + df = pd.DataFrame({ + "id": [1, 2], + "embedding": [ + np.array([0.1, 0.2, 0.3]), + np.array([0.4, 0.5, 0.6]), + ], + }) + + with tempfile.NamedTemporaryFile(suffix=".csv") as tmp: + path = tmp.name + df.to_csv(path, index=False, preserve_complex=True) + df_loaded = pd.read_csv(path, preserve_complex=True) + + assert isinstance( + df_loaded["embedding"][0], np.ndarray + ), "Test Failed: The CSV did not preserve embeddings as NumPy arrays!" + + print("PASS: test_preserve_numpy_arrays_in_csv") + + +def test_preserve_numpy_arrays_in_csv_empty_dataframe(): + print("\nRunning: test_preserve_numpy_arrays_in_csv_empty_dataframe") + df = pd.DataFrame({"embedding": []}) + expected = "embedding\n" + + with tempfile.NamedTemporaryFile(suffix=".csv") as tmp: + path = tmp.name + df.to_csv(path, index=False, preserve_complex=True) + with open(path, encoding="utf-8") as f: + result = f.read() + + msg = ( + f"CSV output mismatch for empty DataFrame.\n" + f"Got:\n{result}\nExpected:\n{expected}" + ) + assert result == expected, msg + print("PASS: test_preserve_numpy_arrays_in_csv_empty_dataframe") + + +def test_preserve_numpy_arrays_in_csv_mixed_dtypes(): + print("\nRunning: test_preserve_numpy_arrays_in_csv_mixed_dtypes") + df = pd.DataFrame({ + "id": [101, 102], + "name": ["alice", "bob"], + "scores": [ + np.array([95.5, 88.0]), + np.array([76.0, 90.5]), + ], + "age": [25, 30], + }) + + with tempfile.NamedTemporaryFile(suffix=".csv") as tmp: + path = tmp.name + df.to_csv(path, index=False, preserve_complex=True) + df_loaded = pd.read_csv(path, preserve_complex=True) + + err_scores = "Failed: 'scores' column not deserialized as np.ndarray." + assert isinstance(df_loaded["scores"][0], np.ndarray), err_scores + assert df_loaded["id"].dtype == np.int64, ( + "Failed: 'id' should still be int." + ) + assert df_loaded["name"].dtype == object, ( + "Failed: 'name' should still be object/string." + ) + assert df_loaded["age"].dtype == np.int64, ( + "Failed: 'age' should still be int." + ) + + print("PASS: test_preserve_numpy_arrays_in_csv_mixed_dtypes") + + +if __name__ == "__main__": + test_preserve_numpy_arrays_in_csv() + test_preserve_numpy_arrays_in_csv_empty_dataframe() + test_preserve_numpy_arrays_in_csv_mixed_dtypes() + print("\nDone.") diff --git a/scripts/tests/test_numpy_array.csv b/scripts/tests/test_numpy_array.csv new file mode 100644 index 0000000000000..32cb6ac05cf37 --- /dev/null +++ b/scripts/tests/test_numpy_array.csv @@ -0,0 +1,3 @@ +id,embedding +1,[0.1 0.2 0.3] +2,[0.4 0.5 0.6] diff --git a/test_numpy_array.csv b/test_numpy_array.csv new file mode 100644 index 0000000000000..3054301980ea7 --- /dev/null +++ b/test_numpy_array.csv @@ -0,0 +1,3 @@ +id,embedding +1,"[0.1, 0.2, 0.3]" +2,"[0.4, 0.5, 0.6]"