Skip to content

Commit

Permalink
added corrector_module for punc, included some pre-trained models
Browse files Browse the repository at this point in the history
  • Loading branch information
plkmo committed Oct 14, 2019
1 parent d67d03f commit 35075f9
Show file tree
Hide file tree
Showing 5 changed files with 122 additions and 37 deletions.
73 changes: 65 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ The training data (default: train.csv) should be formatted into two columns 'tex

The infer data (default: infer.csv) should be formatted into at least one column 'text' being the raw text and rows being the documents index. Optional column 'label' can be added and --train_test_split argument set to 1 to use infer.csv as the test set for model verification.

IMDB datasets for sentiment classification available [here.](https://drive.google.com/drive/folders/1a4tw3UsbwQViIgw08kwWn0jvtLOSnKZb?usp=sharing)

### Running the model
Run classify.py with arguments below.

Expand Down Expand Up @@ -86,16 +88,35 @@ from nlptoolkit.classification.models.infer import infer_from_trained
config = Config(task='classification') # loads default argument parameters as above
config.train_data = './data/train.csv' # sets training data path
config.infer_data = './data/infer.csv' # sets infer data path
config.num_classes = 5 # sets number of prediction classes
config.num_classes = 2 # sets number of prediction classes
config.batch_size = 32
config.model_no = 1 # sets BERT model
config.lr = 0.001 # change learning rate
train_and_fit(config) # starts training with configured parameters
inferer = infer_from_trained(config) # initiate infer object, which loads the model for inference, after training model
inferer.infer_from_input()
inferer.infer_from_input() # infer from user console input
inferer.infer_from_file(in_file="./data/input.txt", out_file="./data/output.txt")
```

```python
inferer.infer_from_input()
```
Sample output:
```bash
Type input sentence (Type 'exit' or 'quit' to quit):
This is a good movie.
Predicted class: 1

Type input sentence (Type 'exit' or 'quit' to quit):
This is a bad movie.
Predicted class: 0

```

### Pre-trained models
Download and zip contents of downloaded folder into ./data/ folder.
1. [BERT for IMDB sentiment analysis](https://drive.google.com/drive/folders/1JHOabZE4U4sfcnttIQsHcu9XNEEJwk0X?usp=sharing) (includes preprocessed data, vocab, and saved results files)
2. [XLNet for IMDB sentiment analysis](https://drive.google.com/drive/folders/1lk0N6DdgeEoVhoaCrC0vysL7GBJe9sAX?usp=sharing) (includes preprocessed data, vocab, and saved results files)
---

## 2) Automatic Speech Recognition
Expand Down Expand Up @@ -227,7 +248,7 @@ config.batch_size = 16
config.lr = 0.0001 # change learning rate
train_and_fit(config) # starts training with configured parameters
inferer = infer_from_trained(config) # initiate infer object, which loads the model for inference, after training model
inferer.infer_from_input()
inferer.infer_from_input() # infer from user console input
inferer.infer_from_file(in_file="./data/input.txt", out_file="./data/output.txt")
```
Expand Down Expand Up @@ -257,12 +278,14 @@ output = infer_from_pretrained(input_sent=None, tokens_len=100, top_k_beam=1)
## 6) Punctuation Restoration
Given unpunctuated (and perhaps un-capitalized) text, punctuation restoration aims to restore the punctuation of the text for easier readability. Applications include punctuating raw transcripts from audio speech data etc. Currently supports the following models:
1. Transformer
2. Bi-LSTM with attention
1. Transformer (PuncTransformer)
2. Bi-LSTM with attention (PuncLSTM)
### Format of dataset files
Currently only supports TED talk transcripts format, whereby punctuated text is annotated by \<transcripts\> tags. Eg. \<transcript\> "punctuated text" \</transcript\>. The "punctuated text" is preprocessed and then used for training.
TED talks dataset can be downloaded [here.](https://drive.google.com/file/d/1fJpl-fF5bcAKbtZbTygipUSZYyJdYU11/view?usp=sharing)
### Running the model
Run punctuate.py
Expand Down Expand Up @@ -293,6 +316,40 @@ punctuate.py [-h]
```
Or, if used as a package,
```python
from nlptoolkit.utils.config import Config
from nlptoolkit.punctuation_restoration.trainer import train_and_fit
from nlptoolkit.punctuation_restoration.infer import infer_from_trained
config = Config(task='punctuation_restoration') # loads default argument parameters as above
config.data_path = "./data/train.tags.en-fr.en"' # sets training data path
config.batch_size = 32
config.lr = 5e-5 # change learning rate
config.model_no = 1 # sets model to PuncLSTM
train_and_fit(config) # starts training with configured parameters
inferer = infer_from_trained(config) # initiate infer object, which loads the model for inference, after training model
inferer.infer_from_input() # infer from user console input
inferer.infer_from_file(in_file="./data/input.txt", out_file="./data/output.txt") # infer from input file
```
```python
inferer.infer_from_input()
```
Sample output:
```bash
Input sentence to punctuate:
hi how are you
Predicted Label: Hi. How are you?
Input sentence to punctuate:
this is good thank you very much
Predicted Label: This is good. Thank you very much.
```
### Pre-trained models
Download and zip contents of downloaded folder into ./data/ folder.
1. [PuncLSTM](https://drive.google.com/drive/folders/1ftDQYj3wv0t9MVtAVod5RIDMrY-NhZ82?usp=sharing) (includes preprocessed data, vocab, and saved results files)
---
## 7) Named Entity Recognition
Expand Down Expand Up @@ -339,7 +396,7 @@ config.lr = 5e-5 # change learning rate
config.model_no = 0 # sets model to BERT
train_and_fit(config) # starts training with configured parameters
inferer = infer_from_trained(config) # initiate infer object, which loads the model for inference, after training model
inferer.infer_from_input()
inferer.infer_from_input() # infer from user console input
inferer.infer_from_file(in_file="./data/input.txt", out_file="./data/output.txt")
```
Expand Down Expand Up @@ -383,8 +440,8 @@ inferer.infer_from_file(in_file="./data/input.txt", out_file="./data/output.txt"
# To do list
In order of priority:
- [ ] Include package usage info for ~~classification~~, ASR, summarization, ~~translation~~, ~~generation~~, punctuation_restoration, ~~NER~~, POS
- [ ] Include package usage info for ~~classification~~, ASR, summarization, ~~translation~~, ~~generation~~, ~~punctuation_restoration~~, ~~NER~~, POS
- [ ] Include benchmark results for ~~classification~~, ASR, summarization, translation, generation, punctuation_restoration, ~~NER~~, POS
- [ ] Include pre-trained models + demo based on benchmark datasets for classification, ASR, summarization, translation, generation, punctuation_restoration, NER, POS
- [ ] Include pre-trained models + demo based on benchmark datasets for ~~classification~~, ASR, summarization, translation, generation, punctuation_restoration, NER, POS
- [ ] Include more models for punctuation restoration, translation, NER
4 changes: 2 additions & 2 deletions classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,15 @@
parser.add_argument("--hidden_size_1", type=int, default=330, help="Size of first GCN hidden weights")
parser.add_argument("--hidden_size_2", type=int, default=130, help="Size of second GCN hidden weights")
parser.add_argument("--tokens_length", type=int, default=200, help="Max tokens length for BERT")
parser.add_argument("--num_classes", type=int, default=5, help="Number of prediction classes (starts from integer 0)")
parser.add_argument("--num_classes", type=int, default=2, help="Number of prediction classes (starts from integer 0)")
parser.add_argument("--train_test_split", type=int, default=0, help="0: No, 1: Yes (Only activate if infer.csv contains labelled data)")
parser.add_argument("--test_ratio", type=float, default=0.1, help="GCN: Ratio of test to training nodes")
parser.add_argument("--batch_size", type=int, default=32, help="Training batch size")
parser.add_argument("--gradient_acc_steps", type=int, default=1, help="No. of steps of gradient accumulation")
parser.add_argument("--max_norm", type=float, default=1.0, help="Clipped gradient norm")
parser.add_argument("--num_epochs", type=int, default=40, help="No of epochs")
parser.add_argument("--lr", type=float, default=0.001, help="learning rate")
parser.add_argument("--model_no", type=int, default=2, help="Model ID: (0: Graph Convolution Network (GCN), 1: BERT, 2: XLNet)")
parser.add_argument("--model_no", type=int, default=1, help="Model ID: (0: Graph Convolution Network (GCN), 1: BERT, 2: XLNet)")

parser.add_argument("--train", type=int, default=1, help="Train model on dataset")
parser.add_argument("--infer", type=int, default=0, help="Infer input sentence labels from trained model")
Expand Down
74 changes: 52 additions & 22 deletions nlptoolkit/punctuation_restoration/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
@author: tsd
"""
import re
import pandas as pd
import torch
from .preprocessing_funcs import load_dataloaders
Expand Down Expand Up @@ -32,6 +33,32 @@ def __init__(self, idx_mappings, mappings):
self.punc2idx = map2
self.idx2punc = {v:k for k,v in map2.items()}

def find(s, ch=[".", "!", "?"]):
return [i for i, ltr in enumerate(s) if ltr in ch]

def corrector_module(corrected, sent=None, cap_abbrev=True):
corrected = corrected[0].upper() + corrected[1:]
corrected = re.sub(r" +([\.\?!,])", r"\1", corrected) # corrected = re.sub(" +[,]", ",", corrected)
corrected = re.sub("' +", "'", corrected) # remove extra spaces from ' s
idxs = find(corrected)
for idx in idxs:
if (idx + 3) < len(corrected):
corrected = corrected[:idx + 2] + corrected[(idx + 2)].upper() + corrected[(idx + 3):]

if cap_abbrev == True:
abbrevs = ["ntu", "nus", "smrt", "sutd", "sim", "smu", "i2r", "astar", "imda", "hdb", "edb", "lta", "cna",\
"suss"]
corrected = corrected.split()
corrected1 = []
for word in corrected:
if word.lower().strip("!?.,") in abbrevs:
corrected1.append(word.upper())
else:
corrected1.append(word)
assert len(corrected) == len(corrected1)
corrected = " ".join(corrected1)
return corrected

class infer_from_trained(object):
def __init__(self, args=None):

Expand Down Expand Up @@ -85,16 +112,12 @@ def infer_from_data(self,):

if self.args.model_no == 0:
src_input, trg_input, trg2_input = data[0], data[1][:, :-1], data[2][:, :-1]
#labels = data[1][:,1:].contiguous().view(-1)
#labels2 = data[2][:,1:].contiguous().view(-1)
src_mask, trg_mask = self.create_masks(src_input, trg_input)
trg2_mask = self.create_trg_mask(trg2_input, ignore_idx=self.idx_mappings['pad'])
if self.cuda:
src_input = src_input.cuda().long(); trg_input = trg_input.cuda().long(); #labels = labels.cuda().long()
src_input = src_input.cuda().long(); trg_input = trg_input.cuda().long();
src_mask = src_mask.cuda(); trg_mask = trg_mask.cuda(); trg2_mask = trg2_mask.cuda()
trg2_input = trg2_input.cuda().long(); #labels2 = labels2.cuda().long()
# self, src, trg, trg2, src_mask, trg_mask=None, trg_mask2=None, infer=False, trg_vocab_obj=None, \
#trg2_vocab_obj=None
trg2_input = trg2_input.cuda().long();
stepwise_translated_words, final_step_words, stepwise_translated_words2, final_step_words2 = self.net(src_input, \
trg_input[:,0].unsqueeze(0), \
trg2_input[:,0].unsqueeze(0),\
Expand Down Expand Up @@ -126,7 +149,6 @@ def infer_from_data(self,):
src_input = src_input.cuda().long(); trg_input = trg_input.cuda().long(); labels = labels.cuda().long()
trg2_input = trg2_input.cuda().long(); labels2 = labels2.cuda().long()
outputs, outputs2 = self.net(src_input, trg_input, trg2_input, infer=True)
#print(outputs, outputs2)
outputs2 = outputs2.cpu().numpy().tolist() if outputs2.is_cuda else outputs2.cpu().numpy().tolist()
punc = [self.trg2_vocab.idx2punc[i] for i in outputs2[0]]
print(punc)
Expand Down Expand Up @@ -173,20 +195,27 @@ def infer_sentence(self, sentence):
self.trg2_vocab)

step_ = " ".join(stepwise_translated_words)
final_2 = " ".join(final_step_words2)
print("\nStepwise-translated:")
print(step_)
print()
print("\nFinal step translated words: ")
print(" ".join(final_step_words))
print()
print("\nStepwise-translated2:")
print(" ".join(stepwise_translated_words2))
print()
print("\nFinal step translated words2: ")
print(final_2)
print()
predicted = step_
final_ = " ".join(final_step_words)
if not (step_ == '') or not (final_ == ''):
step_ = corrector_module(step_)
final_ = corrector_module(final_)
final_2 = " ".join(final_step_words2)
print("\nStepwise-translated:")
print(step_)
print()
print("\nFinal step translated words: ")
print(final_)
print()
print("\nStepwise-translated2:")
print(" ".join(stepwise_translated_words2))
print()
print("\nFinal step translated words2: ")
print(final_2)
print()
predicted = step_
else:
print("None, please try another sentence.")
predicted = "None"

elif self.args.model_no == 1:
sent = torch.nn.functional.pad(sent,[0, (self.args.max_encoder_len - sent.shape[1])], value=1)
Expand All @@ -198,7 +227,7 @@ def infer_sentence(self, sentence):
outputs, outputs2 = self.net(sent, trg_input, trg2_input, infer=True)
outputs2 = outputs2.cpu().numpy().tolist() if outputs2.is_cuda else outputs2.cpu().numpy().tolist()
punc = [self.trg2_vocab.idx2punc[i] for i in outputs2[0]]
print(punc)
#print(punc)
punc = [self.mappings[p] if p in ['!', '?', '.', ','] else p for p in punc]

sent = sent[sent != 1]
Expand All @@ -215,6 +244,7 @@ def infer_sentence(self, sentence):
break

predicted = " ".join(self.vocab.inverse_transform([punc[:idx]]))
predicted = corrector_module(predicted)
print("Predicted Label: ", predicted)
return predicted

Expand Down
6 changes: 2 additions & 4 deletions punctuate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
@author: tsd
"""
from nlptoolkit.punctuation_restoration.preprocessing_funcs import load_dataloaders
from nlptoolkit.punctuation_restoration.trainer import train_and_fit
from nlptoolkit.punctuation_restoration.infer import infer_from_trained
from nlptoolkit.utils.misc import save_as_pickle
Expand Down Expand Up @@ -44,9 +43,8 @@
args = parser.parse_args()
save_as_pickle("args.pkl", args)

#df, train_loader, train_length, max_features_len, max_output_len = load_dataloaders(args)
if args.train:
train_and_fit(args)
if args.infer:
infer_ = infer_from_trained()
infer_.infer_from_data()
inferer = infer_from_trained()
inferer.infer_from_data()
2 changes: 1 addition & 1 deletion translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@

parser.add_argument("--train", type=int, default=1, help="Train model on dataset")
parser.add_argument("--evaluate", type=int, default=0, help="Evaluate the trained model on dataset")
parser.add_argument("--infer", type=int, default=1, help="Infer input sentences")
parser.add_argument("--infer", type=int, default=0, help="Infer input sentences")
args = parser.parse_args()

save_as_pickle("args.pkl", args)
Expand Down

0 comments on commit 35075f9

Please sign in to comment.