Skip to content

Commit

Permalink
Restructure update embeddings (#304)
Browse files Browse the repository at this point in the history
* Restructure update embeddings

* Adapt FAISSDocStore

* Adapt test and tutorial

Co-authored-by: Timo Moeller <[email protected]>
  • Loading branch information
bogdankostic and Timoeller authored Aug 18, 2020
1 parent 8a3eca0 commit 72b1013
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 41 deletions.
22 changes: 3 additions & 19 deletions haystack/database/elasticsearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,26 +455,10 @@ def update_embeddings(self, retriever: BaseRetriever, index: Optional[str] = Non
if not self.embedding_field:
raise RuntimeError("Specify the arg `embedding_field` when initializing ElasticsearchDocumentStore()")

# TODO Index embeddings every X batches to avoid OOM for huge document collections
docs = self.get_all_documents(index)
passages = [d.text for d in docs]

#TODO Index embeddings every X batches to avoid OOM for huge document collections
logger.info(f"Updating embeddings for {len(passages)} docs ...")

# TODO send whole Document to retriever and let retriever decide what fields to embed
from haystack.retriever.dense import DensePassageRetriever
if isinstance(retriever,DensePassageRetriever):
titles = []
for d in docs:
if d.meta is not None:
titles.append(d.meta['name'] if 'name' in d.meta.keys() else None)
if len(titles) == len(passages):
embeddings = retriever.embed_passages(passages,titles) # type: ignore
else:
embeddings = retriever.embed_passages(passages) # type: ignore
else: #EmbeddingRetriever
embeddings = retriever.embed_passages(passages) # type: ignore

logger.info(f"Updating embeddings for {len(docs)} docs ...")
embeddings = retriever.embed_passages(docs) # type: ignore
assert len(docs) == len(embeddings)

if embeddings[0].shape[0] != self.embedding_dim:
Expand Down
8 changes: 5 additions & 3 deletions haystack/database/faiss.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,11 @@ def update_embeddings(self, retriever: BaseRetriever, index: Optional[str] = Non
index = index or self.index

documents = self.get_all_documents(index=index)
for doc in documents:
embedding = retriever.embed_passages([doc.text])[0] # type: ignore
doc.embedding = embedding
logger.info(f"Updating embeddings for {len(documents)} docs ...")
embeddings = retriever.embed_passages(documents) # type: ignore
assert len(documents) == len(embeddings)
for i, doc in enumerate(documents):
doc.embedding = embeddings[i]

phi = self._get_phi(documents)

Expand Down
16 changes: 14 additions & 2 deletions haystack/retriever/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def __init__(self,
batch_size: int = 16,
do_lower_case: bool = False,
use_amp: str = None,
embed_title: bool = True
):
"""
Init the Retriever incl. the two encoder models from a local or remote model checkpoint.
Expand Down Expand Up @@ -78,6 +79,7 @@ def __init__(self,

self.use_amp = use_amp
self.do_lower_case = do_lower_case
self.embed_title = embed_title

# Load checkpoint (incl. additional model params)
saved_state = load_states_from_checkpoint(self.embedding_model)
Expand Down Expand Up @@ -122,14 +124,23 @@ def embed_queries(self, texts: List[str]) -> List[np.array]:
tensorizer=self.tensorizer, batch_size=self.batch_size)
return result

def embed_passages(self, texts: List[str], titles: Optional[List[str]] = None) -> List[np.array]:
def embed_passages(self, docs: List[Document]) -> List[np.array]:
"""
Create embeddings for a list of passages using the passage encoder
:param texts: passage to embed
:param titles: passage title to also take into account during embedding
:return: embeddings, one per input passage
"""
texts = [d.text for d in docs]
titles = []
if self.embed_title:
for d in docs:
if d.meta is not None:
titles.append(d.meta["name"] if "name" in d.meta.keys() else None)
if len(titles) != len(texts):
titles = None # type: ignore

result = self._generate_batch_predictions(texts=texts, titles=titles, model=self.passage_encoder,
tensorizer=self.tensorizer, batch_size=self.batch_size)
return result
Expand Down Expand Up @@ -284,12 +295,13 @@ def embed_queries(self, texts: List[str]) -> List[np.array]:
"""
return self.embed(texts)

def embed_passages(self, texts: List[str]) -> List[np.array]:
def embed_passages(self, docs: List[Document]) -> List[np.array]:
"""
Create embeddings for a list of passages. For this Retriever type: The same as calling .embed()
:param texts: passage to embed
:return: embeddings, one per input passage
"""
texts = [d.text for d in docs]

return self.embed(texts)
46 changes: 31 additions & 15 deletions test/test_dpr_retriever.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,44 @@
import pytest
import time

from haystack.retriever.dense import DensePassageRetriever
from haystack.database.base import Document
from haystack.database.elasticsearch import ElasticsearchDocumentStore


@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "memory"], indirect=True)
@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss"], indirect=True)
def test_dpr_inmemory_retrieval(document_store):

documents = [
{'name': '0', 'text': """Aaron Aaron ( or ; ""Ahärôn"") is a prophet, high priest, and the brother of Moses in the Abrahamic religions. Knowledge of Aaron, along with his brother Moses, comes exclusively from religious texts, such as the Bible and Quran. The Hebrew Bible relates that, unlike Moses, who grew up in the Egyptian royal court, Aaron and his elder sister Miriam remained with their kinsmen in the eastern border-land of Egypt (Goshen). When Moses first confronted the Egyptian king about the Israelites, Aaron served as his brother's spokesman (""prophet"") to the Pharaoh. Part of the Law (Torah) that Moses received from"""},
{'name': '1', 'text': """Schopenhauer, describing him as an ultimately shallow thinker: ""Schopenhauer has quite a crude mind ... where real depth starts, his comes to an end."" His friend Bertrand Russell had a low opinion on the philosopher, and attacked him in his famous ""History of Western Philosophy"" for hypocritically praising asceticism yet not acting upon it. On the opposite isle of Russell on the foundations of mathematics, the Dutch mathematician L. E. J. Brouwer incorporated the ideas of Kant and Schopenhauer in intuitionism, where mathematics is considered a purely mental activity, instead of an analytic activity wherein objective properties of reality are"""},
{'name': '2', 'text': """Democratic Republic of the Congo to the south. Angola's capital, Luanda, lies on the Atlantic coast in the northwest of the country. Angola, although located in a tropical zone, has a climate that is not characterized for this region, due to the confluence of three factors: As a result, Angola's climate is characterized by two seasons: rainfall from October to April and drought, known as ""Cacimbo"", from May to August, drier, as the name implies, and with lower temperatures. On the other hand, while the coastline has high rainfall rates, decreasing from North to South and from to , with"""},
Document(
text="""Aaron Aaron ( or ; ""Ahärôn"") is a prophet, high priest, and the brother of Moses in the Abrahamic religions. Knowledge of Aaron, along with his brother Moses, comes exclusively from religious texts, such as the Bible and Quran. The Hebrew Bible relates that, unlike Moses, who grew up in the Egyptian royal court, Aaron and his elder sister Miriam remained with their kinsmen in the eastern border-land of Egypt (Goshen). When Moses first confronted the Egyptian king about the Israelites, Aaron served as his brother's spokesman (""prophet"") to the Pharaoh. Part of the Law (Torah) that Moses received from""",
meta={"name": "0"}
),
Document(
text="""Schopenhauer, describing him as an ultimately shallow thinker: ""Schopenhauer has quite a crude mind ... where real depth starts, his comes to an end."" His friend Bertrand Russell had a low opinion on the philosopher, and attacked him in his famous ""History of Western Philosophy"" for hypocritically praising asceticism yet not acting upon it. On the opposite isle of Russell on the foundations of mathematics, the Dutch mathematician L. E. J. Brouwer incorporated the ideas of Kant and Schopenhauer in intuitionism, where mathematics is considered a purely mental activity, instead of an analytic activity wherein objective properties of reality are""",
meta={"name": "1"}
),
Document(
text="""Democratic Republic of the Congo to the south. Angola's capital, Luanda, lies on the Atlantic coast in the northwest of the country. Angola, although located in a tropical zone, has a climate that is not characterized for this region, due to the confluence of three factors: As a result, Angola's climate is characterized by two seasons: rainfall from October to April and drought, known as ""Cacimbo"", from May to August, drier, as the name implies, and with lower temperatures. On the other hand, while the coastline has high rainfall rates, decreasing from North to South and from to , with""",
)
]

retriever = DensePassageRetriever(document_store=document_store, embedding_model="dpr-bert-base-nq", use_gpu=False)

embedded = []
for doc in documents:
embedding = retriever.embed_passages([doc['text']])[0]
doc['embedding'] = embedding
embedded.append(doc)
document_store.write_documents(documents, index="test_dpr")
retriever = DensePassageRetriever(document_store=document_store, embedding_model="dpr-bert-base-nq", use_gpu=False, embed_title=True)
document_store.update_embeddings(retriever=retriever, index="test_dpr")
time.sleep(2)

assert (embedding.shape[0] == 768)
assert (embedding[0] - 0.52872 < 0.001)
docs_with_emb = document_store.get_all_documents(index="test_dpr")

document_store.write_documents(embedded)
# FAISSDocumentStore doesn't return embeddings, so these tests only work with ElasticsearchDocumentStore
if isinstance(document_store, ElasticsearchDocumentStore):
assert (len(docs_with_emb[0].embedding) == 768)
assert (abs(docs_with_emb[0].embedding[0] - (-0.30634)) < 0.001)
assert (abs(docs_with_emb[1].embedding[0] - (-0.24695)) < 0.001)
assert (abs(docs_with_emb[2].embedding[0] - (-0.37449)) < 0.001)

res = retriever.retrieve(query="Which philosopher attacked Schopenhauer?")
res = retriever.retrieve(query="Which philosopher attacked Schopenhauer?", index="test_dpr")
assert res[0].meta["name"] == "1"

# clean up
document_store.delete_all_documents(index="test_dpr")
2 changes: 1 addition & 1 deletion tutorials/Tutorial6_Better_Retrieval_via_DPR.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@
"source": [
"from haystack.retriever.dense import DensePassageRetriever\n",
"retriever = DensePassageRetriever(document_store=document_store, embedding_model=\"dpr-bert-base-nq\",\n",
" do_lower_case=True, use_gpu=True)\n",
" do_lower_case=True, use_gpu=True, embed_title=True)\n",
"\n",
"# Important: \n",
"# Now that after we have the DPR initialized, we need to call update_embeddings() to iterate over all\n",
Expand Down
2 changes: 1 addition & 1 deletion tutorials/Tutorial6_Better_Retrieval_via_DPR.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

### Retriever
retriever = DensePassageRetriever(document_store=document_store, embedding_model="dpr-bert-base-nq",
do_lower_case=True, use_gpu=True)
do_lower_case=True, use_gpu=True, embed_title=True)

# Important:
# Now that after we have the DPR initialized, we need to call update_embeddings() to iterate over all
Expand Down

0 comments on commit 72b1013

Please sign in to comment.