Skip to content

Commit

Permalink
[Fix] Support format_result and fix prefix param in cityscape metric,…
Browse files Browse the repository at this point in the history
… and rename CitysMetric to CityscapesMetric (#2660)

as title
  • Loading branch information
MeowZheng committed Mar 7, 2023
1 parent 6c3599b commit a8aafdd
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 74 deletions.
4 changes: 2 additions & 2 deletions mmseg/evaluation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .metrics import CitysMetric, IoUMetric
from .metrics import CityscapesMetric, IoUMetric

__all__ = ['IoUMetric', 'CitysMetric']
__all__ = ['IoUMetric', 'CityscapesMetric']
4 changes: 2 additions & 2 deletions mmseg/evaluation/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .citys_metric import CitysMetric
from .citys_metric import CityscapesMetric
from .iou_metric import IoUMetric

__all__ = ['IoUMetric', 'CitysMetric']
__all__ = ['IoUMetric', 'CityscapesMetric']
116 changes: 64 additions & 52 deletions mmseg/evaluation/metrics/citys_metric.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,41 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from typing import Dict, List, Optional, Sequence
import shutil
from collections import OrderedDict
from typing import Dict, Optional, Sequence

try:

import cityscapesscripts.evaluation.evalPixelLevelSemanticLabeling as CSEval # noqa
import cityscapesscripts.helpers.labels as CSLabels
except ImportError:
CSLabels = None
CSEval = None

import numpy as np
from mmengine.dist import is_main_process, master_only
from mmengine.evaluator import BaseMetric
from mmengine.logging import MMLogger, print_log
from mmengine.utils import mkdir_or_exist, scandir
from mmengine.utils import mkdir_or_exist
from PIL import Image

from mmseg.registry import METRICS


@METRICS.register_module()
class CitysMetric(BaseMetric):
class CityscapesMetric(BaseMetric):
"""Cityscapes evaluation metric.
Args:
output_dir (str): The directory for output prediction
ignore_index (int): Index that will be ignored in evaluation.
Default: 255.
citys_metrics (list[str] | str): Metrics to be evaluated,
Default: ['cityscapes'].
to_label_id (bool): whether convert output to label_id for
submission. Default: True.
suffix (str): The filename prefix of the png files.
If the prefix is "somepath/xxx", the png files will be
named "somepath/xxx.png". Default: '.format_cityscapes'.
format_only (bool): Only format result for results commit without
perform evaluation. It is useful when you want to format the result
to a specific format and submit it to the test server.
Defaults to False.
keep_results (bool): Whether to keep the results. When ``format_only``
is True, ``keep_results`` must be True. Defaults to False.
collect_device (str): Device name used for collecting results from
different ranks during distributed training. Must be 'cpu' or
'gpu'. Defaults to 'cpu'.
Expand All @@ -35,19 +46,34 @@ class CitysMetric(BaseMetric):
"""

def __init__(self,
output_dir: str,
ignore_index: int = 255,
citys_metrics: List[str] = ['cityscapes'],
to_label_id: bool = True,
suffix: str = '.format_cityscapes',
format_only: bool = False,
keep_results: bool = False,
collect_device: str = 'cpu',
prefix: Optional[str] = None) -> None:
super().__init__(collect_device=collect_device, prefix=prefix)

if CSEval is None:
raise ImportError('Please run "pip install cityscapesscripts" to '
'install cityscapesscripts first.')
self.output_dir = output_dir
self.ignore_index = ignore_index
self.metrics = citys_metrics
assert self.metrics[0] == 'cityscapes'
self.to_label_id = to_label_id
self.suffix = suffix

self.format_only = format_only
if format_only:
assert keep_results, (
'When format_only is True, the results must be keep, please '
f'set keep_results as True, but got {keep_results}')
self.keep_results = keep_results
self.prefix = prefix
if is_main_process():
mkdir_or_exist(self.output_dir)

@master_only
def __del__(self) -> None:
"""Clean up."""
if not self.keep_results:
shutil.rmtree(self.output_dir)

def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None:
"""Process one batch of data and data_samples.
Expand All @@ -59,26 +85,23 @@ def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None:
data_batch (dict): A batch of data from the dataloader.
data_samples (Sequence[dict]): A batch of outputs from the model.
"""
mkdir_or_exist(self.suffix)
mkdir_or_exist(self.output_dir)

for data_sample in data_samples:
pred_label = data_sample['pred_sem_seg']['data'][0].cpu().numpy()
# results2img
if self.to_label_id:
pred_label = self._convert_to_label_id(pred_label)
# when evaluating with official cityscapesscripts,
# labelIds should be used
pred_label = self._convert_to_label_id(pred_label)
basename = osp.splitext(osp.basename(data_sample['img_path']))[0]
png_filename = osp.join(self.suffix, f'{basename}.png')
png_filename = osp.abspath(
osp.join(self.output_dir, f'{basename}.png'))
output = Image.fromarray(pred_label.astype(np.uint8)).convert('P')
import cityscapesscripts.helpers.labels as CSLabels
palette = np.zeros((len(CSLabels.id2label), 3), dtype=np.uint8)
for label_id, label in CSLabels.id2label.items():
palette[label_id] = label.color
output.putpalette(palette)
output.save(png_filename)

ann_dir = osp.join(data_samples[0]['seg_map_path'].split('val')[0],
'val')
self.results.append(ann_dir)
# when evaluating with official cityscapesscripts,
# **_gtFine_labelIds.png is used
gt_filename = data_sample['seg_map_path'].replace(
'labelTrainIds.png', 'labelIds.png')
self.results.append((png_filename, gt_filename))

def compute_metrics(self, results: list) -> Dict[str, float]:
"""Compute the metrics from processed results.
Expand All @@ -90,38 +113,28 @@ def compute_metrics(self, results: list) -> Dict[str, float]:
dict[str: float]: Cityscapes evaluation results.
"""
logger: MMLogger = MMLogger.get_current_instance()
try:
import cityscapesscripts.evaluation.evalPixelLevelSemanticLabeling as CSEval # noqa
except ImportError:
raise ImportError('Please run "pip install cityscapesscripts" to '
'install cityscapesscripts first.')
msg = 'Evaluating in Cityscapes style'
if self.format_only:
logger.info(f'results are saved to {osp.dirname(self.output_dir)}')
return OrderedDict()

msg = 'Evaluating in Cityscapes style'
if logger is None:
msg = '\n' + msg
print_log(msg, logger=logger)

result_dir = self.suffix

eval_results = dict()
print_log(f'Evaluating results under {result_dir} ...', logger=logger)
print_log(
f'Evaluating results under {self.output_dir} ...', logger=logger)

CSEval.args.evalInstLevelScore = True
CSEval.args.predictionPath = osp.abspath(result_dir)
CSEval.args.predictionPath = osp.abspath(self.output_dir)
CSEval.args.evalPixelAccuracy = True
CSEval.args.JSONOutput = False

seg_map_list = []
pred_list = []
ann_dir = results[0]
# when evaluating with official cityscapesscripts,
# **_gtFine_labelIds.png is used
for seg_map in scandir(ann_dir, 'gtFine_labelIds.png', recursive=True):
seg_map_list.append(osp.join(ann_dir, seg_map))
pred_list.append(CSEval.getPrediction(CSEval.args, seg_map))
pred_list, gt_list = zip(*results)
metric = dict()
eval_results.update(
CSEval.evaluateImgLists(pred_list, seg_map_list, CSEval.args))
CSEval.evaluateImgLists(pred_list, gt_list, CSEval.args))
metric['averageScoreCategories'] = eval_results[
'averageScoreCategories']
metric['averageScoreInstCategories'] = eval_results[
Expand All @@ -133,7 +146,6 @@ def _convert_to_label_id(result):
"""Convert trainId to id for cityscapes."""
if isinstance(result, str):
result = np.load(result)
import cityscapesscripts.helpers.labels as CSLabels
result_copy = result.copy()
for trainId, label in CSLabels.trainId2label.items():
result_copy[result == trainId] = label.id
Expand Down
44 changes: 26 additions & 18 deletions tests/test_evaluation/test_metrics/test_citys_metric.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from unittest import TestCase

import numpy as np
import pytest
import torch
from mmengine.structures import BaseDataElement, PixelData

from mmseg.evaluation import CitysMetric
from mmseg.evaluation import CityscapesMetric
from mmseg.structures import SegDataSample


class TestCitysMetric(TestCase):
class TestCityscapesMetric(TestCase):

def _demo_mm_inputs(self,
batch_size=1,
Expand Down Expand Up @@ -42,9 +44,8 @@ def _demo_mm_inputs(self,
gt_sem_seg_data = dict(data=gt_semantic_seg)
data_sample.gt_sem_seg = PixelData(**gt_sem_seg_data)
mm_inputs['data_sample'] = data_sample.to_dict()
mm_inputs['data_sample']['seg_map_path'] = \
'tests/data/pseudo_cityscapes_dataset/gtFine/val/\
frankfurt/frankfurt_000000_000294_gtFine_labelTrainIds.png'
mm_inputs['data_sample'][
'seg_map_path'] = 'tests/data/pseudo_cityscapes_dataset/gtFine/val/frankfurt/frankfurt_000000_000294_gtFine_labelTrainIds.png' # noqa

mm_inputs['seg_map_path'] = mm_inputs['data_sample'][
'seg_map_path']
Expand Down Expand Up @@ -86,9 +87,8 @@ def _demo_mm_model_output(self,
for pred in batch_datasampes:
if isinstance(pred, BaseDataElement):
test_data = pred.to_dict()
test_data['img_path'] = \
'tests/data/pseudo_cityscapes_dataset/leftImg8bit/val/\
frankfurt/frankfurt_000000_000294_leftImg8bit.png'
test_data[
'img_path'] = 'tests/data/pseudo_cityscapes_dataset/leftImg8bit/val/frankfurt/frankfurt_000000_000294_leftImg8bit.png' # noqa

_predictions.append(test_data)
else:
Expand All @@ -104,15 +104,23 @@ def test_evaluate(self):
dict(**data, **result)
for data, result in zip(data_batch, predictions)
]
iou_metric = CitysMetric(citys_metrics=['cityscapes'])
iou_metric.process(data_batch, data_samples)
res = iou_metric.evaluate(6)
self.assertIsInstance(res, dict)
# test to_label_id = True
iou_metric = CitysMetric(
citys_metrics=['cityscapes'], to_label_id=True)
iou_metric.process(data_batch, data_samples)
res = iou_metric.evaluate(6)
# test keep_results should be True when format_only is True
with pytest.raises(AssertionError):
CityscapesMetric(
output_dir='tmp', format_only=True, keep_results=False)

# test evaluate with cityscape metric
metric = CityscapesMetric(output_dir='tmp')
metric.process(data_batch, data_samples)
res = metric.evaluate(2)
self.assertIsInstance(res, dict)

# test format_only
metric = CityscapesMetric(
output_dir='tmp', format_only=True, keep_results=True)
metric.process(data_batch, data_samples)
metric.evaluate(2)
assert osp.exists('tmp')
assert osp.isfile('tmp/frankfurt_000000_000294_leftImg8bit.png')
import shutil
shutil.rmtree('.format_cityscapes')
shutil.rmtree('tmp')

0 comments on commit a8aafdd

Please sign in to comment.