Skip to content

Commit

Permalink
[minor] Some params changed
Browse files Browse the repository at this point in the history
  • Loading branch information
JayYip committed May 20, 2019
1 parent 42443e9 commit 910d76c
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 7 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -132,4 +132,5 @@ data/
pubmed_pmc_470k/
gpt_2/output.txt
models/
qa_embeddings/
qa_embeddings/
*.code-work*
9 changes: 7 additions & 2 deletions docproduct/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def getEmbedding(self, questions, search_by='answer', topk=5, answer_only=True):

class GenerateQADoc(object):
def __init__(self,
pretrained_path='pubmed_pmc_470k/',
pretrained_path='models/pubmed_pmc_470k/',
ffn_weight_file=None,
bert_ffn_weight_file='models/bertffn_crossentropy/bertffn',
gpt2_weight_file='models/gpt2',
Expand Down Expand Up @@ -281,7 +281,7 @@ def predict(self, questions, search_by='answer', topk=5, answer_only=False):
embedding, search_by, topk, answer_only)

gpt2_input = self._get_gpt2_inputs(
questions, topk_question, topk_answer)
questions[0], topk_question, topk_answer)
gpt2_pred = self.estimator.predict(
lambda: gpt2_estimator.predict_input_fn(inputs=gpt2_input, batch_size=self.batch_size))
raw_output = gpt2_estimator.predictions_parsing(
Expand All @@ -297,3 +297,8 @@ def predict(self, questions, search_by='answer', topk=5, answer_only=False):
clipped_output = raw_output[0].split(
'`QUESTION')[1].split('`ANSWER:')[1]
return clipped_output


if __name__ == "__main__":
gen = GenerateQADoc()
print(gen.predict('my eyes hurt'))
8 changes: 4 additions & 4 deletions docproduct/train_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def train_gpt2(
model_dir='models/gpt2',
pretrained_path='models/117M',
steps=100000,
batch_size=4,
batch_size=3,
num_gpu=4,
learning_rate=0.0001):
"""Function to train the GPT2 model
Expand Down Expand Up @@ -53,12 +53,12 @@ def train_gpt2(
session_config=session_config,
train_distribute=mirrored_strategy,
eval_distribute=mirrored_strategy,
log_step_count_steps=500)
log_step_count_steps=50)

gpt2_model_fn = gpt2_estimator.get_gpt2_model_fn(
accumulate_gradients=5,
accumulate_gradients=1,
learning_rate=learning_rate,
length=512,
length=600,
batch_size=batch_size,
temperature=0.7,
top_k=0
Expand Down

0 comments on commit 910d76c

Please sign in to comment.