-
Notifications
You must be signed in to change notification settings - Fork 0
/
ucb.py
57 lines (47 loc) · 2.06 KB
/
ucb.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import math
from typing import Tuple
import numpy as np
from numpy.random import Generator
from delayed_bandit.policies.policy import Policy
class UCB(Policy):
def __init__(self, num_arms: int, alpha: float, rng: Generator):
"""
Create Upper Confidence Bound policy. Based on Hoeffding’s inequality
to build confidence intervals. alpha is to control the exploration-exploitation
trade-off.
"""
self._num_arms = num_arms
self._alpha = alpha
self._rng = rng
self._current_arm = -1
self.cumulative_rewards = np.zeros(num_arms, dtype=np.float32)
self.arms_stats = np.zeros(num_arms, dtype=np.int32)
def choice(self, t: int) -> int:
if t < self._num_arms:
self._current_arm = t
return self._current_arm
arms = np.arange(self._num_arms)
indexes = np.array([self._index(t=t, arm=i) for i in arms])
idx = np.where(indexes == np.max(indexes))
best_arms = arms[idx]
self._current_arm = self._rng.choice(best_arms)
return self._current_arm
def feed_reward(self, t: int, arm: int, reward: float):
if arm != self._current_arm:
raise ValueError(f"Expected the reward for arm {self._current_arm}, but got for {arm}")
self.cumulative_rewards[arm] += reward
self.arms_stats[arm] += 1
return
def empirically_best_arm(self) -> Tuple[int, float]:
if np.count_nonzero(self.cumulative_rewards) == 0:
return self._rng.choice(self._num_arms), 0.0
idx = np.where(self.arms_stats != 0)
i = np.argmax(self.cumulative_rewards[idx] / self.arms_stats[idx])
arm = idx[0][i]
return arm, self.cumulative_rewards[arm] / self.arms_stats[arm]
def _index(self, t, arm) -> float:
mean = self.cumulative_rewards[arm] / self.arms_stats[arm]
confidence_radius = math.sqrt((self._alpha * math.log(t + 1)) / self.arms_stats[arm])
return mean + confidence_radius
def name(self) -> str:
return f"UCB(alpha={self._alpha})"