-
Notifications
You must be signed in to change notification settings - Fork 983
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Default of output_layer_parallelism = "row"
is broken for model-parallel training
#905
Comments
Nice catch! Looks like this issue has snuck through for a while. I'm going to make column the default and add an error for row-parallelism for now, and if you're able to add shape checks that would be much appreciated! |
Is there a reason to support row parallelism at all? I’m struggling to think of why the user might care other than compatibility with loss functions, but as noted mpu’s CE loss is column parallel. Is there another library we support that does it row parallel? |
Fwiw, I think getting rid of row-parallel final linear as an option makes the most sense -- it's not clear to me how you could write a loss function that operates on row-parallel outputs, since the output activations are split in the "reduction dimension" as opposed to along any tensor axis. Furthermore, in the simple case where you perform an immediate reduce of the parallel output (rather than have a parallel loss function), column parallel is faster since the reduction is an AllGather, not an AllReduce. |
Describe the bug
By default,
output_layer_parallelism = "row"
, which splits the output in the hidden dimension. But the loss functionmpu.vocab_parallel_cross_entropy
assumes input split in the vocab dimension. Remarkably, it doesn't fail on shape mismatches but instead runs cleanly and silently computes the wrong loss. I noticed that all the provided configs explicitly setoutput_layer_parallelism = "column"
.To Reproduce
If you train the same model with model_parallel_size > 1 and compare row vs. column parallelism, you will get worse results with the default setting of row.
Expected behavior
See above.
Proposed solution
I'm not sure why row-parallelism on the output is supported, but I think any of the following are reasonable:
The text was updated successfully, but these errors were encountered: