Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
moskomule committed Dec 10, 2021
1 parent fead152 commit 0039d72
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 4 deletions.
2 changes: 0 additions & 2 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,7 @@ Welcome to `homura`'s documentation!

homura.metrics
homura.modules

homura.utils
homura.nlp
homura.vision

Indices and tables
Expand Down
3 changes: 2 additions & 1 deletion homura/modules/ema.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ def parameters(self, recurse: bool = True) -> Iterator[nn.Parameter]:
return self._original_model.parameters(recurse)

def requires_grad_(self, requires_grad: bool = True) -> nn.Module:
return self._original_model.requires_grad_(requires_grad)
self._original_model.requires_grad_(requires_grad)
return self

@torch.no_grad()
def _update(self):
Expand Down
7 changes: 6 additions & 1 deletion homura/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ def __init__(self,
if is_distributed():
if self._use_sync_bn:
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
(self.logger.info if is_master() else self.logger.debug)("BNs of model are converted to nn.SyncBatchNorm")
(self.logger.info if is_master() else self.logger.debug)(
"BNs of model are converted to nn.SyncBatchNorm")

rank = get_local_rank()
torch.cuda.set_device(rank)
Expand Down Expand Up @@ -385,6 +386,7 @@ def run(self,
class ProxyLoader(object):
def __init__(self, loader):
self.loader = loader
self._epoch = 0

def __len__(self):
return val_intervals
Expand All @@ -397,6 +399,9 @@ def __iter__(self):
return # from python 3.7, this is valid
yield data
counter += 1
self._epoch += 1
if hasattr(self.loader.sampler, 'set_epoch'):
self.loader.sampler.set_epoch(self._epoch)

train_loader = ProxyLoader(train_loader)
if not isinstance(val_loaders, dict) and (isinstance(val_loaders, Iterable) or
Expand Down

0 comments on commit 0039d72

Please sign in to comment.