Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 36 additions & 8 deletions python/pyspark/sql/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,19 @@ def _need_converter(

@overload
@staticmethod
def _create_converter(dataType: DataType, nullable: bool = True) -> Callable:
def _create_converter(
dataType: DataType, nullable: bool = True, *, int_to_decimal_coercion_enabled: bool = False
) -> Callable:
pass

@overload
@staticmethod
def _create_converter(
dataType: DataType, nullable: bool = True, *, none_on_identity: bool = True
dataType: DataType,
nullable: bool = True,
*,
none_on_identity: bool = True,
int_to_decimal_coercion_enabled: bool = False,
) -> Optional[Callable]:
pass

Expand All @@ -109,6 +115,7 @@ def _create_converter(
nullable: bool = True,
*,
none_on_identity: bool = False,
int_to_decimal_coercion_enabled: bool = False,
) -> Optional[Callable]:
assert dataType is not None and isinstance(dataType, DataType)
assert isinstance(nullable, bool)
Expand All @@ -135,7 +142,10 @@ def convert_null(value: Any) -> Any:

field_convs = [
LocalDataToArrowConversion._create_converter(
field.dataType, field.nullable, none_on_identity=True
field.dataType,
field.nullable,
none_on_identity=True,
int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled,
)
for field in dataType.fields
]
Expand Down Expand Up @@ -189,7 +199,10 @@ def convert_struct(value: Any) -> Any:

elif isinstance(dataType, ArrayType):
element_conv = LocalDataToArrowConversion._create_converter(
dataType.elementType, dataType.containsNull, none_on_identity=True
dataType.elementType,
dataType.containsNull,
none_on_identity=True,
int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled,
)

if element_conv is None:
Expand Down Expand Up @@ -218,10 +231,15 @@ def convert_array(value: Any) -> Any:

elif isinstance(dataType, MapType):
key_conv = LocalDataToArrowConversion._create_converter(
dataType.keyType, nullable=False
dataType.keyType,
nullable=False,
int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled,
)
value_conv = LocalDataToArrowConversion._create_converter(
dataType.valueType, dataType.valueContainsNull, none_on_identity=True
dataType.valueType,
dataType.valueContainsNull,
none_on_identity=True,
int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled,
)

if value_conv is None:
Expand Down Expand Up @@ -295,6 +313,9 @@ def convert_decimal(value: Any) -> Any:
raise PySparkValueError(f"input for {dataType} must not be None")
return None
else:
if int_to_decimal_coercion_enabled and isinstance(value, int):
value = decimal.Decimal(value)

assert isinstance(value, decimal.Decimal)
if value.is_nan():
if not nullable:
Expand Down Expand Up @@ -325,7 +346,10 @@ def convert_string(value: Any) -> Any:
udt: UserDefinedType = dataType

conv = LocalDataToArrowConversion._create_converter(
udt.sqlType(), nullable=nullable, none_on_identity=True
udt.sqlType(),
nullable=nullable,
none_on_identity=True,
int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled,
)

if conv is None:
Expand Down Expand Up @@ -426,7 +450,11 @@ def to_row(item: Any) -> tuple:
if len_column_names > 0:
column_convs = [
LocalDataToArrowConversion._create_converter(
field.dataType, field.nullable, none_on_identity=True
field.dataType,
field.nullable,
none_on_identity=True,
# Default to False for general data conversion
int_to_decimal_coercion_enabled=False,
)
for field in schema.fields
]
Expand Down
15 changes: 12 additions & 3 deletions python/pyspark/sql/pandas/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,15 +776,19 @@ class ArrowBatchUDFSerializer(ArrowStreamArrowUDFSerializer):
A timezone to respect when handling timestamp values
safecheck : bool
If True, conversion from Arrow to Pandas checks for overflow/truncation
input_types : bool
If True, then Pandas DataFrames will get columns by name
input_types : list
List of input data types for the UDF
int_to_decimal_coercion_enabled : bool
If True, applies additional coercions in Python before converting to Arrow
This has performance penalties.
"""

def __init__(
self,
timezone,
safecheck,
input_types,
int_to_decimal_coercion_enabled=False,
):
super().__init__(
timezone=timezone,
Expand All @@ -793,6 +797,7 @@ def __init__(
arrow_cast=True,
)
self._input_types = input_types
self._int_to_decimal_coercion_enabled = int_to_decimal_coercion_enabled

def load_stream(self, stream):
"""
Expand Down Expand Up @@ -843,7 +848,11 @@ def dump_stream(self, iterator, stream):
import pyarrow as pa

def create_array(results, arrow_type, spark_type):
conv = LocalDataToArrowConversion._create_converter(spark_type, none_on_identity=True)
conv = LocalDataToArrowConversion._create_converter(
spark_type,
none_on_identity=True,
int_to_decimal_coercion_enabled=self._int_to_decimal_coercion_enabled,
)
converted = [conv(res) for res in results] if conv is not None else results
try:
return pa.array(converted, type=arrow_type)
Expand Down
55 changes: 55 additions & 0 deletions python/pyspark/sql/tests/arrow/test_arrow_python_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,61 @@ def test_type_coercion_string_to_numeric(self):
with self.assertRaises(PythonException):
df_floating_value.select(udf(lambda x: x, "decimal")("value").alias("res")).collect()

def test_arrow_udf_int_to_decimal_coercion(self):
from decimal import Decimal

with self.sql_conf(
{"spark.sql.legacy.execution.pythonUDF.pandas.conversion.enabled": False}
):
df = self.spark.range(0, 3)

@udf(returnType="decimal(10,2)", useArrow=True)
def int_to_decimal_udf(val):
values = [123, 456, 789]
return values[int(val) % len(values)]

# Test with coercion enabled
with self.sql_conf(
{"spark.sql.execution.pythonUDF.pandas.intToDecimalCoercionEnabled": True}
):
result = df.select(int_to_decimal_udf("id").alias("decimal_val")).collect()
self.assertEqual(result[0]["decimal_val"], Decimal("123.00"))
self.assertEqual(result[1]["decimal_val"], Decimal("456.00"))
self.assertEqual(result[2]["decimal_val"], Decimal("789.00"))

# Test with coercion disabled (should fail)
with self.sql_conf(
{"spark.sql.execution.pythonUDF.pandas.intToDecimalCoercionEnabled": False}
):
with self.assertRaisesRegex(
PythonException, "An exception was thrown from the Python worker"
):
df.select(int_to_decimal_udf("id").alias("decimal_val")).collect()

@udf(returnType="decimal(25,1)", useArrow=True)
def high_precision_udf(val):
values = [1, 2, 3]
return values[int(val) % len(values)]

# Test high precision decimal with coercion enabled
with self.sql_conf(
{"spark.sql.execution.pythonUDF.pandas.intToDecimalCoercionEnabled": True}
):
result = df.select(high_precision_udf("id").alias("decimal_val")).collect()
self.assertEqual(len(result), 3)
self.assertEqual(result[0]["decimal_val"], Decimal("1.0"))
self.assertEqual(result[1]["decimal_val"], Decimal("2.0"))
self.assertEqual(result[2]["decimal_val"], Decimal("3.0"))

# Test high precision decimal with coercion disabled (should fail)
with self.sql_conf(
{"spark.sql.execution.pythonUDF.pandas.intToDecimalCoercionEnabled": False}
):
with self.assertRaisesRegex(
PythonException, "An exception was thrown from the Python worker"
):
df.select(high_precision_udf("id").alias("decimal_val")).collect()

def test_err_return_type(self):
with self.assertRaises(PySparkNotImplementedError) as pe:
udf(lambda x: x, VarcharType(10), useArrow=True)
Expand Down
4 changes: 3 additions & 1 deletion python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2364,7 +2364,9 @@ def read_udfs(pickleSer, infile, eval_type):
input_types = [
f.dataType for f in _parse_datatype_json_string(utf8_deserializer.loads(infile))
]
ser = ArrowBatchUDFSerializer(timezone, safecheck, input_types)
ser = ArrowBatchUDFSerializer(
timezone, safecheck, input_types, int_to_decimal_coercion_enabled
)
else:
# Scalar Pandas UDF handles struct type arguments as pandas DataFrames instead of
# pandas Series. See SPARK-27240.
Expand Down