Skip to content

Commit

Permalink
support model link in model_dir params
Browse files Browse the repository at this point in the history
  • Loading branch information
WenmuZhou committed Jun 10, 2021
1 parent 037e17f commit 7d47283
Showing 1 changed file with 14 additions and 15 deletions.
29 changes: 14 additions & 15 deletions paddleocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

logger = get_logger()
from ppocr.utils.utility import check_and_read_gif, get_image_file_list
from ppocr.utils.network import maybe_download, download_with_progressbar
from ppocr.utils.network import maybe_download, download_with_progressbar, is_link, confirm_model_dir_url
from tools.infer.utility import draw_ocr, init_args, str2bool

__all__ = ['PaddleOCR']
Expand Down Expand Up @@ -192,20 +192,19 @@ def __init__(self, **kwargs):
'dict_path']

# init model dir
if params.det_model_dir is None:
params.det_model_dir = os.path.join(BASE_DIR, VERSION,
'det', det_lang)
if params.rec_model_dir is None:
params.rec_model_dir = os.path.join(BASE_DIR, VERSION,
'rec', lang)
if params.cls_model_dir is None:
params.cls_model_dir = os.path.join(BASE_DIR, 'cls')
params.det_model_dir, det_url = confirm_model_dir_url(params.det_model_dir,
os.path.join(BASE_DIR, VERSION, 'det', det_lang),
model_urls['det'][det_lang])
params.rec_model_dir, rec_url = confirm_model_dir_url(params.rec_model_dir,
os.path.join(BASE_DIR, VERSION, 'rec', lang),
model_urls['rec'][lang]['url'])
params.cls_model_dir, cls_url = confirm_model_dir_url(params.cls_model_dir,
os.path.join(BASE_DIR, VERSION, 'cls'),
model_urls['cls'])
# download model
maybe_download(params.det_model_dir,
model_urls['det'][det_lang])
maybe_download(params.rec_model_dir,
model_urls['rec'][lang]['url'])
maybe_download(params.cls_model_dir, model_urls['cls'])
maybe_download(params.det_model_dir, det_url)
maybe_download(params.rec_model_dir, rec_url)
maybe_download(params.cls_model_dir, cls_url)

if params.det_algorithm not in SUPPORT_DET_MODEL:
logger.error('det_algorithm must in {}'.format(SUPPORT_DET_MODEL))
Expand Down Expand Up @@ -277,7 +276,7 @@ def main():
# for cmd
args = parse_args(mMain=True)
image_dir = args.image_dir
if image_dir.startswith('http'):
if is_link(image_dir):
download_with_progressbar(image_dir, 'tmp.jpg')
image_file_list = ['tmp.jpg']
else:
Expand Down

0 comments on commit 7d47283

Please sign in to comment.