Skip to content

Commit

Permalink
Merge branch 'main' into mor--kv-cache-layout-reformat-output
Browse files Browse the repository at this point in the history
  • Loading branch information
yeandy committed Jun 17, 2024
2 parents 4b9c8a3 + 5c9e569 commit 7fd12c6
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 1 deletion.
34 changes: 33 additions & 1 deletion MaxText/configs/llama2_70b_gpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,36 @@ async_checkpointing: False
logits_dot_in_fp32: False

per_device_batch_size: 6
max_target_length: 4096
max_target_length: 4096

mesh_axes: ['stage', 'data', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'autoregressive']
logical_axis_rules: [
['activation_batch', ['data', 'fsdp', 'fsdp_transpose',]],
# For pipeline parallelism the pre and post decoder layer tensors' batch dimension is sharded by stages.
# Microbatches are sharded by stage, so moving out of and into this sharding should be a local reshape.
# The "stage" needs to be listed first since the microbatch dimension is first before the reshape.
['activation_embed_and_logits_batch', ['stage', 'data', 'fsdp', 'fsdp_transpose']],
['activation_heads', ['tensor','sequence']],
['activation_length', 'sequence'],
['activation_embed', 'tensor'],
['activation_mlp', 'tensor'],
['activation_kv', 'tensor'],
['activation_vocab', ['tensor', 'sequence']],
['activation_vocab', 'tensor'],
['activation_vocab', 'sequence'],
['activation_stage','stage'],
['mlp', ['fsdp_transpose', 'tensor', 'autoregressive']],
['vocab', ['tensor', 'autoregressive']],
['embed', ['fsdp', 'fsdp_transpose', 'sequence']],
['embed', ['fsdp', 'sequence']],
['norm', 'fsdp'],
['heads', ['tensor', 'autoregressive']],
['layers', 'stage'],
['kv', []],
['cache_batch', []],
['cache_heads', ['autoregressive', 'tensor']],
['cache_kv', []],
['cache_sequence', []],
]
# Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details
data_sharding: [['stage', 'data', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'autoregressive']]
9 changes: 9 additions & 0 deletions MaxText/configs/llama2_7b_gpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,13 @@ logits_dot_in_fp32: False
per_device_batch_size: 4
max_target_length: 4096

mesh_axes: ['stage', 'data', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'autoregressive']
logical_axis_rules: [
['activation_batch', ['data', 'fsdp', 'fsdp_transpose',]],
# For pipeline parallelism the pre and post decoder layer tensors' batch dimension is sharded by stages.
# Microbatches are sharded by stage, so moving out of and into this sharding should be a local reshape.
# The "stage" needs to be listed first since the microbatch dimension is first before the reshape.
['activation_embed_and_logits_batch', ['stage', 'data', 'fsdp', 'fsdp_transpose']],
['activation_heads', ['tensor','sequence']],
['activation_length', 'sequence'],
['activation_embed', 'tensor'],
Expand All @@ -28,15 +33,19 @@ logical_axis_rules: [
['activation_vocab', ['tensor', 'sequence']],
['activation_vocab', 'tensor'],
['activation_vocab', 'sequence'],
['activation_stage','stage'],
['mlp', ['fsdp_transpose', 'tensor', 'autoregressive']],
['vocab', ['tensor', 'autoregressive']],
['embed', ['fsdp', 'fsdp_transpose', 'sequence']],
['embed', ['fsdp', 'sequence']],
['norm', 'fsdp'],
['heads', ['tensor', 'autoregressive']],
['layers', 'stage'],
['kv', []],
['cache_batch', []],
['cache_heads', ['autoregressive', 'tensor']],
['cache_kv', []],
['cache_sequence', []],
]
# Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details
data_sharding: [['stage', 'data', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'autoregressive']]

0 comments on commit 7fd12c6

Please sign in to comment.