Skip to content

Commit

Permalink
save the json results in eval
Browse files Browse the repository at this point in the history
  • Loading branch information
wangyingming committed Jan 18, 2022
1 parent c39246b commit f2be216
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 4 deletions.
8 changes: 7 additions & 1 deletion datasets/coco_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@


class CocoEvaluator(object):
def __init__(self, coco_gt, iou_types):
def __init__(self, coco_gt, iou_types, save_json=False):
assert isinstance(iou_types, (list, tuple))
coco_gt = copy.deepcopy(coco_gt)
self.coco_gt = coco_gt
Expand All @@ -41,12 +41,18 @@ def __init__(self, coco_gt, iou_types):
self.img_ids = []
self.eval_imgs = {k: [] for k in iou_types}

self.save_json = save_json
if save_json:
self.results = {k: [] for k in iou_types}

def update(self, predictions):
img_ids = list(np.unique(list(predictions.keys())))
self.img_ids.extend(img_ids)

for iou_type in self.iou_types:
results = self.prepare(predictions, iou_type)
if self.save_json:
self.results[iou_type].extend(results)

# suppress pycocotools prints
with open(os.devnull, 'w') as devnull:
Expand Down
11 changes: 9 additions & 2 deletions engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,


@torch.no_grad()
def evaluate(model, criterion, postprocessors, data_loader, base_ds, device, output_dir):
def evaluate(model, criterion, postprocessors, data_loader, base_ds, device, output_dir, save_json=False):
model.eval()
criterion.eval()

Expand All @@ -90,7 +90,7 @@ def evaluate(model, criterion, postprocessors, data_loader, base_ds, device, out
header = 'Test:'

iou_types = tuple(k for k in ('segm', 'bbox') if k in postprocessors.keys())
coco_evaluator = CocoEvaluator(base_ds, iou_types)
coco_evaluator = CocoEvaluator(base_ds, iou_types, save_json=save_json)
# coco_evaluator.coco_eval[iou_types[0]].params.iouThrs = [0, 0.1, 0.5, 0.75]

panoptic_evaluator = None
Expand Down Expand Up @@ -148,6 +148,13 @@ def evaluate(model, criterion, postprocessors, data_loader, base_ds, device, out
if panoptic_evaluator is not None:
panoptic_evaluator.synchronize_between_processes()

if save_json:
all_res = utils.all_gather(coco_evaluator.results['bbox'])
results=[]
for p in all_res:
results.extend(p)
coco_evaluator.results['bbox'] = results

# accumulate predictions from all images
if coco_evaluator is not None:
coco_evaluator.accumulate()
Expand Down
5 changes: 4 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,9 +273,12 @@ def match_name_keywords(n, name_keywords):

if args.eval:
test_stats, coco_evaluator = evaluate(model, criterion, postprocessors,
data_loader_val, base_ds, device, args.output_dir)
data_loader_val, base_ds, device, args.output_dir, save_json=True)
if args.output_dir:
utils.save_on_master(coco_evaluator.coco_eval["bbox"].eval, output_dir / "eval.pth")
if utils.is_main_process():
with open(os.path.join(args.output_dir, 'results.json'), 'w') as f:
json.dump(coco_evaluator.results['bbox'], f)
return

print("Start training")
Expand Down

0 comments on commit f2be216

Please sign in to comment.