Skip to content

Commit

Permalink
fix: Correctly deserialize Pinecone docstore in embedding retriever (d…
Browse files Browse the repository at this point in the history
…eepset-ai#636)

* fix: Correctly deserialize Pinecone docstore in embedding retriever

* Feature gate tests

* Lints
  • Loading branch information
shadeMe committed Apr 2, 2024
1 parent 0ed69df commit 958f44c
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@ def from_dict(cls, data: Dict[str, Any]) -> "PineconeEmbeddingRetriever":
:returns:
Deserialized component.
"""
data["init_parameters"]["document_store"] = default_from_dict(
PineconeDocumentStore, data["init_parameters"]["document_store"]
data["init_parameters"]["document_store"] = PineconeDocumentStore.from_dict(
data["init_parameters"]["document_store"]
)
return default_from_dict(cls, data)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,7 @@ def __init__(
Creates a new PineconeDocumentStore instance.
It is meant to be connected to a Pinecone index and namespace.
:param api_key: The Pinecone API key. It can be explicitly provided or automatically read from the
environment variable `PINECONE_API_KEY` (recommended).
:param api_key: The Pinecone API key.
:param environment: The Pinecone environment to connect to.
:param index: The Pinecone index to connect to. If the index does not exist, it will be created.
:param namespace: The Pinecone namespace to connect to. If the namespace does not exist, it will be created
Expand All @@ -59,16 +58,9 @@ def __init__(
[API reference](https://docs.pinecone.io/reference/create_index).
"""
resolved_api_key = api_key.resolve_value()
if resolved_api_key is None:
msg = (
"PineconeDocumentStore expects an API key. "
"Set the PINECONE_API_KEY environment variable (recommended) or pass it explicitly."
)
raise ValueError(msg)
self.api_key = api_key

pinecone.init(api_key=resolved_api_key, environment=environment)
pinecone.init(api_key=api_key.resolve_value(), environment=environment)

if index not in pinecone.list_indexes():
logger.info(f"Index {index} does not exist. Creating a new index.")
Expand Down
33 changes: 23 additions & 10 deletions integrations/pinecone/tests/test_document_store.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from unittest.mock import patch

import numpy as np
Expand Down Expand Up @@ -50,7 +51,7 @@ def test_init_api_key_in_environment_variable(mock_pinecone, monkeypatch):


@patch("haystack_integrations.document_stores.pinecone.document_store.pinecone")
def test_to_dict(mock_pinecone, monkeypatch):
def test_to_from_dict(mock_pinecone, monkeypatch):
mock_pinecone.Index.return_value.describe_index_stats.return_value = {"dimension": 30}
monkeypatch.setenv("PINECONE_API_KEY", "env-api-key")
document_store = PineconeDocumentStore(
Expand All @@ -61,7 +62,8 @@ def test_to_dict(mock_pinecone, monkeypatch):
dimension=30,
metric="euclidean",
)
assert document_store.to_dict() == {

dict_output = {
"type": "haystack_integrations.document_stores.pinecone.document_store.PineconeDocumentStore",
"init_parameters": {
"api_key": {
Expand All @@ -79,9 +81,28 @@ def test_to_dict(mock_pinecone, monkeypatch):
"metric": "euclidean",
},
}
assert document_store.to_dict() == dict_output

document_store = PineconeDocumentStore.from_dict(dict_output)
assert document_store.environment == "gcp-starter"
assert document_store.api_key == Secret.from_env_var("PINECONE_API_KEY", strict=True)
assert document_store.index == "my_index"
assert document_store.namespace == "test"
assert document_store.batch_size == 50
assert document_store.dimension == 30


def test_init_fails_wo_api_key(monkeypatch):
monkeypatch.delenv("PINECONE_API_KEY", raising=False)
with pytest.raises(ValueError):
PineconeDocumentStore(
environment="gcp-starter",
index="my_index",
)


@pytest.mark.integration
@pytest.mark.skipif("PINECONE_API_KEY" not in os.environ, reason="PINECONE_API_KEY not set")
class TestDocumentStore(CountDocumentsTest, DeleteDocumentsTest, WriteDocumentsTest):
def test_write_documents(self, document_store: PineconeDocumentStore):
docs = [Document(id="1")]
Expand All @@ -96,14 +117,6 @@ def test_write_documents_duplicate_skip(self, document_store: PineconeDocumentSt
@pytest.mark.skip(reason="Pinecone creates a namespace only when the first document is written")
def test_delete_documents_empty_document_store(self, document_store: PineconeDocumentStore): ...

def test_init_fails_wo_api_key(self, monkeypatch):
monkeypatch.delenv("PINECONE_API_KEY", raising=False)
with pytest.raises(ValueError):
PineconeDocumentStore(
environment="gcp-starter",
index="my_index",
)

def test_embedding_retrieval(self, document_store: PineconeDocumentStore):
query_embedding = [0.1] * 768
most_similar_embedding = [0.8] * 768
Expand Down
9 changes: 9 additions & 0 deletions integrations/pinecone/tests/test_emebedding_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from unittest.mock import Mock, patch

from haystack.dataclasses import Document
from haystack.utils import Secret

from haystack_integrations.components.retrievers.pinecone import PineconeEmbeddingRetriever
from haystack_integrations.document_stores.pinecone import PineconeDocumentStore
Expand Down Expand Up @@ -63,6 +64,13 @@ def test_from_dict(mock_pinecone, monkeypatch):
"init_parameters": {
"document_store": {
"init_parameters": {
"api_key": {
"env_vars": [
"PINECONE_API_KEY",
],
"strict": True,
"type": "env_var",
},
"environment": "gcp-starter",
"index": "default",
"namespace": "test-namespace",
Expand All @@ -82,6 +90,7 @@ def test_from_dict(mock_pinecone, monkeypatch):

document_store = retriever.document_store
assert document_store.environment == "gcp-starter"
assert document_store.api_key == Secret.from_env_var("PINECONE_API_KEY", strict=True)
assert document_store.index == "default"
assert document_store.namespace == "test-namespace"
assert document_store.batch_size == 50
Expand Down
2 changes: 2 additions & 0 deletions integrations/pinecone/tests/test_filters.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from typing import List

import pytest
Expand All @@ -8,6 +9,7 @@


@pytest.mark.integration
@pytest.mark.skipif("PINECONE_API_KEY" not in os.environ, reason="PINECONE_API_KEY not set")
class TestFilters(FilterDocumentsTest):
def assert_documents_are_equal(self, received: List[Document], expected: List[Document]):
for doc in received:
Expand Down

0 comments on commit 958f44c

Please sign in to comment.