Skip to content

Commit

Permalink
[bugfix] Generation parsing fix
Browse files Browse the repository at this point in the history
  • Loading branch information
JayYip committed May 27, 2019
1 parent 15e9236 commit 603595e
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 20 deletions.
34 changes: 16 additions & 18 deletions docproduct/predictor.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
import json
import os
import re
from collections import defaultdict
from multiprocessing import Pool, cpu_count
from time import time

import tensorflow as tf
import faiss
import numpy as np
from time import time
from tqdm import tqdm
import pandas as pd
from multiprocessing import Pool, cpu_count
import faiss
import json
import tensorflow as tf
from tqdm import tqdm

import gpt2_estimator
from docproduct.dataset import convert_text_to_feature
from docproduct.models import MedicalQAModelwithBert
from docproduct.tokenization import FullTokenizer
from keras_bert.loader import checkpoint_loader
import gpt2_estimator


def load_weight(model, bert_ffn_weight_file=None, ffn_weight_file=None):
Expand Down Expand Up @@ -290,17 +291,14 @@ def predict(self, questions, search_by='answer', topk=5, answer_only=False):
lambda: gpt2_estimator.predict_input_fn(inputs=gpt2_input, batch_size=self.batch_size))
raw_output = gpt2_estimator.predictions_parsing(
gpt2_pred, self.encoder)
# original_line = '`QUESTION: %s `ANSWER: ' % questions
# output_list = []
# for output_ind, output_chunk in enumerate(raw_output[0].split(original_line)):
# if output_ind == 0:
# pass
# else:
# output_list.append(output_chunk.split('`QUESTION')[0])

# clipped_output = raw_output[0].split(
# '`QUESTION')[1].split('`ANSWER:')[1]
return raw_output
result_list = [re.search('`ANSWER:(.*)`QUESTION:', s)
for s in raw_output]
result_list = [s for s in result_list if s]
try:
r = result_list[0].group(1)
except AttributeError:
r = ''
return r


if __name__ == "__main__":
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@


setup(
name='MedicalQA',
name='docproduct',
version='0.2.0',
packages=find_packages(),
url='https://github.com/Santosh-Gupta/MedicalQA',
url='https://github.com/re-search/DocProduct',
license='MIT',
author='MedicalQATeam',
author_email='[email protected]',
Expand Down

0 comments on commit 603595e

Please sign in to comment.