Skip to content

Commit

Permalink
refactor: optimize API keys reading (#6655)
Browse files Browse the repository at this point in the history
* centralize API keys handling

* fix mypy and pylint

* rm utility function, be more explicit
  • Loading branch information
anakin87 committed Jan 5, 2024
1 parent 1336456 commit bb2b1a2
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 31 deletions.
21 changes: 10 additions & 11 deletions haystack/components/converters/azure.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from pathlib import Path
from typing import List, Union, Dict, Any, Optional
import os
import logging
import os

from haystack.lazy_imports import LazyImport
from haystack import component, Document, default_to_dict
Expand Down Expand Up @@ -51,16 +51,15 @@ def __init__(self, endpoint: str, api_key: Optional[str] = None, model_id: str =
"""
azure_import.check()

if api_key is None:
try:
api_key = os.environ["AZURE_AI_API_KEY"]
except KeyError as e:
raise ValueError(
"AzureOCRDocumentConverter expects an Azure Credential key. "
"Set the AZURE_AI_API_KEY environment variable (recommended) or pass it explicitly."
) from e

self.api_key = api_key
api_key = api_key or os.environ.get("AZURE_AI_API_KEY")
# we check whether api_key is None or an empty string
if not api_key:
msg = (
"AzureOCRDocumentConverter expects an API key. "
"Set the AZURE_AI_API_KEY environment variable (recommended) or pass it explicitly."
)
raise ValueError(msg)

self.document_analysis_client = DocumentAnalysisClient(
endpoint=endpoint, credential=AzureKeyCredential(api_key)
)
Expand Down
19 changes: 10 additions & 9 deletions haystack/components/websearch/searchapi.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import os
import logging
from typing import Dict, List, Optional, Any
import os

import requests

Expand Down Expand Up @@ -42,14 +42,15 @@ def __init__(
For example, you can set 'num' to 100 to increase the number of search results.
See the [SearchApi website](https://www.searchapi.io/) for more details.
"""
if api_key is None:
try:
api_key = os.environ["SEARCHAPI_API_KEY"]
except KeyError as e:
raise ValueError(
"SearchApiWebSearch expects an API key. "
"Set the SEARCHAPI_API_KEY environment variable (recommended) or pass it explicitly."
) from e
api_key = api_key or os.environ.get("SEARCHAPI_API_KEY")
# we check whether api_key is None or an empty string
if not api_key:
msg = (
"SearchApiWebSearch expects an API key. "
"Set the SEARCHAPI_API_KEY environment variable (recommended) or pass it explicitly."
)
raise ValueError(msg)

self.api_key = api_key
self.top_k = top_k
self.allowed_domains = allowed_domains
Expand Down
20 changes: 10 additions & 10 deletions haystack/components/websearch/serper_dev.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import os
import logging
from typing import Dict, List, Optional, Any
import os

import requests

Expand Down Expand Up @@ -42,15 +42,15 @@ def __init__(
For example, you can set 'num' to 20 to increase the number of search results.
See the [Serper Dev website](https://serper.dev/) for more details.
"""
if api_key is None:
try:
api_key = os.environ["SERPERDEV_API_KEY"]
except KeyError as e:
raise ValueError(
"SerperDevWebSearch expects an API key. "
"Set the SERPERDEV_API_KEY environment variable (recommended) or pass it explicitly."
) from e
raise ValueError("API key for SerperDev API must be set.")
api_key = api_key or os.environ.get("SERPERDEV_API_KEY")
# we check whether api_key is None or an empty string
if not api_key:
msg = (
"SerperDevWebSearch expects an API key. "
"Set the SERPERDEV_API_KEY environment variable (recommended) or pass it explicitly."
)
raise ValueError(msg)

self.api_key = api_key
self.top_k = top_k
self.allowed_domains = allowed_domains
Expand Down
2 changes: 1 addition & 1 deletion test/components/converters/test_azure_ocr_doc_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
class TestAzureOCRDocumentConverter:
def test_init_fail_wo_api_key(self, monkeypatch):
monkeypatch.delenv("AZURE_AI_API_KEY", raising=False)
with pytest.raises(ValueError, match="AzureOCRDocumentConverter expects an Azure Credential key"):
with pytest.raises(ValueError):
AzureOCRDocumentConverter(endpoint="test_endpoint")

def test_to_dict(self):
Expand Down

0 comments on commit bb2b1a2

Please sign in to comment.