Skip to content

Commit

Permalink
add config parameters to OpenAIResource
Browse files Browse the repository at this point in the history
  • Loading branch information
chasleslr committed Jul 26, 2024
1 parent 7a71aa8 commit d013640
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,9 @@ def openai_asset(context: AssetExecutionContext, openai: OpenAIResource):
"""

api_key: str = Field(description=("OpenAI API key. See https://platform.openai.com/api-keys"))
organization: Optional[str] = None
project: Optional[str] = None
base_url: Optional[str] = None

_client: Client = PrivateAttr()

Expand Down Expand Up @@ -212,7 +215,17 @@ def _wrap_with_usage_metadata(

def setup_for_execution(self, context: InitResourceContext) -> None:
# Set up an OpenAI client based on the API key.
self._client = Client(api_key=self.api_key)
kwargs = {}
if self.organization:
kwargs["organization"] = self.organization
if self.project:
kwargs["project"] = self.project
if self.base_url:
kwargs["base_url"] = self.base_url
self._client = Client(
api_key=self.api_key,
**kwargs
)

@public
@contextmanager
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from mock.mock import Base
import pytest
from dagster import (
AssetExecutionContext,
Expand All @@ -19,6 +20,7 @@
from dagster._utils.test import wrap_op_in_graph_and_execute
from dagster_openai import OpenAIResource, with_usage_metadata
from mock import ANY, MagicMock, patch
from openai import OpenAI


@patch("dagster_openai.resources.Client")
Expand All @@ -31,6 +33,26 @@ def test_openai_client(mock_client) -> None:
mock_client.assert_called_once_with(api_key="xoxp-1234123412341234-12341234-1234")


@patch("dagster_openai.resources.Client")
def test_openai_client_with_config(mock_client) -> None:
openai_resource = OpenAIResource(
api_key="xoxp-1234123412341234-12341234-1234",
organization="foo",
project="bar",
base_url="https://foo.bar"
)
openai_resource.setup_for_execution(build_init_resource_context())

mock_context = MagicMock()
with openai_resource.get_client(mock_context):
mock_client.assert_called_once_with(
api_key="xoxp-1234123412341234-12341234-1234",
organization="foo",
project="bar",
base_url="https://foo.bar"
)


@patch("dagster_openai.resources.OpenAIResource._wrap_with_usage_metadata")
@patch("dagster.OpExecutionContext", autospec=OpExecutionContext)
@patch("dagster_openai.resources.Client")
Expand Down

0 comments on commit d013640

Please sign in to comment.