-
Notifications
You must be signed in to change notification settings - Fork 36
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added evaluation - have to wait til I have a trained model to fully t…
…est it Changed from "upload_tags" in config to using separate stages for train and eval upload
- Loading branch information
Showing
3 changed files
with
203 additions
and
66 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import logging | ||
import fiftyone as fo | ||
from fiftyone import ViewField as F | ||
import os | ||
import numpy as np | ||
|
||
def evaluate_detection_model(dataset_name, prediction_field, evaluation_key): | ||
|
||
dataset = fo.load_dataset(dataset_name) | ||
|
||
view = dataset.match_tags(evaluation_name) | ||
|
||
# setting an empty detections field if there isn't one | ||
for sample in view: | ||
if sample["detections"] == None: | ||
sample["detections"] = fo.Detections(detections=[]) | ||
sample.save() | ||
|
||
results = view.evaluate_detections( prediction_field, gt_field="detections", eval_key=evaluation_key) | ||
|
||
# Get the 10 most common classes in the dataset | ||
counts = view.count_values("detections.detections.label") | ||
classes = sorted(counts, key=counts.get, reverse=True)[:10] | ||
|
||
# Print a classification report for the top-10 classes | ||
results.print_report(classes=classes) | ||
|
||
# Print some statistics about the total TP/FP/FN counts | ||
logging.info("TP: %d" % dataset.sum(evaluation_key + "_tp")) | ||
logging.info("FP: %d" % dataset.sum(evaluation_key + "_fp")) | ||
logging.info("FN: %d" % dataset.sum(evaluation_key + "_fn")) | ||
|
||
# Create a view that has samples with the most false positives first, and | ||
# only includes false positive boxes in the `predictions` field | ||
eval_view = (view | ||
.sort_by(evaluation_key + "_fp", reverse=True) | ||
.filter_labels(prediction_field, F(evaluation_key) == "fp") | ||
) | ||
logging.info("mAP: {}".format(results.mAP())) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters