Skip to content

Commit

Permalink
fix the sceduling operation
Browse files Browse the repository at this point in the history
  • Loading branch information
jadore801120 committed Jun 24, 2017
1 parent c11e69e commit 7ea7a9f
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 21 deletions.
33 changes: 13 additions & 20 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@
import torch.optim as optim
import transformer.Constants as Constants
from transformer.Models import Transformer
from transformer.Optim import ScheduledOptim
from DataLoader import DataLoader


def get_performance(crit, pred, gold, smoothing=False, num_class=None):
''' Apply label smoothing if needed '''

Expand Down Expand Up @@ -57,14 +59,15 @@ def train_epoch(model, training_data, crit, optimizer):

# update parameters
optimizer.step()
optimizer.update_learning_rate()

# note keeping
n_words = gold.data.ne(Constants.PAD).sum()
n_total_words += n_words
n_total_correct += n_correct
total_loss += loss.data[0] / len(training_data)
total_loss += loss.data[0]

return total_loss, n_total_correct/n_total_words
return total_loss/n_total_words, n_total_correct/n_total_words

def eval_epoch(model, validation_data, crit):
''' Epoch operation in evaluation phase '''
Expand All @@ -91,31 +94,18 @@ def eval_epoch(model, validation_data, crit):
n_words = gold.data.ne(Constants.PAD).sum()
n_total_words += n_words
n_total_correct += n_correct
total_loss += loss.data[0] / len(validation_data)
total_loss += loss.data[0]

return total_loss, n_total_correct/n_total_words
return total_loss/n_total_words, n_total_correct/n_total_words


def train(model, training_data, validation_data, crit, optimizer, opt):
''' Start training '''

def update_learning_rate(n_steps):
''' Learning rate scheduling '''

n_steps += 1
new_lr = np.power(opt.d_model, -0.5) * np.min([
np.power(n_steps, -0.5),
np.power(opt.n_warmup_steps, -1.5) * n_steps])

for param_group in optimizer.param_groups:
param_group['lr'] = new_lr

valid_accus = []
for epoch_i in range(opt.epoch):
print('[ Epoch', epoch_i, ']')

update_learning_rate(epoch_i)

train_loss, train_accu = train_epoch(model, training_data, crit, optimizer)
print(' - (Training) loss: {loss: 8.5f}, accuracy: {accu:3.3} %'.format(
loss=train_loss, accu=100*train_accu))
Expand Down Expand Up @@ -222,9 +212,12 @@ def main():

#print(transformer)

optimizer = optim.Adam(
transformer.get_trainable_parameters(),
betas=(0.9, 0.98), eps=1e-09)
optimizer = ScheduledOptim(
optim.Adam(
transformer.get_trainable_parameters(),
betas=(0.9, 0.98), eps=1e-09),
opt.d_model, opt.n_warmup_steps)


def get_criterion(vocab_size):
''' With PAD token zero weight '''
Expand Down
30 changes: 30 additions & 0 deletions transformer/Optim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
'''A wrapper class for optimizer '''
import numpy as np

class ScheduledOptim(object):
'''A simple wrapper class for learning rate scheduling'''

def __init__(self, optimizer, d_model, n_warmup_steps):
self.optimizer = optimizer
self.d_model = d_model
self.n_warmup_steps = n_warmup_steps
self.n_current_steps = 0

def step(self):
"Step by the inner optimizer"
self.optimizer.step()

def zero_grad(self):
"Zero out the gradients by the inner optimizer"
self.optimizer.zero_grad()

def update_learning_rate(self):
''' Learning rate scheduling per step '''

self.n_current_steps += 1
new_lr = np.power(self.d_model, -0.5) * np.min([
np.power(self.n_current_steps, -0.5),
np.power(self.n_warmup_steps, -1.5) * self.n_current_steps])

for param_group in self.optimizer.param_groups:
param_group['lr'] = new_lr
3 changes: 2 additions & 1 deletion transformer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
import transformer.Models
import transformer.Translator
import transformer.Beam
import transformer.Optim

__all__ = [
transformer.Constants, transformer.Modules, transformer.Layers,
transformer.SubLayers, transformer.Models,
transformer.SubLayers, transformer.Models, transformer.Optim,
transformer.Translator, transformer.Beam]

0 comments on commit 7ea7a9f

Please sign in to comment.