Skip to content

Commit

Permalink
Merge pull request EleutherAI#59 from EleutherAI/gradient_checkpointing
Browse files Browse the repository at this point in the history
implement gradient checkpointing
  • Loading branch information
ConnorJL committed Jan 13, 2021
2 parents 425dbb5 + 5a353f1 commit 9fe917d
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 9 deletions.
23 changes: 19 additions & 4 deletions gpt_neox/gpt_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,8 @@ def forward(self, x, **kwargs):


class GPTNeoX(nn.Module):
def __init__(self, *, num_tokens, dim, seq_len, depth, heads=8, dim_head=64, attn_dropout=0., ff_dropout=0., sparse_attn=False, use_fused_layernorm=False, tie_classifier_weights=False):
def __init__(self, *, num_tokens, dim, seq_len, depth, heads=8, dim_head=64, attn_dropout=0., ff_dropout=0.,
sparse_attn=False, use_fused_layernorm=False, tie_classifier_weights=False, gradient_checkpointing=True):
super().__init__()
if not use_fused_layernorm:
norm_class = nn.LayerNorm
Expand All @@ -140,23 +141,37 @@ def __init__(self, *, num_tokens, dim, seq_len, depth, heads=8, dim_head=64, att
PreNorm(dim, norm_class, Attention(dim=dim, heads=heads, seq_len=seq_len, dim_head=dim_head, dropout=attn_dropout, sparse_attn=layer_sparse_attn)),
PreNorm(dim, norm_class, FeedForward(dim=dim, dropout=ff_dropout)),
]))
self.depth = depth

self.norm = norm_class(dim)

if tie_classifier_weights:
self.to_logits = lambda t: t @ self.token_emb.weight.t()
else:
self.to_logits = nn.Linear(dim, num_tokens)

self.gradient_checkpointing = gradient_checkpointing

def forward(self, x, mask=None):
n, device = x.shape[1], x.device

x = self.token_emb(x)
x = self.pos_emb(torch.arange(n, device=device)) + x

for (attn, ff) in self.layers:
x = attn(x) + x
x = ff(x) + x
def _layer(attn, ff):
def fn(x):
x = attn(x) + x
return ff(x) + x
return fn

if self.gradient_checkpointing:
for (attn, ff) in self.layers:
layer_fn = _layer(attn, ff)
x = torch.utils.checkpoint.checkpoint(layer_fn, (x))
else:
for (attn, ff) in self.layers:
layer_fn = _layer(attn, ff)
x = layer_fn(x)

x = self.norm(x)
return self.to_logits(x)
Expand Down
13 changes: 8 additions & 5 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

from gpt_neox.utils import get_args, get_params

import GPUtil

train_args = get_args()
params = get_params(train_args.model)

Expand All @@ -26,11 +28,11 @@
seq_len=params["seq_len"],
depth=params["n_layers"],
heads=params["n_heads"],
dim_head=params["dim_head"]
dim_head=params["dim_head"],
gradient_checkpointing=params.get("gradient_checkpointing", True)
)

model = AutoregressiveWrapper(model)

# prepare data
dset_params = params["dataset"]
assert dset_params is not None
Expand Down Expand Up @@ -67,13 +69,14 @@
ds_model_params = prepare_optimizer_parameters(model)

# deepspeed loader
model_engine, optim, train_loader, _ = deepspeed.initialize(args=train_args,
model_engine, optim, _, _ = deepspeed.initialize(args=train_args,
model=model,
optimizer=optim,
model_parameters=ds_model_params,
training_data=train_dataset)
training_data=None)

train_loader = model_engine.deepspeed_io(train_dataset, pin_memory=params.get("pin_memory", False))

print("OPTIMIZER: ", optim)
pbar = trange(params.get("train_steps", 1), mininterval=10., desc='Training Model', dynamic_ncols=True)
for _ in pbar:
for i, data in enumerate(train_loader):
Expand Down

0 comments on commit 9fe917d

Please sign in to comment.