Skip to content

Commit

Permalink
Refactoring of Mask-RCNN to put all mask prediction code in third stage.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 192421843
  • Loading branch information
pkulzc committed Apr 13, 2018
1 parent 227f41e commit b47ca97
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -451,8 +451,6 @@ def __init__(self,

if self._number_of_stages <= 0 or self._number_of_stages > 3:
raise ValueError('Number of stages should be a value in {1, 2, 3}.')
if self._is_training and self._number_of_stages == 3:
self._number_of_stages = 2

@property
def first_stage_feature_extractor_scope(self):
Expand Down Expand Up @@ -739,9 +737,6 @@ def _predict_second_stage(self, rpn_box_encodings,
of the image.
6) box_classifier_features: a 4-D float32 tensor representing the
features for each proposal.
7) mask_predictions: (optional) a 4-D tensor with shape
[total_num_padded_proposals, num_classes, mask_height, mask_width]
containing instance mask predictions.
"""
image_shape_2d = self._image_batch_shape_2d(image_shape)
proposal_boxes_normalized, _, num_proposals = self._postprocess_rpn(
Expand All @@ -757,15 +752,11 @@ def _predict_second_stage(self, rpn_box_encodings,
flattened_proposal_feature_maps,
scope=self.second_stage_feature_extractor_scope))

predict_auxiliary_outputs = False
if self._number_of_stages == 2:
predict_auxiliary_outputs = True
box_predictions = self._mask_rcnn_box_predictor.predict(
[box_classifier_features],
num_predictions_per_location=[1],
scope=self.second_stage_box_predictor_scope,
predict_boxes_and_classes=True,
predict_auxiliary_outputs=predict_auxiliary_outputs)
predict_boxes_and_classes=True)

refined_box_encodings = tf.squeeze(
box_predictions[box_predictor.BOX_ENCODINGS],
Expand All @@ -786,18 +777,16 @@ def _predict_second_stage(self, rpn_box_encodings,
'box_classifier_features': box_classifier_features,
'proposal_boxes_normalized': proposal_boxes_normalized,
}
if box_predictor.MASK_PREDICTIONS in box_predictions:
mask_predictions = tf.squeeze(box_predictions[
box_predictor.MASK_PREDICTIONS], axis=1)
prediction_dict['mask_predictions'] = mask_predictions

return prediction_dict

def _predict_third_stage(self, prediction_dict, image_shapes):
"""Predicts non-box, non-class outputs using refined detections.
This happens after calling the post-processing stage, such that masks
are only calculated for the top scored boxes.
For training, masks as predicted directly on the box_classifier_features,
which are region-features from the initial anchor boxes.
For inference, this happens after calling the post-processing stage, such
that masks are only calculated for the top scored boxes.
Args:
prediction_dict: a dictionary holding "raw" prediction tensors:
Expand All @@ -819,47 +808,62 @@ def _predict_third_stage(self, prediction_dict, image_shapes):
4) proposal_boxes: A float32 tensor of shape
[batch_size, self.max_num_proposals, 4] representing
decoded proposal bounding boxes in absolute coordinates.
5) box_classifier_features: a 4-D float32 tensor representing the
features for each proposal.
image_shapes: A 2-D int32 tensors of shape [batch_size, 3] containing
shapes of images in the batch.
Returns:
prediction_dict: a dictionary that in addition to the input predictions
does hold the following predictions as well:
1) mask_predictions: (optional) a 4-D tensor with shape
1) mask_predictions: a 4-D tensor with shape
[batch_size, max_detection, mask_height, mask_width] containing
instance mask predictions.
"""
detections_dict = self._postprocess_box_classifier(
prediction_dict['refined_box_encodings'],
prediction_dict['class_predictions_with_background'],
prediction_dict['proposal_boxes'],
prediction_dict['num_proposals'],
image_shapes)
prediction_dict.update(detections_dict)
detection_boxes = detections_dict[
fields.DetectionResultFields.detection_boxes]
detection_classes = detections_dict[
fields.DetectionResultFields.detection_classes]
rpn_features_to_crop = prediction_dict['rpn_features_to_crop']
batch_size = tf.shape(detection_boxes)[0]
max_detection = tf.shape(detection_boxes)[1]
flattened_detected_feature_maps = (
self._compute_second_stage_input_feature_maps(
rpn_features_to_crop, detection_boxes))
detected_box_classifier_features = (
self._feature_extractor.extract_box_classifier_features(
flattened_detected_feature_maps,
scope=self.second_stage_feature_extractor_scope))
box_predictions = self._mask_rcnn_box_predictor.predict(
[detected_box_classifier_features],
num_predictions_per_location=[1],
scope=self.second_stage_box_predictor_scope,
predict_boxes_and_classes=False,
predict_auxiliary_outputs=True)
if self._is_training:
curr_box_classifier_features = prediction_dict['box_classifier_features']
detection_classes = prediction_dict['class_predictions_with_background']
box_predictions = self._mask_rcnn_box_predictor.predict(
[curr_box_classifier_features],
num_predictions_per_location=[1],
scope=self.second_stage_box_predictor_scope,
predict_boxes_and_classes=False,
predict_auxiliary_outputs=True)
prediction_dict['mask_predictions'] = tf.squeeze(box_predictions[
box_predictor.MASK_PREDICTIONS], axis=1)
else:
detections_dict = self._postprocess_box_classifier(
prediction_dict['refined_box_encodings'],
prediction_dict['class_predictions_with_background'],
prediction_dict['proposal_boxes'],
prediction_dict['num_proposals'],
image_shapes)
prediction_dict.update(detections_dict)
detection_boxes = detections_dict[
fields.DetectionResultFields.detection_boxes]
detection_classes = detections_dict[
fields.DetectionResultFields.detection_classes]
rpn_features_to_crop = prediction_dict['rpn_features_to_crop']
batch_size = tf.shape(detection_boxes)[0]
max_detection = tf.shape(detection_boxes)[1]
flattened_detected_feature_maps = (
self._compute_second_stage_input_feature_maps(
rpn_features_to_crop, detection_boxes))
curr_box_classifier_features = (
self._feature_extractor.extract_box_classifier_features(
flattened_detected_feature_maps,
scope=self.second_stage_feature_extractor_scope))

box_predictions = self._mask_rcnn_box_predictor.predict(
[curr_box_classifier_features],
num_predictions_per_location=[1],
scope=self.second_stage_box_predictor_scope,
predict_boxes_and_classes=False,
predict_auxiliary_outputs=True)

if box_predictor.MASK_PREDICTIONS in box_predictions:
detection_masks = tf.squeeze(box_predictions[
box_predictor.MASK_PREDICTIONS], axis=1)

_, num_classes, mask_height, mask_width = (
detection_masks.get_shape().as_list())
_, max_detection = detection_classes.get_shape().as_list()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def test_predict_gives_correct_shapes_in_train_mode_both_stages_with_masks(
with test_graph.as_default():
model = self._build_model(
is_training=True,
number_of_stages=2,
number_of_stages=3,
second_stage_batch_size=7,
predict_masks=True,
masks_are_class_agnostic=masks_are_class_agnostic)
Expand Down

0 comments on commit b47ca97

Please sign in to comment.