Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix train pipeline #89

Merged
merged 6 commits into from
Jan 25, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
push fixes to train_pipeline.py
  • Loading branch information
sdtblck committed Jan 23, 2021
commit 139f6b9e33cdffb07e8567f9dcd14ac2fcefb400
4 changes: 2 additions & 2 deletions configs/gpt3_small.json
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
"seed": 1,
"shuffle_input_filenames": true,
"pretokenized": true,
"filetype": "tfrecords",
"mode": "chunks"
"filetype": "tfrecords"
},
"num_epochs": 10,
"train_steps": 572300,
"eval_batch_size": 32,
"learning_rate": 0.0006,
Expand Down
27 changes: 17 additions & 10 deletions gpt_neox/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
class GPT2Dataset(Dataset):

def __init__(self, glob_pattern, seq_len, seed=1, shuffle_input_filenames=True, pretokenized=True,
filetype="tfrecords", mode="chunks", train=True, tokenizer=None, **kwargs):
filetype="tfrecords", mode="normal", train=True, tokenizer=None, **kwargs):

super().__init__()
self.files = glob.glob(glob_pattern) # glob pattern pointing to files
Expand All @@ -34,16 +34,13 @@ def __init__(self, glob_pattern, seq_len, seed=1, shuffle_input_filenames=True,
self._get_lens()

self.seq_len = seq_len # set sequence length
self.mode = mode # set mode ["chunks"]
implemented_modes = ["chunks"]
if self.mode not in implemented_modes:
raise NotImplementedError


self.pretokenized = pretokenized
if not self.pretokenized:
raise NotImplementedError # TODO: tokenize text data on the fly

self.train = train
self.mode = mode

def _get_number_of_documents(self, filename):
# extracts number of files from a filename formatted "<name>_<num_documents>.{filetype}."
Expand Down Expand Up @@ -75,8 +72,6 @@ def _get_lens(self):
def _parse_single_example(self, example):
data = tf.train.Example.FromString(example)
data = torch.tensor(list(data.features.feature["text"].int64_list.value), dtype=torch.long)
if self.mode == "chunks":
assert data.size(0) == self.seq_len + 1
return data

def _process_tfrecord(self, tfrecords_file, resume_idx=None):
Expand Down Expand Up @@ -111,7 +106,17 @@ def __getitem__(self, idx):
seek_idx) # parses tfrecord file to a list *once* then stores in memory
else:
raise NotImplementedError
return chunk[remainder] # get item from current chunk
output = chunk[remainder]
assert output is not None
assert output.size(0) == (self.seq_len + 1), f"Output shape ({output.size(0)}) != the specified sequence length + 1 ({self.seq_len + 1})"
if self.mode == "normal":
return output
elif self.mode == 'with_labels':
x_seq = output[:self.seq_len]
y_seq = output[1:self.seq_len + 1]
return x_seq, y_seq
else:
raise ValueError(f'mode {self.mode} not recognized')

def __len__(self):
return self._len
Expand All @@ -130,10 +135,12 @@ def __getitem__(self, index):
if self.mode == "normal":
full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
return full_seq
else:
elif self.mode == "with_labels":
x_seq = self.data[rand_start: rand_start + self.seq_len].long()
y_seq = self.data[rand_start+1: rand_start + self.seq_len + 1].long()
return x_seq, y_seq
else:
raise ValueError(f'mode {self.mode} not recognized')

def __len__(self):
return self.data.size(0) // self.seq_len
2 changes: 1 addition & 1 deletion scripts/train_gpt3small_pipeline.sh
Original file line number Diff line number Diff line change
@@ -1 +1 @@
MASTER_ADDR=127.0.0.1 MASTER_PORT=2000 deepspeed train_pipeline.py --deepspeed --deepspeed_config configs/deepspeed_zero2.json
MASTER_ADDR=127.0.0.1 MASTER_PORT=2000 deepspeed train_pipeline.py --deepspeed --deepspeed_config configs/deepspeed_zero1.json
133 changes: 63 additions & 70 deletions train_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,82 +18,75 @@

WORLD_SIZE = os.getenv('WORLD_SIZE')

# arguments
train_args = get_args()
params = get_params(train_args.model)

# tokenizer
tokenizer = get_tokenizer(tokenizer_type=params["tokenizer"].get("type", None),
from_pretrained=params["tokenizer"].get("from_pretrained", True),
add_padding_token=params["tokenizer"].get("add_padding_token", False))
vocab_size = len(tokenizer) if params["vocab_size"] is None else params["vocab_size"]

# model
deepspeed.init_distributed(dist_backend='nccl')
torch.distributed.barrier() # barrier will force processes to stop until *all* processes have reached the barrier

def loss_function(x, y):
losses = torch.nn.functional.cross_entropy(x, y, reduction='none')
loss = losses.mean()
return loss

model = GPTNeoX_Pipe(
num_tokens=params["vocab_size"],
dim=params["hidden_dim"],
seq_len=params["seq_len"],
depth=params["n_layers"],
heads=params["n_heads"],
dim_head=params["dim_head"],
loss_fn = loss_function,
num_stages = params.get("pipeline_num_stages", 2)
)
model = AutoregressiveWrapper(model)

# optimizer
ds_model_params = prepare_optimizer_parameters(model)
optim = torch.optim.Adam(model.parameters(), lr=params["learning_rate"])

# prepare data
dset_params = params["dataset"]
assert dset_params is not None

if is_main(train_args):
prepare_data(dset_params["name"])
if __name__ == '__main__':
# arguments
train_args = get_args()
params = get_params(train_args.model)
deepspeed.init_distributed(dist_backend='nccl')

# tokenizer
tokenizer = get_tokenizer(tokenizer_type=params["tokenizer"].get("type", None),
from_pretrained=params["tokenizer"].get("from_pretrained", True),
add_padding_token=params["tokenizer"].get("add_padding_token", False))
vocab_size = len(tokenizer) if params["vocab_size"] is None else params["vocab_size"]

# model
model = GPTNeoX_Pipe(
num_tokens=vocab_size,
dim=params["hidden_dim"],
seq_len=params["seq_len"],
depth=params["n_layers"],
heads=params["n_heads"],
dim_head=params["dim_head"],
loss_fn = loss_function,
num_stages = params.get("pipeline_num_stages", 4),
activation_checkpoint_interval=1
)

# prepare data
dset_params = params["dataset"]
assert dset_params is not None
torch.distributed.barrier() # barrier will force processes to stop until *all* processes have reached the barrier
else:
torch.distributed.barrier()

# data loading
train_dataset = GPT2Dataset(glob_pattern=dset_params["train_path"],
if is_main(train_args):
prepare_data(dset_params["name"])
torch.distributed.barrier() # barrier will force processes to stop until *all* processes have reached the barrier
else:
torch.distributed.barrier()

train_dataset = GPT2Dataset(glob_pattern=dset_params["train_path"],
seq_len=params["seq_len"],
train=True,
mode='with_labels',
**dset_params)

eval_dataset = GPT2Dataset(glob_pattern=dset_params["eval_path"],
seq_len=params["seq_len"],
train=True,
train=False,
mode='with_labels',
**dset_params)

eval_dataset = GPT2Dataset(glob_pattern=dset_params["eval_path"],
seq_len=params["seq_len"],
train=False,
**dset_params)

val_loader = DataLoader(eval_dataset, batch_size=params["eval_batch_size"])
val_loader = iter(val_loader)

# deepspeed loader
model_engine, optim, train_loader, _ = deepspeed.initialize(args=train_args,
model=model,
optimizer=optim,
model_parameters=ds_model_params,
training_data=train_dataset)


batches_to_train = 10000

pbar = trange(params["num_epochs"], mininterval=10., desc='Training Model', dynamic_ncols=True)
for _ in pbar:
for i in range(batches_to_train):

is_main = model_engine.local_rank == 0

loss = model_engine.train_batch()

pbar.set_description(f'Training Loss: {loss.item():.4f}')
pbar.update()
val_loader = DataLoader(eval_dataset, batch_size=params["eval_batch_size"])
val_loader = cycle(val_loader)

# optimizer
ds_model_params = prepare_optimizer_parameters(model)
optim = torch.optim.Adam(ds_model_params, lr=params["learning_rate"])
# deepspeed loader
model_engine, optim, train_loader, _ = deepspeed.initialize(args=train_args,
model=model,
optimizer=optim,
model_parameters=ds_model_params,
training_data=train_dataset)

batches_to_train = 10000
pbar = trange(params["num_epochs"], mininterval=10., desc='Training Model', dynamic_ncols=True)
for _ in pbar:
for i in range(batches_to_train):
loss = model_engine.train_batch()
pbar.set_description(f'Training Loss: {loss.item():.4f}')
pbar.update()