Skip to content

Commit 0fe75a5

Browse files
committed
temp
1 parent bc36a7d commit 0fe75a5

File tree

2 files changed

+48
-0
lines changed

2 files changed

+48
-0
lines changed

python/pyspark/sql/connect/types.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,10 @@ def pyspark_types_to_proto_types(data_type: DataType) -> pb2.DataType:
129129
ret = pb2.DataType()
130130
if isinstance(data_type, NullType):
131131
ret.null.CopyFrom(pb2.DataType.NULL())
132+
elif isinstance(data_type, CharType):
133+
ret.char.length = data_type.length
134+
elif isinstance(data_type, VarcharType):
135+
ret.var_char.length = data_type.length
132136
elif isinstance(data_type, StringType):
133137
ret.string.collation = data_type.collation
134138
elif isinstance(data_type, BooleanType):

python/pyspark/sql/tests/test_udf.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1402,6 +1402,50 @@ def my_udf(input_val):
14021402
result_type = df_result.schema["result"].dataType
14031403
self.assertEqual(result_type, StringType("fr"))
14041404

1405+
def test_udf_with_char_varchar_return_type(self):
1406+
(char_type, char_value) = ("char(10)", "a")
1407+
(varchar_type, varchar_value) = ("varchar(8)", "a")
1408+
(array_with_char_type, array_with_char_type_value) = ("array<char(5)>", ["a", "b"])
1409+
(array_with_varchar_type, array_with_varchar_value) = ("array<varchar(12)>", ["a", "b"])
1410+
(map_type, map_value) = (f"map<{char_type}, {varchar_type}>", {"a": "b"})
1411+
(struct_type, struct_value) = (
1412+
f"struct<f1: {char_type}, f2: {varchar_type}>",
1413+
{"f1": "a", "f2": "b"},
1414+
)
1415+
1416+
pairs = [
1417+
(char_type, char_value),
1418+
(varchar_type, varchar_value),
1419+
(array_with_char_type, array_with_char_type_value),
1420+
(array_with_varchar_type, array_with_varchar_value),
1421+
(map_type, map_value),
1422+
(struct_type, struct_value),
1423+
(
1424+
f"struct<f1: {array_with_char_type}, f2: {array_with_varchar_type}, "
1425+
f"f3: {map_type}>",
1426+
f"{{'f1': {array_with_char_type_value}, 'f2': {array_with_varchar_value}, "
1427+
f"'f3': {map_value}}}",
1428+
),
1429+
(
1430+
f"map<{array_with_char_type}, {array_with_varchar_type}>",
1431+
f"{{{array_with_char_type_value}: {array_with_varchar_value}}}",
1432+
),
1433+
(f"array<{struct_type}>", [struct_value, struct_value]),
1434+
]
1435+
1436+
for return_type, return_value in pairs:
1437+
with self.assertRaisesRegex(
1438+
Exception,
1439+
"(Please use a different output data type for your UDF or DataFrame|"
1440+
"Invalid return type with Arrow-optimized Python UDF)",
1441+
):
1442+
1443+
@udf(return_type)
1444+
def my_udf():
1445+
return return_value
1446+
1447+
self.spark.range(1).select(my_udf().alias("result")).show()
1448+
14051449

14061450
class UDFTests(BaseUDFTestsMixin, ReusedSQLTestCase):
14071451
@classmethod

0 commit comments

Comments
 (0)