-
Notifications
You must be signed in to change notification settings - Fork 2.5k
Low instance segmentation AP result on Cityscapes ? #611
Comments
If starting with the model which pre-trained on COCO, I can get around 35.0(Mask) mAP on ResNet-50. Still 1~2 mAP less than the numbers reported in Mask R-CNN. Did you change the image size for training and inference to the following?
|
Thank you. I will retry. |
With correct test image size, I get 33.5 AP, just less than official result by 1.9. |
hi @chenchr
|
@zimenglan-sysu-512 Hi. Is the model trained from imagenet weight or from coco ? |
@chenchr |
hi @chenchr here is my script to do model surgery on trained model from coco: def clip_weights_from_pretrain_of_coco_to_cityscapes(f, out_file):
""""""
from maskrcnn_benchmark.config.paths_catalog import COCO_CATEGORIES
from maskrcnn_benchmark.config.paths_catalog import CITYSCAPES_FINE_CATEGORIES
coco_cats = COCO_CATEGORIES
cityscapes_cats = CITYSCAPES_FINE_CATEGORIES
coco_cats_to_inds = dict(zip(coco_cats, range(len(coco_cats))))
cityscapes_cats_to_inds = dict(
zip(cityscapes_cats, range(len(cityscapes_cats)))
)
checkpoint = torch.load(f)
m = checkpoint['model']
weight_names = {
"cls_score": "module.roi_heads.box.predictor.cls_score.weight",
"bbox_pred": "module.roi_heads.box.predictor.bbox_pred.weight",
"mask_fcn_logits": "module.roi_heads.mask.predictor.mask_fcn_logits.weight",
}
bias_names = {
"cls_score": "module.roi_heads.box.predictor.cls_score.bias",
"bbox_pred": "module.roi_heads.box.predictor.bbox_pred.bias",
"mask_fcn_logits": "module.roi_heads.mask.predictor.mask_fcn_logits.bias",
}
representation_size = m[weight_names["cls_score"]].size(1)
cls_score = nn.Linear(representation_size, len(cityscapes_cats))
nn.init.normal_(cls_score.weight, std=0.01)
nn.init.constant_(cls_score.bias, 0)
representation_size = m[weight_names["bbox_pred"]].size(1)
class_agnostic = m[weight_names["bbox_pred"]].size(0) != len(coco_cats) * 4
num_bbox_reg_classes = 2 if class_agnostic else len(cityscapes_cats)
bbox_pred = nn.Linear(representation_size, num_bbox_reg_classes * 4)
nn.init.normal_(bbox_pred.weight, std=0.001)
nn.init.constant_(bbox_pred.bias, 0)
dim_reduced = m[weight_names["mask_fcn_logits"]].size(1)
mask_fcn_logits = Conv2d(dim_reduced, len(cityscapes_cats), 1, 1, 0)
nn.init.constant_(mask_fcn_logits.bias, 0)
nn.init.kaiming_normal_(
mask_fcn_logits.weight, mode="fan_out", nonlinearity="relu"
)
def _copy_weight(src_weight, dst_weight):
for ix, cat in enumerate(cityscapes_cats):
if cat not in coco_cats:
continue
jx = coco_cats_to_inds[cat]
dst_weight[ix] = src_weight[jx]
return dst_weight
def _copy_bias(src_bias, dst_bias, class_agnostic=False):
if class_agnostic:
return dst_bias
return _copy_weight(src_bias, dst_bias)
m[weight_names["cls_score"]] = _copy_weight(
m[weight_names["cls_score"]], cls_score.weight
)
m[weight_names["bbox_pred"]] = _copy_weight(
m[weight_names["bbox_pred"]], bbox_pred.weight
)
m[weight_names["mask_fcn_logits"]] = _copy_weight(
m[weight_names["mask_fcn_logits"]], mask_fcn_logits.weight
)
m[bias_names["cls_score"]] = _copy_bias(
m[bias_names["cls_score"]], cls_score.bias
)
m[bias_names["bbox_pred"]] = _copy_bias(
m[bias_names["bbox_pred"]], bbox_pred.bias, class_agnostic
)
m[bias_names["mask_fcn_logits"]] = _copy_bias(
m[bias_names["mask_fcn_logits"]], mask_fcn_logits.bias
)
print("f: {}\nout_file: {}".format(f, out_file))
torch.save(m, out_file) |
Considering the Mask R-CNN paper and the configs for keypoint_rcnn in Detectron, the following setting seems suitable for reproduction.
|
hi @shinya7y |
Hi @zimenglan-sysu-512 |
@chengyangfu @zimenglan-sysu-512 @shinya7y Thank you all. |
@zimenglan-sysu-512 would you mind sending a PR improving the documentation / adding the scripts that you showed in this thread into |
@zimenglan-sysu-512 Thank you. @chenchr Please try to use the following solver setting.
(1) The batchsize should be 8 for reproduction. |
I believe this issue can now be closed thanks to @zimenglan-sysu-512 PR #697 |
hi, @shinya7y thanks a lot! i set them to SOLVER: and i set the MODEL.RPN.FPN_POST_NMS_TOP_N_TRAIN 2000 is it right? and i'm training waiting for a fine result |
Hi, |
Hello, did anyone train or finetune mask-rcnn on cityscapes ?
I just found the result can not match the one reported in mask-rcnn paper.
For res-50-fpn, I trained from imagenet pretrained weight, and I got 22.55 segmentation AP on validation set,
For model finetuned from coco pretrained weight, I got 25.46 AP.
However these two results reported in the paper are 31.5 and 36.4..
There are still a large gap.
Up to now the only difference I found is that I train the model with batchsize as 16 on 4 gpu while mask-rcnn paper uses batchsize as 8 on 8 gpu..
But I don't think it is the cause as usually large batchsize get better result.
I am going to train with batchsize as 8 tomorrow, any suggestions ?
Thanks!
The text was updated successfully, but these errors were encountered: