Skip to content

Commit

Permalink
Add checkpoint saving / loading (#90)
Browse files Browse the repository at this point in the history
* fix torch.utils.checkpoint error

* fix breaks in train_pipeline.py

* push fixes to train_pipeline.py

* push changes to zero1 config

* omnibus changes to *pipeline.py scripts

* add checkpoint saving / loading

* changed line 64

checkpoint_dirs = natural_sort(checkpoint_dir) rather than natural_sort(checkpoint_dirs)

* Update utils.py

* fix checkpoint saving / loading logic

* fix checkpoint saving logic

* Update gpt3_small.json

* Change params for OnebitAdam

Per my issue in deepspeed yesterday, I was told by a dev (microsoft/DeepSpeed#690 (comment)) that the error I was facing was due to the incorrect keyword.

* Fixing batch size

* Made consistent with ZeRO 1

* Made consistent with ZeRO 2

* Update deepspeed_zero2.json

* Undo previous commit

* Reverted back to normal adam from 1-bit-adam (#96)

* Create checkpoints_config.json

* Update train_gpt3small_pipeline.sh

Co-authored-by: Shivanshu Purohit <[email protected]>
Co-authored-by: Stella Biderman <[email protected]>
  • Loading branch information
3 people committed Jan 28, 2021
1 parent 39972e6 commit 4aee002
Show file tree
Hide file tree
Showing 13 changed files with 183 additions and 45 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -281,3 +281,6 @@ TSWLatexianTemp*

# Makeindex log files
*.lpz

# saved model files
*.pt
2 changes: 1 addition & 1 deletion configs/base_model.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,5 @@
"n_layers": 6,
"n_heads": 8,
"dim_head": 64,
"train_batch_size": 8
"checkpoint_dir": "./enwik8_model"
}
45 changes: 45 additions & 0 deletions configs/checkpoints_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
{
"train_batch_size": 1280,
"gradient_accumulation_steps": 80,
"gradient_clipping": 1.0,
"wall_clock_breakdown": true,
"zero_allow_untested_optimizer": true,
"tensorboard": {
"enabled": true,
"output_path": "./logs",
"job_name": "gptneox"
},
"optimizer": {
"type": "OneBitAdam",
"params": {
"lr": 2e-4,
"freeze_step":2,
"cuda_aware":true
}
},
"scheduler": {
"type": "WarmupLR",
"params": {
"warmup_min_lr": 0,
"warmup_max_lr": 0.00015,
"warmup_num_steps": 5000
}
},
"fp16": {
"enabled": true
},
"zero_optimization": {
"stage": 1,
"contiguous_gradients" : true,
"cpu_offload": false
},
"activation_checkpointing": {
"partition_activations": true,
"cpu_checkpointing": false,
"contiguous_memory_optimization": false,
"number_checkpoints": 1,
"synchronize_checkpoint_boundary": false,
"profile": false
}

}
15 changes: 10 additions & 5 deletions configs/deepspeed_zero1.json
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,15 @@
"job_name": "gptneox"
},
"optimizer": {
"type": "OneBitAdam",
"type": "Adam",
"params": {
"lr": 2e-4,
"freeze_step":2,
"cuda-aware":true
"lr": 0.001,
"betas": [
0.8,
0.999
],
"eps": 1e-8,
"weight_decay": 3e-7
}
},
"scheduler": {
Expand All @@ -30,7 +34,8 @@
"zero_optimization": {
"stage": 1,
"contiguous_gradients" : true,
"cpu_offload": false
"cpu_offload": false,
"overlap_comm": false
},
"activation_checkpointing": {
"partition_activations": true,
Expand Down
13 changes: 6 additions & 7 deletions configs/deepspeed_zero2.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
{
"train_batch_size": 1028,
"gradient_accumulation_steps": 1,
"train_batch_size": 1280,
"gradient_accumulation_steps": 80,
"gradient_clipping": 1.0,
"wall_clock_breakdown": true,
"tensorboard": {
"enabled": true,
"output_path": "./logs",
Expand All @@ -26,7 +27,6 @@
"fp16": {
"enabled": true
},
"wall_clock_breakdown": true,
"zero_optimization": {
"stage": 2,
"contiguous_gradients" : true,
Expand All @@ -37,12 +37,11 @@
"steps_per_print": 100,
"wall_clock_breakdown": true
},
"activation_checkpointing": {
"comment": "to turn on activation checkpointing, set this to a positive integer. Do not touch other params.",
"partition_activations": false,
"activation_checkpointing": {
"partition_activations": true,
"cpu_checkpointing": false,
"contiguous_memory_optimization": false,
"number_checkpoints": null,
"number_checkpoints": 1,
"synchronize_checkpoint_boundary": false,
"profile": false
}
Expand Down
1 change: 1 addition & 0 deletions configs/gpt3_small.json
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,6 @@
"n_layers": 12,
"n_heads": 12,
"dim_head": 64,
"checkpoint_dir": "./gpt3small",
"train_batch_size": 256
}
1 change: 0 additions & 1 deletion gpt_neox/gpt_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,5 @@ def __init__(
LayerSpec(nn.Linear, dim, num_tokens),
lambda x: x.transpose(1, 2)
]
print(spec)
assert len(spec) % num_stages == 0, f"for optimal performance, depth + 4 ({len(spec)}) should be divisible by the number of pipeline stages ({num_stages})"
super().__init__(layers=spec, loss_fn=loss_fn, num_stages=num_stages, **kwargs)
76 changes: 76 additions & 0 deletions gpt_neox/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@
import deepspeed
import json
from collections import defaultdict
import shutil
import re
import random
import numpy as np
import torch


# helpers
Expand Down Expand Up @@ -32,6 +37,77 @@ def is_main(args):
return args.local_rank in [0, -1]


def natural_sort(l):
convert = lambda text: int(text) if text.isdigit() else text.lower()
alphanum_key = lambda key: [ convert(c) for c in re.split('([0-9]+)', key) ]
return sorted(l, key = alphanum_key)


def save_ds_checkpoint(iteration, model, params, keep_n_latest_checkpoints=None, is_main=None):
"""Save a model checkpoint."""
iteration = str(iteration)
sd = {}
sd['iteration'] = iteration
if keep_n_latest_checkpoints is not None:
assert is_main is not None
# rng states.
if params.get('save_rng', True):
sd['random_rng_state'] = random.getstate()
sd['np_rng_state'] = np.random.get_state()
sd['torch_rng_state'] = torch.get_rng_state()
sd['cuda_rng_state'] = torch.cuda.get_rng_state()
sd['rng_tracker_states'] = model.mpu.get_cuda_rng_tracker().get_states()

checkpoint_dir = params.get('checkpoint_dir', None)
assert checkpoint_dir is not None, 'add "checkpoint_dir" to your model params to enable checkpoint saving'
if not os.path.isdir(checkpoint_dir):
os.makedirs(checkpoint_dir, exist_ok=True)
if keep_n_latest_checkpoints is not None:
all_checkpoints = os.listdir(checkpoint_dir)
checkpoint_dirs = natural_sort(all_checkpoints)
checkpoint_dirs = [item for item in checkpoint_dirs if os.path.isdir(os.path.join(checkpoint_dir, item))]
checkpoint_dirs = [str(i) for i in checkpoint_dirs]
n = len(checkpoint_dirs) - keep_n_latest_checkpoints
n = 0 if n < 0 else n
to_delete = checkpoint_dirs[:n+1]
if to_delete:
if is_main:
print(f'WARNING: deleting checkpoint dirs {to_delete} in {checkpoint_dir}')
[shutil.rmtree(os.path.join(checkpoint_dir, item)) for item in to_delete]
model.save_checkpoint(checkpoint_dir, iteration, client_state=sd)


def load_ds_checkpoint(model, params, iteration=None):
"""Load a model checkpoint."""
if iteration is not None:
iteration = str(iteration)

checkpoint_dir = params.get('checkpoint_dir', None)
assert checkpoint_dir is not None, 'add "checkpoint_dir" to your model params to enable checkpoint loading'
print(f'Loading latest checkpoint from {checkpoint_dir}')

checkpoint_name, sd = model.load_checkpoint(checkpoint_dir, iteration)
if checkpoint_name is None:
print("Unable to load checkpoint.")
return iteration if iteration is not None else 0

# rng states.
if params.get('load_rng', True):
try:
random.setstate(sd['random_rng_state'])
np.random.set_state(sd['np_rng_state'])
torch.set_rng_state(sd['torch_rng_state'])
torch.cuda.set_rng_state(sd['cuda_rng_state'])
model.mpu.get_cuda_rng_tracker().set_states(sd['rng_tracker_states'])
except KeyError:
print(f'Unable to load rngs from checkpoint {checkpoint_name}, exiting. ')
exit()
torch.distributed.barrier()
print(f'successfully loaded {checkpoint_name}')
iteration = int(os.path.basename(os.path.dirname(checkpoint_name)))
return iteration


def get_all_files(filetype, files_dir):
files = []
for (dir_path, _, filenames) in os.walk(files_dir):
Expand Down
3 changes: 2 additions & 1 deletion scripts/kill.sh
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
pkill -f "python -u train*"
pkill -f "python -u train*"
pkill -9 python
2 changes: 1 addition & 1 deletion scripts/train_enwik8_pipeline.sh
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
mkdir logs
NCCL_SHM_DISABLE=1 NCCL_DEBUG=info MASTER_ADDR=127.0.0.1 MASTER_PORT=2000 deepspeed train_enwik8_pipeline.py --deepspeed --deepspeed_config configs/deepspeed_zero2.json
NCCL_SHM_DISABLE=1 NCCL_DEBUG=info MASTER_ADDR=127.0.0.1 MASTER_PORT=2000 deepspeed train_enwik8_pipeline.py --deepspeed --deepspeed_config configs/deepspeed_zero1.json --model base_model
2 changes: 1 addition & 1 deletion scripts/train_gpt3small_pipeline.sh
Original file line number Diff line number Diff line change
@@ -1 +1 @@
MASTER_ADDR=127.0.0.1 MASTER_PORT=2000 deepspeed train_pipeline.py --deepspeed --deepspeed_config configs/deepspeed_zero1.json
NCCL_SHM_DISABLE=1 MASTER_ADDR=127.0.0.1 MASTER_PORT=2000 deepspeed train_pipeline.py --deepspeed --deepspeed_config configs/checkpoints_config.json
28 changes: 16 additions & 12 deletions train_enwik8_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
from gpt_neox import (GPTNeoX, AutoregressiveWrapper, TextSamplerDataset,
cycle, prepare_optimizer_parameters, decode_tokens, prepare_data,
GPTNeoX_Pipe)
from gpt_neox.utils import is_main, get_args, get_params
from gpt_neox.utils import is_main, get_args, get_params, load_ds_checkpoint, save_ds_checkpoint

from gpt_neox.data_utils import read_enwik8_data
import gpt_neox

Expand Down Expand Up @@ -41,6 +42,9 @@ def prepare_dataset(dset_params, train_args):
if __name__ == '__main__':
# arguments
train_args = get_args()

IS_MAIN = is_main(train_args)

params = get_params(train_args.model)
deepspeed.init_distributed(dist_backend='nccl')
model = gpt_neox.GPTNeoX_Pipe(
Expand All @@ -66,20 +70,20 @@ def prepare_dataset(dset_params, train_args):
# optimizer
ds_model_params = prepare_optimizer_parameters(model)
optim = torch.optim.Adam(ds_model_params, lr=params["learning_rate"])

# deepspeed loader
model_engine, optim, train_loader, _ = deepspeed.initialize(args=train_args,
model, optim, train_loader, _ = deepspeed.initialize(args=train_args,
model=model,
optimizer=optim,
model_parameters=ds_model_params,
training_data=train_dataset)
configure_checkpointing(model_engine)

batches_to_train = 10000

pbar = trange(batches_to_train, mininterval=10., desc='Training Model', dynamic_ncols=True)
for _ in pbar:
for i in range(batches_to_train):

loss = model_engine.train_batch()
pbar.set_description(f'Training Loss: {loss.item():.4f}')
pbar.update()
configure_checkpointing(model)
current_iteration = load_ds_checkpoint(model, params, iteration=None)
pbar = trange(current_iteration, params.get('train_steps', 100000), mininterval=10., desc='Training Model', dynamic_ncols=True)
for i in pbar:
loss = model.train_batch()
pbar.set_description(f'Training Loss: {loss.item():.4f}')
pbar.update()
if not i % params.get('checkpoint_save_frequency', 1000) and i != 0:
save_ds_checkpoint(i, model, params, params.get('keep_n_latest_checkpoints', 5), IS_MAIN)
37 changes: 21 additions & 16 deletions train_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
GPTNeoX_Pipe)
from gpt_neox.datasets import GPT2Dataset
from gpt_neox.data_utils import get_tokenizer
from gpt_neox.utils import is_main, get_args, get_params

from gpt_neox.utils import is_main, get_args, get_params, save_ds_checkpoint, load_ds_checkpoint

import gpt_neox

WORLD_SIZE = os.getenv('WORLD_SIZE')
Expand Down Expand Up @@ -41,6 +43,9 @@ def prepare_dataset(dset_params, train_args):
if __name__ == '__main__':
# arguments
train_args = get_args()

IS_MAIN = is_main(train_args)

params = get_params(train_args.model)
deepspeed.init_distributed(dist_backend='nccl')

Expand Down Expand Up @@ -86,18 +91,18 @@ def prepare_dataset(dset_params, train_args):
ds_model_params = prepare_optimizer_parameters(model)
optim = torch.optim.Adam(ds_model_params, lr=params["learning_rate"])
# deepspeed loader
model_engine, optim, train_loader, _ = deepspeed.initialize(args=train_args,
model=model,
optimizer=optim,
model_parameters=ds_model_params,
training_data=train_dataset)

configure_checkpointing(model_engine)

batches_to_train = 10000
pbar = trange(batches_to_train, mininterval=10., desc='Training Model', dynamic_ncols=True)
for _ in pbar:
for i in range(batches_to_train):
loss = model_engine.train_batch()
pbar.set_description(f'Training Loss: {loss.item():.4f}')
pbar.update()

model, optim, train_loader, lr_scheduler = deepspeed.initialize(args=train_args,
model=model,
optimizer=optim,
model_parameters=ds_model_params,
training_data=train_dataset)
configure_checkpointing(model)
current_iteration = load_ds_checkpoint(model, params, iteration=None)
pbar = trange(current_iteration, params.get('train_steps', 100000), mininterval=10., desc='Training Model', dynamic_ncols=True)
for i in pbar:
loss = model.train_batch()
pbar.set_description(f'Training Loss: {loss.item():.4f}')
pbar.update()
if not i % params.get('checkpoint_save_frequency', 1000) and i != 0:
save_ds_checkpoint(i, model, params, params.get('keep_n_latest_checkpoints', 5), IS_MAIN)

0 comments on commit 4aee002

Please sign in to comment.