Skip to content

Commit

Permalink
Merge pull request SeanNaren#124 from dmitriy-serdyuk/refactor-softmax
Browse files Browse the repository at this point in the history
Refactor model output to reflect its type
  • Loading branch information
Sean Naren committed Aug 1, 2017
2 parents 3ffa9a5 + bf22fb7 commit a3a04f6
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __repr__(self):
return tmpstr


class InferenceBatchSoftmax(nn.Module):
class InferenceBatchLogSoftmax(nn.Module):
def forward(self, input_):
if not self.training:
batch_size = input_.size()[0]
Expand Down Expand Up @@ -115,7 +115,7 @@ def __init__(self, rnn_type=nn.LSTM, labels="abc", rnn_hidden_size=768, nb_layer
self.fc = nn.Sequential(
SequenceWise(fully_connected),
)
self.softmax = InferenceBatchSoftmax()
self.inference_log_softmax = InferenceBatchLogSoftmax()

def forward(self, x):
x = self.conv(x)
Expand All @@ -128,7 +128,8 @@ def forward(self, x):

x = self.fc(x)
x = x.transpose(0, 1)
x = self.softmax(x)
# identity in training mode, logsoftmax in eval mode
x = self.inference_log_softmax(x)
return x

@classmethod
Expand Down

0 comments on commit a3a04f6

Please sign in to comment.