forked from EleutherAI/gpt-neox
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_enwik8_pipeline.py
128 lines (102 loc) · 4.3 KB
/
train_enwik8_pipeline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import argparse
import json
import random
from collections import defaultdict
import os
import deepspeed
import torch
from torch.utils.data import DataLoader
from tqdm.auto import trange
from gpt_neox import (GPTNeoX, AutoregressiveWrapper, TextSamplerDataset,
cycle, prepare_optimizer_parameters, decode_tokens, prepare_data,
GPTNeoX_Pipe)
from gpt_neox.utils import is_main
from gpt_neox.data_utils import read_enwik8_data
import gpt_neox
WORLD_SIZE = os.getenv('WORLD_SIZE')
def get_args():
parser = argparse.ArgumentParser(description='GPTNeox Deepspeed Training Script')
# Include DeepSpeed configuration arguments
parser.add_argument('--model', type=str, default="base_model")
parser.add_argument('--local_rank', type=int, default=-1,
help='local rank passed from distributed launcher')
parser = deepspeed.add_config_arguments(parser)
args = parser.parse_args()
return args
def get_params(model):
model_path = model if model.endswith(".json") else f"./configs/{model}.json"
with open(model_path) as f:
params = json.load(f)
return defaultdict(lambda: None, params)
train_args = get_args()
params = get_params(train_args.model)
# instantiate GPT-like decoder model
'''model = GPTNeoX(
num_tokens=params["vocab_size"],
dim=params["hidden_dim"],
seq_len=params["seq_len"],
depth=params["n_layers"],
heads=params["n_heads"],
dim_head=params["dim_head"]
)
model = AutoregressiveWrapper(model)'''
deepspeed.init_distributed()
def loss_function(x, y):
losses = torch.nn.functional.cross_entropy(x, y, reduction='none')
loss = losses.mean()
return loss
model = gpt_neox.GPTNeoX_Pipe(
num_tokens=params["vocab_size"],
dim=params["hidden_dim"],
seq_len=params["seq_len"],
depth=params["n_layers"],
heads=params["n_heads"],
dim_head=params["dim_head"],
loss_fn = loss_function,#torch.nn.CrossEntropyLoss(),
num_stages = params.get("pipeline_num_stages", 2)
)
# prepare enwik8 data
dset_params = params["dataset"]
deepspeed.init_distributed(dist_backend='nccl')
torch.distributed.barrier() # barrier will force processes to stop until *all* processes have reached the barrier
if is_main(train_args):
prepare_data(dset_params["name"])
torch.distributed.barrier() # barrier will force processes to stop until *all* processes have reached the barrier
else:
torch.distributed.barrier()
data_train, data_val = read_enwik8_data(dset_params["path"])
train_dataset = TextSamplerDataset(data_train, params["seq_len"], mode="with_labels")
val_dataset = TextSamplerDataset(data_val, params["seq_len"], mode="with_labels")
val_loader = cycle(DataLoader(val_dataset, batch_size=params["batch_size"]))
# optimizer
optim = torch.optim.Adam(model.parameters(), lr=params["learning_rate"])
# training
ds_model_params = prepare_optimizer_parameters(model)
# deepspeed loader
model_engine, optim, train_loader, _ = deepspeed.initialize(args=train_args,
model=model,
optimizer=optim,
model_parameters=ds_model_params,
training_data=train_dataset)
batches_to_train = 10000
pbar = trange(params["num_epochs"], mininterval=10., desc='Training Model', dynamic_ncols=True)
for _ in pbar:
for i in range(batches_to_train):
is_main = model_engine.local_rank == 0
loss = model_engine.train_batch()
pbar.set_description(f'Training Loss: {loss.item():.4f}')
pbar.update()
'''if is_main and i % params["validate_every"] == 0:
model.eval()
with torch.no_grad():
val_data = next(val_loader).cuda()
loss = model(val_data)
pbar.write(f'Validation Loss: {loss.item()}')
if is_main and i % params["generate_every"] == 0:
model.eval()
inp = random.choice(val_dataset)[:-1]
prime = decode_tokens(inp)
pbar.write(f"{prime} \n\n {'*' * 100}")
sample = model.generate(inp.cuda(), params["generate_length"])
output_str = decode_tokens(sample)
pbar.write(output_str)'''