Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Fix mmseg.api.inference inference_segmentor #1849

Merged
merged 5 commits into from
Sep 13, 2022
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
[Fix] Fix mmseg.api.inference inference_segmentor
Motivation
Fix inference_segmentor not working with multiple images path or images. List[str/ndarray]

Modification
- process images if instance is list
  • Loading branch information
jinwonkim93 committed Aug 1, 2022
commit 40c8fb081eb1c32f32c1a561366b4a0bd47018b0
17 changes: 13 additions & 4 deletions mmseg/apis/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def __call__(self, results):
return results


def inference_segmentor(model, img):
def inference_segmentor(model, imgs):
"""Inference image(s) with the segmentor.

Args:
Expand All @@ -84,14 +84,23 @@ def inference_segmentor(model, img):
test_pipeline = [LoadImage()] + cfg.data.test.pipeline[1:]
test_pipeline = Compose(test_pipeline)
# prepare data
data = dict(img=img)
data = test_pipeline(data)
data = collate([data], samples_per_gpu=1)
if isinstance(imgs, list):
data = []
for img in imgs:
img_data = dict(img=img)
img_data = test_pipeline(img_data)
data.append(img_data)
data = collate(data, samples_per_gpu=len(imgs))
else:
data = dict(img=imgs)
data = test_pipeline(data)
data = collate([data], samples_per_gpu=1)
jinwonkim93 marked this conversation as resolved.
Show resolved Hide resolved
if next(model.parameters()).is_cuda:
# scatter to specified GPU
data = scatter(data, [device])[0]
else:
data['img_metas'] = [i.data[0] for i in data['img_metas']]
breakpoint()

# forward the model
with torch.no_grad():
Expand Down