Skip to content

Commit

Permalink
fix broken serialization of HFAPI components (#7661)
Browse files Browse the repository at this point in the history
  • Loading branch information
anakin87 committed May 8, 2024
1 parent 9446714 commit 7c9532b
Show file tree
Hide file tree
Showing 9 changed files with 13 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def to_dict(self) -> Dict[str, Any]:
"""
return default_to_dict(
self,
api_type=self.api_type,
api_type=str(self.api_type),
api_params=self.api_params,
prefix=self.prefix,
suffix=self.suffix,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def to_dict(self) -> Dict[str, Any]:
"""
return default_to_dict(
self,
api_type=self.api_type,
api_type=str(self.api_type),
api_params=self.api_params,
prefix=self.prefix,
suffix=self.suffix,
Expand Down
2 changes: 1 addition & 1 deletion haystack/components/generators/chat/hugging_face_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def to_dict(self) -> Dict[str, Any]:
callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None
return default_to_dict(
self,
api_type=self.api_type,
api_type=str(self.api_type),
api_params=self.api_params,
token=self.token.to_dict() if self.token else None,
generation_kwargs=self.generation_kwargs,
Expand Down
2 changes: 1 addition & 1 deletion haystack/components/generators/hugging_face_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def to_dict(self) -> Dict[str, Any]:
callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None
return default_to_dict(
self,
api_type=self.api_type,
api_type=str(self.api_type),
api_params=self.api_params,
token=self.token.to_dict() if self.token else None,
generation_kwargs=self.generation_kwargs,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
fixes:
- |
Fix the broken serialization of HuggingFaceAPITextEmbedder, HuggingFaceAPIDocumentEmbedder,
HuggingFaceAPIGenerator, and HuggingFaceAPIChatGenerator.
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def test_to_dict(self, mock_check_valid_model):
assert data == {
"type": "haystack.components.embedders.hugging_face_api_document_embedder.HuggingFaceAPIDocumentEmbedder",
"init_parameters": {
"api_type": HFEmbeddingAPIType.SERVERLESS_INFERENCE_API,
"api_type": "serverless_inference_api",
"api_params": {"model": "BAAI/bge-small-en-v1.5"},
"token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"},
"prefix": "prefix",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def test_to_dict(self, mock_check_valid_model):
assert data == {
"type": "haystack.components.embedders.hugging_face_api_text_embedder.HuggingFaceAPITextEmbedder",
"init_parameters": {
"api_type": HFEmbeddingAPIType.SERVERLESS_INFERENCE_API,
"api_type": "serverless_inference_api",
"api_params": {"model": "BAAI/bge-small-en-v1.5"},
"token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"},
"prefix": "prefix",
Expand Down
2 changes: 1 addition & 1 deletion test/components/generators/chat/test_hugging_face_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def test_to_dict(self, mock_check_valid_model):
result = generator.to_dict()
init_params = result["init_parameters"]

assert init_params["api_type"] == HFGenerationAPIType.SERVERLESS_INFERENCE_API
assert init_params["api_type"] == "serverless_inference_api"
assert init_params["api_params"] == {"model": "HuggingFaceH4/zephyr-7b-beta"}
assert init_params["token"] == {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"}
assert init_params["generation_kwargs"] == {"temperature": 0.6, "stop": ["stop", "words"], "max_tokens": 512}
Expand Down
2 changes: 1 addition & 1 deletion test/components/generators/test_hugging_face_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def test_to_dict(self, mock_check_valid_model):
result = generator.to_dict()
init_params = result["init_parameters"]

assert init_params["api_type"] == HFGenerationAPIType.SERVERLESS_INFERENCE_API
assert init_params["api_type"] == "serverless_inference_api"
assert init_params["api_params"] == {"model": "HuggingFaceH4/zephyr-7b-beta"}
assert init_params["token"] == {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"}
assert init_params["generation_kwargs"] == {
Expand Down

0 comments on commit 7c9532b

Please sign in to comment.