Skip to content

Commit

Permalink
cache_url fix for inference tools
Browse files Browse the repository at this point in the history
Reviewed By: ir413

Differential Revision: D7364656

fbshipit-source-id: bfd31bc7c95b9606037c2f5546c9edd0e0318272
  • Loading branch information
rbgirshick authored and facebook-github-bot committed Mar 22, 2018
1 parent eddb130 commit e7ad1b4
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 2 deletions.
17 changes: 15 additions & 2 deletions tools/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

import argparse
import cv2 # NOQA (Must import before importing caffe2 due to bug in cv2)
import logging
import os
import sys
import yaml
Expand All @@ -39,6 +40,7 @@
from core.config import cfg
from core.config import merge_cfg_from_cfg
from core.config import merge_cfg_from_file
from utils.io import cache_url
import core.rpn_generator as rpn_engine
import core.test_engine as model_engine
import datasets.dummy_datasets as dummy_datasets
Expand Down Expand Up @@ -112,6 +114,7 @@ def get_rpn_box_proposals(im, args):


def main(args):
logger = logging.getLogger(__name__)
dummy_coco_dataset = dummy_datasets.get_coco_dataset()
cfg_orig = yaml.load(yaml.dump(cfg))
im = cv2.imread(args.im_file)
Expand Down Expand Up @@ -144,6 +147,11 @@ def main(args):
cls_keyps = cls_keyps_ if cls_keyps_ is not None else cls_keyps
workspace.ResetWorkspace()

out_name = os.path.join(
args.output_dir, '{}'.format(os.path.basename(args.im_file) + '.pdf')
)
logger.info('Processing {} -> {}'.format(args.im_file, out_name))

vis_utils.vis_one_image(
im[:, :, ::-1],
args.im_file,
Expand All @@ -165,13 +173,18 @@ def check_args(args):
(args.rpn_pkl is None and args.rpn_cfg is None)
)
if args.rpn_pkl is not None:
args.rpn_pkl = cache_url(args.rpn_pkl, cfg.DOWNLOAD_CACHE)
assert os.path.exists(args.rpn_pkl)
assert os.path.exists(args.rpn_cfg)
if args.models_to_run is not None:
assert len(args.models_to_run) % 2 == 0
for model_file in args.models_to_run:
for i, model_file in enumerate(args.models_to_run):
if len(model_file) > 0:
assert os.path.exists(model_file)
if i % 2 == 0:
model_file = cache_url(model_file, cfg.DOWNLOAD_CACHE)
args.models_to_run[i] = model_file
assert os.path.exists(model_file), \
'\'{}\' does not exist'.format(model_file)


if __name__ == '__main__':
Expand Down
2 changes: 2 additions & 0 deletions tools/infer_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from core.config import assert_and_infer_cfg
from core.config import cfg
from core.config import merge_cfg_from_file
from utils.io import cache_url
from utils.timer import Timer
import core.test_engine as infer_engine
import datasets.dummy_datasets as dummy_datasets
Expand Down Expand Up @@ -94,6 +95,7 @@ def main(args):
logger = logging.getLogger(__name__)
merge_cfg_from_file(args.cfg)
cfg.NUM_GPUS = 1
args.weights = cache_url(args.weights, cfg.DOWNLOAD_CACHE)
assert_and_infer_cfg()
model = infer_engine.initialize_model_from_cfg(args.weights)
dummy_coco_dataset = dummy_datasets.get_coco_dataset()
Expand Down

0 comments on commit e7ad1b4

Please sign in to comment.