Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
jdixosnd committed May 23, 2024
2 parents 803334d + 5edfd53 commit f724743
Show file tree
Hide file tree
Showing 16 changed files with 179 additions and 19 deletions.
23 changes: 23 additions & 0 deletions .github/pull_request_template.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
### Related Issues

- fixes #issue-number

### Proposed Changes:

<!--- In case of a bug: Describe what caused the issue and how you solved it -->
<!--- In case of a feature: Describe what did you add and how it works -->

### How did you test it?

<!-- unit tests, integration tests, manual verification, instructions for manual tests -->

### Notes for the reviewer

<!-- E.g. point out section where the reviewer -->

### Checklist

- I have read the [contributors guidelines](https://github.com/deepset-ai/haystack-core-integrations/blob/main/CONTRIBUTING.md) and the [code of conduct](https://github.com/deepset-ai/haystack-core-integrations/blob/main/CODE_OF_CONDUCT.md)
- I have updated the related issue with new insights and changes
- I added unit tests and updated the docstrings
- I've used one of the [conventional commit types](https://www.conventionalcommits.org/en/v1.0.0/) for my PR title: `fix:`, `feat:`, `build:`, `chore:`, `ci:`, `docs:`, `style:`, `refactor:`, `perf:`, `test:`.
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ def __init__(self, generation_kwargs: Dict[str, Any]):
self.prompt_handler = DefaultPromptHandler(
tokenizer=tokenizer,
model_max_length=model_max_length,
max_length=self.generation_kwargs.get("max_gen_len") or 512,
max_length=self.generation_kwargs.get("max_tokens") or 512,
)

def prepare_body(self, messages: List[ChatMessage], **inference_kwargs) -> Dict[str, Any]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ def find_documents(self, find_query):
filter=find_query.get("filter"),
sort=find_query.get("sort"),
options=find_query.get("options"),
projection={"*": 1},
)

if "data" in response_dict and "documents" in response_dict["data"]:
Expand Down Expand Up @@ -273,6 +274,7 @@ def update_document(self, document: Dict, id_key: str):
filter={id_key: document_id},
update={"$set": document},
options={"returnDocument": "after"},
projection={"*": 1},
)

document[id_key] = document_id
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,16 @@ def get_embedding_backend(
model_name: str,
cache_dir: Optional[str] = None,
threads: Optional[int] = None,
local_files_only: bool = False,
):
embedding_backend_id = f"{model_name}{cache_dir}{threads}"

if embedding_backend_id in _FastembedEmbeddingBackendFactory._instances:
return _FastembedEmbeddingBackendFactory._instances[embedding_backend_id]

embedding_backend = _FastembedEmbeddingBackend(model_name=model_name, cache_dir=cache_dir, threads=threads)
embedding_backend = _FastembedEmbeddingBackend(
model_name=model_name, cache_dir=cache_dir, threads=threads, local_files_only=local_files_only
)
_FastembedEmbeddingBackendFactory._instances[embedding_backend_id] = embedding_backend
return embedding_backend

Expand All @@ -40,8 +43,11 @@ def __init__(
model_name: str,
cache_dir: Optional[str] = None,
threads: Optional[int] = None,
local_files_only: bool = False,
):
self.model = TextEmbedding(model_name=model_name, cache_dir=cache_dir, threads=threads)
self.model = TextEmbedding(
model_name=model_name, cache_dir=cache_dir, threads=threads, local_files_only=local_files_only
)

def embed(self, data: List[str], progress_bar=True, **kwargs) -> List[List[float]]:
# the embed method returns a Iterable[np.ndarray], so we convert it to a list of lists
Expand All @@ -66,14 +72,15 @@ def get_embedding_backend(
model_name: str,
cache_dir: Optional[str] = None,
threads: Optional[int] = None,
local_files_only: bool = False,
):
embedding_backend_id = f"{model_name}{cache_dir}{threads}"

if embedding_backend_id in _FastembedSparseEmbeddingBackendFactory._instances:
return _FastembedSparseEmbeddingBackendFactory._instances[embedding_backend_id]

embedding_backend = _FastembedSparseEmbeddingBackend(
model_name=model_name, cache_dir=cache_dir, threads=threads
model_name=model_name, cache_dir=cache_dir, threads=threads, local_files_only=local_files_only
)
_FastembedSparseEmbeddingBackendFactory._instances[embedding_backend_id] = embedding_backend
return embedding_backend
Expand All @@ -89,8 +96,11 @@ def __init__(
model_name: str,
cache_dir: Optional[str] = None,
threads: Optional[int] = None,
local_files_only: bool = False,
):
self.model = SparseTextEmbedding(model_name=model_name, cache_dir=cache_dir, threads=threads)
self.model = SparseTextEmbedding(
model_name=model_name, cache_dir=cache_dir, threads=threads, local_files_only=local_files_only
)

def embed(self, data: List[List[str]], progress_bar=True, **kwargs) -> List[SparseEmbedding]:
# The embed method returns a Iterable[SparseEmbedding], so we convert to Haystack SparseEmbedding type.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def __init__(
batch_size: int = 256,
progress_bar: bool = True,
parallel: Optional[int] = None,
local_files_only: bool = False,
meta_fields_to_embed: Optional[List[str]] = None,
embedding_separator: str = "\n",
):
Expand All @@ -80,11 +81,12 @@ def __init__(
:param prefix: A string to add to the beginning of each text.
:param suffix: A string to add to the end of each text.
:param batch_size: Number of strings to encode at once.
:param progress_bar: If true, displays progress bar during embedding.
:param progress_bar: If `True`, displays progress bar during embedding.
:param parallel:
If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets.
If 0, use all available cores.
If None, don't use data-parallel processing, use default onnxruntime threading instead.
:param local_files_only: If `True`, only use the model files in the `cache_dir`.
:param meta_fields_to_embed: List of meta fields that should be embedded along with the Document content.
:param embedding_separator: Separator used to concatenate the meta fields to the Document content.
"""
Expand All @@ -97,6 +99,7 @@ def __init__(
self.batch_size = batch_size
self.progress_bar = progress_bar
self.parallel = parallel
self.local_files_only = local_files_only
self.meta_fields_to_embed = meta_fields_to_embed or []
self.embedding_separator = embedding_separator

Expand All @@ -116,6 +119,7 @@ def to_dict(self) -> Dict[str, Any]:
batch_size=self.batch_size,
progress_bar=self.progress_bar,
parallel=self.parallel,
local_files_only=self.local_files_only,
meta_fields_to_embed=self.meta_fields_to_embed,
embedding_separator=self.embedding_separator,
)
Expand All @@ -126,7 +130,10 @@ def warm_up(self):
"""
if not hasattr(self, "embedding_backend"):
self.embedding_backend = _FastembedEmbeddingBackendFactory.get_embedding_backend(
model_name=self.model_name, cache_dir=self.cache_dir, threads=self.threads
model_name=self.model_name,
cache_dir=self.cache_dir,
threads=self.threads,
local_files_only=self.local_files_only,
)

def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def __init__(
batch_size: int = 32,
progress_bar: bool = True,
parallel: Optional[int] = None,
local_files_only: bool = False,
meta_fields_to_embed: Optional[List[str]] = None,
embedding_separator: str = "\n",
):
Expand All @@ -77,6 +78,7 @@ def __init__(
If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets.
If 0, use all available cores.
If None, don't use data-parallel processing, use default onnxruntime threading instead.
:param local_files_only: If `True`, only use the model files in the `cache_dir`.
:param meta_fields_to_embed: List of meta fields that should be embedded along with the Document content.
:param embedding_separator: Separator used to concatenate the meta fields to the Document content.
"""
Expand All @@ -87,6 +89,7 @@ def __init__(
self.batch_size = batch_size
self.progress_bar = progress_bar
self.parallel = parallel
self.local_files_only = local_files_only
self.meta_fields_to_embed = meta_fields_to_embed or []
self.embedding_separator = embedding_separator

Expand All @@ -104,6 +107,7 @@ def to_dict(self) -> Dict[str, Any]:
batch_size=self.batch_size,
progress_bar=self.progress_bar,
parallel=self.parallel,
local_files_only=self.local_files_only,
meta_fields_to_embed=self.meta_fields_to_embed,
embedding_separator=self.embedding_separator,
)
Expand All @@ -114,7 +118,10 @@ def warm_up(self):
"""
if not hasattr(self, "embedding_backend"):
self.embedding_backend = _FastembedSparseEmbeddingBackendFactory.get_embedding_backend(
model_name=self.model_name, cache_dir=self.cache_dir, threads=self.threads
model_name=self.model_name,
cache_dir=self.cache_dir,
threads=self.threads,
local_files_only=self.local_files_only,
)

def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __init__(
threads: Optional[int] = None,
progress_bar: bool = True,
parallel: Optional[int] = None,
local_files_only: bool = False,
):
"""
Create a FastembedSparseTextEmbedder component.
Expand All @@ -43,18 +44,20 @@ def __init__(
Can be set using the `FASTEMBED_CACHE_PATH` env variable.
Defaults to `fastembed_cache` in the system's temp directory.
:param threads: The number of threads single onnxruntime session can use. Defaults to None.
:param progress_bar: If true, displays progress bar during embedding.
:param progress_bar: If `True`, displays progress bar during embedding.
:param parallel:
If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets.
If 0, use all available cores.
If None, don't use data-parallel processing, use default onnxruntime threading instead.
:param local_files_only: If `True`, only use the model files in the `cache_dir`.
"""

self.model_name = model
self.cache_dir = cache_dir
self.threads = threads
self.progress_bar = progress_bar
self.parallel = parallel
self.local_files_only = local_files_only

def to_dict(self) -> Dict[str, Any]:
"""
Expand All @@ -70,6 +73,7 @@ def to_dict(self) -> Dict[str, Any]:
threads=self.threads,
progress_bar=self.progress_bar,
parallel=self.parallel,
local_files_only=self.local_files_only,
)

def warm_up(self):
Expand All @@ -78,7 +82,10 @@ def warm_up(self):
"""
if not hasattr(self, "embedding_backend"):
self.embedding_backend = _FastembedSparseEmbeddingBackendFactory.get_embedding_backend(
model_name=self.model_name, cache_dir=self.cache_dir, threads=self.threads
model_name=self.model_name,
cache_dir=self.cache_dir,
threads=self.threads,
local_files_only=self.local_files_only,
)

@component.output_types(sparse_embedding=SparseEmbedding)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def __init__(
suffix: str = "",
progress_bar: bool = True,
parallel: Optional[int] = None,
local_files_only: bool = False,
):
"""
Create a FastembedTextEmbedder component.
Expand All @@ -46,11 +47,12 @@ def __init__(
:param threads: The number of threads single onnxruntime session can use. Defaults to None.
:param prefix: A string to add to the beginning of each text.
:param suffix: A string to add to the end of each text.
:param progress_bar: If true, displays progress bar during embedding.
:param progress_bar: If `True`, displays progress bar during embedding.
:param parallel:
If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets.
If 0, use all available cores.
If None, don't use data-parallel processing, use default onnxruntime threading instead.
:param local_files_only: If `True`, only use the model files in the `cache_dir`.
"""

self.model_name = model
Expand All @@ -60,6 +62,7 @@ def __init__(
self.suffix = suffix
self.progress_bar = progress_bar
self.parallel = parallel
self.local_files_only = local_files_only

def to_dict(self) -> Dict[str, Any]:
"""
Expand All @@ -77,6 +80,7 @@ def to_dict(self) -> Dict[str, Any]:
suffix=self.suffix,
progress_bar=self.progress_bar,
parallel=self.parallel,
local_files_only=self.local_files_only,
)

def warm_up(self):
Expand All @@ -85,7 +89,10 @@ def warm_up(self):
"""
if not hasattr(self, "embedding_backend"):
self.embedding_backend = _FastembedEmbeddingBackendFactory.get_embedding_backend(
model_name=self.model_name, cache_dir=self.cache_dir, threads=self.threads
model_name=self.model_name,
cache_dir=self.cache_dir,
threads=self.threads,
local_files_only=self.local_files_only,
)

@component.output_types(embedding=List[float])
Expand Down
4 changes: 3 additions & 1 deletion integrations/fastembed/tests/test_fastembed_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ def test_model_initialization(mock_instructor):
_FastembedEmbeddingBackendFactory.get_embedding_backend(
model_name="BAAI/bge-small-en-v1.5",
)
mock_instructor.assert_called_once_with(model_name="BAAI/bge-small-en-v1.5", cache_dir=None, threads=None)
mock_instructor.assert_called_once_with(
model_name="BAAI/bge-small-en-v1.5", cache_dir=None, threads=None, local_files_only=False
)
# restore the factory state
_FastembedEmbeddingBackendFactory._instances = {}

Expand Down
Loading

0 comments on commit f724743

Please sign in to comment.