diff --git a/dask_sql/context.py b/dask_sql/context.py index 0b6f8faf8..17c6d0055 100644 --- a/dask_sql/context.py +++ b/dask_sql/context.py @@ -845,15 +845,19 @@ def _get_ral(self, sql): def _compute_table_from_rel(self, rel: "LogicalPlan", return_futures: bool = True): dc = RelConverter.convert(rel, context=self) - # Optimization might remove some alias projects. Make sure to keep them here. - select_names = [field for field in rel.getRowType().getFieldList()] - if rel.get_current_node_type() == "Explain": return dc if dc is None: return + # Optimization might remove some alias projects. Make sure to keep them here. + select_names = [field for field in rel.getRowType().getFieldList()] + if select_names: + cc = dc.column_container + + select_names = select_names[: len(cc.columns)] + # Use FQ name if not unique and simple name if it is unique. If a join contains the same column # names the output col is prepended with the fully qualified column name field_counts = Counter([field.getName() for field in select_names]) @@ -864,7 +868,6 @@ def _compute_table_from_rel(self, rel: "LogicalPlan", return_futures: bool = Tru for field in select_names ] - cc = dc.column_container cc = cc.rename( { df_col: select_name diff --git a/dask_sql/physical/rel/base.py b/dask_sql/physical/rel/base.py index 520f14e6d..a1f378197 100644 --- a/dask_sql/physical/rel/base.py +++ b/dask_sql/physical/rel/base.py @@ -1,5 +1,5 @@ import logging -from typing import TYPE_CHECKING, List +from typing import TYPE_CHECKING, List, Optional import dask.dataframe as dd @@ -30,7 +30,7 @@ def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> dd.DataFra @staticmethod def fix_column_to_row_type( - cc: ColumnContainer, row_type: "RelDataType" + cc: ColumnContainer, row_type: "RelDataType", join_type: Optional[str] = None ) -> ColumnContainer: """ Make sure that the given column container @@ -39,6 +39,8 @@ def fix_column_to_row_type( and will just "blindly" rename the columns. """ field_names = [str(x) for x in row_type.getFieldNames()] + if join_type in ("leftsemi", "leftanti"): + field_names = field_names[: len(cc.columns)] logger.debug(f"Renaming {cc.columns} to {field_names}") cc = cc.rename_handle_duplicates( @@ -84,7 +86,9 @@ def assert_inputs( return [RelConverter.convert(input_rel, context) for input_rel in input_rels] @staticmethod - def fix_dtype_to_row_type(dc: DataContainer, row_type: "RelDataType"): + def fix_dtype_to_row_type( + dc: DataContainer, row_type: "RelDataType", join_type: Optional[str] = None + ): """ Fix the dtype of the given data container (or: the df within it) to the data type given as argument. @@ -98,9 +102,12 @@ def fix_dtype_to_row_type(dc: DataContainer, row_type: "RelDataType"): df = dc.df cc = dc.column_container + field_list = row_type.getFieldList() + if join_type in ("leftsemi", "leftanti"): + field_list = field_list[: len(cc.columns)] + field_types = { - str(field.getQualifiedName()): field.getType() - for field in row_type.getFieldList() + str(field.getQualifiedName()): field.getType() for field in field_list } for field_name, field_type in field_types.items(): diff --git a/dask_sql/physical/rel/logical/join.py b/dask_sql/physical/rel/logical/join.py index 3aa3774d2..c1c904af6 100644 --- a/dask_sql/physical/rel/logical/join.py +++ b/dask_sql/physical/rel/logical/join.py @@ -14,6 +14,7 @@ from dask_sql.physical.rel.base import BaseRelPlugin from dask_sql.physical.rel.logical.filter import filter_or_scalar from dask_sql.physical.rex import RexConverter +from dask_sql.utils import is_cudf_type if TYPE_CHECKING: import dask_sql @@ -45,7 +46,8 @@ class DaskJoinPlugin(BaseRelPlugin): "LEFT": "left", "RIGHT": "right", "FULL": "outer", - "LEFTSEMI": "inner", # TODO: Need research here! This is likely not a true inner join + "LEFTSEMI": "leftsemi", + "LEFTANTI": "leftanti", } def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> DataContainer: @@ -74,6 +76,10 @@ def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> DataContai join_type = join.getJoinType() join_type = self.JOIN_TYPE_MAPPING[str(join_type)] + # TODO: update with correct implementation of leftsemi for CPU + # https://github.com/dask-contrib/dask-sql/issues/1190 + if join_type == "leftsemi" and not is_cudf_type(df_lhs_renamed): + join_type = "inner" # 3. The join condition can have two forms, that we can understand # (a) a = b @@ -170,14 +176,19 @@ def merge_single_partitions(lhs_partition, rhs_partition): # 6. So the next step is to make sure # we have the correct column order (and to remove the temporary join columns) - correct_column_order = list(df_lhs_renamed.columns) + list( - df_rhs_renamed.columns - ) + if join_type in ("leftsemi", "leftanti"): + correct_column_order = list(df_lhs_renamed.columns) + else: + correct_column_order = list(df_lhs_renamed.columns) + list( + df_rhs_renamed.columns + ) cc = ColumnContainer(df.columns).limit_to(correct_column_order) # and to rename them like the rel specifies row_type = rel.getRowType() field_specifications = [str(f) for f in row_type.getFieldNames()] + if join_type in ("leftsemi", "leftanti"): + field_specifications = field_specifications[: len(cc.columns)] cc = cc.rename( { @@ -185,7 +196,7 @@ def merge_single_partitions(lhs_partition, rhs_partition): for from_col, to_col in zip(cc.columns, field_specifications) } ) - cc = self.fix_column_to_row_type(cc, row_type) + cc = self.fix_column_to_row_type(cc, row_type, join_type) dc = DataContainer(df, cc) # 7. Last but not least we apply any filters by and-chaining together the filters @@ -202,7 +213,7 @@ def merge_single_partitions(lhs_partition, rhs_partition): df = filter_or_scalar(df, filter_condition) dc = DataContainer(df, cc) - dc = self.fix_dtype_to_row_type(dc, rel.getRowType()) + dc = self.fix_dtype_to_row_type(dc, rel.getRowType(), join_type) # # Rename underlying DataFrame column names back to their original values before returning # df = dc.assign() # dc = DataContainer(df, ColumnContainer(cc.columns)) @@ -227,7 +238,7 @@ def _join_on_columns( [~df_lhs_renamed.iloc[:, index].isna() for index in lhs_on], ) df_lhs_renamed = df_lhs_renamed[df_lhs_filter] - if join_type in ["inner", "left"]: + if join_type in ["inner", "left", "leftanti", "leftsemi"]: df_rhs_filter = reduce( operator.and_, [~df_rhs_renamed.iloc[:, index].isna() for index in rhs_on], @@ -256,12 +267,24 @@ def _join_on_columns( "For more information refer to https://github.com/dask/dask/issues/9851" " and https://github.com/dask/dask/issues/9870" ) - df = df_lhs_with_tmp.merge( - df_rhs_with_tmp, - on=added_columns, - how=join_type, - broadcast=broadcast, - ).drop(columns=added_columns) + if join_type == "leftanti" and not is_cudf_type(df_lhs_with_tmp): + df = df_lhs_with_tmp.merge( + df_rhs_with_tmp, + on=added_columns, + how="left", + broadcast=broadcast, + indicator=True, + ).drop(columns=added_columns) + df = df[df["_merge"] == "left_only"].drop( + columns=["_merge"] + list(df_rhs_with_tmp.columns), errors="ignore" + ) + else: + df = df_lhs_with_tmp.merge( + df_rhs_with_tmp, + on=added_columns, + how=join_type, + broadcast=broadcast, + ).drop(columns=added_columns) return df diff --git a/tests/integration/test_join.py b/tests/integration/test_join.py index 3b131541c..1169bc947 100644 --- a/tests/integration/test_join.py +++ b/tests/integration/test_join.py @@ -86,6 +86,56 @@ def test_join_left(c): assert_eq(return_df, expected_df, check_index=False) +@pytest.mark.parametrize("gpu", [False, pytest.param(True, marks=pytest.mark.gpu)]) +def test_join_left_anti(c, gpu): + df1 = pd.DataFrame({"id": [1, 1, 2, 4], "a": ["a", "b", "c", "d"]}) + df2 = pd.DataFrame({"id": [2, 1, 2, 3], "b": ["c", "c", "a", "c"]}) + c.create_table("df_1", df1, gpu=gpu) + c.create_table("df_2", df2, gpu=gpu) + + return_df = c.sql( + """ + SELECT lhs.id, lhs.a + FROM df_1 AS lhs + LEFT ANTI JOIN df_2 AS rhs + ON lhs.id = rhs.id + """ + ) + expected_df = pd.DataFrame( + { + "id": [4], + "a": ["d"], + } + ) + + assert_eq(return_df, expected_df, check_index=False) + + +@pytest.mark.gpu +def test_join_left_semi(c): + df1 = pd.DataFrame({"id": [1, 1, 2, 4], "a": ["a", "b", "c", "d"]}) + df2 = pd.DataFrame({"id": [2, 1, 2, 3], "b": ["c", "c", "a", "c"]}) + c.create_table("df_1", df1, gpu=True) + c.create_table("df_2", df2, gpu=True) + + return_df = c.sql( + """ + SELECT lhs.id, lhs.a + FROM df_1 AS lhs + LEFT SEMI JOIN df_2 AS rhs + ON lhs.id = rhs.id + """ + ) + expected_df = pd.DataFrame( + { + "id": [1, 1, 2], + "a": ["a", "b", "c"], + } + ) + + assert_eq(return_df, expected_df, check_index=False) + + def test_join_right(c): return_df = c.sql( """ diff --git a/tests/unit/test_queries.py b/tests/unit/test_queries.py index b32e9530f..67120df82 100644 --- a/tests/unit/test_queries.py +++ b/tests/unit/test_queries.py @@ -36,7 +36,6 @@ 77, 80, 86, - 87, 88, 89, 92,