Skip to content

Commit

Permalink
chore: Fix streaming before release
Browse files Browse the repository at this point in the history
  • Loading branch information
gradenr committed Mar 8, 2024
1 parent a1594a7 commit 7dec9b7
Show file tree
Hide file tree
Showing 3 changed files with 204 additions and 5 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ chat_completion = client.chat.completions.create(
],
model="mixtral-8x7b-32768",
)
print(chat_completion.choices_0.message.content)
print(chat_completion.choices[0].message.content)
```

While you can provide an `api_key` keyword argument,
Expand Down Expand Up @@ -72,7 +72,7 @@ async def main() -> None:
],
model="mixtral-8x7b-32768",
)
print(chat_completion.choices_0.message.content)
print(chat_completion.choices[0].message.content)


asyncio.run(main())
Expand Down
6 changes: 6 additions & 0 deletions src/groq/_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ def __stream__(self) -> Iterator[_T]:
iterator = self._iter_events()

for sse in iterator:
if sse.data.startswith("[DONE]"):
break
yield process_data(data=sse.json(), cast_to=cast_to, response=response)

# Ensure the entire stream is consumed
Expand Down Expand Up @@ -114,9 +116,13 @@ async def __aiter__(self) -> AsyncIterator[_T]:
async def _iter_events(self) -> AsyncIterator[ServerSentEvent]:
if isinstance(self._decoder, SSEBytesDecoder):
async for sse in self._decoder.aiter_bytes(self.response.aiter_bytes()):
if sse.data.startswith("[DONE]"):
break
yield sse
else:
async for sse in self._decoder.aiter(self.response.aiter_lines()):
if sse.data.startswith("[DONE]"):
break
yield sse

async def __stream__(self) -> AsyncIterator[_T]:
Expand Down
199 changes: 196 additions & 3 deletions src/groq/resources/chat/completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

from __future__ import annotations

from typing import Dict, List, Union, Iterable, Optional
from typing import Dict, List, Union, Iterable, Optional, overload
from typing_extensions import Literal

import httpx

Expand All @@ -19,10 +20,12 @@
async_to_raw_response_wrapper,
async_to_streamed_response_wrapper,
)
from ..._streaming import Stream, AsyncStream
from ...types.chat import ChatCompletion, completion_create_params
from ..._base_client import (
make_request_options,
)
from ...lib.chat_completion_chunk import ChatCompletionChunk

__all__ = ["Completions", "AsyncCompletions"]

Expand All @@ -36,6 +39,7 @@ def with_raw_response(self) -> CompletionsWithRawResponse:
def with_streaming_response(self) -> CompletionsWithStreamingResponse:
return CompletionsWithStreamingResponse(self)

@overload
def create(
self,
*,
Expand All @@ -50,7 +54,7 @@ def create(
response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN,
seed: int | NotGiven = NOT_GIVEN,
stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN,
stream: bool | NotGiven = NOT_GIVEN,
stream: Optional[Literal[False]] | NotGiven = NOT_GIVEN,
temperature: float | NotGiven = NOT_GIVEN,
tool_choice: completion_create_params.ToolChoice | NotGiven = NOT_GIVEN,
tools: Iterable[completion_create_params.Tool] | NotGiven = NOT_GIVEN,
Expand All @@ -64,6 +68,98 @@ def create(
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> ChatCompletion:
...

@overload
def create(
self,
*,
messages: Iterable[completion_create_params.Message],
model: str,
frequency_penalty: float | NotGiven = NOT_GIVEN,
logit_bias: Dict[str, int] | NotGiven = NOT_GIVEN,
logprobs: bool | NotGiven = NOT_GIVEN,
max_tokens: int | NotGiven = NOT_GIVEN,
n: int | NotGiven = NOT_GIVEN,
presence_penalty: float | NotGiven = NOT_GIVEN,
response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN,
seed: int | NotGiven = NOT_GIVEN,
stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN,
stream: Literal[True],
temperature: float | NotGiven = NOT_GIVEN,
tool_choice: completion_create_params.ToolChoice | NotGiven = NOT_GIVEN,
tools: Iterable[completion_create_params.Tool] | NotGiven = NOT_GIVEN,
top_logprobs: int | NotGiven = NOT_GIVEN,
top_p: float | NotGiven = NOT_GIVEN,
user: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> Stream[ChatCompletionChunk]:
...

@overload
def create(
self,
*,
messages: Iterable[completion_create_params.Message],
model: str,
frequency_penalty: float | NotGiven = NOT_GIVEN,
logit_bias: Dict[str, int] | NotGiven = NOT_GIVEN,
logprobs: bool | NotGiven = NOT_GIVEN,
max_tokens: int | NotGiven = NOT_GIVEN,
n: int | NotGiven = NOT_GIVEN,
presence_penalty: float | NotGiven = NOT_GIVEN,
response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN,
seed: int | NotGiven = NOT_GIVEN,
stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN,
stream: bool,
temperature: float | NotGiven = NOT_GIVEN,
tool_choice: completion_create_params.ToolChoice | NotGiven = NOT_GIVEN,
tools: Iterable[completion_create_params.Tool] | NotGiven = NOT_GIVEN,
top_logprobs: int | NotGiven = NOT_GIVEN,
top_p: float | NotGiven = NOT_GIVEN,
user: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> ChatCompletion | Stream[ChatCompletionChunk]:
...

def create(
self,
*,
messages: Iterable[completion_create_params.Message],
model: str,
frequency_penalty: float | NotGiven = NOT_GIVEN,
logit_bias: Dict[str, int] | NotGiven = NOT_GIVEN,
logprobs: bool | NotGiven = NOT_GIVEN,
max_tokens: int | NotGiven = NOT_GIVEN,
n: int | NotGiven = NOT_GIVEN,
presence_penalty: float | NotGiven = NOT_GIVEN,
response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN,
seed: int | NotGiven = NOT_GIVEN,
stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN,
stream: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN,
temperature: float | NotGiven = NOT_GIVEN,
tool_choice: completion_create_params.ToolChoice | NotGiven = NOT_GIVEN,
tools: Iterable[completion_create_params.Tool] | NotGiven = NOT_GIVEN,
top_logprobs: int | NotGiven = NOT_GIVEN,
top_p: float | NotGiven = NOT_GIVEN,
user: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> ChatCompletion | Stream[ChatCompletionChunk]:
"""
Creates a completion for a chat prompt
Expand Down Expand Up @@ -108,6 +204,8 @@ def create(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
),
cast_to=ChatCompletion,
stream=stream or False,
stream_cls=Stream[ChatCompletionChunk],
)


Expand All @@ -120,6 +218,7 @@ def with_raw_response(self) -> AsyncCompletionsWithRawResponse:
def with_streaming_response(self) -> AsyncCompletionsWithStreamingResponse:
return AsyncCompletionsWithStreamingResponse(self)

@overload
async def create(
self,
*,
Expand All @@ -134,7 +233,7 @@ async def create(
response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN,
seed: int | NotGiven = NOT_GIVEN,
stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN,
stream: bool | NotGiven = NOT_GIVEN,
stream: Optional[Literal[False]] | NotGiven = NOT_GIVEN,
temperature: float | NotGiven = NOT_GIVEN,
tool_choice: completion_create_params.ToolChoice | NotGiven = NOT_GIVEN,
tools: Iterable[completion_create_params.Tool] | NotGiven = NOT_GIVEN,
Expand All @@ -148,6 +247,98 @@ async def create(
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> ChatCompletion:
...

@overload
async def create(
self,
*,
messages: Iterable[completion_create_params.Message],
model: str,
frequency_penalty: float | NotGiven = NOT_GIVEN,
logit_bias: Dict[str, int] | NotGiven = NOT_GIVEN,
logprobs: bool | NotGiven = NOT_GIVEN,
max_tokens: int | NotGiven = NOT_GIVEN,
n: int | NotGiven = NOT_GIVEN,
presence_penalty: float | NotGiven = NOT_GIVEN,
response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN,
seed: int | NotGiven = NOT_GIVEN,
stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN,
stream: Literal[True],
temperature: float | NotGiven = NOT_GIVEN,
tool_choice: completion_create_params.ToolChoice | NotGiven = NOT_GIVEN,
tools: Iterable[completion_create_params.Tool] | NotGiven = NOT_GIVEN,
top_logprobs: int | NotGiven = NOT_GIVEN,
top_p: float | NotGiven = NOT_GIVEN,
user: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> AsyncStream[ChatCompletionChunk]:
...

@overload
async def create(
self,
*,
messages: Iterable[completion_create_params.Message],
model: str,
frequency_penalty: float | NotGiven = NOT_GIVEN,
logit_bias: Dict[str, int] | NotGiven = NOT_GIVEN,
logprobs: bool | NotGiven = NOT_GIVEN,
max_tokens: int | NotGiven = NOT_GIVEN,
n: int | NotGiven = NOT_GIVEN,
presence_penalty: float | NotGiven = NOT_GIVEN,
response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN,
seed: int | NotGiven = NOT_GIVEN,
stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN,
stream: bool,
temperature: float | NotGiven = NOT_GIVEN,
tool_choice: completion_create_params.ToolChoice | NotGiven = NOT_GIVEN,
tools: Iterable[completion_create_params.Tool] | NotGiven = NOT_GIVEN,
top_logprobs: int | NotGiven = NOT_GIVEN,
top_p: float | NotGiven = NOT_GIVEN,
user: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> ChatCompletion | AsyncStream[ChatCompletionChunk]:
...

async def create(
self,
*,
messages: Iterable[completion_create_params.Message],
model: str,
frequency_penalty: float | NotGiven = NOT_GIVEN,
logit_bias: Dict[str, int] | NotGiven = NOT_GIVEN,
logprobs: bool | NotGiven = NOT_GIVEN,
max_tokens: int | NotGiven = NOT_GIVEN,
n: int | NotGiven = NOT_GIVEN,
presence_penalty: float | NotGiven = NOT_GIVEN,
response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN,
seed: int | NotGiven = NOT_GIVEN,
stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN,
stream: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN,
temperature: float | NotGiven = NOT_GIVEN,
tool_choice: completion_create_params.ToolChoice | NotGiven = NOT_GIVEN,
tools: Iterable[completion_create_params.Tool] | NotGiven = NOT_GIVEN,
top_logprobs: int | NotGiven = NOT_GIVEN,
top_p: float | NotGiven = NOT_GIVEN,
user: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> ChatCompletion | AsyncStream[ChatCompletionChunk]:
"""
Creates a completion for a chat prompt
Expand Down Expand Up @@ -192,6 +383,8 @@ async def create(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
),
cast_to=ChatCompletion,
stream=stream or False,
stream_cls=AsyncStream[ChatCompletionChunk],
)


Expand Down

0 comments on commit 7dec9b7

Please sign in to comment.