Skip to content

Commit

Permalink
Merge pull request google-research#10 from google-research/dev-v1
Browse files Browse the repository at this point in the history
V1.0 release
  • Loading branch information
carlini committed Mar 11, 2022
2 parents 62d0320 + f86b161 commit 8c54c64
Show file tree
Hide file tree
Showing 11 changed files with 1,513 additions and 1,087 deletions.
6 changes: 3 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "dedup_dataset" # the name of the package
version = "0.1.0" # the current version, obeying semver
version = "1.0.0" # the current version, obeying semver

authors = ["Nicholas Carlini <[email protected]>"]

Expand All @@ -11,5 +11,5 @@ overflow-checks = false # Go FAAASSTTT!
[dependencies]
zstd = "0.5"
crossbeam = "0.3"
fasthash = "0.4"
filebuffer = "0.4"
filebuffer = "0.4"
clap = { version = "3.1.1", features = ["derive"] }
281 changes: 247 additions & 34 deletions README.md

Large diffs are not rendered by default.

23 changes: 15 additions & 8 deletions scripts/count_occurances.py → scripts/count_occurrences.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,24 @@

import argparse

parser = argparse.ArgumentParser(description='Count occurances of sequence.')
parser = argparse.ArgumentParser(description='Count occurrences of sequence.')
parser.add_argument('--suffix', type=str, required=True)
parser.add_argument('--query', type=str)
parser.add_argument('--query_file', type=str)
parser.add_argument('--tokenize', action='store_true')
parser.add_argument('--tokenizer', type=str, default="gpt2")

args = parser.parse_args()

if args.tokenize:
from transformers import GPT2Tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
if args.tokenizer == 'gpt2':
from transformers import GPT2Tokenizer
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
elif args.tokenizer == 't5':
from transformers import T5Tokenizer
tokenizer = T5Tokenizer.from_pretrained('t5-small')
else:
raise

assert args.query or args.query_file

Expand All @@ -37,13 +44,13 @@
arr = args.query.encode('utf-8')
print(arr)
open("/tmp/fin","wb").write(arr)
print(os.popen("./target/debug/dedup_dataset count_occurances %s /tmp/fin"%(args.suffix)).read())
print(os.popen("./target/debug/dedup_dataset count-occurrences --data-file %s --query-file /tmp/fin"%(args.suffix)).read())
else:
q = open(args.query_file).read()
if args.tokenize:
q = open(args.query_file).read()
arr = np.array(tokenizer.encode(q), dtype=np.uint16).view(np.uint8).tobytes()
else:
arr = q.encode('utf-8')
arr = open(args.query_file,"rb").read()
print(arr)
open("/tmp/fin","wb").write(arr.tobytes())
print(os.popen("./target/debug/dedup_dataset count_occurances %s /tmp/fin"%(args.suffix)).read())
open("/tmp/fin","wb").write(arr)
print(os.popen("./target/debug/dedup_dataset count-occurrences --data-file %s --query-file /tmp/fin"%(args.suffix)).read())
4 changes: 4 additions & 0 deletions scripts/deduplicate_single_file.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
python3 scripts/make_suffix_array.py $1
cargo run self-similar --data-file $1 --length-threshold $3 --cache-dir /tmp/cache --num-threads $4
cargo run collect --data-file $1 --cache-dir /tmp/cache --length-threshold $3 > /tmp/drop_tokens_file
python3 scripts/finish_single_file.py $1 /tmp/drop_tokens_file $2
198 changes: 198 additions & 0 deletions scripts/finish_dedup_wiki40b.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
# Copyright 2022 Google LLC
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import os
import shutil
import json
from transformers import GPT2Tokenizer
import multiprocessing as mp
from collections import defaultdict
import tensorflow as tf
import tensorflow_datasets as tfds

import pickle

def _bytes_feature(value):
"""Returns a bytes_list from a string / byte."""
if isinstance(value, type(tf.constant(0))):
value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def serialize_example(**feature):
"""
Creates a tf.train.Example message ready to be written to a file.
"""
# Create a dictionary mapping the feature name to the tf.train.Example-compatible
# data type.
feature = {
'content-length': _bytes_feature(feature['content-length']),
'content-type': _bytes_feature(feature['content-type']),
'text': _bytes_feature(feature['text']),
'timestamp': _bytes_feature(feature['timestamp']),
'url': _bytes_feature(feature['url']),
}

# Create a Features message using tf.train.Example.

example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
return example_proto.SerializeToString()


tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

remove = defaultdict(list)

def run(args):
this_idx, row = args
new_row = {'text': row,
'version_id': '',
'wikidata_id': '',
'timestamp': '',
'url': '',
'content-length': '',
'content-type': ''}

if this_idx in remove_ex:
for start,end in remove_ex[this_idx][::-1]:
#print(start,end)
row = row[:start] + row[end:]

new_row['text'] = row
return new_row

class MyDataset(tfds.core.GeneratorBasedBuilder):
"""DatasetBuilder for my_dataset dataset."""

VERSION = tfds.core.Version('1.0.0')
RELEASE_NOTES = {
'1.0.0': 'Initial release.',
}

def _info(self):
"""Dataset metadata (homepage, citation,...)."""
return tfds.core.DatasetInfo(
builder=self,
features=tfds.features.FeaturesDict({
'text': tfds.features.Text(),
'version_id': tfds.features.Text(),
'wikidata_id': tfds.features.Text(),
'timestamp': tfds.features.Text(),
'url': tfds.features.Text(),
'content-length': tfds.features.Text(),
'content-type': tfds.features.Text()
}),
)

def _split_generators(self, dl_manager: tfds.download.DownloadManager):
"""Download the data and define splits."""
# dl_manager returns pathlib-like objects with `path.read_text()`,
# `path.iterdir()`,...
splits = {args.split: self._generate_examples(args.split)}
return splits

def _generate_examples(self, split):
"""Generator of examples for each split."""
global remove, ks

BS = 2**16
ds = tfds.load(args.name, split=split, shuffle_files=False, batch_size=BS,
data_dir=args.data_dir)


p = mp.Pool(96)
i = -1
for batch in ds:
i += 1
#print('$$',i)
ks = list(batch.keys())
res = [(i*BS + j, row) for j,row in enumerate(batch['text'].numpy())]
res = p.map(run, res)

for j,new_row in enumerate(res):
this_idx = i*BS + j
yield str(this_idx), new_row

import argparse

parser = argparse.ArgumentParser(description='Dedup dataset')
parser.add_argument('--data_dir', type=str)
parser.add_argument('--save_dir', type=str)
parser.add_argument('--suffixarray_dir', type=str)
parser.add_argument('--name', type=str)
parser.add_argument('--split', type=str)
parser.add_argument('--remove', type=str)

args = parser.parse_args()

dataset = args.name
where = args.save_dir

remove = []
fin = open(args.remove)
for line in fin:
if 'out' in line: break
for line in fin:
remove.append(list(map(int,line.split())))

sizes = np.frombuffer(open(os.path.join(args.suffixarray_dir, args.name+"."+args.split+".size"), "rb").read(), dtype=np.uint64)

remove_ex = defaultdict(list)
ptr = 0
for i,byte_start in enumerate(sizes[:-1]):
byte_end = sizes[i+1]
#print(byte_start, byte_end, remove[ptr])
while ptr < len(remove) and byte_start <= remove[ptr][0] < byte_end:
#print(remove[ptr])
assert remove[ptr][1] < byte_end+6
# The magic value 6 here corresponds to the 4-byte index prefix followed by \xff\xff.
remove_ex[i].append((max(int(remove[ptr][0] - byte_start - 6), 0),
min(int(remove[ptr][1] - byte_start), byte_end-byte_start)))
ptr += 1

tfds.load("my_dataset", data_dir=where+"_dedup")


if dataset == "wiki40b":
en = os.path.join(where+"_dedup", "wiki40b")
if not os.path.exists(en):
os.mkdir(en)
en = os.path.join(en, "en")
if not os.path.exists(en):
os.mkdir(en)
en = os.path.join(en, "1.3.0")
if not os.path.exists(en):
os.mkdir(en)

root = os.path.join(where+"_dedup", "my_dataset", "1.0.0")
for f in os.listdir(root):
if "my_dataset" in f:
shutil.move(os.path.join(root, f),
os.path.join(en, f.replace("my_dataset", dataset)))
elif f == 'dataset_info.json' and os.path.exists(os.path.join(en, f)):
json_orig = json.loads(open(os.path.join(en, f)).read())
json_new = json.loads(open(os.path.join(root, f)).read())
json_orig['splits'].extend(json_new['splits'])
open(os.path.join(en, f),"w").write(json.dumps(json_orig))
else:
shutil.move(os.path.join(root, f),
os.path.join(en, f))
else:
raise

try:
os.unlink(os.path.join(where+"_dedup", "my_dataset", "1.0.0", "dataset_info.json"))
except:
pass
os.rmdir(os.path.join(where+"_dedup", "my_dataset", "1.0.0"))
os.rmdir(os.path.join(where+"_dedup", "my_dataset"))
37 changes: 37 additions & 0 deletions scripts/finish_single_file.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright 2022 Google LLC
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import sys

original = sys.argv[1]
remove_file = sys.argv[2]
deduped = sys.argv[3]

remove = []
fin = open(remove_file)
for line in fin:
if 'out' in line: break
for line in fin:
remove.append(list(map(int,line.split())))
remove = remove[::-1]

ds = open(original,"rb")
new_ds = open(deduped,"wb")

start = 0
while len(remove) > 0:
a,b = remove.pop()
new_ds.write(ds.read(a-start))
ds.seek(b)
start = b
new_ds.write(ds.read())
48 changes: 30 additions & 18 deletions scripts/load_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import os
import struct
import numpy as np
from transformers import GPT2Tokenizer
from transformers import GPT2Tokenizer, T5Tokenizer
import multiprocessing as mp

import argparse
Expand All @@ -27,11 +27,18 @@
parser.add_argument('--name', type=str)
parser.add_argument('--split', type=str)
parser.add_argument('--tokenize', action='store_true')

parser.add_argument('--tokenizer', type=str, default="gpt2")
parser.add_argument('--pre_sep', type=bytes, default=b"\xff\xff")
parser.add_argument('--post_sep', type=bytes, default=b"")
args = parser.parse_args()

if args.tokenize:
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
if args.tokenizer == 'gpt2':
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
elif args.tokenizer == 't5':
tokenizer = T5Tokenizer.from_pretrained('t5-small')
else:
raise

split = args.split
data_dir = args.data_dir
Expand All @@ -44,11 +51,14 @@
assert isinstance(ds, tf.data.Dataset)
print(ds)

pre_sep = args.pre_sep
post_sep = args.post_sep

UID = 0
def sep():
global UID
UID += 1
return b"\xff\xff"+struct.pack("<I", UID)
return pre_sep+struct.pack("<I", UID)+post_sep

def tok(x):
if args.tokenize:
Expand All @@ -59,22 +69,24 @@ def tok(x):
return out


fout = open(os.path.join(save_dir, dataset_name+"."+split), "wb")

p = mp.Pool(96)
if not os.path.exists(save_dir):
os.mkdir(save_dir)

i = 0
sizes = [0]
for b in ds:
print(i)
fout = open(os.path.join(save_dir, dataset_name+"."+split), "wb")

text = b['text'].numpy()
text = p.map(tok,text)
with mp.Pool(mp.cpu_count()) as p:
i = 0
sizes = [0]
for b in ds:
print(i)

for x in text:
next_line = sep()+x
fout.write(next_line)
sizes.append(sizes[-1]+len(next_line))
i += 1
text = b['text'].numpy()
text = p.map(tok,text)

for x in text:
next_line = sep()+x
fout.write(next_line)
sizes.append(sizes[-1]+len(next_line))
i += 1

open(os.path.join(save_dir,dataset_name+"."+split+".size"), "wb").write(np.array(sizes,dtype=np.uint64).tobytes())
Loading

0 comments on commit 8c54c64

Please sign in to comment.