Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix train pipeline #89

Merged
merged 6 commits into from
Jan 25, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions configs/deepspeed_zero1.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
{
"train_batch_size": 256,
"gradient_accumulation_steps": 1,
"train_batch_size": 1280,
"gradient_accumulation_steps": 80,
"gradient_clipping": 1.0,
"wall_clock_breakdown": true,
"tensorboard": {
"enabled": true,
"output_path": "./logs",
Expand All @@ -28,11 +29,11 @@
},
"zero_optimization": {
"stage": 1,
"contiguous_gradients" : false,
"contiguous_gradients" : true,
"cpu_offload": false
},
"activation_checkpointing": {
"partition_activations": false,
"partition_activations": true,
"cpu_checkpointing": false,
"contiguous_memory_optimization": false,
"number_checkpoints": 1,
Expand Down
4 changes: 2 additions & 2 deletions configs/gpt3_small.json
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
"seed": 1,
"shuffle_input_filenames": true,
"pretokenized": true,
"filetype": "tfrecords",
"mode": "chunks"
"filetype": "tfrecords"
},
"num_epochs": 10,
"train_steps": 572300,
"eval_batch_size": 32,
"learning_rate": 0.0006,
Expand Down
27 changes: 17 additions & 10 deletions gpt_neox/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
class GPT2Dataset(Dataset):

def __init__(self, glob_pattern, seq_len, seed=1, shuffle_input_filenames=True, pretokenized=True,
filetype="tfrecords", mode="chunks", train=True, tokenizer=None, **kwargs):
filetype="tfrecords", mode="normal", train=True, tokenizer=None, **kwargs):

super().__init__()
self.files = glob.glob(glob_pattern) # glob pattern pointing to files
Expand All @@ -34,16 +34,13 @@ def __init__(self, glob_pattern, seq_len, seed=1, shuffle_input_filenames=True,
self._get_lens()

self.seq_len = seq_len # set sequence length
self.mode = mode # set mode ["chunks"]
implemented_modes = ["chunks"]
if self.mode not in implemented_modes:
raise NotImplementedError


self.pretokenized = pretokenized
if not self.pretokenized:
raise NotImplementedError # TODO: tokenize text data on the fly

self.train = train
self.mode = mode

def _get_number_of_documents(self, filename):
# extracts number of files from a filename formatted "<name>_<num_documents>.{filetype}."
Expand Down Expand Up @@ -75,8 +72,6 @@ def _get_lens(self):
def _parse_single_example(self, example):
data = tf.train.Example.FromString(example)
data = torch.tensor(list(data.features.feature["text"].int64_list.value), dtype=torch.long)
if self.mode == "chunks":
assert data.size(0) == self.seq_len + 1
return data

def _process_tfrecord(self, tfrecords_file, resume_idx=None):
Expand Down Expand Up @@ -111,7 +106,17 @@ def __getitem__(self, idx):
seek_idx) # parses tfrecord file to a list *once* then stores in memory
else:
raise NotImplementedError
return chunk[remainder] # get item from current chunk
output = chunk[remainder]
assert output is not None
assert output.size(0) == (self.seq_len + 1), f"Output shape ({output.size(0)}) != the specified sequence length + 1 ({self.seq_len + 1})"
if self.mode == "normal":
return output
elif self.mode == 'with_labels':
x_seq = output[:-1]
y_seq = output[1:]
return x_seq, y_seq
else:
raise ValueError(f'mode {self.mode} not recognized')

def __len__(self):
return self._len
Expand All @@ -130,10 +135,12 @@ def __getitem__(self, index):
if self.mode == "normal":
full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
return full_seq
else:
elif self.mode == "with_labels":
x_seq = self.data[rand_start: rand_start + self.seq_len].long()
y_seq = self.data[rand_start+1: rand_start + self.seq_len + 1].long()
return x_seq, y_seq
else:
raise ValueError(f'mode {self.mode} not recognized')

def __len__(self):
return self.data.size(0) // self.seq_len
2 changes: 1 addition & 1 deletion scripts/train_gpt3small_pipeline.sh
Original file line number Diff line number Diff line change
@@ -1 +1 @@
MASTER_ADDR=127.0.0.1 MASTER_PORT=2000 deepspeed train_pipeline.py --deepspeed --deepspeed_config configs/deepspeed_zero2.json
MASTER_ADDR=127.0.0.1 MASTER_PORT=2000 deepspeed train_pipeline.py --deepspeed --deepspeed_config configs/deepspeed_zero1.json
165 changes: 61 additions & 104 deletions train_enwik8_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,118 +11,75 @@
from gpt_neox import (GPTNeoX, AutoregressiveWrapper, TextSamplerDataset,
cycle, prepare_optimizer_parameters, decode_tokens, prepare_data,
GPTNeoX_Pipe)
from gpt_neox.utils import is_main
from gpt_neox.utils import is_main, get_args, get_params
from gpt_neox.data_utils import read_enwik8_data
import gpt_neox

WORLD_SIZE = os.getenv('WORLD_SIZE')

def get_args():
parser = argparse.ArgumentParser(description='GPTNeox Deepspeed Training Script')
# Include DeepSpeed configuration arguments
parser.add_argument('--model', type=str, default="base_model")
parser.add_argument('--local_rank', type=int, default=-1,
help='local rank passed from distributed launcher')
parser = deepspeed.add_config_arguments(parser)
args = parser.parse_args()
return args


def get_params(model):
model_path = model if model.endswith(".json") else f"./configs/{model}.json"
with open(model_path) as f:
params = json.load(f)
return defaultdict(lambda: None, params)


train_args = get_args()
params = get_params(train_args.model)

# instantiate GPT-like decoder model
'''model = GPTNeoX(
num_tokens=params["vocab_size"],
dim=params["hidden_dim"],
seq_len=params["seq_len"],
depth=params["n_layers"],
heads=params["n_heads"],
dim_head=params["dim_head"]
)

model = AutoregressiveWrapper(model)'''

deepspeed.init_distributed()

def loss_function(x, y):
losses = torch.nn.functional.cross_entropy(x, y, reduction='none')
loss = losses.mean()
return loss

model = gpt_neox.GPTNeoX_Pipe(
num_tokens=params["vocab_size"],
dim=params["hidden_dim"],
seq_len=params["seq_len"],
depth=params["n_layers"],
heads=params["n_heads"],
dim_head=params["dim_head"],
loss_fn = loss_function,#torch.nn.CrossEntropyLoss(),
num_stages = params.get("pipeline_num_stages", 2)
)

# prepare enwik8 data
dset_params = params["dataset"]
deepspeed.init_distributed(dist_backend='nccl')
torch.distributed.barrier() # barrier will force processes to stop until *all* processes have reached the barrier
if is_main(train_args):
prepare_data(dset_params["name"])
torch.distributed.barrier() # barrier will force processes to stop until *all* processes have reached the barrier
else:
torch.distributed.barrier()

data_train, data_val = read_enwik8_data(dset_params["path"])
train_dataset = TextSamplerDataset(data_train, params["seq_len"], mode="with_labels")
val_dataset = TextSamplerDataset(data_val, params["seq_len"], mode="with_labels")
val_loader = cycle(DataLoader(val_dataset, batch_size=params["batch_size"]))

# optimizer
optim = torch.optim.Adam(model.parameters(), lr=params["learning_rate"])
def configure_checkpointing(model_engine):
deepspeed.checkpointing.configure(model_engine.mpu, deepspeed_config=train_args.deepspeed_config)
model_engine.mpu.checkpoint = deepspeed.checkpointing.checkpoint
model_engine.mpu.get_cuda_rng_tracker = deepspeed.checkpointing.get_cuda_rng_tracker
model_engine.mpu.model_parallel_cuda_manual_seed = deepspeed.checkpointing.model_parallel_cuda_manual_seed
assert deepspeed.checkpointing.is_configured()

# training
ds_model_params = prepare_optimizer_parameters(model)


# deepspeed loader
model_engine, optim, train_loader, _ = deepspeed.initialize(args=train_args,
model=model,
optimizer=optim,
model_parameters=ds_model_params,
training_data=train_dataset)


batches_to_train = 10000

pbar = trange(params["num_epochs"], mininterval=10., desc='Training Model', dynamic_ncols=True)
for _ in pbar:
for i in range(batches_to_train):

is_main = model_engine.local_rank == 0

loss = model_engine.train_batch()

pbar.set_description(f'Training Loss: {loss.item():.4f}')
pbar.update()

'''if is_main and i % params["validate_every"] == 0:
model.eval()
with torch.no_grad():
val_data = next(val_loader).cuda()
loss = model(val_data)
pbar.write(f'Validation Loss: {loss.item()}')

if is_main and i % params["generate_every"] == 0:
model.eval()
inp = random.choice(val_dataset)[:-1]
prime = decode_tokens(inp)
pbar.write(f"{prime} \n\n {'*' * 100}")
sample = model.generate(inp.cuda(), params["generate_length"])
output_str = decode_tokens(sample)
pbar.write(output_str)'''
def prepare_dataset(dset_params, train_args):
torch.distributed.barrier() # barrier will force processes to stop until *all* processes have reached the barrier
if is_main(train_args):
prepare_data(dset_params["name"])
torch.distributed.barrier() # barrier will force processes to stop until *all* processes have reached the barrier
else:
torch.distributed.barrier()


if __name__ == '__main__':
# arguments
train_args = get_args()
params = get_params(train_args.model)
deepspeed.init_distributed(dist_backend='nccl')
model = gpt_neox.GPTNeoX_Pipe(
num_tokens=params["vocab_size"],
dim=params["hidden_dim"],
seq_len=params["seq_len"],
depth=params["n_layers"],
heads=params["n_heads"],
dim_head=params["dim_head"],
loss_fn = loss_function,
num_stages = params.get("pipeline_num_stages", 2),
activation_checkpoint_interval=params.get('activation_checkpoint_interval', 1)
)

# prepare enwik8 data
dset_params = params["dataset"]
prepare_dataset(dset_params, train_args)
data_train, data_val = read_enwik8_data(dset_params["path"])
train_dataset = TextSamplerDataset(data_train, params["seq_len"], mode="with_labels")
val_dataset = TextSamplerDataset(data_val, params["seq_len"], mode="with_labels")
val_loader = cycle(DataLoader(val_dataset, batch_size=params["batch_size"]))

# optimizer
ds_model_params = prepare_optimizer_parameters(model)
optim = torch.optim.Adam(ds_model_params, lr=params["learning_rate"])
# deepspeed loader
model_engine, optim, train_loader, _ = deepspeed.initialize(args=train_args,
model=model,
optimizer=optim,
model_parameters=ds_model_params,
training_data=train_dataset)
configure_checkpointing(model_engine)

batches_to_train = 10000

pbar = trange(batches_to_train, mininterval=10., desc='Training Model', dynamic_ncols=True)
for _ in pbar:
for i in range(batches_to_train):

loss = model_engine.train_batch()
pbar.set_description(f'Training Loss: {loss.item():.4f}')
pbar.update()
Loading