Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dian xt ms #29

Open
wants to merge 27 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
ppo_ms modify
  • Loading branch information
AmiyaSX committed Mar 9, 2023
commit c3e46543b212755985adb925a4ba9a23227617e3
13 changes: 10 additions & 3 deletions xt/model/ms_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,19 @@ def import_ms_compact():
from mindspore.nn import Adam
from mindspore.nn import Conv2d, Dense, Flatten, ReLU
from mindspore.nn import MSELoss
from mindspore.nn import WithLossCell,TrainOneStepCell, SoftmaxCrossEntropyWithLogits
from mindspore.nn import Cell, WithLossCell, DynamicLossScaleUpdateCell
from mindspore.train import Model
from mindspore.nn import WithLossCell, TrainOneStepCell, SoftmaxCrossEntropyWithLogits, SequentialCell
from mindspore.nn import Cell, WithLossCell, DynamicLossScaleUpdateCell, get_activation
from mindspore import Model, Tensor
from mindspore.ops import Cast, MultitypeFuncGraph, ReduceSum, ReduceMax, ReduceMean
from mindspore.ops import Cast, MultitypeFuncGraph, ReduceSum, ReduceMax, ReduceMin, ReduceMean, Reciprocal
from mindspore import History

def loss_to_val(loss):
"""Make keras instance into value."""
if isinstance(loss, History):
loss = loss.history.get("loss")[0]
return loss


DTYPE_MAP = {
"float32": ms.float32,
Expand Down
95 changes: 67 additions & 28 deletions xt/model/ms_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,52 @@
from xt.model.ms_compat import ms, Cast, ReduceSum, ReduceMax, SoftmaxCrossEntropyWithLogits, Tensor
from mindspore import ops
import mindspore.nn.probability.distribution as msd
from mindspore import Parameter
from mindspore import Parameter,ms_function,ms_class
from mindspore.common.initializer import initializer

@ms_class
class ActionDist:
"""Build base action distribution."""
def init_by_param(self, param):
raise NotImplementedError

def flatparam(self):
raise NotImplementedError

def sample(self, repeat):
"""Sample action from this distribution."""
raise NotImplementedError

def sample_dtype(self):
raise NotImplementedError

def get_shape(self):
return self.flatparam().shape.as_list()

@property
def shape(self):
return self.get_shape()

def __getitem__(self, idx):
return self.flatparam()[idx]

def neglog_prob(self, x,logits):
raise NotImplementedError

class DiagGaussianDist(msd.Normal):
def log_prob(self, x,logits):
"""Calculate the log-likelihood."""
return -self.neglog_prob(x,logits)

def mode(self):
raise NotImplementedError

def entropy(self):
raise NotImplementedError

def kl(self, other):
raise NotImplementedError

class DiagGaussianDist(ActionDist):
"""Build Diagonal Gaussian distribution, each vector represented one distribution."""

def __init__(self, size):
Expand All @@ -29,17 +70,17 @@ def flatparam(self):
def sample_dtype(self):
return ms.float32

def _log_prob(self, x, mean, sd):
def neglog_prob(self, x,mean,sd):
log_sd = self.log(sd)
neglog_prob = 0.5 * self.log(2.0 * np.pi) * Cast()((self.shape(x)[-1]), ms.float32) + \
0.5 * self.reduce_sum(self.square((x - mean) / sd), axis=-1) + \
self.reduce_sum(log_sd, axis=-1)
return -neglog_prob
neglog_prob= 0.5 * self.log(2.0 * np.pi) * Cast()((self.shape(x)[-1]), ms.float32) + \
0.5 * self.reduce_sum(self.square((x - mean) / sd), axis=-1) + \
self.reduce_sum(log_sd, axis=-1)
return neglog_prob

def mode(self):
return self.mean

def _entropy(self, sd):
def entropy(self,sd):
log_sd = self.log(sd)
return self.reduce_sum(log_sd + 0.5 * (self.log(2.0 * np.pi) + 1.0), axis=-1)

Expand All @@ -51,45 +92,42 @@ def kl(self, other):
other.log_std - self.log_std - 0.5,
axis=-1)

def _sample(self, mean, sd):
def sample(self, mean,sd):
return mean + sd * self.normal(self.shape(mean), dtype=ms.float32)


class CategoricalDist(msd.Categorical):
class CategoricalDist(ActionDist):

def __init__(self, size):
super(CategoricalDist, self).__init__()
self.size = size
self.OneHot = ops.OneHot()
self.softmax_cross = ops.SoftmaxCrossEntropyWithLogits()
self.reduce_max = ReduceMax(keep_dims=True)
self.reduce_sum = ReduceSum(keep_dims=True)
self.exp = ops.Exp()
self.log = ops.Log()
self.squeeze = ops.Squeeze()
self.random_categorical = ops.RandomCategorical(dtype=ms.int32)
self.expand_dims = ops.ExpandDims()

def init_by_param(self, logits):
self.logits = logits
return

def flatparam(self):
return self.logits

def sample_dtype(self):
return ms.int32

def _log_prob(self, x, logits):
on_value, off_value = Tensor(1.0, ms.float32), Tensor(0.0, ms.float32)
x = self.OneHot(x, self.size, on_value, off_value)
def neglog_prob(self, x,logits):
on_value, off_value = Tensor(1.0, ms.float32), Tensor(0.0, ms.float32)
x = self.OneHot(x , self.size, on_value, off_value)
loss, dlogits = self.softmax_cross(logits, x)
return -self.expand_dims(loss, -1)
return self.expand_dims(loss, -1)

def _entropy(self, logits):

def entropy(self,logits):

rescaled_logits = logits - self.reduce_max(logits, -1)
exp_logits = self.exp(rescaled_logits)

z = self.reduce_sum(exp_logits, -1)
p = exp_logits / z
return self.reduce_sum(p * (self.log(z) - rescaled_logits), -1)
Expand All @@ -98,20 +136,21 @@ def kl(self, other):
assert isinstance(other, CategoricalDist), 'Distribution type not match.'
reduce_max = ReduceMax(keep_dims=True)
reduce_sum = ReduceSum(keep_dims=True)
rescaled_logits_self = self.logits - reduce_max(self.logits, -1)
rescaled_logits_other = other.logits - reduce_max(other.logits, -1)
rescaled_logits_self = self.logits - reduce_max(self.logits, axis=-1)
rescaled_logits_other = other.logits - reduce_max(other.logits, axis=-1)
exp_logits_self = ops.exp(rescaled_logits_self)
exp_logits_other = ops.exp(rescaled_logits_other)
z_self = reduce_sum(exp_logits_self, -1)
z_other = reduce_sum(exp_logits_other, -1)
z_self = reduce_sum(exp_logits_self, axis=-1)
z_other = reduce_sum(exp_logits_other, axis=-1)
p = exp_logits_self / z_self
return reduce_sum(p * (rescaled_logits_self - ops.log(z_self) - rescaled_logits_other + ops.log(z_other)),
-1)
axis=-1)

def _sample(self, logits):
def sample(self,logits):
# u = tf.random_uniform(tf.shape(self.logits), dtype=self.logits.dtype)
# return tf.argmax(self.logits - tf.log(-tf.log(u)), axis=-1, output_type=tf.int32)
return logits.random_categorical(1, dtype=ms.int32).squeeze(-1)
action = ops.squeeze(ops.random_categorical(logits,1,dtype=ms.int32),-1)
return action


def make_dist(ac_type, ac_dim):
Expand Down
Empty file.
Empty file.
Empty file.
Empty file.
1 change: 0 additions & 1 deletion xt/model/ppo/ppo_mlp_ms.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from xt.model.model_utils_ms import ACTIVATION_MAP_MS, get_mlp_backbone_ms, get_mlp_default_settings_ms
from xt.model.ppo.default_config import MLP_SHARE_LAYERS
from xt.model.ppo.ppo_ms import PPOMS
from xt.model.tf_compat import tf
from zeus.common.util.register import Registers
from xt.model.ms_utils import MSVariables

Expand Down
107 changes: 54 additions & 53 deletions xt/model/ppo/ppo_ms.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
import mindspore.nn.probability.distribution as msd
from mindspore import set_context
import time


import os
import psutil
@Registers.model
class PPOMS(XTModel_MS):
"""Build PPO MLP network."""

def __init__(self, model_info):
model_config = model_info.get('model_config')
import_config(globals(), model_config)
Expand All @@ -27,7 +27,6 @@ def __init__(self, model_info):
self.state_dim = model_info['state_dim']
self.action_dim = model_info['action_dim']
self.input_dtype = model_info.get('input_dtype', 'float32')

self.action_type = model_config.get('action_type')
self._lr = model_config.get('LR', LR)
self._batch_size = model_config.get('BATCH_SIZE', BATCH_SIZE)
Expand All @@ -38,41 +37,42 @@ def __init__(self, model_info):
self.num_sgd_iter = model_config.get('NUM_SGD_ITER', NUM_SGD_ITER)
self.verbose = model_config.get('SUMMARY', SUMMARY)
self.vf_clip = Tensor(model_config.get('VF_CLIP', VF_CLIP))
self.net = self.create_model(model_info)

self.dist = make_dist(self.action_type, self.action_dim)
'''创建优化器、损失函数以及训练网络'''
adam = Adam(params=self.net.trainable_params(), learning_rate=0.0005)
loss_fn = WithLossCell(self.critic_loss_coef, self.clip_ratio, self.ent_coef, self.vf_clip)
forward_fn = NetWithLoss(self.net, loss_fn, self.dist, self.action_type)

super().__init__(model_info)

'''创建训练网络'''
adam = Adam(params=self.model.trainable_params(), learning_rate=0.0005,use_amsgrad=True)
loss_fn = WithLossCell(self.critic_loss_coef,self.clip_ratio, self.ent_coef, self.vf_clip)
forward_fn = NetWithLoss(self.model,loss_fn, self.dist, self.action_type)
self.train_net = MyTrainOneStepCell(forward_fn, optimizer=adam, max_grad_norm=self._max_grad_norm)
self.train_net.set_train()
self.depend = ops.Depend()
self.exp = ops.Exp()
self.log = ops.Log()
self.reduce_sum = ops.ReduceSum(keep_dims=True)
super().__init__(model_info)

def predict(self, state):
"""Predict state."""
self.model.set_train(False)

state = Tensor(state, ms.uint8)
pi_latent, v_out = self.net(state)
pi_latent, v_out = self.model(state)

if self.action_type == 'DiagGaussian':
log_std = ms.common.initializer('zeros', [1, self.action_dim], ms.float32)
dist_param = ops.concat([pi_latent, pi_latent * 0.0 + log_std], axis=-1)
mean, log_std = ops.split(dist_param, axis=-1, output_num=2)
sd = ops.exp(log_std)
action = self.dist.sample(mean, sd)
logp = self.dist.log_prob(action, mean, sd)
std = ms.common.initializer('ones', [pi_latent.shape[0], self.action_dim], ms.float32)
self.action = self.dist.sample(pi_latent, std)
self.logp = self.dist.log_prob(self.action, pi_latent, std)
elif self.action_type == 'Categorical':
logits = pi_latent
action = self.dist.sample(logits)
logp = self.dist.log_prob(action, logits)
action = action.asnumpy()
logp = logp.asnumpy()

self.action = self.dist.sample(pi_latent)
self.logp = self.dist.log_prob(self.action, pi_latent)

self.action = self.action.asnumpy()
self.logp = self.logp.asnumpy()
v_out = v_out.asnumpy()
return action, logp, v_out
return self.action, self.logp, v_out

def train(self, state, label):

self.model.set_train(True)
nbatch = state[0].shape[0]
inds = np.arange(nbatch)
loss_val = []
Expand All @@ -85,33 +85,35 @@ def train(self, state, label):
mbinds = inds[start:end]
state_ph = Tensor(state[0][mbinds])
behavior_action_ph = Tensor(label[0][mbinds])
old_logp_ph = Tensor(label[1][mbinds], ms.float32)
adv_ph = Tensor(label[2][mbinds], ms.float32)
old_v_ph = Tensor(label[3][mbinds], ms.float32)
target_v_ph = Tensor(label[4][mbinds], ms.float32)
loss = self.train_net(state_ph, adv_ph, old_logp_ph, behavior_action_ph, target_v_ph,
old_v_ph).asnumpy()
old_logp_ph = Tensor(label[1][mbinds],ms.float32)
adv_ph = Tensor(label[2][mbinds],ms.float32)
old_v_ph = Tensor(label[3][mbinds],ms.float32)
target_v_ph = Tensor(label[4][mbinds],ms.float32)

loss = self.train_net( state_ph,adv_ph, old_logp_ph,behavior_action_ph, target_v_ph, old_v_ph ).asnumpy()
loss_val.append(np.mean(loss))
self.actor_var = MSVariables(self.net)
return np.mean(loss_val)

self.actor_var = MSVariables(self.model)

return np.mean(loss_val)

class MyTrainOneStepCell(nn.TrainOneStepCell):
def __init__(self, network, optimizer, max_grad_norm, sens=1.0):
super(MyTrainOneStepCell, self).__init__(network, optimizer, sens)
self.sens = sens
self.max_grad_norm = max_grad_norm

def construct(self, *inputs):
loss, grads = ops.value_and_grad(self.network, grad_position=None, weights=self.weights)(*inputs)

def construct(self,*inputs):
weights = self.weights
loss, grads = ops.value_and_grad(self.network, grad_position=None, weights=weights)(*inputs)
grads = ops.clip_by_global_norm(grads, self.max_grad_norm)
grads = self.grad_reducer(grads)
loss = ops.depend(loss, self.optimizer(grads))
return loss


class NetWithLoss(nn.Cell):
def __init__(self, net, loss_fn, dist, action_type):
def __init__(self, net,loss_fn, dist,action_type):
super(NetWithLoss, self).__init__(auto_prefix=False)
self.net = net
self._loss_fn = loss_fn
Expand All @@ -122,26 +124,24 @@ def __init__(self, net, loss_fn, dist, action_type):
self.concat = ops.Concat()
self.exp = ops.Exp()
self.log = ops.Log()

def construct(self, state_ph, adv_ph, old_logp_ph, behavior_action, target_v, old_v_ph):
def construct(self,state_ph,adv_ph, old_logp_ph,behavior_action, target_v, old_v_ph ):
pi_latent, v_out = self.net(state_ph)
if self.action_type == 'DiagGaussian':
log_std = ms.common.initializer('zeros', [1, self.action_dim], ms.float32)
dist_param = self.concat([pi_latent, pi_latent * 0.0 + log_std], axis=-1)
mean, log_std = self.split(dist_param, axis=-1, output_num=2)
sd = self.exp(log_std)
ent = self.dist.entropy(sd)
action_log_prob = self.dist.log_prob(behavior_action, mean, sd)
action_log_prob = self.dist.log_prob(behavior_action,mean,sd)
else:
logits = pi_latent
ent = self.dist.entropy(logits)
action_log_prob = self.dist.log_prob(behavior_action, logits)
loss = self._loss_fn(action_log_prob, ent, adv_ph, old_logp_ph, target_v, v_out, old_v_ph)
ent = self.dist.entropy(pi_latent)
action_log_prob = self.dist.log_prob(behavior_action,pi_latent)
loss = self._loss_fn(action_log_prob,ent, adv_ph, old_logp_ph, target_v, v_out, old_v_ph)
return loss


class WithLossCell(nn.LossBase):
def __init__(self, critic_loss_coef, clip_ratio, ent_coef, val_clip):
def __init__(self, critic_loss_coef,clip_ratio,ent_coef,val_clip):
super(WithLossCell, self).__init__()
self.reduce_mean = ReduceMean(keep_dims=True)
self.critic_loss_coef = critic_loss_coef
Expand All @@ -155,8 +155,8 @@ def __init__(self, critic_loss_coef, clip_ratio, ent_coef, val_clip):
self.exp = ops.Exp()
self.square = ops.Square()
self.squeeze = ops.Squeeze()

def construct(self, action_log_prob, ent, adv, old_log_p, target_v, out_v, old_v):
def construct(self, action_log_prob,ent, adv, old_log_p, target_v, out_v, old_v):
ratio = self.exp(action_log_prob - old_log_p)

surr_loss_1 = ratio * adv
Expand All @@ -170,6 +170,7 @@ def construct(self, action_log_prob, ent, adv, old_log_p, target_v, out_v, old_v
val_pred_clipped = old_v + ops.clip_by_value(out_v - old_v, -self.val_clip, self.val_clip)
vf_losses2 = self.square(val_pred_clipped - target_v)

critic_loss = 0.5 * self.reduce_mean(self.maximum(vf_losses1, vf_losses2))
critic_loss = 0.5 * self.reduce_mean(self.maximum(vf_losses1, vf_losses2))
loss = actor_loss + self.critic_loss_coef * critic_loss
return loss
return loss