Skip to content

Commit

Permalink
Merge pull request #1168 from efenocchi/main
Browse files Browse the repository at this point in the history
fix(dspy): fixed bug in deeplake_rm retriever part
  • Loading branch information
arnavsinghvi11 authored Jun 18, 2024
2 parents 05a4923 + 386aa53 commit 081b637
Showing 1 changed file with 15 additions and 19 deletions.
34 changes: 15 additions & 19 deletions dspy/retrieve/deeplake_rm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,14 @@
from dsp.utils import dotdict

try:
import openai.error
import openai

ERRORS = (
openai.error.RateLimitError,
openai.error.ServiceUnavailableError,
openai.error.APIError,
openai.RateLimitError,
openai.APIError,
)
except Exception:
ERRORS = (openai.error.RateLimitError, openai.error.APIError)
ERRORS = (openai.RateLimitError, openai.APIError)


class DeeplakeRM(dspy.Retrieve):
Expand Down Expand Up @@ -58,13 +57,15 @@ def __init__(
k: int = 3,
):
try:
from deeplake import VectorStore
from deeplake import VectorStore
except ImportError:
raise ImportError(
"The 'deeplake' extra is required to use DeepLakeRM. Install it with `pip install dspy-ai[deeplake]`",
)
raise ImportError("The 'deeplake' extra is required to use DeepLakeRM. Install it with `pip install dspy-ai[deeplake]`",)

self._deeplake_vectorstore_name = deeplake_vectorstore_name
self._deeplake_client = deeplake_client
self._deeplake_client = deeplake_client(
path=self._deeplake_vectorstore_name,
embedding_function=self.embedding_function,
)

super().__init__(k=k)

Expand All @@ -73,11 +74,9 @@ def embedding_function(self, texts, model="text-embedding-ada-002"):
texts = [texts]

texts = [t.replace("\n", " ") for t in texts]
return [
data["embedding"]
for data in openai.Embedding.create(input=texts, model=model)["data"]
]


return [data.embedding for data in openai.embeddings.create(input = texts, model=model).data]

def forward(
self, query_or_queries: Union[str, List[str]], k: Optional[int],**kwargs,
) -> dspy.Prediction:
Expand All @@ -103,10 +102,7 @@ def forward(
passages = defaultdict(float)
#deeplake doesn't support batch querying, manually querying each query and storing them
for query in queries:
results = self._deeplake_client(
path=self._deeplake_vectorstore_name,
embedding_function=self.embedding_function,
).search(query, k=k,**kwargs)
results = self._deeplake_client.search(query, k=k, **kwargs)

for score,text in zip(results.get('score',0.0),results.get('text',"")):
passages[text] += score
Expand Down

0 comments on commit 081b637

Please sign in to comment.