From ecd1ecb6ba1cca0372d3c3289108e19733012acb Mon Sep 17 00:00:00 2001 From: whooray Date: Wed, 14 Sep 2022 01:13:43 +0900 Subject: [PATCH] [Fix] Fix mmseg.api.inference inference_segmentor (#1849) * [Fix] Fix mmseg.api.inference inference_segmentor Motivation Fix inference_segmentor not working with multiple images path or images. List[str/ndarray] Modification - process images if instance is list * fix typo * Update mmseg/apis/inference.py Co-authored-by: Hakjin Lee Co-authored-by: Hakjin Lee --- mmseg/apis/inference.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/mmseg/apis/inference.py b/mmseg/apis/inference.py index a2a8ab0cb0..5bbe66634e 100644 --- a/mmseg/apis/inference.py +++ b/mmseg/apis/inference.py @@ -67,7 +67,7 @@ def __call__(self, results): return results -def inference_segmentor(model, img): +def inference_segmentor(model, imgs): """Inference image(s) with the segmentor. Args: @@ -84,9 +84,13 @@ def inference_segmentor(model, img): test_pipeline = [LoadImage()] + cfg.data.test.pipeline[1:] test_pipeline = Compose(test_pipeline) # prepare data - data = dict(img=img) - data = test_pipeline(data) - data = collate([data], samples_per_gpu=1) + data = [] + imgs = imgs if isinstance(imgs, list) else [imgs] + for img in imgs: + img_data = dict(img=img) + img_data = test_pipeline(img_data) + data.append(img_data) + data = collate(data, samples_per_gpu=len(imgs)) if next(model.parameters()).is_cuda: # scatter to specified GPU data = scatter(data, [device])[0]