Skip to content

Commit

Permalink
add download of owt2
Browse files Browse the repository at this point in the history
  • Loading branch information
steven-mi committed Jan 4, 2021
2 parents f4c17a0 + 1f8a58c commit 40bc140
Show file tree
Hide file tree
Showing 8 changed files with 58 additions and 36 deletions.
11 changes: 10 additions & 1 deletion configs/base_deepspeed.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
{
"train_batch_size": 8,
"gradient_accumulation_steps": 1,
"gradient_clipping": 1.0,
"tensorboard": {
"enabled": true,
"output_path": "./logs",
Expand All @@ -12,12 +13,20 @@
"lr": 0.00015
}
},
"scheduler": {
"type": "WarmupLR",
"params": {
"warmup_min_lr": 0,
"warmup_max_lr": 0.00015,
"warmup_num_steps": 5000
}
},
"fp16": {
"enabled": true
},
"zero_optimization": {
"stage": 2,
"contiguous_gradients" : true,
"cpu_offload": true
}
}
}
4 changes: 2 additions & 2 deletions configs/gpt3_small.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
"add_padding_token": false
},
"dataset": {
"name": "OWT2",
"dir": "/root/data",
"name": "owt2",
"dir": "./data",
"seed": 1,
"shuffle_input_filenames": true,
"pretokenized": true,
Expand Down
5 changes: 4 additions & 1 deletion gpt_neox/autoregressive_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,5 +86,8 @@ def forward(self, x, **kwargs):
kwargs.update(mask = mask)

out = self.net(xi, **kwargs)
loss = F.cross_entropy(out.transpose(1, 2), xo, ignore_index = self.ignore_index)

losses = F.cross_entropy(out.transpose(1, 2), xo, reduction='none', ignore_index = self.ignore_index)
loss = losses.mean()

return loss
27 changes: 14 additions & 13 deletions gpt_neox/downloader.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,24 @@
import os
import tarfile

DATASETS = {
"owt2": "http:https://eaidata.bmk.sh/data/owt2_new.tar.gz",
"enwiki8": "http:https://eaidata.bmk.sh/data/enwik8.gz"
}

def download_dataset(dataset, dataset_dir="/root/data"):
if dataset == "OWT2":
_download_owt2(dataset_dir)

def download_dataset(dataset, dataset_dir="./data"):
if DATASETS.get(dataset, False):
return _download_dataset(DATASETS[dataset], os.path.join(dataset_dir, dataset))
else:
raise NotImplementedError # TODO: tokenize text data on the fly
raise NotImplementedError


def _download_owt2(dataset_dir):
download_url = "http:https://eaidata.bmk.sh/data/owt2_new.tar.gz"
file_name = os.path.basename(download_url)
def _download_dataset(dataset_url, dataset_dir):
file_name = os.path.basename(dataset_url)
output_path = os.path.join(dataset_dir, file_name)

if not os.path.isfile(output_path):
os.system('mkdir -p {}'.format(dir))
os.system('wget -O {}'.format(output_path))
os.system('mkdir -p {}'.format(dataset_dir))
os.system('wget {} -O {}'.format(dataset_url, output_path))

dataset_tar = tarfile.open(output_path)
dataset_tar.extractall(dataset_dir)
dataset_tar.close()
return output_path
2 changes: 1 addition & 1 deletion gpt_neox/gpt_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def forward(self, x, **kwargs):
i, j = q.shape[-2], k.shape[-2]
bool_mask = torch.ones(i, j, device=device).triu_(j - i + 1).bool()
mask = torch.zeros(i, j, device=device).to(q)
mask_value = -torch.finfo(q.dtype).max
mask_value = -(torch.finfo(q.dtype).max / 2)
mask.masked_fill_(bool_mask, mask_value)

out = self.attn_fn(q, k, v, attn_mask=mask)
Expand Down
16 changes: 10 additions & 6 deletions gpt_neox/utils.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
import gzip
import os
import tarfile

import numpy as np
import torch


# helpers
def prepare_enwik8_data():
if not os.path.isfile('./data/enwik8.gz'):
os.system('mkdir -p ./data')
os.system('wget http:https://eaidata.bmk.sh/data/enwik8.gz -O ./data/enwik8.gz')

with gzip.open('./data/enwik8.gz') as file:
def prepare_enwik8_data(data_path):
with gzip.open(data_path) as file:
X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)
trX, vaX = np.split(X, [int(90e6)])
data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)
Expand All @@ -28,6 +25,13 @@ def get_all_files(filetype, files_dir):
return files


def extract_tarfile(tarfile_path, extract_dir=None):
dataset_tar = tarfile.open(tarfile_path)
os.makedirs(extract_dir, exist_ok=False)
dataset_tar.extractall(extract_dir)
dataset_tar.close()


def cycle(loader):
while True:
for data in loader:
Expand Down
10 changes: 6 additions & 4 deletions train.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import argparse
import json
import os
import random
from collections import defaultdict

Expand All @@ -8,7 +9,7 @@
from torch.utils.data import DataLoader
from tqdm.auto import trange

from gpt_neox import (GPTNeoX, AutoregressiveWrapper, GPT2Dataset,
from gpt_neox import (GPTNeoX, AutoregressiveWrapper, GPT2Dataset, extract_tarfile,
prepare_optimizer_parameters, get_tokenizer, download_dataset, get_all_files)


Expand Down Expand Up @@ -55,9 +56,10 @@ def get_params(model):
dset_params = params["dataset"]
assert dset_params is not None

download_dataset(dataset=params["name"], dataset_dir=params["dir"])
files = get_all_files(filetype=params["filetype"], files_dir=params["dir"])
# TODO: SPLIT?
data_path = download_dataset(dataset=dset_params["name"], dataset_dir=dset_params["dir"])
data_dir = os.path.dirname(data_path)
extract_tarfile(tarfile_path=data_path, extract_dir=data_dir)
files = get_all_files(filetype=dset_params["filetype"], files_dir=data_dir)

train_dataset = GPT2Dataset(files=files,
seq_len=params["seq_len"],
Expand Down
19 changes: 11 additions & 8 deletions train_enwik8.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from gpt_neox import (GPTNeoX, AutoregressiveWrapper, TextSamplerDataset,
cycle, prepare_optimizer_parameters, decode_tokens, prepare_enwik8_data)
import argparse
import json
import random
from collections import defaultdict

import deepspeed
import torch
from torch.utils.data import DataLoader
import deepspeed
from tqdm.auto import trange
import argparse
import json
from collections import defaultdict

from gpt_neox import (GPTNeoX, AutoregressiveWrapper, TextSamplerDataset, download_dataset,
cycle, prepare_optimizer_parameters, decode_tokens, prepare_enwik8_data)


def get_args():
Expand Down Expand Up @@ -44,7 +46,8 @@ def get_params(model):
model = AutoregressiveWrapper(model)

# prepare enwik8 data
data_train, data_val = prepare_enwik8_data()
data_path = download_dataset(dataset="enwiki8")
data_train, data_val = prepare_enwik8_data(data_path=data_path)
train_dataset = TextSamplerDataset(data_train, params["seq_len"])
val_dataset = TextSamplerDataset(data_val, params["seq_len"])
val_loader = cycle(DataLoader(val_dataset, batch_size=params["batch_size"]))
Expand All @@ -60,7 +63,7 @@ def get_params(model):
model=model,
optimizer=optim,
model_parameters=ds_model_params,
training_data=train_dataset)
training_data=train_dataset)

pbar = trange(params["num_epochs"], mininterval=10., desc='Training Model', dynamic_ncols=True)
for _ in pbar:
Expand Down

0 comments on commit 40bc140

Please sign in to comment.