Skip to content

Commit

Permalink
Compute Precision and recall per class
Browse files Browse the repository at this point in the history
Doing ./darknet detector map ... now returns precision and recall per class instead of a global precision and recall
  • Loading branch information
agirbau authored Mar 22, 2019
1 parent 1cd332e commit 198d169
Showing 1 changed file with 33 additions and 4 deletions.
37 changes: 33 additions & 4 deletions src/detector.c
Original file line number Diff line number Diff line change
Expand Up @@ -751,7 +751,18 @@ float validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, floa
int unique_truth_count = 0;

int* truth_classes_count = (int*)calloc(classes, sizeof(int));


// For multi-class precision and recall computation
float avg_iou_per_class[classes];
int tp_for_thresh_per_class[classes];
int fp_for_thresh_per_class[classes];
int unique_truth_count_per_class[classes];

memset(avg_iou_per_class, 0.0, classes * sizeof(float));
memset(tp_for_thresh_per_class, 0, classes * sizeof(int));
memset(fp_for_thresh_per_class, 0, classes * sizeof(int));
memset(unique_truth_count_per_class, 0, classes * sizeof(int));

for (t = 0; t < nthreads; ++t) {
args.path = paths[i + t];
args.im = &buf[t];
Expand Down Expand Up @@ -800,6 +811,7 @@ float validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, floa
int i, j;
for (j = 0; j < num_labels; ++j) {
truth_classes_count[truth[j].id]++;
unique_truth_count_per_class[truth[j].id]++;
}

// difficult
Expand Down Expand Up @@ -876,9 +888,13 @@ float validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, floa
if (truth_index > -1 && found == 0) {
avg_iou += max_iou;
++tp_for_thresh;
avg_iou_per_class[class_id] += max_iou;
tp_for_thresh_per_class[class_id]++;
}
else
else{
fp_for_thresh++;
fp_for_thresh_per_class[class_id]++;
}
}
}
}
Expand All @@ -905,7 +921,11 @@ float validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, floa
if ((tp_for_thresh + fp_for_thresh) > 0)
avg_iou = avg_iou / (tp_for_thresh + fp_for_thresh);


for(int class_id = 0; class_id < classes; class_id++){
if ((tp_for_thresh_per_class[class_id] + fp_for_thresh_per_class[class_id]) > 0)
avg_iou_per_class[class_id] = avg_iou_per_class[class_id] / (tp_for_thresh_per_class[class_id] + fp_for_thresh_per_class[class_id]);
}

// SORT(detections)
qsort(detections, detections_count, sizeof(box_prob), detections_comparator);

Expand Down Expand Up @@ -983,6 +1003,8 @@ float validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, floa

for (i = 0; i < classes; ++i) {
double avg_precision = 0;
float class_precision = 0.0;
float class_recall = 0.0;
int point;
for (point = 0; point < 11; ++point) {
double cur_recall = point * 0.1;
Expand All @@ -1000,10 +1022,17 @@ float validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, floa
avg_precision += cur_precision;
}
avg_precision = avg_precision / 11;
printf("class_id = %d, name = %s, \t ap = %2.2f %% \n", i, names[i], avg_precision * 100);

printf("\nTP = %d, FP = %d \n", tp_for_thresh_per_class[i], fp_for_thresh_per_class[i]);
class_precision = (float)tp_for_thresh_per_class[i] / ((float)tp_for_thresh_per_class[i] + (float)fp_for_thresh_per_class[i]);
class_recall = (float)tp_for_thresh_per_class[i] / ((float)tp_for_thresh_per_class[i] + (float)(unique_truth_count_per_class[i] - tp_for_thresh_per_class[i]));

printf("class_id = %d, name = %s, \t P = %1.2f, \t R = %1.2f, \t avg IOU = %2.2f %%, \t ap = %2.2f %% \n", i, names[i], class_precision, class_recall, avg_iou_per_class[i], avg_precision * 100);
mean_average_precision += avg_precision;
}

printf("TP = %1.2f, FP = %1.2f \n", (float)tp_for_thresh, (float)fp_for_thresh);

const float cur_precision = (float)tp_for_thresh / ((float)tp_for_thresh + (float)fp_for_thresh);
const float cur_recall = (float)tp_for_thresh / ((float)tp_for_thresh + (float)(unique_truth_count - tp_for_thresh));
const float f1_score = 2.F * cur_precision * cur_recall / (cur_precision + cur_recall);
Expand Down

0 comments on commit 198d169

Please sign in to comment.