-
Notifications
You must be signed in to change notification settings - Fork 5.6k
/
test_supported_multi_agent.py
124 lines (99 loc) · 3.55 KB
/
test_supported_multi_agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import unittest
import ray
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
from ray.rllib.algorithms.dqn import DQNConfig
from ray.rllib.algorithms.impala import IMPALAConfig
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.algorithms.sac import SACConfig
from ray.rllib.examples.envs.classes.multi_agent import (
MultiAgentCartPole,
MultiAgentMountainCar,
)
from ray.rllib.policy.policy import PolicySpec
from ray.rllib.utils.test_utils import check_train_results
from ray.tune.registry import register_env
def check_support_multiagent(alg: str, config: AlgorithmConfig):
register_env(
"multi_agent_mountaincar", lambda _: MultiAgentMountainCar({"num_agents": 2})
)
register_env(
"multi_agent_cartpole", lambda _: MultiAgentCartPole({"num_agents": 2})
)
# Simulate a simple multi-agent setup.
policies = {
"policy_0": PolicySpec(config={"gamma": 0.99}),
"policy_1": PolicySpec(config={"gamma": 0.95}),
}
policy_ids = list(policies.keys())
def policy_mapping_fn(agent_id, episode, worker, **kwargs):
pol_id = policy_ids[agent_id]
return pol_id
config.multi_agent(policies=policies, policy_mapping_fn=policy_mapping_fn)
if alg == "SAC":
a = config.build(env="multi_agent_mountaincar")
else:
a = config.build(env="multi_agent_cartpole")
results = a.train()
check_train_results(results)
print(results)
a.stop()
class TestSupportedMultiAgentPolicyGradient(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
ray.init(num_cpus=4)
@classmethod
def tearDownClass(cls) -> None:
ray.shutdown()
def test_impala_multiagent(self):
check_support_multiagent("IMPALA", IMPALAConfig().resources(num_gpus=0))
def test_ppo_multiagent(self):
check_support_multiagent(
"PPO",
(
PPOConfig()
.env_runners(num_env_runners=1, rollout_fragment_length=10)
.training(num_sgd_iter=1, train_batch_size=10, sgd_minibatch_size=1)
),
)
class TestSupportedMultiAgentOffPolicy(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
ray.init(num_cpus=6)
@classmethod
def tearDownClass(cls) -> None:
ray.shutdown()
def test_dqn_multiagent(self):
check_support_multiagent(
"DQN",
(
DQNConfig()
.reporting(min_sample_timesteps_per_iteration=1)
.training(replay_buffer_config={"capacity": 1000})
),
)
def test_sac_multiagent(self):
check_support_multiagent(
"SAC",
(
SACConfig()
.environment(normalize_actions=False)
.env_runners(num_env_runners=0)
.training(replay_buffer_config={"capacity": 1000})
),
)
class TestSupportedMultiAgentMultiGPU(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
ray.init()
@classmethod
def tearDownClass(cls) -> None:
ray.shutdown()
def test_impala_multiagent_multi_gpu(self):
check_support_multiagent("IMPALA", IMPALAConfig().resources(num_gpus=2))
if __name__ == "__main__":
import pytest
import sys
# One can specify the specific TestCase class to run.
# None for all unittest.TestCase classes in this file.
class_ = sys.argv[1] if len(sys.argv) > 1 else None
sys.exit(pytest.main(["-v", __file__ + ("" if class_ is None else "::" + class_)]))