Skip to content

Commit

Permalink
[SPARK-45523][PYTHON] Refactor the null-checking to have shortcuts
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

This is a follow-up of apache#43356.

Refactor the null-checking to have shortcuts.

### Why are the changes needed?

The null-check can have shortcuts for some cases.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

The existing tests.

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes apache#43492 from ueshin/issues/SPARK-45523/nullcheck.

Authored-by: Takuya UESHIN <[email protected]>
Signed-off-by: Takuya UESHIN <[email protected]>
  • Loading branch information
ueshin committed Oct 24, 2023
1 parent 7ef96ee commit 68c0f64
Showing 1 changed file with 129 additions and 82 deletions.
211 changes: 129 additions & 82 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import time
from inspect import getfullargspec
import json
from typing import Any, Callable, Iterable, Iterator
from typing import Any, Callable, Iterable, Iterator, Optional
import faulthandler

from pyspark.accumulators import _accumulatorRegistry
Expand Down Expand Up @@ -58,7 +58,6 @@
MapType,
Row,
StringType,
StructField,
StructType,
_create_row,
_parse_datatype_json_string,
Expand Down Expand Up @@ -700,7 +699,7 @@ def read_udtf(pickleSer, infile, eval_type):
)

return_type = _parse_datatype_json_string(utf8_deserializer.loads(infile))
if not type(return_type) == StructType:
if not isinstance(return_type, StructType):
raise PySparkRuntimeError(
f"The return type of a UDTF must be a struct type, but got {type(return_type)}."
)
Expand Down Expand Up @@ -845,70 +844,112 @@ def _remove_partition_by_exprs(self, arg: Any) -> Any:
"the query again."
)

# This determines which result columns have nullable types.
def check_nullable_column(i: int, data_type: DataType, nullable: bool) -> None:
if not nullable:
nullable_columns.add(i)
elif isinstance(data_type, ArrayType):
check_nullable_column(i, data_type.elementType, data_type.containsNull)
elif isinstance(data_type, StructType):
for subfield in data_type.fields:
check_nullable_column(i, subfield.dataType, subfield.nullable)
elif isinstance(data_type, MapType):
check_nullable_column(i, data_type.valueType, data_type.valueContainsNull)

nullable_columns: set[int] = set()
for i, field in enumerate(return_type.fields):
check_nullable_column(i, field.dataType, field.nullable)

# Compares each UDTF output row against the output schema for this particular UDTF call,
# raising an error if the two are incompatible.
def check_output_row_against_schema(row: Any, expected_schema: StructType) -> None:
for result_column_index in nullable_columns:

def check_for_none_in_non_nullable_column(
value: Any, data_type: DataType, nullable: bool
) -> None:
if value is None and not nullable:
raise PySparkRuntimeError(
error_class="UDTF_EXEC_ERROR",
message_parameters={
"method_name": "eval' or 'terminate",
"error": f"Column {result_column_index} within a returned row had a "
+ "value of None, either directly or within array/struct/map "
+ "subfields, but the corresponding column type was declared as "
+ "non-nullable; please update the UDTF to return a non-None value at "
+ "this location or otherwise declare the column type as nullable.",
},
)
elif (
isinstance(data_type, ArrayType)
and isinstance(value, list)
and not data_type.containsNull
):
for sub_value in value:
check_for_none_in_non_nullable_column(
sub_value, data_type.elementType, data_type.containsNull
)
elif isinstance(data_type, StructType) and isinstance(value, Row):
for i in range(len(value)):
check_for_none_in_non_nullable_column(
value[i], data_type[i].dataType, data_type[i].nullable
)
elif isinstance(data_type, MapType) and isinstance(value, dict):
for map_key, map_value in value.items():
check_for_none_in_non_nullable_column(
map_key, data_type.keyType, nullable=False
)
check_for_none_in_non_nullable_column(
map_value, data_type.valueType, data_type.valueContainsNull
)
def build_null_checker(return_type: StructType) -> Optional[Callable[[Any], None]]:
def raise_(result_column_index):
raise PySparkRuntimeError(
error_class="UDTF_EXEC_ERROR",
message_parameters={
"method_name": "eval' or 'terminate",
"error": f"Column {result_column_index} within a returned row had a "
+ "value of None, either directly or within array/struct/map "
+ "subfields, but the corresponding column type was declared as "
+ "non-nullable; please update the UDTF to return a non-None value at "
+ "this location or otherwise declare the column type as nullable.",
},
)

field: StructField = expected_schema[result_column_index]
if row is not None:
check_for_none_in_non_nullable_column(
list(row)[result_column_index], field.dataType, field.nullable
)
def checker(data_type: DataType, result_column_index: int):
if isinstance(data_type, ArrayType):
element_checker = checker(data_type.elementType, result_column_index)
contains_null = data_type.containsNull

if element_checker is None and contains_null:
return None

def check_array(arr):
if isinstance(arr, list):
for e in arr:
if e is None:
if not contains_null:
raise_(result_column_index)
elif element_checker is not None:
element_checker(e)

return check_array

elif isinstance(data_type, MapType):
key_checker = checker(data_type.keyType, result_column_index)
value_checker = checker(data_type.valueType, result_column_index)
value_contains_null = data_type.valueContainsNull

if value_checker is None and value_contains_null:

def check_map(map):
if isinstance(map, dict):
for k, v in map.items():
if k is None:
raise_(result_column_index)
elif key_checker is not None:
key_checker(k)

else:

def check_map(map):
if isinstance(map, dict):
for k, v in map.items():
if k is None:
raise_(result_column_index)
elif key_checker is not None:
key_checker(k)
if v is None:
if not value_contains_null:
raise_(result_column_index)
elif value_checker is not None:
value_checker(v)

return check_map

elif isinstance(data_type, StructType):
field_checkers = [checker(f.dataType, result_column_index) for f in data_type]
nullables = [f.nullable for f in data_type]

if all(c is None for c in field_checkers) and all(nullables):
return None

def check_struct(struct):
if isinstance(struct, tuple):
for value, checker, nullable in zip(struct, field_checkers, nullables):
if value is None:
if not nullable:
raise_(result_column_index)
elif checker is not None:
checker(value)

return check_struct

else:
return None

field_checkers = [
checker(f.dataType, result_column_index=i) for i, f in enumerate(return_type)
]
nullables = [f.nullable for f in return_type]

if all(c is None for c in field_checkers) and all(nullables):
return None

def check(row):
if isinstance(row, tuple):
for i, (value, checker, nullable) in enumerate(zip(row, field_checkers, nullables)):
if value is None:
if not nullable:
raise_(i)
elif checker is not None:
checker(value)

return check

check_output_row_against_schema = build_null_checker(return_type)

if eval_type == PythonEvalType.SQL_ARROW_TABLE_UDF:

Expand Down Expand Up @@ -948,8 +989,6 @@ def verify_result(result):
verify_pandas_result(
result, return_type, assign_cols_by_name=False, truncate_return_schema=False
)
for result_tuple in result.itertuples():
check_output_row_against_schema(list(result_tuple), return_type)
return result

# Wrap the exception thrown from the UDTF in a PySparkRuntimeError.
Expand All @@ -965,28 +1004,36 @@ def func(*a: Any) -> Any:
def check_return_value(res):
# Check whether the result of an arrow UDTF is iterable before
# using it to construct a pandas DataFrame.
if res is not None and not isinstance(res, Iterable):
raise PySparkRuntimeError(
error_class="UDTF_RETURN_NOT_ITERABLE",
message_parameters={
"type": type(res).__name__,
"func": f.__name__,
},
)
if res is not None:
if not isinstance(res, Iterable):
raise PySparkRuntimeError(
error_class="UDTF_RETURN_NOT_ITERABLE",
message_parameters={
"type": type(res).__name__,
"func": f.__name__,
},
)
if check_output_row_against_schema is not None:
for row in res:
if row is not None:
check_output_row_against_schema(row)
yield row
else:
yield from res

def evaluate(*args: pd.Series):
if len(args) == 0:
res = func()
check_return_value(res)
yield verify_result(pd.DataFrame(res)), arrow_return_type
yield verify_result(pd.DataFrame(check_return_value(res))), arrow_return_type
else:
# Create tuples from the input pandas Series, each tuple
# represents a row across all Series.
row_tuples = zip(*args)
for row in row_tuples:
res = func(*row)
check_return_value(res)
yield verify_result(pd.DataFrame(res)), arrow_return_type
yield verify_result(
pd.DataFrame(check_return_value(res))
), arrow_return_type

return evaluate

Expand Down Expand Up @@ -1043,8 +1090,8 @@ def verify_and_convert_result(result):
"func": f.__name__,
},
)

check_output_row_against_schema(result, return_type)
if check_output_row_against_schema is not None:
check_output_row_against_schema(result)
return toInternal(result)

# Evaluate the function and return a tuple back to the executor.
Expand Down

0 comments on commit 68c0f64

Please sign in to comment.