Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Example] RLHF end to end example #1324

Closed
wants to merge 30 commits into from
Closed

Conversation

apbard
Copy link
Contributor

@apbard apbard commented Jun 27, 2023

merge after #1309, #1319, #1316, #1315 + rebase

Adds a complete end 2 end RLHF pipeline

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 27, 2023
examples/rlhf/train_rlhf.py Outdated Show resolved Hide resolved
Comment on lines 57 to 60
"""Returns adaptively updated KL coefficient, βₜ₊₁.
Arguments:
current: The current KL value between the newest policy and the initial policy.
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wrong formatting

examples/rlhf/train_rlhf.py Outdated Show resolved Hide resolved
examples/rlhf/train_rlhf.py Outdated Show resolved Hide resolved

For debugging purposes, we also generate responses to a fixed prompt so that the
quality of the model can be visually assessed during training.
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing args and example

batch = next(dataloader)
# NOTE: disable kl for evaluation
td = rollout_from_model.rollout_from_data(batch, kl_coef=0.0)
rewards[k] = td.get(("next", "reward")).sum(dim=1).mean().item()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why item?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to get a scalar instead of a scalar-tensor

examples/rlhf/train_rlhf.py Outdated Show resolved Hide resolved
examples/rlhf/train_rlhf.py Outdated Show resolved Hide resolved
examples/rlhf/train_rlhf.py Outdated Show resolved Hide resolved
examples/rlhf/train_rlhf.py Outdated Show resolved Hide resolved
@vmoens vmoens added the enhancement New feature or request label Jun 28, 2023
model,
ref_model,
reward_model,
kl_controller,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is missing from the doc.

I'm not so sure about this kl_controller that is passed to the module. I feel it should be handled separately. It's like passing the lr_scheduler to the optimizer, the reason we don't do that is that it mixes responsibilities between modules. It gives the impression that one module has multiple responsibilities but it is less clear than doing things explicitly in the main code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are you suggesting we go back to passing just the kl coefficient?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the KL controller should be another class, but then we have 2 options:
the KL controller changes the KL coefficienbt of the other class (like the LR scheduler changes the LR of the optimizer or the target param updaters in torchrl update the target params of the loss) or we explicitely pass the kl coef.
I think the first option is more "pytorch"-style

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the KL controller should be another class
actually is another class

the KL controller changes the KL coefficienbt of the other class
isn't this what we are currently doing?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we not remove it then?

)

rollout_from_model = RolloutFromModel(model, ref_model, reward_model)
kl_controller = AdaptiveKLController(rollout_from_model, 0.1, 6, 10000)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think AdaptiveKLController takes the rollout_from_model as input does it?

Comment on lines 21 to 92
class KLControllerBase(abc.ABC):
"""Base class for KL controllers.

Each controller must implement an update method that takes the current KL value and
the number of steps and updates the self.coef attribute, which will multiply
the KL during calculation of the reward.
"""

@abc.abstractmethod
def update(self, kl_value: float, n_steps: int):
pass


class ConstantKLController(KLControllerBase):
"""Constant KL Controller.

This controller maintains a fixed coefficient no matter what values it is updated
with.

Arguments:
coefficient (float): The coefficient to multiply KL with when calculating the
reward.
"""

def __init__(self, coefficient):
self.coef = coefficient

def update(self, kl_value: float, n_steps: int):
pass


class AdaptiveKLController(KLControllerBase):
"""Adaptive KL Controller as described in Ziegler et al. "Fine-Tuning Language Models from Human Preferences".

Arguments:
init_kl_coef (float): The starting value of the coefficient.
target (float): The target KL value. When the observed KL is smaller, the
coefficient is decreased, thereby relaxing the KL penalty in the training
objective and allowing the model to stray further from the reference model.
When the observed KL is greater than the target, the KL coefficient is
increased, thereby pulling the model back towards the reference model.
horizon (int): Scaling factor to control how aggressively we update the
coefficient.

Reference: Section 2.2 https://arxiv.org/pdf/1909.08593.pdf#page=2
Source: https://github.com/openai/lm-human-preferences/blob/master/lm_human_preferences/train_policy.py
"""

def __init__(self, init_kl_coef: float, target: float, horizon: int):
self.coef = init_kl_coef
self.target = target
self.horizon = horizon

def update(self, kl_value: float, n_steps: int):
"""Update ``self.coef`` adaptively.

Arguments:
kl_value: The current KL value between the newest policy and the initial
policy.
n_steps: The number of training steps taken since last update.
"""
proportional_error = np.clip(kl_value / self.target - 1, -0.2, 0.2) # ϵₜ
mult = 1 + proportional_error * n_steps / self.horizon
self.coef *= mult # βₜ₊₁
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these guys really part of data?
They seem more related to the model to me.
They act on a class that belongs to data (maybe should be moved to collector tbh) but the KL coef is something that has to do with the stochastic policy (the language model, in our case), not the data.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New classes should be added to the doc (provided we're sure of where they belong)

model,
ref_model,
reward_model,
kl_controller,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we not remove it then?

"""Makes a step in the KL coefficient schedule."""
raise NotImplementedError
self.kl_controller.update(kl_value, n_steps)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto, maybe this function should go away?

@@ -167,7 +242,7 @@ def create_rollout_td(self, batch, generated, log_probs, log_ratio, kl_coef=0.1)
)
reward_raw = clipped_scores.unsqueeze(-1).unsqueeze(-1)
reward_raw = reward_raw * done
reward_kl = -kl_coef * log_ratio.unsqueeze(-1)
reward_kl = -self.kl_controller.coef * log_ratio.unsqueeze(-1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

@apbard apbard changed the title [Example, NOMERGE] RLHF end to end example [Example] RLHF end to end example Jul 7, 2023
@vmoens vmoens closed this Jun 10, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants