Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add dynamic dataset for processing/tokenizing examples lazily #46

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 146 additions & 0 deletions gpt_neox/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@
import re
import logging
from itertools import cycle
import os
import subprocess
import simdjson as json

# will get tons of warnings.
logging.getLogger("transformers.tokenization_utils_base").setLevel(logging.ERROR)
os.environ["TOKENIZERS_PARALLELISM"] = "true"

class GPT2Dataset(Dataset):

Expand Down Expand Up @@ -132,3 +139,142 @@ def __getitem__(self, index):

def __len__(self):
return self.data.size(0) // self.seq_len

class LineSeekableFile:
def __init__(self, seekable):
self.fin = seekable
self.line_map = list()
self.line_map.append(0)
while seekable.readline():
self.line_map.append(seekable.tell())

def __getitem__(self, index):
self.fin.seek(self.line_map[index])
return self.fin.readline()

class DynamicDataset(Dataset):
def __init__(self, input_files, tokenizer, max_seq_len, target_field='text', seed=1, shuffle_files=True, debug=False, **kwargs):
super().__init__()
self.files = []
self.debug = debug
self.setup_files(input_files)
if shuffle_files:
random.seed(seed)
random.shuffle(self.files)
self.tokenizer = tokenizer
self.max_seq_len = max_seq_len
self.target_field = target_field
self.token_cache = []
self.text_chunks = ''
self.sep_token = tokenizer.eos_token_id
self.sep_word = tokenizer.eos_token
self.pad_token = tokenizer.pad_token_id
self.parser = json.Parser()


def setup_files(self, input_files):
if isinstance(input_files, str):
if input_files.endswith('*'):
self.files = glob.glob(input_files)
elif os.path.isdir(input_files):
self.files = glob.glob(os.path.join(input_files, '*'))
elif isinstance(input_files, list):
for file_path in input_files:
if os.path.isfile(file_path) and os.path.exists(file_path):
self.files.append(file_path)
elif file_path.endswith('*'):
self.files.extend(glob.glob(file_path))
elif os.path.isdir(file_path):
self.files.extend(glob.glob(os.path.join(file_path, '*')))

self.total_files = len(self.files)
self.file_idx, self.total_lines = {}, 0
for x, file_path in enumerate(self.files):
total_lines = self.total_lines_in_file(file_path)
self.file_idx[x] = {
'start': self.total_lines, 'stop': (self.total_lines + total_lines),
'file': file_path, 'reader': LineSeekableFile(tf.io.gfile.GFile(file_path, 'rb'))
}
if self.debug:
logging.debug(f'File IDX Start: {self.total_lines} - File IDX End: {self.total_lines + total_lines}')
self.total_lines += total_lines
if self.debug:
logging.debug(f'Total Files: {self.total_files}. Total Lines: {self.total_lines}')

def get_file_line(self, idx):
for x in range(len(self.files)):
if idx in range(self.file_idx[x]['start'], self.file_idx[x]['stop']):
fidx = idx - self.file_idx[x]['start']
if self.debug:
logging.debug(f'File IDX: {fidx}')
return self.file_idx[x]['reader'][fidx]

def parse_json(self, line):
try:
return self.parser.parse(line).as_dict()[self.target_field]
except ValueError:
return line
except TypeError:
return line

@classmethod
def total_lines_in_file(cls, file_path):
return int(subprocess.check_output(['wc', '-l', file_path]).split()[0])

def tokenize_example(self, ex):
tokenized = self.tokenizer(ex)['input_ids']
if self.token_cache:
if len(self.token_cache) > self.max_seq_len:
out = self.token_cache[0:self.max_seq_len]
self.token_cache = self.token_cache[0:self.max_seq_len].extend(tokenized.append(self.sep_token))

else:
out = self.token_cache[:]
_to_slice = self.max_seq_len - len(out)
out.extend(tokenized[:_to_slice])
self.token_cache = tokenized[_to_slice:].append(self.sep_token) if tokenized[_to_slice:] else []

else:
out = tokenized[:self.max_seq_len]
self.token_cache = tokenized[self.max_seq_len:].append(self.sep_token) if tokenized[self.max_seq_len:] else []

if len(out) < self.max_seq_len:
_to_pad = self.max_seq_len - len(out)
out.extend([self.pad_token for i in range(_to_pad)])

return torch.tensor(out, dtype=torch.long)

def _get_example(self, idx):
idx = idx if idx <= self.total_lines else random.randint(0, self.total_lines)
ex = self.get_file_line(idx)
if not ex:
while True:
new_idx = random.randint(0, self.total_lines)
if self.debug:
logging.debug(f'Bad IDX: {idx} - New Random IDX: {new_idx}')
ex = self.get_file_line(new_idx)
if ex:
break
return self.parse_json(ex).strip()

def _seq_len(self, ex):
return len(ex.split()) > self.max_seq_len

def __getitem__(self, idx):
if self.debug:
logging.debug(f'Getting IDX: {idx}')
ex = self._get_example(idx)
if self._seq_len(ex):
return self.tokenize_example(ex)
else:
while True:
self.text_chunks = self.text_chunks + (ex + ' ' + self.sep_word)
if self._seq_len(self.text_chunks):
out = self.tokenize_example(self.text_chunks)
self.text_chunks = ''
break
ex = self._get_example(idx+1)
return out

def __len__(self):
return self.total_lines
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ transformers
tensorflow==2.4.0
ftfy
lm_dataformat
pysimdjson