Skip to content

Commit

Permalink
Fix BanditClassifier (online-ml#1262)
Browse files Browse the repository at this point in the history
  • Loading branch information
MaxHalford committed Jun 8, 2023
1 parent f978972 commit afac9ab
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 13 deletions.
22 changes: 10 additions & 12 deletions river/bandit/thompson.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class ThompsonSampling(bandit.base.Policy):
Parameters
----------
dist
reward_obj
A distribution to sample from.
burn_in
The number of steps to use for the burn-in phase. Each arm is given the chance to be pulled
Expand All @@ -51,7 +51,7 @@ class ThompsonSampling(bandit.base.Policy):
>>> _ = env.reset(seed=42)
>>> _ = env.action_space.seed(123)
>>> policy = bandit.ThompsonSampling(dist=proba.Beta(), seed=101)
>>> policy = bandit.ThompsonSampling(reward_obj=proba.Beta(), seed=101)
>>> metric = stats.Sum()
>>> while True:
Expand All @@ -71,22 +71,20 @@ class ThompsonSampling(bandit.base.Policy):
"""

def __init__(self, dist: proba.base.Distribution, burn_in=0, seed: int | None = None):
super().__init__(reward_obj=dist, burn_in=burn_in)
def __init__(
self, reward_obj: proba.base.Distribution = None, burn_in=0, seed: int | None = None
):
super().__init__(reward_obj=reward_obj, burn_in=burn_in)
self.seed = seed
self._rng = random.Random(seed)
self._rewards.default_factory = self._clone_dist_with_seed
self._rewards.default_factory = self._clone_reward_obj_with_seed

def _clone_dist_with_seed(self):
return self.dist.clone({"seed": self._rng.randint(0, 2**32)})

@property
def dist(self):
return self.reward_obj
def _clone_reward_obj_with_seed(self):
return self.reward_obj.clone({"seed": self._rng.randint(0, 2**32)})

def _pull(self, arm_ids):
return max(arm_ids, key=lambda arm_id: self._rewards[arm_id].sample())

@classmethod
def _unit_test_params(cls):
yield {"dist": proba.Beta()}
yield {"reward_obj": proba.Beta()}
2 changes: 1 addition & 1 deletion river/model_selection/bandit.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def learn_one(self, x, y):
y_pred = (
model.predict_one(x) if self.metric.requires_labels else model.predict_proba_one(x)
)
self.policy.update(arm_id, y_true=y, y_pred=y_pred)
self.policy.update(arm_id, y, y_pred)
model.learn_one(x, y)

return self
87 changes: 87 additions & 0 deletions river/model_selection/test_bandit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from __future__ import annotations

import importlib
import inspect

import pytest

from river import (
bandit,
datasets,
evaluate,
linear_model,
metrics,
model_selection,
optim,
preprocessing,
)


def test_1259():
"""
https://github.com/online-ml/river/issues/1259
>>> from river import bandit
>>> from river import datasets
>>> from river import evaluate
>>> from river import linear_model
>>> from river import metrics
>>> from river import model_selection
>>> from river import optim
>>> from river import preprocessing
>>> models = [
... linear_model.LogisticRegression(optimizer=optim.SGD(lr=lr))
... for lr in [0.0001, 0.001, 1e-05, 0.01]
... ]
>>> dataset = datasets.Phishing()
>>> model = (
... preprocessing.StandardScaler() |
... model_selection.BanditClassifier(
... models,
... metric=metrics.Accuracy(),
... policy=bandit.Exp3(
... gamma=0.5,
... seed=42
... )
... )
... )
>>> metric = metrics.Accuracy()
>>> evaluate.progressive_val_score(dataset, model, metric)
Accuracy: 87.20%
"""


@pytest.mark.parametrize(
"policy",
[
pytest.param(
policy(**params),
id=f"{policy.__name__}",
)
for _, policy in inspect.getmembers(
importlib.import_module("river.bandit"),
lambda obj: inspect.isclass(obj) and issubclass(obj, bandit.base.Policy),
)
for params in policy._unit_test_params()
if policy.__name__ != "ThompsonSampling"
],
)
def test_bandit_classifier_with_each_policy(policy):
models = [
linear_model.LogisticRegression(optimizer=optim.SGD(lr=lr))
for lr in [0.0001, 0.001, 1e-05, 0.01]
]

dataset = datasets.Phishing()
model = preprocessing.StandardScaler() | model_selection.BanditClassifier(
models, metric=metrics.Accuracy(), policy=policy
)
metric = metrics.Accuracy()

score = evaluate.progressive_val_score(dataset, model, metric)
assert score.get() > 0.5

0 comments on commit afac9ab

Please sign in to comment.