From c456202dfe4ac9ef9ffc45f3ce8ee0074eb5ffa8 Mon Sep 17 00:00:00 2001 From: Charles Blackmon-Luca <20627856+charlesbluca@users.noreply.github.com> Date: Mon, 22 Jan 2024 06:22:30 -0800 Subject: [PATCH 1/8] Pin sklearn to <1.4 --- continuous_integration/docker/conda.txt | 2 +- continuous_integration/docker/main.dockerfile | 2 +- continuous_integration/environment-3.10.yaml | 2 +- continuous_integration/environment-3.11.yaml | 2 +- continuous_integration/environment-3.12.yaml | 2 +- continuous_integration/gpuci/environment-3.10.yaml | 2 +- continuous_integration/gpuci/environment-3.9.yaml | 2 +- pyproject.toml | 2 +- 8 files changed, 8 insertions(+), 8 deletions(-) diff --git a/continuous_integration/docker/conda.txt b/continuous_integration/docker/conda.txt index 270c2febd..0b3378e27 100644 --- a/continuous_integration/docker/conda.txt +++ b/continuous_integration/docker/conda.txt @@ -16,7 +16,7 @@ uvicorn>=0.13.4 pyarrow>=6.0.2 prompt_toolkit>=3.0.8 pygments>=2.7.1 -scikit-learn>=1.0.0 +scikit-learn>=1.0,<1.4 intake>=0.6.0 pre-commit>=2.11.1 black=22.10.0 diff --git a/continuous_integration/docker/main.dockerfile b/continuous_integration/docker/main.dockerfile index 78cd46938..f6a2ade72 100644 --- a/continuous_integration/docker/main.dockerfile +++ b/continuous_integration/docker/main.dockerfile @@ -27,7 +27,7 @@ RUN mamba install -y \ tabulate \ # additional dependencies "pyarrow>=6.0.2" \ - "scikit-learn>=1.0.0" \ + "scikit-learn>=1.0,<1.4" \ "intake>=0.6.0" \ && conda clean -ay diff --git a/continuous_integration/environment-3.10.yaml b/continuous_integration/environment-3.10.yaml index b0557a915..f888fef74 100644 --- a/continuous_integration/environment-3.10.yaml +++ b/continuous_integration/environment-3.10.yaml @@ -26,7 +26,7 @@ dependencies: - pytest-xdist - pytest - python=3.10 -- scikit-learn>=1.0.0 +- scikit-learn>=1.0,<1.4 - sphinx - sqlalchemy<2 - tpot>=0.12.0 diff --git a/continuous_integration/environment-3.11.yaml b/continuous_integration/environment-3.11.yaml index 1bcf46d45..07c98f590 100644 --- a/continuous_integration/environment-3.11.yaml +++ b/continuous_integration/environment-3.11.yaml @@ -26,7 +26,7 @@ dependencies: - pytest-xdist - pytest - python=3.11 -- scikit-learn>=1.0.0 +- scikit-learn>=1.0,<1.4 - sphinx - sqlalchemy<2 - tpot>=0.12.0 diff --git a/continuous_integration/environment-3.12.yaml b/continuous_integration/environment-3.12.yaml index 18a67409b..288e6af9c 100644 --- a/continuous_integration/environment-3.12.yaml +++ b/continuous_integration/environment-3.12.yaml @@ -27,7 +27,7 @@ dependencies: - pytest-xdist - pytest - python=3.12 -- scikit-learn>=1.0.0 +- scikit-learn>=1.0,<1.4 - sphinx - sqlalchemy<2 - tpot>=0.12.0 diff --git a/continuous_integration/gpuci/environment-3.10.yaml b/continuous_integration/gpuci/environment-3.10.yaml index 2420e949f..74d520546 100644 --- a/continuous_integration/gpuci/environment-3.10.yaml +++ b/continuous_integration/gpuci/environment-3.10.yaml @@ -31,7 +31,7 @@ dependencies: - pytest-xdist - pytest - python=3.10 -- scikit-learn>=1.0.0 +- scikit-learn>=1.0,<1.4 - sphinx - sqlalchemy<2 - tpot>=0.12.0 diff --git a/continuous_integration/gpuci/environment-3.9.yaml b/continuous_integration/gpuci/environment-3.9.yaml index f88cf57c7..b09f6b292 100644 --- a/continuous_integration/gpuci/environment-3.9.yaml +++ b/continuous_integration/gpuci/environment-3.9.yaml @@ -31,7 +31,7 @@ dependencies: - pytest-xdist - pytest - python=3.9 -- scikit-learn>=1.0.0 +- scikit-learn>=1.0,<1.4 - sphinx - sqlalchemy<2 - tpot>=0.12.0 diff --git a/pyproject.toml b/pyproject.toml index 75ec4519f..978c058e2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,7 +52,7 @@ dev = [ "mock>=4.0.3", "sphinx>=3.2.1", "pyarrow>=6.0.2", - "scikit-learn>=1.0.0", + "scikit-learn>=1.0,<1.4", "intake>=0.6.0", "pre-commit", "black==22.10.0", From 0b85c0fd4de332417cf512ee5ff88c64d4262584 Mon Sep 17 00:00:00 2001 From: Charles Blackmon-Luca <20627856+charlesbluca@users.noreply.github.com> Date: Mon, 22 Jan 2024 06:35:28 -0800 Subject: [PATCH 2/8] Unpin sqlalchemy<2 --- continuous_integration/environment-3.10.yaml | 2 +- continuous_integration/environment-3.11.yaml | 2 +- continuous_integration/environment-3.12.yaml | 2 +- continuous_integration/environment-3.9.yaml | 2 +- continuous_integration/gpuci/environment-3.10.yaml | 2 +- continuous_integration/gpuci/environment-3.9.yaml | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/continuous_integration/environment-3.10.yaml b/continuous_integration/environment-3.10.yaml index f888fef74..9667a06a7 100644 --- a/continuous_integration/environment-3.10.yaml +++ b/continuous_integration/environment-3.10.yaml @@ -28,7 +28,7 @@ dependencies: - python=3.10 - scikit-learn>=1.0,<1.4 - sphinx -- sqlalchemy<2 +- sqlalchemy - tpot>=0.12.0 # FIXME: https://github.com/fugue-project/fugue/issues/526 - triad<0.9.2 diff --git a/continuous_integration/environment-3.11.yaml b/continuous_integration/environment-3.11.yaml index 07c98f590..f8d7eb83d 100644 --- a/continuous_integration/environment-3.11.yaml +++ b/continuous_integration/environment-3.11.yaml @@ -28,7 +28,7 @@ dependencies: - python=3.11 - scikit-learn>=1.0,<1.4 - sphinx -- sqlalchemy<2 +- sqlalchemy - tpot>=0.12.0 # FIXME: https://github.com/fugue-project/fugue/issues/526 - triad<0.9.2 diff --git a/continuous_integration/environment-3.12.yaml b/continuous_integration/environment-3.12.yaml index 288e6af9c..e3adbb97b 100644 --- a/continuous_integration/environment-3.12.yaml +++ b/continuous_integration/environment-3.12.yaml @@ -29,7 +29,7 @@ dependencies: - python=3.12 - scikit-learn>=1.0,<1.4 - sphinx -- sqlalchemy<2 +- sqlalchemy - tpot>=0.12.0 # FIXME: https://github.com/fugue-project/fugue/issues/526 - triad<0.9.2 diff --git a/continuous_integration/environment-3.9.yaml b/continuous_integration/environment-3.9.yaml index a627318c1..f9f8e9ebf 100644 --- a/continuous_integration/environment-3.9.yaml +++ b/continuous_integration/environment-3.9.yaml @@ -28,7 +28,7 @@ dependencies: - python=3.9 - scikit-learn=1.0.0 - sphinx -- sqlalchemy<2 +- sqlalchemy - tpot>=0.12.0 # FIXME: https://github.com/fugue-project/fugue/issues/526 - triad<0.9.2 diff --git a/continuous_integration/gpuci/environment-3.10.yaml b/continuous_integration/gpuci/environment-3.10.yaml index 74d520546..9ea5c4a00 100644 --- a/continuous_integration/gpuci/environment-3.10.yaml +++ b/continuous_integration/gpuci/environment-3.10.yaml @@ -33,7 +33,7 @@ dependencies: - python=3.10 - scikit-learn>=1.0,<1.4 - sphinx -- sqlalchemy<2 +- sqlalchemy - tpot>=0.12.0 # FIXME: https://github.com/fugue-project/fugue/issues/526 - triad<0.9.2 diff --git a/continuous_integration/gpuci/environment-3.9.yaml b/continuous_integration/gpuci/environment-3.9.yaml index b09f6b292..83b616f5c 100644 --- a/continuous_integration/gpuci/environment-3.9.yaml +++ b/continuous_integration/gpuci/environment-3.9.yaml @@ -33,7 +33,7 @@ dependencies: - python=3.9 - scikit-learn>=1.0,<1.4 - sphinx -- sqlalchemy<2 +- sqlalchemy - tpot>=0.12.0 # FIXME: https://github.com/fugue-project/fugue/issues/526 - triad<0.9.2 From a8ec65dca8a1edb6f8b0ecd7d7e5f4648b947004 Mon Sep 17 00:00:00 2001 From: Charles Blackmon-Luca <20627856+charlesbluca@users.noreply.github.com> Date: Mon, 22 Jan 2024 07:45:24 -0800 Subject: [PATCH 3/8] Refactor pyhive input/tests for sqlalchemy 2 --- dask_sql/input_utils/hive.py | 24 ++++++++++++++++------- tests/integration/test_hive.py | 36 +++++++++++++++++++++++++--------- 2 files changed, 44 insertions(+), 16 deletions(-) diff --git a/dask_sql/input_utils/hive.py b/dask_sql/input_utils/hive.py index 14bc547f0..b65e4d5ce 100644 --- a/dask_sql/input_utils/hive.py +++ b/dask_sql/input_utils/hive.py @@ -30,12 +30,12 @@ class HiveInputPlugin(BaseInputPlugin): def is_correct_input( self, input_item: Any, table_name: str, format: str = None, **kwargs ): - is_sqlalchemy_hive = sqlalchemy and isinstance( - input_item, sqlalchemy.engine.base.Connection - ) is_hive_cursor = hive and isinstance(input_item, hive.Cursor) - return is_sqlalchemy_hive or is_hive_cursor or format == "hive" + return self.is_sqlalchemy_hive(input_item) or is_hive_cursor or format == "hive" + + def is_sqlalchemy_hive(self, input_item: Any): + return sqlalchemy and isinstance(input_item, sqlalchemy.engine.base.Connection) def to_dc( self, @@ -201,7 +201,11 @@ def _parse_hive_table_description( of the DESCRIBE FORMATTED call, which is unfortunately in a format not easily readable by machines. """ - cursor.execute(f"USE {schema}") + cursor.execute( + sqlalchemy.text(f"USE {schema}") + if self.is_sqlalchemy_hive(cursor) + else f"USE {schema}" + ) if partition: # Hive wants quoted, comma separated list of partition keys partition = partition.replace("=", '="') @@ -283,7 +287,11 @@ def _parse_hive_partition_description( """ Extract all partition informaton for a given table """ - cursor.execute(f"USE {schema}") + cursor.execute( + sqlalchemy.text(f"USE {schema}") + if self.is_sqlalchemy_hive(cursor) + else f"USE {schema}" + ) result = self._fetch_all_results(cursor, f"SHOW PARTITIONS {table_name}") return [row[0] for row in result] @@ -298,7 +306,9 @@ def _fetch_all_results( The former has the fetchall method on the cursor, whereas the latter on the executed query. """ - result = cursor.execute(sql) + result = cursor.execute( + sqlalchemy.text(sql) if self.is_sqlalchemy_hive(cursor) else sql + ) try: return result.fetchall() diff --git a/tests/integration/test_hive.py b/tests/integration/test_hive.py index 1a86082c1..17f4c1a98 100644 --- a/tests/integration/test_hive.py +++ b/tests/integration/test_hive.py @@ -142,25 +142,43 @@ def hive_cursor(): # Create a non-partitioned column cursor.execute( - f"CREATE TABLE df (i INTEGER, j INTEGER) ROW FORMAT DELIMITED STORED AS PARQUET LOCATION '{tmpdir}'" + sqlalchemy.text( + f"CREATE TABLE df (i INTEGER, j INTEGER) ROW FORMAT DELIMITED STORED AS PARQUET LOCATION '{tmpdir}'" + ) ) - cursor.execute("INSERT INTO df (i, j) VALUES (1, 2)") - cursor.execute("INSERT INTO df (i, j) VALUES (2, 4)") + cursor.execute(sqlalchemy.text("INSERT INTO df (i, j) VALUES (1, 2)")) + cursor.execute(sqlalchemy.text("INSERT INTO df (i, j) VALUES (2, 4)")) cursor.execute( - f"CREATE TABLE df_part (i INTEGER) PARTITIONED BY (j INTEGER) ROW FORMAT DELIMITED STORED AS PARQUET LOCATION '{tmpdir_parted}'" + sqlalchemy.text( + f"CREATE TABLE df_part (i INTEGER) PARTITIONED BY (j INTEGER) ROW FORMAT DELIMITED STORED AS PARQUET LOCATION '{tmpdir_parted}'" + ) + ) + cursor.execute( + sqlalchemy.text("INSERT INTO df_part PARTITION (j=2) (i) VALUES (1)") + ) + cursor.execute( + sqlalchemy.text("INSERT INTO df_part PARTITION (j=4) (i) VALUES (2)") ) - cursor.execute("INSERT INTO df_part PARTITION (j=2) (i) VALUES (1)") - cursor.execute("INSERT INTO df_part PARTITION (j=4) (i) VALUES (2)") cursor.execute( - f""" + sqlalchemy.text( + f""" CREATE TABLE df_parts (i INTEGER) PARTITIONED BY (j INTEGER, k STRING) ROW FORMAT DELIMITED STORED AS PARQUET LOCATION '{tmpdir_multiparted}' """ + ) + ) + cursor.execute( + sqlalchemy.text( + "INSERT INTO df_parts PARTITION (j=1, k='a') (i) VALUES (1)" + ) + ) + cursor.execute( + sqlalchemy.text( + "INSERT INTO df_parts PARTITION (j=2, k='b') (i) VALUES (2)" + ) ) - cursor.execute("INSERT INTO df_parts PARTITION (j=1, k='a') (i) VALUES (1)") - cursor.execute("INSERT INTO df_parts PARTITION (j=2, k='b') (i) VALUES (2)") # The data files are created as root user by default. Change that: hive_server.exec_run(["chmod", "a+rwx", "-R", tmpdir]) From ba6bee35b4793dfa22d2310701e3565e3e167c97 Mon Sep 17 00:00:00 2001 From: Charles Blackmon-Luca <20627856+charlesbluca@users.noreply.github.com> Date: Mon, 22 Jan 2024 08:11:29 -0800 Subject: [PATCH 4/8] Use astype to normalize dtypes in _assert_query_gives_same_result --- tests/integration/fixtures.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tests/integration/fixtures.py b/tests/integration/fixtures.py index 90b6f3828..d9ee2c858 100644 --- a/tests/integration/fixtures.py +++ b/tests/integration/fixtures.py @@ -335,6 +335,10 @@ def _assert_query_gives_same_result(query, sort_columns=None, **kwargs): # as expressions are handled differently dask_result.columns = sql_result.columns + # nullable dtypes make it such that `check_dtype=False` isn't sufficient to + # normalize Postgres & Dask results + dask_result = dask_result.astype(sql_result.dtypes.to_dict()) + if sort_columns: sql_result = sql_result.sort_values(sort_columns) dask_result = dask_result.sort_values(sort_columns) @@ -342,7 +346,9 @@ def _assert_query_gives_same_result(query, sort_columns=None, **kwargs): sql_result = sql_result.reset_index(drop=True) dask_result = dask_result.reset_index(drop=True) - assert_eq(sql_result, dask_result, check_dtype=False, **kwargs) + dask_result = dask_result.astype(sql_result.dtypes.to_dict()) + + assert_eq(sql_result, dask_result, **kwargs) return _assert_query_gives_same_result From 3f96a5c95c3f09f956d2ac6e9053cd638482d0c0 Mon Sep 17 00:00:00 2001 From: Charles Blackmon-Luca <20627856+charlesbluca@users.noreply.github.com> Date: Mon, 22 Jan 2024 08:50:01 -0800 Subject: [PATCH 5/8] Refine pd.NA normalization in _assert_query_gives_same_result --- tests/integration/fixtures.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/tests/integration/fixtures.py b/tests/integration/fixtures.py index d9ee2c858..cd4e38928 100644 --- a/tests/integration/fixtures.py +++ b/tests/integration/fixtures.py @@ -335,9 +335,9 @@ def _assert_query_gives_same_result(query, sort_columns=None, **kwargs): # as expressions are handled differently dask_result.columns = sql_result.columns - # nullable dtypes make it such that `check_dtype=False` isn't sufficient to - # normalize Postgres & Dask results - dask_result = dask_result.astype(sql_result.dtypes.to_dict()) + # replace all pd.NA scalars, which are resistent to + # check_dype=False and .astype() + dask_result = dask_result.replace({pd.NA: None}) if sort_columns: sql_result = sql_result.sort_values(sort_columns) @@ -346,9 +346,7 @@ def _assert_query_gives_same_result(query, sort_columns=None, **kwargs): sql_result = sql_result.reset_index(drop=True) dask_result = dask_result.reset_index(drop=True) - dask_result = dask_result.astype(sql_result.dtypes.to_dict()) - - assert_eq(sql_result, dask_result, **kwargs) + assert_eq(sql_result, dask_result, check_dtype=False, **kwargs) return _assert_query_gives_same_result From bf770692408437fdcbefc6ce38ea6a5988fc4076 Mon Sep 17 00:00:00 2001 From: Charles Blackmon-Luca <20627856+charlesbluca@users.noreply.github.com> Date: Mon, 22 Jan 2024 11:05:56 -0800 Subject: [PATCH 6/8] Explicitly compute pandas result in test_join_reorder --- tests/integration/test_join.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/tests/integration/test_join.py b/tests/integration/test_join.py index 3f19a3211..e47721108 100644 --- a/tests/integration/test_join.py +++ b/tests/integration/test_join.py @@ -463,7 +463,6 @@ def test_join_reorder(c): SELECT a1, b2, c3 FROM a, b, c WHERE b1 < 3 AND c3 < 5 AND a1 = b1 AND b2 = c2 - LIMIT 10 """ explain_string = c.explain(query) @@ -491,15 +490,20 @@ def test_join_reorder(c): assert explain_string.index(second_join) < explain_string.index(first_join) result_df = c.sql(query) - expected_df = pd.DataFrame({"a1": [1] * 10, "b2": [2] * 10, "c3": [4] * 10}) - assert_eq(result_df, expected_df) + merged_df = df.merge(df2, left_on="a1", right_on="b1").merge( + df3, left_on="b2", right_on="c2" + ) + expected_df = merged_df[(merged_df["b1"] < 3) & (merged_df["c3"] < 5)][ + ["a1", "b2", "c3"] + ] + + assert_eq(result_df, expected_df, check_index=False) # By default, join reordering should NOT reorder unfiltered dimension tables query = """ SELECT a1, b2, c3 FROM a, b, c WHERE a1 = b1 AND b2 = c2 - LIMIT 10 """ explain_string = c.explain(query) @@ -510,8 +514,11 @@ def test_join_reorder(c): assert explain_string.index(second_join) < explain_string.index(first_join) result_df = c.sql(query) - expected_df = pd.DataFrame({"a1": [1] * 10, "b2": [2] * 10, "c3": [4, 5] * 5}) - assert_eq(result_df, expected_df) + expected_df = df.merge(df2, left_on="a1", right_on="b1").merge( + df3, left_on="b2", right_on="c2" + )[["a1", "b2", "c3"]] + + assert_eq(result_df, expected_df, check_index=False) @pytest.mark.xfail( From d3c2e5d4caeac8ae7f544145474f644c16155cb7 Mon Sep 17 00:00:00 2001 From: Charles Blackmon-Luca <20627856+charlesbluca@users.noreply.github.com> Date: Mon, 22 Jan 2024 12:04:01 -0800 Subject: [PATCH 7/8] xfail tpot tests, unpin sklearn --- continuous_integration/docker/conda.txt | 2 +- continuous_integration/docker/main.dockerfile | 2 +- continuous_integration/environment-3.10.yaml | 2 +- continuous_integration/environment-3.11.yaml | 2 +- continuous_integration/environment-3.12.yaml | 2 +- continuous_integration/gpuci/environment-3.10.yaml | 2 +- continuous_integration/gpuci/environment-3.9.yaml | 2 +- pyproject.toml | 2 +- tests/integration/test_model.py | 10 +++++++++- 9 files changed, 17 insertions(+), 9 deletions(-) diff --git a/continuous_integration/docker/conda.txt b/continuous_integration/docker/conda.txt index 0b3378e27..270c2febd 100644 --- a/continuous_integration/docker/conda.txt +++ b/continuous_integration/docker/conda.txt @@ -16,7 +16,7 @@ uvicorn>=0.13.4 pyarrow>=6.0.2 prompt_toolkit>=3.0.8 pygments>=2.7.1 -scikit-learn>=1.0,<1.4 +scikit-learn>=1.0.0 intake>=0.6.0 pre-commit>=2.11.1 black=22.10.0 diff --git a/continuous_integration/docker/main.dockerfile b/continuous_integration/docker/main.dockerfile index f6a2ade72..78cd46938 100644 --- a/continuous_integration/docker/main.dockerfile +++ b/continuous_integration/docker/main.dockerfile @@ -27,7 +27,7 @@ RUN mamba install -y \ tabulate \ # additional dependencies "pyarrow>=6.0.2" \ - "scikit-learn>=1.0,<1.4" \ + "scikit-learn>=1.0.0" \ "intake>=0.6.0" \ && conda clean -ay diff --git a/continuous_integration/environment-3.10.yaml b/continuous_integration/environment-3.10.yaml index 9667a06a7..912e2c54e 100644 --- a/continuous_integration/environment-3.10.yaml +++ b/continuous_integration/environment-3.10.yaml @@ -26,7 +26,7 @@ dependencies: - pytest-xdist - pytest - python=3.10 -- scikit-learn>=1.0,<1.4 +- scikit-learn>=1.0.0 - sphinx - sqlalchemy - tpot>=0.12.0 diff --git a/continuous_integration/environment-3.11.yaml b/continuous_integration/environment-3.11.yaml index f8d7eb83d..cd77ac8d5 100644 --- a/continuous_integration/environment-3.11.yaml +++ b/continuous_integration/environment-3.11.yaml @@ -26,7 +26,7 @@ dependencies: - pytest-xdist - pytest - python=3.11 -- scikit-learn>=1.0,<1.4 +- scikit-learn>=1.0.0 - sphinx - sqlalchemy - tpot>=0.12.0 diff --git a/continuous_integration/environment-3.12.yaml b/continuous_integration/environment-3.12.yaml index e3adbb97b..53b52e629 100644 --- a/continuous_integration/environment-3.12.yaml +++ b/continuous_integration/environment-3.12.yaml @@ -27,7 +27,7 @@ dependencies: - pytest-xdist - pytest - python=3.12 -- scikit-learn>=1.0,<1.4 +- scikit-learn>=1.0.0 - sphinx - sqlalchemy - tpot>=0.12.0 diff --git a/continuous_integration/gpuci/environment-3.10.yaml b/continuous_integration/gpuci/environment-3.10.yaml index 9ea5c4a00..6d567d498 100644 --- a/continuous_integration/gpuci/environment-3.10.yaml +++ b/continuous_integration/gpuci/environment-3.10.yaml @@ -31,7 +31,7 @@ dependencies: - pytest-xdist - pytest - python=3.10 -- scikit-learn>=1.0,<1.4 +- scikit-learn>=1.0.0 - sphinx - sqlalchemy - tpot>=0.12.0 diff --git a/continuous_integration/gpuci/environment-3.9.yaml b/continuous_integration/gpuci/environment-3.9.yaml index 83b616f5c..1e2c50efb 100644 --- a/continuous_integration/gpuci/environment-3.9.yaml +++ b/continuous_integration/gpuci/environment-3.9.yaml @@ -31,7 +31,7 @@ dependencies: - pytest-xdist - pytest - python=3.9 -- scikit-learn>=1.0,<1.4 +- scikit-learn>=1.0.0 - sphinx - sqlalchemy - tpot>=0.12.0 diff --git a/pyproject.toml b/pyproject.toml index 978c058e2..75ec4519f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,7 +52,7 @@ dev = [ "mock>=4.0.3", "sphinx>=3.2.1", "pyarrow>=6.0.2", - "scikit-learn>=1.0,<1.4", + "scikit-learn>=1.0.0", "intake>=0.6.0", "pre-commit", "black==22.10.0", diff --git a/tests/integration/test_model.py b/tests/integration/test_model.py index 973802fe4..e2a609797 100644 --- a/tests/integration/test_model.py +++ b/tests/integration/test_model.py @@ -6,6 +6,9 @@ import pandas as pd import pytest +from packaging.version import parse as parseVersion + + from tests.utils import assert_eq try: @@ -17,6 +20,10 @@ xgboost = None dask_cudf = None +sklearn = pytest.importorskip("sklearn") + +SKLEARN_GT_130 = parseVersion(sklearn.__version__) >= parseVersion("1.4") + def check_trained_model(c, model_name="my_model", df_name="timeseries"): sql = f""" @@ -902,10 +909,10 @@ def test_ml_experiment(c, client): ) +@pytest.mark.xfail(reason="tpot is broken with sklearn>=1.4", condition=SKLEARN_GT_130) def test_experiment_automl_classifier(c, client): tpot = pytest.importorskip("tpot", reason="tpot not installed") - # currently tested with tpot== c.sql( """ CREATE EXPERIMENT my_automl_exp1 WITH ( @@ -927,6 +934,7 @@ def test_experiment_automl_classifier(c, client): check_trained_model(c, "my_automl_exp1") +@pytest.mark.xfail(reason="tpot is broken with sklearn>=1.4", condition=SKLEARN_GT_130) def test_experiment_automl_regressor(c, client): tpot = pytest.importorskip("tpot", reason="tpot not installed") From d0df6c51041fba34b60d1bfd4aabc7588c2b85e8 Mon Sep 17 00:00:00 2001 From: Charles Blackmon-Luca <20627856+charlesbluca@users.noreply.github.com> Date: Mon, 22 Jan 2024 12:04:20 -0800 Subject: [PATCH 8/8] Linting --- tests/integration/test_model.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/integration/test_model.py b/tests/integration/test_model.py index e2a609797..c341965ce 100644 --- a/tests/integration/test_model.py +++ b/tests/integration/test_model.py @@ -5,10 +5,8 @@ import joblib import pandas as pd import pytest - from packaging.version import parse as parseVersion - from tests.utils import assert_eq try: