From 039d343ce9a24eebe8e53aa2e14b89095f5e1922 Mon Sep 17 00:00:00 2001 From: plkmo Date: Sat, 12 Oct 2019 22:33:30 +0800 Subject: [PATCH] updated PuncLSTM infer for punctuation --- nlptoolkit/punctuation_restoration/infer.py | 38 +++++++++++++++++++-- translate.py | 4 +-- 2 files changed, 37 insertions(+), 5 deletions(-) diff --git a/nlptoolkit/punctuation_restoration/infer.py b/nlptoolkit/punctuation_restoration/infer.py index 1abc027..4b92eae 100644 --- a/nlptoolkit/punctuation_restoration/infer.py +++ b/nlptoolkit/punctuation_restoration/infer.py @@ -138,12 +138,14 @@ def infer_from_data(self,): if (p == 'word') and (counter < len(src_input)): punc[idx] = src_input[counter] counter += 1 - elif (p == "eos") or (counter >= len(src_input)): + elif (p == "eos"): break - else: + elif (counter >= len(src_input)) and (p in ['word', 'sos']): punc[idx] = 5 + break + #print(punc, self.trg2_vocab.punc2idx['word'], idx) - print("Predicted Label: ", " ".join(self.vocab.inverse_transform([punc[:(idx + 1)]]))) + print("Predicted Label: ", " ".join(self.vocab.inverse_transform([punc[:idx]]))) time.sleep(10) @@ -185,6 +187,36 @@ def infer_from_input(self,): print("\nFinal step translated words2: ") print(" ".join(final_step_words2)) print() + + elif self.args.model_no == 1: + sent = torch.nn.functional.pad(sent,[0, (self.args.max_encoder_len - sent.shape[1])], value=1) + trg_input = torch.tensor([self.vocab.word_vocab['__sos']]).unsqueeze(0) + trg2_input = torch.tensor([self.idx_mappings['sos']]).unsqueeze(0) + if self.cuda: + sent = sent.cuda().long(); trg_input = trg_input.cuda().long() + trg2_input = trg2_input.cuda().long() + outputs, outputs2 = self.net(sent, trg_input, trg2_input, infer=True) + #print(outputs, outputs2) + outputs2 = outputs2.cpu().numpy().tolist() if outputs2.is_cuda else outputs2.cpu().numpy().tolist() + punc = [self.trg2_vocab.idx2punc[i] for i in outputs2[0]] + print(punc) + punc = [self.mappings[p] if p in ['!', '?', '.', ','] else p for p in punc] + + sent = sent[sent != 1] + sent = sent.cpu().numpy().tolist() if self.cuda else sent.numpy().tolist() + counter = 0 + for idx, p in enumerate(punc): + if (p == 'word') and (counter < len(sent)): + punc[idx] = sent[counter] + counter += 1 + elif (p == "eos"): + break + elif (counter >= len(sent)) and (p in ['word', 'sos']): + punc[idx] = 5 + break + + #print(punc, self.trg2_vocab.punc2idx['word'], idx) + print("Predicted Label: ", " ".join(self.vocab.inverse_transform([punc[:idx]]))) def infer(args, from_data=False): diff --git a/translate.py b/translate.py index e32b34e..fcaede5 100644 --- a/translate.py +++ b/translate.py @@ -28,12 +28,12 @@ parser.add_argument("--max_encoder_len", type=int, default=200, help="Max src length") parser.add_argument("--max_decoder_len", type=int, default=200, help="Max trg length") parser.add_argument("--num_epochs", type=int, default=500, help="No of epochs") - parser.add_argument("--lr", type=float, default=0.00007, help="learning rate") + parser.add_argument("--lr", type=float, default=0.00005, help="learning rate") parser.add_argument("--gradient_acc_steps", type=int, default=2, help="Number of steps of gradient accumulation") parser.add_argument("--max_norm", type=float, default=1.0, help="Clipped gradient norm") parser.add_argument("--model_no", type=int, default=0, help="Model ID (0: Transformer)") - parser.add_argument("--train", type=int, default=0, help="Train model on dataset") + parser.add_argument("--train", type=int, default=1, help="Train model on dataset") parser.add_argument("--evaluate", type=int, default=0, help="Evaluate the trained model on dataset") parser.add_argument("--infer", type=int, default=1, help="Infer input sentences") args = parser.parse_args()