Skip to content

Commit

Permalink
updated PuncLSTM infer for punctuation
Browse files Browse the repository at this point in the history
  • Loading branch information
plkmo committed Oct 12, 2019
1 parent c9e2f1c commit 039d343
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 5 deletions.
38 changes: 35 additions & 3 deletions nlptoolkit/punctuation_restoration/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,12 +138,14 @@ def infer_from_data(self,):
if (p == 'word') and (counter < len(src_input)):
punc[idx] = src_input[counter]
counter += 1
elif (p == "eos") or (counter >= len(src_input)):
elif (p == "eos"):
break
else:
elif (counter >= len(src_input)) and (p in ['word', 'sos']):
punc[idx] = 5
break

#print(punc, self.trg2_vocab.punc2idx['word'], idx)
print("Predicted Label: ", " ".join(self.vocab.inverse_transform([punc[:(idx + 1)]])))
print("Predicted Label: ", " ".join(self.vocab.inverse_transform([punc[:idx]])))

time.sleep(10)

Expand Down Expand Up @@ -185,6 +187,36 @@ def infer_from_input(self,):
print("\nFinal step translated words2: ")
print(" ".join(final_step_words2))
print()

elif self.args.model_no == 1:
sent = torch.nn.functional.pad(sent,[0, (self.args.max_encoder_len - sent.shape[1])], value=1)
trg_input = torch.tensor([self.vocab.word_vocab['__sos']]).unsqueeze(0)
trg2_input = torch.tensor([self.idx_mappings['sos']]).unsqueeze(0)
if self.cuda:
sent = sent.cuda().long(); trg_input = trg_input.cuda().long()
trg2_input = trg2_input.cuda().long()
outputs, outputs2 = self.net(sent, trg_input, trg2_input, infer=True)
#print(outputs, outputs2)
outputs2 = outputs2.cpu().numpy().tolist() if outputs2.is_cuda else outputs2.cpu().numpy().tolist()
punc = [self.trg2_vocab.idx2punc[i] for i in outputs2[0]]
print(punc)
punc = [self.mappings[p] if p in ['!', '?', '.', ','] else p for p in punc]

sent = sent[sent != 1]
sent = sent.cpu().numpy().tolist() if self.cuda else sent.numpy().tolist()
counter = 0
for idx, p in enumerate(punc):
if (p == 'word') and (counter < len(sent)):
punc[idx] = sent[counter]
counter += 1
elif (p == "eos"):
break
elif (counter >= len(sent)) and (p in ['word', 'sos']):
punc[idx] = 5
break

#print(punc, self.trg2_vocab.punc2idx['word'], idx)
print("Predicted Label: ", " ".join(self.vocab.inverse_transform([punc[:idx]])))

def infer(args, from_data=False):

Expand Down
4 changes: 2 additions & 2 deletions translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@
parser.add_argument("--max_encoder_len", type=int, default=200, help="Max src length")
parser.add_argument("--max_decoder_len", type=int, default=200, help="Max trg length")
parser.add_argument("--num_epochs", type=int, default=500, help="No of epochs")
parser.add_argument("--lr", type=float, default=0.00007, help="learning rate")
parser.add_argument("--lr", type=float, default=0.00005, help="learning rate")
parser.add_argument("--gradient_acc_steps", type=int, default=2, help="Number of steps of gradient accumulation")
parser.add_argument("--max_norm", type=float, default=1.0, help="Clipped gradient norm")
parser.add_argument("--model_no", type=int, default=0, help="Model ID (0: Transformer)")

parser.add_argument("--train", type=int, default=0, help="Train model on dataset")
parser.add_argument("--train", type=int, default=1, help="Train model on dataset")
parser.add_argument("--evaluate", type=int, default=0, help="Evaluate the trained model on dataset")
parser.add_argument("--infer", type=int, default=1, help="Infer input sentences")
args = parser.parse_args()
Expand Down

0 comments on commit 039d343

Please sign in to comment.