Skip to content

Commit

Permalink
[ISSUE-187] Fixed misspelling in 'deterministic' (Denys88#188)
Browse files Browse the repository at this point in the history
* fixed misspeling

* fixed player

Co-authored-by: Denys Makoviichuk <[email protected]>
  • Loading branch information
Denys88 and DenSumy authored Jul 4, 2022
1 parent 4f1f4e8 commit 758ac4f
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 22 deletions.
20 changes: 10 additions & 10 deletions rl_games/algos_torch/players.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(self, params):
self.model.eval()
self.is_rnn = self.model.is_rnn()

def get_action(self, obs, is_determenistic = False):
def get_action(self, obs, is_deterministic = False):
if self.has_batch_dimension == False:
obs = unsqueeze_obs(obs)
obs = self._preproc_obs(obs)
Expand All @@ -56,7 +56,7 @@ def get_action(self, obs, is_determenistic = False):
mu = res_dict['mus']
action = res_dict['actions']
self.states = res_dict['rnn_states']
if is_determenistic:
if is_deterministic:
current_action = mu
else:
current_action = action
Expand Down Expand Up @@ -106,7 +106,7 @@ def __init__(self, params):
self.model.eval()
self.is_rnn = self.model.is_rnn()

def get_masked_action(self, obs, action_masks, is_determenistic = True):
def get_masked_action(self, obs, action_masks, is_deterministic = True):
if self.has_batch_dimension == False:
obs = unsqueeze_obs(obs)
obs = self._preproc_obs(obs)
Expand All @@ -126,18 +126,18 @@ def get_masked_action(self, obs, action_masks, is_determenistic = True):
action = res_dict['actions']
self.states = res_dict['rnn_states']
if self.is_multi_discrete:
if is_determenistic:
if is_deterministic:
action = [torch.argmax(logit.detach(), axis=-1).squeeze() for logit in logits]
return torch.stack(action,dim=-1)
else:
return action.squeeze().detach()
else:
if is_determenistic:
if is_deterministic:
return torch.argmax(logits.detach(), axis=-1).squeeze()
else:
return action.squeeze().detach()

def get_action(self, obs, is_determenistic = False):
def get_action(self, obs, is_deterministic = False):
if self.has_batch_dimension == False:
obs = unsqueeze_obs(obs)
obs = self._preproc_obs(obs)
Expand All @@ -155,13 +155,13 @@ def get_action(self, obs, is_determenistic = False):
action = res_dict['actions']
self.states = res_dict['rnn_states']
if self.is_multi_discrete:
if is_determenistic:
if is_deterministic:
action = [torch.argmax(logit.detach(), axis=1).squeeze() for logit in logits]
return torch.stack(action,dim=-1)
else:
return action.squeeze().detach()
else:
if is_determenistic:
if is_deterministic:
return torch.argmax(logits.detach(), axis=-1).squeeze()
else:
return action.squeeze().detach()
Expand Down Expand Up @@ -210,11 +210,11 @@ def restore(self, fn):
if self.normalize_input and 'running_mean_std' in checkpoint:
self.model.running_mean_std.load_state_dict(checkpoint['running_mean_std'])

def get_action(self, obs, is_determenistic=False):
def get_action(self, obs, is_deterministic=False):
if self.has_batch_dimension == False:
obs = unsqueeze_obs(obs)
dist = self.model.actor(obs)
actions = dist.sample() if is_determenistic else dist.mean
actions = dist.sample() if is_deterministic else dist.mean
actions = actions.clamp(*self.action_range).to(self.device)
if self.has_batch_dimension == False:
actions = torch.squeeze(actions.detach())
Expand Down
15 changes: 9 additions & 6 deletions rl_games/common/player.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,10 @@ def __init__(self, params):
self.device_name = self.config.get('device_name', 'cuda')
self.render_env = self.player_config.get('render', False)
self.games_num = self.player_config.get('games_num', 2000)
self.is_determenistic = self.player_config.get('determenistic', True)
if 'deterministic' in self.player_config:
self.is_deterministic = self.player_config['deterministic']
else:
self.is_deterministic = self.player_config.get('determenistic', True)
self.n_game_life = self.player_config.get('n_game_life', 1)
self.print_stats = self.player_config.get('print_stats', True)
self.render_sleep = self.player_config.get('render_sleep', 0.002)
Expand Down Expand Up @@ -143,10 +146,10 @@ def set_weights(self, weights):
def create_env(self):
return env_configurations.configurations[self.env_name]['env_creator'](**self.env_config)

def get_action(self, obs, is_determenistic=False):
def get_action(self, obs, is_deterministic=False):
raise NotImplementedError('step')

def get_masked_action(self, obs, mask, is_determenistic=False):
def get_masked_action(self, obs, mask, is_deterministic=False):
raise NotImplementedError('step')

def reset(self):
Expand All @@ -162,7 +165,7 @@ def run(self):
n_games = self.games_num
render = self.render_env
n_game_life = self.n_game_life
is_determenistic = self.is_determenistic
is_deterministic = self.is_deterministic
sum_rewards = 0
sum_steps = 0
sum_game_res = 0
Expand Down Expand Up @@ -203,9 +206,9 @@ def run(self):
if has_masks:
masks = self.env.get_action_mask()
action = self.get_masked_action(
obses, masks, is_determenistic)
obses, masks, is_deterministic)
else:
action = self.get_action(obses, is_determenistic)
action = self.get_action(obses, is_deterministic)

obses, r, done, info = self.env_step(self.env, action)
cr += r
Expand Down
5 changes: 4 additions & 1 deletion rl_games/configs/ppo_pendulum_torch.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ params:
#scale: 0.001

config:
env_name: Pendulum-v0
env_name: openai_gym
reward_shaper:
scale_value: 0.01
normalize_advantage: True
Expand All @@ -56,3 +56,6 @@ params:

normalize_input: False
bounds_loss_coef: 0

env_config:
name: Pendulum-v1
6 changes: 3 additions & 3 deletions rl_games/envs/connect4_selfplay.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class ConnectFourSelfPlay(gym.Env):
def __init__(self, name="connect_four_v0", **kwargs):
gym.Env.__init__(self)
self.name = name
self.is_determenistic = kwargs.pop('is_determenistic', False)
self.is_determenistic = kwargs.pop('is_deterministic', False)
self.is_human = kwargs.pop('is_human', False)
self.random_agent = kwargs.pop('random_agent', False)
self.config_path = kwargs.pop('config_path')
Expand Down Expand Up @@ -62,7 +62,7 @@ def reset(self):
if self.random_agent:
opponent_action = np.random.choice(ids, 1)[0]
else:
opponent_action = self.agent.get_masked_action(op_obs, mask, self.is_determenistic).item()
opponent_action = self.agent.get_masked_action(op_obs, mask, self.is_deterministic).item()


obs, _, _, _ = self.env_step(opponent_action)
Expand Down Expand Up @@ -107,7 +107,7 @@ def step(self, action):
if self.random_agent:
opponent_action = np.random.choice(ids, 1)[0]
else:
opponent_action = self.agent.get_masked_action(op_obs, mask, self.is_determenistic).item()
opponent_action = self.agent.get_masked_action(op_obs, mask, self.is_deterministic).item()
obs, reward, done,_ = self.env_step(opponent_action)
if done:
if reward == -1:
Expand Down
4 changes: 2 additions & 2 deletions rl_games/envs/slimevolley_selfplay.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class SlimeVolleySelfplay(gym.Env):
def __init__(self, name="SlimeVolleyDiscrete-v0", **kwargs):
gym.Env.__init__(self)
self.name = name
self.is_determenistic = kwargs.pop('is_determenistic', False)
self.is_deterministic = kwargs.pop('is_deterministic', False)
self.config_path = kwargs.pop('config_path')
self.agent = None
self.pos_scale = 1
Expand Down Expand Up @@ -45,7 +45,7 @@ def create_agent(self, config='rl_games/configs/ma/ppo_slime_self_play.yaml'):
def step(self, action):
op_obs = self.agent.obs_to_torch(self.opponent_obs)

opponent_action = self.agent.get_action(op_obs, self.is_determenistic).item()
opponent_action = self.agent.get_action(op_obs, self.is_deterministic).item()
obs, reward, done, info = self.env.step(action, opponent_action)
self.sum_rewards += reward
if reward < 0:
Expand Down

0 comments on commit 758ac4f

Please sign in to comment.