diff --git a/docproduct/predictor.py b/docproduct/predictor.py index e5537f4..21aabed 100644 --- a/docproduct/predictor.py +++ b/docproduct/predictor.py @@ -288,7 +288,7 @@ def predict(self, questions, search_by='answer', topk=5, answer_only=False): gpt2_input = self._get_gpt2_inputs( questions[0], topk_question, topk_answer) gpt2_pred = self.estimator.predict( - lambda: gpt2_estimator.predict_input_fn(inputs=gpt2_input, batch_size=self.batch_size)) + lambda: gpt2_estimator.predict_input_fn(inputs=gpt2_input, batch_size=self.batch_size, checkpoint_path=gpt2_weight_file)) raw_output = gpt2_estimator.predictions_parsing( gpt2_pred, self.encoder) result_list = [re.search('`ANSWER:(.*)`QUESTION:', s)