From c7ac97d5dba5715785a9bde012f9e4360e0b2c37 Mon Sep 17 00:00:00 2001 From: haruishi <7902640+haruishi43@users.noreply.github.com> Date: Thu, 14 Dec 2023 11:45:36 +0900 Subject: [PATCH] [Feature] add -with-labels arg to inferencer for visualization without labels (#3466) Thanks for your contribution and we appreciate it a lot. The following instructions would make your pull request more healthy and more easily get feedback. If you do not understand some items, don't worry, just make the pull request and seek help from maintainers. ## Motivation It is difficult to visualize without "labels" when using the inferencer. - While using the `MMSegInferencer`, the visualized prediction contains labels on the mask, but it is difficult to pass `withLabels=False` without rewriting the config (which is harder to do when you initialize the inferencer with a model name rather than the config). - I thought it would be easier to just pass `withLabels=False` to `inferencer.__call__()` since you can also pass `opacity` and other parameters anyway. ## Modification Please briefly describe what modification is made in this PR. - Added `with_labels` to `visualize_kwargs` inside `MMSegInferencer`. - Modified to `visualize()` function. ## BC-breaking (Optional) Does the modification introduce changes that break the backward-compatibility of the downstream repos? If so, please describe how it breaks the compatibility and how the downstream projects should modify their code to keep compatibility with this PR. ## Use cases (Optional) If this PR introduces a new feature, it is better to list some use cases here, and update the documentation. ## Checklist 1. Pre-commit or other linting tools are used to fix the potential lint issues. 2. The modification is covered by complete unit tests. If not, please add more unit test to ensure the correctness. 3. If the modification has potential influence on downstream projects, this PR should be tested with downstream projects, like MMDet or MMDet3D. 4. The documentation has been modified accordingly, like docstring or example tutorials. --------- Co-authored-by: xiexinch --- demo/image_demo.py | 6 ++++++ demo/image_demo_with_inferencer.py | 11 ++++++++++- mmseg/apis/inference.py | 6 +++--- mmseg/apis/mmseg_inferencer.py | 9 ++++++--- mmseg/visualization/local_visualizer.py | 14 +++++++------- 5 files changed, 32 insertions(+), 14 deletions(-) diff --git a/demo/image_demo.py b/demo/image_demo.py index 231aacb9dd..ebc34c80b2 100644 --- a/demo/image_demo.py +++ b/demo/image_demo.py @@ -19,6 +19,11 @@ def main(): type=float, default=0.5, help='Opacity of painted segmentation map. In (0, 1] range.') + parser.add_argument( + '--with-labels', + action='store_true', + default=False, + help='Whether to display the class labels.') parser.add_argument( '--title', default='result', help='The image identifier.') args = parser.parse_args() @@ -36,6 +41,7 @@ def main(): result, title=args.title, opacity=args.opacity, + with_labels=args.with_labels, draw_gt=False, show=False if args.out_file is not None else True, out_file=args.out_file) diff --git a/demo/image_demo_with_inferencer.py b/demo/image_demo_with_inferencer.py index 26bf0f257c..d1fa9deb9e 100644 --- a/demo/image_demo_with_inferencer.py +++ b/demo/image_demo_with_inferencer.py @@ -27,6 +27,11 @@ def main(): type=float, default=0.5, help='Opacity of painted segmentation map. In (0, 1] range.') + parser.add_argument( + '--with-labels', + action='store_true', + default=False, + help='Whether to display the class labels.') args = parser.parse_args() # build the model from a config file and a checkpoint file @@ -38,7 +43,11 @@ def main(): # test a single image mmseg_inferencer( - args.img, show=args.show, out_dir=args.out_dir, opacity=args.opacity) + args.img, + show=args.show, + out_dir=args.out_dir, + opacity=args.opacity, + with_labels=args.with_labels) if __name__ == '__main__': diff --git a/mmseg/apis/inference.py b/mmseg/apis/inference.py index 57fc5d23dc..aab11d14f4 100644 --- a/mmseg/apis/inference.py +++ b/mmseg/apis/inference.py @@ -127,7 +127,7 @@ def show_result_pyplot(model: BaseSegmentor, draw_pred: bool = True, wait_time: float = 0, show: bool = True, - withLabels: Optional[bool] = True, + with_labels: Optional[bool] = True, save_dir=None, out_file=None): """Visualize the segmentation results on the image. @@ -147,7 +147,7 @@ def show_result_pyplot(model: BaseSegmentor, that means "forever". Defaults to 0. show (bool): Whether to display the drawn image. Default to True. - withLabels(bool, optional): Add semantic labels in visualization + with_labels(bool, optional): Add semantic labels in visualization result, Default to True. save_dir (str, optional): Save file dir for all storage backends. If it is None, the backend storage will not save any data. @@ -183,7 +183,7 @@ def show_result_pyplot(model: BaseSegmentor, wait_time=wait_time, out_file=out_file, show=show, - withLabels=withLabels) + with_labels=with_labels) vis_img = visualizer.get_image() return vis_img diff --git a/mmseg/apis/mmseg_inferencer.py b/mmseg/apis/mmseg_inferencer.py index 095639a80f..02a198b516 100644 --- a/mmseg/apis/mmseg_inferencer.py +++ b/mmseg/apis/mmseg_inferencer.py @@ -60,7 +60,8 @@ class MMSegInferencer(BaseInferencer): preprocess_kwargs: set = set() forward_kwargs: set = {'mode', 'out_dir'} visualize_kwargs: set = { - 'show', 'wait_time', 'img_out_dir', 'opacity', 'return_vis' + 'show', 'wait_time', 'img_out_dir', 'opacity', 'return_vis', + 'with_labels' } postprocess_kwargs: set = {'pred_out_dir', 'return_datasample'} @@ -201,7 +202,8 @@ def visualize(self, show: bool = False, wait_time: int = 0, img_out_dir: str = '', - opacity: float = 0.8) -> List[np.ndarray]: + opacity: float = 0.8, + with_labels: Optional[bool] = True) -> List[np.ndarray]: """Visualize predictions. Args: @@ -254,7 +256,8 @@ def visualize(self, wait_time=wait_time, draw_gt=False, draw_pred=True, - out_file=out_file) + out_file=out_file, + with_labels=with_labels) if return_vis: results.append(self.visualizer.get_image()) self.num_visualized_imgs += 1 diff --git a/mmseg/visualization/local_visualizer.py b/mmseg/visualization/local_visualizer.py index 3096e3183b..ee3d652c7b 100644 --- a/mmseg/visualization/local_visualizer.py +++ b/mmseg/visualization/local_visualizer.py @@ -103,7 +103,7 @@ def _draw_sem_seg(self, sem_seg: PixelData, classes: Optional[List], palette: Optional[List], - withLabels: Optional[bool] = True) -> np.ndarray: + with_labels: Optional[bool] = True) -> np.ndarray: """Draw semantic seg of GT or prediction. Args: @@ -119,7 +119,7 @@ def _draw_sem_seg(self, palette (list, optional): Input palette for result rendering, which is a list of color palette responding to the classes. Defaults to None. - withLabels(bool, optional): Add semantic labels in visualization + with_labels(bool, optional): Add semantic labels in visualization result, Default to True. Returns: @@ -139,7 +139,7 @@ def _draw_sem_seg(self, for label, color in zip(labels, colors): mask[sem_seg[0] == label, :] = color - if withLabels: + if with_labels: font = cv2.FONT_HERSHEY_SIMPLEX # (0,1] to change the size of the text relative to the image scale = 0.05 @@ -265,7 +265,7 @@ def add_datasample( # TODO: Supported in mmengine's Viusalizer. out_file: Optional[str] = None, step: int = 0, - withLabels: Optional[bool] = True) -> None: + with_labels: Optional[bool] = True) -> None: """Draw datasample and save to all backends. - If GT and prediction are plotted at the same time, they are @@ -291,7 +291,7 @@ def add_datasample( wait_time (float): The interval of show (s). Defaults to 0. out_file (str): Path to output file. Defaults to None. step (int): Global step value to record. Defaults to 0. - withLabels(bool, optional): Add semantic labels in visualization + with_labels(bool, optional): Add semantic labels in visualization result, Defaults to True. """ classes = self.dataset_meta.get('classes', None) @@ -307,7 +307,7 @@ def add_datasample( 'visualizing semantic ' \ 'segmentation results.' gt_img_data = self._draw_sem_seg(image, data_sample.gt_sem_seg, - classes, palette, withLabels) + classes, palette, with_labels) if 'gt_depth_map' in data_sample: gt_img_data = gt_img_data if gt_img_data is not None else image @@ -325,7 +325,7 @@ def add_datasample( pred_img_data = self._draw_sem_seg(image, data_sample.pred_sem_seg, classes, palette, - withLabels) + with_labels) if 'pred_depth_map' in data_sample: pred_img_data = pred_img_data if pred_img_data is not None \