Skip to content

Commit

Permalink
update slm-lab part
Browse files Browse the repository at this point in the history
  • Loading branch information
sungjinl committed Jun 1, 2019
1 parent b0b1edd commit ee2cdb6
Show file tree
Hide file tree
Showing 43 changed files with 3,448 additions and 4,848 deletions.
500 changes: 225 additions & 275 deletions convlab/agent/__init__.py

Large diffs are not rendered by default.

1 change: 0 additions & 1 deletion convlab/agent/algorithm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
# expose all the classes
from .actor_critic import *
from .dqn import *
from .hydra_dqn import *
from .ppo import *
from .random import *
from .reinforce import *
Expand Down
311 changes: 133 additions & 178 deletions convlab/agent/algorithm/actor_critic.py

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion convlab/agent/algorithm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,8 @@ def load(self):
for k, v in vars(self).items():
if k.endswith('_scheduler'):
var_name = k.replace('_scheduler', '')
setattr(self.body, var_name, v.end_val)
if hasattr(v, 'end_val'):
setattr(self.body, var_name, v.end_val)

# NOTE optional extension for multi-agent-env

Expand Down
192 changes: 68 additions & 124 deletions convlab/agent/algorithm/dqn.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
# Modified by Microsoft Corporation.
# Licensed under the MIT license.

from convlab.agent import net
from convlab.agent.algorithm import policy_util
from convlab.agent.algorithm.sarsa import SARSA
from convlab.agent.net import net_util
from convlab.lib import logger, util
from convlab.lib import logger, math_util, util
from convlab.lib.decorator import lab_api
import numpy as np
import pydash as ps
Expand Down Expand Up @@ -46,11 +43,10 @@ class VanillaDQN(SARSA):
"end_step": 1000,
},
"gamma": 0.99,
"training_batch_epoch": 8,
"training_epoch": 4,
"training_batch_iter": 8,
"training_iter": 4,
"training_frequency": 10,
"training_start_step": 10,
"normalize_state": true
}
'''

Expand All @@ -66,63 +62,63 @@ def init_algorithm_params(self):
'action_pdtype',
'action_policy',
'rule_guide_max_epi',
"rule_guide_frequency",
'rule_guide_frequency',
# explore_var is epsilon, tau or etc. depending on the action policy
# these control the trade off between exploration and exploitaton
'explore_var_spec',
'gamma', # the discount factor
'training_batch_epoch', # how many gradient updates per batch
'training_epoch', # how many batches to train each time
'training_batch_iter', # how many gradient updates per batch
'training_iter', # how many batches to train each time
'training_frequency', # how often to train (once a few timesteps)
'training_start_step', # how long before starting training
'normalize_state',
])
super(VanillaDQN, self).init_algorithm_params()
super().init_algorithm_params()

@lab_api
def init_nets(self, global_nets=None):
'''Initialize the neural network used to learn the Q function from the spec'''
if self.algorithm_spec['name'] == 'VanillaDQN':
assert all(k not in self.net_spec for k in ['update_type', 'update_frequency', 'polyak_coef']), 'Network update not available for VanillaDQN; use DQN.'
if global_nets is None:
in_dim = self.body.state_dim
out_dim = net_util.get_out_dim(self.body)
NetClass = getattr(net, self.net_spec['type'])
self.net = NetClass(self.net_spec, in_dim, out_dim)
self.net_names = ['net']
else:
util.set_attr(self, global_nets)
self.net_names = list(global_nets.keys())
in_dim = self.body.state_dim
out_dim = net_util.get_out_dim(self.body)
NetClass = getattr(net, self.net_spec['type'])
self.net = NetClass(self.net_spec, in_dim, out_dim)
self.net_names = ['net']
# init net optimizer and its lr scheduler
self.optim = net_util.get_optim(self.net, self.net.optim_spec)
self.lr_scheduler = net_util.get_lr_scheduler(self.optim, self.net.lr_scheduler_spec)
net_util.set_global_nets(self, global_nets)
self.post_init_nets()

def calc_q_loss(self, batch):
'''Compute the Q value loss using predicted and target Q values from the appropriate networks'''
q_preds = self.net.wrap_eval(batch['states'])
states = batch['states']
next_states = batch['next_states']
q_preds = self.net(states)
with torch.no_grad():
next_q_preds = self.net(next_states)
act_q_preds = q_preds.gather(-1, batch['actions'].long().unsqueeze(-1)).squeeze(-1)
next_q_preds = self.net.wrap_eval(batch['next_states'])
# Bellman equation: compute max_q_targets using reward and max estimated Q values (0 if no next_state)
max_next_q_preds, _ = next_q_preds.max(dim=-1, keepdim=True)
max_q_targets = batch['rewards'] + self.gamma * (1 - batch['dones']) * max_next_q_preds
max_q_targets = max_q_targets.detach()
logger.debug(f'act_q_preds: {act_q_preds}\nmax_q_targets: {max_q_targets}')
q_loss = self.net.loss_fn(act_q_preds, max_q_targets)

# TODO use the same loss_fn but do not reduce yet
if 'Prioritized' in util.get_class_name(self.body.memory): # PER
errors = torch.abs(max_q_targets - act_q_preds.detach())
errors = (max_q_targets - act_q_preds.detach()).abs().cpu().numpy()
self.body.memory.update_priorities(errors)
return q_loss

@lab_api
def act(self, state):
'''Selects and returns a discrete action for body using the action policy'''
return super(VanillaDQN, self).act(state)
return super().act(state)

@lab_api
def sample(self):
'''Samples a batch from memory of size self.memory_spec['batch_size']'''
batch = self.body.memory.sample()
if self.normalize_state:
batch = policy_util.normalize_states_and_next_states(self.body, batch)
batch = util.to_torch_batch(batch, self.net.device, self.body.memory.is_episodic)
return batch

Expand All @@ -136,32 +132,34 @@ def train(self):
Otherwise this function does nothing.
'''
if util.in_eval_lab_modes():
self.body.flush()
return np.nan
clock = self.body.env.clock
tick = clock.get(clock.max_tick_unit)
self.to_train = (tick > self.training_start_step and tick % self.training_frequency == 0)
if self.to_train == 1:
total_loss = torch.tensor(0.0, device=self.net.device)
for _ in range(self.training_epoch):
batch = self.sample()
for _ in range(self.training_batch_epoch):
total_loss = torch.tensor(0.0)
# for _ in range(self.training_iter):
# batch = self.sample()
# clock.set_batch_size(len(batch))
# for _ in range(self.training_batch_iter):
num_batches = int(self.body.memory.size / self.body.memory.batch_size)
for _ in range(self.training_iter):
# clock.set_batch_size(len(batch))
for _ in range(min(self.training_batch_iter, num_batches)):
batch = self.sample()
loss = self.calc_q_loss(batch)
self.net.training_step(loss=loss, lr_clock=clock)
self.net.train_step(loss, self.optim, self.lr_scheduler, clock=clock, global_net=self.global_net)
total_loss += loss
loss = total_loss / (self.training_epoch * self.training_batch_epoch)
loss = total_loss / (self.training_iter * self.training_batch_iter)
# reset
self.to_train = 0
self.body.flush()
logger.debug(f'Trained {self.name} at epi: {clock.epi}, total_t: {clock.total_t}, t: {clock.t}, total_reward so far: {self.body.memory.total_reward}, loss: {loss:g}')
logger.debug(f'Trained {self.name} at epi: {clock.epi}, frame: {clock.frame}, t: {clock.t}, total_reward so far: {self.body.total_reward}, loss: {loss:g}')
return loss.item()
else:
return np.nan

@lab_api
def update(self):
'''Update the agent after training'''
return super(VanillaDQN, self).update()
return super().update()


class DQNBase(VanillaDQN):
Expand All @@ -185,100 +183,57 @@ def init_nets(self, global_nets=None):
'''Initialize networks'''
if self.algorithm_spec['name'] == 'DQNBase':
assert all(k not in self.net_spec for k in ['update_type', 'update_frequency', 'polyak_coef']), 'Network update not available for DQNBase; use DQN.'
if global_nets is None:
in_dim = self.body.state_dim
out_dim = net_util.get_out_dim(self.body)
NetClass = getattr(net, self.net_spec['type'])
self.net = NetClass(self.net_spec, in_dim, out_dim)
self.target_net = NetClass(self.net_spec, in_dim, out_dim)
self.net_names = ['net', 'target_net']
else:
util.set_attr(self, global_nets)
self.net_names = list(global_nets.keys())
in_dim = self.body.state_dim
out_dim = net_util.get_out_dim(self.body)
NetClass = getattr(net, self.net_spec['type'])
self.net = NetClass(self.net_spec, in_dim, out_dim)
self.target_net = NetClass(self.net_spec, in_dim, out_dim)
self.net_names = ['net', 'target_net']
# init net optimizer and its lr scheduler
self.optim = net_util.get_optim(self.net, self.net.optim_spec)
self.lr_scheduler = net_util.get_lr_scheduler(self.optim, self.net.lr_scheduler_spec)
net_util.set_global_nets(self, global_nets)
self.post_init_nets()
self.online_net = self.target_net
self.eval_net = self.target_net

def calc_q_loss(self, batch):
'''Compute the Q value loss using predicted and target Q values from the appropriate networks'''
q_preds = self.net.wrap_eval(batch['states'])
# q_preds = self.net(batch['states'])
states = batch['states']
next_states = batch['next_states']
q_preds = self.net(states)
with torch.no_grad():
# Use online_net to select actions in next state
online_next_q_preds = self.online_net(next_states)
# Use eval_net to calculate next_q_preds for actions chosen by online_net
next_q_preds = self.eval_net(next_states)
act_q_preds = q_preds.gather(-1, batch['actions'].long().unsqueeze(-1)).squeeze(-1)
# Use online_net to select actions in next state
online_next_q_preds = self.online_net.wrap_eval(batch['next_states'])
# Use eval_net to calculate next_q_preds for actions chosen by online_net
next_q_preds = self.eval_net.wrap_eval(batch['next_states'])
max_next_q_preds = next_q_preds.gather(-1, online_next_q_preds.argmax(dim=-1, keepdim=True)).squeeze(-1)
online_actions = online_next_q_preds.argmax(dim=-1, keepdim=True)
max_next_q_preds = next_q_preds.gather(-1, online_actions).squeeze(-1)
max_q_targets = batch['rewards'] + self.gamma * (1 - batch['dones']) * max_next_q_preds
max_q_targets = max_q_targets.detach()

# print(action_list[int(batch['actions'][0].item())])
# print(batch['actions'][0].item())
# print('{} vs {}'.format(act_q_preds.item(), max_q_targets.item()))

logger.debug(f'act_q_preds: {act_q_preds}\nmax_q_targets: {max_q_targets}')
q_loss = self.net.loss_fn(act_q_preds, max_q_targets)


# TODO use the same loss_fn but do not reduce yet
if 'Prioritized' in util.get_class_name(self.body.memory): # PER
errors = torch.abs(max_q_targets - act_q_preds.detach())
errors = (max_q_targets - act_q_preds.detach()).abs().cpu().numpy()
self.body.memory.update_priorities(errors)
return q_loss

@lab_api
def train(self):
'''
Completes one training step for the agent if it is time to train.
i.e. the environment timestep is greater than the minimum training timestep and a multiple of the training_frequency.
Each training step consists of sampling n batches from the agent's memory.
For each of the batches, the target Q values (q_targets) are computed and a single training step is taken k times
Otherwise this function does nothing.
'''
if util.in_eval_lab_modes():
self.body.flush()
return np.nan
clock = self.body.env.clock
tick = clock.get(clock.max_tick_unit)
self.to_train = (tick > self.training_start_step and tick % self.training_frequency == 0)
if self.to_train == 1:
total_loss = torch.tensor(0.0, device=self.net.device)
for epoch in range(self.training_epoch):
num_batches = int(self.body.memory.true_size / self.body.memory.batch_size)
for _ in range(num_batches):
batch = self.sample()
loss = self.calc_q_loss(batch)
self.net.training_step(loss=loss, lr_clock=clock)
total_loss += loss
loss = total_loss / (self.training_epoch * num_batches)
# reset
self.to_train = 0
self.body.flush()
logger.debug(f'Trained {self.name} at epi: {clock.epi}, total_t: {clock.total_t}, t: {clock.t}, total_reward so far: {self.body.memory.total_reward}, loss: {loss:g}')
return loss.item()
else:
return np.nan

def update_nets(self):
total_t = self.body.env.clock.total_t
if total_t % self.net.update_frequency == 0:
if util.frame_mod(self.body.env.clock.frame, self.net.update_frequency, self.body.env.num_envs):
if self.net.update_type == 'replace':
logger.debug('Updating target_net by replacing')
net_util.copy(self.net, self.target_net)
self.online_net = self.target_net
self.eval_net = self.target_net
elif self.net.update_type == 'polyak':
logger.debug('Updating net by averaging')
net_util.polyak_update(self.net, self.target_net, self.net.polyak_coef)
self.online_net = self.target_net
self.eval_net = self.target_net
else:
raise ValueError('Unknown net.update_type. Should be "replace" or "polyak". Exiting.')

@lab_api
def update(self):
'''Updates self.target_net and the explore variables'''
self.update_nets()
return super(DQNBase, self).update()
return super().update()


class DQN(DQNBase):
Expand All @@ -298,15 +253,15 @@ class DQN(DQNBase):
"end_step": 1000,
},
"gamma": 0.99,
"training_batch_epoch": 8,
"training_epoch": 4,
"training_batch_iter": 8,
"training_iter": 4,
"training_frequency": 10,
"training_start_step": 10
}
'''
@lab_api
def init_nets(self, global_nets=None):
super(DQN, self).init_nets(global_nets)
super().init_nets(global_nets)


class DoubleDQN(DQN):
Expand All @@ -326,25 +281,14 @@ class DoubleDQN(DQN):
"end_step": 1000,
},
"gamma": 0.99,
"training_batch_epoch": 8,
"training_epoch": 4,
"training_batch_iter": 8,
"training_iter": 4,
"training_frequency": 10,
"training_start_step": 10
}
'''
@lab_api
def init_nets(self, global_nets=None):
super(DoubleDQN, self).init_nets(global_nets)
super().init_nets(global_nets)
self.online_net = self.net
self.eval_net = self.target_net

def update_nets(self):
res = super(DoubleDQN, self).update_nets()
total_t = self.body.env.clock.total_t
if self.net.update_type == 'replace':
if total_t % self.net.update_frequency == 0:
self.online_net = self.net
self.eval_net = self.target_net
elif self.net.update_type == 'polyak':
self.online_net = self.net
self.eval_net = self.target_net
Loading

0 comments on commit ee2cdb6

Please sign in to comment.