Skip to content

Commit

Permalink
[Feature]: Support print per class AP. (#395)
Browse files Browse the repository at this point in the history
* [Feature]: Support print per class AP.

* update requirements
  • Loading branch information
RangiLyu committed Jan 26, 2022
1 parent 8cb9044 commit 3996f81
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 3 deletions.
1 change: 1 addition & 0 deletions .github/workflows/workflow.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ jobs:
python -m pip install ninja opencv-python-headless onnx pytest-xdist codecov
python -m pip install torch==${{matrix.torch}}+cpu torchvision==${{matrix.torchvision}}+cpu -f https://download.pytorch.org/whl/torch_stable.html
python -m pip install Cython termcolor numpy tensorboard pycocotools matplotlib pyaml opencv-python tqdm pytorch-lightning torchmetrics codecov flake8 pytest
python -m pip install -r requirements.txt
- name: Setup
run: rm -rf .eggs && python setup.py develop
- name: Run unittests and generate coverage report
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ conda install pytorch torchvision cudatoolkit=11.1 -c pytorch -c conda-forge
3. Install requirements

```shell script
pip install Cython termcolor numpy tensorboard pycocotools matplotlib pyaml opencv-python tqdm pytorch-lightning torchmetrics
pip install -r requirements.txt
```

4. Setup NanoDet
Expand Down
1 change: 1 addition & 0 deletions nanodet/data/dataset/coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def get_data_info(self, ann_path):
self.cat_ids = sorted(self.coco_api.getCatIds())
self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)}
self.cats = self.coco_api.loadCats(self.cat_ids)
self.class_names = [cat["name"] for cat in self.cats]
self.img_ids = sorted(self.coco_api.imgs.keys())
img_info = self.coco_api.loadImgs(self.img_ids)
return img_info
Expand Down
58 changes: 57 additions & 1 deletion nanodet/evaluator/coco_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import contextlib
import copy
import io
import itertools
import json
import logging
import os
import warnings

import numpy as np
from pycocotools.cocoeval import COCOeval
from tabulate import tabulate

logger = logging.getLogger("NanoDet")


def xyxy2xywh(bbox):
Expand All @@ -37,6 +45,7 @@ def xyxy2xywh(bbox):
class CocoDetectionEvaluator:
def __init__(self, dataset):
assert hasattr(dataset, "coco_api")
self.class_names = dataset.class_names
self.coco_api = dataset.coco_api
self.cat_ids = dataset.cat_ids
self.metric_names = ["mAP", "AP_50", "AP_75", "AP_small", "AP_m", "AP_l"]
Expand Down Expand Up @@ -85,7 +94,54 @@ def evaluate(self, results, save_dir, rank=-1):
)
coco_eval.evaluate()
coco_eval.accumulate()
coco_eval.summarize()

# use logger to log coco eval results
redirect_string = io.StringIO()
with contextlib.redirect_stdout(redirect_string):
coco_eval.summarize()
logger.info("\n" + redirect_string.getvalue())

# print per class AP
headers = ["class", "AP50", "mAP"]
colums = 6
per_class_ap50s = []
per_class_maps = []
precisions = coco_eval.eval["precision"]
# dimension of precisions: [TxRxKxAxM]
# precision has dims (iou, recall, cls, area range, max dets)
assert len(self.class_names) == precisions.shape[2]

for idx, name in enumerate(self.class_names):
# area range index 0: all area ranges
# max dets index -1: typically 100 per image
precision_50 = precisions[0, :, idx, 0, -1]
precision_50 = precision_50[precision_50 > -1]
ap50 = np.mean(precision_50) if precision_50.size else float("nan")
per_class_ap50s.append(float(ap50 * 100))

precision = precisions[:, :, idx, 0, -1]
precision = precision[precision > -1]
ap = np.mean(precision) if precision.size else float("nan")
per_class_maps.append(float(ap * 100))

num_cols = min(colums, len(self.class_names) * len(headers))
flatten_results = []
for name, ap50, mAP in zip(self.class_names, per_class_ap50s, per_class_maps):
flatten_results += [name, ap50, mAP]

row_pair = itertools.zip_longest(
*[flatten_results[i::num_cols] for i in range(num_cols)]
)
table_headers = headers * (num_cols // len(headers))
table = tabulate(
row_pair,
tablefmt="pipe",
floatfmt=".1f",
headers=table_headers,
numalign="left",
)
logger.info("\n" + table)

aps = coco_eval.stats[:6]
eval_results = {}
for k, v in zip(self.metric_names, aps):
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
Cython
matplotlib
numpy
omegaconf>=2.0.1
onnx
onnx-simplifier
opencv-python
pyaml
pycocotools
pytorch-lightning>=1.4.0
omegaconf>=2.0.1
tabulate
tensorboard
termcolor
torch>=1.7
Expand Down

0 comments on commit 3996f81

Please sign in to comment.