Skip to content
This repository has been archived by the owner on Feb 12, 2022. It is now read-only.

Commit

Permalink
Generate was broken for QRNN as QRNN didn't flush their previously st…
Browse files Browse the repository at this point in the history
…ored X, resulting in diff batch sizes, hence angry concat dimensions
  • Loading branch information
Smerity committed Nov 26, 2017
1 parent 66107f8 commit 9c62358
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
# Model parameters.
parser.add_argument('--data', type=str, default='./data/penn',
help='location of the data corpus')
parser.add_argument('--model', type=str, default='LSTM',
help='type of recurrent net (LSTM, QRNN)')
parser.add_argument('--checkpoint', type=str, default='./model.pt',
help='model checkpoint to use')
parser.add_argument('--outf', type=str, default='generated.txt',
Expand Down Expand Up @@ -47,6 +49,8 @@
with open(args.checkpoint, 'rb') as f:
model = torch.load(f)
model.eval()
if args.model == 'QRNN':
model.reset()

if args.cuda:
model.cuda()
Expand Down

0 comments on commit 9c62358

Please sign in to comment.