Skip to content

Commit

Permalink
Add softmax during evaluation
Browse files Browse the repository at this point in the history
  • Loading branch information
ryanleary committed Jun 28, 2017
1 parent bf93286 commit 7837feb
Showing 1 changed file with 10 additions and 1 deletion.
11 changes: 10 additions & 1 deletion model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import torch
import torch.nn as nn
import torch.nn.functional as F

supported_rnns = {
'lstm': nn.LSTM,
Expand Down Expand Up @@ -35,6 +36,12 @@ def __repr__(self):
tmpstr += ')'
return tmpstr

class BatchSoftmax(nn.Module):
def forward(self, input_):
output_ = input_.transpose(0,1)
batch_size = output_.size()[0]
output_ = torch.stack([F.log_softmax(output_[i]) for i in range(batch_size)], 0)
return output_

class BatchRNN(nn.Module):
def __init__(self, input_size, hidden_size, rnn_type=nn.LSTM, bidirectional=False, batch_norm=True):
Expand Down Expand Up @@ -70,6 +77,7 @@ def __init__(self, rnn_type=nn.LSTM, labels="abc", rnn_hidden_size=768, nb_layer
self._rnn_type = rnn_type
self._audio_conf = audio_conf or {}
self._labels = labels
self._softmax = BatchSoftmax()

sample_rate = self._audio_conf.get("sample_rate", 16000)
window_size = self._audio_conf.get("window_size", 0.02)
Expand Down Expand Up @@ -116,7 +124,8 @@ def forward(self, x):
x = self.rnns(x)

x = self.fc(x)
x = x.transpose(0, 1) # Transpose for multi-gpu concat
if not self.training:
x = self._softmax(x)
return x

@classmethod
Expand Down

0 comments on commit 7837feb

Please sign in to comment.