Skip to content

Commit

Permalink
chore: patch streaming
Browse files Browse the repository at this point in the history
  • Loading branch information
jordan-wu-97 committed May 15, 2024
1 parent 961159c commit c8eaad3
Show file tree
Hide file tree
Showing 3 changed files with 202 additions and 5 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ chat_completion = client.chat.completions.create(
],
model="mixtral-8x7b-32768",
)
print(chat_completion.choices_0.message.content)
print(chat_completion.choices[0].message.content)
```

While you can provide an `api_key` keyword argument,
Expand Down Expand Up @@ -66,7 +66,7 @@ async def main() -> None:
],
model="mixtral-8x7b-32768",
)
print(chat_completion.choices_0.message.content)
print(chat_completion.choices[0].message.content)


asyncio.run(main())
Expand Down
4 changes: 4 additions & 0 deletions src/groq/_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ def __stream__(self) -> Iterator[_T]:
iterator = self._iter_events()

for sse in iterator:
if sse.data.startswith("[DONE]"):
break
yield process_data(data=sse.json(), cast_to=cast_to, response=response)

# Ensure the entire stream is consumed
Expand Down Expand Up @@ -119,6 +121,8 @@ async def __stream__(self) -> AsyncIterator[_T]:
iterator = self._iter_events()

async for sse in iterator:
if sse.data.startswith("[DONE]"):
break
yield process_data(data=sse.json(), cast_to=cast_to, response=response)

# Ensure the entire stream is consumed
Expand Down
199 changes: 196 additions & 3 deletions src/groq/resources/chat/completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

from __future__ import annotations

from typing import Dict, List, Union, Iterable, Optional
from typing import Dict, List, Union, Iterable, Optional, overload
from typing_extensions import Literal

import httpx

Expand All @@ -19,10 +20,12 @@
async_to_raw_response_wrapper,
async_to_streamed_response_wrapper,
)
from ..._streaming import Stream, AsyncStream
from ...types.chat import completion_create_params
from ..._base_client import (
make_request_options,
)
from ...lib.chat_completion_chunk import ChatCompletionChunk
from ...types.chat.chat_completion import ChatCompletion

__all__ = ["CompletionsResource", "AsyncCompletionsResource"]
Expand All @@ -37,6 +40,7 @@ def with_raw_response(self) -> CompletionsResourceWithRawResponse:
def with_streaming_response(self) -> CompletionsResourceWithStreamingResponse:
return CompletionsResourceWithStreamingResponse(self)

@overload
def create(
self,
*,
Expand All @@ -51,7 +55,7 @@ def create(
response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN,
seed: int | NotGiven = NOT_GIVEN,
stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN,
stream: bool | NotGiven = NOT_GIVEN,
stream: Optional[Literal[False]] | NotGiven = NOT_GIVEN,
temperature: float | NotGiven = NOT_GIVEN,
tool_choice: completion_create_params.ToolChoice | NotGiven = NOT_GIVEN,
tools: Iterable[completion_create_params.Tool] | NotGiven = NOT_GIVEN,
Expand All @@ -65,6 +69,98 @@ def create(
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> ChatCompletion:
...

@overload
def create(
self,
*,
messages: Iterable[completion_create_params.Message],
model: str,
frequency_penalty: float | NotGiven = NOT_GIVEN,
logit_bias: Dict[str, int] | NotGiven = NOT_GIVEN,
logprobs: bool | NotGiven = NOT_GIVEN,
max_tokens: int | NotGiven = NOT_GIVEN,
n: int | NotGiven = NOT_GIVEN,
presence_penalty: float | NotGiven = NOT_GIVEN,
response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN,
seed: int | NotGiven = NOT_GIVEN,
stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN,
stream: Literal[True],
temperature: float | NotGiven = NOT_GIVEN,
tool_choice: completion_create_params.ToolChoice | NotGiven = NOT_GIVEN,
tools: Iterable[completion_create_params.Tool] | NotGiven = NOT_GIVEN,
top_logprobs: int | NotGiven = NOT_GIVEN,
top_p: float | NotGiven = NOT_GIVEN,
user: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> Stream[ChatCompletionChunk]:
...

@overload
def create(
self,
*,
messages: Iterable[completion_create_params.Message],
model: str,
frequency_penalty: float | NotGiven = NOT_GIVEN,
logit_bias: Dict[str, int] | NotGiven = NOT_GIVEN,
logprobs: bool | NotGiven = NOT_GIVEN,
max_tokens: int | NotGiven = NOT_GIVEN,
n: int | NotGiven = NOT_GIVEN,
presence_penalty: float | NotGiven = NOT_GIVEN,
response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN,
seed: int | NotGiven = NOT_GIVEN,
stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN,
stream: bool,
temperature: float | NotGiven = NOT_GIVEN,
tool_choice: completion_create_params.ToolChoice | NotGiven = NOT_GIVEN,
tools: Iterable[completion_create_params.Tool] | NotGiven = NOT_GIVEN,
top_logprobs: int | NotGiven = NOT_GIVEN,
top_p: float | NotGiven = NOT_GIVEN,
user: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> ChatCompletion | Stream[ChatCompletionChunk]:
...

def create(
self,
*,
messages: Iterable[completion_create_params.Message],
model: str,
frequency_penalty: float | NotGiven = NOT_GIVEN,
logit_bias: Dict[str, int] | NotGiven = NOT_GIVEN,
logprobs: bool | NotGiven = NOT_GIVEN,
max_tokens: int | NotGiven = NOT_GIVEN,
n: int | NotGiven = NOT_GIVEN,
presence_penalty: float | NotGiven = NOT_GIVEN,
response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN,
seed: int | NotGiven = NOT_GIVEN,
stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN,
stream: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN,
temperature: float | NotGiven = NOT_GIVEN,
tool_choice: completion_create_params.ToolChoice | NotGiven = NOT_GIVEN,
tools: Iterable[completion_create_params.Tool] | NotGiven = NOT_GIVEN,
top_logprobs: int | NotGiven = NOT_GIVEN,
top_p: float | NotGiven = NOT_GIVEN,
user: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> ChatCompletion | Stream[ChatCompletionChunk]:
"""
Creates a completion for a chat prompt
Expand Down Expand Up @@ -109,6 +205,8 @@ def create(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
),
cast_to=ChatCompletion,
stream=stream or False,
stream_cls=Stream[ChatCompletionChunk],
)


Expand All @@ -121,6 +219,7 @@ def with_raw_response(self) -> AsyncCompletionsResourceWithRawResponse:
def with_streaming_response(self) -> AsyncCompletionsResourceWithStreamingResponse:
return AsyncCompletionsResourceWithStreamingResponse(self)

@overload
async def create(
self,
*,
Expand All @@ -135,7 +234,7 @@ async def create(
response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN,
seed: int | NotGiven = NOT_GIVEN,
stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN,
stream: bool | NotGiven = NOT_GIVEN,
stream: Optional[Literal[False]] | NotGiven = NOT_GIVEN,
temperature: float | NotGiven = NOT_GIVEN,
tool_choice: completion_create_params.ToolChoice | NotGiven = NOT_GIVEN,
tools: Iterable[completion_create_params.Tool] | NotGiven = NOT_GIVEN,
Expand All @@ -149,6 +248,98 @@ async def create(
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> ChatCompletion:
...

@overload
async def create(
self,
*,
messages: Iterable[completion_create_params.Message],
model: str,
frequency_penalty: float | NotGiven = NOT_GIVEN,
logit_bias: Dict[str, int] | NotGiven = NOT_GIVEN,
logprobs: bool | NotGiven = NOT_GIVEN,
max_tokens: int | NotGiven = NOT_GIVEN,
n: int | NotGiven = NOT_GIVEN,
presence_penalty: float | NotGiven = NOT_GIVEN,
response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN,
seed: int | NotGiven = NOT_GIVEN,
stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN,
stream: Literal[True],
temperature: float | NotGiven = NOT_GIVEN,
tool_choice: completion_create_params.ToolChoice | NotGiven = NOT_GIVEN,
tools: Iterable[completion_create_params.Tool] | NotGiven = NOT_GIVEN,
top_logprobs: int | NotGiven = NOT_GIVEN,
top_p: float | NotGiven = NOT_GIVEN,
user: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> AsyncStream[ChatCompletionChunk]:
...

@overload
async def create(
self,
*,
messages: Iterable[completion_create_params.Message],
model: str,
frequency_penalty: float | NotGiven = NOT_GIVEN,
logit_bias: Dict[str, int] | NotGiven = NOT_GIVEN,
logprobs: bool | NotGiven = NOT_GIVEN,
max_tokens: int | NotGiven = NOT_GIVEN,
n: int | NotGiven = NOT_GIVEN,
presence_penalty: float | NotGiven = NOT_GIVEN,
response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN,
seed: int | NotGiven = NOT_GIVEN,
stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN,
stream: bool,
temperature: float | NotGiven = NOT_GIVEN,
tool_choice: completion_create_params.ToolChoice | NotGiven = NOT_GIVEN,
tools: Iterable[completion_create_params.Tool] | NotGiven = NOT_GIVEN,
top_logprobs: int | NotGiven = NOT_GIVEN,
top_p: float | NotGiven = NOT_GIVEN,
user: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> ChatCompletion | AsyncStream[ChatCompletionChunk]:
...

async def create(
self,
*,
messages: Iterable[completion_create_params.Message],
model: str,
frequency_penalty: float | NotGiven = NOT_GIVEN,
logit_bias: Dict[str, int] | NotGiven = NOT_GIVEN,
logprobs: bool | NotGiven = NOT_GIVEN,
max_tokens: int | NotGiven = NOT_GIVEN,
n: int | NotGiven = NOT_GIVEN,
presence_penalty: float | NotGiven = NOT_GIVEN,
response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN,
seed: int | NotGiven = NOT_GIVEN,
stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN,
stream: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN,
temperature: float | NotGiven = NOT_GIVEN,
tool_choice: completion_create_params.ToolChoice | NotGiven = NOT_GIVEN,
tools: Iterable[completion_create_params.Tool] | NotGiven = NOT_GIVEN,
top_logprobs: int | NotGiven = NOT_GIVEN,
top_p: float | NotGiven = NOT_GIVEN,
user: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> ChatCompletion | AsyncStream[ChatCompletionChunk]:
"""
Creates a completion for a chat prompt
Expand Down Expand Up @@ -193,6 +384,8 @@ async def create(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
),
cast_to=ChatCompletion,
stream=stream or False,
stream_cls=AsyncStream[ChatCompletionChunk],
)


Expand Down

0 comments on commit c8eaad3

Please sign in to comment.