Skip to content

Commit

Permalink
added mixed precision fp16 training for ASR, summarization, experimental
Browse files Browse the repository at this point in the history
  • Loading branch information
plkmo committed Oct 15, 2019
1 parent aac1056 commit 6ffd0fe
Show file tree
Hide file tree
Showing 17 changed files with 437 additions and 159 deletions.
21 changes: 20 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ speech.py [-h]
[--num NUM (default: 6)]
[--n_heads N_HEADS(default: 4)]
[--batch_size BATCH_SIZE (default: 30)]
[--fp16 FP16 (default:1)]
[--num_epochs NUM_EPOCHS (default: 8000)]
[--lr LR default=0.003]
[--gradient_acc_steps GRADIENT_ACC_STEPS (default: 4)]
Expand Down Expand Up @@ -187,7 +188,8 @@ summarize.py [-h]
[--n_heads N_HEADS(default: 4)]
[--LAS_embed_dim LAS_EMBED_DIM (default: 128)]
[--LAS_hidden_size LAS_HIDDEN_SIZE (default: 128)]
[--batch_size BATCH_SIZE (default: 30)]
[--batch_size BATCH_SIZE (default: 32)]
[--fp16 FP16 (default: 1)]
[--num_epochs NUM_EPOCHS (default: 8000)]
[--lr LR default=0.003]
[--gradient_acc_steps GRADIENT_ACC_STEPS (default: 4)]
Expand All @@ -199,6 +201,23 @@ summarize.py [-h]
```
Or if used as a package:
```python
from nlptoolkit.utils.config import Config
from nlptoolkit.summarization.trainer import train_and_fit
from nlptoolkit.summarization.infer import infer_from_trained
config = Config(task='summarization') # loads default argument parameters as above
config.data_path = "./data/cnn_stories/cnn/stories/"
config.batch_size = 32
config.lr = 0.0001 # change learning rate
config.model_no = 0 # set model as Transformer
train_and_fit(config) # starts training with configured parameters
inferer = infer_from_trained(config) # initiate infer object, which loads the model for inference, after training model
inferer.infer_from_input() # infer from user console input
inferer.infer_from_file(in_file="./data/input.txt", out_file="./data/output.txt")
```
---
## 4) Machine Translation
Expand Down
27 changes: 27 additions & 0 deletions nlptoolkit/ASR/evaluate.py → nlptoolkit/ASR/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,33 @@
datefmt='%m/%d/%Y %I:%M:%S %p', level=logging.INFO)
logger = logging.getLogger('__file__')

class infer_from_trained(object):
def __init__(self, args=None):
if args is None:
self.args = load_pickle("args.pkl")
else:
self.args = args
self.cuda = torch.cuda.is_available()
self.args.batch_size = 1

def infer_sentence(self, sent):
return

def infer_from_input(self):
self.net.eval()
while True:
user_input = input("Type input sentence (Type \'exit' or \'quit' to quit):\n")
if user_input in ["exit", "quit"]:
break
predicted = self.infer_sentence(user_input)
return predicted

def infer_from_file(self, in_file="./data/input.txt", out_file="./data/output.txt"):
df = pd.read_csv(in_file, header=None, names=["sents"])
df['labels'] = df.progress_apply(lambda x: self.infer_sentence(x['sents']), axis=1)
df.to_csv(out_file, index=False)
logger.info("Done and saved as %s!" % out_file)
return

def infer(file_path=None, speaker=None, pyTransformer=False):
if pyTransformer:
Expand Down
22 changes: 18 additions & 4 deletions nlptoolkit/ASR/models/LAS/LAS_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np

Expand Down Expand Up @@ -168,18 +169,30 @@ def forward(self, x, trg_input, infer=False):
return x

@classmethod
def load_model(cls, path):
def load_model(cls, path, args, cuda=True, amp=None):
checkpoint = torch.load(path)
model = cls(listener_input_size=checkpoint['listener_input_size'], \
listener_hidden_size=checkpoint['listener_hidden_size'], \
output_class_dim=checkpoint['output_class_dim'],\
max_label_len = 100)
model.listener.flatten_parameters()
#model.speller.flatten_parameters()
if cuda:
model.cuda()

if amp is not None:
optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.98), eps=1e-9)
model, optimizer = amp.initialize(model, optimizer, opt_level='O2')
amp.load_state_dict(checkpoint['amp'])
#optimizer.load_state_dict(checkpoint['optimizer']) # dynamic loss scaling spikes if we load this! waiting for fix from nvidia apex
print("Loaded amp state dict!")
else:
optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.98), eps=1e-9)
optimizer.load_state_dict(checkpoint['optimizer'])
model.load_state_dict(checkpoint['state_dict'])
return model
return model, optimizer

def save_state(self, epoch, optimizer, scheduler, best_acc, path):
def save_state(self, epoch, optimizer, scheduler, best_acc, path, amp=None):
state = {
'epoch': epoch + 1,\
'state_dict': self.state_dict(),\
Expand All @@ -188,7 +201,8 @@ def save_state(self, epoch, optimizer, scheduler, best_acc, path):
'scheduler' : scheduler.state_dict(),\
'listener_input_size' : self.listener_input_size,\
'listener_hidden_size': self.listener_hidden_size,\
'output_class_dim': self.output_class_dim
'output_class_dim': self.output_class_dim,\
'amp': amp.state_dict()
}
torch.save(state, path)

22 changes: 18 additions & 4 deletions nlptoolkit/ASR/models/Transformer/transformer_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import numpy as np
import math
Expand Down Expand Up @@ -351,7 +352,7 @@ def forward(self, src, trg, src_mask, trg_mask=None, g_mask1=None, g_mask2=None,
#return x

@classmethod
def load_model(cls, path):
def load_model(cls, path, args, cuda=True, amp=None):
checkpoint = torch.load(path)
model = cls(src_vocab=checkpoint["src_vocab"], \
trg_vocab=checkpoint["trg_vocab"], \
Expand All @@ -362,10 +363,22 @@ def load_model(cls, path):
max_encoder_len=checkpoint["max_encoder_len"], \
max_decoder_len=checkpoint["max_decoder_len"], \
use_conv=True)
if cuda:
model.cuda()

if amp is not None:
optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.98), eps=1e-9)
model, optimizer = amp.initialize(model, optimizer, opt_level='O2')
amp.load_state_dict(checkpoint['amp'])
#optimizer.load_state_dict(checkpoint['optimizer']) # dynamic loss scaling spikes if we load this! waiting for fix from nvidia apex
print("Loaded amp state dict!")
else:
optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.98), eps=1e-9)
optimizer.load_state_dict(checkpoint['optimizer'])
model.load_state_dict(checkpoint['state_dict'])
return model
return model, optimizer

def save_state(self, epoch, optimizer, scheduler, best_acc, path):
def save_state(self, epoch, optimizer, scheduler, best_acc, path, amp=None):
state = {
'epoch': epoch + 1,\
'state_dict': self.state_dict(),\
Expand All @@ -379,6 +392,7 @@ def save_state(self, epoch, optimizer, scheduler, best_acc, path):
'num': self.num,\
'n_heads': self.n_heads,\
'max_encoder_len': self.max_encoder_len,\
'max_decoder_len': self.max_decoder_len,
'max_decoder_len': self.max_decoder_len,\
'amp': amp.state_dict()
}
torch.save(state, path)
44 changes: 25 additions & 19 deletions nlptoolkit/ASR/train_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
datefmt='%m/%d/%Y %I:%M:%S %p', level=logging.INFO)
logger = logging.getLogger(__file__)

def load_model_and_optimizer(args, vocab, max_features_length, max_seq_length, cuda, pyTransformer=False):
def load_model_and_optimizer(args, vocab, max_features_length, max_seq_length, cuda, amp=None, pyTransformer=False):

if pyTransformer:
from .models.Transformer.py_Transformer import pyTransformer as SpeechTransformer, \
Expand Down Expand Up @@ -45,16 +45,23 @@ def load_model_and_optimizer(args, vocab, max_features_length, max_seq_length, c
nn.init.xavier_uniform_(p)

criterion = nn.CrossEntropyLoss(ignore_index=1) # ignore padding tokens
optimizer = optim.Adam(net.parameters(), lr=args.lr, betas=(0.9, 0.98), eps=1e-9)

net, optimizer, start_epoch, acc, loaded_opt = load_state(net, cuda, args, load_best=False, \
amp=amp)

if cuda and (not loaded_opt):
net.cuda()

if (not loaded_opt):
optimizer = optim.Adam(net.parameters(), lr=args.lr, betas=(0.9, 0.98), eps=1e-9)
scheduler = CosineWithRestarts(optimizer, T_max=args.T_max)

model = SpeechTransformer if (args.model_no == 0) else LAS
model, loaded_optimizer, loaded_scheduler, start_epoch, acc = load_state(model, args, load_best=False, load_scheduler=False)
if (args.fp16) and (not loaded_opt) and (amp is not None):
logger.info("Using fp16...")
net, optimizer = amp.initialize(net, optimizer, opt_level='O2')
scheduler = CosineWithRestarts(optimizer, T_max=args.T_max)

if start_epoch != 0:
net = model; optimizer = loaded_optimizer; scheduler = loaded_scheduler
if cuda:
net.cuda()
logger.info("Done setting up model, optimizer and scheduler.")

if args.model_no == 0:
'''
Expand All @@ -79,13 +86,13 @@ def load_model_and_optimizer(args, vocab, max_features_length, max_seq_length, c

return net, criterion, optimizer, scheduler, start_epoch, acc, g_mask1, g_mask2

def load_state(net, args, load_best=False, load_scheduler=False):
def load_state(net, cuda, args, load_best=False, amp=None):
""" Loads saved model and optimizer states if exists """
loaded_opt = False
base_path = "./data/"
checkpoint_path = os.path.join(base_path,"test_checkpoint_%d.pth.tar" % args.model_no)
best_path = os.path.join(base_path,"test_model_best_%d.pth.tar" % args.model_no)
start_epoch, best_pred, checkpoint = 0, 0, None
optimizer, scheduler = None, None
if (load_best == True) and os.path.isfile(best_path):
checkpoint = torch.load(best_path)
logger.info("Loaded best model.")
Expand All @@ -96,16 +103,15 @@ def load_state(net, args, load_best=False, load_scheduler=False):
start_epoch = checkpoint['epoch']
best_pred = checkpoint['best_acc']
if load_best:
net = net.load_model(best_path)
net, optimizer = net.load_model(best_path, args, cuda, amp)
else:
net = net.load_model(checkpoint_path)
optimizer = optim.Adam(net.parameters(), lr=args.lr, betas=(0.9, 0.98), eps=1e-9)
scheduler = CosineWithRestarts(optimizer, T_max=args.T_max)
if load_scheduler:
optimizer.load_state_dict(checkpoint['optimizer'])
scheduler.load_state_dict(checkpoint['scheduler'])
logger.info("Loaded model and optimizer.")
return net, optimizer, scheduler, start_epoch, best_pred
net, optimizer = net.load_model(checkpoint_path, args, cuda, amp)

logger.info("Loaded model and optimizer.")
loaded_opt = True
else:
optimizer = None
return net, optimizer, start_epoch, best_pred, loaded_opt

def load_results(model_no=0):
""" Loads saved results if exists """
Expand Down
24 changes: 20 additions & 4 deletions nlptoolkit/ASR/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,21 @@ def train_and_fit(args, pyTransformer=False):
print("Max sequence length: %d" % max_seq_length)
vocab = load_pickle("vocab.pkl")

if args.fp16:
from apex import amp
else:
amp = None

logger.info("Loading model and optimizers...")
cuda = torch.cuda.is_available()
net, criterion, optimizer, scheduler, start_epoch, acc, g_mask1, g_mask2 = load_model_and_optimizer(args, vocab, \
max_features_length, \
max_seq_length, cuda,\
pyTransformer)
losses_per_epoch, accuracy_per_epoch = load_results(model_no=args.model_no)
batch_update_steps = 2
batch_update_steps = int(train_length/(args.batch_size*10))

optimizer.zero_grad()
logger.info("Starting training process...")
for e in range(start_epoch, args.num_epochs):
#l_rate = lrate(e + 1, d_model=32, k=10, warmup_n=25000)
Expand All @@ -65,9 +72,18 @@ def train_and_fit(args, pyTransformer=False):
outputs = outputs.view(-1, outputs.size(-1))
loss = criterion(outputs, labels);
loss = loss/args.gradient_acc_steps
loss.backward()
if pyTransformer:
clip_grad_norm_(net.parameters(), args.max_norm)

if args.fp16:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()

if args.fp16:
grad_norm = torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_norm)
else:
grad_norm = clip_grad_norm_(net.parameters(), args.max_norm)

if (i % args.gradient_acc_steps) == 0:
optimizer.step()
optimizer.zero_grad()
Expand Down
89 changes: 0 additions & 89 deletions nlptoolkit/summarization/evaluate.py

This file was deleted.

Loading

0 comments on commit 6ffd0fe

Please sign in to comment.