Skip to content

Commit

Permalink
add JSONArrowScalar
Browse files Browse the repository at this point in the history
  • Loading branch information
chelsea-lin committed Jan 16, 2025
1 parent 6a7e82d commit 4475f9c
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 32 deletions.
7 changes: 4 additions & 3 deletions db_dtypes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
# To use JSONArray and JSONDtype, you'll need Pandas 1.5.0 or later. With the removal
# of Python 3.7 compatibility, the minimum Pandas version will be updated to 1.5.0.
if packaging.version.Version(pandas.__version__) >= packaging.version.Version("1.5.0"):
from db_dtypes.json import ArrowJSONType, JSONArray, JSONDtype
from db_dtypes.json import JSONArray, JSONArrowScalar, JSONArrowType, JSONDtype
else:
JSONArray = None
JSONDtype = None
Expand Down Expand Up @@ -359,7 +359,7 @@ def __sub__(self, other):
)


if not JSONArray or not JSONDtype or not ArrowJSONType:
if not JSONArray or not JSONDtype:
__all__ = [
"__version__",
"DateArray",
Expand All @@ -370,11 +370,12 @@ def __sub__(self, other):
else:
__all__ = [
"__version__",
"ArrowJSONType",
"DateArray",
"DateDtype",
"JSONDtype",
"JSONArray",
"JSONArrowType",
"JSONArrowScalar",
"TimeArray",
"TimeDtype",
]
20 changes: 14 additions & 6 deletions db_dtypes/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,8 @@ def __getitem__(self, item):
value = self.pa_data[item]
if isinstance(value, pa.ChunkedArray):
return type(self)(value)
elif isinstance(value, pa.ExtensionScalar):
return value.as_py()
else:
scalar = JSONArray._deserialize_json(value.as_py())
if scalar is None:
Expand Down Expand Up @@ -259,28 +261,34 @@ def __array__(self, dtype=None, copy: bool | None = None) -> np.ndarray:
return result


class ArrowJSONType(pa.ExtensionType):
class JSONArrowScalar(pa.ExtensionScalar):
def as_py(self):
return JSONArray._deserialize_json(self.value.as_py() if self.value else None)


class JSONArrowType(pa.ExtensionType):
"""Arrow extension type for the `dbjson` Pandas extension type."""

def __init__(self) -> None:
super().__init__(pa.string(), "dbjson")

def __arrow_ext_serialize__(self) -> bytes:
# No parameters are necessary
return b""

@classmethod
def __arrow_ext_deserialize__(cls, storage_type, serialized) -> ArrowJSONType:
# return an instance of this subclass
return ArrowJSONType()
def __arrow_ext_deserialize__(cls, storage_type, serialized) -> JSONArrowType:
return JSONArrowType()

def __hash__(self) -> int:
return hash(str(self))

def to_pandas_dtype(self):
return JSONDtype()

def __arrow_ext_scalar_class__(self):
return JSONArrowScalar


# Register the type to be included in RecordBatches, sent over IPC and received in
# another Python process.
pa.register_extension_type(ArrowJSONType())
pa.register_extension_type(JSONArrowType())
4 changes: 0 additions & 4 deletions tests/compliance/json/test_json_compliance.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,6 @@
import pytest


class TestJSONArrayAccumulate(base.BaseAccumulateTests):
pass


class TestJSONArrayCasting(base.BaseCastingTests):
def test_astype_str(self, data):
# Use `json.dumps(str)` instead of passing `str(obj)` directly to the super method.
Expand Down
119 changes: 100 additions & 19 deletions tests/unit/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import json
import math

import numpy as np
import pandas as pd
Expand All @@ -37,7 +38,7 @@
"null_field": None,
"order": {
"items": ["book", "pen", "computer"],
"total": 15.99,
"total": 15,
"address": {"street": "123 Main St", "city": "Anytown"},
},
},
Expand Down Expand Up @@ -117,35 +118,115 @@ def test_as_numpy_array():
pd._testing.assert_equal(result, expected)


def test_arrow_json_storage_type():
arrow_json_type = db_dtypes.ArrowJSONType()
def test_json_arrow_storage_type():
arrow_json_type = db_dtypes.JSONArrowType()
assert arrow_json_type.extension_name == "dbjson"
assert pa.types.is_string(arrow_json_type.storage_type)


def test_arrow_json_constructors():
storage_array = pa.array(
["0", "str", '{"b": 2}', '{"a": [1, 2, 3]}'], type=pa.string()
)
arr_1 = db_dtypes.ArrowJSONType().wrap_array(storage_array)
def test_json_arrow_constructors():
data = [
json.dumps(value, sort_keys=True, separators=(",", ":"))
for value in JSON_DATA.values()
]
storage_array = pa.array(data, type=pa.string())

arr_1 = db_dtypes.JSONArrowType().wrap_array(storage_array)
assert isinstance(arr_1, pa.ExtensionArray)

arr_2 = pa.ExtensionArray.from_storage(db_dtypes.ArrowJSONType(), storage_array)
arr_2 = pa.ExtensionArray.from_storage(db_dtypes.JSONArrowType(), storage_array)
assert isinstance(arr_2, pa.ExtensionArray)

assert arr_1 == arr_2


def test_arrow_json_to_pandas():
storage_array = pa.array(
[None, "0", "str", '{"b": 2}', '{"a": [1, 2, 3]}'], type=pa.string()
)
arr = db_dtypes.ArrowJSONType().wrap_array(storage_array)
def test_json_arrow_to_pandas():
data = [
json.dumps(value, sort_keys=True, separators=(",", ":"))
for value in JSON_DATA.values()
]
arr = pa.array(data, type=db_dtypes.JSONArrowType())

s = arr.to_pandas()
assert isinstance(s.dtypes, db_dtypes.JSONDtype)
assert pd.isna(s[0])
assert s[1] == 0
assert s[2] == "str"
assert s[3]["b"] == 2
assert s[4]["a"] == [1, 2, 3]
assert s[0]
assert s[1] == 100
assert math.isclose(s[2], 0.98)
assert s[3] == "hello world"
assert math.isclose(s[4][0], 0.1)
assert math.isclose(s[4][1], 0.2)
assert s[5] == {
"null_field": None,
"order": {
"items": ["book", "pen", "computer"],
"total": 15,
"address": {"street": "123 Main St", "city": "Anytown"},
},
}
assert pd.isna(s[6])


def test_json_arrow_to_pylist():
data = [
json.dumps(value, sort_keys=True, separators=(",", ":"))
for value in JSON_DATA.values()
]
arr = pa.array(data, type=db_dtypes.JSONArrowType())

s = arr.to_pylist()
assert isinstance(s, list)
assert s[0]
assert s[1] == 100
assert math.isclose(s[2], 0.98)
assert s[3] == "hello world"
assert math.isclose(s[4][0], 0.1)
assert math.isclose(s[4][1], 0.2)
assert s[5] == {
"null_field": None,
"order": {
"items": ["book", "pen", "computer"],
"total": 15,
"address": {"street": "123 Main St", "city": "Anytown"},
},
}
assert s[6] is None


def test_json_arrow_record_batch():
data = [
json.dumps(value, sort_keys=True, separators=(",", ":"))
for value in JSON_DATA.values()
]
arr = pa.array(data, type=db_dtypes.JSONArrowType())
batch = pa.RecordBatch.from_arrays([arr], ["json_col"])
sink = pa.BufferOutputStream()

with pa.RecordBatchStreamWriter(sink, batch.schema) as writer:
writer.write_batch(batch)

buf = sink.getvalue()

with pa.ipc.open_stream(buf) as reader:
result = reader.read_all()

json_col = result.column("json_col")
assert isinstance(json_col.type, db_dtypes.JSONArrowType)

s = json_col.to_pylist()

assert isinstance(s, list)
assert s[0]
assert s[1] == 100
assert math.isclose(s[2], 0.98)
assert s[3] == "hello world"
assert math.isclose(s[4][0], 0.1)
assert math.isclose(s[4][1], 0.2)
assert s[5] == {
"null_field": None,
"order": {
"items": ["book", "pen", "computer"],
"total": 15,
"address": {"street": "123 Main St", "city": "Anytown"},
},
}
assert s[6] is None

0 comments on commit 4475f9c

Please sign in to comment.