diff --git a/models/common.py b/models/common.py index 5c82b18f102c..7ac3a4a29672 100644 --- a/models/common.py +++ b/models/common.py @@ -457,7 +457,7 @@ def wrap_frozen_graph(gd, inputs, outputs): self.__dict__.update(locals()) # assign all variables to self - def forward(self, im, augment=False, visualize=False, val=False): + def forward(self, im, augment=False, visualize=False): # YOLOv5 MultiBackend inference b, ch, h, w = im.shape # batch, channel, height, width if self.fp16 and im.dtype != torch.float16: @@ -521,10 +521,12 @@ def forward(self, im, augment=False, visualize=False, val=False): y[..., :4] *= [w, h, w, h] # xywh normalized to pixels if isinstance(y, (list, tuple)): - y = y[0] - if isinstance(y, np.ndarray): - y = torch.from_numpy(y).to(self.device) - return (y, []) if val else y + return self.from_numpy(y[0]) if len(y) == 1 else [self.from_numpy(x) for x in y] + else: + return self.from_numpy(y) + + def from_numpy(self, x): + return torch.from_numpy(x).to(self.device) if isinstance(x, np.ndarray) else x def warmup(self, imgsz=(1, 3, 640, 640)): # Warmup model by running inference once diff --git a/utils/general.py b/utils/general.py index 25a1a1456009..cae63fd9dd21 100755 --- a/utils/general.py +++ b/utils/general.py @@ -813,6 +813,9 @@ def non_max_suppression(prediction, list of detections, on (n,6) tensor per image [xyxy, conf, cls] """ + if isinstance(prediction, (list, tuple)): # YOLOv5 model in validation model, output = (inference_out, loss_out) + prediction = prediction[0] # select only inference output + bs = prediction.shape[0] # batch size nc = prediction.shape[2] - 5 # number of classes xc = prediction[..., 4] > conf_thres # candidates diff --git a/val.py b/val.py index 58b9c9e1bec0..5427ee7b3619 100644 --- a/val.py +++ b/val.py @@ -204,11 +204,11 @@ def run( # Inference with dt[1]: - out, train_out = model(im) if training else model(im, augment=augment, val=True) # inference, loss outputs + out, train_out = model(im) if compute_loss else (model(im, augment=augment), None) # Loss if compute_loss: - loss += compute_loss([x.float() for x in train_out], targets)[1] # box, obj, cls + loss += compute_loss(train_out, targets)[1] # box, obj, cls # NMS targets[:, 2:] *= torch.tensor((width, height, width, height), device=device) # to pixels