From 1326ee97ac1f032fdfbd5245c8356f59b254a9b5 Mon Sep 17 00:00:00 2001 From: Svebor Karaman Date: Thu, 6 Feb 2020 11:01:38 -0500 Subject: [PATCH] base_model_path in conf --- cufacesearch/cufacesearch/searcher/searcher_lopqhbase.py | 3 ++- setup/ConfGenerator/create_conf_searcher.py | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/cufacesearch/cufacesearch/searcher/searcher_lopqhbase.py b/cufacesearch/cufacesearch/searcher/searcher_lopqhbase.py index 5b12e32..2708cf9 100644 --- a/cufacesearch/cufacesearch/searcher/searcher_lopqhbase.py +++ b/cufacesearch/cufacesearch/searcher/searcher_lopqhbase.py @@ -136,6 +136,7 @@ def init_searcher(self): try: # This can fail with a "retrieval incomplete: got only" ... # Or can stall... why? + # Could we change that to download using an s3_storer? download_file(os.path.join(self.base_model_path, self.build_model_str()), self.build_model_str()) lopq_model = pickle.load(open(self.build_model_str(), 'rb')) @@ -157,7 +158,7 @@ def init_searcher(self): full_trace_error(log_msg.format(self.pp, self.build_model_str(), inst)) sys.stdout.flush() else: - log_msg = "[{}: Warning] Could not retrieve pre-trained model as `` was not set." + log_msg = "[{}: Warning] Could not retrieve pre-trained model as `base_model_path` was not set." print(log_msg.format(self.pp, self.build_model_str())) else: log_msg = "[{}: log] Skipped retrieving pre-trained model from s3 as requested." diff --git a/setup/ConfGenerator/create_conf_searcher.py b/setup/ConfGenerator/create_conf_searcher.py index 5037903..f4c88f6 100644 --- a/setup/ConfGenerator/create_conf_searcher.py +++ b/setup/ConfGenerator/create_conf_searcher.py @@ -139,6 +139,7 @@ conf[search_prefix + 'lopq_M'] = int(os.environ['lopq_M']) conf[search_prefix + 'lopq_subq'] = int(os.environ['lopq_subq']) conf[search_prefix + 'reranking'] = bool(int(os.getenv('reranking', 1))) + conf[search_prefix + 'base_model_path'] = os.getenv('base_model_path', None) conf[search_prefix + 'skip_get_sim_info'] = bool(int(os.getenv('skip_get_sim_info', 0))) if conf[search_prefix + 'model_type'] == "lopq_pca": conf[search_prefix + 'nb_train_pca'] = int(os.environ['nb_train_pca'])