Skip to content
This repository has been archived by the owner on Oct 31, 2023. It is now read-only.

Low instance segmentation AP result on Cityscapes ? #611

Closed
chenchr opened this issue Mar 27, 2019 · 19 comments
Closed

Low instance segmentation AP result on Cityscapes ? #611

chenchr opened this issue Mar 27, 2019 · 19 comments

Comments

@chenchr
Copy link

chenchr commented Mar 27, 2019

Hello, did anyone train or finetune mask-rcnn on cityscapes ?
I just found the result can not match the one reported in mask-rcnn paper.
For res-50-fpn, I trained from imagenet pretrained weight, and I got 22.55 segmentation AP on validation set,
For model finetuned from coco pretrained weight, I got 25.46 AP.
However these two results reported in the paper are 31.5 and 36.4..
There are still a large gap.
Up to now the only difference I found is that I train the model with batchsize as 16 on 4 gpu while mask-rcnn paper uses batchsize as 8 on 8 gpu..
But I don't think it is the cause as usually large batchsize get better result.
I am going to train with batchsize as 8 tomorrow, any suggestions ?
Thanks!

@chengyangfu
Copy link
Contributor

chengyangfu commented Mar 27, 2019

If starting with the model which pre-trained on COCO, I can get around 35.0(Mask) mAP on ResNet-50. Still 1~2 mAP less than the numbers reported in Mask R-CNN.

Did you change the image size for training and inference to the following?

INPUT:
      MIN_SIZE_TRAIN: (800, 1024)
      MAX_SIZE_TRAIN: 2048
      MIN_SIZE_TEST: 1024
      MAX_SIZE_TEST: 2048

@chenchr
Copy link
Author

chenchr commented Mar 28, 2019

If starting with the model which pre-trained on COCO, I can get around 35.0(Mask) mAP on ResNet-50. Still 1~2 mAP less than the numbers reported in Mask R-CNN.

Did you change the image size for training and inference to the following?

INPUT:
      MIN_SIZE_TRAIN: (800, 1024)
      MAX_SIZE_TRAIN: 2048
      MIN_SIZE_TEST: 1024
      MAX_SIZE_TEST: 2048

Thank you. I will retry.
Another weird phenomenon is that I get a higher AP as 27.76 after training with batchsize 8..
When training with batchsize 16, I only get 25.46 AP..

@chenchr
Copy link
Author

chenchr commented Mar 28, 2019

With correct test image size, I get 33.5 AP, just less than official result by 1.9.

@zimenglan-sysu-512
Copy link
Contributor

zimenglan-sysu-512 commented Mar 29, 2019

hi @chenchr
here is my input setting, and i use single gpu or 4 gpus to train cityscapes, and get the same results: 35.7(mask), still a little gap less than the numbers (36.4) reported in Mask R-CNN

  MIN_SIZE_TRAIN: (800, 1024)
  MAX_SIZE_TRAIN: 2048
  MIN_SIZE_TEST: 1024
  MAX_SIZE_TEST: 2048

@chenchr
Copy link
Author

chenchr commented Mar 29, 2019

@zimenglan-sysu-512 Hi. Is the model trained from imagenet weight or from coco ?

@zimenglan-sysu-512
Copy link
Contributor

@chenchr
i finetune the model from coco

@zimenglan-sysu-512
Copy link
Contributor

hi @chenchr

here is my script to do model surgery on trained model from coco:

def clip_weights_from_pretrain_of_coco_to_cityscapes(f, out_file):
	""""""
	from maskrcnn_benchmark.config.paths_catalog import COCO_CATEGORIES
	from maskrcnn_benchmark.config.paths_catalog import CITYSCAPES_FINE_CATEGORIES
	coco_cats = COCO_CATEGORIES
	cityscapes_cats = CITYSCAPES_FINE_CATEGORIES
	coco_cats_to_inds = dict(zip(coco_cats, range(len(coco_cats))))
	cityscapes_cats_to_inds = dict(
		zip(cityscapes_cats, range(len(cityscapes_cats)))
	)

	checkpoint = torch.load(f)
	m = checkpoint['model']

	weight_names = {
		"cls_score": "module.roi_heads.box.predictor.cls_score.weight", 
		"bbox_pred": "module.roi_heads.box.predictor.bbox_pred.weight", 
		"mask_fcn_logits": "module.roi_heads.mask.predictor.mask_fcn_logits.weight", 
	}
	bias_names = {
		"cls_score": "module.roi_heads.box.predictor.cls_score.bias",
		"bbox_pred": "module.roi_heads.box.predictor.bbox_pred.bias", 
		"mask_fcn_logits": "module.roi_heads.mask.predictor.mask_fcn_logits.bias",
	}
	
	representation_size = m[weight_names["cls_score"]].size(1)
	cls_score = nn.Linear(representation_size, len(cityscapes_cats))
	nn.init.normal_(cls_score.weight, std=0.01)
	nn.init.constant_(cls_score.bias, 0)

	representation_size = m[weight_names["bbox_pred"]].size(1)
	class_agnostic = m[weight_names["bbox_pred"]].size(0) != len(coco_cats) * 4
	num_bbox_reg_classes = 2 if class_agnostic else len(cityscapes_cats)
	bbox_pred = nn.Linear(representation_size, num_bbox_reg_classes * 4)
	nn.init.normal_(bbox_pred.weight, std=0.001)
	nn.init.constant_(bbox_pred.bias, 0)

	dim_reduced = m[weight_names["mask_fcn_logits"]].size(1)
	mask_fcn_logits = Conv2d(dim_reduced, len(cityscapes_cats), 1, 1, 0)
	nn.init.constant_(mask_fcn_logits.bias, 0)
	nn.init.kaiming_normal_(
		mask_fcn_logits.weight, mode="fan_out", nonlinearity="relu"
	)
	
	def _copy_weight(src_weight, dst_weight):
		for ix, cat in enumerate(cityscapes_cats):
			if cat not in coco_cats:
				continue
			jx = coco_cats_to_inds[cat]
			dst_weight[ix] = src_weight[jx]
		return dst_weight

	def _copy_bias(src_bias, dst_bias, class_agnostic=False):
		if class_agnostic:
			return dst_bias
		return _copy_weight(src_bias, dst_bias)

	m[weight_names["cls_score"]] = _copy_weight(
		m[weight_names["cls_score"]], cls_score.weight
	)
	m[weight_names["bbox_pred"]] = _copy_weight(
		m[weight_names["bbox_pred"]], bbox_pred.weight
	)
	m[weight_names["mask_fcn_logits"]] = _copy_weight(
		m[weight_names["mask_fcn_logits"]], mask_fcn_logits.weight
	)

	m[bias_names["cls_score"]] = _copy_bias(
		m[bias_names["cls_score"]], cls_score.bias
	)
	m[bias_names["bbox_pred"]] = _copy_bias(
		m[bias_names["bbox_pred"]], bbox_pred.bias, class_agnostic
	)
	m[bias_names["mask_fcn_logits"]] = _copy_bias(
		m[bias_names["mask_fcn_logits"]], mask_fcn_logits.bias
	)

	print("f: {}\nout_file: {}".format(f, out_file))
	torch.save(m, out_file)

@zimenglan-sysu-512
Copy link
Contributor

related to #259 #378

@shinya7y
Copy link

Considering the Mask R-CNN paper and the configs for keypoint_rcnn in Detectron, the following setting seems suitable for reproduction.

MIN_SIZE_TRAIN: (800, 832, 864, 896, 928, 960, 992, 1024)
MAX_SIZE_TRAIN: 2048
MIN_SIZE_TEST: 1024
MAX_SIZE_TEST: 2048

@zimenglan-sysu-512
Copy link
Contributor

hi @shinya7y
i follow what u say to set the input, and train the model again, and get 0.409 for bbox, and 0.362 for mask.
thank u.

@shinya7y
Copy link

Hi @zimenglan-sysu-512
Glad to hear that! Would you please specify COCO_CATEGORIES and CITYSCAPES_FINE_CATEGORIES for ease of reproduction?

@chenchr
Copy link
Author

chenchr commented Apr 18, 2019

@chengyangfu @zimenglan-sysu-512 @shinya7y Thank you all.
I change the image size but can get 0.378 ap for box and 0.335 for segm only.
Are there any other changes should be conducted ?
Thanks.

@zimenglan-sysu-512
Copy link
Contributor

hi @shinya7y
COCO_CATEGORIES is the 81 classes of coco, CITYSCAPES_FINE_CATEGORIES is the 9 classes of cityscape.

hi @chenchr
please use the above code to get the pretrained model from coco, and use it to initialize the model. and then set the input as what @shinya7y says.
hope it can help u.

@fmassa
Copy link
Contributor

fmassa commented Apr 19, 2019

@zimenglan-sysu-512 would you mind sending a PR improving the documentation / adding the scripts that you showed in this thread into maskrcnn-benchmark?

@shinya7y
Copy link

@zimenglan-sysu-512 Thank you.
I wonder whether the names for background class are the same or not.
Transferring the weights of background for COCO to those for Cityscapes may slightly harm accuracy.
In any case, PR will be very helpful.

@chenchr Please try to use the following solver setting.

SOLVER:
  BASE_LR: 0.01  # (1)
  IMS_PER_BATCH: 8  # (1)
  WEIGHT_DECAY: 0.0001
  STEPS: (3000,)  # (2)
  MAX_ITER: 4000  # (2)

(1) The batchsize should be 8 for reproduction.
The "linear scaling rule" ( https://github.com/facebookresearch/Detectron/blob/master/GETTING_STARTED.md#2-multi-gpu-training ) is useful but not perfect ( https://arxiv.org/abs/1811.03600 ).
(2) Changing iteration settings should be conducted when we use the COCO pre-trained model ( https://arxiv.org/abs/1703.06870 ).

@zimenglan-sysu-512
Copy link
Contributor

hi @shinya7y and @fmassa, i send a PR #697 for finetuning.

@fmassa
Copy link
Contributor

fmassa commented Apr 20, 2019

I believe this issue can now be closed thanks to @zimenglan-sysu-512 PR #697

@fmassa fmassa closed this as completed Apr 20, 2019
@14211019
Copy link

14211019 commented May 31, 2019

@zimenglan-sysu-512 Thank you.
I wonder whether the names for background class are the same or not.
Transferring the weights of background for COCO to those for Cityscapes may slightly harm accuracy.
In any case, PR will be very helpful.

@chenchr Please try to use the following solver setting.

SOLVER:
  BASE_LR: 0.01  # (1)
  IMS_PER_BATCH: 8  # (1)
  WEIGHT_DECAY: 0.0001
  STEPS: (3000,)  # (2)
  MAX_ITER: 4000  # (2)

(1) The batchsize should be 8 for reproduction.
The "linear scaling rule" ( https://github.com/facebookresearch/Detectron/blob/master/GETTING_STARTED.md#2-multi-gpu-training ) is useful but not perfect ( https://arxiv.org/abs/1811.03600 ).
(2) Changing iteration settings should be conducted when we use the COCO pre-trained model ( https://arxiv.org/abs/1703.06870 ).

hi, @shinya7y
I only have one gpu with batchz-size=2 ,could you tell me how to set the following parameters?
SOLVER:
BASE_LR: 0.01 # (1)
IMS_PER_BATCH: 8 # (1)
WEIGHT_DECAY: 0.0001
STEPS: (3000,) # (2)
MAX_ITER: 4000 # (2)

thanks a lot!

i set them to

SOLVER:
#BASE_LR: 0.01 #batch=8
#IMS_PER_BATCH: 8
#STEPS: (3000,) #batch=8
#MAX_ITER: 4000 #batch=8
BASE_LR: 0.0025
WEIGHT_DECAY: 0.0001
IMS_PER_BATCH: 2
STEPS: (12000,)
MAX_ITER: 16000
INPUT:
MIN_SIZE_TRAIN: (800, 832, 864, 896, 928, 960, 992, 1024, 1024)
MAX_SIZE_TRAIN: 2048
MIN_SIZE_TEST: 1024
MAX_SIZE_TEST: 2048

and i set the MODEL.RPN.FPN_POST_NMS_TOP_N_TRAIN 2000

is it right? and i'm training waiting for a fine result

@botcs
Copy link
Contributor

botcs commented Sep 23, 2019

Hi,
I could reproduce reported results in the PR #1090

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

7 participants