Skip to content

Commit

Permalink
refactor!: Sentence Transformers Embedders - new devices mgmt (#7033)
Browse files Browse the repository at this point in the history
* new device mgmt for Sentence Transformers embedders

* reno
  • Loading branch information
anakin87 committed Feb 19, 2024
1 parent cb01cb4 commit d00f171
Show file tree
Hide file tree
Showing 5 changed files with 114 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from haystack.components.embedders.backends.sentence_transformers_backend import (
_SentenceTransformersEmbeddingBackendFactory,
)
from haystack.utils import Secret, deserialize_secrets_inplace
from haystack.utils import Secret, deserialize_secrets_inplace, ComponentDevice


@component
Expand All @@ -31,7 +31,7 @@ class SentenceTransformersDocumentEmbedder:
def __init__(
self,
model: str = "sentence-transformers/all-mpnet-base-v2",
device: Optional[str] = None,
device: Optional[ComponentDevice] = None,
token: Optional[Secret] = Secret.from_env_var("HF_API_TOKEN", strict=False),
prefix: str = "",
suffix: str = "",
Expand All @@ -46,8 +46,8 @@ def __init__(
:param model: Local path or name of the model in Hugging Face's model hub,
such as ``'sentence-transformers/all-mpnet-base-v2'``.
:param device: Device (like 'cuda' / 'cpu') that should be used for computation.
Defaults to CPU.
:param device: The device on which the model is loaded. If `None`, the default device is automatically
selected.
:param token: The API token used to download private models from Hugging Face.
:param prefix: A string to add to the beginning of each Document text before embedding.
Can be used to prepend the text with an instruction, as required by some embedding models,
Expand All @@ -61,8 +61,7 @@ def __init__(
"""

self.model = model
# TODO: remove device parameter and use Haystack's device management once migrated
self.device = device or "cpu"
self.device = ComponentDevice.resolve_device(device)
self.token = token
self.prefix = prefix
self.suffix = suffix
Expand All @@ -85,7 +84,7 @@ def to_dict(self) -> Dict[str, Any]:
return default_to_dict(
self,
model=self.model,
device=self.device,
device=self.device.to_dict(),
token=self.token.to_dict() if self.token else None,
prefix=self.prefix,
suffix=self.suffix,
Expand All @@ -98,6 +97,9 @@ def to_dict(self) -> Dict[str, Any]:

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "SentenceTransformersDocumentEmbedder":
serialized_device = data["init_parameters"]["device"]
data["init_parameters"]["device"] = ComponentDevice.from_dict(serialized_device)

deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
return default_from_dict(cls, data)

Expand All @@ -107,7 +109,7 @@ def warm_up(self):
"""
if not hasattr(self, "embedding_backend"):
self.embedding_backend = _SentenceTransformersEmbeddingBackendFactory.get_embedding_backend(
model=self.model, device=self.device, auth_token=self.token
model=self.model, device=self.device.to_torch_str(), auth_token=self.token
)

@component.output_types(documents=List[Document])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from haystack.components.embedders.backends.sentence_transformers_backend import (
_SentenceTransformersEmbeddingBackendFactory,
)
from haystack.utils import Secret, deserialize_secrets_inplace
from haystack.utils import Secret, deserialize_secrets_inplace, ComponentDevice


@component
Expand All @@ -30,7 +30,7 @@ class SentenceTransformersTextEmbedder:
def __init__(
self,
model: str = "sentence-transformers/all-mpnet-base-v2",
device: Optional[str] = None,
device: Optional[ComponentDevice] = None,
token: Optional[Secret] = Secret.from_env_var("HF_API_TOKEN", strict=False),
prefix: str = "",
suffix: str = "",
Expand All @@ -43,8 +43,8 @@ def __init__(
:param model: Local path or name of the model in Hugging Face's model hub,
such as ``'sentence-transformers/all-mpnet-base-v2'``.
:param device: Device (like 'cuda' / 'cpu') that should be used for computation.
Defaults to CPU.
:param device: The device on which the model is loaded. If `None`, the default device is automatically
selected.
:param token: The API token used to download private models from Hugging Face.
:param prefix: A string to add to the beginning of each Document text before embedding.
Can be used to prepend the text with an instruction, as required by some embedding models,
Expand All @@ -56,8 +56,7 @@ def __init__(
"""

self.model = model
# TODO: remove device parameter and use Haystack's device management once migrated
self.device = device or "cpu"
self.device = ComponentDevice.resolve_device(device)
self.token = token
self.prefix = prefix
self.suffix = suffix
Expand All @@ -78,7 +77,7 @@ def to_dict(self) -> Dict[str, Any]:
return default_to_dict(
self,
model=self.model,
device=self.device,
device=self.device.to_dict(),
token=self.token.to_dict() if self.token else None,
prefix=self.prefix,
suffix=self.suffix,
Expand All @@ -89,6 +88,9 @@ def to_dict(self) -> Dict[str, Any]:

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "SentenceTransformersTextEmbedder":
serialized_device = data["init_parameters"]["device"]
data["init_parameters"]["device"] = ComponentDevice.from_dict(serialized_device)

deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
return default_from_dict(cls, data)

Expand All @@ -98,7 +100,7 @@ def warm_up(self):
"""
if not hasattr(self, "embedding_backend"):
self.embedding_backend = _SentenceTransformersEmbeddingBackendFactory.get_embedding_backend(
model=self.model, device=self.device, auth_token=self.token
model=self.model, device=self.device.to_torch_str(), auth_token=self.token
)

@component.output_types(embedding=List[float])
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
---
upgrade:
- |
Adopt the new framework-agnostic device management in Sentence Transformers Embedders.
Before this change:
```python
from haystack.components.embedders import SentenceTransformersTextEmbedder
embedder = SentenceTransformersTextEmbedder(device="cuda:0")
```
After this change:
```python
from haystack.utils.device import ComponentDevice, Device
from haystack.components.embedders import SentenceTransformersTextEmbedder
device = ComponentDevice.from_single(Device.gpu(id=0))
# or
# device = ComponentDevice.from_str("cuda:0")
embedder = SentenceTransformersTextEmbedder(device=device)
```
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from unittest.mock import patch, MagicMock
import pytest
import numpy as np
from haystack.utils.auth import Secret
from haystack.utils import Secret, ComponentDevice

from haystack import Document
from haystack.components.embedders.sentence_transformers_document_embedder import SentenceTransformersDocumentEmbedder
Expand All @@ -11,7 +11,7 @@ class TestSentenceTransformersDocumentEmbedder:
def test_init_default(self):
embedder = SentenceTransformersDocumentEmbedder(model="model")
assert embedder.model == "model"
assert embedder.device == "cpu"
assert embedder.device == ComponentDevice.resolve_device(None)
assert embedder.token == Secret.from_env_var("HF_API_TOKEN", strict=False)
assert embedder.prefix == ""
assert embedder.suffix == ""
Expand All @@ -24,7 +24,7 @@ def test_init_default(self):
def test_init_with_parameters(self):
embedder = SentenceTransformersDocumentEmbedder(
model="model",
device="cuda",
device=ComponentDevice.from_str("cuda:0"),
token=Secret.from_token("fake-api-token"),
prefix="prefix",
suffix="suffix",
Expand All @@ -35,7 +35,7 @@ def test_init_with_parameters(self):
embedding_separator=" | ",
)
assert embedder.model == "model"
assert embedder.device == "cuda"
assert embedder.device == ComponentDevice.from_str("cuda:0")
assert embedder.token == Secret.from_token("fake-api-token")
assert embedder.prefix == "prefix"
assert embedder.suffix == "suffix"
Expand All @@ -46,13 +46,13 @@ def test_init_with_parameters(self):
assert embedder.embedding_separator == " | "

def test_to_dict(self):
component = SentenceTransformersDocumentEmbedder(model="model")
component = SentenceTransformersDocumentEmbedder(model="model", device=ComponentDevice.from_str("cpu"))
data = component.to_dict()
assert data == {
"type": "haystack.components.embedders.sentence_transformers_document_embedder.SentenceTransformersDocumentEmbedder",
"init_parameters": {
"model": "model",
"device": "cpu",
"device": ComponentDevice.from_str("cpu").to_dict(),
"token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"},
"prefix": "",
"suffix": "",
Expand All @@ -67,7 +67,7 @@ def test_to_dict(self):
def test_to_dict_with_custom_init_parameters(self):
component = SentenceTransformersDocumentEmbedder(
model="model",
device="cuda",
device=ComponentDevice.from_str("cuda:0"),
token=Secret.from_env_var("ENV_VAR", strict=False),
prefix="prefix",
suffix="suffix",
Expand All @@ -83,7 +83,7 @@ def test_to_dict_with_custom_init_parameters(self):
"type": "haystack.components.embedders.sentence_transformers_document_embedder.SentenceTransformersDocumentEmbedder",
"init_parameters": {
"model": "model",
"device": "cuda",
"device": ComponentDevice.from_str("cuda:0").to_dict(),
"token": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
"prefix": "prefix",
"suffix": "suffix",
Expand All @@ -95,11 +95,43 @@ def test_to_dict_with_custom_init_parameters(self):
},
}

def test_from_dict(self):
init_parameters = {
"model": "model",
"device": ComponentDevice.from_str("cuda:0").to_dict(),
"token": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
"prefix": "prefix",
"suffix": "suffix",
"batch_size": 64,
"progress_bar": False,
"normalize_embeddings": True,
"embedding_separator": " - ",
"meta_fields_to_embed": ["meta_field"],
}
component = SentenceTransformersDocumentEmbedder.from_dict(
{
"type": "haystack.components.embedders.sentence_transformers_document_embedder.SentenceTransformersDocumentEmbedder",
"init_parameters": init_parameters,
}
)
assert component.model == "model"
assert component.device == ComponentDevice.from_str("cuda:0")
assert component.token == Secret.from_env_var("ENV_VAR", strict=False)
assert component.prefix == "prefix"
assert component.suffix == "suffix"
assert component.batch_size == 64
assert component.progress_bar is False
assert component.normalize_embeddings is True
assert component.embedding_separator == " - "
assert component.meta_fields_to_embed == ["meta_field"]

@patch(
"haystack.components.embedders.sentence_transformers_document_embedder._SentenceTransformersEmbeddingBackendFactory"
)
def test_warmup(self, mocked_factory):
embedder = SentenceTransformersDocumentEmbedder(model="model", token=None)
embedder = SentenceTransformersDocumentEmbedder(
model="model", token=None, device=ComponentDevice.from_str("cpu")
)
mocked_factory.get_embedding_backend.assert_not_called()
embedder.warm_up()
mocked_factory.get_embedding_backend.assert_called_once_with(model="model", device="cpu", auth_token=None)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from unittest.mock import patch, MagicMock
import pytest
from haystack.utils.auth import Secret
from haystack.utils import Secret, ComponentDevice

import numpy as np

Expand All @@ -11,7 +11,7 @@ class TestSentenceTransformersTextEmbedder:
def test_init_default(self):
embedder = SentenceTransformersTextEmbedder(model="model")
assert embedder.model == "model"
assert embedder.device == "cpu"
assert embedder.device == ComponentDevice.resolve_device(None)
assert embedder.token == Secret.from_env_var("HF_API_TOKEN", strict=False)
assert embedder.prefix == ""
assert embedder.suffix == ""
Expand All @@ -22,7 +22,7 @@ def test_init_default(self):
def test_init_with_parameters(self):
embedder = SentenceTransformersTextEmbedder(
model="model",
device="cuda",
device=ComponentDevice.from_str("cuda:0"),
token=Secret.from_token("fake-api-token"),
prefix="prefix",
suffix="suffix",
Expand All @@ -31,7 +31,7 @@ def test_init_with_parameters(self):
normalize_embeddings=True,
)
assert embedder.model == "model"
assert embedder.device == "cuda"
assert embedder.device == ComponentDevice.from_str("cuda:0")
assert embedder.token == Secret.from_token("fake-api-token")
assert embedder.prefix == "prefix"
assert embedder.suffix == "suffix"
Expand All @@ -40,14 +40,14 @@ def test_init_with_parameters(self):
assert embedder.normalize_embeddings is True

def test_to_dict(self):
component = SentenceTransformersTextEmbedder(model="model")
component = SentenceTransformersTextEmbedder(model="model", device=ComponentDevice.from_str("cpu"))
data = component.to_dict()
assert data == {
"type": "haystack.components.embedders.sentence_transformers_text_embedder.SentenceTransformersTextEmbedder",
"init_parameters": {
"token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"},
"model": "model",
"device": "cpu",
"device": ComponentDevice.from_str("cpu").to_dict(),
"prefix": "",
"suffix": "",
"batch_size": 32,
Expand All @@ -59,7 +59,7 @@ def test_to_dict(self):
def test_to_dict_with_custom_init_parameters(self):
component = SentenceTransformersTextEmbedder(
model="model",
device="cuda",
device=ComponentDevice.from_str("cuda:0"),
token=Secret.from_env_var("ENV_VAR", strict=False),
prefix="prefix",
suffix="suffix",
Expand All @@ -73,7 +73,7 @@ def test_to_dict_with_custom_init_parameters(self):
"init_parameters": {
"token": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
"model": "model",
"device": "cuda",
"device": ComponentDevice.from_str("cuda:0").to_dict(),
"prefix": "prefix",
"suffix": "suffix",
"batch_size": 64,
Expand All @@ -87,11 +87,35 @@ def test_to_dict_not_serialize_token(self):
with pytest.raises(ValueError, match="Cannot serialize token-based secret"):
component.to_dict()

def test_from_dict(self):
data = {
"type": "haystack.components.embedders.sentence_transformers_text_embedder.SentenceTransformersTextEmbedder",
"init_parameters": {
"token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"},
"model": "model",
"device": ComponentDevice.from_str("cpu").to_dict(),
"prefix": "",
"suffix": "",
"batch_size": 32,
"progress_bar": True,
"normalize_embeddings": False,
},
}
component = SentenceTransformersTextEmbedder.from_dict(data)
assert component.model == "model"
assert component.device == ComponentDevice.from_str("cpu")
assert component.token == Secret.from_env_var("HF_API_TOKEN", strict=False)
assert component.prefix == ""
assert component.suffix == ""
assert component.batch_size == 32
assert component.progress_bar is True
assert component.normalize_embeddings is False

@patch(
"haystack.components.embedders.sentence_transformers_text_embedder._SentenceTransformersEmbeddingBackendFactory"
)
def test_warmup(self, mocked_factory):
embedder = SentenceTransformersTextEmbedder(model="model", token=None)
embedder = SentenceTransformersTextEmbedder(model="model", token=None, device=ComponentDevice.from_str("cpu"))
mocked_factory.get_embedding_backend.assert_not_called()
embedder.warm_up()
mocked_factory.get_embedding_backend.assert_called_once_with(model="model", device="cpu", auth_token=None)
Expand Down

0 comments on commit d00f171

Please sign in to comment.