-
Notifications
You must be signed in to change notification settings - Fork 38
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
bfc1e12
commit cfef16b
Showing
1 changed file
with
117 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
''' | ||
Utilities to match ground truth boxes to anchor boxes. | ||
Copyright (C) 2018 Pierluigi Ferrari | ||
This program is free software: you can redistribute it and/or modify | ||
it under the terms of the GNU General Public License as published by | ||
the Free Software Foundation, either version 3 of the License, or | ||
(at your option) any later version. | ||
This program is distributed in the hope that it will be useful, | ||
but WITHOUT ANY WARRANTY; without even the implied warranty of | ||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | ||
GNU General Public License for more details. | ||
You should have received a copy of the GNU General Public License | ||
along with this program. If not, see <http:https://www.gnu.org/licenses/>. | ||
''' | ||
|
||
from __future__ import division | ||
import numpy as np | ||
|
||
def match_bipartite_greedy(weight_matrix): | ||
''' | ||
Returns a bipartite matching according to the given weight matrix. | ||
The algorithm works as follows: | ||
Let the first axis of `weight_matrix` represent ground truth boxes | ||
and the second axis anchor boxes. | ||
The ground truth box that has the greatest similarity with any | ||
anchor box will be matched first, then out of the remaining ground | ||
truth boxes, the ground truth box that has the greatest similarity | ||
with any of the remaining anchor boxes will be matched second, and | ||
so on. That is, the ground truth boxes will be matched in descending | ||
order by maximum similarity with any of the respectively remaining | ||
anchor boxes. | ||
The runtime complexity is O(m^2 * n), where `m` is the number of | ||
ground truth boxes and `n` is the number of anchor boxes. | ||
Arguments: | ||
weight_matrix (array): A 2D Numpy array that represents the weight matrix | ||
for the matching process. If `(m,n)` is the shape of the weight matrix, | ||
it must be `m <= n`. The weights can be integers or floating point | ||
numbers. The matching process will maximize, i.e. larger weights are | ||
preferred over smaller weights. | ||
Returns: | ||
A 1D Numpy array of length `weight_matrix.shape[0]` that represents | ||
the matched index along the second axis of `weight_matrix` for each index | ||
along the first axis. | ||
''' | ||
|
||
weight_matrix = np.copy(weight_matrix) # We'll modify this array. | ||
num_ground_truth_boxes = weight_matrix.shape[0] | ||
all_gt_indices = list(range(num_ground_truth_boxes)) # Only relevant for fancy-indexing below. | ||
|
||
# This 1D array will contain for each ground truth box the index of | ||
# the matched anchor box. | ||
matches = np.zeros(num_ground_truth_boxes, dtype=np.int) | ||
|
||
# In each iteration of the loop below, exactly one ground truth box | ||
# will be matched to one anchor box. | ||
for _ in range(num_ground_truth_boxes): | ||
|
||
# Find the maximal anchor-ground truth pair in two steps: First, reduce | ||
# over the anchor boxes and then reduce over the ground truth boxes. | ||
anchor_indices = np.argmax(weight_matrix, axis=1) # Reduce along the anchor box axis. | ||
overlaps = weight_matrix[all_gt_indices, anchor_indices] | ||
ground_truth_index = np.argmax(overlaps) # Reduce along the ground truth box axis. | ||
anchor_index = anchor_indices[ground_truth_index] | ||
matches[ground_truth_index] = anchor_index # Set the match. | ||
|
||
# Set the row of the matched ground truth box and the column of the matched | ||
# anchor box to all zeros. This ensures that those boxes will not be matched again, | ||
# because they will never be the best matches for any other boxes. | ||
weight_matrix[ground_truth_index] = 0 | ||
weight_matrix[:,anchor_index] = 0 | ||
|
||
return matches | ||
|
||
def match_multi(weight_matrix, threshold): | ||
''' | ||
Matches all elements along the second axis of `weight_matrix` to their best | ||
matches along the first axis subject to the constraint that the weight of a match | ||
must be greater than or equal to `threshold` in order to produce a match. | ||
If the weight matrix contains elements that should be ignored, the row or column | ||
representing the respective elemet should be set to a value below `threshold`. | ||
Arguments: | ||
weight_matrix (array): A 2D Numpy array that represents the weight matrix | ||
for the matching process. If `(m,n)` is the shape of the weight matrix, | ||
it must be `m <= n`. The weights can be integers or floating point | ||
numbers. The matching process will maximize, i.e. larger weights are | ||
preferred over smaller weights. | ||
threshold (float): A float that represents the threshold (i.e. lower bound) | ||
that must be met by a pair of elements to produce a match. | ||
Returns: | ||
Two 1D Numpy arrays of equal length that represent the matched indices. The first | ||
array contains the indices along the first axis of `weight_matrix`, the second array | ||
contains the indices along the second axis. | ||
''' | ||
|
||
num_anchor_boxes = weight_matrix.shape[1] | ||
all_anchor_indices = list(range(num_anchor_boxes)) # Only relevant for fancy-indexing below. | ||
|
||
# Find the best ground truth match for every anchor box. | ||
ground_truth_indices = np.argmax(weight_matrix, axis=0) # Array of shape (weight_matrix.shape[1],) | ||
overlaps = weight_matrix[ground_truth_indices, all_anchor_indices] # Array of shape (weight_matrix.shape[1],) | ||
|
||
# Filter out the matches with a weight below the threshold. | ||
anchor_indices_thresh_met = np.nonzero(overlaps >= threshold)[0] | ||
gt_indices_thresh_met = ground_truth_indices[anchor_indices_thresh_met] | ||
|
||
return gt_indices_thresh_met, anchor_indices_thresh_met |