Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add dynamic dataset for processing/tokenizing examples lazily #46

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
add dynamic dataset for processing/tokenizing examples lazily
- adding pysimdjson as a requirement for memory-efficient and fast json parsing
- added DynamicDataset class intended for jsonlines, leveraging TF's TextLineDataset C++ io as an iterator.
  • Loading branch information
trisongz committed Jan 5, 2021
commit e3aec3bc5cd53a0754692b3ef380f553eea02fbc
71 changes: 71 additions & 0 deletions gpt_neox/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
import re
import logging
from itertools import cycle
import os
import subprocess
import simdjson as json

class GPT2Dataset(Dataset):

Expand Down Expand Up @@ -130,3 +133,71 @@ def __getitem__(self, index):

def __len__(self):
return self.data.size(0) // self.seq_len


class DynamicDataset(Dataset):
def __init__(self, input_files, tokenizer, max_seq_len, target_field='text', seed=1, shuffle_files=True, **kwargs):
super().__init__()
self.files = []
self.setup_files(input_files)
if shuffle_files:
random.seed(seed)
random.shuffle(self.files)
self.create_pipeline()
self.tokenizer = tokenizer
self.max_seq_len = max_seq_len
self.target_field = target_field
self.parser = json.Parser()
self.idx = 0

def setup_files(self, input_files):
if isinstance(input_files, str):
if input_files.endswith('*'):
self.files = glob.glob(input_files)
elif os.path.isdir(input_files):
self.files = glob.glob(os.path.join(input_files, '*'))
elif isinstance(input_files, list):
for file_path in input_files:
if os.path.isfile(file_path) and os.path.exists(file_path):
self.files.append(file_path)
elif file_path.endswith('*'):
self.files.extend(glob.glob(file_path))
elif os.path.isdir(file_path):
self.files.extend(glob.glob(os.path.join(file_path, '*')))

self.total_files = len(self.files)
self.file_idx, self.total_lines = {}, 0
for file_path in self.files:
total_lines = self.total_lines_in_file(file_path)
self.file_idx[file_path] = total_lines
self.total_lines += total_lines
logging.info(f'Total Files: {self.total_files}. Total Lines: {self.total_lines}')

def create_pipeline(self):
self.pipeline = tf.data.TextLineDataset(self.files, num_parallel_reads=tf.data.experimental.AUTOTUNE).as_numpy_iterator()

def parse_json(self, line):
try:
return self.parser.parse(line).as_dict()
except ValueError:
return line

@classmethod
def total_lines_in_file(cls, file_path):
return int(subprocess.check_output(['wc', '-l', file_path]).split()[0])

def tokenize_example(self, ex):
self.idx += 1
return self.tokenizer(ex[self.target_field], max_length=self.max_seq_len, truncation=True, return_tensors='pt')['input_ids']

def __getitem__(self, idx):
try:
ex = next(self.pipeline)
except StopIteration:
del self.pipeline
self.create_pipeline()
ex = next(self.pipeline)
return self.tokenize_example(self.parse_json(ex))

def __len__(self):
return self.total_lines
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ transformers
tensorflow==1.15.2
ftfy
lm_dataformat
pysimdjson