forked from ray-project/ray
-
Notifications
You must be signed in to change notification settings - Fork 0
/
async_samples_optimizer.py
189 lines (161 loc) · 7.22 KB
/
async_samples_optimizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
"""Implements the IMPALA asynchronous sampling architecture.
https://arxiv.org/abs/1802.01561"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import logging
import time
from ray.rllib.optimizers.aso_aggregator import SimpleAggregator
from ray.rllib.optimizers.aso_tree_aggregator import TreeAggregator
from ray.rllib.optimizers.aso_learner import LearnerThread
from ray.rllib.optimizers.aso_multi_gpu_learner import TFMultiGPULearner
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
from ray.rllib.utils.annotations import override
from ray.rllib.utils.timer import TimerStat
logger = logging.getLogger(__name__)
class AsyncSamplesOptimizer(PolicyOptimizer):
"""Main event loop of the IMPALA architecture.
This class coordinates the data transfers between the learner thread
and remote workers (IMPALA actors).
"""
def __init__(self,
workers,
train_batch_size=500,
sample_batch_size=50,
num_envs_per_worker=1,
num_gpus=0,
lr=0.0005,
replay_buffer_num_slots=0,
replay_proportion=0.0,
num_data_loader_buffers=1,
max_sample_requests_in_flight_per_worker=2,
broadcast_interval=1,
num_sgd_iter=1,
minibatch_buffer_size=1,
learner_queue_size=16,
learner_queue_timeout=300,
num_aggregation_workers=0,
_fake_gpus=False):
PolicyOptimizer.__init__(self, workers)
self._stats_start_time = time.time()
self._last_stats_time = {}
self._last_stats_sum = {}
if num_gpus > 1 or num_data_loader_buffers > 1:
logger.info(
"Enabling multi-GPU mode, {} GPUs, {} parallel loaders".format(
num_gpus, num_data_loader_buffers))
if num_data_loader_buffers < minibatch_buffer_size:
raise ValueError(
"In multi-gpu mode you must have at least as many "
"parallel data loader buffers as minibatch buffers: "
"{} vs {}".format(num_data_loader_buffers,
minibatch_buffer_size))
self.learner = TFMultiGPULearner(
self.workers.local_worker(),
lr=lr,
num_gpus=num_gpus,
train_batch_size=train_batch_size,
num_data_loader_buffers=num_data_loader_buffers,
minibatch_buffer_size=minibatch_buffer_size,
num_sgd_iter=num_sgd_iter,
learner_queue_size=learner_queue_size,
learner_queue_timeout=learner_queue_timeout,
_fake_gpus=_fake_gpus)
else:
self.learner = LearnerThread(
self.workers.local_worker(),
minibatch_buffer_size=minibatch_buffer_size,
num_sgd_iter=num_sgd_iter,
learner_queue_size=learner_queue_size,
learner_queue_timeout=learner_queue_timeout)
self.learner.start()
# Stats
self._optimizer_step_timer = TimerStat()
self._stats_start_time = time.time()
self._last_stats_time = {}
if num_aggregation_workers > 0:
self.aggregator = TreeAggregator(
workers,
num_aggregation_workers,
replay_proportion=replay_proportion,
max_sample_requests_in_flight_per_worker=(
max_sample_requests_in_flight_per_worker),
replay_buffer_num_slots=replay_buffer_num_slots,
train_batch_size=train_batch_size,
sample_batch_size=sample_batch_size,
broadcast_interval=broadcast_interval)
else:
self.aggregator = SimpleAggregator(
workers,
replay_proportion=replay_proportion,
max_sample_requests_in_flight_per_worker=(
max_sample_requests_in_flight_per_worker),
replay_buffer_num_slots=replay_buffer_num_slots,
train_batch_size=train_batch_size,
sample_batch_size=sample_batch_size,
broadcast_interval=broadcast_interval)
def add_stat_val(self, key, val):
if key not in self._last_stats_sum:
self._last_stats_sum[key] = 0
self._last_stats_time[key] = self._stats_start_time
self._last_stats_sum[key] += val
def get_mean_stats_and_reset(self):
now = time.time()
mean_stats = {
key: round(val / (now - self._last_stats_time[key]), 3)
for key, val in self._last_stats_sum.items()
}
for key in self._last_stats_sum.keys():
self._last_stats_sum[key] = 0
self._last_stats_time[key] = time.time()
return mean_stats
@override(PolicyOptimizer)
def step(self):
if len(self.workers.remote_workers()) == 0:
raise ValueError("Config num_workers=0 means training will hang!")
assert self.learner.is_alive()
with self._optimizer_step_timer:
sample_timesteps, train_timesteps = self._step()
if sample_timesteps > 0:
self.add_stat_val("sample_throughput", sample_timesteps)
if train_timesteps > 0:
self.add_stat_val("train_throughput", train_timesteps)
self.num_steps_sampled += sample_timesteps
self.num_steps_trained += train_timesteps
@override(PolicyOptimizer)
def stop(self):
self.learner.stopped = True
@override(PolicyOptimizer)
def reset(self, remote_workers):
self.workers.reset(remote_workers)
self.aggregator.reset(remote_workers)
@override(PolicyOptimizer)
def stats(self):
def timer_to_ms(timer):
return round(1000 * timer.mean, 3)
stats = self.aggregator.stats()
stats.update(self.get_mean_stats_and_reset())
stats["timing_breakdown"] = {
"optimizer_step_time_ms": timer_to_ms(self._optimizer_step_timer),
"learner_grad_time_ms": timer_to_ms(self.learner.grad_timer),
"learner_load_time_ms": timer_to_ms(self.learner.load_timer),
"learner_load_wait_time_ms": timer_to_ms(
self.learner.load_wait_timer),
"learner_dequeue_time_ms": timer_to_ms(self.learner.queue_timer),
}
stats["learner_queue"] = self.learner.learner_queue_size.stats()
if self.learner.stats:
stats["learner"] = self.learner.stats
return dict(PolicyOptimizer.stats(self), **stats)
def _step(self):
sample_timesteps, train_timesteps = 0, 0
for train_batch in self.aggregator.iter_train_batches():
sample_timesteps += train_batch.count
self.learner.inqueue.put(train_batch)
if (self.learner.weights_updated
and self.aggregator.should_broadcast()):
self.aggregator.broadcast_new_weights()
while not self.learner.outqueue.empty():
count = self.learner.outqueue.get()
train_timesteps += count
return sample_timesteps, train_timesteps