Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add dynamic dataset for processing/tokenizing examples lazily #46

Closed
wants to merge 7 commits into from
Closed
Prev Previous commit
Next Next commit
update dataset implementation
  • Loading branch information
trisongz committed Jan 11, 2021
commit 80c65bfcc0a497d7632777d597ad936983fa45a1
51 changes: 34 additions & 17 deletions gpt_neox/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,21 +136,31 @@ def __getitem__(self, index):
def __len__(self):
return self.data.size(0) // self.seq_len

class LineSeekableFile:
def __init__(self, seekable):
self.fin = seekable
self.line_map = list()
self.line_map.append(0)
while seekable.readline():
self.line_map.append(seekable.tell())

def __getitem__(self, index):
self.fin.seek(self.line_map[index])
return self.fin.readline()

class DynamicDataset(Dataset):
def __init__(self, input_files, tokenizer, max_seq_len, target_field='text', seed=1, shuffle_files=True, **kwargs):
def __init__(self, input_files, tokenizer, max_seq_len, target_field='text', seed=1, shuffle_files=True, debug=False, **kwargs):
super().__init__()
self.files = []
self.setup_files(input_files)
if shuffle_files:
random.seed(seed)
random.shuffle(self.files)
self.create_pipeline()
self.tokenizer = tokenizer
self.max_seq_len = max_seq_len
self.target_field = target_field
self.parser = json.Parser()
self.idx = 0
self.debug = debug

def setup_files(self, input_files):
if isinstance(input_files, str):
Expand All @@ -169,14 +179,25 @@ def setup_files(self, input_files):

self.total_files = len(self.files)
self.file_idx, self.total_lines = {}, 0
for file_path in self.files:
for x, file_path in enumerate(self.files):
total_lines = self.total_lines_in_file(file_path)
self.file_idx[file_path] = total_lines
self.file_idx[x] = {
'start': self.total_lines, 'stop': (self.total_lines + total_lines - 1),
'file': file_path, 'reader': LineSeekableFile(tf.io.gfile.GFile(file_path, 'r'))
}
if self.debug:
logging.debug(f'File IDX Start: {self.total_lines} - File IDX End: {self.total_lines + total_lines - 1}')
self.total_lines += total_lines
logging.info(f'Total Files: {self.total_files}. Total Lines: {self.total_lines}')
if self.debug:
logging.debug(f'Total Files: {self.total_files}. Total Lines: {self.total_lines}')

def create_pipeline(self):
self.pipeline = tf.data.TextLineDataset(self.files, num_parallel_reads=tf.data.experimental.AUTOTUNE).as_numpy_iterator()
def get_file_line(self, idx):
for x in range(len(self.files)):
if idx in range(self.file_idx[x]['start'], self.file_idx[x]['stop']):
fidx = idx - self.file_idx[x]['start']
if self.debug:
logging.debug(f'File IDX: {fidx}')
return self.file_idx[x]['reader'][fidx]

def parse_json(self, line):
try:
Expand All @@ -189,17 +210,13 @@ def total_lines_in_file(cls, file_path):
return int(subprocess.check_output(['wc', '-l', file_path]).split()[0])

def tokenize_example(self, ex):
self.idx += 1
return self.tokenizer(ex[self.target_field], max_length=self.max_seq_len, truncation=True, return_tensors='pt')['input_ids']
return self.tokenizer(ex[self.target_field], max_length=self.max_seq_len, padding='max_length', truncation=True, return_tensors='pt')['input_ids']

def __getitem__(self, idx):
try:
ex = next(self.pipeline)
except StopIteration:
del self.pipeline
self.create_pipeline()
ex = next(self.pipeline)
if self.debug:
logging.debug(f'Getting IDX: {idx}')
ex = self.get_file_line(idx)
return self.tokenize_example(self.parse_json(ex))

def __len__(self):
return self.total_lines
return self.total_lines