From 7dec9b711c222dec4f4a86d7e6c650525066d2aa Mon Sep 17 00:00:00 2001 From: Graden Rea Date: Fri, 1 Mar 2024 14:03:21 -0800 Subject: [PATCH] chore: Fix streaming before release --- README.md | 4 +- src/groq/_streaming.py | 6 + src/groq/resources/chat/completions.py | 199 ++++++++++++++++++++++++- 3 files changed, 204 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 6ef2d7c..76bbd7d 100644 --- a/README.md +++ b/README.md @@ -39,7 +39,7 @@ chat_completion = client.chat.completions.create( ], model="mixtral-8x7b-32768", ) -print(chat_completion.choices_0.message.content) +print(chat_completion.choices[0].message.content) ``` While you can provide an `api_key` keyword argument, @@ -72,7 +72,7 @@ async def main() -> None: ], model="mixtral-8x7b-32768", ) - print(chat_completion.choices_0.message.content) + print(chat_completion.choices[0].message.content) asyncio.run(main()) diff --git a/src/groq/_streaming.py b/src/groq/_streaming.py index 80b2aa6..c408423 100644 --- a/src/groq/_streaming.py +++ b/src/groq/_streaming.py @@ -58,6 +58,8 @@ def __stream__(self) -> Iterator[_T]: iterator = self._iter_events() for sse in iterator: + if sse.data.startswith("[DONE]"): + break yield process_data(data=sse.json(), cast_to=cast_to, response=response) # Ensure the entire stream is consumed @@ -114,9 +116,13 @@ async def __aiter__(self) -> AsyncIterator[_T]: async def _iter_events(self) -> AsyncIterator[ServerSentEvent]: if isinstance(self._decoder, SSEBytesDecoder): async for sse in self._decoder.aiter_bytes(self.response.aiter_bytes()): + if sse.data.startswith("[DONE]"): + break yield sse else: async for sse in self._decoder.aiter(self.response.aiter_lines()): + if sse.data.startswith("[DONE]"): + break yield sse async def __stream__(self) -> AsyncIterator[_T]: diff --git a/src/groq/resources/chat/completions.py b/src/groq/resources/chat/completions.py index 38c433c..277aa58 100644 --- a/src/groq/resources/chat/completions.py +++ b/src/groq/resources/chat/completions.py @@ -2,7 +2,8 @@ from __future__ import annotations -from typing import Dict, List, Union, Iterable, Optional +from typing import Dict, List, Union, Iterable, Optional, overload +from typing_extensions import Literal import httpx @@ -19,10 +20,12 @@ async_to_raw_response_wrapper, async_to_streamed_response_wrapper, ) +from ..._streaming import Stream, AsyncStream from ...types.chat import ChatCompletion, completion_create_params from ..._base_client import ( make_request_options, ) +from ...lib.chat_completion_chunk import ChatCompletionChunk __all__ = ["Completions", "AsyncCompletions"] @@ -36,6 +39,7 @@ def with_raw_response(self) -> CompletionsWithRawResponse: def with_streaming_response(self) -> CompletionsWithStreamingResponse: return CompletionsWithStreamingResponse(self) + @overload def create( self, *, @@ -50,7 +54,7 @@ def create( response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN, seed: int | NotGiven = NOT_GIVEN, stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN, - stream: bool | NotGiven = NOT_GIVEN, + stream: Optional[Literal[False]] | NotGiven = NOT_GIVEN, temperature: float | NotGiven = NOT_GIVEN, tool_choice: completion_create_params.ToolChoice | NotGiven = NOT_GIVEN, tools: Iterable[completion_create_params.Tool] | NotGiven = NOT_GIVEN, @@ -64,6 +68,98 @@ def create( extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> ChatCompletion: + ... + + @overload + def create( + self, + *, + messages: Iterable[completion_create_params.Message], + model: str, + frequency_penalty: float | NotGiven = NOT_GIVEN, + logit_bias: Dict[str, int] | NotGiven = NOT_GIVEN, + logprobs: bool | NotGiven = NOT_GIVEN, + max_tokens: int | NotGiven = NOT_GIVEN, + n: int | NotGiven = NOT_GIVEN, + presence_penalty: float | NotGiven = NOT_GIVEN, + response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN, + seed: int | NotGiven = NOT_GIVEN, + stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN, + stream: Literal[True], + temperature: float | NotGiven = NOT_GIVEN, + tool_choice: completion_create_params.ToolChoice | NotGiven = NOT_GIVEN, + tools: Iterable[completion_create_params.Tool] | NotGiven = NOT_GIVEN, + top_logprobs: int | NotGiven = NOT_GIVEN, + top_p: float | NotGiven = NOT_GIVEN, + user: str | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> Stream[ChatCompletionChunk]: + ... + + @overload + def create( + self, + *, + messages: Iterable[completion_create_params.Message], + model: str, + frequency_penalty: float | NotGiven = NOT_GIVEN, + logit_bias: Dict[str, int] | NotGiven = NOT_GIVEN, + logprobs: bool | NotGiven = NOT_GIVEN, + max_tokens: int | NotGiven = NOT_GIVEN, + n: int | NotGiven = NOT_GIVEN, + presence_penalty: float | NotGiven = NOT_GIVEN, + response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN, + seed: int | NotGiven = NOT_GIVEN, + stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN, + stream: bool, + temperature: float | NotGiven = NOT_GIVEN, + tool_choice: completion_create_params.ToolChoice | NotGiven = NOT_GIVEN, + tools: Iterable[completion_create_params.Tool] | NotGiven = NOT_GIVEN, + top_logprobs: int | NotGiven = NOT_GIVEN, + top_p: float | NotGiven = NOT_GIVEN, + user: str | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> ChatCompletion | Stream[ChatCompletionChunk]: + ... + + def create( + self, + *, + messages: Iterable[completion_create_params.Message], + model: str, + frequency_penalty: float | NotGiven = NOT_GIVEN, + logit_bias: Dict[str, int] | NotGiven = NOT_GIVEN, + logprobs: bool | NotGiven = NOT_GIVEN, + max_tokens: int | NotGiven = NOT_GIVEN, + n: int | NotGiven = NOT_GIVEN, + presence_penalty: float | NotGiven = NOT_GIVEN, + response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN, + seed: int | NotGiven = NOT_GIVEN, + stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN, + stream: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN, + temperature: float | NotGiven = NOT_GIVEN, + tool_choice: completion_create_params.ToolChoice | NotGiven = NOT_GIVEN, + tools: Iterable[completion_create_params.Tool] | NotGiven = NOT_GIVEN, + top_logprobs: int | NotGiven = NOT_GIVEN, + top_p: float | NotGiven = NOT_GIVEN, + user: str | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> ChatCompletion | Stream[ChatCompletionChunk]: """ Creates a completion for a chat prompt @@ -108,6 +204,8 @@ def create( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), cast_to=ChatCompletion, + stream=stream or False, + stream_cls=Stream[ChatCompletionChunk], ) @@ -120,6 +218,7 @@ def with_raw_response(self) -> AsyncCompletionsWithRawResponse: def with_streaming_response(self) -> AsyncCompletionsWithStreamingResponse: return AsyncCompletionsWithStreamingResponse(self) + @overload async def create( self, *, @@ -134,7 +233,7 @@ async def create( response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN, seed: int | NotGiven = NOT_GIVEN, stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN, - stream: bool | NotGiven = NOT_GIVEN, + stream: Optional[Literal[False]] | NotGiven = NOT_GIVEN, temperature: float | NotGiven = NOT_GIVEN, tool_choice: completion_create_params.ToolChoice | NotGiven = NOT_GIVEN, tools: Iterable[completion_create_params.Tool] | NotGiven = NOT_GIVEN, @@ -148,6 +247,98 @@ async def create( extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> ChatCompletion: + ... + + @overload + async def create( + self, + *, + messages: Iterable[completion_create_params.Message], + model: str, + frequency_penalty: float | NotGiven = NOT_GIVEN, + logit_bias: Dict[str, int] | NotGiven = NOT_GIVEN, + logprobs: bool | NotGiven = NOT_GIVEN, + max_tokens: int | NotGiven = NOT_GIVEN, + n: int | NotGiven = NOT_GIVEN, + presence_penalty: float | NotGiven = NOT_GIVEN, + response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN, + seed: int | NotGiven = NOT_GIVEN, + stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN, + stream: Literal[True], + temperature: float | NotGiven = NOT_GIVEN, + tool_choice: completion_create_params.ToolChoice | NotGiven = NOT_GIVEN, + tools: Iterable[completion_create_params.Tool] | NotGiven = NOT_GIVEN, + top_logprobs: int | NotGiven = NOT_GIVEN, + top_p: float | NotGiven = NOT_GIVEN, + user: str | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> AsyncStream[ChatCompletionChunk]: + ... + + @overload + async def create( + self, + *, + messages: Iterable[completion_create_params.Message], + model: str, + frequency_penalty: float | NotGiven = NOT_GIVEN, + logit_bias: Dict[str, int] | NotGiven = NOT_GIVEN, + logprobs: bool | NotGiven = NOT_GIVEN, + max_tokens: int | NotGiven = NOT_GIVEN, + n: int | NotGiven = NOT_GIVEN, + presence_penalty: float | NotGiven = NOT_GIVEN, + response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN, + seed: int | NotGiven = NOT_GIVEN, + stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN, + stream: bool, + temperature: float | NotGiven = NOT_GIVEN, + tool_choice: completion_create_params.ToolChoice | NotGiven = NOT_GIVEN, + tools: Iterable[completion_create_params.Tool] | NotGiven = NOT_GIVEN, + top_logprobs: int | NotGiven = NOT_GIVEN, + top_p: float | NotGiven = NOT_GIVEN, + user: str | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> ChatCompletion | AsyncStream[ChatCompletionChunk]: + ... + + async def create( + self, + *, + messages: Iterable[completion_create_params.Message], + model: str, + frequency_penalty: float | NotGiven = NOT_GIVEN, + logit_bias: Dict[str, int] | NotGiven = NOT_GIVEN, + logprobs: bool | NotGiven = NOT_GIVEN, + max_tokens: int | NotGiven = NOT_GIVEN, + n: int | NotGiven = NOT_GIVEN, + presence_penalty: float | NotGiven = NOT_GIVEN, + response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN, + seed: int | NotGiven = NOT_GIVEN, + stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN, + stream: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN, + temperature: float | NotGiven = NOT_GIVEN, + tool_choice: completion_create_params.ToolChoice | NotGiven = NOT_GIVEN, + tools: Iterable[completion_create_params.Tool] | NotGiven = NOT_GIVEN, + top_logprobs: int | NotGiven = NOT_GIVEN, + top_p: float | NotGiven = NOT_GIVEN, + user: str | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> ChatCompletion | AsyncStream[ChatCompletionChunk]: """ Creates a completion for a chat prompt @@ -192,6 +383,8 @@ async def create( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), cast_to=ChatCompletion, + stream=stream or False, + stream_cls=AsyncStream[ChatCompletionChunk], )