From ef5e366f3e070b4af4718bcc0cdb044bb5579a25 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sun, 4 Sep 2022 14:09:04 +0200 Subject: [PATCH 1/6] Update DetectMultiBackend for tuple outputs 2 Signed-off-by: Glenn Jocher --- models/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/common.py b/models/common.py index 5c82b18f102c..b65fbf664799 100644 --- a/models/common.py +++ b/models/common.py @@ -520,7 +520,7 @@ def forward(self, im, augment=False, visualize=False, val=False): y = (y.astype(np.float32) - zero_point) * scale # re-scale y[..., :4] *= [w, h, w, h] # xywh normalized to pixels - if isinstance(y, (list, tuple)): + if isinstance(y, (list, tuple)) and len(y) == 1: y = y[0] if isinstance(y, np.ndarray): y = torch.from_numpy(y).to(self.device) From d373e0b9ed760910859f7a430bb72c8f37a3abb6 Mon Sep 17 00:00:00 2001 From: glennjocher Date: Sun, 4 Sep 2022 14:55:48 +0200 Subject: [PATCH 2/6] Update --- models/common.py | 12 +++++++----- utils/general.py | 3 +++ val.py | 2 +- 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/models/common.py b/models/common.py index b65fbf664799..2005b3413093 100644 --- a/models/common.py +++ b/models/common.py @@ -520,11 +520,13 @@ def forward(self, im, augment=False, visualize=False, val=False): y = (y.astype(np.float32) - zero_point) * scale # re-scale y[..., :4] *= [w, h, w, h] # xywh normalized to pixels - if isinstance(y, (list, tuple)) and len(y) == 1: - y = y[0] - if isinstance(y, np.ndarray): - y = torch.from_numpy(y).to(self.device) - return (y, []) if val else y + if isinstance(y, (list, tuple)): + return self.from_numpy(y[0]) if len(y) == 1 else [self.from_numpy(x) for x in y] + else: + return self.from_numpy(y) + + def from_numpy(self, x): + return torch.from_numpy(x).to(self.device) if isinstance(x, np.ndarray) else x def warmup(self, imgsz=(1, 3, 640, 640)): # Warmup model by running inference once diff --git a/utils/general.py b/utils/general.py index 25a1a1456009..cae63fd9dd21 100755 --- a/utils/general.py +++ b/utils/general.py @@ -813,6 +813,9 @@ def non_max_suppression(prediction, list of detections, on (n,6) tensor per image [xyxy, conf, cls] """ + if isinstance(prediction, (list, tuple)): # YOLOv5 model in validation model, output = (inference_out, loss_out) + prediction = prediction[0] # select only inference output + bs = prediction.shape[0] # batch size nc = prediction.shape[2] - 5 # number of classes xc = prediction[..., 4] > conf_thres # candidates diff --git a/val.py b/val.py index 58b9c9e1bec0..0566f04bf9f8 100644 --- a/val.py +++ b/val.py @@ -204,7 +204,7 @@ def run( # Inference with dt[1]: - out, train_out = model(im) if training else model(im, augment=augment, val=True) # inference, loss outputs + out, train_out = model(im) if training else model(im, augment=augment, val=True), None # inference, loss outputs # Loss if compute_loss: From d6b9a9f223028dbaca08e8df53a66262e97b5237 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 4 Sep 2022 12:56:11 +0000 Subject: [PATCH 3/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- val.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/val.py b/val.py index 0566f04bf9f8..5a5b6533624a 100644 --- a/val.py +++ b/val.py @@ -204,7 +204,8 @@ def run( # Inference with dt[1]: - out, train_out = model(im) if training else model(im, augment=augment, val=True), None # inference, loss outputs + out, train_out = model(im) if training else model(im, augment=augment, + val=True), None # inference, loss outputs # Loss if compute_loss: From ca9a530497c07c23fd4b872c91f0c95dea7308a3 Mon Sep 17 00:00:00 2001 From: glennjocher Date: Sun, 4 Sep 2022 15:05:58 +0200 Subject: [PATCH 4/6] Update --- models/common.py | 2 +- val.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/models/common.py b/models/common.py index 2005b3413093..7ac3a4a29672 100644 --- a/models/common.py +++ b/models/common.py @@ -457,7 +457,7 @@ def wrap_frozen_graph(gd, inputs, outputs): self.__dict__.update(locals()) # assign all variables to self - def forward(self, im, augment=False, visualize=False, val=False): + def forward(self, im, augment=False, visualize=False): # YOLOv5 MultiBackend inference b, ch, h, w = im.shape # batch, channel, height, width if self.fp16 and im.dtype != torch.float16: diff --git a/val.py b/val.py index 0566f04bf9f8..614a194cc14f 100644 --- a/val.py +++ b/val.py @@ -124,7 +124,7 @@ def run( compute_loss=None, ): # Initialize/load model and set device - training = model is not None + training = compute_loss is not None if training: # called by train.py device, pt, jit, engine = next(model.parameters()).device, True, False, False # get model device, PyTorch model half &= device.type != 'cpu' # half precision only supported on CUDA @@ -204,7 +204,7 @@ def run( # Inference with dt[1]: - out, train_out = model(im) if training else model(im, augment=augment, val=True), None # inference, loss outputs + out, train_out = model(im) if training else model(im, augment=augment), None # Loss if compute_loss: From 29d8ccfb8f97884d510d555e36505f5855281718 Mon Sep 17 00:00:00 2001 From: glennjocher Date: Sun, 4 Sep 2022 15:14:12 +0200 Subject: [PATCH 5/6] Update --- val.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/val.py b/val.py index 4863c5571896..534f12e28b6a 100644 --- a/val.py +++ b/val.py @@ -124,7 +124,7 @@ def run( compute_loss=None, ): # Initialize/load model and set device - training = compute_loss is not None + training = model is not None if training: # called by train.py device, pt, jit, engine = next(model.parameters()).device, True, False, False # get model device, PyTorch model half &= device.type != 'cpu' # half precision only supported on CUDA @@ -204,7 +204,7 @@ def run( # Inference with dt[1]: - out, train_out = model(im) if training else model(im, augment=augment), None # inference, loss outputs + out, train_out = model(im) if compute_loss else model(im, augment=augment), None # inference, loss outputs # Loss if compute_loss: From b255ac0b7580c94ff40746617af35e688d16d975 Mon Sep 17 00:00:00 2001 From: glennjocher Date: Sun, 4 Sep 2022 15:27:19 +0200 Subject: [PATCH 6/6] Update --- val.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/val.py b/val.py index 534f12e28b6a..5427ee7b3619 100644 --- a/val.py +++ b/val.py @@ -204,11 +204,11 @@ def run( # Inference with dt[1]: - out, train_out = model(im) if compute_loss else model(im, augment=augment), None # inference, loss outputs + out, train_out = model(im) if compute_loss else (model(im, augment=augment), None) # Loss if compute_loss: - loss += compute_loss([x.float() for x in train_out], targets)[1] # box, obj, cls + loss += compute_loss(train_out, targets)[1] # box, obj, cls # NMS targets[:, 2:] *= torch.tensor((width, height, width, height), device=device) # to pixels