Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dian xt ms #29

Open
wants to merge 27 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
modify ppo model
  • Loading branch information
AmiyaSX committed Mar 25, 2023
commit f4012921b8e636f8959437e1ebc4df451e20d6f2
6 changes: 4 additions & 2 deletions xt/model/model_utils_ms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,9 @@

import numpy as np
from xt.model.tf_compat import K, Conv2D, Input, Lambda, Flatten, Model, Dense, Concatenate, tf
from xt.model.tf_utils import gelu, norm_initializer

from xt.model.ms_compat import ms, SequentialCell, Dense, Conv2d, Flatten, get_activation, Cell
from mindspore._checkparam import twice
from mindspore import nn, ops

ACTIVATION_MAP_MS = {
'sigmoid': 'sigmoid',
Expand Down Expand Up @@ -39,6 +37,8 @@ def __init__(self, state_dim, act_dim, hidden_sizes, activation):
self.dense_out = Dense(hidden_sizes[-1], 1, weight_init="XavierUniform")

def construct(self, x):
if(x.dtype==ms.float64):
x = x.astype(ms.float32)
pi_latent = self.dense_layer_pi(x)
pi_latent = self.dense_pi(pi_latent)
out_value = self.dense_layer_v(x)
Expand All @@ -57,6 +57,8 @@ def __init__(self, state_dim, act_dim, hidden_sizes, activation):
self.dense_out = Dense(hidden_sizes[-1], 1, weight_init="XavierUniform")

def construct(self, x):
if(x.dtype==ms.float64):
x = x.astype(ms.float32)
share = self.dense_layer_share(x)
pi_latent = self.dense_pi(share)
out_value = self.dense_out(share)
Expand Down
77 changes: 44 additions & 33 deletions xt/model/ms_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
import numpy as np
from xt.model.ms_compat import ms, Cast, ReduceSum, ReduceMax, SoftmaxCrossEntropyWithLogits, Tensor
from mindspore import ops
import mindspore.nn.probability.distribution as msd
from mindspore import Parameter,ms_function,ms_class
from mindspore.common.initializer import initializer
from mindspore import Parameter, ms_function, ms_class


@ms_class
class ActionDist:
"""Build base action distribution."""

def init_by_param(self, param):
raise NotImplementedError

Expand All @@ -32,12 +32,12 @@ def shape(self):
def __getitem__(self, idx):
return self.flatparam()[idx]

def neglog_prob(self, x,logits):
def neglog_prob(self, x, logits):
raise NotImplementedError

def log_prob(self, x,logits):
def log_prob(self, x, logits):
"""Calculate the log-likelihood."""
return -self.neglog_prob(x,logits)
return -self.neglog_prob(x, logits)

def mode(self):
raise NotImplementedError
Expand All @@ -48,6 +48,7 @@ def entropy(self):
def kl(self, other):
raise NotImplementedError


class DiagGaussianDist(ActionDist):
"""Build Diagonal Gaussian distribution, each vector represented one distribution."""

Expand All @@ -57,8 +58,9 @@ def __init__(self, size):
self.log = ops.Log()
self.shape = ops.Shape()
self.square = ops.Square()
self.Normal = ops.StandardNormal()
self.normal = ops.StandardNormal()
self.cast = Cast()

def init_by_param(self, param):
self.param = param
self.mean, self.log_std = ops.split(self.param, axis=-1, output_num=2)
Expand All @@ -70,44 +72,55 @@ def flatparam(self):
def sample_dtype(self):
return ms.float32

def neglog_prob(self, x,mean,sd):
log_sd = self.log(sd)
neglog_prob= 0.5 * self.log(2.0 * np.pi) * self.cast()((self.shape(x)[-1]), ms.float32) + \
def log_prob(self, x, mean, sd = None):
if sd is not None:
log_sd = self.log(sd)
neglog_prob = 0.5 * self.log(2.0 * np.pi) * self.cast((self.shape(x)[-1]), ms.float32) + \
0.5 * self.reduce_sum(self.square((x - mean) / sd), axis=-1) + \
self.reduce_sum(log_sd, axis=-1)
return neglog_prob
else:
neglog_prob = 0.5 * self.log(2.0 * np.pi) * self.cast((self.shape(x)[-1]), ms.float32) + \
0.5 * self.reduce_sum(self.square((x - mean) / sd), axis=-1)
return -neglog_prob

def mode(self):
return self.mean

def entropy(self,sd):
log_sd = self.log(sd)
return self.reduce_sum(log_sd + 0.5 * (self.log(2.0 * np.pi) + 1.0), axis=-1)
def entropy(self, mean, sd = None):
if sd is not None:
log_sd = self.log(sd)
return self.reduce_sum(log_sd + 0.5 * (self.log(2.0 * np.pi) + 1.0), axis=-1)
return 0.5 * (self.log(2.0 * np.pi) + 1.0)

def kl(self, other):
assert isinstance(other, DiagGaussianDist), 'Distribution type not match.'
assert isinstance(
other, DiagGaussianDist), 'Distribution type not match.'
reduce_sum = ReduceSum(keep_dims=True)
return reduce_sum(
(ops.square(self.std) + ops.square(self.mean - other.mean)) / (2.0 * ops.square(other.std)) +
other.log_std - self.log_std - 0.5,
axis=-1)

def sample(self, mean,sd):
return mean + sd * self.normal(self.shape(mean), dtype=ms.float32)
def sample(self, mean, sd = None):
if sd is not None:
return mean + sd * self.normal(self.shape(mean), dtype=ms.float32)
return mean + self.normal(self.shape(mean), dtype=ms.float32)

class CategoricalDist(ActionDist):

def __init__(self, size):
self.size = size
self.OneHot = ops.OneHot()
self.oneHot = ops.OneHot()
self.softmax_cross = ops.SoftmaxCrossEntropyWithLogits()
self.reduce_max = ReduceMax(keep_dims=True)
self.reduce_sum = ReduceSum(keep_dims=True)
self.exp = ops.Exp()
self.log = ops.Log()
self.expand_dims = ops.ExpandDims()
self.random_categorical = ops.RandomCategorical(dtype=ms.int32)
self.on_value, self.off_value = Tensor(1.0, ms.float32), Tensor(0.0, ms.float32)
self.on_value, self.off_value = Tensor(
1.0, ms.float32), Tensor(0.0, ms.float32)

def init_by_param(self, logits):
self.logits = logits

Expand All @@ -117,14 +130,12 @@ def flatparam(self):
def sample_dtype(self):
return ms.int32

def neglog_prob(self, x,logits):
x = self.OneHot(x , self.size, self.on_value, self.off_value)
loss, dlogits = self.softmax_cross(logits, x)
return self.expand_dims(loss, -1)


def entropy(self,logits):
def log_prob(self, x, logits):
x = self.oneHot(x, self.size, self.on_value, self.off_value)
loss, _ = self.softmax_cross(logits, x)
return -self.expand_dims(loss, -1)

def entropy(self, logits):
rescaled_logits = logits - self.reduce_max(logits, -1)
exp_logits = self.exp(rescaled_logits)

Expand All @@ -133,23 +144,23 @@ def entropy(self,logits):
return self.reduce_sum(p * (self.log(z) - rescaled_logits), -1)

def kl(self, other):
assert isinstance(other, CategoricalDist), 'Distribution type not match.'
assert isinstance(
other, CategoricalDist), 'Distribution type not match.'
reduce_max = ReduceMax(keep_dims=True)
reduce_sum = ReduceSum(keep_dims=True)
rescaled_logits_self = self.logits - reduce_max(self.logits, axis=-1)
rescaled_logits_other = other.logits - reduce_max(other.logits, axis=-1)
rescaled_logits_other = other.logits - \
reduce_max(other.logits, axis=-1)
exp_logits_self = ops.exp(rescaled_logits_self)
exp_logits_other = ops.exp(rescaled_logits_other)
z_self = reduce_sum(exp_logits_self, axis=-1)
z_other = reduce_sum(exp_logits_other, axis=-1)
p = exp_logits_self / z_self
return reduce_sum(p * (rescaled_logits_self - ops.log(z_self) - rescaled_logits_other + ops.log(z_other)),
axis=-1)
axis=-1)

def sample(self,logits):
# u = tf.random_uniform(tf.shape(self.logits), dtype=self.logits.dtype)
# return tf.argmax(self.logits - tf.log(-tf.log(u)), axis=-1, output_type=tf.int32)
return self.random_categorical(logits,1,0).squeeze(-1)
def sample(self, logits):
return self.random_categorical(logits, 1, 0).squeeze(-1)


def make_dist(ac_type, ac_dim):
Expand Down
Loading