Skip to content

Commit

Permalink
updated translate BLEU
Browse files Browse the repository at this point in the history
  • Loading branch information
plkmo committed Sep 23, 2019
1 parent 6f6ef8b commit 633ee5e
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 13 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ Tasks:
[References](#references)

## Pre-requisites
torch==1.2.0 ; spacy==2.1.8 ; seqeval==0.0.12
torch==1.2.0 ; spacy==2.1.8 ; torchtext==0.4.0 ; seqeval==0.0.12

** Pre-trained models (XLNet, BERT, GPT-2) are courtesy of huggingface (https://github.com/huggingface/pytorch-transformers)

Expand Down
14 changes: 10 additions & 4 deletions translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
@author: WT
"""
from translation.trainer import train_and_fit
from translation.evaluate import infer
from translation.evaluate import infer, evaluate_corpus_bleu
from utils.misc import save_as_pickle
from argparse import ArgumentParser
import logging
Expand All @@ -30,11 +30,17 @@
parser.add_argument("--gradient_acc_steps", type=int, default=1, help="Number of steps of gradient accumulation")
parser.add_argument("--max_norm", type=float, default=1.0, help="Clipped gradient norm")
parser.add_argument("--model_no", type=int, default=0, help="Model ID (0: Transformer)")

parser.add_argument("--evaluate_only", type=int, default=0, help="Only evaluate the trained model on dataset")
parser.add_argument("--infer_only", type=int, default=0, help="Only infer input sentences")
args = parser.parse_args()

save_as_pickle("args.pkl", args)

'''PyTorch's transformer module runs much slower'''
train_and_fit(args, pytransformer=False)

#infer(args, True)
if (not args.evaluate_only) and (not args.infer_only):
train_and_fit(args, pytransformer=False)
elif args.evaluate_only:
evaluate_corpus_bleu(args)
elif args.infer_only:
infer(args, True)
26 changes: 19 additions & 7 deletions translation/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,26 @@
from .models.Transformer.Transformer import create_masks
from .train_funcs import load_model_and_optimizer
from .preprocessing_funcs import tokener, load_dataloaders
from tqdm import tqdm
import time
import logging

logging.basicConfig(format='%(asctime)s [%(levelname)s]: %(message)s', \
datefmt='%m/%d/%Y %I:%M:%S %p', level=logging.INFO)
logger = logging.getLogger('__file__')

def dum_tokenizer(sent):
return sent.split()

def calculate_bleu(src, trg, weights=(0.25, 0.25, 0.25, 0.25), corpus_level=False):
def calculate_bleu(src, trg, corpus_level=False, weights=(0.25, 0.25, 0.25, 0.25)):
# src = [[sent words1], [sent words2], ...], trg = [sent words]
if not corpus_level:
score = bleu_score.sentence_bleu(src, trg, weights=weights)
else:
score = bleu_score.corpus_bleu(src, trg, weights=weights)
return score

def evaluate_bleu(args):
def evaluate_corpus_bleu(args, early_stopping=True, stop_no=1000):
args.batch_size = 1
#tokenizer_en = tokener("en")
train_iter, FR, EN, train_length = load_dataloaders(args)
Expand All @@ -38,9 +44,10 @@ def evaluate_bleu(args):
trg_init = FR.vocab.stoi["<sos>"]
trg_init = Variable(torch.LongTensor([trg_init])).unsqueeze(0)

logger.info("Evaluating corpus bleu...")
refs = []; hyps = []
with torch.no_grad():
for i, data in enumerate(train_iter):
for i, data in tqdm(enumerate(train_iter), total=len(train_iter)):
trg_input = trg_init
labels = data.FR[:,1:].contiguous().view(-1)
src_mask, trg_mask = create_masks(data.EN, trg_input)
Expand All @@ -50,9 +57,13 @@ def evaluate_bleu(args):
stepwise_translated_words, final_step_words = net(data.EN, trg_input, src_mask, None,\
infer=True, trg_vocab_obj=FR)
refs.append([stepwise_translated_words]) # need to remove <eos> tokens
hyps.append([FR.vocab.itos[i] for i in labels])

return
hyps.append([FR.vocab.itos[i] for i in labels[:-1]])
if early_stopping and ((i + 1) % stop_no == 0):
print(refs); print(hyps)
break
score = calculate_bleu(refs, hyps, corpus_level=True)
print("Corpus bleu score: %.5f" % score)
return score

def infer(args, from_data=False):
args.batch_size = 1
Expand Down Expand Up @@ -100,6 +111,7 @@ def infer(args, from_data=False):
stepwise_translated_words, final_step_words = net(data.EN, trg_input, src_mask, None,\
infer=True, trg_vocab_obj=FR)
score = calculate_bleu([stepwise_translated_words], [FR.vocab.itos[i] for i in labels])
print([stepwise_translated_words]); print([FR.vocab.itos[i] for i in labels])
print("\n\nInput:")
print(" ".join(EN.vocab.itos[i] for i in data.EN[0]))
print("\nStepwise-translated:")
Expand All @@ -108,7 +120,7 @@ def infer(args, from_data=False):
print(" ".join(final_step_words))
print("\nGround Truth:")
print(" ".join(FR.vocab.itos[i] for i in labels))
print("Bleu score (stepwise-translated sentence level): %.3f" % score)
print("Bleu score (stepwise-translated sentence level): %.5f" % score)
time.sleep(7)

else:
Expand Down
15 changes: 14 additions & 1 deletion translation/preprocessing_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
@author: WT
"""

import sys
import csv
import pandas as pd
import os
import re
Expand All @@ -13,6 +14,18 @@
import spacy
import logging

maxInt = sys.maxsize

while True:
# decrease the maxInt value by factor 10
# as long as the OverflowError occurs.

try:
csv.field_size_limit(maxInt)
break
except OverflowError:
maxInt = int(maxInt/10)

logging.basicConfig(format='%(asctime)s [%(levelname)s]: %(message)s', \
datefmt='%m/%d/%Y %I:%M:%S %p', level=logging.INFO)
logger = logging.getLogger(__file__)
Expand Down

0 comments on commit 633ee5e

Please sign in to comment.