Skip to content

Commit

Permalink
[SPARK-46044][PYTHON][TESTS] Improve test coverage of udf.py
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Improve test coverage of udf.py

### Why are the changes needed?
Subtasks of [SPARK-46041](https://issues.apache.org/jira/browse/SPARK-46041) to improve test coverage

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
Test changes only.

### Was this patch authored or co-authored using generative AI tooling?
No

Closes apache#43947 from xinrong-meng/test_udf.

Authored-by: Xinrong Meng <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
  • Loading branch information
xinrong-meng authored and HyukjinKwon committed Dec 5, 2023
1 parent c9df53f commit 5d47aae
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 1 deletion.
6 changes: 6 additions & 0 deletions python/pyspark/sql/tests/connect/test_parity_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,12 @@ def test_non_existed_udf_with_sql_context(self):
def test_udf_registration_returns_udf_on_sql_context(self):
super().test_udf_registration_returns_udf_on_sql_context()

def test_err_udf_registration(self):
self.check_err_udf_registration()

def test_err_udf_init(self):
self.check_err_udf_init()


if __name__ == "__main__":
import unittest
Expand Down
53 changes: 52 additions & 1 deletion python/pyspark/sql/tests/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,12 +377,17 @@ def test_udf_with_order_by_and_limit(self):
def test_udf_registration_returns_udf(self):
df = self.spark.range(10)
add_three = self.spark.udf.register("add_three", lambda x: x + 3, IntegerType())

self.assertListEqual(
df.selectExpr("add_three(id) AS plus_three").collect(),
df.select(add_three("id").alias("plus_three")).collect(),
)

add_three_str = self.spark.udf.register("add_three_str", lambda x: x + 3)
self.assertListEqual(
df.selectExpr("add_three_str(id) AS plus_three").collect(),
df.select(add_three_str("id").alias("plus_three")).collect(),
)

def test_udf_registration_returns_udf_on_sql_context(self):
df = self.spark.range(10)

Expand Down Expand Up @@ -425,6 +430,20 @@ def test_register_java_udaf(self):
).first()
self.assertEqual(row.asDict(), Row(name="b", avg=102.0).asDict())

def test_err_udf_registration(self):
with QuietTest(self.sc):
self.check_err_udf_registration()

def check_err_udf_registration(self):
with self.assertRaises(PySparkTypeError) as pe:
self.spark.udf.register("f", UserDefinedFunction("x", StringType()), "int")

self.check_error(
exception=pe.exception,
error_class="NOT_CALLABLE",
message_parameters={"arg_name": "func", "arg_type": "str"},
)

def test_non_existed_udf(self):
spark = self.spark
self.assertRaisesRegex(
Expand Down Expand Up @@ -1027,6 +1046,38 @@ def test_python_udf_segfault(self):

self.spark.range(1).select(udf(lambda x: ctypes.string_at(0))("id")).collect()

def test_err_udf_init(self):
with QuietTest(self.sc):
self.check_err_udf_init()

def check_err_udf_init(self):
with self.assertRaises(PySparkTypeError) as pe:
UserDefinedFunction("x", StringType())

self.check_error(
exception=pe.exception,
error_class="NOT_CALLABLE",
message_parameters={"arg_name": "func", "arg_type": "str"},
)

with self.assertRaises(PySparkTypeError) as pe:
UserDefinedFunction(lambda x: x, 1)

self.check_error(
exception=pe.exception,
error_class="NOT_DATATYPE_OR_STR",
message_parameters={"arg_name": "returnType", "arg_type": "int"},
)

with self.assertRaises(PySparkTypeError) as pe:
UserDefinedFunction(lambda x: x, StringType(), evalType="SQL_BATCHED_UDF")

self.check_error(
exception=pe.exception,
error_class="NOT_INT",
message_parameters={"arg_name": "evalType", "arg_type": "str"},
)


class UDFTests(BaseUDFTestsMixin, ReusedSQLTestCase):
@classmethod
Expand Down

0 comments on commit 5d47aae

Please sign in to comment.