Skip to content

Commit

Permalink
init gec
Browse files Browse the repository at this point in the history
  • Loading branch information
plkmo committed Jun 14, 2020
1 parent aebd823 commit 3919b57
Show file tree
Hide file tree
Showing 16 changed files with 3,555 additions and 0 deletions.
81 changes: 81 additions & 0 deletions gec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# -*- coding: utf-8 -*-
"""
Created on Sun Aug 18 13:09:24 2019
@author: WT
"""
from nlptoolkit.gec.infer import infer_from_trained
from nlptoolkit.utils.misc import save_as_pickle
from argparse import ArgumentParser
import logging

logging.basicConfig(format='%(asctime)s [%(levelname)s]: %(message)s', \
datefmt='%m/%d/%Y %I:%M:%S %p', level=logging.INFO)
logger = logging.getLogger('__file__')

if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--model_no", type=int, default=0, help="0: GECToR")
parser.add_argument('--model_path', type=str, default=['./data/gec/gector/roberta_1_gector.th'],
help='Path to the model file.', nargs='+')
parser.add_argument('--vocab_path', type=str, default='./data/gec/gector/output_vocabulary/',
help='Path to the model file.')
#parser.add_argument('--input_file', type=str, default='./data/gec/gector/input.txt',
# help='Path to the evalset file')
#parser.add_argument('--output_file', type=str, default='./data/gec/gector/output.txt',
# help='Path to the output file')
parser.add_argument('--max_len',
type=int,
help='The max sentence length'
'(all longer will be truncated)',
default=50)
parser.add_argument('--min_len',
type=int,
help='The minimum sentence length'
'(all longer will be returned w/o changes)',
default=3)
parser.add_argument('--batch_size',
type=int,
help='The size of hidden unit cell.',
default=128)
parser.add_argument('--lowercase_tokens',
type=int,
help='Whether to lowercase tokens.',
default=0)
parser.add_argument('--transformer_model',
choices=['bert', 'gpt2', 'transformerxl', 'xlnet', 'distilbert', 'roberta', 'albert'],
help='Name of the transformer model.',
default='roberta')
parser.add_argument('--iteration_count',
type=int,
help='The number of iterations of the model.',
default=5)
parser.add_argument('--additional_confidence',
type=float,
help='How many probability to add to $KEEP token.',
default=0)
parser.add_argument('--min_probability',
type=float,
default=0.0)
parser.add_argument('--min_error_probability',
type=float,
default=0.0)
parser.add_argument('--special_tokens_fix',
type=int,
help='Whether to fix problem with [CLS], [SEP] tokens tokenization. '
'For reproducing reported results it should be 0 for BERT/XLNet and 1 for RoBERTa.',
default=1)
parser.add_argument('--is_ensemble',
type=int,
help='Whether to do ensembling.',
default=0)
parser.add_argument('--weights',
help='Used to calculate weighted average', nargs='+',
default=None)
args = parser.parse_args()

save_as_pickle("args.pkl", args)

inferer = infer_from_trained(args)
inferer.infer_from_file(input_file='./data/gec/gector/input.txt', \
output_file='./data/gec/gector/output.txt', batch_size=32)
Empty file added nlptoolkit/gec/__init__.py
Empty file.
133 changes: 133 additions & 0 deletions nlptoolkit/gec/infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import argparse

from .models.gector.utils.helpers import read_lines
from .models.gector.gec_model import GecBERTModel

class infer_from_trained(object):
def __init__(self, args):
self.args = args
self.model = GecBERTModel(vocab_path=args.vocab_path,
model_paths=args.model_path,
max_len=args.max_len, min_len=args.min_len,
iterations=args.iteration_count,
min_error_probability=args.min_error_probability,
min_probability=args.min_error_probability,
lowercase_tokens=args.lowercase_tokens,
model_name=args.transformer_model,
special_tokens_fix=args.special_tokens_fix,
log=False,
confidence=args.additional_confidence,
is_ensemble=args.is_ensemble,
weigths=args.weights)

def infer_from_file(self, input_file='./data/gec/gector/input.txt', \
output_file='./data/gec/gector/output.txt', batch_size=32):
test_data = read_lines(input_file)
predictions = []
cnt_corrections = 0
batch = []
for sent in test_data:
batch.append(sent.split())
if len(batch) == batch_size:
preds, cnt = self.model.handle_batch(batch)
predictions.extend(preds)
cnt_corrections += cnt
batch = []
if batch:
preds, cnt = self.model.handle_batch(batch)
predictions.extend(preds)
cnt_corrections += cnt

with open(output_file, 'w') as f:
f.write("\n".join([" ".join(x) for x in predictions]) + '\n')
return cnt_corrections


def main(args):
# get all paths
model = GecBERTModel(vocab_path=args.vocab_path,
model_paths=args.model_path,
max_len=args.max_len, min_len=args.min_len,
iterations=args.iteration_count,
min_error_probability=args.min_error_probability,
min_probability=args.min_error_probability,
lowercase_tokens=args.lowercase_tokens,
model_name=args.transformer_model,
special_tokens_fix=args.special_tokens_fix,
log=False,
confidence=args.additional_confidence,
is_ensemble=args.is_ensemble,
weigths=args.weights)

cnt_corrections = predict_for_file(args.input_file, args.output_file, model,
batch_size=args.batch_size)
# evaluate with m2 or ERRANT
print(f"Produced overall corrections: {cnt_corrections}")


if __name__ == '__main__':
# read parameters
parser = argparse.ArgumentParser()
parser.add_argument('--model_path',
help='Path to the model file.', nargs='+',
required=True)
parser.add_argument('--vocab_path',
help='Path to the model file.',
default='data/output_vocabulary' # to use pretrained models
)
parser.add_argument('--input_file',
help='Path to the evalset file',
required=True)
parser.add_argument('--output_file',
help='Path to the output file',
required=True)
parser.add_argument('--max_len',
type=int,
help='The max sentence length'
'(all longer will be truncated)',
default=50)
parser.add_argument('--min_len',
type=int,
help='The minimum sentence length'
'(all longer will be returned w/o changes)',
default=3)
parser.add_argument('--batch_size',
type=int,
help='The size of hidden unit cell.',
default=128)
parser.add_argument('--lowercase_tokens',
type=int,
help='Whether to lowercase tokens.',
default=0)
parser.add_argument('--transformer_model',
choices=['bert', 'gpt2', 'transformerxl', 'xlnet', 'distilbert', 'roberta', 'albert'],
help='Name of the transformer model.',
default='roberta')
parser.add_argument('--iteration_count',
type=int,
help='The number of iterations of the model.',
default=5)
parser.add_argument('--additional_confidence',
type=float,
help='How many probability to add to $KEEP token.',
default=0)
parser.add_argument('--min_probability',
type=float,
default=0.0)
parser.add_argument('--min_error_probability',
type=float,
default=0.0)
parser.add_argument('--special_tokens_fix',
type=int,
help='Whether to fix problem with [CLS], [SEP] tokens tokenization. '
'For reproducing reported results it should be 0 for BERT/XLNet and 1 for RoBERTa.',
default=1)
parser.add_argument('--is_ensemble',
type=int,
help='Whether to do ensembling.',
default=0)
parser.add_argument('--weights',
help='Used to calculate weighted average', nargs='+',
default=None)
args = parser.parse_args()
main(args)
Empty file.
Empty file.
Loading

0 comments on commit 3919b57

Please sign in to comment.