Skip to content

Commit

Permalink
feat: Qdrant filtering support
Browse files Browse the repository at this point in the history
  • Loading branch information
Anush008 committed Mar 12, 2024
1 parent f95e4d0 commit 3c80ed9
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 32 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any, List, Optional, Union
from datetime import datetime
from typing import List, Optional, Union

from haystack.utils.filters import COMPARISON_OPERATORS, LOGICAL_OPERATORS, FilterError
from qdrant_client.http import models
Expand All @@ -10,15 +10,7 @@
LOGICAL_OPERATORS = LOGICAL_OPERATORS.keys()


class BaseFilterConverter(ABC):
"""Converts Haystack filters to a format accepted by an external tool."""

@abstractmethod
def convert(self, filter_term: Optional[Union[List[dict], dict]]) -> Optional[Any]:
raise NotImplementedError


class QdrantFilterConverter(BaseFilterConverter):
class QdrantFilterConverter:
"""Converts Haystack filters to the format used by Qdrant."""

def __init__(self):
Expand Down Expand Up @@ -141,34 +133,47 @@ def _build_nin_condition(self, key: str, value: List[models.ValueVariants]) -> m
must_not=[
(
models.FieldCondition(key=key, match=models.MatchText(text=item))
if isinstance(item, str) and " " not in item
if isinstance(item, str) and " " in item
else models.FieldCondition(key=key, match=models.MatchValue(value=item))
)
for item in value
]
)

def _build_lt_condition(self, key: str, value: float) -> models.Condition:
def _build_lt_condition(self, key: str, value: Union[str, float, int]) -> models.Condition:
if isinstance(value, str) and is_datetime_string(value):
return models.FieldCondition(key=key, range=models.DatetimeRange(lt=value))

if not isinstance(value, (int, float)):
msg = f"Value {value} is not an int or float"
msg = f"Value {value} is not an int or float or datetime string"
raise FilterError(msg)
return models.FieldCondition(key=key, range=models.Range(lt=value))

def _build_lte_condition(self, key: str, value: float) -> models.Condition:
def _build_lte_condition(self, key: str, value: Union[str, float, int]) -> models.Condition:
if isinstance(value, str) and is_datetime_string(value):
return models.FieldCondition(key=key, range=models.DatetimeRange(lte=value))

if not isinstance(value, (int, float)):
msg = f"Value {value} is not an int or float"
msg = f"Value {value} is not an int or float or datetime string"
raise FilterError(msg)
return models.FieldCondition(key=key, range=models.Range(lte=value))

def _build_gt_condition(self, key: str, value: float) -> models.Condition:
def _build_gt_condition(self, key: str, value: Union[str, float, int]) -> models.Condition:
if isinstance(value, str) and is_datetime_string(value):
return models.FieldCondition(key=key, range=models.DatetimeRange(gt=value))

if not isinstance(value, (int, float)):
msg = f"Value {value} is not an int or float"
msg = f"Value {value} is not an int or float or datetime string"
raise FilterError(msg)

return models.FieldCondition(key=key, range=models.Range(gt=value))

def _build_gte_condition(self, key: str, value: float) -> models.Condition:
def _build_gte_condition(self, key: str, value: Union[str, float, int]) -> models.Condition:
if isinstance(value, str) and is_datetime_string(value):
return models.FieldCondition(key=key, range=models.DatetimeRange(gte=value))

if not isinstance(value, (int, float)):
msg = f"Value {value} is not an int or float"
msg = f"Value {value} is not an int or float or datetime string"
raise FilterError(msg)
return models.FieldCondition(key=key, range=models.Range(gte=value))

Expand Down Expand Up @@ -215,3 +220,11 @@ def _squeeze_filter(self, payload_filter: models.Filter) -> models.Filter:
return models.Filter(**{part_name: subfilter.must})

return payload_filter


def is_datetime_string(value: str) -> bool:
try:
datetime.fromisoformat(value)
return True
except ValueError:
return False
12 changes: 0 additions & 12 deletions integrations/qdrant/tests/test_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,17 +91,5 @@ def test_comparison_equal_with_dataframe(self, document_store, filterable_docs):
@pytest.mark.skip(reason="Qdrant doesn't support comparision with dataframe")
def test_comparison_not_equal_with_dataframe(self, document_store, filterable_docs): ...

@pytest.mark.skip(reason="Qdrant doesn't support comparision with Dates")
def test_comparison_greater_than_with_iso_date(self, document_store, filterable_docs): ...

@pytest.mark.skip(reason="Qdrant doesn't support comparision with Dates")
def test_comparison_greater_than_equal_with_iso_date(self, document_store, filterable_docs): ...

@pytest.mark.skip(reason="Qdrant doesn't support comparision with Dates")
def test_comparison_less_than_with_iso_date(self, document_store, filterable_docs): ...

@pytest.mark.skip(reason="Qdrant doesn't support comparision with Dates")
def test_comparison_less_than_equal_with_iso_date(self, document_store, filterable_docs): ...

@pytest.mark.skip(reason="Cannot distinguish errors yet")
def test_missing_top_level_operator_key(self, document_store, filterable_docs): ...

0 comments on commit 3c80ed9

Please sign in to comment.