From 35075f9a92cbf4171734b986bc727038768f34bf Mon Sep 17 00:00:00 2001 From: plkmo Date: Mon, 14 Oct 2019 11:23:09 +0800 Subject: [PATCH] added corrector_module for punc, included some pre-trained models --- README.md | 73 +++++++++++++++++--- classify.py | 4 +- nlptoolkit/punctuation_restoration/infer.py | 74 +++++++++++++++------ punctuate.py | 6 +- translate.py | 2 +- 5 files changed, 122 insertions(+), 37 deletions(-) diff --git a/README.md b/README.md index cb3050f..c6e03b2 100644 --- a/README.md +++ b/README.md @@ -52,6 +52,8 @@ The training data (default: train.csv) should be formatted into two columns 'tex The infer data (default: infer.csv) should be formatted into at least one column 'text' being the raw text and rows being the documents index. Optional column 'label' can be added and --train_test_split argument set to 1 to use infer.csv as the test set for model verification. +IMDB datasets for sentiment classification available [here.](https://drive.google.com/drive/folders/1a4tw3UsbwQViIgw08kwWn0jvtLOSnKZb?usp=sharing) + ### Running the model Run classify.py with arguments below. @@ -86,16 +88,35 @@ from nlptoolkit.classification.models.infer import infer_from_trained config = Config(task='classification') # loads default argument parameters as above config.train_data = './data/train.csv' # sets training data path config.infer_data = './data/infer.csv' # sets infer data path -config.num_classes = 5 # sets number of prediction classes +config.num_classes = 2 # sets number of prediction classes config.batch_size = 32 config.model_no = 1 # sets BERT model config.lr = 0.001 # change learning rate train_and_fit(config) # starts training with configured parameters inferer = infer_from_trained(config) # initiate infer object, which loads the model for inference, after training model -inferer.infer_from_input() +inferer.infer_from_input() # infer from user console input inferer.infer_from_file(in_file="./data/input.txt", out_file="./data/output.txt") ``` +```python +inferer.infer_from_input() +``` +Sample output: +```bash +Type input sentence (Type 'exit' or 'quit' to quit): +This is a good movie. +Predicted class: 1 + +Type input sentence (Type 'exit' or 'quit' to quit): +This is a bad movie. +Predicted class: 0 + +``` + +### Pre-trained models +Download and zip contents of downloaded folder into ./data/ folder. +1. [BERT for IMDB sentiment analysis](https://drive.google.com/drive/folders/1JHOabZE4U4sfcnttIQsHcu9XNEEJwk0X?usp=sharing) (includes preprocessed data, vocab, and saved results files) +2. [XLNet for IMDB sentiment analysis](https://drive.google.com/drive/folders/1lk0N6DdgeEoVhoaCrC0vysL7GBJe9sAX?usp=sharing) (includes preprocessed data, vocab, and saved results files) --- ## 2) Automatic Speech Recognition @@ -227,7 +248,7 @@ config.batch_size = 16 config.lr = 0.0001 # change learning rate train_and_fit(config) # starts training with configured parameters inferer = infer_from_trained(config) # initiate infer object, which loads the model for inference, after training model -inferer.infer_from_input() +inferer.infer_from_input() # infer from user console input inferer.infer_from_file(in_file="./data/input.txt", out_file="./data/output.txt") ``` @@ -257,12 +278,14 @@ output = infer_from_pretrained(input_sent=None, tokens_len=100, top_k_beam=1) ## 6) Punctuation Restoration Given unpunctuated (and perhaps un-capitalized) text, punctuation restoration aims to restore the punctuation of the text for easier readability. Applications include punctuating raw transcripts from audio speech data etc. Currently supports the following models: -1. Transformer -2. Bi-LSTM with attention +1. Transformer (PuncTransformer) +2. Bi-LSTM with attention (PuncLSTM) ### Format of dataset files Currently only supports TED talk transcripts format, whereby punctuated text is annotated by \ tags. Eg. \ "punctuated text" \. The "punctuated text" is preprocessed and then used for training. +TED talks dataset can be downloaded [here.](https://drive.google.com/file/d/1fJpl-fF5bcAKbtZbTygipUSZYyJdYU11/view?usp=sharing) + ### Running the model Run punctuate.py @@ -293,6 +316,40 @@ punctuate.py [-h] ``` +Or, if used as a package, +```python +from nlptoolkit.utils.config import Config +from nlptoolkit.punctuation_restoration.trainer import train_and_fit +from nlptoolkit.punctuation_restoration.infer import infer_from_trained + +config = Config(task='punctuation_restoration') # loads default argument parameters as above +config.data_path = "./data/train.tags.en-fr.en"' # sets training data path +config.batch_size = 32 +config.lr = 5e-5 # change learning rate +config.model_no = 1 # sets model to PuncLSTM +train_and_fit(config) # starts training with configured parameters +inferer = infer_from_trained(config) # initiate infer object, which loads the model for inference, after training model +inferer.infer_from_input() # infer from user console input +inferer.infer_from_file(in_file="./data/input.txt", out_file="./data/output.txt") # infer from input file +``` + +```python +inferer.infer_from_input() +``` +Sample output: +```bash +Input sentence to punctuate: +hi how are you +Predicted Label: Hi. How are you? + +Input sentence to punctuate: +this is good thank you very much +Predicted Label: This is good. Thank you very much. +``` + +### Pre-trained models +Download and zip contents of downloaded folder into ./data/ folder. +1. [PuncLSTM](https://drive.google.com/drive/folders/1ftDQYj3wv0t9MVtAVod5RIDMrY-NhZ82?usp=sharing) (includes preprocessed data, vocab, and saved results files) --- ## 7) Named Entity Recognition @@ -339,7 +396,7 @@ config.lr = 5e-5 # change learning rate config.model_no = 0 # sets model to BERT train_and_fit(config) # starts training with configured parameters inferer = infer_from_trained(config) # initiate infer object, which loads the model for inference, after training model -inferer.infer_from_input() +inferer.infer_from_input() # infer from user console input inferer.infer_from_file(in_file="./data/input.txt", out_file="./data/output.txt") ``` @@ -383,8 +440,8 @@ inferer.infer_from_file(in_file="./data/input.txt", out_file="./data/output.txt" # To do list In order of priority: -- [ ] Include package usage info for ~~classification~~, ASR, summarization, ~~translation~~, ~~generation~~, punctuation_restoration, ~~NER~~, POS +- [ ] Include package usage info for ~~classification~~, ASR, summarization, ~~translation~~, ~~generation~~, ~~punctuation_restoration~~, ~~NER~~, POS - [ ] Include benchmark results for ~~classification~~, ASR, summarization, translation, generation, punctuation_restoration, ~~NER~~, POS -- [ ] Include pre-trained models + demo based on benchmark datasets for classification, ASR, summarization, translation, generation, punctuation_restoration, NER, POS +- [ ] Include pre-trained models + demo based on benchmark datasets for ~~classification~~, ASR, summarization, translation, generation, punctuation_restoration, NER, POS - [ ] Include more models for punctuation restoration, translation, NER diff --git a/classify.py b/classify.py index d46c06b..b4fac4b 100644 --- a/classify.py +++ b/classify.py @@ -35,7 +35,7 @@ parser.add_argument("--hidden_size_1", type=int, default=330, help="Size of first GCN hidden weights") parser.add_argument("--hidden_size_2", type=int, default=130, help="Size of second GCN hidden weights") parser.add_argument("--tokens_length", type=int, default=200, help="Max tokens length for BERT") - parser.add_argument("--num_classes", type=int, default=5, help="Number of prediction classes (starts from integer 0)") + parser.add_argument("--num_classes", type=int, default=2, help="Number of prediction classes (starts from integer 0)") parser.add_argument("--train_test_split", type=int, default=0, help="0: No, 1: Yes (Only activate if infer.csv contains labelled data)") parser.add_argument("--test_ratio", type=float, default=0.1, help="GCN: Ratio of test to training nodes") parser.add_argument("--batch_size", type=int, default=32, help="Training batch size") @@ -43,7 +43,7 @@ parser.add_argument("--max_norm", type=float, default=1.0, help="Clipped gradient norm") parser.add_argument("--num_epochs", type=int, default=40, help="No of epochs") parser.add_argument("--lr", type=float, default=0.001, help="learning rate") - parser.add_argument("--model_no", type=int, default=2, help="Model ID: (0: Graph Convolution Network (GCN), 1: BERT, 2: XLNet)") + parser.add_argument("--model_no", type=int, default=1, help="Model ID: (0: Graph Convolution Network (GCN), 1: BERT, 2: XLNet)") parser.add_argument("--train", type=int, default=1, help="Train model on dataset") parser.add_argument("--infer", type=int, default=0, help="Infer input sentence labels from trained model") diff --git a/nlptoolkit/punctuation_restoration/infer.py b/nlptoolkit/punctuation_restoration/infer.py index 24bf1eb..7abb1f6 100644 --- a/nlptoolkit/punctuation_restoration/infer.py +++ b/nlptoolkit/punctuation_restoration/infer.py @@ -4,6 +4,7 @@ @author: tsd """ +import re import pandas as pd import torch from .preprocessing_funcs import load_dataloaders @@ -32,6 +33,32 @@ def __init__(self, idx_mappings, mappings): self.punc2idx = map2 self.idx2punc = {v:k for k,v in map2.items()} +def find(s, ch=[".", "!", "?"]): + return [i for i, ltr in enumerate(s) if ltr in ch] + +def corrector_module(corrected, sent=None, cap_abbrev=True): + corrected = corrected[0].upper() + corrected[1:] + corrected = re.sub(r" +([\.\?!,])", r"\1", corrected) # corrected = re.sub(" +[,]", ",", corrected) + corrected = re.sub("' +", "'", corrected) # remove extra spaces from ' s + idxs = find(corrected) + for idx in idxs: + if (idx + 3) < len(corrected): + corrected = corrected[:idx + 2] + corrected[(idx + 2)].upper() + corrected[(idx + 3):] + + if cap_abbrev == True: + abbrevs = ["ntu", "nus", "smrt", "sutd", "sim", "smu", "i2r", "astar", "imda", "hdb", "edb", "lta", "cna",\ + "suss"] + corrected = corrected.split() + corrected1 = [] + for word in corrected: + if word.lower().strip("!?.,") in abbrevs: + corrected1.append(word.upper()) + else: + corrected1.append(word) + assert len(corrected) == len(corrected1) + corrected = " ".join(corrected1) + return corrected + class infer_from_trained(object): def __init__(self, args=None): @@ -85,16 +112,12 @@ def infer_from_data(self,): if self.args.model_no == 0: src_input, trg_input, trg2_input = data[0], data[1][:, :-1], data[2][:, :-1] - #labels = data[1][:,1:].contiguous().view(-1) - #labels2 = data[2][:,1:].contiguous().view(-1) src_mask, trg_mask = self.create_masks(src_input, trg_input) trg2_mask = self.create_trg_mask(trg2_input, ignore_idx=self.idx_mappings['pad']) if self.cuda: - src_input = src_input.cuda().long(); trg_input = trg_input.cuda().long(); #labels = labels.cuda().long() + src_input = src_input.cuda().long(); trg_input = trg_input.cuda().long(); src_mask = src_mask.cuda(); trg_mask = trg_mask.cuda(); trg2_mask = trg2_mask.cuda() - trg2_input = trg2_input.cuda().long(); #labels2 = labels2.cuda().long() - # self, src, trg, trg2, src_mask, trg_mask=None, trg_mask2=None, infer=False, trg_vocab_obj=None, \ - #trg2_vocab_obj=None + trg2_input = trg2_input.cuda().long(); stepwise_translated_words, final_step_words, stepwise_translated_words2, final_step_words2 = self.net(src_input, \ trg_input[:,0].unsqueeze(0), \ trg2_input[:,0].unsqueeze(0),\ @@ -126,7 +149,6 @@ def infer_from_data(self,): src_input = src_input.cuda().long(); trg_input = trg_input.cuda().long(); labels = labels.cuda().long() trg2_input = trg2_input.cuda().long(); labels2 = labels2.cuda().long() outputs, outputs2 = self.net(src_input, trg_input, trg2_input, infer=True) - #print(outputs, outputs2) outputs2 = outputs2.cpu().numpy().tolist() if outputs2.is_cuda else outputs2.cpu().numpy().tolist() punc = [self.trg2_vocab.idx2punc[i] for i in outputs2[0]] print(punc) @@ -173,20 +195,27 @@ def infer_sentence(self, sentence): self.trg2_vocab) step_ = " ".join(stepwise_translated_words) - final_2 = " ".join(final_step_words2) - print("\nStepwise-translated:") - print(step_) - print() - print("\nFinal step translated words: ") - print(" ".join(final_step_words)) - print() - print("\nStepwise-translated2:") - print(" ".join(stepwise_translated_words2)) - print() - print("\nFinal step translated words2: ") - print(final_2) - print() - predicted = step_ + final_ = " ".join(final_step_words) + if not (step_ == '') or not (final_ == ''): + step_ = corrector_module(step_) + final_ = corrector_module(final_) + final_2 = " ".join(final_step_words2) + print("\nStepwise-translated:") + print(step_) + print() + print("\nFinal step translated words: ") + print(final_) + print() + print("\nStepwise-translated2:") + print(" ".join(stepwise_translated_words2)) + print() + print("\nFinal step translated words2: ") + print(final_2) + print() + predicted = step_ + else: + print("None, please try another sentence.") + predicted = "None" elif self.args.model_no == 1: sent = torch.nn.functional.pad(sent,[0, (self.args.max_encoder_len - sent.shape[1])], value=1) @@ -198,7 +227,7 @@ def infer_sentence(self, sentence): outputs, outputs2 = self.net(sent, trg_input, trg2_input, infer=True) outputs2 = outputs2.cpu().numpy().tolist() if outputs2.is_cuda else outputs2.cpu().numpy().tolist() punc = [self.trg2_vocab.idx2punc[i] for i in outputs2[0]] - print(punc) + #print(punc) punc = [self.mappings[p] if p in ['!', '?', '.', ','] else p for p in punc] sent = sent[sent != 1] @@ -215,6 +244,7 @@ def infer_sentence(self, sentence): break predicted = " ".join(self.vocab.inverse_transform([punc[:idx]])) + predicted = corrector_module(predicted) print("Predicted Label: ", predicted) return predicted diff --git a/punctuate.py b/punctuate.py index 777a4fc..f2fd672 100644 --- a/punctuate.py +++ b/punctuate.py @@ -4,7 +4,6 @@ @author: tsd """ -from nlptoolkit.punctuation_restoration.preprocessing_funcs import load_dataloaders from nlptoolkit.punctuation_restoration.trainer import train_and_fit from nlptoolkit.punctuation_restoration.infer import infer_from_trained from nlptoolkit.utils.misc import save_as_pickle @@ -44,9 +43,8 @@ args = parser.parse_args() save_as_pickle("args.pkl", args) - #df, train_loader, train_length, max_features_len, max_output_len = load_dataloaders(args) if args.train: train_and_fit(args) if args.infer: - infer_ = infer_from_trained() - infer_.infer_from_data() + inferer = infer_from_trained() + inferer.infer_from_data() diff --git a/translate.py b/translate.py index fcaede5..a0bdf05 100644 --- a/translate.py +++ b/translate.py @@ -35,7 +35,7 @@ parser.add_argument("--train", type=int, default=1, help="Train model on dataset") parser.add_argument("--evaluate", type=int, default=0, help="Evaluate the trained model on dataset") - parser.add_argument("--infer", type=int, default=1, help="Infer input sentences") + parser.add_argument("--infer", type=int, default=0, help="Infer input sentences") args = parser.parse_args() save_as_pickle("args.pkl", args)