Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

change norm sharding #623

Merged
merged 1 commit into from
Apr 26, 2024
Merged

change norm sharding #623

merged 1 commit into from
Apr 26, 2024

Conversation

ZhiyuLi-goog
Copy link
Collaborator

@ZhiyuLi-goog ZhiyuLi-goog commented Apr 26, 2024

This change is to switch the norm layer's sharding from "embed"/FSDP to match the "activation_embed"/tensor parallelism, aligning the sharding for the element-wise multiplication

  • boosted multislice MFU by roughly 10% with TP activated after avoiding unexpected cross-slice DCN collective-permute caused by sharding mismatch experiments
  • verified in HLOs
# Setup [data_dcn, fsdp_ici, tensor_ici] corresponds to [2,16,4] 

# norm weight (scale) was sharded by "embed"/fsdp i.e. 16 way fsdp sharding

reshape.371 = bf16[12288]{0} reshape(add.49), sharding={devices=[16,8]<=[2,16,4]T(1,0,2) last_tile_dim_replicate}, metadata={op_name="jit(train_step)/jit(main)/transpose(jvp(Transformer))/decoder/while/body/checkpoint/rematted_computation/layers/pre_self_attention_norm/add" source_file="/app/maxtext/MaxText/layers/gpt3.py" source_line=91}


# to ensure the same sharding as activation ((data, fsdp), None, tensor)
# expensive all gather and collective permuate communication triggered by sharding mismatch in broadcast: [16(fsdp),8(None)] -> [32(data x fsdp),1(None),4(tp)]

broadcast.1326 = bf16[128,2048,12288]{2,1,0} broadcast(reshape.371), dimensions={2}, sharding={devices=[32,1,4]<=[128]}, metadata={op_name="jit(train_step)/jit(main)/transpose(jvp(Transformer))/decoder/while/body/checkpoint/rematted_computation/layers/pre_self_attention_norm/mul" source_file="/app/maxtext/MaxText/layers/gpt3.py" source_line=91}
multiply.1327 = bf16[128,2048,12288]{2,1,0} multiply(multiply.1320, broadcast.1326), sharding={devices=[32,1,4]<=[128]}, metadata={op_name="jit(train_step)/jit(main)/transpose(jvp(Transformer))/decoder/while/body/checkpoint/rematted_computation/layers/pre_self_attention_norm/mul" source_file="/app/maxtext/MaxText/layers/gpt3.py" source_line=91}

@rwitten rwitten removed their assignment Apr 26, 2024
fix lint

Revert "fix lint"

This reverts commit d8dc450.

fix lint
@copybara-service copybara-service bot merged commit 6570445 into main Apr 26, 2024
8 checks passed
@copybara-service copybara-service bot deleted the lizhiyu/change_norm_sharding branch April 26, 2024 21:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants