Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cull pp = 0 model branch #269

Merged
merged 26 commits into from
Apr 30, 2021
Merged

Cull pp = 0 model branch #269

merged 26 commits into from
Apr 30, 2021

Conversation

sdtblck
Copy link
Contributor

@sdtblck sdtblck commented Apr 28, 2021

Ok, this pr is a compendium of lots of stuff.

  • Main change is that the model branch where pp=0 no longer exists. pipe parallel defaults to 1. This means we only have to maintain a single model. I verified both models acheived exactly the same loss here, and in fact pp=1 (GPT2ModelPipe) is slightly faster, for some reason, so there's no reason that i can see to keep the other model branch around.

  • Fix the wandb group name logging (we ended up adding multiple uuids to the end of the group name)

  • Removes a lot of dead code from:
    - megatron/mpu/random/.py -> handled activation checkpointing. With the pipeline model, this is all handled by deepspeed
    - megatron/fp16 stuff -> handles loss scaling / fp16 conversion. Again, all handled by deepspeed, so can be safely removed.
    - megatron/mpu/grads.py -> handled gradient clipping, also now handled by deepspeed.
    - megatron/memory.py -> wasn't used anywhere.

  • Makes the apex dependency optional (better to have it as apex fusedadam is slightly faster than deepspeed's version).

  • Some mild reorganization of the layout of megatron/model to make it easier to work with / navigate

  • renamed megatron/arguments/megatron_args.py to megatron/arguments/neox_args.py (in the future i think we should rename the whole package to neox - I think we're now sufficiently different lol, but couldn't be arsed to go through the hassle rn)

  • Updated requirements (separated optional ones from mandatory ones) and updated dockerfile

@sdtblck sdtblck requested a review from a team as a code owner April 28, 2021 14:11
This was linked to issues Apr 28, 2021
@sdtblck
Copy link
Contributor Author

sdtblck commented Apr 28, 2021

pls don't merge yet - making some more changes

@sdtblck
Copy link
Contributor Author

sdtblck commented Apr 28, 2021

Okay, should be ready to merge now.

We can now convert GPT2ModelPipe to a regular nn.Sequential model by calling GPT2ModelPipe.to_sequential(). If pipe parallel is set to 0, we train using this model. This should also enable us to still use ZeRO 2 / 3 etc. if desired.

@sdtblck
Copy link
Contributor Author

sdtblck commented Apr 30, 2021

Screenshot from 2021-04-29 16-00-05
Imo everything here is ready to merge - here is training with pipe parallel on / off as well as loading from checkpoint - everything stable and running as expected.

@sdtblck sdtblck requested a review from sweinbach April 30, 2021 09:38
@sdtblck sdtblck merged commit dc44965 into main Apr 30, 2021
@sdtblck sdtblck deleted the cull-model-branch branch April 30, 2021 09:43
@sdtblck sdtblck mentioned this pull request Apr 30, 2021
3 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Get rid of codepath where pp = 0 Timer logging innacurate if pp=0
4 participants