Skip to content

Commit

Permalink
chore: patch streaming
Browse files Browse the repository at this point in the history
  • Loading branch information
jordan-wu-97 authored and gradenr committed May 22, 2024
1 parent bacc106 commit 9287ee7
Show file tree
Hide file tree
Showing 2 changed files with 210 additions and 5 deletions.
211 changes: 208 additions & 3 deletions src/groq/resources/chat/completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

from __future__ import annotations

from typing import Dict, List, Union, Iterable, Optional
from typing import Dict, List, Union, Iterable, Optional, overload
from typing_extensions import Literal

import httpx

Expand All @@ -19,10 +20,12 @@
async_to_raw_response_wrapper,
async_to_streamed_response_wrapper,
)
from ..._streaming import Stream, AsyncStream
from ...types.chat import completion_create_params
from ..._base_client import (
make_request_options,
)
from ...lib.chat_completion_chunk import ChatCompletionChunk
from ...types.chat.chat_completion import ChatCompletion

__all__ = ["Completions", "AsyncCompletions"]
Expand All @@ -37,6 +40,7 @@ def with_raw_response(self) -> CompletionsWithRawResponse:
def with_streaming_response(self) -> CompletionsWithStreamingResponse:
return CompletionsWithStreamingResponse(self)

@overload
def create(
self,
*,
Expand All @@ -53,7 +57,7 @@ def create(
response_format: Optional[completion_create_params.ResponseFormat] | NotGiven = NOT_GIVEN,
seed: Optional[int] | NotGiven = NOT_GIVEN,
stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN,
stream: Optional[bool] | NotGiven = NOT_GIVEN,
stream: Optional[Literal[False]] | NotGiven = NOT_GIVEN,
temperature: Optional[float] | NotGiven = NOT_GIVEN,
tool_choice: Optional[completion_create_params.ToolChoice] | NotGiven = NOT_GIVEN,
tools: Optional[Iterable[completion_create_params.Tool]] | NotGiven = NOT_GIVEN,
Expand All @@ -67,6 +71,104 @@ def create(
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> ChatCompletion:
...

@overload
def create(
self,
*,
messages: Iterable[completion_create_params.Message],
model: str,
frequency_penalty: Optional[float] | NotGiven = NOT_GIVEN,
function_call: Optional[completion_create_params.FunctionCall] | NotGiven = NOT_GIVEN,
functions: Optional[Iterable[completion_create_params.Function]] | NotGiven = NOT_GIVEN,
logit_bias: Optional[Dict[str, int]] | NotGiven = NOT_GIVEN,
logprobs: Optional[bool] | NotGiven = NOT_GIVEN,
max_tokens: Optional[int] | NotGiven = NOT_GIVEN,
n: Optional[int] | NotGiven = NOT_GIVEN,
presence_penalty: Optional[float] | NotGiven = NOT_GIVEN,
response_format: Optional[completion_create_params.ResponseFormat] | NotGiven = NOT_GIVEN,
seed: Optional[int] | NotGiven = NOT_GIVEN,
stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN,
stream: Literal[True],
temperature: Optional[float] | NotGiven = NOT_GIVEN,
tool_choice: Optional[completion_create_params.ToolChoice] | NotGiven = NOT_GIVEN,
tools: Optional[Iterable[completion_create_params.Tool]] | NotGiven = NOT_GIVEN,
top_logprobs: Optional[int] | NotGiven = NOT_GIVEN,
top_p: Optional[float] | NotGiven = NOT_GIVEN,
user: Optional[str] | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> Stream[ChatCompletionChunk]:
...

@overload
def create(
self,
*,
messages: Iterable[completion_create_params.Message],
model: str,
frequency_penalty: Optional[float] | NotGiven = NOT_GIVEN,
function_call: Optional[completion_create_params.FunctionCall] | NotGiven = NOT_GIVEN,
functions: Optional[Iterable[completion_create_params.Function]] | NotGiven = NOT_GIVEN,
logit_bias: Optional[Dict[str, int]] | NotGiven = NOT_GIVEN,
logprobs: Optional[bool] | NotGiven = NOT_GIVEN,
max_tokens: Optional[int] | NotGiven = NOT_GIVEN,
n: Optional[int] | NotGiven = NOT_GIVEN,
presence_penalty: Optional[float] | NotGiven = NOT_GIVEN,
response_format: Optional[completion_create_params.ResponseFormat] | NotGiven = NOT_GIVEN,
seed: Optional[int] | NotGiven = NOT_GIVEN,
stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN,
stream: bool,
temperature: Optional[float] | NotGiven = NOT_GIVEN,
tool_choice: Optional[completion_create_params.ToolChoice] | NotGiven = NOT_GIVEN,
tools: Optional[Iterable[completion_create_params.Tool]] | NotGiven = NOT_GIVEN,
top_logprobs: Optional[int] | NotGiven = NOT_GIVEN,
top_p: Optional[float] | NotGiven = NOT_GIVEN,
user: Optional[str] | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> ChatCompletion | Stream[ChatCompletionChunk]:
...

def create(
self,
*,
messages: Iterable[completion_create_params.Message],
model: str,
frequency_penalty: Optional[float] | NotGiven = NOT_GIVEN,
function_call: Optional[completion_create_params.FunctionCall] | NotGiven = NOT_GIVEN,
functions: Optional[Iterable[completion_create_params.Function]] | NotGiven = NOT_GIVEN,
logit_bias: Optional[Dict[str, int]] | NotGiven = NOT_GIVEN,
logprobs: Optional[bool] | NotGiven = NOT_GIVEN,
max_tokens: Optional[int] | NotGiven = NOT_GIVEN,
n: Optional[int] | NotGiven = NOT_GIVEN,
presence_penalty: Optional[float] | NotGiven = NOT_GIVEN,
response_format: Optional[completion_create_params.ResponseFormat] | NotGiven = NOT_GIVEN,
seed: Optional[int] | NotGiven = NOT_GIVEN,
stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN,
stream: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN,
temperature: Optional[float] | NotGiven = NOT_GIVEN,
tool_choice: Optional[completion_create_params.ToolChoice] | NotGiven = NOT_GIVEN,
tools: Optional[Iterable[completion_create_params.Tool]] | NotGiven = NOT_GIVEN,
top_logprobs: Optional[int] | NotGiven = NOT_GIVEN,
top_p: Optional[float] | NotGiven = NOT_GIVEN,
user: Optional[str] | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> ChatCompletion | Stream[ChatCompletionChunk]:
"""
Creates a model response for the given chat conversation.
Expand Down Expand Up @@ -203,6 +305,8 @@ def create(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
),
cast_to=ChatCompletion,
stream=stream or False,
stream_cls=Stream[ChatCompletionChunk],
)


Expand All @@ -215,6 +319,7 @@ def with_raw_response(self) -> AsyncCompletionsWithRawResponse:
def with_streaming_response(self) -> AsyncCompletionsWithStreamingResponse:
return AsyncCompletionsWithStreamingResponse(self)

@overload
async def create(
self,
*,
Expand All @@ -231,7 +336,7 @@ async def create(
response_format: Optional[completion_create_params.ResponseFormat] | NotGiven = NOT_GIVEN,
seed: Optional[int] | NotGiven = NOT_GIVEN,
stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN,
stream: Optional[bool] | NotGiven = NOT_GIVEN,
stream: Optional[Literal[False]] | NotGiven = NOT_GIVEN,
temperature: Optional[float] | NotGiven = NOT_GIVEN,
tool_choice: Optional[completion_create_params.ToolChoice] | NotGiven = NOT_GIVEN,
tools: Optional[Iterable[completion_create_params.Tool]] | NotGiven = NOT_GIVEN,
Expand All @@ -245,6 +350,104 @@ async def create(
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> ChatCompletion:
...

@overload
async def create(
self,
*,
messages: Iterable[completion_create_params.Message],
model: str,
frequency_penalty: Optional[float] | NotGiven = NOT_GIVEN,
function_call: Optional[completion_create_params.FunctionCall] | NotGiven = NOT_GIVEN,
functions: Optional[Iterable[completion_create_params.Function]] | NotGiven = NOT_GIVEN,
logit_bias: Optional[Dict[str, int]] | NotGiven = NOT_GIVEN,
logprobs: Optional[bool] | NotGiven = NOT_GIVEN,
max_tokens: Optional[int] | NotGiven = NOT_GIVEN,
n: Optional[int] | NotGiven = NOT_GIVEN,
presence_penalty: Optional[float] | NotGiven = NOT_GIVEN,
response_format: Optional[completion_create_params.ResponseFormat] | NotGiven = NOT_GIVEN,
seed: Optional[int] | NotGiven = NOT_GIVEN,
stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN,
stream: Literal[True],
temperature: Optional[float] | NotGiven = NOT_GIVEN,
tool_choice: Optional[completion_create_params.ToolChoice] | NotGiven = NOT_GIVEN,
tools: Optional[Iterable[completion_create_params.Tool]] | NotGiven = NOT_GIVEN,
top_logprobs: Optional[int] | NotGiven = NOT_GIVEN,
top_p: Optional[float] | NotGiven = NOT_GIVEN,
user: Optional[str] | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> AsyncStream[ChatCompletionChunk]:
...

@overload
async def create(
self,
*,
messages: Iterable[completion_create_params.Message],
model: str,
frequency_penalty: Optional[float] | NotGiven = NOT_GIVEN,
function_call: Optional[completion_create_params.FunctionCall] | NotGiven = NOT_GIVEN,
functions: Optional[Iterable[completion_create_params.Function]] | NotGiven = NOT_GIVEN,
logit_bias: Optional[Dict[str, int]] | NotGiven = NOT_GIVEN,
logprobs: Optional[bool] | NotGiven = NOT_GIVEN,
max_tokens: Optional[int] | NotGiven = NOT_GIVEN,
n: Optional[int] | NotGiven = NOT_GIVEN,
presence_penalty: Optional[float] | NotGiven = NOT_GIVEN,
response_format: Optional[completion_create_params.ResponseFormat] | NotGiven = NOT_GIVEN,
seed: Optional[int] | NotGiven = NOT_GIVEN,
stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN,
stream: bool,
temperature: Optional[float] | NotGiven = NOT_GIVEN,
tool_choice: Optional[completion_create_params.ToolChoice] | NotGiven = NOT_GIVEN,
tools: Optional[Iterable[completion_create_params.Tool]] | NotGiven = NOT_GIVEN,
top_logprobs: Optional[int] | NotGiven = NOT_GIVEN,
top_p: Optional[float] | NotGiven = NOT_GIVEN,
user: Optional[str] | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> ChatCompletion | AsyncStream[ChatCompletionChunk]:
...

async def create(
self,
*,
messages: Iterable[completion_create_params.Message],
model: str,
frequency_penalty: Optional[float] | NotGiven = NOT_GIVEN,
function_call: Optional[completion_create_params.FunctionCall] | NotGiven = NOT_GIVEN,
functions: Optional[Iterable[completion_create_params.Function]] | NotGiven = NOT_GIVEN,
logit_bias: Optional[Dict[str, int]] | NotGiven = NOT_GIVEN,
logprobs: Optional[bool] | NotGiven = NOT_GIVEN,
max_tokens: Optional[int] | NotGiven = NOT_GIVEN,
n: Optional[int] | NotGiven = NOT_GIVEN,
presence_penalty: Optional[float] | NotGiven = NOT_GIVEN,
response_format: Optional[completion_create_params.ResponseFormat] | NotGiven = NOT_GIVEN,
seed: Optional[int] | NotGiven = NOT_GIVEN,
stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN,
stream: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN,
temperature: Optional[float] | NotGiven = NOT_GIVEN,
tool_choice: Optional[completion_create_params.ToolChoice] | NotGiven = NOT_GIVEN,
tools: Optional[Iterable[completion_create_params.Tool]] | NotGiven = NOT_GIVEN,
top_logprobs: Optional[int] | NotGiven = NOT_GIVEN,
top_p: Optional[float] | NotGiven = NOT_GIVEN,
user: Optional[str] | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> ChatCompletion | AsyncStream[ChatCompletionChunk]:
"""
Creates a model response for the given chat conversation.
Expand Down Expand Up @@ -381,6 +584,8 @@ async def create(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
),
cast_to=ChatCompletion,
stream=stream or False,
stream_cls=AsyncStream[ChatCompletionChunk],
)


Expand Down
4 changes: 2 additions & 2 deletions tests/api_resources/chat/test_completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def test_method_create_with_all_params(self, client: Groq) -> None:
response_format={"type": "string"},
seed=0,
stop="\n",
stream=True,
stream=False,
temperature=0,
tool_choice="none",
tools=[
Expand Down Expand Up @@ -252,7 +252,7 @@ async def test_method_create_with_all_params(self, async_client: AsyncGroq) -> N
response_format={"type": "string"},
seed=0,
stop="\n",
stream=True,
stream=False,
temperature=0,
tool_choice="none",
tools=[
Expand Down

0 comments on commit 9287ee7

Please sign in to comment.