Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change load_pretrained function to accept models from MoBY #184

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Change load pretrained to accept models from MoBY
Added a check to remove encoder prefixes as done in https://github.com/SwinTransformer/Swin-Transformer-Object-Detection
Added a head check to reinit if the head key isn't present
  • Loading branch information
Giles-Billenness committed Mar 16, 2022
commit 821c0161079acb11e36ae7d5ac03b1a35d4aa205
45 changes: 27 additions & 18 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ def load_pretrained(config, model, logger):
checkpoint = torch.load(config.MODEL.PRETRAINED, map_location='cpu')
state_dict = checkpoint['model']

# for MoBY, preplace prefixes
if sorted(list(state_dict.keys()))[0].startswith('encoder'):
state_dict = {k.replace('encoder.', ''): v for k, v in state_dict.items() if k.startswith('encoder.')}

# delete relative_position_index since we always re-init it
relative_position_index_keys = [k for k in state_dict.keys() if "relative_position_index" in k]
for k in relative_position_index_keys:
Expand Down Expand Up @@ -105,24 +109,29 @@ def load_pretrained(config, model, logger):
state_dict[k] = absolute_pos_embed_pretrained_resized

# check classifier, if not match, then re-init classifier to zero
head_bias_pretrained = state_dict['head.bias']
Nc1 = head_bias_pretrained.shape[0]
Nc2 = model.head.bias.shape[0]
if (Nc1 != Nc2):
if Nc1 == 21841 and Nc2 == 1000:
logger.info("loading ImageNet-22K weight to ImageNet-1K ......")
map22kto1k_path = f'data/map22kto1k.txt'
with open(map22kto1k_path) as f:
map22kto1k = f.readlines()
map22kto1k = [int(id22k.strip()) for id22k in map22kto1k]
state_dict['head.weight'] = state_dict['head.weight'][map22kto1k, :]
state_dict['head.bias'] = state_dict['head.bias'][map22kto1k]
else:
torch.nn.init.constant_(model.head.bias, 0.)
torch.nn.init.constant_(model.head.weight, 0.)
del state_dict['head.weight']
del state_dict['head.bias']
logger.warning(f"Error in loading classifier head, re-init classifier head to 0")
if ('head.bias' in state_dict):
head_bias_pretrained = state_dict['head.bias']
Nc1 = head_bias_pretrained.shape[0]
Nc2 = model.head.bias.shape[0]
if (Nc1 != Nc2):
if Nc1 == 21841 and Nc2 == 1000:
logger.info("loading ImageNet-22K weight to ImageNet-1K ......")
map22kto1k_path = f'data/map22kto1k.txt'
with open(map22kto1k_path) as f:
map22kto1k = f.readlines()
map22kto1k = [int(id22k.strip()) for id22k in map22kto1k]
state_dict['head.weight'] = state_dict['head.weight'][map22kto1k, :]
state_dict['head.bias'] = state_dict['head.bias'][map22kto1k]
else:
torch.nn.init.constant_(model.head.bias, 0.)
torch.nn.init.constant_(model.head.weight, 0.)
del state_dict['head.weight']
del state_dict['head.bias']
logger.warning(f"Error in loading classifier head, re-init classifier head to 0")
else:
torch.nn.init.constant_(model.head.bias, 0.)
torch.nn.init.constant_(model.head.weight, 0.)
logger.warning(f"Error in loading classifier head, re-init classifier head to 0")

msg = model.load_state_dict(state_dict, strict=False)
logger.warning(msg)
Expand Down