Skip to content

Commit

Permalink
Merge pull request #468 from yoshitomo-matsubara/dev
Browse files Browse the repository at this point in the history
Use warning
  • Loading branch information
yoshitomo-matsubara committed May 25, 2024
2 parents 1fe3088 + bfc175a commit dd17646
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 9 deletions.
6 changes: 3 additions & 3 deletions torchdistill/common/main_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def load_ckpt(ckpt_file_path, model=None, optimizer=None, lr_scheduler=None, str
logger.info('Loading model parameters only')
model.load_state_dict(ckpt, strict=strict)
else:
logger.info('No model parameters found')
logger.warning('No model parameters found')

if optimizer is not None:
if 'optimizer' in ckpt:
Expand All @@ -306,7 +306,7 @@ def load_ckpt(ckpt_file_path, model=None, optimizer=None, lr_scheduler=None, str
logger.info('Loading optimizer parameters only')
optimizer.load_state_dict(ckpt)
else:
logger.info('No optimizer parameters found')
logger.warning('No optimizer parameters found')

if lr_scheduler is not None:
if 'lr_scheduler' in ckpt:
Expand All @@ -316,7 +316,7 @@ def load_ckpt(ckpt_file_path, model=None, optimizer=None, lr_scheduler=None, str
logger.info('Loading scheduler parameters only')
lr_scheduler.load_state_dict(ckpt)
else:
logger.info('No scheduler parameters found')
logger.warning('No scheduler parameters found')
return ckpt.get('best_value', 0.0), ckpt.get('args', None)


Expand Down
10 changes: 6 additions & 4 deletions torchdistill/common/module_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,15 +105,17 @@ def get_module(root_module, module_path):
if isinstance(module, Sequential) and module_name.lstrip('-').isnumeric():
module = module[int(module_name)]
else:
logger.info('`{}` of `{}` could not be reached in `{}`'.format(module_name, module_path,
type(root_module).__name__))
logger.warning('`{}` of `{}` could not be reached in `{}`'.format(
module_name, module_path, type(root_module).__name__)
)
else:
module = getattr(module, module_name)
elif isinstance(module, (Sequential, ModuleList)) and module_name.lstrip('-').isnumeric():
module = module[int(module_name)]
else:
logger.info('`{}` of `{}` could not be reached in `{}`'.format(module_name, module_path,
type(root_module).__name__))
logger.warning('`{}` of `{}` could not be reached in `{}`'.format(
module_name, module_path, type(root_module).__name__)
)
return None
else:
module = getattr(module, module_name)
Expand Down
1 change: 0 additions & 1 deletion torchdistill/models/custom/bottleneck/detection/rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ def custom_maskrcnn_resnet_fpn(backbone, weights=None, progress=True,
mask_roi_pool = None if num_feature_maps == 4 \
else MultiScaleRoIAlign(featmap_names=[str(i) for i in range(num_feature_maps)],
output_size=14, sampling_ratio=2)
print(kwargs)
model = MaskRCNN(backbone_model, num_classes, box_roi_pool=box_roi_pool, mask_roi_pool=mask_roi_pool, **kwargs)
if weights is not None:
state_dict = \
Expand Down
1 change: 0 additions & 1 deletion torchdistill/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from ..common import misc_util

MODEL_DICT = dict()
MODEL_DICT = dict()
ADAPTATION_MODULE_DICT = dict()
AUXILIARY_MODEL_WRAPPER_DICT = dict()
Expand Down

0 comments on commit dd17646

Please sign in to comment.