Skip to content

Commit

Permalink
test: update filtering of Pinecone mock to imitate doc store (deepset…
Browse files Browse the repository at this point in the history
…-ai#3020)

* updated filtering of doc store to imitate pinecone

* Update test/mocks/pinecone.py
  • Loading branch information
jamescalam committed Aug 18, 2022
1 parent 74b7c2c commit 82c9cff
Showing 1 changed file with 143 additions and 7 deletions.
150 changes: 143 additions & 7 deletions test/mocks/pinecone.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from typing import Optional, List
from typing import Optional, List, Dict, Union

import logging

logger = logging.getLogger(__name__)


# Mock Pinecone instance
CONFIG: dict = {"api_key": None, "environment": None, "indexes": {}}

Expand Down Expand Up @@ -58,10 +57,23 @@ def upsert(self, vectors: List[tuple], namespace: str = ""):
upsert_count += 1
return {"upserted_count": upsert_count}

def describe_index_stats(self):
def update(self, namespace: str, id: str, set_metadata: dict):
# Get existing item metadata
meta = self.index_config.namespaces[namespace][id]["metadata"]
# Add new metadata to existing item metadata
self.index_config.namespaces[namespace][id]["metadata"] = {**meta, **set_metadata}

def describe_index_stats(self, filter=None):
namespaces = {}
for namespace in self.index_config.namespaces.items():
namespaces[namespace[0]] = {"vector_count": len(namespace[1])}
records = self.index_config.namespaces[namespace[0]]
if filter:
filtered_records = []
for record in records.values():
if self._filter(metadata=record["metadata"], filters=filter, top_level=True):
filtered_records.append(record)
records = filtered_records
namespaces[namespace[0]] = {"vector_count": len(records)}
return {"dimension": self.index_config.dimension, "index_fullness": 0.0, "namespaces": namespaces}

def query(
Expand All @@ -87,7 +99,11 @@ def query(
if include_metadata:
match["metadata"] = records[_id]["metadata"].copy()
match["score"] = 0.0
response["matches"].append(match)
if filter is None or (
filter is not None and self._filter(records[_id]["metadata"], filter, top_level=True)
):
# filter if needed
response["matches"].append(match)
return response

def fetch(self, ids: List[str], namespace: str = ""):
Expand All @@ -107,22 +123,142 @@ def fetch(self, ids: List[str], namespace: str = ""):
}
return response

def _filter(
self,
metadata: dict,
filters: Dict[str, Union[str, int, float, bool, list]],
mode: Optional[str] = "$and",
top_level=False,
) -> dict:
"""
Mock filtering function
"""
bools = []
if type(filters) is list:
list_bools = []
for _filter in filters:
res = self._filter(metadata, _filter, mode=mode)
for key, value in res.items():
if key == "$and":
list_bools.append(all(value))
else:
list_bools.append(any(value))
if mode == "$and":
bools.append(all(list_bools))
elif mode == "$or":
bools.append(any(list_bools))
else:
for field, potential_value in filters.items():
if field in ["$and", "$or"]:
bools.append(self._filter(metadata, potential_value, mode=field))
mode = field
cond = field
else:
if type(potential_value) is dict:
sub_bool = []
for cond, value in potential_value.items():
if len(potential_value.keys()) > 1:
sub_filter = {field: {cond: value}}
bools.append(self._filter(metadata, sub_filter))
if len(sub_bool) > 1:
if field == "$or":
bools.append(any(sub_bool))
else:
bools.append(all(sub_bool))
elif type(potential_value) is list:
cond = "$in"
value = potential_value
else:
cond = "$eq"
value = potential_value
# main chunk of condition checks
if cond == "$eq":
if field in metadata and metadata[field] == value:
bools.append(True)
else:
bools.append(False)
elif cond == "$ne":
if field in metadata and metadata[field] != value:
bools.append(True)
else:
bools.append(False)
elif cond == "$in":
if field in metadata and metadata[field] in value:
bools.append(True)
else:
bools.append(False)
elif cond == "$nin":
if field in metadata and metadata[field] not in value:
bools.append(True)
else:
bools.append(False)
elif cond == "$gt":
if field in metadata and metadata[field] > value:
bools.append(True)
else:
bools.append(False)
elif cond == "$lt":
if field in metadata and metadata[field] < value:
bools.append(True)
else:
bools.append(False)
elif cond == "$gte":
if field in metadata and metadata[field] >= value:
bools.append(True)
else:
bools.append(False)
elif cond == "$lte":
if field in metadata and metadata[field] <= value:
bools.append(True)
else:
bools.append(False)
if top_level:
final = []
for item in bools:
if type(item) is dict:
for key, value in item.items():
if key == "$and":
final.append(all(value))
else:
final.append(any(value))
else:
final.append(item)
if mode == "$and":
bools = all(final)
else:
bools = any(final)
else:
if mode == "$and":
return {"$and": bools}
else:
return {"$or": bools}
return bools

def delete(
self,
ids: Optional[List[str]] = None,
namespace: str = "",
filters: Optional[dict] = None,
delete_all: bool = False,
):
if delete_all:
if filters:
# Get a filtered list of IDs
matches = self.query(filters=filters, namespace=namespace, include_values=False, include_metadata=False)[
"vectors"
]
filter_ids: List[str] = matches.keys() # .keys() returns an object that supports set operators already
elif delete_all:
self.index_config.namespaces[namespace] = {}

if namespace not in self.index_config.namespaces:
pass
elif ids is not None:
id_list: List[str] = ids
if filters:
# We find the intersect between the IDs and filtered IDs
id_list = set(id_list).intersection(filter_ids)
records = self.index_config.namespaces[namespace]
for _id in list(records.keys()):
for _id in records.keys():
if _id in id_list:
del records[_id]
else:
Expand Down

0 comments on commit 82c9cff

Please sign in to comment.