Skip to content

Commit

Permalink
Update predictor.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Santosh-Gupta committed May 29, 2019
1 parent aa7bbfa commit 9a8146f
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion docproduct/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ def __init__(self,
config = tf.estimator.RunConfig(
session_config=session_config)
self.batch_size = 1
self.gpt2_weight_file = gpt2_weight_file
gpt2_model_fn = gpt2_estimator.get_gpt2_model_fn(
accumulate_gradients=5,
learning_rate=0.1,
Expand Down Expand Up @@ -288,7 +289,7 @@ def predict(self, questions, search_by='answer', topk=5, answer_only=False):
gpt2_input = self._get_gpt2_inputs(
questions[0], topk_question, topk_answer)
gpt2_pred = self.estimator.predict(
lambda: gpt2_estimator.predict_input_fn(inputs=gpt2_input, batch_size=self.batch_size, checkpoint_path=gpt2_weight_file))
lambda: gpt2_estimator.predict_input_fn(inputs=gpt2_input, batch_size=self.batch_size, checkpoint_path=self.gpt2_weight_file))
raw_output = gpt2_estimator.predictions_parsing(
gpt2_pred, self.encoder)
result_list = [re.search('`ANSWER:(.*)`QUESTION:', s)
Expand Down

0 comments on commit 9a8146f

Please sign in to comment.