Skip to content

Commit

Permalink
[RLlib] Fix memory leak in APEX_DQN (ray-project#26691)
Browse files Browse the repository at this point in the history
Signed-off-by: Rohan138 <[email protected]>
  • Loading branch information
avnishn authored and Rohan138 committed Jul 28, 2022
1 parent 82c0cb4 commit 1aa6c73
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 26 deletions.
36 changes: 13 additions & 23 deletions rllib/algorithms/apex_dqn/apex_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
""" # noqa: E501
import copy
import platform
import queue
import random
from collections import defaultdict
from typing import Callable, Dict, List, Optional, Type
Expand Down Expand Up @@ -413,7 +412,6 @@ def setup(self, config: PartialAlgorithmConfigDict):
weights = self.workers.local_worker().get_weights()
self.curr_learner_weights = ray.put(weights)
self.curr_num_samples_collected = 0
self.replay_sample_batches = []
self._num_ts_trained_since_last_target_update = 0

@classmethod
Expand Down Expand Up @@ -563,14 +561,15 @@ def wait_on_replay_actors() -> None:
If the timeout is None, then block on the actors indefinitely.
"""
_replay_samples_ready = self._replay_actor_manager.get_ready()

replay_sample_batches = []
for _replay_actor, _sample_batches in _replay_samples_ready.items():
for _sample_batch in _sample_batches:
self.replay_sample_batches.append((_replay_actor, _sample_batch))
replay_sample_batches.append((_replay_actor, _sample_batch))
return replay_sample_batches

num_samples_collected = sum(num_samples_collected.values())
self.curr_num_samples_collected += num_samples_collected
wait_on_replay_actors()
replay_sample_batches = wait_on_replay_actors()
if self.curr_num_samples_collected >= self.config["train_batch_size"]:
training_intensity = int(self.config["training_intensity"] or 1)
num_requests_to_launch = (
Expand All @@ -583,26 +582,17 @@ def wait_on_replay_actors() -> None:
lambda actor, num_items: actor.sample(num_items),
fn_args=[self.config["train_batch_size"]],
)
wait_on_replay_actors()
replay_sample_batches.extend(wait_on_replay_actors())

# add the sample batches to the learner queue
while self.replay_sample_batches:
try:
item = self.replay_sample_batches[0]
# the replay buffer returns none if it has not been filled to
# the minimum threshold yet.
if item:
# Setting block = True prevents the learner thread,
# the main thread, and the gpu loader threads from
# thrashing when there are more samples than the
# learner can reasonable process.
# see https://github.com/ray-project/ray/pull/26581#issuecomment-1187877674 # noqa
self.learner_thread.inqueue.put(
self.replay_sample_batches[0], block=True
)
self.replay_sample_batches.pop(0)
except queue.Full:
break
for item in replay_sample_batches:
# Setting block = True prevents the learner thread,
# the main thread, and the gpu loader threads from
# thrashing when there are more samples than the
# learner can reasonable process.
# see https://github.com/ray-project/ray/pull/26581#issuecomment-1187877674 # noqa
self.learner_thread.inqueue.put(item, block=True)
del replay_sample_batches

def update_replay_sample_priority(self) -> None:
"""Update the priorities of the sample batches with new priorities that are
Expand Down
9 changes: 6 additions & 3 deletions rllib/algorithms/dqn/learner_thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ def step(self):
self.outqueue.put(
(replay_actor, prio_dict, ma_batch.count, ma_batch.agent_steps())
)
self.learner_queue_size.push(self.inqueue.qsize())
self.weights_updated = True
self.overall_timer.push_units_processed(ma_batch and ma_batch.count or 0)
self.learner_queue_size.push(self.inqueue.qsize())
self.weights_updated = True
self.overall_timer.push_units_processed(
ma_batch and ma_batch.count or 0
)
del ma_batch

0 comments on commit 1aa6c73

Please sign in to comment.