Skip to content

Commit

Permalink
DensePassageRetriever: Add Training, Refactor Inference to FARM modul…
Browse files Browse the repository at this point in the history
…es (#527)

* dpr training and inference code refactored with FARM modules

* dpr test cases modified

* docstring and default arguments updated

* dpr training docstring updated

* bugfix in dense retriever inference, DPR tutorials modified

* Bump FARM to 0.5.0

* update README for DPR

* dpr training and inference code refactored with FARM modules

* dpr test cases modified

* docstring and default arguments updated

* dpr training docstring updated

* bugfix in dense retriever inference, DPR tutorials modified

* Bump FARM to 0.5.0

* update README for DPR

* mypy errors fix

* DPR instantiation bugfix

* Fix DPR init in RAG Tutorial

Co-authored-by: Malte Pietsch <[email protected]>
  • Loading branch information
kolk and tholor committed Oct 30, 2020
1 parent f134430 commit 72b637a
Show file tree
Hide file tree
Showing 9 changed files with 420 additions and 383 deletions.
7 changes: 5 additions & 2 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,11 @@ Example
.. code-block:: python
retriever = DensePassageRetriever(document_store=document_store,
embedding_model="dpr-bert-base-nq",
do_lower_case=True, use_gpu=True)
query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
use_gpu=True,
batch_size=16,
embed_title=True)
retriever.retrieve(query="Why did the revenue increase?")
# returns: [Document, Document]
Expand Down
369 changes: 202 additions & 167 deletions haystack/retriever/dense.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,4 @@ tika
uvloop; sys_platform != 'win32' and sys_platform != 'cygwin'
httptools
nltk
more_itertools
more_itertools
14 changes: 5 additions & 9 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def dpr_retriever(faiss_document_store):
passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
use_gpu=False,
embed_title=True,
remove_sep_tok_from_untitled_passages=True
use_fast_tokenizers=True
)


Expand Down Expand Up @@ -288,14 +288,10 @@ def get_document_store(document_store_type, faiss_document_store, inmemory_docum
def get_retriever(retriever_type, document_store):

if retriever_type == "dpr":
retriever = DensePassageRetriever(
document_store=document_store,
query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
use_gpu=False,
embed_title=True,
remove_sep_tok_from_untitled_passages=True
)
retriever = DensePassageRetriever(document_store=document_store,
query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
use_gpu=False, embed_title=True)
elif retriever_type == "tfidf":
return TfidfRetriever(document_store=document_store)
elif retriever_type == "embedding":
Expand Down
10 changes: 5 additions & 5 deletions test/test_dpr_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,11 @@ def test_dpr_retrieval(document_store, retriever, return_embedding):
# FAISSDocumentStore doesn't return embeddings, so these tests only work with ElasticsearchDocumentStore
if isinstance(document_store, ElasticsearchDocumentStore):
assert (len(docs_with_emb[0].embedding) == 768)
assert (abs(docs_with_emb[0].embedding[0] - (-0.30634)) < 0.001)
assert (abs(docs_with_emb[1].embedding[0] - (-0.37449)) < 0.001)
assert (abs(docs_with_emb[2].embedding[0] - (-0.24695)) < 0.001)
assert (abs(docs_with_emb[3].embedding[0] - (-0.08017)) < 0.001)
assert (abs(docs_with_emb[4].embedding[0] - (-0.01534)) < 0.001)
assert (abs(docs_with_emb[0].embedding[0] - (-0.3063)) < 0.001)
assert (abs(docs_with_emb[1].embedding[0] - (-0.3914)) < 0.001)
assert (abs(docs_with_emb[2].embedding[0] - (-0.2470)) < 0.001)
assert (abs(docs_with_emb[3].embedding[0] - (-0.0802)) < 0.001)
assert (abs(docs_with_emb[4].embedding[0] - (-0.0551)) < 0.001)

res = retriever.retrieve(query="Which philosopher attacked Schopenhauer?")

Expand Down
7 changes: 4 additions & 3 deletions tutorials/Tutorial6_Better_Retrieval_via_DPR.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -515,11 +515,12 @@
"retriever = DensePassageRetriever(document_store=document_store,\n",
" query_embedding_model=\"facebook/dpr-question_encoder-single-nq-base\",\n",
" passage_embedding_model=\"facebook/dpr-ctx_encoder-single-nq-base\",\n",
" max_seq_len_query=64,\n",
" max_seq_len_passage=256,\n",
" batch_size=16,\n",
" use_gpu=True,\n",
" embed_title=True,\n",
" max_seq_len=256,\n",
" batch_size=16,\n",
" remove_sep_tok_from_untitled_passages=True)\n",
" use_fast_tokenizers=True)\n",
"# Important: \n",
"# Now that after we have the DPR initialized, we need to call update_embeddings() to iterate over all\n",
"# previously indexed documents and update their embedding representation. \n",
Expand Down
6 changes: 5 additions & 1 deletion tutorials/Tutorial6_Better_Retrieval_via_DPR.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,13 @@
retriever = DensePassageRetriever(document_store=document_store,
query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
max_seq_len_query=64,
max_seq_len_passage=256,
batch_size=2,
use_gpu=True,
embed_title=True,
remove_sep_tok_from_untitled_passages=True)
use_fast_tokenizers=True
)

# Important:
# Now that after we have the DPR initialized, we need to call update_embeddings() to iterate over all
Expand Down
Loading

0 comments on commit 72b637a

Please sign in to comment.