Skip to content

Commit

Permalink
Merge pull request #57 from deepset-ai/new-device-mgmt
Browse files Browse the repository at this point in the history
update cookbooks for new device mgmt in Sentence Transformers embedders
  • Loading branch information
TuanaCelik committed Feb 22, 2024
2 parents 08c8ab8 + 42d660a commit 4cad8ce
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 9 deletions.
10 changes: 7 additions & 3 deletions notebooks/improve-retrieval-by-embedding-metadata.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,8 @@
"from haystack.components.embedders import SentenceTransformersTextEmbedder, SentenceTransformersDocumentEmbedder\n",
"from haystack.components.writers import DocumentWriter\n",
"from haystack.document_stores.types import DuplicatePolicy\n",
"from haystack.components.retrievers.in_memory import InMemoryEmbeddingRetriever"
"from haystack.components.retrievers.in_memory import InMemoryEmbeddingRetriever\n",
"from haystack.utils import ComponentDevice"
]
},
{
Expand All @@ -145,7 +146,9 @@
" indexing.add_component(\"splitter\", DocumentSplitter(split_by='sentence', split_length=2))\n",
"\n",
" # in the following componente, we can specify the parameter `metadata_fields_to_embed`, with the metadata to embed\n",
" indexing.add_component(\"doc_embedder\", SentenceTransformersDocumentEmbedder(model=\"thenlper/gte-large\", device=\"cuda:0\", meta_fields_to_embed=metadata_fields_to_embed)\n",
" indexing.add_component(\"doc_embedder\", SentenceTransformersDocumentEmbedder(model=\"thenlper/gte-large\",\n",
" device=ComponentDevice.from_str(\"cuda:0\"),\n",
" meta_fields_to_embed=metadata_fields_to_embed)\n",
" )\n",
" indexing.add_component(\"writer\", DocumentWriter(document_store=document_store, policy=DuplicatePolicy.OVERWRITE))\n",
"\n",
Expand All @@ -167,7 +170,8 @@
"def create_retrieval_pipeline(document_store):\n",
"\n",
" retrieval = Pipeline()\n",
" retrieval.add_component(\"text_embedder\", SentenceTransformersTextEmbedder(model=\"thenlper/gte-large\", device=\"cuda:0\"))\n",
" retrieval.add_component(\"text_embedder\", SentenceTransformersTextEmbedder(model=\"thenlper/gte-large\",\n",
" device=ComponentDevice.from_str(\"cuda:0\")))\n",
" retrieval.add_component(\"retriever\", InMemoryEmbeddingRetriever(document_store=document_store, scale_score=False, top_k=3))\n",
"\n",
" retrieval.connect(\"text_embedder\", \"retriever\")\n",
Expand Down
8 changes: 5 additions & 3 deletions notebooks/multilingual_rag_podcast.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,9 @@
"outputs": [],
"source": [
"# from haystack.components.audio import LocalWhisperTranscriber\n",
"# from haystack.utils import ComponentDevice\n",
"\n",
"# whisper = LocalWhisperTranscriber(model=\"small\", device=\"cuda:0\")\n",
"# whisper = LocalWhisperTranscriber(model=\"small\", device=ComponentDevice.from_str(\"cuda:0\"),)\n",
"# whisper.warm_up()\n",
"# transcription = whisper.run(audio_files=[\"podcast.mp3\"])\n",
"\n",
Expand Down Expand Up @@ -188,6 +189,7 @@
"from haystack.components.converters import TextFileToDocument\n",
"from haystack.components.writers import DocumentWriter\n",
"from haystack.components.preprocessors import DocumentSplitter\n",
"from haystack.utils import ComponentDevice\n",
"\n",
"# initialize the Document store\n",
"document_store = QdrantDocumentStore(\n",
Expand All @@ -203,7 +205,7 @@
" \"embedder\",\n",
" SentenceTransformersDocumentEmbedder(\n",
" model=\"intfloat/multilingual-e5-large\", # good multilingual model: https://huggingface.co/intfloat/multilingual-e5-large\n",
" device=\"cuda:0\", # load the model on GPU\n",
" device=ComponentDevice.from_str(\"cuda:0\"), # load the model on GPU\n",
" prefix=\"passage:\", # as explained in the model card (https://huggingface.co/intfloat/multilingual-e5-large#faq), documents should be prefixed with \"passage:\"\n",
" ))\n",
"indexing_pipeline.add_component(\"writer\", DocumentWriter(document_store=document_store))\n",
Expand Down Expand Up @@ -1035,7 +1037,7 @@
" \"text_embedder\",\n",
" SentenceTransformersTextEmbedder(\n",
" model=\"intfloat/multilingual-e5-large\", # good multilingual model: https://huggingface.co/intfloat/multilingual-e5-large\n",
" device=\"cuda:0\", # load the model on GPU\n",
" device=ComponentDevice.from_str(\"cuda:0\"), # load the model on GPU\n",
" prefix=\"query:\", # as explained in the model card (https://huggingface.co/intfloat/multilingual-e5-large#faq), queries should be prefixed with \"query:\"\n",
" ))\n",
"query_pipeline.add_component(\"retriever\", QdrantEmbeddingRetriever(document_store=document_store))\n",
Expand Down
10 changes: 7 additions & 3 deletions notebooks/zephyr-7b-beta-for-rag.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,8 @@
"from haystack.components.preprocessors import DocumentCleaner, DocumentSplitter\n",
"from haystack.components.embedders import SentenceTransformersTextEmbedder, SentenceTransformersDocumentEmbedder\n",
"from haystack.components.writers import DocumentWriter\n",
"from haystack.document_stores.types import DuplicatePolicy"
"from haystack.document_stores.types import DuplicatePolicy\n",
"from haystack.utils import ComponentDevice"
]
},
{
Expand Down Expand Up @@ -207,7 +208,9 @@
"indexing = Pipeline()\n",
"indexing.add_component(\"cleaner\", DocumentCleaner())\n",
"indexing.add_component(\"splitter\", DocumentSplitter(split_by='sentence', split_length=2))\n",
"indexing.add_component(\"doc_embedder\", SentenceTransformersDocumentEmbedder(model=\"thenlper/gte-large\", device=\"cuda:0\", metadata_fields_to_embed=[\"title\"]))\n",
"indexing.add_component(\"doc_embedder\", SentenceTransformersDocumentEmbedder(model=\"thenlper/gte-large\",\n",
" device=ComponentDevice.from_str(\"cuda:0\"), \n",
" meta_fields_to_embed=[\"title\"]))\n",
"indexing.add_component(\"writer\", DocumentWriter(document_store=document_store, policy=DuplicatePolicy.OVERWRITE))\n",
"\n",
"indexing.connect(\"cleaner\", \"splitter\")\n",
Expand Down Expand Up @@ -1541,7 +1544,8 @@
"outputs": [],
"source": [
"rag = Pipeline()\n",
"rag.add_component(\"text_embedder\", SentenceTransformersTextEmbedder(model=\"thenlper/gte-large\", device=\"cuda:0\"))\n",
"rag.add_component(\"text_embedder\", SentenceTransformersTextEmbedder(model=\"thenlper/gte-large\", \n",
" device=ComponentDevice.from_str(\"cuda:0\"))\n",
"rag.add_component(\"retriever\", InMemoryEmbeddingRetriever(document_store=document_store, top_k=5))\n",
"rag.add_component(\"prompt_builder\", prompt_builder)\n",
"rag.add_component(\"llm\", generator)\n",
Expand Down

0 comments on commit 4cad8ce

Please sign in to comment.