Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix deprecation warning #42

Merged
merged 7 commits into from
Jan 5, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
add class for automatically downloading datasets
  • Loading branch information
sdtblck committed Jan 5, 2021
commit 7e3a0486519237434f4c984e16df5bb33ff2e0d1
1 change: 1 addition & 0 deletions configs/base_model.json
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
{
"data_path": "./data/enwik8.tar.gz",
"num_epochs": 10,
"vocab_size": 256,
"batch_size": 4,
Expand Down
8 changes: 3 additions & 5 deletions configs/gpt3_small.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,13 @@
},
"dataset": {
"name": "owt2",
"dir": "./data",
"train_path": "./data/train/*",
"eval_path": "./data/eval/*",
"seed": 1,
"shuffle_input_filenames": true,
"pretokenized": true,
"filetype": "tfrecords",
"mode": "chunks",
"save_progress_every": 10000,
"checkpoint_path": "gpt3_small_ckpt.txt",
"resume_from_checkpoint": true
"mode": "chunks"
},
"train_steps": 572300,
"batch_size": 256,
Expand Down
2 changes: 1 addition & 1 deletion gpt_neox/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from gpt_neox.autoregressive_wrapper import AutoregressiveWrapper
from gpt_neox.data_utils import get_tokenizer
from gpt_neox.datasets import TextSamplerDataset, GPT2Dataset
from gpt_neox.downloader import download_dataset
from gpt_neox.gpt_neox import GPTNeoX
from gpt_neox.utils import *
from gpt_neox.data_downloader_registry import prepare_data
127 changes: 127 additions & 0 deletions gpt_neox/data_downloader_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
import os
import tarfile
from abc import ABC, abstractmethod
from glob import glob
import shutil
import random

"""
This registry is for automatically downloading and extracting datasets.

To register a class you need to inherit the DataDownloader class and provide name, filetype and url attributes, and (optionally)
provide download / extract / exists functions to check if the data exists, and, if it doesn't, download and extract the data
and move it to the correct directory.

When done, add it to the DATA_DOWNLOADERS dict. The function process_data runs the pre-processing for the selected dataset.
"""

class DataDownloader(ABC):

"""Dataset registry class to automatically download / extract datasets"""

@property
def base_dir(self):
"""base data directory"""
return "./data"

@property
@abstractmethod
def name(self):
"""name of dataset"""
pass

@property
@abstractmethod
def filetype(self):
"""filetype of dataset"""
pass

@property
@abstractmethod
def url(self):
"""URL from which to download dataset"""
pass

def _extract(self):
self.path = os.path.join(self.base_dir, self.name)
os.makedirs(self.path, exist_ok=True)
tarfile_path = os.path.join(self.base_dir, os.path.basename(self.url))
with tarfile.open(tarfile_path, "r:gz") as dataset_tar:
print(f'Extracting files from {tarfile_path}...')
dataset_tar.extractall(self.path)

def extract(self):
"""extracts dataset and moves to the correct data dir if necessary"""
self._extract()

def exists(self):
"""Checks if the dataset is present"""
return os.path.isdir(f"{self.base_dir}/{self.name}")

def download(self):
"""downloads dataset"""
os.makedirs(self.base_dir, exist_ok = True)
os.system(f"wget {self.url} -O {os.path.join(self.base_dir, os.path.basename(self.url))}")

def prepare(self):
if not self.exists():
self.download()
self.extract()

class OWT2(DataDownloader):

name = "owt2"
filetype = "tfrecords"
url = "https://eaidata.bmk.sh/data/owt2_new.tar.gz"
seed = 1

def extract(self):
self._extract()
# the files are within nested subdirectories, and not split by train/test
# so we need to move them to the correct directories
all_files = glob(f"{self.path}/**/*.{self.filetype}", recursive=True)
print(all_files)
train_dir = f"{self.path}/train"
eval_dir = f"{self.path}/eval"
os.makedirs(train_dir, exist_ok=True)
os.makedirs(eval_dir, exist_ok=True)
total_tfrecords = len(all_files)
n_eval_tfrecords = total_tfrecords // 10
# owt2 doesn't have an official train/test split, so sample at random from tfrecords
random.seed(self.seed)
random.shuffle(all_files)
eval_set = all_files[:n_eval_tfrecords]
train_set = all_files[n_eval_tfrecords:]
for f in train_set:
shutil.move(f, train_dir)
for f in eval_set:
shutil.move(f, eval_dir)
dirs_to_remove = [f for f in glob(f"{self.path}/*") if f not in [train_dir, eval_dir]]
for d in dirs_to_remove:
shutil.rmtree(d)

class Enwik8(DataDownloader):

name = "owt2"
filetype = "tar.gz"
url = "https://eaidata.bmk.sh/data/enwik8.gz"

def extract(self):
pass

def exists(self):
return os.path.isfile(f"{self.base_dir}/enwik8.gz")


DATA_DOWNLOADERS = {
"owt2": OWT2,
"enwik8": Enwik8
}

def prepare_data(dataset_name):
DownloaderClass = DATA_DOWNLOADERS.get(dataset_name, None)
if DownloaderClass is None:
raise NotImplementedError
else:
d = DownloaderClass()
d.prepare()
7 changes: 7 additions & 0 deletions gpt_neox/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,10 @@ def get_tokenizer(tokenizer_type=None, from_pretrained=True, add_padding_token=F
return tok
else:
raise NotImplementedError('TODO: add custom tokenizers')

def read_enwik8_data(data_path):
with gzip.open(data_path) as file:
X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)
trX, vaX = np.split(X, [int(90e6)])
data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)
return data_train, data_val
24 changes: 0 additions & 24 deletions gpt_neox/downloader.py

This file was deleted.

34 changes: 26 additions & 8 deletions gpt_neox/utils.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,37 @@
import gzip
import os
import tarfile

import numpy as np
import torch
import argparse
import deepspeed
import json
from collections import defaultdict


# helpers
def prepare_enwik8_data(data_path):
with gzip.open(data_path) as file:
X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)
trX, vaX = np.split(X, [int(90e6)])
data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)
return data_train, data_val

def get_args():
parser = argparse.ArgumentParser(description='GPTNeox Deepspeed Training Script')
# Include DeepSpeed configuration arguments
parser.add_argument('--model', type=str, default="gpt3_small")
parser.add_argument('--local_rank', type=int, default=-1,
help='local rank passed from distributed launcher')
parser = deepspeed.add_config_arguments(parser)
args = parser.parse_args()
return args


def get_params(model):
model_path = model if model.endswith(".json") else f"./configs/{model}.json"
with open(model_path) as f:
params = json.load(f)
return defaultdict(lambda: None, params)

def is_main(args):
"""
returns True if process is being run on the main GPU
"""
return args.local_rank in [0, -1]

def get_all_files(filetype, files_dir):
files = []
Expand Down
12 changes: 12 additions & 0 deletions prepare_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from gpt_neox.utils import get_args, get_params
from gpt_neox import prepare_data

args = get_args()
if args.model == "enwik8":
prepare_data("enwik8")
else:
params = get_params(args.model)
# prepare data
dset_params = params["dataset"]
assert dset_params is not None
prepare_data(dset_params["name"])
1 change: 1 addition & 0 deletions scripts/train_enwik8.sh
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
mkdir logs
python3 prepare_dataset.py --model enwik8
NCCL_SHM_DISABLE=1 NCCL_DEBUG=info MASTER_ADDR=127.0.0.1 MASTER_PORT=2000 deepspeed train_enwik8.py --deepspeed --deepspeed_config configs/base_deepspeed.json
1 change: 1 addition & 0 deletions scripts/train_gpt3small.sh
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
mkdir logs
python3 prepare_dataset.py
NCCL_SHM_DISABLE=1 NCCL_DEBUG=info MASTER_ADDR=127.0.0.1 MASTER_PORT=2000 deepspeed train.py --deepspeed --deepspeed_config configs/base_deepspeed.json
38 changes: 14 additions & 24 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,30 +8,17 @@
import torch
from torch.utils.data import DataLoader
from tqdm.auto import trange
import torch.distributed as distributed

from gpt_neox import (GPTNeoX, AutoregressiveWrapper, GPT2Dataset, extract_tarfile,
prepare_optimizer_parameters, get_tokenizer, download_dataset, get_all_files)
prepare_optimizer_parameters, get_tokenizer, is_main, prepare_data)

from gpt_neox.utils import prepare_enwik8_data, get_args, get_params

def get_args():
parser = argparse.ArgumentParser(description='GPTNeox Deepspeed Training Script')
# Include DeepSpeed configuration arguments
parser.add_argument('--model', type=str, default="gpt3_small")
parser.add_argument('--local_rank', type=int, default=-1,
help='local rank passed from distributed launcher')
parser = deepspeed.add_config_arguments(parser)
args = parser.parse_args()
return args


def get_params(model):
model_path = model if model.endswith(".json") else f"./configs/{model}.json"
with open(model_path) as f:
params = json.load(f)
return defaultdict(lambda: None, params)


train_args = get_args()
print("RANK: ", train_args.local_rank)
params = get_params(train_args.model)

# tokenizer
Expand All @@ -56,16 +43,19 @@ def get_params(model):
dset_params = params["dataset"]
assert dset_params is not None

data_path = download_dataset(dataset=dset_params["name"], dataset_dir=dset_params["dir"])
data_dir = os.path.dirname(data_path)
extract_tarfile(tarfile_path=data_path, extract_dir=data_dir)
files = get_all_files(filetype=dset_params["filetype"], files_dir=data_dir)

train_dataset = GPT2Dataset(files=files,
torch.distributed.barrier() # barrier will force processes to stop until *all* processes have reached the barrier
if is_main(train_args):
prepare_data(dset_params["name"])
torch.distributed.barrier() # barrier will force processes to stop until *all* processes have reached the barrier
else:
torch.distributed.barrier()

train_dataset = GPT2Dataset(glob_pattern=dset_params["train_path"],
seq_len=params["seq_len"],
train=True,
**dset_params)
eval_dataset = GPT2Dataset(files=files,

eval_dataset = GPT2Dataset(glob_pattern=dset_params["eval_path"],
seq_len=params["seq_len"],
train=False,
**dset_params)
Expand Down
10 changes: 4 additions & 6 deletions train_enwik8.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from tqdm.auto import trange

from gpt_neox import (GPTNeoX, AutoregressiveWrapper, TextSamplerDataset, download_dataset,
cycle, prepare_optimizer_parameters, decode_tokens, prepare_enwik8_data)
cycle, prepare_optimizer_parameters, decode_tokens, prepare_enwik8_data, is_main)


def get_args():
Expand Down Expand Up @@ -46,8 +46,7 @@ def get_params(model):
model = AutoregressiveWrapper(model)

# prepare enwik8 data
data_path = download_dataset(dataset="enwiki8")
data_train, data_val = prepare_enwik8_data(data_path=data_path)
data_train, data_val = prepare_enwik8_data(params["data_path"])
train_dataset = TextSamplerDataset(data_train, params["seq_len"])
val_dataset = TextSamplerDataset(data_val, params["seq_len"])
val_loader = cycle(DataLoader(val_dataset, batch_size=params["batch_size"]))
Expand All @@ -69,7 +68,6 @@ def get_params(model):
for _ in pbar:
for i, data in enumerate(train_loader):
model_engine.train()
is_main = model_engine.local_rank == 0
data = data.to(model_engine.local_rank)

loss = model_engine(data)
Expand All @@ -79,14 +77,14 @@ def get_params(model):
pbar.set_description(f'Training Loss: {loss.item():.4f}')
pbar.update()

if is_main and i % params["validate_every"] == 0:
if is_main(train_args) and i % params["validate_every"] == 0:
model.eval()
with torch.no_grad():
val_data = next(val_loader).cuda()
loss = model(val_data)
pbar.write(f'Validation Loss: {loss.item()}')

if is_main and i % params["generate_every"] == 0:
if is_main(train_args) and i % params["generate_every"] == 0:
model.eval()
inp = random.choice(val_dataset)[:-1]
prime = decode_tokens(inp)
Expand Down