forked from ZiyuanMa/MATD3
-
Notifications
You must be signed in to change notification settings - Fork 0
/
matd3.py
207 lines (179 loc) · 8.25 KB
/
matd3.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http:https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import parl
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from parl.utils.utils import check_model_method
from copy import deepcopy
__all__ = ['MADDPG']
from parl.core.paddle.policy_distribution import DiagGaussianDistribution
from parl.core.paddle.policy_distribution import SoftCategoricalDistribution
from parl.core.paddle.policy_distribution import SoftMultiCategoricalDistribution
def SoftPDistribution(logits, act_space):
""" Select Policy Distribution according to act_space.
Args:
logits (paddle tensor): the output of policy model
act_space: action space, must be gym.spaces.Box or gym.spaces.Discrete or multiagent.multi_discrete.MultiDiscrete
Returns:
instance of DiagGaussianDistribution or SoftCategoricalDistribution or SoftMultiCategoricalDistribution
"""
# is instance of gym.spaces.Discrete
if (hasattr(act_space, 'n')):
return SoftCategoricalDistribution(logits)
# is instance of multiagent.multi_discrete.MultiDiscrete
elif (hasattr(act_space, 'num_discrete_space')):
return SoftMultiCategoricalDistribution(logits, act_space.low,
act_space.high)
# is instance of gym.spaces.Box
elif (hasattr(act_space, 'high')):
return DiagGaussianDistribution(logits)
else:
raise AssertionError("act_space must be instance of gym.spaces.Box or \
gym.spaces.Discrete or multiagent.multi_discrete.MultiDiscrete")
class MADDPG(parl.Algorithm):
def __init__(self,
model,
agent_index=None,
act_space=None,
gamma=None,
tau=None,
actor_lr=None,
critic_lr=None,
policy_freq=2):
""" MADDPG algorithm
Args:
model (parl.Model): forward network of actor and critic.
The function get_actor_params() of model should be implemented.
agent_index (int): index of agent, in multiagent env
act_space (list): action_space, gym space
gamma (float): discounted factor for reward computation.
tau (float): decay coefficient when updating the weights of self.target_model with self.model
critic_lr (float): learning rate of the critic model
actor_lr (float): learning rate of the actor model
"""
# checks
check_model_method(model, 'value', self.__class__.__name__)
check_model_method(model, 'policy', self.__class__.__name__)
check_model_method(model, 'get_actor_params', self.__class__.__name__)
check_model_method(model, 'get_critic_params', self.__class__.__name__)
assert isinstance(agent_index, int)
assert isinstance(act_space, list)
assert isinstance(gamma, float)
assert isinstance(tau, float)
assert isinstance(actor_lr, float)
assert isinstance(critic_lr, float)
self.continuous_actions = False
if not len(act_space) == 0 and hasattr(act_space[0], 'high') \
and not hasattr(act_space[0], 'num_discrete_space'):
self.continuous_actions = True
self.agent_index = agent_index
self.act_space = act_space
self.gamma = gamma
self.tau = tau
self.actor_lr = actor_lr
self.critic_lr = critic_lr
self.policy_freq = policy_freq
self.model = model
self.target_model = deepcopy(model)
self.sync_target(0)
self.actor_optimizer = paddle.optimizer.Adam(
learning_rate=self.actor_lr,
parameters=self.model.get_actor_params(),
grad_clip=nn.ClipGradByNorm(clip_norm=0.5))
self.critic_optimizer = paddle.optimizer.Adam(
learning_rate=self.critic_lr,
parameters=self.model.get_critic_params(),
grad_clip=nn.ClipGradByNorm(clip_norm=0.5))
self.training_steps = 0
def predict(self, obs, use_target_model=False):
""" use the policy model to predict actions
Args:
obs (paddle tensor): observation, shape([B] + shape of obs_n[agent_index])
use_target_model (bool): use target_model or not
Returns:
act (paddle tensor): action, shape([B] + shape of act_n[agent_index]),
noted that in the discrete case we take the argmax along the last axis as action
"""
if use_target_model:
policy = self.target_model.policy(obs)
else:
policy = self.model.policy(obs)
action = SoftPDistribution(
logits=policy,
act_space=self.act_space[self.agent_index]).sample()
if self.continuous_actions:
action = paddle.tanh(action)
return action
def Q(self, obs_n, act_n, use_target_model=False, with_q2=True):
""" use the value model to predict Q values
Args:
obs_n (list of paddle tensor): all agents' observation, len(agent's num) + shape([B] + shape of obs_n)
act_n (list of paddle tensor): all agents' action, len(agent's num) + shape([B] + shape of act_n)
use_target_model (bool): use target_model or not
with_q2 (bool): output q2 or not
Returns:
Q (paddle tensor): Q value of this agent, shape([B])
"""
if use_target_model:
return self.target_model.value(obs_n, act_n, with_q2)
else:
return self.model.value(obs_n, act_n, with_q2)
def learn(self, obs_n, act_n, target_q):
""" update actor and critic model with MADDPG algorithm
"""
self.training_steps += 1
critic_cost = self._critic_learn(obs_n, act_n, target_q)
if self.training_steps % self.policy_freq == 1:
actor_cost = self._actor_learn(obs_n, act_n)
self.sync_target()
return critic_cost
def _actor_learn(self, obs_n, act_n):
i = self.agent_index
this_policy = self.model.policy(obs_n[i])
sample_this_action = SoftPDistribution(
logits=this_policy,
act_space=self.act_space[self.agent_index]).sample()
if self.continuous_actions:
sample_this_action = paddle.tanh(sample_this_action)
# action_input_n = deepcopy(act_n)
action_input_n = act_n + []
action_input_n[i] = sample_this_action
eval_q = self.Q(obs_n, action_input_n, with_q2=False)
act_cost = paddle.mean(-1.0 * eval_q)
# when continuous, 'this_policy' will be a tuple with two element: (mean, std)
if self.continuous_actions:
this_policy = paddle.concat(this_policy, axis=-1)
act_reg = paddle.mean(paddle.square(this_policy))
cost = act_cost + act_reg * 1e-3
self.actor_optimizer.clear_grad()
cost.backward()
self.actor_optimizer.step()
return cost
def _critic_learn(self, obs_n, act_n, target_q):
pred_q1, pred_q2 = self.Q(obs_n, act_n)
cost = F.mse_loss(pred_q1, target_q) + F.mse_loss(pred_q2, target_q)
self.critic_optimizer.clear_grad()
cost.backward()
self.critic_optimizer.step()
return cost
def sync_target(self, decay=None):
""" update the target network with the training network
Args:
decay(float): the decaying factor while updating the target network with the training network.
0 represents the **assignment**. None represents updating the target network slowly that depends on the hyperparameter `tau`.
"""
if decay is None:
decay = 1.0 - self.tau
self.model.sync_weights_to(self.target_model, decay=decay)