Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix deprecation warning #42

Merged
merged 7 commits into from
Jan 5, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
fix deprecation warning in dataset
  • Loading branch information
sdtblck committed Jan 5, 2021
commit ff5a3088065491c8e790df1affe6e3c581966b82
4 changes: 2 additions & 2 deletions configs/base_deepspeed.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"train_batch_size": 8,
"gradient_accumulation_steps": 1,
"train_batch_size": 512,
"train_micro_batch_size_per_gpu": 8,
"gradient_clipping": 1.0,
"tensorboard": {
"enabled": true,
Expand Down
2 changes: 0 additions & 2 deletions configs/gpt3_small.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
"mode": "chunks"
},
"train_steps": 572300,
"batch_size": 256,
"eval_batch_size": 32,
"learning_rate": 0.0006,
"generate_every": 500,
"generate_length": 256,
Expand Down
2 changes: 1 addition & 1 deletion gpt_neox/data_downloader_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def extract(self):

class Enwik8(DataDownloader):
name = "owt2"
filetype = "tar.gz"
filetype = "gz"
url = "https://eaidata.bmk.sh/data/enwik8.gz"

def extract(self):
Expand Down
22 changes: 12 additions & 10 deletions gpt_neox/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __init__(self, glob_pattern, seq_len, seed=1, shuffle_input_filenames=True,
if self.filetype not in implemented_filetypes:
raise NotImplementedError

self.processed_files = FixedSizeOrderedDict(max=2) # storage for lazily loading data
self.processed_files = FixedSizeOrderedDict(max=1) # storage for lazily loading data

# parses the length of the files, either by encoding in the filenames or by iterating over them
self._get_lens()
Expand Down Expand Up @@ -71,17 +71,19 @@ def _get_lens(self):
lens.append(n_documents)
self.lens = lens
self._len = sum(self.lens)

def _parse_single_example(self, example):
data = tf.train.Example.FromString(example)
data = torch.tensor(list(data.features.feature["text"].int64_list.value), dtype=torch.long)
if self.mode == "chunks":
assert data.size(0) == self.seq_len + 1
return data
def _parse_function(self, example_proto):
features = {
"text": tf.io.VarLenFeature(tf.int64)
}
parsed_features = tf.io.parse_single_example(example_proto, features)
return tf.sparse.to_dense(parsed_features["text"], parsed_features["text"].dense_shape[0])

def _process_tfrecord(self, tfrecords_file, resume_idx=None):
for idx, example in enumerate(tf.io.tf_record_iterator(tfrecords_file)):
yield self._parse_single_example(example)
dataset = tf.data.TFRecordDataset([tfrecords_file])
dataset = dataset.map(self._parse_function, num_parallel_calls=1)
for example in dataset.as_numpy_iterator():
yield torch.tensor(example, dtype=torch.long)

def _maybe_process_tfrecord(self, file_idx):
if self.processed_files.get(file_idx) is None:
Expand Down