Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dian xt ms #29

Open
wants to merge 27 commits into
base: master
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
ppo modify
  • Loading branch information
AmiyaSX committed Apr 8, 2023
commit fd30d13310fe1b47bf3bd2f2c61dc20fb7a024f6
31 changes: 13 additions & 18 deletions xt/model/ppo/ppo_ms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,32 +4,28 @@
from xt.model.ms_dist import make_dist
from zeus.common.util.common import import_config
from zeus.common.util.register import Registers
from xt.model.ms_compat import ReduceMean, Tensor, Adam
from xt.model.ms_compat import ReduceMean, Tensor, Adam,Model
from xt.model.model_ms import XTModel_MS
from xt.model.ms_utils import MSVariables

from mindspore.ops import Depend, value_and_grad, clip_by_global_norm, Log, ReduceSum, Minimum, Maximum, Exp, Square, clip_by_value
from mindspore.nn import Cell, TrainOneStepCell, LossBase

from mindspore import set_context
import mindspore as ms
set_context(mode=ms.GRAPH_MODE)

@Registers.model
class PPOMS(XTModel_MS):

class PPOPredictPolicy(Cell):
'''封装用于预测的网络
这么做可以避免内存泄漏,具体原理还不清楚
'''

def __init__(self, net, dist):
super(PPOMS.PPOPredictPolicy, self).__init__()
self.network = net
self.dist = dist

def construct(self, state):
pi_latent, v_out = self.network(state)
action = self.dist.sample(pi_latent)
logp = self.dist.log_prob(action, pi_latent)

return action, logp, v_out

def __init__(self, model_info):
Expand Down Expand Up @@ -57,26 +53,25 @@ def __init__(self, model_info):
self.predict_net = self.PPOPredictPolicy(
self.model, self.dist)
adam = Adam(params=self.predict_net.trainable_params(),
learning_rate=self._lr)
learning_rate=0.0005)
loss_fn = WithLossCell(self.critic_loss_coef,
self.clip_ratio, self.ent_coef, self.vf_clip)
forward_fn = NetWithLoss(
self.model, loss_fn, self.dist, self.action_type)
self.train_net = MyTrainOneStepCell(
forward_fn, optimizer=adam, max_grad_norm=self._max_grad_norm)
self.train_net.set_train()

def predict(self, state):
"""Predict state."""
state = Tensor(state)
state = Tensor.from_numpy(state)
action, logp, v_out = self.predict_net(state)
action = action.asnumpy()
logp = logp.asnumpy()
v_out = v_out.asnumpy()
return action, logp, v_out

def train(self, state, label):
self.model.set_train(True)
nbatch = state[0].shape[0]
inds = np.arange(nbatch)
loss_val = []
Expand All @@ -85,12 +80,12 @@ def train(self, state, label):
for start in range(0, nbatch, self._batch_size):
end = start + self._batch_size
mbinds = inds[start:end]
state_ph = Tensor(state[0][mbinds])
behavior_action_ph = Tensor(label[0][mbinds])
old_logp_ph = Tensor(label[1][mbinds])
adv_ph = Tensor(label[2][mbinds])
old_v_ph = Tensor(label[3][mbinds])
target_v_ph = Tensor(label[4][mbinds])
state_ph = Tensor.from_numpy(state[0][mbinds])
behavior_action_ph = Tensor.from_numpy(label[0][mbinds])
old_logp_ph = Tensor.from_numpy(label[1][mbinds])
adv_ph = Tensor.from_numpy(label[2][mbinds])
old_v_ph = Tensor.from_numpy(label[3][mbinds])
target_v_ph = Tensor.from_numpy(label[4][mbinds])
loss = self.train_net(
state_ph, adv_ph, old_logp_ph, behavior_action_ph, target_v_ph, old_v_ph).asnumpy()
loss_val.append(np.mean(loss))
Expand Down