Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix deprecation warning #42

Merged
merged 7 commits into from
Jan 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions configs/base_deepspeed.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"train_batch_size": 8,
"gradient_accumulation_steps": 1,
"train_batch_size": 512,
"train_micro_batch_size_per_gpu": 8,
"gradient_clipping": 1.0,
"tensorboard": {
"enabled": true,
Expand Down
4 changes: 4 additions & 0 deletions configs/base_model.json
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
{
"dataset": {
"name": "enwik8",
"path": "./data/enwik8.gz"
},
"num_epochs": 10,
"vocab_size": 256,
"batch_size": 4,
Expand Down
10 changes: 3 additions & 7 deletions configs/gpt3_small.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,15 @@
},
"dataset": {
"name": "owt2",
"dir": "./data",
"train_path": "./data/owt2/train/*",
"eval_path": "./data/owt2/eval/*",
"seed": 1,
"shuffle_input_filenames": true,
"pretokenized": true,
"filetype": "tfrecords",
"mode": "chunks",
"save_progress_every": 10000,
"checkpoint_path": "gpt3_small_ckpt.txt",
"resume_from_checkpoint": true
"mode": "chunks"
},
"train_steps": 572300,
"batch_size": 256,
"eval_batch_size": 32,
"learning_rate": 0.0006,
"generate_every": 500,
"generate_length": 256,
Expand Down
4 changes: 2 additions & 2 deletions gpt_neox/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from gpt_neox.autoregressive_wrapper import AutoregressiveWrapper
from gpt_neox.data_utils import get_tokenizer
from gpt_neox.data_utils import get_tokenizer, read_enwik8_data
from gpt_neox.datasets import TextSamplerDataset, GPT2Dataset
from gpt_neox.downloader import download_dataset
from gpt_neox.gpt_neox import GPTNeoX
from gpt_neox.utils import *
from gpt_neox.data_downloader_registry import prepare_data
129 changes: 129 additions & 0 deletions gpt_neox/data_downloader_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import os
import tarfile
from abc import ABC, abstractmethod
from glob import glob
import shutil
import random

"""
This registry is for automatically downloading and extracting datasets.
To register a class you need to inherit the DataDownloader class and provide name, filetype and url attributes, and
(optionally) provide download / extract / exists functions to check if the data exists, and, if it doesn't, download and
extract the data and move it to the correct directory.
When done, add it to the DATA_DOWNLOADERS dict. The function process_data runs the pre-processing for the selected
dataset.
"""


class DataDownloader(ABC):
"""Dataset registry class to automatically download / extract datasets"""

@property
def base_dir(self):
"""base data directory"""
return "./data"

@property
@abstractmethod
def name(self):
"""name of dataset"""
pass

@property
@abstractmethod
def filetype(self):
"""filetype of dataset"""
pass

@property
@abstractmethod
def url(self):
"""URL from which to download dataset"""
pass

def _extract(self):
self.path = os.path.join(self.base_dir, self.name)
os.makedirs(self.path, exist_ok=True)
tarfile_path = os.path.join(self.base_dir, os.path.basename(self.url))
with tarfile.open(tarfile_path, "r:gz") as dataset_tar:
print(f'Extracting files from {tarfile_path}...')
dataset_tar.extractall(self.path)

def extract(self):
"""extracts dataset and moves to the correct data dir if necessary"""
self._extract()

def exists(self):
"""Checks if the dataset is present"""
return os.path.isdir(f"{self.base_dir}/{self.name}")

def download(self):
"""downloads dataset"""
os.makedirs(self.base_dir, exist_ok=True)
os.system(f"wget {self.url} -O {os.path.join(self.base_dir, os.path.basename(self.url))}")

def prepare(self):
if not self.exists():
self.download()
self.extract()


class OWT2(DataDownloader):
name = "owt2"
filetype = "tfrecords"
url = "http:https://eaidata.bmk.sh/data/owt2_new.tar.gz"
seed = 1

def extract(self):
self._extract()
# the files are within nested subdirectories, and not split by train/test
# so we need to move them to the correct directories
all_files = glob(f"{self.path}/**/*.{self.filetype}", recursive=True)
print(all_files)
train_dir = f"{self.path}/train"
eval_dir = f"{self.path}/eval"
os.makedirs(train_dir, exist_ok=True)
os.makedirs(eval_dir, exist_ok=True)
total_tfrecords = len(all_files)
n_eval_tfrecords = total_tfrecords // 10
# owt2 doesn't have an official train/test split, so sample at random from tfrecords
random.seed(self.seed)
random.shuffle(all_files)
eval_set = all_files[:n_eval_tfrecords]
train_set = all_files[n_eval_tfrecords:]
for f in train_set:
shutil.move(f, train_dir)
for f in eval_set:
shutil.move(f, eval_dir)
dirs_to_remove = [f for f in glob(f"{self.path}/*") if f not in [train_dir, eval_dir]]
for d in dirs_to_remove:
shutil.rmtree(d)


class Enwik8(DataDownloader):
name = "owt2"
filetype = "gz"
url = "http:https://eaidata.bmk.sh/data/enwik8.gz"

def extract(self):
pass

def exists(self):
return os.path.isfile(f"{self.base_dir}/enwik8.gz")


DATA_DOWNLOADERS = {
"owt2": OWT2,
"enwik8": Enwik8
}


def prepare_data(dataset_name):
DownloaderClass = DATA_DOWNLOADERS.get(dataset_name, None)
if DownloaderClass is None:
raise NotImplementedError
else:
d = DownloaderClass()
d.prepare()
11 changes: 10 additions & 1 deletion gpt_neox/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
from itertools import islice
import re
from collections import OrderedDict

import gzip
import numpy as np
import torch

class FixedSizeOrderedDict(OrderedDict):
def __init__(self, *args, max=0, **kwargs):
Expand Down Expand Up @@ -39,3 +41,10 @@ def get_tokenizer(tokenizer_type=None, from_pretrained=True, add_padding_token=F
return tok
else:
raise NotImplementedError('TODO: add custom tokenizers')

def read_enwik8_data(data_path):
with gzip.open(data_path) as file:
X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)
trX, vaX = np.split(X, [int(90e6)])
data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)
return data_train, data_val
22 changes: 12 additions & 10 deletions gpt_neox/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __init__(self, glob_pattern, seq_len, seed=1, shuffle_input_filenames=True,
if self.filetype not in implemented_filetypes:
raise NotImplementedError

self.processed_files = FixedSizeOrderedDict(max=2) # storage for lazily loading data
self.processed_files = FixedSizeOrderedDict(max=1) # storage for lazily loading data

# parses the length of the files, either by encoding in the filenames or by iterating over them
self._get_lens()
Expand Down Expand Up @@ -71,17 +71,19 @@ def _get_lens(self):
lens.append(n_documents)
self.lens = lens
self._len = sum(self.lens)

def _parse_single_example(self, example):
data = tf.train.Example.FromString(example)
data = torch.tensor(list(data.features.feature["text"].int64_list.value), dtype=torch.long)
if self.mode == "chunks":
assert data.size(0) == self.seq_len + 1
return data
def _parse_function(self, example_proto):
features = {
"text": tf.io.VarLenFeature(tf.int64)
}
parsed_features = tf.io.parse_single_example(example_proto, features)
return tf.sparse.to_dense(parsed_features["text"], parsed_features["text"].dense_shape[0])

def _process_tfrecord(self, tfrecords_file, resume_idx=None):
for idx, example in enumerate(tf.io.tf_record_iterator(tfrecords_file)):
yield self._parse_single_example(example)
dataset = tf.data.TFRecordDataset([tfrecords_file])
dataset = dataset.map(self._parse_function, num_parallel_calls=1)
for example in dataset.as_numpy_iterator():
yield torch.tensor(example, dtype=torch.long)

def _maybe_process_tfrecord(self, file_idx):
if self.processed_files.get(file_idx) is None:
Expand Down
24 changes: 0 additions & 24 deletions gpt_neox/downloader.py

This file was deleted.

37 changes: 27 additions & 10 deletions gpt_neox/utils.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,35 @@
import gzip
import os
import tarfile

import numpy as np
import torch
import argparse
import deepspeed
import json
from collections import defaultdict


# helpers
def prepare_enwik8_data(data_path):
with gzip.open(data_path) as file:
X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)
trX, vaX = np.split(X, [int(90e6)])
data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)
return data_train, data_val
def get_args():
parser = argparse.ArgumentParser(description='GPTNeox Deepspeed Training Script')
# Include DeepSpeed configuration arguments
parser.add_argument('--model', type=str, default="gpt3_small")
parser.add_argument('--local_rank', type=int, default=-1,
help='local rank passed from distributed launcher')
parser = deepspeed.add_config_arguments(parser)
args = parser.parse_args()
return args


def get_params(model):
model_path = model if model.endswith(".json") else f"./configs/{model}.json"
with open(model_path) as f:
params = json.load(f)
return defaultdict(lambda: None, params)


def is_main(args):
"""
returns True if process is being run on the main GPU
"""
return args.local_rank in [0, -1]


def get_all_files(filetype, files_dir):
Expand Down
50 changes: 19 additions & 31 deletions train.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,14 @@
import argparse
import json
import os
import random
from collections import defaultdict

import deepspeed
import torch
from torch.utils.data import DataLoader
from tqdm.auto import trange
import torch.distributed as distributed

from gpt_neox import (GPTNeoX, AutoregressiveWrapper, GPT2Dataset, extract_tarfile,
prepare_optimizer_parameters, get_tokenizer, download_dataset, get_all_files)


def get_args():
parser = argparse.ArgumentParser(description='GPTNeox Deepspeed Training Script')
# Include DeepSpeed configuration arguments
parser.add_argument('--model', type=str, default="gpt3_small")
parser.add_argument('--local_rank', type=int, default=-1,
help='local rank passed from distributed launcher')
parser = deepspeed.add_config_arguments(parser)
args = parser.parse_args()
return args


def get_params(model):
model_path = model if model.endswith(".json") else f"./configs/{model}.json"
with open(model_path) as f:
params = json.load(f)
return defaultdict(lambda: None, params)
prepare_optimizer_parameters, get_tokenizer, is_main, prepare_data)

from gpt_neox.utils import get_args, get_params

train_args = get_args()
params = get_params(train_args.model)
Expand All @@ -56,16 +35,20 @@ def get_params(model):
dset_params = params["dataset"]
assert dset_params is not None

data_path = download_dataset(dataset=dset_params["name"], dataset_dir=dset_params["dir"])
data_dir = os.path.dirname(data_path)
extract_tarfile(tarfile_path=data_path, extract_dir=data_dir)
files = get_all_files(filetype=dset_params["filetype"], files_dir=data_dir)
deepspeed.init_distributed(dist_backend='nccl')
torch.distributed.barrier() # barrier will force processes to stop until *all* processes have reached the barrier
if is_main(train_args):
prepare_data(dset_params["name"])
torch.distributed.barrier() # barrier will force processes to stop until *all* processes have reached the barrier
else:
torch.distributed.barrier()

train_dataset = GPT2Dataset(files=files,
train_dataset = GPT2Dataset(glob_pattern=dset_params["train_path"],
seq_len=params["seq_len"],
train=True,
**dset_params)
eval_dataset = GPT2Dataset(files=files,

eval_dataset = GPT2Dataset(glob_pattern=dset_params["eval_path"],
seq_len=params["seq_len"],
train=False,
**dset_params)
Expand All @@ -74,7 +57,11 @@ def get_params(model):
val_loader = iter(val_loader)

# optimizer
optim = torch.optim.Adam(model.parameters(), lr=params["learning_rate"])
if train_args.local_rank == -1: # non-deepspeed
optim = torch.optim.Adam(model.parameters(), lr=params["learning_rate"])
else:
optim = None # deepspeed will prepare the optimizer for us


# training
ds_model_params = prepare_optimizer_parameters(model)
Expand All @@ -86,6 +73,7 @@ def get_params(model):
model_parameters=ds_model_params,
training_data=train_dataset)

print("OPTIMIZER: ", optim)
pbar = trange(params.get("train_steps", 1), mininterval=10., desc='Training Model', dynamic_ncols=True)
for _ in pbar:
for i, data in enumerate(train_loader):
Expand Down
Loading