diff --git a/db_dtypes/__init__.py b/db_dtypes/__init__.py index 6ce652d..d5b05dc 100644 --- a/db_dtypes/__init__.py +++ b/db_dtypes/__init__.py @@ -50,7 +50,7 @@ # To use JSONArray and JSONDtype, you'll need Pandas 1.5.0 or later. With the removal # of Python 3.7 compatibility, the minimum Pandas version will be updated to 1.5.0. if packaging.version.Version(pandas.__version__) >= packaging.version.Version("1.5.0"): - from db_dtypes.json import ArrowJSONType, JSONArray, JSONDtype + from db_dtypes.json import JSONArray, JSONArrowScalar, JSONArrowType, JSONDtype else: JSONArray = None JSONDtype = None @@ -359,7 +359,7 @@ def __sub__(self, other): ) -if not JSONArray or not JSONDtype or not ArrowJSONType: +if not JSONArray or not JSONDtype: __all__ = [ "__version__", "DateArray", @@ -370,11 +370,12 @@ def __sub__(self, other): else: __all__ = [ "__version__", - "ArrowJSONType", "DateArray", "DateDtype", "JSONDtype", "JSONArray", + "JSONArrowType", + "JSONArrowScalar", "TimeArray", "TimeDtype", ] diff --git a/db_dtypes/json.py b/db_dtypes/json.py index ddbf7c7..d08d1cb 100644 --- a/db_dtypes/json.py +++ b/db_dtypes/json.py @@ -221,6 +221,8 @@ def __getitem__(self, item): value = self.pa_data[item] if isinstance(value, pa.ChunkedArray): return type(self)(value) + elif isinstance(value, pa.ExtensionScalar): + return value.as_py() else: scalar = JSONArray._deserialize_json(value.as_py()) if scalar is None: @@ -259,20 +261,23 @@ def __array__(self, dtype=None, copy: bool | None = None) -> np.ndarray: return result -class ArrowJSONType(pa.ExtensionType): +class JSONArrowScalar(pa.ExtensionScalar): + def as_py(self): + return JSONArray._deserialize_json(self.value.as_py() if self.value else None) + + +class JSONArrowType(pa.ExtensionType): """Arrow extension type for the `dbjson` Pandas extension type.""" def __init__(self) -> None: super().__init__(pa.string(), "dbjson") def __arrow_ext_serialize__(self) -> bytes: - # No parameters are necessary return b"" @classmethod - def __arrow_ext_deserialize__(cls, storage_type, serialized) -> ArrowJSONType: - # return an instance of this subclass - return ArrowJSONType() + def __arrow_ext_deserialize__(cls, storage_type, serialized) -> JSONArrowType: + return JSONArrowType() def __hash__(self) -> int: return hash(str(self)) @@ -280,7 +285,10 @@ def __hash__(self) -> int: def to_pandas_dtype(self): return JSONDtype() + def __arrow_ext_scalar_class__(self): + return JSONArrowScalar + # Register the type to be included in RecordBatches, sent over IPC and received in # another Python process. -pa.register_extension_type(ArrowJSONType()) +pa.register_extension_type(JSONArrowType()) diff --git a/tests/compliance/json/test_json_compliance.py b/tests/compliance/json/test_json_compliance.py index 2a8e69a..9a0d0ef 100644 --- a/tests/compliance/json/test_json_compliance.py +++ b/tests/compliance/json/test_json_compliance.py @@ -22,10 +22,6 @@ import pytest -class TestJSONArrayAccumulate(base.BaseAccumulateTests): - pass - - class TestJSONArrayCasting(base.BaseCastingTests): def test_astype_str(self, data): # Use `json.dumps(str)` instead of passing `str(obj)` directly to the super method. diff --git a/tests/unit/test_json.py b/tests/unit/test_json.py index 750ddbc..949f1bd 100644 --- a/tests/unit/test_json.py +++ b/tests/unit/test_json.py @@ -13,6 +13,7 @@ # limitations under the License. import json +import math import numpy as np import pandas as pd @@ -37,7 +38,7 @@ "null_field": None, "order": { "items": ["book", "pen", "computer"], - "total": 15.99, + "total": 15, "address": {"street": "123 Main St", "city": "Anytown"}, }, }, @@ -117,35 +118,115 @@ def test_as_numpy_array(): pd._testing.assert_equal(result, expected) -def test_arrow_json_storage_type(): - arrow_json_type = db_dtypes.ArrowJSONType() +def test_json_arrow_storage_type(): + arrow_json_type = db_dtypes.JSONArrowType() assert arrow_json_type.extension_name == "dbjson" assert pa.types.is_string(arrow_json_type.storage_type) -def test_arrow_json_constructors(): - storage_array = pa.array( - ["0", "str", '{"b": 2}', '{"a": [1, 2, 3]}'], type=pa.string() - ) - arr_1 = db_dtypes.ArrowJSONType().wrap_array(storage_array) +def test_json_arrow_constructors(): + data = [ + json.dumps(value, sort_keys=True, separators=(",", ":")) + for value in JSON_DATA.values() + ] + storage_array = pa.array(data, type=pa.string()) + + arr_1 = db_dtypes.JSONArrowType().wrap_array(storage_array) assert isinstance(arr_1, pa.ExtensionArray) - arr_2 = pa.ExtensionArray.from_storage(db_dtypes.ArrowJSONType(), storage_array) + arr_2 = pa.ExtensionArray.from_storage(db_dtypes.JSONArrowType(), storage_array) assert isinstance(arr_2, pa.ExtensionArray) assert arr_1 == arr_2 -def test_arrow_json_to_pandas(): - storage_array = pa.array( - [None, "0", "str", '{"b": 2}', '{"a": [1, 2, 3]}'], type=pa.string() - ) - arr = db_dtypes.ArrowJSONType().wrap_array(storage_array) +def test_json_arrow_to_pandas(): + data = [ + json.dumps(value, sort_keys=True, separators=(",", ":")) + for value in JSON_DATA.values() + ] + arr = pa.array(data, type=db_dtypes.JSONArrowType()) s = arr.to_pandas() assert isinstance(s.dtypes, db_dtypes.JSONDtype) - assert pd.isna(s[0]) - assert s[1] == 0 - assert s[2] == "str" - assert s[3]["b"] == 2 - assert s[4]["a"] == [1, 2, 3] + assert s[0] + assert s[1] == 100 + assert math.isclose(s[2], 0.98) + assert s[3] == "hello world" + assert math.isclose(s[4][0], 0.1) + assert math.isclose(s[4][1], 0.2) + assert s[5] == { + "null_field": None, + "order": { + "items": ["book", "pen", "computer"], + "total": 15, + "address": {"street": "123 Main St", "city": "Anytown"}, + }, + } + assert pd.isna(s[6]) + + +def test_json_arrow_to_pylist(): + data = [ + json.dumps(value, sort_keys=True, separators=(",", ":")) + for value in JSON_DATA.values() + ] + arr = pa.array(data, type=db_dtypes.JSONArrowType()) + + s = arr.to_pylist() + assert isinstance(s, list) + assert s[0] + assert s[1] == 100 + assert math.isclose(s[2], 0.98) + assert s[3] == "hello world" + assert math.isclose(s[4][0], 0.1) + assert math.isclose(s[4][1], 0.2) + assert s[5] == { + "null_field": None, + "order": { + "items": ["book", "pen", "computer"], + "total": 15, + "address": {"street": "123 Main St", "city": "Anytown"}, + }, + } + assert s[6] is None + + +def test_json_arrow_record_batch(): + data = [ + json.dumps(value, sort_keys=True, separators=(",", ":")) + for value in JSON_DATA.values() + ] + arr = pa.array(data, type=db_dtypes.JSONArrowType()) + batch = pa.RecordBatch.from_arrays([arr], ["json_col"]) + sink = pa.BufferOutputStream() + + with pa.RecordBatchStreamWriter(sink, batch.schema) as writer: + writer.write_batch(batch) + + buf = sink.getvalue() + + with pa.ipc.open_stream(buf) as reader: + result = reader.read_all() + + json_col = result.column("json_col") + assert isinstance(json_col.type, db_dtypes.JSONArrowType) + + s = json_col.to_pylist() + + assert isinstance(s, list) + assert s[0] + assert s[1] == 100 + assert math.isclose(s[2], 0.98) + assert s[3] == "hello world" + assert math.isclose(s[4][0], 0.1) + assert math.isclose(s[4][1], 0.2) + assert s[5] == { + "null_field": None, + "order": { + "items": ["book", "pen", "computer"], + "total": 15, + "address": {"street": "123 Main St", "city": "Anytown"}, + }, + } + assert s[6] is None