Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
chelsea-lin committed Aug 7, 2024
1 parent b4cfcd9 commit 17f560e
Showing 1 changed file with 15 additions and 24 deletions.
39 changes: 15 additions & 24 deletions db_dtypes/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,6 @@
import pyarrow as pa
import pyarrow.compute

ARROW_CMP_FUNCS = {
"eq": pyarrow.compute.equal,
"ne": pyarrow.compute.not_equal,
"lt": pyarrow.compute.less,
"gt": pyarrow.compute.greater,
"le": pyarrow.compute.less_equal,
"ge": pyarrow.compute.greater_equal,
}


@pd.api.extensions.register_extension_dtype
class JSONDtype(pd.api.extensions.ExtensionDtype):
Expand Down Expand Up @@ -68,11 +59,6 @@ def construct_array_type(cls):
"""Return the array type associated with this dtype."""
return JSONArray

# @staticmethod
# def __from_arrow__(array: typing.Union[pa.Array, pa.ChunkedArray]) -> JSONArray:
# """Convert to JSONArray from an Arrow array."""
# return JSONArray(array)


class JSONArray(arrays.ArrowExtensionArray):
"""Extension array that handles BigQuery JSON data, leveraging a string-based
Expand All @@ -95,26 +81,26 @@ def _box_pa(
cls, value, pa_type: pa.DataType | None = None
) -> pa.Array | pa.ChunkedArray | pa.Scalar:
"""Box value into a pyarrow Array, ChunkedArray or Scalar."""
if pa_type is not None and pa_type != pa.string():
raise ValueError(f"Unsupported type '{pa_type}' for JSONArray")

if isinstance(value, pa.Scalar) or not (
common.is_list_like(value) and not common.is_dict_like(value)
):
return cls._box_pa_scalar(value, pa_type)
return cls._box_pa_array(value, pa_type)
return cls._box_pa_scalar(value)
return cls._box_pa_array(value)

@classmethod
def _box_pa_scalar(cls, value, pa_type: pa.DataType | None = None) -> pa.Scalar:
def _box_pa_scalar(cls, value) -> pa.Scalar:
"""Box value into a pyarrow Scalar."""
if isinstance(value, pa.Scalar):
pa_scalar = value
if pd.isna(value):
pa_scalar = pa.scalar(None, type=pa_type)
pa_scalar = pa.scalar(None, type=pa.string())
else:
value = JSONArray._serialize_json(value)
pa_scalar = pa.scalar(value, type=pa_type, from_pandas=True)
pa_scalar = pa.scalar(value, type=pa.string(), from_pandas=True)

if pa_type is not None and pa_scalar.type != pa_type:
pa_scalar = pa_scalar.cast(pa_type)
return pa_scalar

@classmethod
Expand All @@ -131,7 +117,8 @@ def _box_pa_array(
value = [JSONArray._serialize_json(x) for x in value]
pa_array = pa.array(value, type=pa_type, from_pandas=True)
except (pa.ArrowInvalid, pa.ArrowTypeError):
# GH50430: let pyarrow infer type, then cast
# https://github.com/pandas-dev/pandas/pull/50430:
# let pyarrow infer type, then cast
pa_array = pa.array(value, from_pandas=True)

if pa_type is not None and pa_array.type != pa_type:
Expand Down Expand Up @@ -181,8 +168,12 @@ def dtype(self) -> JSONDtype:
return self._dtype

def _cmp_method(self, other, op):
pc_func = ARROW_CMP_FUNCS[op.__name__]
result = pc_func(self._pa_array, self._box_pa(other))
if op.__name__ == "eq":
result = pyarrow.compute.equal(self._pa_array, self._box_pa(other))
elif op.__name__ == "ne":
result = pyarrow.compute.not_equal(self._pa_array, self._box_pa(other))
else:
raise NotImplementedError(f"{op.__name__} not implemented for JSONArray")
return arrays.ArrowExtensionArray(result)

def __getitem__(self, item):
Expand Down

0 comments on commit 17f560e

Please sign in to comment.