Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support gradient accumulation using a forward fn decorator with lax.scan #614

Merged
merged 1 commit into from
Jul 31, 2024

Conversation

cemkoc
Copy link
Contributor

@cemkoc cemkoc commented Jul 30, 2024

This PR adds a new forward function decorator into the learner config which allows users to define their own decorators that could transform the forward function behaviour. One such example of this use case is gradient accumulation which is what this PR is about.

Specifically, this PR enables gradient accumulation using a a new forward function decorator called with_minibatch_steps which internally implements a lax.scan loop over minibatches sliced with dynamic input batch slicing. The new forward function decorator, called with_minibatch_steps, wraps the forward function and accumulates gradients using a jax.lax.scan for the number of gradient accumulation steps which the user specifies. To use this decorator the learner config is extended with an optional forward_fn_decorator parameter which will be specified by the user in the following way if they want to enable gradient accumulation:

Example:

gradient_accumulation_steps = 4

learner.forward_fn_decorator = config.config_for_function(with_minibatch_steps).set(
    steps=gradient_accumulation_steps,
    metric_accumulator=MetricAccumulator.default_config(),
)

Since we scan over the forward function (instead of the value_and_grad func) we need to have a way to compute the gradients in a minibatched manner therefore we wrap the forward function with a custom_vjp implementation to compute the gradients and accumulate during the forward phase. Since the gradients are accumulated and computed during the forward phase we simply pass them to the backward phase and use them as is. This allows us to compute and accumulate the gradients in a memory efficient (minibatched) manner.

A similar effort for implementing grad accumulation using a scan was mentioned by @apoorvtintin.

An integer representing minibatch size.

Raises:
ValueError if the input batch is not divisible by steps.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for missing this earlier, but we should document the other raise conditions too.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good, let me update it now

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

@cemkoc cemkoc added this pull request to the merge queue Jul 31, 2024
Merged via the queue into apple:main with commit 5ef4825 Jul 31, 2024
4 checks passed
@cemkoc cemkoc deleted the cemkoc/grad-accum-scan branch July 31, 2024 17:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants