diff --git a/mmseg/apis/inference.py b/mmseg/apis/inference.py index a2a8ab0cb0..5bbe66634e 100644 --- a/mmseg/apis/inference.py +++ b/mmseg/apis/inference.py @@ -67,7 +67,7 @@ def __call__(self, results): return results -def inference_segmentor(model, img): +def inference_segmentor(model, imgs): """Inference image(s) with the segmentor. Args: @@ -84,9 +84,13 @@ def inference_segmentor(model, img): test_pipeline = [LoadImage()] + cfg.data.test.pipeline[1:] test_pipeline = Compose(test_pipeline) # prepare data - data = dict(img=img) - data = test_pipeline(data) - data = collate([data], samples_per_gpu=1) + data = [] + imgs = imgs if isinstance(imgs, list) else [imgs] + for img in imgs: + img_data = dict(img=img) + img_data = test_pipeline(img_data) + data.append(img_data) + data = collate(data, samples_per_gpu=len(imgs)) if next(model.parameters()).is_cuda: # scatter to specified GPU data = scatter(data, [device])[0]