Skip to content

Commit

Permalink
chore: Turn warning into error
Browse files Browse the repository at this point in the history
Degenerate ground truth boxes will lead to NaN errors during the training, so the warning that warns about them should really be an error.
  • Loading branch information
pierluigiferrari committed Apr 9, 2018
1 parent cfef16b commit 10f867d
Showing 1 changed file with 9 additions and 6 deletions.
15 changes: 9 additions & 6 deletions ssd_encoder_decoder/ssd_input_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

from __future__ import division
import numpy as np
import warnings

from bounding_box_utils.bounding_box_utils import iou, convert_coordinates
from ssd_encoder_decoder.matching_utils import match_bipartite_greedy, match_multi
Expand Down Expand Up @@ -304,11 +303,9 @@ def __call__(self, ground_truth_labels, diagnostics=False):

# Check for degenerate ground truth bounding boxes before attempting any computations.
if np.any(labels[:,[xmax]] - labels[:,[xmin]] <= 0) or np.any(labels[:,[ymax]] - labels[:,[ymin]] <= 0):
warnings.warn("SSDInputEncoder detected degenerate ground truth bounding boxes for batch item {} with bounding boxes {}, ".format(i, labels) +
"i.e. bounding boxes where xmax <= xmin and/or ymax <= ymin. " +
"This means that your dataset either contains bad ground truth data or that you are passing ground truth in the wrong coordinate " +
"format. Note that SSDInputEncoder expects the box coordinates to be in the format (xmin, ymin, xmax, ymax). Degenerate ground truth " +
"bounding boxes may lead to NaN errors during the training.")
raise DegenerateBoxError("SSDInputEncoder detected degenerate ground truth bounding boxes for batch item {} with bounding boxes {}, ".format(i, labels) +
"i.e. bounding boxes where xmax <= xmin and/or ymax <= ymin. Degenerate ground truth " +
"bounding boxes will lead to NaN errors during the training.")

# Maybe normalize the box coordinates.
if self.normalize_coords:
Expand Down Expand Up @@ -584,3 +581,9 @@ def generate_encoding_template(self, batch_size, diagnostics=False):
return y_encoding_template, self.centers_diag, self.wh_list_diag, self.steps_diag, self.offsets_diag
else:
return y_encoding_template

class DegenerateBoxError(Exception):
'''
An exception class to be raised if degenerate boxes are being detected.
'''
pass

0 comments on commit 10f867d

Please sign in to comment.