forked from ray-project/ray
-
Notifications
You must be signed in to change notification settings - Fork 0
/
replay_buffer.py
457 lines (392 loc) · 17.9 KB
/
replay_buffer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
import collections
import logging
import numpy as np
import platform
import random
from typing import List, Dict
# Import ray before psutil will make sure we use psutil's bundled version
import ray # noqa F401
import psutil # noqa E402
from ray.rllib.execution.segment_tree import SumSegmentTree, MinSegmentTree
from ray.rllib.policy.rnn_sequencing import \
timeslice_along_seq_lens_with_overlap
from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch, \
DEFAULT_POLICY_ID
from ray.rllib.utils.annotations import DeveloperAPI
from ray.util.iter import ParallelIteratorWorker
from ray.util.debug import log_once
from ray.rllib.utils.timer import TimerStat
from ray.rllib.utils.window_stat import WindowStat
from ray.rllib.utils.typing import SampleBatchType
# Constant that represents all policies in lockstep replay mode.
_ALL_POLICIES = "__all__"
logger = logging.getLogger(__name__)
def warn_replay_buffer_size(*, item: SampleBatchType, num_items: int) -> None:
"""Warn if the configured replay buffer size is too large."""
if log_once("replay_buffer_size"):
item_size = item.size_bytes()
psutil_mem = psutil.virtual_memory()
total_gb = psutil_mem.total / 1e9
mem_size = num_items * item_size / 1e9
msg = ("Estimated max memory usage for replay buffer is {} GB "
"({} batches of size {}, {} bytes each), "
"available system memory is {} GB".format(
mem_size, num_items, item.count, item_size, total_gb))
if mem_size > total_gb:
raise ValueError(msg)
elif mem_size > 0.2 * total_gb:
logger.warning(msg)
else:
logger.info(msg)
@DeveloperAPI
class ReplayBuffer:
@DeveloperAPI
def __init__(self, size: int):
"""Create Prioritized Replay buffer.
Args:
size (int): Max number of timesteps to store in the FIFO buffer.
"""
self._storage = []
self._maxsize = size
self._next_idx = 0
self._hit_count = np.zeros(size)
self._eviction_started = False
self._num_timesteps_added = 0
self._num_timesteps_added_wrap = 0
self._num_timesteps_sampled = 0
self._evicted_hit_stats = WindowStat("evicted_hit", 1000)
self._est_size_bytes = 0
def __len__(self) -> int:
return len(self._storage)
@DeveloperAPI
def add(self, item: SampleBatchType, weight: float) -> None:
warn_replay_buffer_size(
item=item, num_items=self._maxsize / item.count)
assert item.count > 0, item
self._num_timesteps_added += item.count
self._num_timesteps_added_wrap += item.count
if self._next_idx >= len(self._storage):
self._storage.append(item)
self._est_size_bytes += item.size_bytes()
else:
self._storage[self._next_idx] = item
# Wrap around storage as a circular buffer once we hit maxsize.
if self._num_timesteps_added_wrap >= self._maxsize:
self._eviction_started = True
self._num_timesteps_added_wrap = 0
self._next_idx = 0
else:
self._next_idx += 1
if self._eviction_started:
self._evicted_hit_stats.push(self._hit_count[self._next_idx])
self._hit_count[self._next_idx] = 0
def _encode_sample(self, idxes: List[int]) -> SampleBatchType:
out = SampleBatch.concat_samples([self._storage[i] for i in idxes])
out.decompress_if_needed()
return out
@DeveloperAPI
def sample(self, num_items: int) -> SampleBatchType:
"""Sample a batch of experiences.
Args:
num_items (int): Number of items to sample from this buffer.
Returns:
SampleBatchType: concatenated batch of items.
"""
idxes = [
random.randint(0,
len(self._storage) - 1) for _ in range(num_items)
]
self._num_sampled += num_items
return self._encode_sample(idxes)
@DeveloperAPI
def stats(self, debug=False) -> dict:
data = {
"added_count": self._num_timesteps_added,
"sampled_count": self._num_timesteps_sampled,
"est_size_bytes": self._est_size_bytes,
"num_entries": len(self._storage),
}
if debug:
data.update(self._evicted_hit_stats.stats())
return data
@DeveloperAPI
class PrioritizedReplayBuffer(ReplayBuffer):
@DeveloperAPI
def __init__(self, size: int, alpha: float):
"""Create Prioritized Replay buffer.
Args:
size (int): Max number of items to store in the FIFO buffer.
alpha (float): how much prioritization is used
(0 - no prioritization, 1 - full prioritization).
See also:
ReplayBuffer.__init__()
"""
super(PrioritizedReplayBuffer, self).__init__(size)
assert alpha > 0
self._alpha = alpha
it_capacity = 1
while it_capacity < size:
it_capacity *= 2
self._it_sum = SumSegmentTree(it_capacity)
self._it_min = MinSegmentTree(it_capacity)
self._max_priority = 1.0
self._prio_change_stats = WindowStat("reprio", 1000)
@DeveloperAPI
def add(self, item: SampleBatchType, weight: float) -> None:
idx = self._next_idx
super(PrioritizedReplayBuffer, self).add(item, weight)
if weight is None:
weight = self._max_priority
self._it_sum[idx] = weight**self._alpha
self._it_min[idx] = weight**self._alpha
def _sample_proportional(self, num_items: int) -> List[int]:
res = []
for _ in range(num_items):
# TODO(szymon): should we ensure no repeats?
mass = random.random() * self._it_sum.sum(0, len(self._storage))
idx = self._it_sum.find_prefixsum_idx(mass)
res.append(idx)
return res
@DeveloperAPI
def sample(self, num_items: int, beta: float) -> SampleBatchType:
"""Sample a batch of experiences and return priority weights, indices.
Args:
num_items (int): Number of items to sample from this buffer.
beta (float): To what degree to use importance weights
(0 - no corrections, 1 - full correction).
Returns:
SampleBatchType: Concatenated batch of items including "weights"
and "batch_indexes" fields denoting IS of each sampled
transition and original idxes in buffer of sampled experiences.
"""
assert beta >= 0.0
idxes = self._sample_proportional(num_items)
weights = []
batch_indexes = []
p_min = self._it_min.min() / self._it_sum.sum()
max_weight = (p_min * len(self._storage))**(-beta)
for idx in idxes:
p_sample = self._it_sum[idx] / self._it_sum.sum()
weight = (p_sample * len(self._storage))**(-beta)
count = self._storage[idx].count
# If zero-padded, count will not be the actual batch size of the
# data.
if isinstance(self._storage[idx], SampleBatch) and \
self._storage[idx].zero_padded:
actual_size = self._storage[idx].max_seq_len
else:
actual_size = count
weights.extend([weight / max_weight] * actual_size)
batch_indexes.extend([idx] * actual_size)
self._num_timesteps_sampled += count
batch = self._encode_sample(idxes)
# Note: prioritization is not supported in lockstep replay mode.
if isinstance(batch, SampleBatch):
batch["weights"] = np.array(weights)
batch["batch_indexes"] = np.array(batch_indexes)
return batch
@DeveloperAPI
def update_priorities(self, idxes: List[int],
priorities: List[float]) -> None:
"""Update priorities of sampled transitions.
sets priority of transition at index idxes[i] in buffer
to priorities[i].
Parameters
----------
idxes: [int]
List of idxes of sampled transitions
priorities: [float]
List of updated priorities corresponding to
transitions at the sampled idxes denoted by
variable `idxes`.
"""
# Making sure we don't pass in e.g. a torch tensor.
assert isinstance(idxes, (list, np.ndarray)), \
"ERROR: `idxes` is not a list or np.ndarray, but " \
"{}!".format(type(idxes).__name__)
assert len(idxes) == len(priorities)
for idx, priority in zip(idxes, priorities):
assert priority > 0
assert 0 <= idx < len(self._storage)
delta = priority**self._alpha - self._it_sum[idx]
self._prio_change_stats.push(delta)
self._it_sum[idx] = priority**self._alpha
self._it_min[idx] = priority**self._alpha
self._max_priority = max(self._max_priority, priority)
@DeveloperAPI
def stats(self, debug: bool = False) -> Dict:
parent = ReplayBuffer.stats(self, debug)
if debug:
parent.update(self._prio_change_stats.stats())
return parent
# Visible for testing.
_local_replay_buffer = None
class LocalReplayBuffer(ParallelIteratorWorker):
"""A replay buffer shard storing data for all policies (in multiagent setup).
Ray actors are single-threaded, so for scalability, multiple replay actors
may be created to increase parallelism."""
def __init__(self,
num_shards: int = 1,
learning_starts: int = 1000,
buffer_size: int = 10000,
replay_batch_size: int = 1,
prioritized_replay_alpha: float = 0.6,
prioritized_replay_beta: float = 0.4,
prioritized_replay_eps: float = 1e-6,
replay_mode: str = "independent",
replay_sequence_length: int = 1,
replay_burn_in: int = 0,
replay_zero_init_states: bool = True):
"""Initializes a LocalReplayBuffer instance.
Args:
num_shards (int): The number of buffer shards that exist in total
(including this one).
learning_starts (int): Number of timesteps after which a call to
`replay()` will yield samples (before that, `replay()` will
return None).
buffer_size (int): The size of the buffer. Note that when
`replay_sequence_length` > 1, this is the number of sequences
(not single timesteps) stored.
replay_batch_size (int): The batch size to be sampled (in
timesteps). Note that if `replay_sequence_length` > 1,
`self.replay_batch_size` will be set to the number of
sequences sampled (B).
prioritized_replay_alpha (float): Alpha parameter for a prioritized
replay buffer.
prioritized_replay_beta (float): Beta parameter for a prioritized
replay buffer.
prioritized_replay_eps (float): Epsilon parameter for a prioritized
replay buffer.
replay_mode (str): One of "independent" or "lockstep". Determined,
whether in the multiagent case, sampling is done across all
agents/policies equally.
replay_sequence_length (int): The sequence length (T) of a single
sample. If > 1, we will sample B x T from this buffer.
replay_burn_in (int): The burn-in length in case
`replay_sequence_length` > 0. This is the number of timesteps
each sequence overlaps with the previous one to generate a
better internal state (=state after the burn-in), instead of
starting from 0.0 each RNN rollout.
replay_zero_init_states (bool): Whether the initial states in the
buffer (if replay_sequence_length > 0) are alwayas 0.0 or
should be updated with the previous train_batch state outputs.
"""
self.replay_starts = learning_starts // num_shards
self.buffer_size = buffer_size // num_shards
self.replay_batch_size = replay_batch_size
self.prioritized_replay_beta = prioritized_replay_beta
self.prioritized_replay_eps = prioritized_replay_eps
self.replay_mode = replay_mode
self.replay_sequence_length = replay_sequence_length
self.replay_burn_in = replay_burn_in
self.replay_zero_init_states = replay_zero_init_states
if replay_sequence_length > 1:
self.replay_batch_size = int(
max(1, replay_batch_size // replay_sequence_length))
logger.info(
"Since replay_sequence_length={} and replay_batch_size={}, "
"we will replay {} sequences at a time.".format(
replay_sequence_length, replay_batch_size,
self.replay_batch_size))
if replay_mode not in ["lockstep", "independent"]:
raise ValueError("Unsupported replay mode: {}".format(replay_mode))
def gen_replay():
while True:
yield self.replay()
ParallelIteratorWorker.__init__(self, gen_replay, False)
def new_buffer():
return PrioritizedReplayBuffer(
self.buffer_size, alpha=prioritized_replay_alpha)
self.replay_buffers = collections.defaultdict(new_buffer)
# Metrics
self.add_batch_timer = TimerStat()
self.replay_timer = TimerStat()
self.update_priorities_timer = TimerStat()
self.num_added = 0
# Make externally accessible for testing.
global _local_replay_buffer
_local_replay_buffer = self
# If set, return this instead of the usual data for testing.
self._fake_batch = None
@staticmethod
def get_instance_for_testing():
global _local_replay_buffer
return _local_replay_buffer
def get_host(self) -> str:
return platform.node()
def add_batch(self, batch: SampleBatchType) -> None:
# Make a copy so the replay buffer doesn't pin plasma memory.
batch = batch.copy()
# Handle everything as if multiagent
if isinstance(batch, SampleBatch):
batch = MultiAgentBatch({DEFAULT_POLICY_ID: batch}, batch.count)
with self.add_batch_timer:
# Lockstep mode: Store under _ALL_POLICIES key (we will always
# only sample from all policies at the same time).
if self.replay_mode == "lockstep":
# Note that prioritization is not supported in this mode.
for s in batch.timeslices(self.replay_sequence_length):
self.replay_buffers[_ALL_POLICIES].add(s, weight=None)
else:
for policy_id, sample_batch in batch.policy_batches.items():
if self.replay_sequence_length == 1:
timeslices = sample_batch.timeslices(1)
else:
timeslices = timeslice_along_seq_lens_with_overlap(
sample_batch=sample_batch,
zero_pad_max_seq_len=self.replay_sequence_length,
pre_overlap=self.replay_burn_in,
zero_init_states=self.replay_zero_init_states,
)
for time_slice in timeslices:
# If SampleBatch has prio-replay weights, average
# over these to use as a weight for the entire
# sequence.
if "weights" in time_slice:
weight = np.mean(time_slice["weights"])
else:
weight = None
self.replay_buffers[policy_id].add(
time_slice, weight=weight)
self.num_added += batch.count
def replay(self) -> SampleBatchType:
if self._fake_batch:
fake_batch = SampleBatch(self._fake_batch)
return MultiAgentBatch({
DEFAULT_POLICY_ID: fake_batch
}, fake_batch.count)
if self.num_added < self.replay_starts:
return None
with self.replay_timer:
# Lockstep mode: Sample from all policies at the same time an
# equal amount of steps.
if self.replay_mode == "lockstep":
return self.replay_buffers[_ALL_POLICIES].sample(
self.replay_batch_size, beta=self.prioritized_replay_beta)
else:
samples = {}
for policy_id, replay_buffer in self.replay_buffers.items():
samples[policy_id] = replay_buffer.sample(
self.replay_batch_size,
beta=self.prioritized_replay_beta)
return MultiAgentBatch(samples, self.replay_batch_size)
def update_priorities(self, prio_dict: Dict) -> None:
with self.update_priorities_timer:
for policy_id, (batch_indexes, td_errors) in prio_dict.items():
new_priorities = (
np.abs(td_errors) + self.prioritized_replay_eps)
self.replay_buffers[policy_id].update_priorities(
batch_indexes, new_priorities)
def stats(self, debug: bool = False) -> Dict:
stat = {
"add_batch_time_ms": round(1000 * self.add_batch_timer.mean, 3),
"replay_time_ms": round(1000 * self.replay_timer.mean, 3),
"update_priorities_time_ms": round(
1000 * self.update_priorities_timer.mean, 3),
}
for policy_id, replay_buffer in self.replay_buffers.items():
stat.update({
"policy_{}".format(policy_id): replay_buffer.stats(debug=debug)
})
return stat
ReplayActor = ray.remote(num_cpus=0)(LocalReplayBuffer)