Skip to content

Commit

Permalink
update ppo_impala
Browse files Browse the repository at this point in the history
  • Loading branch information
wisnunugroho21 committed Nov 4, 2020
1 parent 925e305 commit 6faa2ab
Showing 1 changed file with 20 additions and 16 deletions.
36 changes: 20 additions & 16 deletions PPO_Impala/pytorch/ppo_impala.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class Memory(Dataset):
def __init__(self):
self.states = []
self.actions = []
self.action_means = []
self.action_means = []
self.rewards = []
self.dones = []
self.next_states = []
Expand Down Expand Up @@ -292,16 +292,19 @@ def update_ppo(self):
def get_weights(self):
return self.actor.state_dict()

def save_weights(self):
torch.save(self.actor.state_dict(), 'weights/agent.pth')

class Agent:
def __init__(self, state_dim, action_dim, is_training_mode):
self.std = torch.ones([1, action_dim]).float()
self.is_training_mode = is_training_mode
self.is_training_mode = is_training_mode

self.memory = Memory()
self.distributions = Distributions()
self.actor = Actor_Model(state_dim, action_dim, torch.device('cpu'))
self.device = torch.device('cpu')

self.device = torch.device('cpu')
self.memory = Memory()
self.distributions = Distributions(self.device)
self.actor = Actor_Model(state_dim, action_dim, torch.device('cpu'))

if is_training_mode:
self.actor.train()
Expand All @@ -326,10 +329,13 @@ def act(self, state):
else:
action = action_mean

return action.squeeze(0).numpy(), action_mean.squeeze(0).numpy()
return action.squeeze(0).numpy(), action_mean.squeeze(0).detach().numpy()

def set_weights(self, weights):
self.actor.load_state_dict(weights.to(self.device))
self.actor.load_state_dict(weights)

def load_weights(self):
self.actor.load_state_dict(torch.load('weights/agent.pth', map_location = self.device))

def plot(datas):
print('----------')
Expand All @@ -345,11 +351,10 @@ def plot(datas):
print('Avg :', np.mean(datas))

@ray.remote
def run_episode(env_name, state_dim, action_dim, is_training_mode, weights, render, training_mode, t_updates, n_update):
print('Testing')
def run_episode(env_name, state_dim, action_dim, is_training_mode, render, training_mode, t_updates, n_update):
env = gym.make(env_name)
agent = Agent(state_dim, action_dim, is_training_mode)
agent.set_weights(weights)
agent.load_weights()
############################################
state = env.reset()
done = False
Expand All @@ -358,7 +363,6 @@ def run_episode(env_name, state_dim, action_dim, is_training_mode, weights, rend
############################################

while not done:
print('test: ', eps_time)
action, action_mean = agent.act(state)
next_state, reward, done, _ = env.step(action)

Expand Down Expand Up @@ -405,10 +409,10 @@ def main():
learner = Learner(state_dim, action_dim, training_mode, policy_kl_range, policy_params, value_clip, entropy_coef, vf_loss_coef,
minibatch, PPO_epochs, gamma, lam, learning_rate)
#############################################
weights = learner.get_weights()
learner.save_weights()
t_updates = 0

episode_ids = [run_episode.remote(env_name, state_dim, action_dim, training_mode, weights, render, training_mode, t_updates, n_update)
episode_ids = [run_episode.remote(env_name, state_dim, action_dim, training_mode, render, training_mode, t_updates, n_update)
for i in range(4)]

for i_episode in range(1, n_episode + 1):
Expand All @@ -420,10 +424,10 @@ def main():
states, actions, action_means, rewards, dones, next_states = trajectory
learner.save_all(states, actions, action_means, rewards, dones, next_states)
learner.update_ppo()
weights = learner.get_weights()
learner.save_weights()

episode_ids = not_ready
episode_ids.append(run_episode.remote(env_name, state_dim, action_dim, training_mode, weights, render, training_mode, t_updates, n_update))
episode_ids.append(run_episode.remote(env_name, state_dim, action_dim, training_mode, render, training_mode, t_updates, n_update))

print('Episode {} \t t_reward: {} \t time: {} \t '.format(i_episode, total_reward, time))

Expand Down

0 comments on commit 6faa2ab

Please sign in to comment.