From ea7b42475182fb98ee931dba98c44744544b5e16 Mon Sep 17 00:00:00 2001 From: "sean.narenthiran" Date: Mon, 26 Aug 2019 09:07:21 -0500 Subject: [PATCH] Fixed bug in testing script, swap to BoolTensor for PyTorch 1.2 support --- model.py | 2 +- test.py | 6 +----- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/model.py b/model.py index c9f7e49c..f90da3aa 100644 --- a/model.py +++ b/model.py @@ -57,7 +57,7 @@ def forward(self, x, lengths): """ for module in self.seq_module: x = module(x) - mask = torch.ByteTensor(x.size()).fill_(0) + mask = torch.BoolTensor(x.size()).fill_(0) if x.is_cuda: mask = mask.cuda() for i, length in enumerate(lengths): diff --git a/test.py b/test.py index ac4ad208..b59d24d8 100644 --- a/test.py +++ b/test.py @@ -39,14 +39,10 @@ def evaluate(test_loader, device, model, decoder, target_decoder, save_output=Fa out, output_sizes = model(inputs, input_sizes) - if save_output: - # add output to data array, and continue - output_data.append((out.cpu().numpy(), output_sizes.numpy())) - decoded_output, _ = decoder.decode(out, output_sizes) target_strings = target_decoder.convert_to_strings(split_targets) - if args.save_output is not None: + if save_output is not None: # add output to data array, and continue output_data.append((out.cpu().numpy(), output_sizes.numpy(), target_strings)) for x in range(len(target_strings)):