From ad1c871bea02f3da5b4ed0955bb17983e58493ca Mon Sep 17 00:00:00 2001 From: alibool <65023386+alibool@users.noreply.github.com> Date: Mon, 14 Mar 2022 20:33:13 +0800 Subject: [PATCH] fix indexerror for retinanet in torch1.x `keep_ix / fg_probs.shape[1]` returns a float tensor which can not be used as index with torch 1.x. --- models/retina_net.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/retina_net.py b/models/retina_net.py index 31140b4..4eff65e 100644 --- a/models/retina_net.py +++ b/models/retina_net.py @@ -209,7 +209,7 @@ def refine_detections(anchors, probs, deltas, batch_ixs, cf): flat_probs, flat_probs_order = fg_probs.view(-1).sort(descending=True) keep_ix = flat_probs_order[:cf.pre_nms_limit] # reshape indices to 2D index array with shape like fg_probs. - keep_arr = torch.cat(((keep_ix / fg_probs.shape[1]).unsqueeze(1), (keep_ix % fg_probs.shape[1]).unsqueeze(1)), 1) + keep_arr = torch.cat(((keep_ix // fg_probs.shape[1]).unsqueeze(1), (keep_ix % fg_probs.shape[1]).unsqueeze(1)), 1) pre_nms_scores = flat_probs[:cf.pre_nms_limit] pre_nms_class_ids = keep_arr[:, 1] + 1 # add background again.