Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gradient Accumulation in Axlearn #465

Closed
wants to merge 2 commits into from

Conversation

apoorvtintin
Copy link

@apoorvtintin apoorvtintin commented May 13, 2024

Gradient accumulation allows training with higher batch sizes without scaling out.

Added a new learner type learner.klass: 'axlearn.common.learner.AccumulatedLearner'

At a high level the optimization does the following:

  1. Input batch is split into even microbatches.
  2. Creates a buffer for gradients and metrics.
  3. Runs forward and backward pass for each microbatch in a loop summing up the gradients and aggregating metrics.
  4. Average gradients across microbatches and normalize metrics.

Configuration changes:

  • Number of microbatches are specified during configuration through option micriobatches in the learner.

Copy link
Contributor

@ruomingp ruomingp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain why this is needed? We find usually it's more efficient to use either a larger mesh or a smaller batch size.

axlearn/common/learner.py Show resolved Hide resolved
@@ -444,6 +444,153 @@ def _mask_tree(tree: dict, *, keep: dict) -> dict:
)


class MetricsAccumulationOp(NamedTuple):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Axlearn already has metric accumulation classes that are used by evalers. Could those be reused here instead of defining new classes?

Copy link
Author

@apoorvtintin apoorvtintin Jul 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These metric accumulation classes are stateless so that they are usable as carry by jax.lax.scan unlike the ones in the evaler, I can make the class structure similar though.

axlearn/common/learner.py Outdated Show resolved Hide resolved
# tuple of key-value pairs specifying custom aggregation and normalization
# for a specific metric
metrics_accumulation_key_ops: Sequence[Dict[str, Optional[MetricsAccumulationOp]]] = []
gradient_dtype: Optional[jnp.dtype] = jnp.bfloat16
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the existing learner class use this? If not, we should try to be consistent with its API.

Copy link
Author

@apoorvtintin apoorvtintin Jul 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please be more specific, is the concern naming of members?


Returns:
ForwardBackwardOutputs: pytree containing gradients and metrics
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if instead of having a separate learner for microbatching, it would be more flexible to have a generic way of wrapping a ForwardFn so that it uses Jax.lax.map to run the microbatches. Beyond avoiding the need to add a new learner, it would also allow for other microbetching uses outside of learner, eg inference or in second order optimizers.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

jax.lax.map gives no guarantees of sequential execution of microbatches which is the key quality of gradient accumulation.

@apghml
Copy link
Contributor

apghml commented Jul 3, 2024

I’m ooo this week. I have left some preliminary comments for now.

@apghml
Copy link
Contributor

apghml commented Jul 31, 2024

Closing since gradient accumulation functionality has been implemented via #614

@apghml apghml closed this Jul 31, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants