Skip to content

Commit

Permalink
[Enhancement] Support input gt seg map is not 2D (#2739)
Browse files Browse the repository at this point in the history
Thanks for your contribution and we appreciate it a lot. The following
instructions would make your pull request more healthy and more easily
get feedback. If you do not understand some items, don't worry, just
make the pull request and seek help from maintainers.

## Motivation

fix #2593

## Modification

1. Only when gt seg map is 2D, extend its shape to 3D PixelData 
2. If seg map is not 2D, we raised warning for users.

---------

Co-authored-by: xiexinch <[email protected]>
  • Loading branch information
MeowZheng and xiexinch committed Mar 14, 2023
1 parent 684d79f commit 6ba4696
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 3 deletions.
15 changes: 12 additions & 3 deletions mmseg/datasets/transforms/formatting.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings

import numpy as np
from mmcv.transforms import to_tensor
from mmcv.transforms.base import BaseTransform
Expand Down Expand Up @@ -72,9 +74,16 @@ def transform(self, results: dict) -> dict:

data_sample = SegDataSample()
if 'gt_seg_map' in results:
gt_sem_seg_data = dict(
data=to_tensor(results['gt_seg_map'][None,
...].astype(np.int64)))
if results['gt_seg_map'].shape == 2:
data = to_tensor(results['gt_seg_map'][None,
...].astype(np.int64))
else:
warnings.warn('Please pay attention your ground truth '
'segmentation map, usually the segentation '
'map is 2D, but got '
f'{results["gt_seg_map"].shape}')
data = to_tensor(results['gt_seg_map'].astype(np.int64))
gt_sem_seg_data = dict(data=data)
data_sample.gt_sem_seg = PixelData(**gt_sem_seg_data)

if 'gt_edge_map' in results:
Expand Down
5 changes: 5 additions & 0 deletions tests/test_datasets/test_formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ def test_transform(self):
BaseDataElement)
self.assertEqual(results['data_samples'].ori_shape,
results['data_samples'].gt_sem_seg.shape)
results = copy.deepcopy(self.results)
results['gt_seg_map'] = np.random.rand(3, 300, 400)
results = transform(results)
self.assertEqual(results['data_samples'].ori_shape,
results['data_samples'].gt_sem_seg.shape)

def test_repr(self):
transform = PackSegInputs(meta_keys=self.meta_keys)
Expand Down

0 comments on commit 6ba4696

Please sign in to comment.