Skip to content

Commit

Permalink
Merge pull request #701 from google:mattdavidow-pipeline-circular
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 644065470
  • Loading branch information
maxtext authors committed Jun 17, 2024
2 parents 5c9e569 + a75d9a9 commit 75b3a5e
Show file tree
Hide file tree
Showing 4 changed files with 251 additions and 22 deletions.
10 changes: 9 additions & 1 deletion MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,16 @@ normalize_embedding_logits: True # whether to normlize pre-softmax logits if lo
logits_dot_in_fp32: True # whether to use fp32 in logits_dense or shared_embedding dot product for stability

# pipeline parallelism
# The number of decoder layers is equal to the product of num_stages and num_layers_per_pipeline_stage (does not yet support circular pipelines).
# The number of decoder layers is equal to the product of num_stages, num_layers_per_pipeline_stage and num_pipeline_repeats.
# There is a tradeoff between the num_layers_per_pipeline_stage and num_pipeline_repeats: The more layers per stage the easier
# it is to hide the pipeline communication behind the compute since there is more compute per stage, however there will be a larger bubble
# since there are fewer repeats. Similarly there is tradeoff for num_pipeline_microbatches - more microbatches leads to a smaller bubble,
# but a smaller size per microbatch which may hurt per-stage performance. Additionally note when microbatches > num_stages we have the opportunity to
# perform the circular transfer (last stage to first) asynchronously.
# The bubble fraction is (num_stages - 1) / (num_pipeline_repeats * num_pipeline_microbatches + num_stages - 1)
num_layers_per_pipeline_stage: 1
# The number of repeats will be set to num_decoder_layers / (num_pipeline_stages * num_layers_per_pipeline_stage)
num_pipeline_repeats: -1
# num_pipeline_microbatches must be a multiple of the number of pipeline stages. By default it is set to the number of stages.
# Note the microbatch_size is given by global_batch_size / num_pipeline_microbatches, where global_batch_size = per_device_batch_size * num_devices
num_pipeline_microbatches: -1
Expand Down
Loading

0 comments on commit 75b3a5e

Please sign in to comment.