Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dian xt ms #29

Open
wants to merge 27 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
MuZero Model mindspore版本移植
  • Loading branch information
AmiyaSX committed May 4, 2023
commit 1deed68708962fd6ee4a394d49bd002a4df52b41
2 changes: 1 addition & 1 deletion examples/breakout_ppo_ms.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ model_para:
hidden_sizes: [256]

env_num: 10
speedup: False


benchmark:
log_interval_to_train: 10
2 changes: 1 addition & 1 deletion examples/cartpole_ppo_ms.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ model_para:
hidden_sizes: [64, 64]

env_num: 10
speedup: False


benchmark:
log_interval_to_train: 20
Expand Down
1 change: 1 addition & 0 deletions examples/muzero/muzero_breakout_ms.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,4 @@ model_para:
}

env_num: 50
speedup: False
7 changes: 4 additions & 3 deletions xt/model/model_ms.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@
import glob
import mindspore as ms
from xt.model.model import XTModel
from xt.model.ms_utils import MSVariables

os.environ["KERAS_BACKEND"] = "mindspore"


class XTModel_MS(XTModel):
Expand Down Expand Up @@ -85,7 +83,10 @@ def load_model(self, model_name):

def check_keep_model(model_path, keep_num):
"""Check model saved count under path."""
target_file = glob.glob(os.path.join(model_path, "actor*".format(model_path)))
target_file = glob.glob(
os.path.join(
model_path,
"actor*".format(model_path)))
if len(target_file) > keep_num:
to_rm_model = sorted(target_file, reverse=True)[keep_num:]
for item in to_rm_model:
Expand Down
100 changes: 45 additions & 55 deletions xt/model/model_utils_ms.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""Retain model utils."""

import numpy as np
from xt.model.tf_compat import K, Conv2D, Input, Lambda, Flatten, Model, Dense, Concatenate, tf

from xt.model.ms_compat import ms, SequentialCell, Dense, Conv2d, Flatten, get_activation, Cell
from xt.model.ms_compat import ms, SequentialCell, Dense, Conv2d, Flatten,\
get_activation, Cell
from mindspore._checkparam import twice

ACTIVATION_MAP_MS = {
Expand All @@ -15,29 +15,33 @@
'leakyrelu': 'leakyrelu',
'elu': 'elu',
'selu': 'seLU',
# 'swish': tf.nn.swish,
'hswish': 'hswish', # FIXME: ms中没有swish,只有h-swish
'hswish': 'hswish', # FIXME: ms中没有swish,只有h-swish
'gelu': 'gelu'
}


def cal_shape(input_shape, kernel_size, stride):
kernel_size = twice(kernel_size)
stride = twice(stride)
return tuple(
(v - kernel_size[i]) // stride[i] + 1 for i, v in enumerate(input_shape)
)
(v - kernel_size[i]) // stride[i] + 1 for i,
v in enumerate(input_shape))


class MlpBackbone(Cell):
def __init__(self, state_dim, act_dim, hidden_sizes, activation):
super().__init__()
self.dense_layer_pi = bulid_mlp_layers_ms(state_dim[-1], hidden_sizes, activation)
self.dense_pi = Dense(hidden_sizes[-1], act_dim, weight_init="XavierUniform")
self.dense_layer_v = bulid_mlp_layers_ms(state_dim[-1], hidden_sizes, activation)
self.dense_out = Dense(hidden_sizes[-1], 1, weight_init="XavierUniform")
self.dense_layer_pi = bulid_mlp_layers_ms(
state_dim[-1], hidden_sizes, activation)
self.dense_pi = Dense(
hidden_sizes[-1], act_dim, weight_init="XavierUniform")
self.dense_layer_v = bulid_mlp_layers_ms(
state_dim[-1], hidden_sizes, activation)
self.dense_out = Dense(
hidden_sizes[-1], 1, weight_init="XavierUniform")

def construct(self, x):
if(x.dtype==ms.float64):
if x.dtype == ms.float64:
x = x.astype(ms.float32)
pi_latent = self.dense_layer_pi(x)
pi_latent = self.dense_pi(pi_latent)
Expand All @@ -53,11 +57,13 @@ def __init__(self, state_dim, act_dim, hidden_sizes, activation):
self.dense_layer_share = bulid_mlp_layers_ms(
state_dim[-1], hidden_sizes, activation
)
self.dense_pi = Dense(hidden_sizes[-1], act_dim, weight_init="XavierUniform")
self.dense_out = Dense(hidden_sizes[-1], 1, weight_init="XavierUniform")
self.dense_pi = Dense(
hidden_sizes[-1], act_dim, weight_init="XavierUniform")
self.dense_out = Dense(
hidden_sizes[-1], 1, weight_init="XavierUniform")

def construct(self, x):
if(x.dtype==ms.float64):
if x.dtype == ms.float64:
x = x.astype(ms.float32)
share = self.dense_layer_share(x)
pi_latent = self.dense_pi(share)
Expand All @@ -78,16 +84,20 @@ def __init__(
):
super().__init__()
self.dtype = dtype
self.conv_layer_pi = build_conv_layers_ms(state_dim[-1], filter_arches, activation)
self.conv_layer_pi = build_conv_layers_ms(
state_dim[-1], filter_arches, activation)
self.flatten_layer = Flatten()
height, width = state_dim[-3], state_dim[-2]
filters = 1
for filters, kernel_size, strides in filter_arches:
height, width = cal_shape((height, width), kernel_size, strides)
dim = height * width * filters
self.dense_layer_pi = bulid_mlp_layers_ms(dim, hidden_sizes, activation)
self.dense_pi = Dense(hidden_sizes[-1], act_dim, weight_init="XavierUniform")
self.conv_layer_v = build_conv_layers_ms(state_dim[-1], filter_arches, activation)
self.dense_layer_pi = bulid_mlp_layers_ms(
dim, hidden_sizes, activation)
self.dense_pi = Dense(
hidden_sizes[-1], act_dim, weight_init="XavierUniform")
self.conv_layer_v = build_conv_layers_ms(
state_dim[-1], filter_arches, activation)
self.dense_layer_v = bulid_mlp_layers_ms(dim, hidden_sizes, activation)
self.dense_v = Dense(hidden_sizes[-1], 1, weight_init="XavierUniform")

Expand Down Expand Up @@ -128,9 +138,12 @@ def __init__(
for filters, kernel_size, strides in filter_arches:
height, width = cal_shape((height, width), kernel_size, strides)
dim = height * width * filters
self.dense_layer_share = bulid_mlp_layers_ms(dim, hidden_sizes, activation)
self.dense_pi = Dense(hidden_sizes[-1], act_dim, weight_init="XavierUniform")
self.dense_layer_share = bulid_mlp_layers_ms(
dim, hidden_sizes, activation)
self.dense_pi = Dense(
hidden_sizes[-1], act_dim, weight_init="XavierUniform")
self.dense_v = Dense(hidden_sizes[-1], 1, weight_init="XavierUniform")

def construct(self, x):
x = x.transpose((0, 3, 1, 2))
if self.dtype == "uint8":
Expand Down Expand Up @@ -178,7 +191,8 @@ def get_cnn_backbone_ms(
"""Get CNN backbone."""
if dtype != "uint8" and dtype != "float32":
raise ValueError(
'dtype: {} not supported automatically, please implement it yourself'.format(
'dtype: {} not supported automatically, \
please implement it yourself'.format(
dtype
)
)
Expand Down Expand Up @@ -259,11 +273,15 @@ def get_default_filters_ms(shape):
"""Get default model set for atari environments."""
shape = list(shape)
if len(shape) != 3:
raise ValueError('Without default architecture for obs shape {}'.format(shape))
raise ValueError(
'Without default architecture for obs shape {}'.format(shape))
# (out_size, kernel, stride)
filters_84x84 = [[32, (8, 8), (4, 4)], [32, (4, 4), (2, 2)], [64, (3, 3), (1, 1)]]
filters_42x42 = [[32, (4, 4), (2, 2)], [32, (4, 4), (2, 2)], [64, (3, 3), (1, 1)]]
filters_15x15 = [[32, (5, 5), (1, 1)], [64, (3, 3), (1, 1)], [64, (3, 3), (1, 1)]]
filters_84x84 = [[32, (8, 8), (4, 4)], [32, (4, 4), (2, 2)], [
64, (3, 3), (1, 1)]]
filters_42x42 = [[32, (4, 4), (2, 2)], [32, (4, 4), (2, 2)], [
64, (3, 3), (1, 1)]]
filters_15x15 = [[32, (5, 5), (1, 1)], [64, (3, 3), (1, 1)], [
64, (3, 3), (1, 1)]]
if shape[:2] == [84, 84]:
return filters_84x84
elif shape[:2] == [42, 42]:
Expand All @@ -282,7 +300,8 @@ def get_default_filters_ms(shape):
filter_h, stride_h, flat_flag_h = _infer_stride_and_kernel_ms(
input_h, flat_flag_h
)
filters.append((num_filters, (filter_w, filter_h), (stride_w, stride_h)))
filters.append(
(num_filters, (filter_w, filter_h), (stride_w, stride_h)))
num_filters *= 2
input_w = input_w // stride_w
input_h = input_h // stride_h
Expand All @@ -303,36 +322,7 @@ def _infer_stride_and_kernel_ms(size, flat_flag):
return 2 * stride + 1, stride, False


def _infer_same_padding_size_ms(old_size, stride):
new_size = old_size // stride
if new_size * stride == old_size:
return new_size
else:
return new_size + 1


def layer_function_ms(x):
"""Normalize data."""
return x.astype(ms.float32) / 255.0


def state_transform_ms(x, mean=1e-5, std=255., input_dtype="float32"):
"""Normalize data."""
if input_dtype in ("float32", "float", "float64"):
return x

# only cast non-float32 state
if np.abs(mean) < 1e-4:
return tf.cast(x, dtype='float32') / std
else:
return (tf.cast(x, dtype="float32") - mean) / std


def custom_norm_initializer_ms(std=0.5):
"""Perform Customize norm initializer for op."""
def _initializer(shape, dtype=None, partition_info=None):
out = np.random.randn(*shape).astype(np.float32)
out *= std / np.sqrt(np.square(out).sum(axis=0, keepdims=True))
return tf.constant(out)

return _initializer
5 changes: 4 additions & 1 deletion xt/model/ms_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import sys


def import_ms_compact():
"""Import mindspore with compact behavior."""
if "mindspore" not in sys.modules:
Expand All @@ -32,6 +33,7 @@ def import_ms_compact():
else:
return sys.modules["mindspore"]


ms = import_ms_compact()


Expand All @@ -47,6 +49,7 @@ def import_ms_compact():
from mindspore.ops import Cast, MultitypeFuncGraph, ReduceSum, ReduceMax, ReduceMin, ReduceMean, Reciprocal
from mindspore import History


def loss_to_val(loss):
"""Make keras instance into value."""
if isinstance(loss, History):
Expand All @@ -57,4 +60,4 @@ def loss_to_val(loss):
DTYPE_MAP = {
"float32": ms.float32,
"float16": ms.float16,
}
}
41 changes: 23 additions & 18 deletions xt/model/ms_dist.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""Action distribution with mindspore"""
import numpy as np
from xt.model.ms_compat import ms, Cast, ReduceSum, ReduceMax, SoftmaxCrossEntropyWithLogits, Tensor
from xt.model.ms_compat import ms, Cast, ReduceSum, ReduceMax, Tensor
from mindspore import ops
from mindspore import Parameter, ms_function, ms_class
from mindspore import ms_class


@ms_class
Expand Down Expand Up @@ -72,40 +72,42 @@ def flatparam(self):
def sample_dtype(self):
return ms.float32

def log_prob(self, x, mean, sd = None):
def log_prob(self, x, mean, sd=None):
if sd is not None:
log_sd = self.log(sd)
neglog_prob = 0.5 * self.log(2.0 * np.pi) * self.cast((self.shape(x)[-1]), ms.float32) + \
0.5 * self.reduce_sum(self.square((x - mean) / sd), axis=-1) + \
self.reduce_sum(log_sd, axis=-1)
0.5 * self.reduce_sum(self.square((x - mean) / sd), axis=-1) + \
self.reduce_sum(log_sd, axis=-1)
else:
neglog_prob = 0.5 * self.log(2.0 * np.pi) * self.cast((self.shape(x)[-1]), ms.float32) + \
0.5 * self.reduce_sum(self.square((x - mean) / sd), axis=-1)
neglog_prob = 0.5 * self.log(2.0 * np.pi) * self.cast((self.shape(
x)[-1]), ms.float32) + 0.5 * self.reduce_sum(self.square((x - mean) / sd), axis=-1)
return -neglog_prob

def mode(self):
return self.mean

def entropy(self, mean, sd = None):
def entropy(self, mean, sd=None):
if sd is not None:
log_sd = self.log(sd)
return self.reduce_sum(log_sd + 0.5 * (self.log(2.0 * np.pi) + 1.0), axis=-1)
return self.reduce_sum(
log_sd + 0.5 * (self.log(2.0 * np.pi) + 1.0), axis=-1)
return 0.5 * (self.log(2.0 * np.pi) + 1.0)

def kl(self, other):
assert isinstance(
other, DiagGaussianDist), 'Distribution type not match.'
reduce_sum = ReduceSum(keep_dims=True)
return reduce_sum(
(ops.square(self.std) + ops.square(self.mean - other.mean)) / (2.0 * ops.square(other.std)) +
other.log_std - self.log_std - 0.5,
axis=-1)
return reduce_sum((self.square(self.std) +
self.square(self.mean - other.mean)) /
(2.0 * self.square(other.std)) +
other.log_std - self.log_std - 0.5, axis=-1)

def sample(self, mean, sd = None):
def sample(self, mean, sd=None):
if sd is not None:
return mean + sd * self.normal(self.shape(mean), dtype=ms.float32)
return mean + self.normal(self.shape(mean), dtype=ms.float32)


class CategoricalDist(ActionDist):

def __init__(self, size):
Expand Down Expand Up @@ -151,13 +153,16 @@ def kl(self, other):
rescaled_logits_self = self.logits - reduce_max(self.logits, axis=-1)
rescaled_logits_other = other.logits - \
reduce_max(other.logits, axis=-1)
exp_logits_self = ops.exp(rescaled_logits_self)
exp_logits_other = ops.exp(rescaled_logits_other)
exp_logits_self = self.exp(rescaled_logits_self)
exp_logits_other = self.exp(rescaled_logits_other)
z_self = reduce_sum(exp_logits_self, axis=-1)
z_other = reduce_sum(exp_logits_other, axis=-1)
p = exp_logits_self / z_self
return reduce_sum(p * (rescaled_logits_self - ops.log(z_self) - rescaled_logits_other + ops.log(z_other)),
axis=-1)
return reduce_sum(p *
(rescaled_logits_self -
self.log(z_self) -
rescaled_logits_other +
self.log(z_other)), axis=-1)

def sample(self, logits):
return self.random_categorical(logits, 1, 0).squeeze(-1)
Expand Down
18 changes: 12 additions & 6 deletions xt/model/ms_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
"""Create tf utils for assign weights between learner and actor and model utils for universal usage."""
"""Create tf utils for assign weights between learner and actor\
and model utils for universal usage."""

import numpy as np
from mindspore import nn
Expand All @@ -31,19 +32,24 @@ def __init__(self, net: nn.Cell) -> None:
self.net = net

def get_weights(self) -> OrderedDict:
_weights = OrderedDict((par_name, par.data.asnumpy()) for par_name, par in self.net.parameters_and_names())
_weights = OrderedDict((par_name, par.data.asnumpy())
for par_name, par in
self.net.parameters_and_names())
return _weights

def save_weights(self, save_name: str):
_weights = OrderedDict((par_name, par.data.asnumpy()) for par_name, par in self.net.parameters_and_names())
_weights = OrderedDict((par_name, par.data.asnumpy())
for par_name, par in
self.net.parameters_and_names())
np.savez(save_name, **_weights)

def set_weights(self, to_weights):
for _, param in self.net.parameters_and_names():
if param.name in to_weights:
new_param_data = ms.Tensor(copy.deepcopy(to_weights[param.name]))
new_param_data = ms.Tensor(
copy.deepcopy(to_weights[param.name]))
param.set_data(new_param_data, param.sliced)

return
def read_weights(weight_file: str):
"""Read weights with numpy.npz"""
np_file = np.load(weight_file)
Expand All @@ -59,4 +65,4 @@ def save_weight_with_checkpoint(self, filename: str):

def load_weight_with_checkpoint(self, filename: str):
param_dict = ms.load_checkpoint(filename, self.net)
param_not_load = ms.load_param_into_net(self.net, param_dict)
param_not_load = ms.load_param_into_net(self.net, param_dict)
Loading