Skip to content

Commit

Permalink
[ZeRO-3] Partitioned init with deepspeed.zero.Init() (#1190)
Browse files Browse the repository at this point in the history
* added ds zero.Init() to get_model

* Clean up conditional with block

* pre-commit

---------

Co-authored-by: Quentin Anthony <[email protected]>
  • Loading branch information
R0n12 and Quentin-Anthony committed Mar 19, 2024
1 parent 277141e commit 7267a74
Showing 1 changed file with 10 additions and 7 deletions.
17 changes: 10 additions & 7 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

import math
import sys
from contextlib import nullcontext

import torch
import deepspeed
Expand Down Expand Up @@ -426,13 +427,15 @@ def get_model(neox_args, use_cache=False):
# If mup isn't being used anyways, this has no effect.
old_use_mup = neox_args.use_mup
neox_args.use_mup = False
model = GPT2ModelPipe(
neox_args=neox_args,
num_tokentypes=0,
parallel_output=True,
topology=mpu.get_topology(),
use_cache=use_cache,
)

with deepspeed.zero.Init() if neox_args.zero_stage == 3 else nullcontext() as gs:
model = GPT2ModelPipe(
neox_args=neox_args,
num_tokentypes=0,
parallel_output=True,
topology=mpu.get_topology(),
use_cache=use_cache,
)

### soft prompt tuning stuff ###
if neox_args.soft_prompt_tuning is not None and neox_args.soft_prompt_tuning.get(
Expand Down

0 comments on commit 7267a74

Please sign in to comment.