Skip to content

Commit

Permalink
Update test_product.py
Browse files Browse the repository at this point in the history
  • Loading branch information
MaxHalford committed Jun 5, 2023
1 parent 2444d4c commit 2bad2b5
Showing 1 changed file with 28 additions and 2 deletions.
30 changes: 28 additions & 2 deletions river/compose/test_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ def test_issue_1253():
>>> import numpy as np
>>> import pandas as pd
>>> from river import compose, preprocessing
>>> from sklearn import datasets
>>> from river import compat, compose, preprocessing
>>> from sklearn import datasets, linear_model
>>> np.random.seed(1000)
>>> X, y = datasets.make_regression(n_samples=5_000, n_features=2)
Expand All @@ -86,6 +86,32 @@ def test_issue_1253():
>>> XT.sparse.to_dense().memory_usage().sum() // 1000
4455
>>> X, y = datasets.make_regression(n_samples=6, n_features=2)
>>> X = pd.DataFrame(X)
>>> X.columns = ['feat_1','feat_2']
>>> X['cat'] = np.random.randint(1, 3, X.shape[0])
>>> y = pd.Series(y)
>>> group1 = compose.Select('cat') | preprocessing.OneHotEncoder()
>>> group2 = compose.Select('feat_2') | preprocessing.StandardScaler()
>>> sparsify = lambda X: X.astype({
... key: pd.SparseDtype(X.dtypes[key].type, fill_value=0)
... for key in X.dtypes.keys()
... })
>>> model = (
... (group1 + group1 * group2) |
... compose.FuncTransformer(sparsify) |
... compat.convert_sklearn_to_river(linear_model.SGDRegressor(max_iter=3))
... )
>>> _ = model.predict_many(X)
>>> model.transform_many(X)
cat_1*feat_2 cat_2*feat_2 cat_1 cat_2
0 -1.196841 0.000000 1 0
1 1.304619 0.000000 1 0
2 -1.294091 0.000000 1 0
3 0.287426 0.000000 1 0
4 -0.143960 0.000000 1 0
5 0.000000 1.042847 0 1
"""


Expand Down

0 comments on commit 2bad2b5

Please sign in to comment.