Skip to content

Commit

Permalink
fix the bugs to make natvie experiments run normally
Browse files Browse the repository at this point in the history
  • Loading branch information
ohyeat committed Jan 28, 2020
1 parent 19e897f commit c23e676
Showing 1 changed file with 32 additions and 65 deletions.
97 changes: 32 additions & 65 deletions det/maskrcnn_benchmark/utils/c2_model_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,20 @@
from maskrcnn_benchmark.utils.registry import Registry


def _rename_basic_resnet_weights(layer_keys):
def _rename_basic_resnet_weights(layer_keys, no_bn_fuse=False):
layer_keys = [k.replace("_", ".") for k in layer_keys]
layer_keys = [k.replace(".w", ".weight") for k in layer_keys]
layer_keys = [k.replace(".bn", "_bn") for k in layer_keys]
layer_keys = [k.replace(".beta", "._beta") for k in layer_keys]
layer_keys = [k.replace(".b", ".bias") for k in layer_keys]
layer_keys = [k.replace("._beta", ".bias") for k in layer_keys]
layer_keys = [k.replace(".gamma", ".weight") for k in layer_keys]
layer_keys = [k.replace("running.mean", "running_mean") for k in layer_keys]
layer_keys = [k.replace("running.var", "running_var") for k in layer_keys]
# layer_keys = [k.replace("_bn.s", "_bn.scale") for k in layer_keys]
if not no_bn_fuse:
layer_keys = [k.replace("_bn.s", "_bn.scale") for k in layer_keys]
else:
layer_keys = [k.replace(".beta", "._beta") for k in layer_keys]
layer_keys = [k.replace("._beta", ".bias") for k in layer_keys]
layer_keys = [k.replace(".gamma", ".weight") for k in layer_keys]
layer_keys = [k.replace("running.mean", "running_mean") for k in layer_keys]
layer_keys = [k.replace("running.var", "running_var") for k in layer_keys]

layer_keys = [k.replace(".biasranch", ".branch") for k in layer_keys]
layer_keys = [k.replace("bbox.pred", "bbox_pred") for k in layer_keys]
layer_keys = [k.replace("cls.score", "cls_score") for k in layer_keys]
Expand All @@ -34,7 +37,8 @@ def _rename_basic_resnet_weights(layer_keys):
layer_keys]

# Affine-Channel -> BatchNorm enaming
# layer_keys = [k.replace("_bn.scale", "_bn.weight") for k in layer_keys]
if not no_bn_fuse:
layer_keys = [k.replace("_bn.scale", "_bn.weight") for k in layer_keys]

# Make torchvision-compatible
layer_keys = [k.replace("conv1_bn.", "bn1.") for k in layer_keys]
Expand Down Expand Up @@ -66,60 +70,9 @@ def _rename_basic_resnet_weights(layer_keys):
for k in layer_keys]
layer_keys = [k.replace("downsample.0.gn.bias", "downsample.1.bias") \
for k in layer_keys]

return layer_keys

# def _rename_basic_resnet_weights(layer_keys):
# layer_keys = [k.replace("_", ".") for k in layer_keys]
# layer_keys = [k.replace(".w", ".weight") for k in layer_keys]
# layer_keys = [k.replace(".bn", "_bn") for k in layer_keys]
# layer_keys = [k.replace(".b", ".bias") for k in layer_keys]
# layer_keys = [k.replace("_bn.s", "_bn.scale") for k in layer_keys]
# layer_keys = [k.replace(".biasranch", ".branch") for k in layer_keys]
# layer_keys = [k.replace("bbox.pred", "bbox_pred") for k in layer_keys]
# layer_keys = [k.replace("cls.score", "cls_score") for k in layer_keys]
# layer_keys = [k.replace("res.conv1_", "conv1_") for k in layer_keys]

# # RPN / Faster RCNN
# layer_keys = [k.replace(".biasbox", ".bbox") for k in layer_keys]
# layer_keys = [k.replace("conv.rpn", "rpn.conv") for k in layer_keys]
# layer_keys = [k.replace("rpn.bbox.pred", "rpn.bbox_pred") for k in layer_keys]
# layer_keys = [k.replace("rpn.cls.logits", "rpn.cls_logits") for k in layer_keys]

# # Affine-Channel -> BatchNorm enaming
# layer_keys = [k.replace("_bn.scale", "_bn.weight") for k in layer_keys]

# # Make torchvision-compatible
# layer_keys = [k.replace("conv1_bn.", "bn1.") for k in layer_keys]

# layer_keys = [k.replace("res2.", "layer1.") for k in layer_keys]
# layer_keys = [k.replace("res3.", "layer2.") for k in layer_keys]
# layer_keys = [k.replace("res4.", "layer3.") for k in layer_keys]
# layer_keys = [k.replace("res5.", "layer4.") for k in layer_keys]

# layer_keys = [k.replace(".branch2a.", ".conv1.") for k in layer_keys]
# layer_keys = [k.replace(".branch2a_bn.", ".bn1.") for k in layer_keys]
# layer_keys = [k.replace(".branch2b.", ".conv2.") for k in layer_keys]
# layer_keys = [k.replace(".branch2b_bn.", ".bn2.") for k in layer_keys]
# layer_keys = [k.replace(".branch2c.", ".conv3.") for k in layer_keys]
# layer_keys = [k.replace(".branch2c_bn.", ".bn3.") for k in layer_keys]

# layer_keys = [k.replace(".branch1.", ".downsample.0.") for k in layer_keys]
# layer_keys = [k.replace(".branch1_bn.", ".downsample.1.") for k in layer_keys]

# # GroupNorm
# layer_keys = [k.replace("conv1.gn.s", "bn1.weight") for k in layer_keys]
# layer_keys = [k.replace("conv1.gn.bias", "bn1.bias") for k in layer_keys]
# layer_keys = [k.replace("conv2.gn.s", "bn2.weight") for k in layer_keys]
# layer_keys = [k.replace("conv2.gn.bias", "bn2.bias") for k in layer_keys]
# layer_keys = [k.replace("conv3.gn.s", "bn3.weight") for k in layer_keys]
# layer_keys = [k.replace("conv3.gn.bias", "bn3.bias") for k in layer_keys]
# layer_keys = [k.replace("downsample.0.gn.s", "downsample.1.weight") \
# for k in layer_keys]
# layer_keys = [k.replace("downsample.0.gn.bias", "downsample.1.bias") \
# for k in layer_keys]

# return layer_keys

def _rename_fpn_weights(layer_keys, stage_names):
for mapped_idx, stage_name in enumerate(stage_names, 1):
suffix = ""
Expand All @@ -140,7 +93,7 @@ def _rename_fpn_weights(layer_keys, stage_names):
return layer_keys


def _rename_weights_for_resnet(weights, stage_names):
def _rename_weights_for_resnet(weights, stage_names, no_bn_fuse=False):
original_keys = sorted(weights.keys())
layer_keys = sorted(weights.keys())

Expand All @@ -149,7 +102,7 @@ def _rename_weights_for_resnet(weights, stage_names):
layer_keys = [k if k != "pred_w" else "fc1000_w" for k in layer_keys]

# performs basic renaming: _ -> . , etc
layer_keys = _rename_basic_resnet_weights(layer_keys)
layer_keys = _rename_basic_resnet_weights(layer_keys, no_bn_fuse)

# FPN
layer_keys = _rename_fpn_weights(layer_keys, stage_names)
Expand Down Expand Up @@ -244,23 +197,37 @@ def _rename_conv_weights_for_deformable_conv_layers(state_dict, cfg):
@C2_FORMAT_LOADER.register("R-101-C5")
@C2_FORMAT_LOADER.register("R-50-FPN")
@C2_FORMAT_LOADER.register("R-50-FPN-RETINANET")
@C2_FORMAT_LOADER.register("R-50-FPN-SYNCBN")
@C2_FORMAT_LOADER.register("R-101-FPN")
@C2_FORMAT_LOADER.register("R-101-FPN-RETINANET")
@C2_FORMAT_LOADER.register("R-152-FPN")
def load_resnet_c2_format(cfg, f):
state_dict = _load_c2_pickled_weights(f)
conv_body = cfg.MODEL.BACKBONE.CONV_BODY
arch = conv_body.replace("-C4", "").replace("-C5", "").replace("-FPN", "").replace("-SYNCBN", "")
arch = conv_body.replace("-C4", "").replace("-C5", "").replace("-FPN", "")
arch = arch.replace("-RETINANET", "")
stages = _C2_STAGE_NAMES[arch]
state_dict = _rename_weights_for_resnet(state_dict, stages)
state_dict = _rename_weights_for_resnet(state_dict, stages, False)
# ***********************************
# for deformable convolutional layer
state_dict = _rename_conv_weights_for_deformable_conv_layers(state_dict, cfg)
# ***********************************
return dict(model=state_dict)


@C2_FORMAT_LOADER.register("R-50-FPN-SYNCBN")
@C2_FORMAT_LOADER.register("R-50-FPN-MABN")
def load_resnet_c2_format_bn_no_fuse(cfg, f):
state_dict = _load_c2_pickled_weights(f)
conv_body = cfg.MODEL.BACKBONE.CONV_BODY
arch = conv_body.replace("-C4", "").replace("-C5", "").replace("-FPN", "").replace("-SYNCBN", "").replace("-MABN", "")
arch = arch.replace("-RETINANET", "")
stages = _C2_STAGE_NAMES[arch]
state_dict = _rename_weights_for_resnet(state_dict, stages, True)
# ***********************************
# for deformable convolutional layer
state_dict = _rename_conv_weights_for_deformable_conv_layers(state_dict, cfg)
# ***********************************
return dict(model=state_dict)

def load_c2_format(cfg, f):
return C2_FORMAT_LOADER[cfg.MODEL.BACKBONE.CONV_BODY](cfg, f)

0 comments on commit c23e676

Please sign in to comment.