Skip to content

Commit

Permalink
Update scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
leogao2 committed Dec 21, 2020
1 parent 1e14c4f commit 4a1d42b
Show file tree
Hide file tree
Showing 4 changed files with 335 additions and 1 deletion.
4 changes: 3 additions & 1 deletion processing_scripts/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ Replication scripts are listed in approximate order needed for replication.

## Analysis & Ablation

- `lang_len_analysis.py`: Runs analysis for length in {chars, bytes, tokens, words} and language. Saves the result as .jsonl.zst files which need a second pass to aggregate, but this first pass is the more expensive one anyways, and this means we can make nice histograms and stuff. Should be run with `TOKENIZERS_PARALLELISM=false` for max performance since it prevents thread thrashing. This script would be a useful template for other future analysis.
- `lang_len_analysis_pass1.py`: Runs analysis for length in {chars, bytes, tokens, words} and language. Saves the result as .jsonl.zst files which need a second pass to aggregate, but this first pass is the more expensive one anyways, and this means we can make nice histograms and stuff. Should be run with `TOKENIZERS_PARALLELISM=false` for max performance since it prevents thread thrashing. This script would be a useful template for other future analysis.
- `lang_len_analysis_pass2.py`: Pass 2 for langth/language analysis. Aggregates and makes plots.
- `profanity_analysis_pass1.py`: Profanity analysis pass 1.
- `ablation_dedupe/make_excludes_lambada_wikitext.py`: For ablation; detokenizes LAMBADA and wikitext in preparation for eval-dedupe. Thie script should be obsolete now; `write_out.py` in lm_evaluation_harness handles many more sets. TODO: write detailed guide on how to use `write_out.py`
- `ablation_dedupe/make_deduped.py`: For ablation; performs decontamination of training data against validation/test data. Run `make_excludes_lambada_wikitext` or `write_out.py` first. TODO: clean up and make official validation-dedupe script.

Expand Down
File renamed without changes.
178 changes: 178 additions & 0 deletions processing_scripts/lang_len_analysis_pass2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
from glob import glob
import random
import os
from tqdm import tqdm
import shutil
import zstandard
import json
import io

import sys
import math


def readf(f):
with open(f, 'rb') as fh:
cctx = zstandard.ZstdDecompressor()
reader = io.BufferedReader(cctx.stream_reader(fh))
yield from reader


rewritenames = {
'CommonCrawl': 'Pile-CC',
'Bibliotik': 'Books3',
'USPTO': 'USPTO Backgrounds',
'BookCorpus': 'BookCorpus2',
}


def rewrite_name(n):
if n in rewritenames: return rewritenames[n]

return n

import collections

dat = collections.defaultdict(list)

set_names = set()


for f in tqdm(glob('langlen_stage1/*')[:1]):
# forgot to add \n in stage1 >.>
for x in map(lambda x: x + b'}', (list(next(readf(f)).split(b'}'))[:-1])):
ob = json.loads(x)
setname = rewrite_name(ob['pile_set_name'])
#print(ob)
for attr in ['len_char', 'len_utf8bytes', 'len_words', 'len_tokens', 'lang']:
dat[(setname, attr)].append(ob[attr])
dat[('Pile', attr)].append(ob[attr])
set_names.add(setname)

if ob['len_tokens'] > 0:
dat[(setname, 'bytes_per_token')].append(ob['len_utf8bytes'] / ob['len_tokens'])
dat[('Pile', 'bytes_per_token')].append(ob['len_utf8bytes'] / ob['len_tokens'])

dat[(setname, 'words_per_token')].append(ob['len_words'] / ob['len_tokens'])
dat[('Pile', 'words_per_token')].append(ob['len_words'] / ob['len_tokens'])

set_names = list(set_names)
set_names.append('Pile')

def mean(x):
return sum(x) / len(x)

def stddev(x):
mu = mean(x)
return math.sqrt(mean([(v - mu) ** 2 for v in x]))

def freqs(x):
ret = collections.defaultdict(int)
for v in x:
ret[v] += 1

return ret

def filter_freqs(x, minpass):
total = sum(x.values())
torm = []
for k, v in x.items():
if v / total < minpass:
torm.append(k)

for k in torm: del x[k]

return x

nicename = {
'len_char': 'Length in characters',
'len_utf8bytes': 'Length in bytes',
'len_words': 'Length in words',
'len_tokens': 'Length in tokens',
'bytes_per_token': 'Mean bytes per token',
'words_per_token': 'Mean words per token',
'lang': 'Language'
}


import matplotlib.pyplot as plt
import numpy as np

def histogram(x, sname, attr):
plt.clf()
plt.cla()
plt.hist(x, density=True, bins=100)
#plt.ylabel('Probability Density')
plt.xlabel('{} ({})'.format(nicename[attr], sname))
plt.savefig('figures/analysis_{}_{}.png'.format(sname, attr),bbox_inches='tight')


def barplot(d, sname, attr, normalize=True, yerr=False):
x, y = zip(*sorted(d.items(), key=lambda x: x[1], reverse=True))
yerrs = None
if yerr:
yerrs = [v[1] for v in y]
y = [v[0] for v in y]
if normalize:
total = sum(d.values())
y = [val / total for val in y]
plt.clf()
plt.cla()
if yerr:
plt.errorbar(x, y, yerr=yerrs, fmt='o')
plt.xticks(rotation=45, ha="right")

#ymin = None
#ymax = None

#if attr == 'len_char':
# ymin, ymax = -30000, 1200000
#if attr == 'len_tokens':
# ymin, ymax = -30000, 300000
#if attr == 'len_utf8bytes':
# ymin, ymax = -30000, 1200000
#if ymin and ymax:
# axes = plt.gca()
# axes.set_ylim([ymin,ymax])
else:
plt.bar(x, y)
#plt.ylabel('Proportion')
# plt.xlabel('{} ({})'.format(nicename[attr], sname))
plt.xlabel('Pile component')
plt.ylabel(nicename[attr])
plt.savefig('figures/analysis_{}_{}.png'.format(sname, attr),bbox_inches='tight', dpi=600)


def format_freqs(d):
res = []
total = sum(d.values())
for k,v in sorted(d.items(), key=lambda x: -x[1]):
res.append(' {}: {:2f}%'.format(k, v / total * 100))
return '\n'.join(res)


def rm_outliers_trunc_1p(x):
x = list(sorted(x))
return x[:len(x)*99//100]


summary = collections.defaultdict(dict)

print('bytes per token, all:', sum(dat[('Pile', 'len_utf8bytes')]) / sum(dat[('Pile', 'len_tokens')]))

for sname in set_names:
print('**' + sname + '**')
for attr in ['len_char', 'len_utf8bytes', 'len_words', 'len_tokens', 'bytes_per_token', 'words_per_token']:
mu, sigma = mean(dat[(sname,attr)]), stddev(dat[(sname,attr)])
print('{}: {:.4f}±{:.4f}'.format(nicename[attr], mu, sigma))
#histogram(rm_outliers_trunc_1p(dat[(sname,attr)]), sname, attr)
if sname != 'Pile' and (sname != 'Ubuntu IRC' or 'len_' not in attr): summary[attr][sname] = (mu, sigma)

#barplot(filter_freqs(freqs(dat[(sname,'lang')]), 0.001), sname, 'lang')

print('Langs:')
print(format_freqs(freqs(dat[(sname, 'lang')])))


for attr in ['len_char', 'len_utf8bytes', 'len_words', 'len_tokens', 'bytes_per_token', 'words_per_token']:
barplot(summary[attr], 'overview', attr, normalize=False, yerr=True)
154 changes: 154 additions & 0 deletions processing_scripts/profanity_analysis_pass1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
import lm_dataformat as lmd
from glob import glob
import os
import json
import collections
from tqdm import tqdm

import re
from best_download import download_file
import fasttext

import zstandard
import multiprocessing as mp
from profanity_check import predict

in_path = 'pile'
out_path = 'prof_analysis'

# From https://stackoverflow.com/a/31505798
import re
alphabets= "([A-Za-z])"
prefixes = "(Mr|St|Mrs|Ms|Dr)[.]"
suffixes = "(Inc|Ltd|Jr|Sr|Co)"
starters = "(Mr|Mrs|Ms|Dr|He\s|She\s|It\s|They\s|Their\s|Our\s|We\s|But\s|However\s|That\s|This\s|Wherever)"
acronyms = "([A-Z][.][A-Z][.](?:[A-Z][.])?)"
websites = "[.](com|net|org|io|gov)"

def split_into_sentences(text):
text = " " + text + " "
text = text.replace("\n"," ")
text = re.sub(prefixes,"\\1<prd>",text)
text = re.sub(websites,"<prd>\\1",text)
if "Ph.D" in text: text = text.replace("Ph.D.","Ph<prd>D<prd>")
text = re.sub("\s" + alphabets + "[.] "," \\1<prd> ",text)
text = re.sub(acronyms+" "+starters,"\\1<stop> \\2",text)
text = re.sub(alphabets + "[.]" + alphabets + "[.]" + alphabets + "[.]","\\1<prd>\\2<prd>\\3<prd>",text)
text = re.sub(alphabets + "[.]" + alphabets + "[.]","\\1<prd>\\2<prd>",text)
text = re.sub(" "+suffixes+"[.] "+starters," \\1<stop> \\2",text)
text = re.sub(" "+suffixes+"[.]"," \\1<prd>",text)
text = re.sub(" " + alphabets + "[.]"," \\1<prd>",text)
if "”" in text: text = text.replace(".”","”.")
if "\"" in text: text = text.replace(".\"","\".")
if "!" in text: text = text.replace("!\"","\"!")
if "?" in text: text = text.replace("?\"","\"?")
text = text.replace(".",".<stop>")
text = text.replace("?","?<stop>")
text = text.replace("!","!<stop>")
text = text.replace("<prd>",".")

# return quotes to normal
text = text.replace("\".", ".\"")
text = text.replace("”.", ".”")
text = text.replace("\"!", "!\"")
text = text.replace("”!", "!”")
text = text.replace("\"?", "?\"")
text = text.replace("”?", "?”")
sentences = text.split("<stop>")
sentences = sentences[:-1]
sentences = [s.strip() for s in sentences]
return sentences

from best_download import download_file
import fasttext
download_file('https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.bin', 'lid.176.bin', '7e69ec5451bc261cc7844e49e4792a85d7f09c06789ec800fc4a44aec362764e')

langdet = fasttext.load_model("lid.176.bin")

def language(doc):
details = langdet.predict(doc.replace('\n', ' '), k=1)

return {
'lang': details[0][0].replace('__label__', '')
}

def is_english(doc): return doc != '' and language(doc) == 'en'

def words(sent): return re.split(r'\s+', sent)

def join(arr):
ret = []
for val in arr:
ret.extend(val)

return ret

def unjoin(arr, lens):
ret = []

last = 0
for l in lens:
ret.append(arr[last:last+l])
last += l
assert last == len(arr)

return ret


def is_profane(docs):
if len(docs) == 0: return []
return list(map(int, predict(docs)))


def profanity(doc):
sents = list(filter(is_english, split_into_sentences(doc)))
p_sents = is_profane(sents)

sentwords = list(map(words, sents))
sentlens = list(map(len, sentwords))

lwords = join(sentwords)
p_words = list(map(is_profane, lwords))
p_words = unjoin(p_words, sentlens)
n_prof = list(map(sum, p_words))

res = list(zip(pred, sentlens, n_prof))
return {
'sentences': res,
'num_bytes': len(doc.encode('utf-8'))
}

def writef(f, lines):
with open(f, 'wb') as fh:
cctx = zstandard.ZstdCompressor(level=3, threads=8)
compressor = cctx.stream_writer(fh)
for line in tqdm(lines):
compressor.write(line)
compressor.flush(zstandard.FLUSH_FRAME)


def analyze(ob):
doc, meta = ob
res = {
'pile_set_name': meta['pile_set_name']
}
for metric in metrics:
res = {**res, **metric(doc)}
return json.dumps(res).encode('utf-8')


metrics = [
profanity
]

pool = mp.Pool(24)


for f in tqdm(sorted(glob(in_path + '/*'))):
if os.path.exists(out_path + '/analysis_' + f.split('/')[-1]): continue
def meta_items():
rdr = lmd.Reader(f)
return pool.imap(analyze, rdr.stream_data(get_meta=True))

writef(out_path + '/tmp_analysis_' + f.split('/')[-1], meta_items())
os.rename(out_path + '/tmp_analysis_' + f.split('/')[-1], out_path + '/analysis_' + f.split('/')[-1])

0 comments on commit 4a1d42b

Please sign in to comment.