Skip to content

Commit

Permalink
Manually set autograd params for reduced memory
Browse files Browse the repository at this point in the history
  • Loading branch information
ryanleary committed Jun 10, 2017
1 parent 1e8e21d commit 8b67c9c
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 13 deletions.
8 changes: 4 additions & 4 deletions benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,15 @@ def iteration(input_data):
target_size = torch.IntTensor(batch_size).fill_(int((seconds * 100) / 2))
input_percentages = torch.IntTensor(batch_size).fill_(1)

inputs = Variable(input_data)
target_sizes = Variable(target_size)
targets = Variable(target)
inputs = Variable(input_data, requires_grad=False)
target_sizes = Variable(target_size requires_grad=False)
targets = Variable(target requires_grad=False)
start = time.time()
out = model(inputs)
out = out.transpose(0, 1) # TxNxH

seq_length = out.size(0)
sizes = Variable(input_percentages.mul_(int(seq_length)).int())
sizes = Variable(input_percentages.mul_(int(seq_length)).int() requires_grad=False)
loss = criterion(out, targets, sizes, target_sizes)
loss = loss / inputs.size(0) # average the loss by minibatch
# compute gradient
Expand Down
2 changes: 1 addition & 1 deletion predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
parser = SpectrogramParser(audio_conf, normalize=True)
spect = parser.parse_audio(args.audio_path).contiguous()
spect = spect.view(1, 1, spect.size(0), spect.size(1))
out = model(Variable(spect))
out = model(Variable(spect, volatile=True))
out = out.transpose(0, 1) # TxNxH
decoded_output = decoder.decode(out.data)
print(decoded_output[0])
4 changes: 2 additions & 2 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
for i, (data) in enumerate(test_loader):
inputs, targets, input_percentages, target_sizes = data

inputs = Variable(inputs)
inputs = Variable(inputs, volatile=True)

# unflatten targets
split_targets = []
Expand All @@ -49,7 +49,7 @@
out = model(inputs)
out = out.transpose(0, 1) # TxNxH
seq_length = out.size(0)
sizes = Variable(input_percentages.mul_(int(seq_length)).int())
sizes = Variable(input_percentages.mul_(int(seq_length)).int(), volatile=True)

decoded_output = decoder.decode(out.data, sizes)
target_strings = decoder.process_strings(decoder.convert_to_strings(split_targets))
Expand Down
12 changes: 6 additions & 6 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,9 +177,9 @@ def main():
inputs, targets, input_percentages, target_sizes = data
# measure data loading time
data_time.update(time.time() - end)
inputs = Variable(inputs)
target_sizes = Variable(target_sizes)
targets = Variable(targets)
inputs = Variable(inputs, requires_grad=False)
target_sizes = Variable(target_sizes, requires_grad=False)
targets = Variable(targets. reqiores_grad=False)

if args.cuda:
inputs = inputs.cuda()
Expand All @@ -188,7 +188,7 @@ def main():
out = out.transpose(0, 1) # TxNxH

seq_length = out.size(0)
sizes = Variable(input_percentages.mul_(int(seq_length)).int())
sizes = Variable(input_percentages.mul_(int(seq_length)).int(), requires_grad=False)

loss = criterion(out, targets, sizes, target_sizes)
loss = loss / inputs.size(0) # average the loss by minibatch
Expand Down Expand Up @@ -243,7 +243,7 @@ def main():
for i, (data) in enumerate(test_loader): # test
inputs, targets, input_percentages, target_sizes = data

inputs = Variable(inputs)
inputs = Variable(inputs, volatile=True)

# unflatten targets
split_targets = []
Expand All @@ -258,7 +258,7 @@ def main():
out = model(inputs)
out = out.transpose(0, 1) # TxNxH
seq_length = out.size(0)
sizes = Variable(input_percentages.mul_(int(seq_length)).int())
sizes = Variable(input_percentages.mul_(int(seq_length)).int(), volatile=True)

decoded_output = decoder.decode(out.data, sizes)
target_strings = decoder.process_strings(decoder.convert_to_strings(split_targets))
Expand Down

0 comments on commit 8b67c9c

Please sign in to comment.