Skip to content

Commit

Permalink
[feature] Add reading pkl and parquet embedding
Browse files Browse the repository at this point in the history
  • Loading branch information
JayYip committed May 27, 2019
1 parent 2f0f84e commit 15e9236
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 2 deletions.
6 changes: 5 additions & 1 deletion docproduct/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,11 @@ class FaissTopK(object):
def __init__(self, embedding_file):
super(FaissTopK, self).__init__()
self.embedding_file = embedding_file
self.df = pd.read_parquet(self.embedding_file)
_, ext = os.path.splitext(self.embedding_file)
if ext == '.pkl':
self.df = pd.read_pickle(self.embedding_file)
else:
self.df = pd.read_parquet(self.embedding_file)
self._get_faiss_index()
# self.df.drop(columns=["Q_FFNN_embeds", "A_FFNN_embeds"], inplace=True)

Expand Down
7 changes: 6 additions & 1 deletion docproduct/train_embedding_to_gpt2_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,12 @@ def train_embedding_to_gpt2_data(
batch_size {int} -- Retreive batch size of FAISS (default: {512})
"""
qa = pd.read_parquet(data_path)
_, ext = os.path.splitext(data_path)
if ext == '.pkl':
qa = pd.read_pickle(data_path)
else:
qa = pd.read_parquet(data_path)
# qa = pd.read_parquet(data_path)
question_bert = qa["Q_FFNN_embeds"].tolist()
answer_bert = qa["A_FFNN_embeds"].tolist()
question_bert = np.array(question_bert)
Expand Down

0 comments on commit 15e9236

Please sign in to comment.