Skip to content

Commit

Permalink
fix enwik8 errors
Browse files Browse the repository at this point in the history
  • Loading branch information
sdtblck committed Jan 5, 2021
1 parent ee7aad0 commit a38537e
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion configs/base_model.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"dataset": {
"name": "enwik8",
"path": "./data/enwik8.tar.gz"
"path": "./data/enwik8.gz"
},
"num_epochs": 10,
"vocab_size": 256,
Expand Down
2 changes: 1 addition & 1 deletion gpt_neox/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from gpt_neox.autoregressive_wrapper import AutoregressiveWrapper
from gpt_neox.data_utils import get_tokenizer
from gpt_neox.data_utils import get_tokenizer, read_enwik8_data
from gpt_neox.datasets import TextSamplerDataset, GPT2Dataset
from gpt_neox.gpt_neox import GPTNeoX
from gpt_neox.utils import *
Expand Down
4 changes: 2 additions & 2 deletions train_enwik8.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from torch.utils.data import DataLoader
from tqdm.auto import trange

from gpt_neox import (GPTNeoX, AutoregressiveWrapper, TextSamplerDataset, download_dataset,
from gpt_neox import (GPTNeoX, AutoregressiveWrapper, TextSamplerDataset,
cycle, prepare_optimizer_parameters, decode_tokens, read_enwik8_data, is_main, prepare_data)


Expand Down Expand Up @@ -54,7 +54,7 @@ def get_params(model):
torch.distributed.barrier()

# prepare enwik8 data
data_train, data_val = read_enwik8_data(dset_params["data_path"])
data_train, data_val = read_enwik8_data(dset_params["path"])
train_dataset = TextSamplerDataset(data_train, params["seq_len"])
val_dataset = TextSamplerDataset(data_val, params["seq_len"])
val_loader = cycle(DataLoader(val_dataset, batch_size=params["batch_size"]))
Expand Down

0 comments on commit a38537e

Please sign in to comment.