Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Retire openapi3, use openapi-service-client instead #7514

Closed
wants to merge 9 commits into from
Closed
Prev Previous commit
Next Next commit
mypy fixes
  • Loading branch information
vblagoje committed May 23, 2024
commit 6712e425afa17502a6669c54974f8d05bd927398
8 changes: 5 additions & 3 deletions haystack/components/connectors/openapi_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class OpenAPIServiceConnector:

"""

def __init__(self, provider_map: Optional[Dict[str, Any]] = None, default_provider: Optional[str] = "openai"):
def __init__(self, provider_map: Optional[Dict[str, Any]] = None, default_provider: Optional[str] = None):
"""
Initializes the OpenAPIServiceConnector instance

Expand All @@ -77,6 +77,7 @@ def __init__(self, provider_map: Optional[Dict[str, Any]] = None, default_provid
"anthropic": AnthropicLLMProvider(),
"cohere": CohereLLMProvider(),
}
default_provider = default_provider or "openai"
if default_provider not in self.provider_map:
raise ValueError(f"Default provider {default_provider} not found in provider map.")
self.default_provider = default_provider
Expand All @@ -87,7 +88,7 @@ def run(
messages: List[ChatMessage],
service_openapi_spec: Dict[str, Any],
service_credentials: Optional[Union[dict, str]] = None,
llm_provider: Optional[str] = "openai",
llm_provider: Optional[str] = None,
) -> Dict[str, List[ChatMessage]]:
"""
Processes a list of chat messages to invoke a method on an OpenAPI service.
Expand Down Expand Up @@ -119,7 +120,8 @@ def run(
if not last_message.content:
raise ValueError("Function calling message content is empty.")

llm_provider = self.provider_map.get(llm_provider, self.provider_map[self.default_provider])
default_provider = self.provider_map.get(self.default_provider, None)
llm_provider = self.provider_map.get(llm_provider or "openai", None) or default_provider
logger.debug(f"Using LLM provider: {llm_provider.__class__.__name__}")

builder = ClientConfigurationBuilder()
Expand Down
11 changes: 5 additions & 6 deletions haystack/components/converters/openapi_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class OpenAPIServiceToFunctions:

MIN_REQUIRED_OPENAPI_SPEC_VERSION = 3

def __init__(self, provider_map: Optional[Dict[str, Any]] = None, default_provider: Optional[str] = "openai"):
def __init__(self, provider_map: Optional[Dict[str, Any]] = None, default_provider: Optional[str] = None):
"""
Create an OpenAPIServiceToFunctions component.

Expand All @@ -57,12 +57,10 @@ def __init__(self, provider_map: Optional[Dict[str, Any]] = None, default_provid
}
if default_provider not in self.provider_map:
raise ValueError(f"Default provider {default_provider} not found in provider map.")
self.default_provider = default_provider
self.default_provider = default_provider or "openai"

@component.output_types(functions=List[Dict[str, Any]], openapi_specs=List[Dict[str, Any]])
def run(
self, sources: List[Union[str, Path, ByteStream]], llm_provider: Optional[str] = "openai"
) -> Dict[str, Any]:
def run(self, sources: List[Union[str, Path, ByteStream]], llm_provider: Optional[str] = None) -> Dict[str, Any]:
"""
Converts OpenAPI definitions into LLM specific function calling format.

Expand All @@ -83,7 +81,8 @@ def run(
"""
all_extracted_fc_definitions: List[Dict[str, Any]] = []
all_openapi_specs = []
llm_provider = self.provider_map.get(llm_provider, self.provider_map.get(self.default_provider, None))
default_provider = self.provider_map.get(self.default_provider, "")
llm_provider = self.provider_map.get(llm_provider or "openai", None) or default_provider
if llm_provider is None:
raise ValueError(f"LLM provider {llm_provider} not found in provider map.")
logger.debug(f"Using LLM provider: {llm_provider.__class__.__name__}")
Expand Down