Skip to content

Commit

Permalink
For fixing caching related instability and support for all seq2seq mo…
Browse files Browse the repository at this point in the history
…dels
  • Loading branch information
rajcscw committed Oct 17, 2022
1 parent d2a8f4f commit ac9eecf
Show file tree
Hide file tree
Showing 8 changed files with 268 additions and 129 deletions.
17 changes: 16 additions & 1 deletion rl4lms/algorithms/nlpo/nlpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def __init__(
self.clip_range_vf = clip_range_vf
self.normalize_advantage = normalize_advantage
self.target_kl = target_kl
self._tracker = tracker

if _init_setup_model:
self._setup_model()
Expand Down Expand Up @@ -333,7 +334,7 @@ def train(self) -> None:
for epoch in range(self.n_epochs):
approx_kl_divs = []
# Do a complete pass on the rollout buffer
for rollout_data in self.rollout_buffer.get(self.batch_size):
for batch_ix, rollout_data in enumerate(list(self.rollout_buffer.get(self.batch_size))):
actions = rollout_data.actions
if isinstance(self.action_space, spaces.Discrete):
# Convert discrete action from float to long
Expand All @@ -354,6 +355,11 @@ def train(self) -> None:

# ratio between old and new policy, should be one at the first iteration
ratio = th.exp(log_prob - rollout_data.old_log_prob)
if batch_ix == 0 and epoch == 0:
assert th.allclose(th.mean(ratio), th.tensor(
1.0), atol=1e-3), f"Ratio is {th.mean(ratio)}"

assert th.allclose(values, rollout_data.old_values, atol=1e-3)

# clipped surrogate loss
policy_loss_1 = advantages * ratio
Expand Down Expand Up @@ -437,6 +443,15 @@ def train(self) -> None:
if self.clip_range_vf is not None:
self.logger.record("train/clip_range_vf", clip_range_vf)

train_info = {
"ppo/entropy_loss": np.mean(entropy_losses).item(),
"ppo/policy_gradient_loss": np.mean(pg_losses).item(),
"ppo/value_loss": np.mean(value_losses).item(),
"ppo/approx_kl": np.mean(approx_kl_divs).item(),
}

self._tracker.log_training_infos(train_info)

def learn(
self,
total_timesteps: int,
Expand Down
18 changes: 5 additions & 13 deletions rl4lms/algorithms/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,16 +175,6 @@ def _setup_model(self) -> None:

self.clip_range_vf = get_schedule_fn(self.clip_range_vf)

# def verify_rollout_data(self, rollout_data):
# for rollout_data in self.rollout_buffer.get(self.batch_size):
# actions = rollout_data.actions.long().flatten()
# values, log_prob, _ = self.policy.evaluate_actions(
# rollout_data.observations, actions)

# assert th.allclose(
# values.flatten(), rollout_data.old_values.flatten(), atol=1e-4)
# assert th.allclose(log_prob, rollout_data.old_log_prob, atol=1e-4)

def train(self) -> None:
"""
Update policy using the currently gathered rollout buffer.
Expand Down Expand Up @@ -232,9 +222,11 @@ def train(self) -> None:

# ratio between old and new policy, should be one at the first iteration
ratio = th.exp(log_prob - rollout_data.old_log_prob)
# if batch_ix == 0 and epoch == 0:
# assert th.allclose(th.mean(ratio), th.tensor(
# 1.0)), f"Ratio is {th.mean(ratio)}"
if batch_ix == 0 and epoch == 0:
assert th.allclose(th.mean(ratio), th.tensor(
1.0), atol=1e-3), f"Ratio is {th.mean(ratio)}"

assert th.allclose(values, rollout_data.old_values, atol=1e-3)

# clipped surrogate loss
policy_loss_1 = advantages * ratio
Expand Down
58 changes: 19 additions & 39 deletions rl4lms/envs/text_generation/alg_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from stable_baselines3.common.utils import obs_as_tensor
from stable_baselines3.common.vec_env import VecEnv
from transformers import PreTrainedTokenizer
from rl4lms.envs.text_generation.policy import PolicyType


def unpack_observations(obs_tensor, n_envs: int):
Expand Down Expand Up @@ -139,36 +140,27 @@ def generate_batch(self,
with torch.no_grad():
obs_tensor = obs_as_tensor(current_obs, self.device)

# # get log probs from policy
# _, cache_log_prob, _, _, policy_past_state = self.policy.forward_policy(
# obs_tensor, actions_tensor, policy_past_state)
# for seq2seq policy
if self.policy.get_policy_type() == PolicyType.SEQ2SEQ:
# overrdide log probs without caching
_, log_probs, _, _, _, _ = self.policy.forward_policy(
obs_tensor, actions_tensor, action_mask)

# _, without_cache_log_prob, _, _, policy_past_state = self.policy.forward_policy(
# obs_tensor, actions_tensor, None)
# # sanity check 0 - rollout probs and policy probs must match
# assert torch.allclose(cache_log_prob, log_probs, atol=1e-3)
# get values without caching
values, value_past_state = self.policy.forward_value(obs_tensor)

# # sanity check 1 - log probs with and without cache must match
# assert torch.allclose(
# cache_log_prob, without_cache_log_prob, atol=1e-3)
# get reference log probs
ref_log_probs, ref_past_state = self.policy.get_log_probs_ref_model(obs_tensor,
actions_tensor)
else: # causal policy
# get values
values, value_past_state = self.policy.forward_value(obs_tensor,
value_past_state)

# get values
values, value_past_state = self.policy.forward_value(obs_tensor,
value_past_state)

# get reference log probs
ref_log_probs, ref_past_state = self.policy.get_log_probs_ref_model(obs_tensor,
actions_tensor,
ref_past_state)

# sanity check 2 (this is without caching - must match with values from generate which is with caching)
# eval_values, eval_log_probs, _ = self.policy.evaluate_actions(
# obs_tensor, actions_tensor)

# assert torch.allclose(
# eval_log_probs, without_cache_log_prob, atol=1e-3)
# assert torch.allclose(
# eval_values, values, atol=1e-3)
# get reference log probs
ref_log_probs, ref_past_state = self.policy.get_log_probs_ref_model(obs_tensor,
actions_tensor,
ref_past_state)

# compute KL rewards
kl_div = log_probs - ref_log_probs
Expand Down Expand Up @@ -321,18 +313,6 @@ def collect_rollouts(
# adapt the KL coeff
self._kl_controller.step(torch.tensor(
aggregated_rollout_info["rollout_info/kl_div_mean"]))

# sanity check 3: now, loop over the buffer
# and check the log_probs and values match
# for rollout_data in self.rollout_buffer.get(self.batch_size):
# actions = rollout_data.actions.long().flatten()
# values, log_prob, entropy = self.policy.evaluate_actions(
# rollout_data.observations, actions)

# assert torch.allclose(
# values.flatten(), rollout_data.old_values.flatten(), atol=1e-4)
# assert torch.allclose(
# log_prob, rollout_data.old_log_prob, atol=1e-4)
return True

# instantiate the wrapped alg
Expand Down
2 changes: 1 addition & 1 deletion rl4lms/envs/text_generation/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __init__(self, tokenizer: AutoTokenizer,
super().__init__()

# set the observation and action space here
self._vocab_size = len(tokenizer.vocab)
self._vocab_size = tokenizer.vocab_size
self.observation_space = DictSpace({
# we have to provide fixed sized inputs (padded) because sb3 support for DictObsersevation is limited
# while creating rollout buffers, observations are concatenated for each key
Expand Down
19 changes: 13 additions & 6 deletions rl4lms/envs/text_generation/observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,18 +136,25 @@ def init_from_sample(cls, sample: Sample,
return_attention_mask=True,
truncation=True)
tokenizer.truncation_side = prev_truncation_side


# encode the context text
context_outputs = tokenizer("",
# for seq2seq models, context should be initialized to start token if provided
if context_start_token is not None:
context_outputs = tokenizer("",
padding="max_length",
max_length=max_context_length,
return_tensors="pt",
return_attention_mask=True)

# for seq2seq models, context should be initialized to start token if provided
if context_start_token is not None:
context_outputs.input_ids = torch.ones(1, max_context_length, dtype=torch.int32) * tokenizer.pad_token_id
context_outputs.input_ids[:, -1] = context_start_token
context_outputs.attention_mask[:, -1] = 1
context_outputs.attention_mask = torch.zeros(1, max_context_length, dtype=torch.int32)
context_outputs.attention_mask[:,-1] = 1
else:
context_outputs = tokenizer("",
padding="max_length",
max_length=max_context_length,
return_tensors="pt",
return_attention_mask=True)

# concatenate
input_encoded_pt, input_attention_mask_pt = Observation._concat(
Expand Down
Loading

0 comments on commit ac9eecf

Please sign in to comment.