Skip to content

Commit

Permalink
updated puncLSTM infer_from_file, to add punc heuristic module and pu…
Browse files Browse the repository at this point in the history
…ncTransformer
  • Loading branch information
plkmo committed Oct 12, 2019
1 parent 039d343 commit d67d03f
Showing 1 changed file with 79 additions and 62 deletions.
141 changes: 79 additions & 62 deletions nlptoolkit/punctuation_restoration/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,18 @@
@author: tsd
"""
import pandas as pd
import torch
from .preprocessing_funcs import load_dataloaders
from .train_funcs import load_model_and_optimizer
from .utils.bpe_vocab import Encoder
from .utils.misc import save_as_pickle, load_pickle
from .utils.word_char_level_vocab import tokener
from tqdm import tqdm
import time
import logging

tqdm.pandas(desc="prog-bar")
logging.basicConfig(format='%(asctime)s [%(levelname)s]: %(message)s', \
datefmt='%m/%d/%Y %I:%M:%S %p', level=logging.INFO)
logger = logging.getLogger('__file__')
Expand Down Expand Up @@ -72,6 +75,7 @@ def __init__(self, args=None):
idx_mappings=self.idx_mappings,\
cuda=self.cuda)
self.net = net
self.net.eval()

def infer_from_data(self,):
_, train_loader, train_length, max_features_length, max_output_len = load_dataloaders(self.args)
Expand Down Expand Up @@ -144,10 +148,75 @@ def infer_from_data(self,):
punc[idx] = 5
break

#print(punc, self.trg2_vocab.punc2idx['word'], idx)
print("Predicted Label: ", " ".join(self.vocab.inverse_transform([punc[:idx]])))

time.sleep(10)

def infer_sentence(self, sentence):
sent = torch.tensor(next(self.vocab.transform([sentence]))).unsqueeze(0)
if self.args.model_no == 0:
trg_input = torch.tensor([self.vocab.word_vocab['__sos']]).unsqueeze(0)
trg2_input = torch.tensor([self.idx_mappings['sos']]).unsqueeze(0)
src_mask, trg_mask = self.create_masks(sent, trg_input)
trg2_mask = self.create_trg_mask(trg2_input, ignore_idx=self.idx_mappings['pad'])
if self.cuda:
sent = sent.cuda().long(); trg_input = trg_input.cuda().long(); trg2_input = trg2_input.cuda().long()
src_mask = src_mask.cuda(); trg_mask = trg_mask.cuda(); trg2_mask = trg2_mask.cuda()
stepwise_translated_words, final_step_words, stepwise_translated_words2, final_step_words2 = self.net(sent, \
trg_input, \
trg2_input,\
src_mask, \
trg_mask, \
trg2_mask, \
True, \
self.vocab, \
self.trg2_vocab)

step_ = " ".join(stepwise_translated_words)
final_2 = " ".join(final_step_words2)
print("\nStepwise-translated:")
print(step_)
print()
print("\nFinal step translated words: ")
print(" ".join(final_step_words))
print()
print("\nStepwise-translated2:")
print(" ".join(stepwise_translated_words2))
print()
print("\nFinal step translated words2: ")
print(final_2)
print()
predicted = step_

elif self.args.model_no == 1:
sent = torch.nn.functional.pad(sent,[0, (self.args.max_encoder_len - sent.shape[1])], value=1)
trg_input = torch.tensor([self.vocab.word_vocab['__sos']]).unsqueeze(0)
trg2_input = torch.tensor([self.idx_mappings['sos']]).unsqueeze(0)
if self.cuda:
sent = sent.cuda().long(); trg_input = trg_input.cuda().long()
trg2_input = trg2_input.cuda().long()
outputs, outputs2 = self.net(sent, trg_input, trg2_input, infer=True)
outputs2 = outputs2.cpu().numpy().tolist() if outputs2.is_cuda else outputs2.cpu().numpy().tolist()
punc = [self.trg2_vocab.idx2punc[i] for i in outputs2[0]]
print(punc)
punc = [self.mappings[p] if p in ['!', '?', '.', ','] else p for p in punc]

sent = sent[sent != 1]
sent = sent.cpu().numpy().tolist() if self.cuda else sent.numpy().tolist()
counter = 0
for idx, p in enumerate(punc):
if (p == 'word') and (counter < len(sent)):
punc[idx] = sent[counter]
counter += 1
elif (p == "eos"):
break
elif (counter >= len(sent)) and (p in ['word', 'sos']):
punc[idx] = 5
break

predicted = " ".join(self.vocab.inverse_transform([punc[:idx]]))
print("Predicted Label: ", predicted)
return predicted

def infer_from_input(self,):
self.net.eval()
Expand All @@ -156,67 +225,15 @@ def infer_from_input(self,):
sent = input("Input sentence to punctuate:\n")
if sent in ["quit", "exit"]:
break
sent = torch.tensor(next(self.vocab.transform([sent]))).unsqueeze(0)
if self.args.model_no == 0:
trg_input = torch.tensor([self.vocab.word_vocab['__sos']]).unsqueeze(0)
trg2_input = torch.tensor([self.idx_mappings['sos']]).unsqueeze(0)
src_mask, trg_mask = self.create_masks(sent, trg_input)
trg2_mask = self.create_trg_mask(trg2_input, ignore_idx=self.idx_mappings['pad'])
if self.cuda:
sent = sent.cuda().long(); trg_input = trg_input.cuda().long(); trg2_input = trg2_input.cuda().long()
src_mask = src_mask.cuda(); trg_mask = trg_mask.cuda(); trg2_mask = trg2_mask.cuda()
stepwise_translated_words, final_step_words, stepwise_translated_words2, final_step_words2 = self.net(sent, \
trg_input, \
trg2_input,\
src_mask, \
trg_mask, \
trg2_mask, \
True, \
self.vocab, \
self.trg2_vocab)

print("\nStepwise-translated:")
print(" ".join(stepwise_translated_words))
print()
print("\nFinal step translated words: ")
print(" ".join(final_step_words))
print()
print("\nStepwise-translated2:")
print(" ".join(stepwise_translated_words2))
print()
print("\nFinal step translated words2: ")
print(" ".join(final_step_words2))
print()

elif self.args.model_no == 1:
sent = torch.nn.functional.pad(sent,[0, (self.args.max_encoder_len - sent.shape[1])], value=1)
trg_input = torch.tensor([self.vocab.word_vocab['__sos']]).unsqueeze(0)
trg2_input = torch.tensor([self.idx_mappings['sos']]).unsqueeze(0)
if self.cuda:
sent = sent.cuda().long(); trg_input = trg_input.cuda().long()
trg2_input = trg2_input.cuda().long()
outputs, outputs2 = self.net(sent, trg_input, trg2_input, infer=True)
#print(outputs, outputs2)
outputs2 = outputs2.cpu().numpy().tolist() if outputs2.is_cuda else outputs2.cpu().numpy().tolist()
punc = [self.trg2_vocab.idx2punc[i] for i in outputs2[0]]
print(punc)
punc = [self.mappings[p] if p in ['!', '?', '.', ','] else p for p in punc]

sent = sent[sent != 1]
sent = sent.cpu().numpy().tolist() if self.cuda else sent.numpy().tolist()
counter = 0
for idx, p in enumerate(punc):
if (p == 'word') and (counter < len(sent)):
punc[idx] = sent[counter]
counter += 1
elif (p == "eos"):
break
elif (counter >= len(sent)) and (p in ['word', 'sos']):
punc[idx] = 5
break

#print(punc, self.trg2_vocab.punc2idx['word'], idx)
print("Predicted Label: ", " ".join(self.vocab.inverse_transform([punc[:idx]])))
predicted = self.infer_sentence(sent)
return predicted

def infer_from_file(self, in_file="./data/input.txt", out_file="./data/output.txt"):
df = pd.read_csv(in_file, header=None, names=["sents"])
df['labels'] = df.progress_apply(lambda x: self.infer_sentence(x['sents']), axis=1)
df.to_csv(out_file, index=False)
logger.info("Done and saved as %s!" % out_file)
return

def infer(args, from_data=False):

Expand Down

0 comments on commit d67d03f

Please sign in to comment.