Skip to content

Commit

Permalink
fix channel last
Browse files Browse the repository at this point in the history
  • Loading branch information
moskomule committed Dec 18, 2021
1 parent 0039d72 commit 219aa5b
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions homura/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,7 @@ def __init__(self,
self._use_channel_last = use_channel_last
if self._use_channel_last:
self.logger.warning("channel_last format is an experimental feature")
self.model.to(memory_format=torch.channels_last)
self.model = self.model.to(memory_format=torch.channels_last)
if report_accuracy_topk is not None:
if not isinstance(report_accuracy_topk, Iterable):
report_accuracy_topk = [report_accuracy_topk]
Expand Down Expand Up @@ -574,7 +574,7 @@ def iteration(self,

def data_preprocess(self,
data: tuple[Tensor, Tensor]
) -> (tuple[Tensor, Tensor], int):
) -> tuple[Tensor, Tensor]:
input, target = data
return (input.to(self.device, non_blocking=self._cuda_nonblocking,
memory_format=torch.channels_last if self._use_channel_last
Expand Down

0 comments on commit 219aa5b

Please sign in to comment.