Skip to content

Commit

Permalink
hotfix/retain_hidden_state (#78)
Browse files Browse the repository at this point in the history
* let the hidden be retained

* rl2_refactoring

* .

Co-authored-by: dongminlee94 <[email protected]>
  • Loading branch information
Clyde21c and dongminlee94 committed Jun 20, 2022
1 parent 20ebcd5 commit a0820f7
Show file tree
Hide file tree
Showing 11 changed files with 38 additions and 27 deletions.
1 change: 1 addition & 0 deletions .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ disable=print-statement,
protected-access,
used-before-assignment,
line-too-long,
too-few-public-methods,

# Enable the message, report, category or checker with the given id(s). You can
# either give multiple identifier separated by comma (,) or put this option
Expand Down
8 changes: 4 additions & 4 deletions src/meta_rl/maml/algorithm/meta_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(
self.test_interval = test_interval

self.num_iterations = config["num_iterations"]
self.num_sample_tasks = config["num_sample_tasks"]
self.meta_batch_size = config["meta_batch_size"]
self.num_samples = config["num_samples"]
self.max_steps = config["max_steps"]

Expand All @@ -73,7 +73,7 @@ def __init__(
observ_dim=observ_dim,
action_dim=action_dim,
agent=agent,
num_tasks=max(self.num_sample_tasks, len(self.test_tasks)),
num_tasks=max(self.meta_batch_size, len(self.test_tasks)),
num_episodes=(self.num_adapt_epochs + 1), # [num of adapatation for train] + [validation]
max_size=self.num_samples,
device=device,
Expand Down Expand Up @@ -159,7 +159,7 @@ def meta_surrogate_loss(self, set_grad: bool) -> Tuple[torch.Tensor, ...]:
backup_params = dict(self.agent.policy.named_parameters())

# Compute loss for each sampled task
for cur_task in range(self.num_sample_tasks):
for cur_task in range(self.meta_batch_size):
# Inner loop
# Adapt policy to each task through few grandient steps
for cur_adapt in range(self.num_adapt_epochs):
Expand Down Expand Up @@ -253,7 +253,7 @@ def meta_train(self) -> None:
print(f"\n=============== Iteration {iteration} ===============")
# Sample batch of tasks randomly from train task distribution and
# optain adaptating data for each batch task
indices = np.random.randint(len(self.train_tasks), size=self.num_sample_tasks)
indices = np.random.randint(len(self.train_tasks), size=self.meta_batch_size)
self.collect_train_data(indices)

# Meta update
Expand Down
2 changes: 1 addition & 1 deletion src/meta_rl/maml/configs/dir_target_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ maml_params:
# Number of training iterations
num_iterations: 1000
# Number of task samples for training
num_sample_tasks: 40
meta_batch_size: 40
# Number of samples per task to train
num_samples: 2000
# Maximum steps for the environment
Expand Down
2 changes: 1 addition & 1 deletion src/meta_rl/maml/configs/vel_target_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ maml_params:
# Number of training iterations
num_iterations: 1000
# Number of task samples for training
num_sample_tasks: 40
meta_batch_size: 40
# Number of samples per task to train
num_samples: 4000
# Maximum steps for the environment
Expand Down
2 changes: 1 addition & 1 deletion src/meta_rl/pearl/configs/dir_target_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ pearl_params:
num_posterior_samples: 1000
# Number of meta-gradient taken per iteration
num_meta_grads: 1500
# Number of tasks to average the gradient across
# Number of task samples for training
meta_batch_size: 4
# Number of samples in the context batch
batch_size: 256
Expand Down
2 changes: 1 addition & 1 deletion src/meta_rl/pearl/configs/vel_target_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ pearl_params:
num_posterior_samples: 600
# Number of meta-gradient taken per iteration
num_meta_grads: 1500
# Number of tasks to average the gradient across
# Number of task samples for training
meta_batch_size: 16
# Number of samples in the context batch
batch_size: 100
Expand Down
8 changes: 5 additions & 3 deletions src/meta_rl/rl2/algorithm/meta_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,10 @@ def __init__(
self.test_tasks = test_tasks

self.num_iterations: int = config["num_iterations"]
self.meta_batch_size: int = config["meta_batch_size"]
self.num_samples: int = config["num_samples"]

self.batch_size: int = len(train_tasks) * config["num_samples"]
self.batch_size: int = self.meta_batch_size * config["num_samples"]
self.max_step: int = config["max_step"]

self.sampler = Sampler(
Expand Down Expand Up @@ -100,11 +101,12 @@ def meta_train(self) -> None:

print(f"=============== Iteration {iteration} ===============")
# Sample data randomly from train tasks.
for index in range(len(self.train_tasks)):
indices = np.random.randint(len(self.train_tasks), size=self.meta_batch_size)
for i, index in enumerate(indices):
self.env.reset_task(index)
self.agent.policy.is_deterministic = False

print(f"[{index + 1}/{len(self.train_tasks)}] collecting samples")
print(f"[{i + 1}/{self.meta_batch_size}] collecting samples")
trajs: List[Dict[str, np.ndarray]] = self.sampler.obtain_samples(
max_samples=self.num_samples,
)
Expand Down
20 changes: 12 additions & 8 deletions src/meta_rl/rl2/algorithm/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,19 @@ def __init__(
self.hidden_dim = hidden_dim
self.max_step = max_step
self.cur_samples = 0
self.pi_hidden = None
self.v_hidden = None

def obtain_samples(self, max_samples: int) -> List[Dict[str, np.ndarray]]:
"""Obtain samples up to the number of maximum samples"""
self.pi_hidden = np.zeros((1, self.hidden_dim))
self.v_hidden = np.zeros((1, self.hidden_dim))

trajs = []
while not self.cur_samples == max_samples:
traj = self.rollout(max_samples)
trajs.append(traj)

self.cur_samples = 0
return trajs

Expand All @@ -55,22 +61,20 @@ def rollout(self, max_samples: int) -> Dict[str, np.ndarray]:
action = np.zeros(self.action_dim)
reward = np.zeros(1)
done = np.zeros(1)
pi_hidden = np.zeros((1, self.hidden_dim))
v_hidden = np.zeros((1, self.hidden_dim))

while not (done or cur_step == self.max_step or self.cur_samples == max_samples):
tran = np.concatenate((obs, action, reward, done), axis=-1).reshape(1, -1)
action, log_prob, next_pi_hidden = self.agent.get_action(tran, pi_hidden)
value, next_v_hidden = self.agent.get_value(tran, v_hidden)
action, log_prob, next_pi_hidden = self.agent.get_action(tran, self.pi_hidden)
value, next_v_hidden = self.agent.get_value(tran, self.v_hidden)

next_obs, reward, done, info = self.env.step(action)
reward = np.array(reward).reshape(-1)
done = np.array(int(done)).reshape(-1)

# Flatten out the samples needed to train and add them to each list
trans.append(tran.reshape(-1))
pi_hiddens.append(pi_hidden.reshape(-1))
v_hiddens.append(v_hidden.reshape(-1))
pi_hiddens.append(self.pi_hidden.reshape(-1))
v_hiddens.append(self.v_hidden.reshape(-1))
actions.append(action)
rewards.append(reward)
dones.append(done)
Expand All @@ -79,8 +83,8 @@ def rollout(self, max_samples: int) -> Dict[str, np.ndarray]:
log_probs.append(log_prob.reshape(-1))

obs = next_obs.reshape(-1)
pi_hidden = next_pi_hidden[0]
v_hidden = next_v_hidden[0]
self.pi_hidden = next_pi_hidden[0]
self.v_hidden = next_v_hidden[0]
cur_step += 1
self.cur_samples += 1
return dict(
Expand Down
8 changes: 5 additions & 3 deletions src/meta_rl/rl2/configs/dir_target_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ train_tasks: 2
test_tasks: 2

# Number of hidden units in neural networks
hidden_dim: 256
hidden_dim: 64

# RL^2 setup
# ----------
Expand All @@ -20,6 +20,8 @@ rl2_params:
num_samples: 800
# Maximum step for the environment
max_step: 200
# Number of task samples for training
meta_batch_size: 4
# Number of early stopping conditions
num_stop_conditions: 3
# Goal value used to early stopping condition
Expand All @@ -31,9 +33,9 @@ ppo_params:
# Discount factor
gamma: 0.99
# Number of epochs per iteration
num_epochs: 5
num_epochs: 10
# Number of minibatch within each epoch
mini_batch_size: 128
mini_batch_size: 32
# PPO clip parameter
clip_param: 0.3
# Learning rate of PPO losses
Expand Down
2 changes: 1 addition & 1 deletion src/meta_rl/rl2/configs/experiment_config.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# RL^2 Experiment configs

# Select either dir or vel (str)
env_name: "vel"
env_name: "dir"

# Set an experiment name to save (str)
save_exp_name: "exp_1"
Expand Down
10 changes: 6 additions & 4 deletions src/meta_rl/rl2/configs/vel_target_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
# General setup
# -------------
# Number of tasks for meta-train
train_tasks: 200
train_tasks: 300

# Number of tasks for meta-test
test_tasks: 15

# Number of hidden units in neural networks
hidden_dim: 256
hidden_dim: 64

# RL^2 setup
# ----------
Expand All @@ -20,6 +20,8 @@ rl2_params:
num_samples: 1200
# Maximum step for the environment
max_step: 200
# Number of task samples for training
meta_batch_size: 10
# Number of early stopping conditions
num_stop_conditions: 3
# Goal value used to early stopping condition
Expand All @@ -31,9 +33,9 @@ ppo_params:
# Discount factor
gamma: 0.99
# Number of epochs per iteration
num_epochs: 5
num_epochs: 10
# Number of minibatch within each epoch
mini_batch_size: 128
mini_batch_size: 32
# PPO clip parameter
clip_param: 0.3
# Learning rate of PPO losses
Expand Down

0 comments on commit a0820f7

Please sign in to comment.