Skip to content

Commit

Permalink
[RLlib] Fix IMPALA/APPO when using multi GPU setup and Multi-Agent Env (
Browse files Browse the repository at this point in the history
ray-project#35120)

Signed-off-by: Michael <[email protected]>
Co-authored-by: Artur Niederfahrenhorst <[email protected]>
  • Loading branch information
2 people authored and scv119 committed Jun 11, 2023
1 parent 7151cc7 commit 85d0140
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 1 deletion.
10 changes: 10 additions & 0 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2713,6 +2713,16 @@ py_test(
args = ["TestTrainAndEvaluate"]
)

py_test(
name = "tests/test_supported_multi_agent_multi_gpu",
main = "tests/test_supported_multi_agent.py",
tags = ["team:rllib", "tests_dir", "multi_gpu"],
size = "medium",
srcs = ["tests/test_supported_multi_agent.py"],
args = ["TestSupportedMultiAgentMultiGPU"]
)


py_test(
name = "tests/test_supported_multi_agent_pg",
main = "tests/test_supported_multi_agent.py",
Expand Down
4 changes: 3 additions & 1 deletion rllib/execution/multi_gpu_learner_thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,9 @@ def step(self) -> None:
default_policy_results = policy.learn_on_loaded_batch(
offset=0, buffer_index=buffer_idx
)
learner_info_builder.add_learn_on_batch_results(default_policy_results)
learner_info_builder.add_learn_on_batch_results(
default_policy_results, policy_id=pid
)
self.policy_ids_updated.append(pid)
get_num_samples_loaded_into_buffer += (
policy.get_num_samples_loaded_into_buffer(buffer_idx)
Expand Down
13 changes: 13 additions & 0 deletions rllib/tests/test_supported_multi_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,19 @@ def test_sac_multiagent(self):
)


class TestSupportedMultiAgentMultiGPU(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
ray.init()

@classmethod
def tearDownClass(cls) -> None:
ray.shutdown()

def test_impala_multiagent_multi_gpu(self):
check_support_multiagent("IMPALA", {"num_gpus": 2})


if __name__ == "__main__":
import pytest
import sys
Expand Down

0 comments on commit 85d0140

Please sign in to comment.