Skip to content

Commit

Permalink
more bandit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
MaxHalford committed Jun 8, 2023
1 parent f978972 commit 2090e40
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 3 deletions.
6 changes: 4 additions & 2 deletions river/bandit/exp3.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ def update(self, arm_id, *reward_args, **reward_kwargs):

@classmethod
def _unit_test_params(cls):
yield {"gamma": 0.0}
yield {"gamma": 0}
yield {"gamma": 0.1}
yield {"gamma": 0.5}
yield {"gamma": 1.0}
yield {"gamma": 0.9}
yield {"gamma": 1}
2 changes: 1 addition & 1 deletion river/model_selection/bandit.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def learn_one(self, x, y):
for arm_id in self._pick_arms():
model = self[arm_id]
y_pred = model.predict_one(x)
self.policy.update(arm_id, y_true=y, y_pred=y_pred)
self.policy.update(arm_id, y, y_pred)
model.learn_one(x, y)

return self
Expand Down
118 changes: 118 additions & 0 deletions river/model_selection/test_bandit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
from __future__ import annotations

import importlib
import inspect

import pytest

from river import (
bandit,
datasets,
evaluate,
linear_model,
metrics,
model_selection,
optim,
preprocessing,
)


def test_1259():
"""
https://github.com/online-ml/river/issues/1259
>>> from river import bandit
>>> from river import datasets
>>> from river import evaluate
>>> from river import linear_model
>>> from river import metrics
>>> from river import model_selection
>>> from river import optim
>>> from river import preprocessing
>>> models = [
... linear_model.LogisticRegression(optimizer=optim.SGD(lr=lr))
... for lr in [0.0001, 0.001, 1e-05, 0.01]
... ]
>>> dataset = datasets.Phishing()
>>> model = (
... preprocessing.StandardScaler() |
... model_selection.BanditClassifier(
... models,
... metric=metrics.Accuracy(),
... policy=bandit.Exp3(
... gamma=0.5,
... seed=42
... )
... )
... )
>>> metric = metrics.Accuracy()
>>> evaluate.progressive_val_score(dataset, model, metric)
Accuracy: 87.20%
"""


@pytest.mark.parametrize(
"policy",
[
pytest.param(
policy(**params),
id=f"{policy.__name__}",
)
for _, policy in inspect.getmembers(
importlib.import_module("river.bandit"),
lambda obj: inspect.isclass(obj) and issubclass(obj, bandit.base.Policy),
)
for params in policy._unit_test_params()
if policy.__name__ != "ThompsonSampling"
],
)
def test_bandit_classifier_with_each_policy(policy):
models = [
linear_model.LogisticRegression(optimizer=optim.SGD(lr=lr))
for lr in [0.0001, 0.001, 1e-05, 0.01]
]

dataset = datasets.Phishing()
model = preprocessing.StandardScaler() | model_selection.BanditClassifier(
models, metric=metrics.Accuracy(), policy=policy
)
metric = metrics.Accuracy()

score = evaluate.progressive_val_score(dataset, model, metric)
assert score.get() > 0.5


@pytest.mark.parametrize(
"policy",
[
pytest.param(
policy(**params),
id=f"{policy.__name__}",
)
for _, policy in inspect.getmembers(
importlib.import_module("river.bandit"),
lambda obj: inspect.isclass(obj) and issubclass(obj, bandit.base.Policy),
)
for params in policy._unit_test_params()
if policy.__name__ not in {"ThompsonSampling", "Exp3"}
],
)
def test_bandit_regressor_with_each_policy(policy):
models = [
linear_model.LinearRegression(optimizer=optim.SGD(lr=lr))
for lr in [0.0001, 0.001, 1e-05, 0.01]
]

dataset = datasets.TrumpApproval()
model = preprocessing.StandardScaler() | model_selection.BanditRegressor(
models, metric=metrics.MSE(), policy=policy
)
metric = metrics.MSE()

score = evaluate.progressive_val_score(dataset, model, metric)
assert score.get() < 300

0 comments on commit 2090e40

Please sign in to comment.