Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unblock CI failures from scikit-learn 1.4.0, pandas 2.2.0 #1295

Merged
merged 8 commits into from
Jan 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion continuous_integration/environment-3.10.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ dependencies:
- python=3.10
- scikit-learn>=1.0.0
- sphinx
- sqlalchemy<2
- sqlalchemy
- tpot>=0.12.0
# FIXME: https://github.com/fugue-project/fugue/issues/526
- triad<0.9.2
Expand Down
2 changes: 1 addition & 1 deletion continuous_integration/environment-3.11.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ dependencies:
- python=3.11
- scikit-learn>=1.0.0
- sphinx
- sqlalchemy<2
- sqlalchemy
- tpot>=0.12.0
# FIXME: https://github.com/fugue-project/fugue/issues/526
- triad<0.9.2
Expand Down
2 changes: 1 addition & 1 deletion continuous_integration/environment-3.12.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ dependencies:
- python=3.12
- scikit-learn>=1.0.0
- sphinx
- sqlalchemy<2
- sqlalchemy
- tpot>=0.12.0
# FIXME: https://github.com/fugue-project/fugue/issues/526
- triad<0.9.2
Expand Down
2 changes: 1 addition & 1 deletion continuous_integration/environment-3.9.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ dependencies:
- python=3.9
- scikit-learn=1.0.0
- sphinx
- sqlalchemy<2
- sqlalchemy
- tpot>=0.12.0
# FIXME: https://github.com/fugue-project/fugue/issues/526
- triad<0.9.2
Expand Down
2 changes: 1 addition & 1 deletion continuous_integration/gpuci/environment-3.10.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ dependencies:
- python=3.10
- scikit-learn>=1.0.0
- sphinx
- sqlalchemy<2
- sqlalchemy
- tpot>=0.12.0
# FIXME: https://github.com/fugue-project/fugue/issues/526
- triad<0.9.2
Expand Down
2 changes: 1 addition & 1 deletion continuous_integration/gpuci/environment-3.9.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ dependencies:
- python=3.9
- scikit-learn>=1.0.0
- sphinx
- sqlalchemy<2
- sqlalchemy
- tpot>=0.12.0
# FIXME: https://github.com/fugue-project/fugue/issues/526
- triad<0.9.2
Expand Down
24 changes: 17 additions & 7 deletions dask_sql/input_utils/hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,12 @@ class HiveInputPlugin(BaseInputPlugin):
def is_correct_input(
self, input_item: Any, table_name: str, format: str = None, **kwargs
):
is_sqlalchemy_hive = sqlalchemy and isinstance(
input_item, sqlalchemy.engine.base.Connection
)
is_hive_cursor = hive and isinstance(input_item, hive.Cursor)

return is_sqlalchemy_hive or is_hive_cursor or format == "hive"
return self.is_sqlalchemy_hive(input_item) or is_hive_cursor or format == "hive"

def is_sqlalchemy_hive(self, input_item: Any):
return sqlalchemy and isinstance(input_item, sqlalchemy.engine.base.Connection)

def to_dc(
self,
Expand Down Expand Up @@ -201,7 +201,11 @@ def _parse_hive_table_description(
of the DESCRIBE FORMATTED call, which is unfortunately
in a format not easily readable by machines.
"""
cursor.execute(f"USE {schema}")
cursor.execute(
sqlalchemy.text(f"USE {schema}")
if self.is_sqlalchemy_hive(cursor)
else f"USE {schema}"
)
if partition:
# Hive wants quoted, comma separated list of partition keys
partition = partition.replace("=", '="')
Expand Down Expand Up @@ -283,7 +287,11 @@ def _parse_hive_partition_description(
"""
Extract all partition informaton for a given table
"""
cursor.execute(f"USE {schema}")
cursor.execute(
sqlalchemy.text(f"USE {schema}")
if self.is_sqlalchemy_hive(cursor)
else f"USE {schema}"
)
result = self._fetch_all_results(cursor, f"SHOW PARTITIONS {table_name}")

return [row[0] for row in result]
Expand All @@ -298,7 +306,9 @@ def _fetch_all_results(
The former has the fetchall method on the cursor,
whereas the latter on the executed query.
"""
result = cursor.execute(sql)
result = cursor.execute(
sqlalchemy.text(sql) if self.is_sqlalchemy_hive(cursor) else sql
)

try:
return result.fetchall()
Expand Down
4 changes: 4 additions & 0 deletions tests/integration/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,10 @@ def _assert_query_gives_same_result(query, sort_columns=None, **kwargs):
# as expressions are handled differently
dask_result.columns = sql_result.columns

# replace all pd.NA scalars, which are resistent to
# check_dype=False and .astype()
dask_result = dask_result.replace({pd.NA: None})

if sort_columns:
sql_result = sql_result.sort_values(sort_columns)
dask_result = dask_result.sort_values(sort_columns)
Expand Down
36 changes: 27 additions & 9 deletions tests/integration/test_hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,25 +142,43 @@ def hive_cursor():

# Create a non-partitioned column
cursor.execute(
f"CREATE TABLE df (i INTEGER, j INTEGER) ROW FORMAT DELIMITED STORED AS PARQUET LOCATION '{tmpdir}'"
sqlalchemy.text(
f"CREATE TABLE df (i INTEGER, j INTEGER) ROW FORMAT DELIMITED STORED AS PARQUET LOCATION '{tmpdir}'"
)
)
cursor.execute("INSERT INTO df (i, j) VALUES (1, 2)")
cursor.execute("INSERT INTO df (i, j) VALUES (2, 4)")
cursor.execute(sqlalchemy.text("INSERT INTO df (i, j) VALUES (1, 2)"))
cursor.execute(sqlalchemy.text("INSERT INTO df (i, j) VALUES (2, 4)"))

cursor.execute(
f"CREATE TABLE df_part (i INTEGER) PARTITIONED BY (j INTEGER) ROW FORMAT DELIMITED STORED AS PARQUET LOCATION '{tmpdir_parted}'"
sqlalchemy.text(
f"CREATE TABLE df_part (i INTEGER) PARTITIONED BY (j INTEGER) ROW FORMAT DELIMITED STORED AS PARQUET LOCATION '{tmpdir_parted}'"
)
)
cursor.execute(
sqlalchemy.text("INSERT INTO df_part PARTITION (j=2) (i) VALUES (1)")
)
cursor.execute(
sqlalchemy.text("INSERT INTO df_part PARTITION (j=4) (i) VALUES (2)")
)
cursor.execute("INSERT INTO df_part PARTITION (j=2) (i) VALUES (1)")
cursor.execute("INSERT INTO df_part PARTITION (j=4) (i) VALUES (2)")

cursor.execute(
f"""
sqlalchemy.text(
f"""
CREATE TABLE df_parts (i INTEGER) PARTITIONED BY (j INTEGER, k STRING)
ROW FORMAT DELIMITED STORED AS PARQUET LOCATION '{tmpdir_multiparted}'
"""
)
)
cursor.execute(
sqlalchemy.text(
"INSERT INTO df_parts PARTITION (j=1, k='a') (i) VALUES (1)"
)
)
cursor.execute(
sqlalchemy.text(
"INSERT INTO df_parts PARTITION (j=2, k='b') (i) VALUES (2)"
)
)
cursor.execute("INSERT INTO df_parts PARTITION (j=1, k='a') (i) VALUES (1)")
cursor.execute("INSERT INTO df_parts PARTITION (j=2, k='b') (i) VALUES (2)")

# The data files are created as root user by default. Change that:
hive_server.exec_run(["chmod", "a+rwx", "-R", tmpdir])
Expand Down
19 changes: 13 additions & 6 deletions tests/integration/test_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,6 @@ def test_join_reorder(c):
SELECT a1, b2, c3
FROM a, b, c
WHERE b1 < 3 AND c3 < 5 AND a1 = b1 AND b2 = c2
LIMIT 10
"""

explain_string = c.explain(query)
Expand Down Expand Up @@ -491,15 +490,20 @@ def test_join_reorder(c):
assert explain_string.index(second_join) < explain_string.index(first_join)

result_df = c.sql(query)
expected_df = pd.DataFrame({"a1": [1] * 10, "b2": [2] * 10, "c3": [4] * 10})
assert_eq(result_df, expected_df)
merged_df = df.merge(df2, left_on="a1", right_on="b1").merge(
df3, left_on="b2", right_on="c2"
)
expected_df = merged_df[(merged_df["b1"] < 3) & (merged_df["c3"] < 5)][
["a1", "b2", "c3"]
]

assert_eq(result_df, expected_df, check_index=False)

# By default, join reordering should NOT reorder unfiltered dimension tables
query = """
SELECT a1, b2, c3
FROM a, b, c
WHERE a1 = b1 AND b2 = c2
LIMIT 10
"""

explain_string = c.explain(query)
Expand All @@ -510,8 +514,11 @@ def test_join_reorder(c):
assert explain_string.index(second_join) < explain_string.index(first_join)

result_df = c.sql(query)
expected_df = pd.DataFrame({"a1": [1] * 10, "b2": [2] * 10, "c3": [4, 5] * 5})
assert_eq(result_df, expected_df)
expected_df = df.merge(df2, left_on="a1", right_on="b1").merge(
df3, left_on="b2", right_on="c2"
)[["a1", "b2", "c3"]]

assert_eq(result_df, expected_df, check_index=False)


@pytest.mark.xfail(
Expand Down
8 changes: 7 additions & 1 deletion tests/integration/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import joblib
import pandas as pd
import pytest
from packaging.version import parse as parseVersion

from tests.utils import assert_eq

Expand All @@ -17,6 +18,10 @@
xgboost = None
dask_cudf = None

sklearn = pytest.importorskip("sklearn")

SKLEARN_GT_130 = parseVersion(sklearn.__version__) >= parseVersion("1.4")


def check_trained_model(c, model_name="my_model", df_name="timeseries"):
sql = f"""
Expand Down Expand Up @@ -902,10 +907,10 @@ def test_ml_experiment(c, client):
)


@pytest.mark.xfail(reason="tpot is broken with sklearn>=1.4", condition=SKLEARN_GT_130)
def test_experiment_automl_classifier(c, client):
tpot = pytest.importorskip("tpot", reason="tpot not installed")

# currently tested with tpot==
c.sql(
"""
CREATE EXPERIMENT my_automl_exp1 WITH (
Expand All @@ -927,6 +932,7 @@ def test_experiment_automl_classifier(c, client):
check_trained_model(c, "my_automl_exp1")


@pytest.mark.xfail(reason="tpot is broken with sklearn>=1.4", condition=SKLEARN_GT_130)
def test_experiment_automl_regressor(c, client):
tpot = pytest.importorskip("tpot", reason="tpot not installed")

Expand Down
Loading