diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 2851dae..1e9f15a 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -22,7 +22,7 @@ jobs: echo "$HOME/.rye/shims" >> $GITHUB_PATH env: RYE_VERSION: 0.24.0 - RYE_INSTALL_OPTION: "--yes" + RYE_INSTALL_OPTION: '--yes' - name: Install dependencies run: | @@ -39,3 +39,24 @@ jobs: - name: Ensure importable run: | rye run python -c 'import groq' + test: + name: test + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + + - name: Install Rye + run: | + curl -sSf https://rye-up.com/get | bash + echo "$HOME/.rye/shims" >> $GITHUB_PATH + env: + RYE_VERSION: 0.24.0 + RYE_INSTALL_OPTION: '--yes' + + - name: Bootstrap + run: ./scripts/bootstrap + + - name: Run tests + run: ./scripts/test + diff --git a/.github/workflows/publish-pypi.yml b/.github/workflows/publish-pypi.yml index 962b8c4..2658f61 100644 --- a/.github/workflows/publish-pypi.yml +++ b/.github/workflows/publish-pypi.yml @@ -14,7 +14,7 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Install Rye run: | diff --git a/.github/workflows/release-doctor.yml b/.github/workflows/release-doctor.yml index c8b83af..58b5273 100644 --- a/.github/workflows/release-doctor.yml +++ b/.github/workflows/release-doctor.yml @@ -10,7 +10,7 @@ jobs: if: github.repository == 'groq/groq-python' && (github.event_name == 'push' || github.event_name == 'workflow_dispatch' || startsWith(github.head_ref, 'release-please') || github.head_ref == 'next') steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Check release environment run: | diff --git a/.gitignore b/.gitignore index a4b2f8c..0f9a66a 100644 --- a/.gitignore +++ b/.gitignore @@ -12,3 +12,4 @@ dist .env .envrc codegen.log +Brewfile.lock.json diff --git a/.stats.yml b/.stats.yml index 2b7dbf3..8c47545 100644 --- a/.stats.yml +++ b/.stats.yml @@ -1 +1,2 @@ configured_endpoints: 6 +openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/groqcloud%2Fgroqcloud-0a3089666368ff1ff668f2a73ea3b40d8b20420d8403a18579a1168dd67f2220.yml diff --git a/Brewfile b/Brewfile new file mode 100644 index 0000000..492ca37 --- /dev/null +++ b/Brewfile @@ -0,0 +1,2 @@ +brew "rye" + diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 0650ff4..143dd07 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -86,7 +86,7 @@ Most tests require you to [set up a mock server](https://github.com/stoplightio/ ```bash # you will need npm installed -npx prism path/to/your/openapi.yml +npx prism mock path/to/your/openapi.yml ``` ```bash @@ -121,5 +121,5 @@ You can release to package managers by using [the `Publish PyPI` GitHub action]( ### Publish manually -If you need to manually release a package, you can run the `bin/publish-pypi` script with an `PYPI_TOKEN` set on +If you need to manually release a package, you can run the `bin/publish-pypi` script with a `PYPI_TOKEN` set on the environment. diff --git a/README.md b/README.md index 76bbd7d..3e2d442 100644 --- a/README.md +++ b/README.md @@ -6,6 +6,8 @@ The Groq Python library provides convenient access to the Groq REST API from any application. The library includes type definitions for all request params and response fields, and offers both synchronous and asynchronous clients powered by [httpx](https://github.com/encode/httpx). +It is generated with [Stainless](https://www.stainlessapi.com/). + ## Documentation The REST API documentation can be found [on console.groq.com](https://console.groq.com/docs). The full API of this library can be found in [api.md](api.md). @@ -22,13 +24,9 @@ pip install groq The full API of this library can be found in [api.md](api.md). ```python -import os from groq import Groq -client = Groq( - # This is the default and can be omitted - api_key=os.environ.get("GROQ_API_KEY"), -) +client = Groq() chat_completion = client.chat.completions.create( messages=[ @@ -39,7 +37,7 @@ chat_completion = client.chat.completions.create( ], model="mixtral-8x7b-32768", ) -print(chat_completion.choices[0].message.content) +print(chat_completion.choices_0.message.content) ``` While you can provide an `api_key` keyword argument, @@ -52,14 +50,10 @@ so that your API Key is not stored in source control. Simply import `AsyncGroq` instead of `Groq` and use `await` with each API call: ```python -import os import asyncio from groq import AsyncGroq -client = AsyncGroq( - # This is the default and can be omitted - api_key=os.environ.get("GROQ_API_KEY"), -) +client = AsyncGroq() async def main() -> None: @@ -72,7 +66,7 @@ async def main() -> None: ], model="mixtral-8x7b-32768", ) - print(chat_completion.choices[0].message.content) + print(chat_completion.choices_0.message.content) asyncio.run(main()) @@ -82,10 +76,10 @@ Functionality between the synchronous and asynchronous clients is otherwise iden ## Using types -Nested request parameters are [TypedDicts](https://docs.python.org/3/library/typing.html#typing.TypedDict). Responses are [Pydantic models](https://docs.pydantic.dev), which provide helper methods for things like: +Nested request parameters are [TypedDicts](https://docs.python.org/3/library/typing.html#typing.TypedDict). Responses are [Pydantic models](https://docs.pydantic.dev) which also provide helper methods for things like: -- Serializing back into JSON, `model.model_dump_json(indent=2, exclude_unset=True)` -- Converting to a dictionary, `model.model_dump(exclude_unset=True)` +- Serializing back into JSON, `model.to_json()` +- Converting to a dictionary, `model.to_dict()` Typed requests and responses provide autocomplete and documentation within your editor. If you would like to see type errors in VS Code to help catch bugs earlier, set `python.analysis.typeCheckingMode` to `basic`. @@ -195,7 +189,7 @@ client = Groq( ) # Override per-request: -client.with_options(timeout=5 * 1000).chat.completions.create( +client.with_options(timeout=5.0).chat.completions.create( messages=[ { "role": "system", @@ -294,6 +288,41 @@ with client.chat.completions.with_streaming_response.create( The context manager is required so that the response will reliably be closed. +### Making custom/undocumented requests + +This library is typed for convenient access to the documented API. + +If you need to access undocumented endpoints, params, or response properties, the library can still be used. + +#### Undocumented endpoints + +To make requests to undocumented endpoints, you can make requests using `client.get`, `client.post`, and other +http verbs. Options on the client will be respected (such as retries) will be respected when making this +request. + +```py +import httpx + +response = client.post( + "/foo", + cast_to=httpx.Response, + body={"my_param": True}, +) + +print(response.headers.get("x-foo")) +``` + +#### Undocumented request params + +If you want to explicitly send an extra param, you can do so with the `extra_query`, `extra_body`, and `extra_headers` request +options. + +#### Undocumented response properties + +To access undocumented response properties, you can access the extra fields like `response.unknown_prop`. You +can also get all the extra fields on the Pydantic model as a dict with +[`response.model_extra`](https://docs.pydantic.dev/latest/api/base_model/#pydantic.BaseModel.model_extra). + ### Configuring the HTTP client You can directly override the [httpx client](https://www.python-httpx.org/api/#client) to customize it for your use case, including: @@ -303,13 +332,12 @@ You can directly override the [httpx client](https://www.python-httpx.org/api/#c - Additional [advanced](https://www.python-httpx.org/advanced/#client-instances) functionality ```python -import httpx -from groq import Groq +from groq import Groq, DefaultHttpxClient client = Groq( # Or use the `GROQ_BASE_URL` env var base_url="http://my.test.server.example.com:8083", - http_client=httpx.Client( + http_client=DefaultHttpxClient( proxies="http://my.test.proxy.example.com", transport=httpx.HTTPTransport(local_address="0.0.0.0"), ), diff --git a/bin/check-env-state.py b/bin/check-env-state.py deleted file mode 100644 index e1b8b6c..0000000 --- a/bin/check-env-state.py +++ /dev/null @@ -1,40 +0,0 @@ -"""Script that exits 1 if the current environment is not -in sync with the `requirements-dev.lock` file. -""" - -from pathlib import Path - -import importlib_metadata - - -def should_run_sync() -> bool: - dev_lock = Path(__file__).parent.parent.joinpath("requirements-dev.lock") - - for line in dev_lock.read_text().splitlines(): - if not line or line.startswith("#") or line.startswith("-e"): - continue - - dep, lock_version = line.split("==") - - try: - version = importlib_metadata.version(dep) - - if lock_version != version: - print(f"mismatch for {dep} current={version} lock={lock_version}") - return True - except Exception: - print(f"could not import {dep}") - return True - - return False - - -def main() -> None: - if should_run_sync(): - exit(1) - else: - exit(0) - - -if __name__ == "__main__": - main() diff --git a/bin/check-test-server b/bin/check-test-server deleted file mode 100755 index a6fa349..0000000 --- a/bin/check-test-server +++ /dev/null @@ -1,50 +0,0 @@ -#!/usr/bin/env bash - -RED='\033[0;31m' -GREEN='\033[0;32m' -YELLOW='\033[0;33m' -NC='\033[0m' # No Color - -function prism_is_running() { - curl --silent "http://localhost:4010" >/dev/null 2>&1 -} - -function is_overriding_api_base_url() { - [ -n "$TEST_API_BASE_URL" ] -} - -if is_overriding_api_base_url ; then - # If someone is running the tests against the live API, we can trust they know - # what they're doing and exit early. - echo -e "${GREEN}✔ Running tests against ${TEST_API_BASE_URL}${NC}" - - exit 0 -elif prism_is_running ; then - echo -e "${GREEN}✔ Mock prism server is running with your OpenAPI spec${NC}" - echo - - exit 0 -else - echo -e "${RED}ERROR:${NC} The test suite will not run without a mock Prism server" - echo -e "running against your OpenAPI spec." - echo - echo -e "${YELLOW}To fix:${NC}" - echo - echo -e "1. Install Prism (requires Node 16+):" - echo - echo -e " With npm:" - echo -e " \$ ${YELLOW}npm install -g @stoplight/prism-cli${NC}" - echo - echo -e " With yarn:" - echo -e " \$ ${YELLOW}yarn global add @stoplight/prism-cli${NC}" - echo - echo -e "2. Run the mock server" - echo - echo -e " To run the server, pass in the path of your OpenAPI" - echo -e " spec to the prism command:" - echo - echo -e " \$ ${YELLOW}prism mock path/to/your.openapi.yml${NC}" - echo - - exit 1 -fi diff --git a/bin/test b/bin/test deleted file mode 100755 index 60ede7a..0000000 --- a/bin/test +++ /dev/null @@ -1,3 +0,0 @@ -#!/usr/bin/env bash - -bin/check-test-server && rye run pytest "$@" diff --git a/pyproject.toml b/pyproject.toml index 6d07531..26bb110 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,7 +2,7 @@ name = "groq" version = "0.5.0" description = "The official Python library for the groq API" -readme = "README.md" +dynamic = ["readme"] license = "Apache-2.0" authors = [ { name = "Groq", email = "support@groq.com" }, @@ -48,7 +48,7 @@ Repository = "https://github.com/groq/groq-python" managed = true # version pins are in requirements-dev.lock dev-dependencies = [ - "pyright", + "pyright>=1.1.359", "mypy", "respx", "pytest", @@ -68,7 +68,7 @@ format = { chain = [ "fix:ruff", ]} "format:black" = "black ." -"format:docs" = "python bin/ruffen-docs.py README.md api.md" +"format:docs" = "python scripts/utils/ruffen-docs.py README.md api.md" "format:ruff" = "ruff format" "format:isort" = "isort ." @@ -88,7 +88,7 @@ typecheck = { chain = [ "typecheck:mypy" = "mypy ." [build-system] -requires = ["hatchling"] +requires = ["hatchling", "hatch-fancy-pypi-readme"] build-backend = "hatchling.build" [tool.hatch.build] @@ -99,6 +99,17 @@ include = [ [tool.hatch.build.targets.wheel] packages = ["src/groq"] +[tool.hatch.metadata.hooks.fancy-pypi-readme] +content-type = "text/markdown" + +[[tool.hatch.metadata.hooks.fancy-pypi-readme.fragments]] +path = "README.md" + +[[tool.hatch.metadata.hooks.fancy-pypi-readme.substitutions]] +# replace relative links with absolute links +pattern = '\[(.+?)\]\(((?!https?://)\S+?)\)' +replacement = '[\1](https://github.com/groq/groq-python/tree/main/\g<2>)' + [tool.black] line-length = 120 target-version = ["py37"] @@ -130,6 +141,7 @@ reportImplicitOverride = true reportImportCycles = false reportPrivateUsage = false + [tool.ruff] line-length = 120 output-format = "grouped" @@ -149,7 +161,9 @@ select = [ "T201", "T203", # misuse of typing.TYPE_CHECKING - "TCH004" + "TCH004", + # import rules + "TID251", ] ignore = [ # mutable defaults @@ -165,6 +179,9 @@ ignore-init-module-imports = true [tool.ruff.format] docstring-code-format = true +[tool.ruff.lint.flake8-tidy-imports.banned-api] +"functools.lru_cache".msg = "This function does not retain type information for the wrapped function's arguments; The `lru_cache` function from `_utils` should be used instead" + [tool.ruff.lint.isort] length-sort = true length-sort-straight = true @@ -174,5 +191,6 @@ known-first-party = ["groq", "tests"] [tool.ruff.per-file-ignores] "bin/**.py" = ["T201", "T203"] +"scripts/**.py" = ["T201", "T203"] "tests/**.py" = ["T201", "T203"] "examples/**.py" = ["T201", "T203"] diff --git a/requirements-dev.lock b/requirements-dev.lock index b67c731..bcf3065 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -63,7 +63,7 @@ pydantic==2.4.2 # via groq pydantic-core==2.10.1 # via pydantic -pyright==1.1.351 +pyright==1.1.359 pytest==7.1.1 # via pytest-asyncio pytest-asyncio==0.21.1 diff --git a/scripts/bootstrap b/scripts/bootstrap new file mode 100755 index 0000000..29df07e --- /dev/null +++ b/scripts/bootstrap @@ -0,0 +1,19 @@ +#!/usr/bin/env bash + +set -e + +cd "$(dirname "$0")/.." + +if [ -f "Brewfile" ] && [ "$(uname -s)" = "Darwin" ]; then + brew bundle check >/dev/null 2>&1 || { + echo "==> Installing Homebrew dependencies…" + brew bundle + } +fi + +echo "==> Installing Python dependencies…" + +# experimental uv support makes installations significantly faster +rye config --set-bool behavior.use-uv=true + +rye sync diff --git a/scripts/format b/scripts/format new file mode 100755 index 0000000..2a9ea46 --- /dev/null +++ b/scripts/format @@ -0,0 +1,8 @@ +#!/usr/bin/env bash + +set -e + +cd "$(dirname "$0")/.." + +rye run format + diff --git a/scripts/lint b/scripts/lint new file mode 100755 index 0000000..0cc68b5 --- /dev/null +++ b/scripts/lint @@ -0,0 +1,8 @@ +#!/usr/bin/env bash + +set -e + +cd "$(dirname "$0")/.." + +rye run lint + diff --git a/scripts/mock b/scripts/mock new file mode 100755 index 0000000..fe89a1d --- /dev/null +++ b/scripts/mock @@ -0,0 +1,41 @@ +#!/usr/bin/env bash + +set -e + +cd "$(dirname "$0")/.." + +if [[ -n "$1" && "$1" != '--'* ]]; then + URL="$1" + shift +else + URL="$(grep 'openapi_spec_url' .stats.yml | cut -d' ' -f2)" +fi + +# Check if the URL is empty +if [ -z "$URL" ]; then + echo "Error: No OpenAPI spec path/url provided or found in .stats.yml" + exit 1 +fi + +echo "==> Starting mock server with URL ${URL}" + +# Run prism mock on the given spec +if [ "$1" == "--daemon" ]; then + npm exec --package=@stoplight/prism-cli@~5.8 -- prism mock "$URL" &> .prism.log & + + # Wait for server to come online + echo -n "Waiting for server" + while ! grep -q "✖ fatal\|Prism is listening" ".prism.log" ; do + echo -n "." + sleep 0.1 + done + + if grep -q "✖ fatal" ".prism.log"; then + cat .prism.log + exit 1 + fi + + echo +else + npm exec --package=@stoplight/prism-cli@~5.8 -- prism mock "$URL" +fi diff --git a/scripts/test b/scripts/test new file mode 100755 index 0000000..be01d04 --- /dev/null +++ b/scripts/test @@ -0,0 +1,57 @@ +#!/usr/bin/env bash + +set -e + +cd "$(dirname "$0")/.." + +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[0;33m' +NC='\033[0m' # No Color + +function prism_is_running() { + curl --silent "http://localhost:4010" >/dev/null 2>&1 +} + +kill_server_on_port() { + pids=$(lsof -t -i tcp:"$1" || echo "") + if [ "$pids" != "" ]; then + kill "$pids" + echo "Stopped $pids." + fi +} + +function is_overriding_api_base_url() { + [ -n "$TEST_API_BASE_URL" ] +} + +if ! is_overriding_api_base_url && ! prism_is_running ; then + # When we exit this script, make sure to kill the background mock server process + trap 'kill_server_on_port 4010' EXIT + + # Start the dev server + ./scripts/mock --daemon +fi + +if is_overriding_api_base_url ; then + echo -e "${GREEN}✔ Running tests against ${TEST_API_BASE_URL}${NC}" + echo +elif ! prism_is_running ; then + echo -e "${RED}ERROR:${NC} The test suite will not run without a mock Prism server" + echo -e "running against your OpenAPI spec." + echo + echo -e "To run the server, pass in the path or url of your OpenAPI" + echo -e "spec to the prism command:" + echo + echo -e " \$ ${YELLOW}npm exec --package=@stoplight/prism-cli@~5.3.2 -- prism mock path/to/your.openapi.yml${NC}" + echo + + exit 1 +else + echo -e "${GREEN}✔ Mock prism server is running with your OpenAPI spec${NC}" + echo +fi + +# Run tests +echo "==> Running tests" +rye run pytest "$@" diff --git a/bin/ruffen-docs.py b/scripts/utils/ruffen-docs.py similarity index 100% rename from bin/ruffen-docs.py rename to scripts/utils/ruffen-docs.py diff --git a/src/groq/__init__.py b/src/groq/__init__.py index 105e1f4..25f858e 100644 --- a/src/groq/__init__.py +++ b/src/groq/__init__.py @@ -1,12 +1,13 @@ -# File generated from our OpenAPI spec by Stainless. +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. from . import types -from ._types import NoneType, Transport, ProxiesTypes +from ._types import NOT_GIVEN, NoneType, NotGiven, Transport, ProxiesTypes from ._utils import file_from_path from ._client import Groq, Client, Stream, Timeout, AsyncGroq, Transport, AsyncClient, AsyncStream, RequestOptions from ._models import BaseModel from ._version import __title__, __version__ from ._response import APIResponse as APIResponse, AsyncAPIResponse as AsyncAPIResponse +from ._constants import DEFAULT_TIMEOUT, DEFAULT_MAX_RETRIES, DEFAULT_CONNECTION_LIMITS from ._exceptions import ( APIError, GroqError, @@ -23,6 +24,7 @@ UnprocessableEntityError, APIResponseValidationError, ) +from ._base_client import DefaultHttpxClient, DefaultAsyncHttpxClient from ._utils._logs import setup_logging as _setup_logging __all__ = [ @@ -32,6 +34,8 @@ "NoneType", "Transport", "ProxiesTypes", + "NotGiven", + "NOT_GIVEN", "GroqError", "APIError", "APIStatusError", @@ -56,6 +60,11 @@ "AsyncGroq", "file_from_path", "BaseModel", + "DEFAULT_TIMEOUT", + "DEFAULT_MAX_RETRIES", + "DEFAULT_CONNECTION_LIMITS", + "DefaultHttpxClient", + "DefaultAsyncHttpxClient", ] _setup_logging() diff --git a/src/groq/_base_client.py b/src/groq/_base_client.py index 2b3a1f9..e97817e 100644 --- a/src/groq/_base_client.py +++ b/src/groq/_base_client.py @@ -29,7 +29,6 @@ cast, overload, ) -from functools import lru_cache from typing_extensions import Literal, override, get_origin import anyio @@ -61,7 +60,7 @@ RequestOptions, ModelBuilderProtocol, ) -from ._utils import is_dict, is_list, is_given, is_mapping +from ._utils import is_dict, is_list, is_given, lru_cache, is_mapping from ._compat import model_copy, model_dump from ._models import GenericModel, FinalRequestOptions, validate_type, construct_type from ._response import ( @@ -71,13 +70,13 @@ extract_response_type, ) from ._constants import ( - DEFAULT_LIMITS, DEFAULT_TIMEOUT, MAX_RETRY_DELAY, DEFAULT_MAX_RETRIES, INITIAL_RETRY_DELAY, RAW_RESPONSE_HEADER, OVERRIDE_CAST_TO_HEADER, + DEFAULT_CONNECTION_LIMITS, ) from ._streaming import Stream, SSEDecoder, AsyncStream, SSEBytesDecoder from ._exceptions import ( @@ -360,6 +359,11 @@ def __init__( self._strict_response_validation = _strict_response_validation self._idempotency_header = None + if max_retries is None: # pyright: ignore[reportUnnecessaryComparison] + raise TypeError( + "max_retries cannot be None. If you want to disable retries, pass `0`; if you want unlimited retries, pass `math.inf` or a very high number; if you want the default behavior, pass `groq.DEFAULT_MAX_RETRIES`" + ) + def _enforce_trailing_slash(self, url: URL) -> URL: if url.raw_path.endswith(b"/"): return url @@ -710,7 +714,27 @@ def _idempotency_key(self) -> str: return f"stainless-python-retry-{uuid.uuid4()}" -class SyncHttpxClientWrapper(httpx.Client): +class _DefaultHttpxClient(httpx.Client): + def __init__(self, **kwargs: Any) -> None: + kwargs.setdefault("timeout", DEFAULT_TIMEOUT) + kwargs.setdefault("limits", DEFAULT_CONNECTION_LIMITS) + kwargs.setdefault("follow_redirects", True) + super().__init__(**kwargs) + + +if TYPE_CHECKING: + DefaultHttpxClient = httpx.Client + """An alias to `httpx.Client` that provides the same defaults that this SDK + uses internally. + + This is useful because overriding the `http_client` with your own instance of + `httpx.Client` will result in httpx's defaults being used, not ours. + """ +else: + DefaultHttpxClient = _DefaultHttpxClient + + +class SyncHttpxClientWrapper(DefaultHttpxClient): def __del__(self) -> None: try: self.close() @@ -746,7 +770,7 @@ def __init__( if http_client is not None: raise ValueError("The `http_client` argument is mutually exclusive with `connection_pool_limits`") else: - limits = DEFAULT_LIMITS + limits = DEFAULT_CONNECTION_LIMITS if transport is not None: warnings.warn( @@ -921,6 +945,8 @@ def _request( if self.custom_auth is not None: kwargs["auth"] = self.custom_auth + log.debug("Sending HTTP Request: %s %s", request.method, request.url) + try: response = self._client.send( request, @@ -959,7 +985,12 @@ def _request( raise APIConnectionError(request=request) from err log.debug( - 'HTTP Request: %s %s "%i %s"', request.method, request.url, response.status_code, response.reason_phrase + 'HTTP Response: %s %s "%i %s" %s', + request.method, + request.url, + response.status_code, + response.reason_phrase, + response.headers, ) try: @@ -1243,7 +1274,27 @@ def get_api_list( return self._request_api_list(model, page, opts) -class AsyncHttpxClientWrapper(httpx.AsyncClient): +class _DefaultAsyncHttpxClient(httpx.AsyncClient): + def __init__(self, **kwargs: Any) -> None: + kwargs.setdefault("timeout", DEFAULT_TIMEOUT) + kwargs.setdefault("limits", DEFAULT_CONNECTION_LIMITS) + kwargs.setdefault("follow_redirects", True) + super().__init__(**kwargs) + + +if TYPE_CHECKING: + DefaultAsyncHttpxClient = httpx.AsyncClient + """An alias to `httpx.AsyncClient` that provides the same defaults that this SDK + uses internally. + + This is useful because overriding the `http_client` with your own instance of + `httpx.AsyncClient` will result in httpx's defaults being used, not ours. + """ +else: + DefaultAsyncHttpxClient = _DefaultAsyncHttpxClient + + +class AsyncHttpxClientWrapper(DefaultAsyncHttpxClient): def __del__(self) -> None: try: # TODO(someday): support non asyncio runtimes here @@ -1280,7 +1331,7 @@ def __init__( if http_client is not None: raise ValueError("The `http_client` argument is mutually exclusive with `connection_pool_limits`") else: - limits = DEFAULT_LIMITS + limits = DEFAULT_CONNECTION_LIMITS if transport is not None: warnings.warn( diff --git a/src/groq/_client.py b/src/groq/_client.py index 4ba4e27..8567162 100644 --- a/src/groq/_client.py +++ b/src/groq/_client.py @@ -1,4 +1,4 @@ -# File generated from our OpenAPI spec by Stainless. +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. from __future__ import annotations @@ -46,9 +46,9 @@ class Groq(SyncAPIClient): - chat: resources.Chat - audio: resources.Audio - models: resources.Models + chat: resources.ChatResource + audio: resources.AudioResource + models: resources.ModelsResource with_raw_response: GroqWithRawResponse with_streaming_response: GroqWithStreamedResponse @@ -64,7 +64,9 @@ def __init__( max_retries: int = DEFAULT_MAX_RETRIES, default_headers: Mapping[str, str] | None = None, default_query: Mapping[str, object] | None = None, - # Configure a custom httpx client. See the [httpx documentation](https://www.python-httpx.org/api/#client) for more details. + # Configure a custom httpx client. + # We provide a `DefaultHttpxClient` class that you can pass to retain the default values we use for `limits`, `timeout` & `follow_redirects`. + # See the [httpx documentation](https://www.python-httpx.org/api/#client) for more details. http_client: httpx.Client | None = None, # Enable or disable schema validation for data returned by the API. # When enabled an error APIResponseValidationError is raised @@ -104,9 +106,9 @@ def __init__( _strict_response_validation=_strict_response_validation, ) - self.chat = resources.Chat(self) - self.audio = resources.Audio(self) - self.models = resources.Models(self) + self.chat = resources.ChatResource(self) + self.audio = resources.AudioResource(self) + self.models = resources.ModelsResource(self) self.with_raw_response = GroqWithRawResponse(self) self.with_streaming_response = GroqWithStreamedResponse(self) @@ -216,9 +218,9 @@ def _make_status_error( class AsyncGroq(AsyncAPIClient): - chat: resources.AsyncChat - audio: resources.AsyncAudio - models: resources.AsyncModels + chat: resources.AsyncChatResource + audio: resources.AsyncAudioResource + models: resources.AsyncModelsResource with_raw_response: AsyncGroqWithRawResponse with_streaming_response: AsyncGroqWithStreamedResponse @@ -234,7 +236,9 @@ def __init__( max_retries: int = DEFAULT_MAX_RETRIES, default_headers: Mapping[str, str] | None = None, default_query: Mapping[str, object] | None = None, - # Configure a custom httpx client. See the [httpx documentation](https://www.python-httpx.org/api/#asyncclient) for more details. + # Configure a custom httpx client. + # We provide a `DefaultAsyncHttpxClient` class that you can pass to retain the default values we use for `limits`, `timeout` & `follow_redirects`. + # See the [httpx documentation](https://www.python-httpx.org/api/#asyncclient) for more details. http_client: httpx.AsyncClient | None = None, # Enable or disable schema validation for data returned by the API. # When enabled an error APIResponseValidationError is raised @@ -274,9 +278,9 @@ def __init__( _strict_response_validation=_strict_response_validation, ) - self.chat = resources.AsyncChat(self) - self.audio = resources.AsyncAudio(self) - self.models = resources.AsyncModels(self) + self.chat = resources.AsyncChatResource(self) + self.audio = resources.AsyncAudioResource(self) + self.models = resources.AsyncModelsResource(self) self.with_raw_response = AsyncGroqWithRawResponse(self) self.with_streaming_response = AsyncGroqWithStreamedResponse(self) @@ -387,30 +391,30 @@ def _make_status_error( class GroqWithRawResponse: def __init__(self, client: Groq) -> None: - self.chat = resources.ChatWithRawResponse(client.chat) - self.audio = resources.AudioWithRawResponse(client.audio) - self.models = resources.ModelsWithRawResponse(client.models) + self.chat = resources.ChatResourceWithRawResponse(client.chat) + self.audio = resources.AudioResourceWithRawResponse(client.audio) + self.models = resources.ModelsResourceWithRawResponse(client.models) class AsyncGroqWithRawResponse: def __init__(self, client: AsyncGroq) -> None: - self.chat = resources.AsyncChatWithRawResponse(client.chat) - self.audio = resources.AsyncAudioWithRawResponse(client.audio) - self.models = resources.AsyncModelsWithRawResponse(client.models) + self.chat = resources.AsyncChatResourceWithRawResponse(client.chat) + self.audio = resources.AsyncAudioResourceWithRawResponse(client.audio) + self.models = resources.AsyncModelsResourceWithRawResponse(client.models) class GroqWithStreamedResponse: def __init__(self, client: Groq) -> None: - self.chat = resources.ChatWithStreamingResponse(client.chat) - self.audio = resources.AudioWithStreamingResponse(client.audio) - self.models = resources.ModelsWithStreamingResponse(client.models) + self.chat = resources.ChatResourceWithStreamingResponse(client.chat) + self.audio = resources.AudioResourceWithStreamingResponse(client.audio) + self.models = resources.ModelsResourceWithStreamingResponse(client.models) class AsyncGroqWithStreamedResponse: def __init__(self, client: AsyncGroq) -> None: - self.chat = resources.AsyncChatWithStreamingResponse(client.chat) - self.audio = resources.AsyncAudioWithStreamingResponse(client.audio) - self.models = resources.AsyncModelsWithStreamingResponse(client.models) + self.chat = resources.AsyncChatResourceWithStreamingResponse(client.chat) + self.audio = resources.AsyncAudioResourceWithStreamingResponse(client.audio) + self.models = resources.AsyncModelsResourceWithStreamingResponse(client.models) Client = Groq diff --git a/src/groq/_constants.py b/src/groq/_constants.py index bf15141..a2ac3b6 100644 --- a/src/groq/_constants.py +++ b/src/groq/_constants.py @@ -1,4 +1,4 @@ -# File generated from our OpenAPI spec by Stainless. +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. import httpx @@ -8,7 +8,7 @@ # default timeout is 1 minute DEFAULT_TIMEOUT = httpx.Timeout(timeout=60.0, connect=5.0) DEFAULT_MAX_RETRIES = 2 -DEFAULT_LIMITS = httpx.Limits(max_connections=100, max_keepalive_connections=20) +DEFAULT_CONNECTION_LIMITS = httpx.Limits(max_connections=100, max_keepalive_connections=20) INITIAL_RETRY_DELAY = 0.5 MAX_RETRY_DELAY = 8.0 diff --git a/src/groq/_exceptions.py b/src/groq/_exceptions.py index ccc539a..ca69070 100644 --- a/src/groq/_exceptions.py +++ b/src/groq/_exceptions.py @@ -1,4 +1,4 @@ -# File generated from our OpenAPI spec by Stainless. +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. from __future__ import annotations diff --git a/src/groq/_models.py b/src/groq/_models.py index 8108914..ff3f54e 100644 --- a/src/groq/_models.py +++ b/src/groq/_models.py @@ -1,5 +1,6 @@ from __future__ import annotations +import os import inspect from typing import TYPE_CHECKING, Any, Type, Union, Generic, TypeVar, Callable, cast from datetime import date, datetime @@ -10,6 +11,7 @@ Protocol, Required, TypedDict, + TypeGuard, final, override, runtime_checkable, @@ -30,7 +32,20 @@ AnyMapping, HttpxRequestFiles, ) -from ._utils import is_list, is_given, is_mapping, parse_date, parse_datetime, strip_not_given +from ._utils import ( + PropertyInfo, + is_list, + is_given, + lru_cache, + is_mapping, + parse_date, + coerce_boolean, + parse_datetime, + strip_not_given, + extract_type_arg, + is_annotated_type, + strip_annotated_type, +) from ._compat import ( PYDANTIC_V2, ConfigDict, @@ -46,6 +61,9 @@ ) from ._constants import RAW_RESPONSE_HEADER +if TYPE_CHECKING: + from pydantic_core.core_schema import ModelField, ModelFieldsSchema + __all__ = ["BaseModel", "GenericModel"] _T = TypeVar("_T") @@ -58,7 +76,9 @@ class _ConfigProtocol(Protocol): class BaseModel(pydantic.BaseModel): if PYDANTIC_V2: - model_config: ClassVar[ConfigDict] = ConfigDict(extra="allow") + model_config: ClassVar[ConfigDict] = ConfigDict( + extra="allow", defer_build=coerce_boolean(os.environ.get("DEFER_PYDANTIC_BUILD", "true")) + ) else: @property @@ -70,6 +90,79 @@ def model_fields_set(self) -> set[str]: class Config(pydantic.BaseConfig): # pyright: ignore[reportDeprecated] extra: Any = pydantic.Extra.allow # type: ignore + def to_dict( + self, + *, + mode: Literal["json", "python"] = "python", + use_api_names: bool = True, + exclude_unset: bool = True, + exclude_defaults: bool = False, + exclude_none: bool = False, + warnings: bool = True, + ) -> dict[str, object]: + """Recursively generate a dictionary representation of the model, optionally specifying which fields to include or exclude. + + By default, fields that were not set by the API will not be included, + and keys will match the API response, *not* the property names from the model. + + For example, if the API responds with `"fooBar": true` but we've defined a `foo_bar: bool` property, + the output will use the `"fooBar"` key (unless `use_api_names=False` is passed). + + Args: + mode: + If mode is 'json', the dictionary will only contain JSON serializable types. e.g. `datetime` will be turned into a string, `"2024-3-22T18:11:19.117000Z"`. + If mode is 'python', the dictionary may contain any Python objects. e.g. `datetime(2024, 3, 22)` + + use_api_names: Whether to use the key that the API responded with or the property name. Defaults to `True`. + exclude_unset: Whether to exclude fields that have not been explicitly set. + exclude_defaults: Whether to exclude fields that are set to their default value from the output. + exclude_none: Whether to exclude fields that have a value of `None` from the output. + warnings: Whether to log warnings when invalid fields are encountered. This is only supported in Pydantic v2. + """ + return self.model_dump( + mode=mode, + by_alias=use_api_names, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + warnings=warnings, + ) + + def to_json( + self, + *, + indent: int | None = 2, + use_api_names: bool = True, + exclude_unset: bool = True, + exclude_defaults: bool = False, + exclude_none: bool = False, + warnings: bool = True, + ) -> str: + """Generates a JSON string representing this model as it would be received from or sent to the API (but with indentation). + + By default, fields that were not set by the API will not be included, + and keys will match the API response, *not* the property names from the model. + + For example, if the API responds with `"fooBar": true` but we've defined a `foo_bar: bool` property, + the output will use the `"fooBar"` key (unless `use_api_names=False` is passed). + + Args: + indent: Indentation to use in the JSON output. If `None` is passed, the output will be compact. Defaults to `2` + use_api_names: Whether to use the key that the API responded with or the property name. Defaults to `True`. + exclude_unset: Whether to exclude fields that have not been explicitly set. + exclude_defaults: Whether to exclude fields that have the default value. + exclude_none: Whether to exclude fields that have a value of `None`. + warnings: Whether to show any warnings that occurred during serialization. This is only supported in Pydantic v2. + """ + return self.model_dump_json( + indent=indent, + by_alias=use_api_names, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + warnings=warnings, + ) + @override def __str__(self) -> str: # mypy complains about an invalid self arg @@ -259,7 +352,6 @@ def _construct_field(value: object, field: FieldInfo, key: str) -> object: def is_basemodel(type_: type) -> bool: """Returns whether or not the given type is either a `BaseModel` or a union of `BaseModel`""" - origin = get_origin(type_) or type_ if is_union(type_): for variant in get_args(type_): if is_basemodel(variant): @@ -267,14 +359,29 @@ def is_basemodel(type_: type) -> bool: return False + return is_basemodel_type(type_) + + +def is_basemodel_type(type_: type) -> TypeGuard[type[BaseModel] | type[GenericModel]]: + origin = get_origin(type_) or type_ return issubclass(origin, BaseModel) or issubclass(origin, GenericModel) -def construct_type(*, value: object, type_: type) -> object: +def construct_type(*, value: object, type_: object) -> object: """Loose coercion to the expected type with construction of nested values. If the given value does not match the expected type then it is returned as-is. """ + # we allow `object` as the input type because otherwise, passing things like + # `Literal['value']` will be reported as a type error by type checkers + type_ = cast("type[object]", type_) + + # unwrap `Annotated[T, ...]` -> `T` + if is_annotated_type(type_): + meta: tuple[Any, ...] = get_args(type_)[1:] + type_ = extract_type_arg(type_, 0) + else: + meta = tuple() # we need to use the origin class for any types that are subscripted generics # e.g. Dict[str, object] @@ -287,6 +394,28 @@ def construct_type(*, value: object, type_: type) -> object: except Exception: pass + # if the type is a discriminated union then we want to construct the right variant + # in the union, even if the data doesn't match exactly, otherwise we'd break code + # that relies on the constructed class types, e.g. + # + # class FooType: + # kind: Literal['foo'] + # value: str + # + # class BarType: + # kind: Literal['bar'] + # value: int + # + # without this block, if the data we get is something like `{'kind': 'bar', 'value': 'foo'}` then + # we'd end up constructing `FooType` when it should be `BarType`. + discriminator = _build_discriminated_union_meta(union=type_, meta_annotations=meta) + if discriminator and is_mapping(value): + variant_value = value.get(discriminator.field_alias_from or discriminator.field_name) + if variant_value and isinstance(variant_value, str): + variant_type = discriminator.mapping.get(variant_value) + if variant_type: + return construct_type(type_=variant_type, value=value) + # if the data is not valid, use the first variant that doesn't fail while deserializing for variant in args: try: @@ -344,6 +473,129 @@ def construct_type(*, value: object, type_: type) -> object: return value +@runtime_checkable +class CachedDiscriminatorType(Protocol): + __discriminator__: DiscriminatorDetails + + +class DiscriminatorDetails: + field_name: str + """The name of the discriminator field in the variant class, e.g. + + ```py + class Foo(BaseModel): + type: Literal['foo'] + ``` + + Will result in field_name='type' + """ + + field_alias_from: str | None + """The name of the discriminator field in the API response, e.g. + + ```py + class Foo(BaseModel): + type: Literal['foo'] = Field(alias='type_from_api') + ``` + + Will result in field_alias_from='type_from_api' + """ + + mapping: dict[str, type] + """Mapping of discriminator value to variant type, e.g. + + {'foo': FooVariant, 'bar': BarVariant} + """ + + def __init__( + self, + *, + mapping: dict[str, type], + discriminator_field: str, + discriminator_alias: str | None, + ) -> None: + self.mapping = mapping + self.field_name = discriminator_field + self.field_alias_from = discriminator_alias + + +def _build_discriminated_union_meta(*, union: type, meta_annotations: tuple[Any, ...]) -> DiscriminatorDetails | None: + if isinstance(union, CachedDiscriminatorType): + return union.__discriminator__ + + discriminator_field_name: str | None = None + + for annotation in meta_annotations: + if isinstance(annotation, PropertyInfo) and annotation.discriminator is not None: + discriminator_field_name = annotation.discriminator + break + + if not discriminator_field_name: + return None + + mapping: dict[str, type] = {} + discriminator_alias: str | None = None + + for variant in get_args(union): + variant = strip_annotated_type(variant) + if is_basemodel_type(variant): + if PYDANTIC_V2: + field = _extract_field_schema_pv2(variant, discriminator_field_name) + if not field: + continue + + # Note: if one variant defines an alias then they all should + discriminator_alias = field.get("serialization_alias") + + field_schema = field["schema"] + + if field_schema["type"] == "literal": + for entry in field_schema["expected"]: + if isinstance(entry, str): + mapping[entry] = variant + else: + field_info = cast("dict[str, FieldInfo]", variant.__fields__).get(discriminator_field_name) # pyright: ignore[reportDeprecated, reportUnnecessaryCast] + if not field_info: + continue + + # Note: if one variant defines an alias then they all should + discriminator_alias = field_info.alias + + if field_info.annotation and is_literal_type(field_info.annotation): + for entry in get_args(field_info.annotation): + if isinstance(entry, str): + mapping[entry] = variant + + if not mapping: + return None + + details = DiscriminatorDetails( + mapping=mapping, + discriminator_field=discriminator_field_name, + discriminator_alias=discriminator_alias, + ) + cast(CachedDiscriminatorType, union).__discriminator__ = details + return details + + +def _extract_field_schema_pv2(model: type[BaseModel], field_name: str) -> ModelField | None: + schema = model.__pydantic_core_schema__ + if schema["type"] != "model": + return None + + fields_schema = schema["schema"] + if fields_schema["type"] != "model-fields": + return None + + fields_schema = cast("ModelFieldsSchema", fields_schema) + + field = fields_schema["fields"].get(field_name) + if not field: + return None + + return cast("ModelField", field) # pyright: ignore[reportUnnecessaryCast] + + def validate_type(*, type_: type[_T], value: object) -> _T: """Strict validation that the given value matches the expected type""" if inspect.isclass(type_) and issubclass(type_, pydantic.BaseModel): @@ -363,7 +615,14 @@ class GenericModel(BaseGenericModel, BaseModel): if PYDANTIC_V2: - from pydantic import TypeAdapter + from pydantic import TypeAdapter as _TypeAdapter + + _CachedTypeAdapter = cast("TypeAdapter[object]", lru_cache(maxsize=None)(_TypeAdapter)) + + if TYPE_CHECKING: + from pydantic import TypeAdapter + else: + TypeAdapter = _CachedTypeAdapter def _validate_non_model_type(*, type_: type[_T], value: object) -> _T: return TypeAdapter(type_).validate_python(value) diff --git a/src/groq/_resource.py b/src/groq/_resource.py index d6fc089..fc39158 100644 --- a/src/groq/_resource.py +++ b/src/groq/_resource.py @@ -1,4 +1,4 @@ -# File generated from our OpenAPI spec by Stainless. +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. from __future__ import annotations diff --git a/src/groq/_response.py b/src/groq/_response.py index f7620f9..c4bc069 100644 --- a/src/groq/_response.py +++ b/src/groq/_response.py @@ -25,7 +25,7 @@ import pydantic from ._types import NoneType -from ._utils import is_given, extract_type_var_from_base +from ._utils import is_given, extract_type_arg, is_annotated_type, extract_type_var_from_base from ._models import BaseModel, is_basemodel from ._constants import RAW_RESPONSE_HEADER, OVERRIDE_CAST_TO_HEADER from ._streaming import Stream, AsyncStream, is_stream_class_type, extract_stream_chunk_type @@ -121,6 +121,10 @@ def __repr__(self) -> str: ) def _parse(self, *, to: type[_T] | None = None) -> R | _T: + # unwrap `Annotated[T, ...]` -> `T` + if to and is_annotated_type(to): + to = extract_type_arg(to, 0) + if self._is_sse_stream: if to: if not is_stream_class_type(to): @@ -162,6 +166,11 @@ def _parse(self, *, to: type[_T] | None = None) -> R | _T: ) cast_to = to if to is not None else self._cast_to + + # unwrap `Annotated[T, ...]` -> `T` + if is_annotated_type(cast_to): + cast_to = extract_type_arg(cast_to, 0) + if cast_to is NoneType: return cast(R, None) @@ -630,7 +639,7 @@ def to_streamed_response_wrapper(func: Callable[P, R]) -> Callable[P, ResponseCo @functools.wraps(func) def wrapped(*args: P.args, **kwargs: P.kwargs) -> ResponseContextManager[APIResponse[R]]: - extra_headers = {**(cast(Any, kwargs.get("extra_headers")) or {})} + extra_headers: dict[str, str] = {**(cast(Any, kwargs.get("extra_headers")) or {})} extra_headers[RAW_RESPONSE_HEADER] = "stream" kwargs["extra_headers"] = extra_headers @@ -651,7 +660,7 @@ def async_to_streamed_response_wrapper( @functools.wraps(func) def wrapped(*args: P.args, **kwargs: P.kwargs) -> AsyncResponseContextManager[AsyncAPIResponse[R]]: - extra_headers = {**(cast(Any, kwargs.get("extra_headers")) or {})} + extra_headers: dict[str, str] = {**(cast(Any, kwargs.get("extra_headers")) or {})} extra_headers[RAW_RESPONSE_HEADER] = "stream" kwargs["extra_headers"] = extra_headers @@ -675,7 +684,7 @@ def to_custom_streamed_response_wrapper( @functools.wraps(func) def wrapped(*args: P.args, **kwargs: P.kwargs) -> ResponseContextManager[_APIResponseT]: - extra_headers = {**(cast(Any, kwargs.get("extra_headers")) or {})} + extra_headers: dict[str, Any] = {**(cast(Any, kwargs.get("extra_headers")) or {})} extra_headers[RAW_RESPONSE_HEADER] = "stream" extra_headers[OVERRIDE_CAST_TO_HEADER] = response_cls @@ -700,7 +709,7 @@ def async_to_custom_streamed_response_wrapper( @functools.wraps(func) def wrapped(*args: P.args, **kwargs: P.kwargs) -> AsyncResponseContextManager[_AsyncAPIResponseT]: - extra_headers = {**(cast(Any, kwargs.get("extra_headers")) or {})} + extra_headers: dict[str, Any] = {**(cast(Any, kwargs.get("extra_headers")) or {})} extra_headers[RAW_RESPONSE_HEADER] = "stream" extra_headers[OVERRIDE_CAST_TO_HEADER] = response_cls @@ -720,7 +729,7 @@ def to_raw_response_wrapper(func: Callable[P, R]) -> Callable[P, APIResponse[R]] @functools.wraps(func) def wrapped(*args: P.args, **kwargs: P.kwargs) -> APIResponse[R]: - extra_headers = {**(cast(Any, kwargs.get("extra_headers")) or {})} + extra_headers: dict[str, str] = {**(cast(Any, kwargs.get("extra_headers")) or {})} extra_headers[RAW_RESPONSE_HEADER] = "raw" kwargs["extra_headers"] = extra_headers @@ -737,7 +746,7 @@ def async_to_raw_response_wrapper(func: Callable[P, Awaitable[R]]) -> Callable[P @functools.wraps(func) async def wrapped(*args: P.args, **kwargs: P.kwargs) -> AsyncAPIResponse[R]: - extra_headers = {**(cast(Any, kwargs.get("extra_headers")) or {})} + extra_headers: dict[str, str] = {**(cast(Any, kwargs.get("extra_headers")) or {})} extra_headers[RAW_RESPONSE_HEADER] = "raw" kwargs["extra_headers"] = extra_headers @@ -759,7 +768,7 @@ def to_custom_raw_response_wrapper( @functools.wraps(func) def wrapped(*args: P.args, **kwargs: P.kwargs) -> _APIResponseT: - extra_headers = {**(cast(Any, kwargs.get("extra_headers")) or {})} + extra_headers: dict[str, Any] = {**(cast(Any, kwargs.get("extra_headers")) or {})} extra_headers[RAW_RESPONSE_HEADER] = "raw" extra_headers[OVERRIDE_CAST_TO_HEADER] = response_cls @@ -782,7 +791,7 @@ def async_to_custom_raw_response_wrapper( @functools.wraps(func) def wrapped(*args: P.args, **kwargs: P.kwargs) -> Awaitable[_AsyncAPIResponseT]: - extra_headers = {**(cast(Any, kwargs.get("extra_headers")) or {})} + extra_headers: dict[str, Any] = {**(cast(Any, kwargs.get("extra_headers")) or {})} extra_headers[RAW_RESPONSE_HEADER] = "raw" extra_headers[OVERRIDE_CAST_TO_HEADER] = response_cls diff --git a/src/groq/_streaming.py b/src/groq/_streaming.py index c408423..70e074f 100644 --- a/src/groq/_streaming.py +++ b/src/groq/_streaming.py @@ -23,7 +23,7 @@ class Stream(Generic[_T]): response: httpx.Response - _decoder: SSEDecoder | SSEBytesDecoder + _decoder: SSEBytesDecoder def __init__( self, @@ -46,10 +46,7 @@ def __iter__(self) -> Iterator[_T]: yield item def _iter_events(self) -> Iterator[ServerSentEvent]: - if isinstance(self._decoder, SSEBytesDecoder): - yield from self._decoder.iter_bytes(self.response.iter_bytes()) - else: - yield from self._decoder.iter(self.response.iter_lines()) + yield from self._decoder.iter_bytes(self.response.iter_bytes()) def __stream__(self) -> Iterator[_T]: cast_to = cast(Any, self._cast_to) @@ -58,8 +55,6 @@ def __stream__(self) -> Iterator[_T]: iterator = self._iter_events() for sse in iterator: - if sse.data.startswith("[DONE]"): - break yield process_data(data=sse.json(), cast_to=cast_to, response=response) # Ensure the entire stream is consumed @@ -114,16 +109,8 @@ async def __aiter__(self) -> AsyncIterator[_T]: yield item async def _iter_events(self) -> AsyncIterator[ServerSentEvent]: - if isinstance(self._decoder, SSEBytesDecoder): - async for sse in self._decoder.aiter_bytes(self.response.aiter_bytes()): - if sse.data.startswith("[DONE]"): - break - yield sse - else: - async for sse in self._decoder.aiter(self.response.aiter_lines()): - if sse.data.startswith("[DONE]"): - break - yield sse + async for sse in self._decoder.aiter_bytes(self.response.aiter_bytes()): + yield sse async def __stream__(self) -> AsyncIterator[_T]: cast_to = cast(Any, self._cast_to) @@ -211,21 +198,49 @@ def __init__(self) -> None: self._last_event_id = None self._retry = None - def iter(self, iterator: Iterator[str]) -> Iterator[ServerSentEvent]: - """Given an iterator that yields lines, iterate over it & yield every event encountered""" - for line in iterator: - line = line.rstrip("\n") - sse = self.decode(line) - if sse is not None: - yield sse - - async def aiter(self, iterator: AsyncIterator[str]) -> AsyncIterator[ServerSentEvent]: - """Given an async iterator that yields lines, iterate over it & yield every event encountered""" - async for line in iterator: - line = line.rstrip("\n") - sse = self.decode(line) - if sse is not None: - yield sse + def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[ServerSentEvent]: + """Given an iterator that yields raw binary data, iterate over it & yield every event encountered""" + for chunk in self._iter_chunks(iterator): + # Split before decoding so splitlines() only uses \r and \n + for raw_line in chunk.splitlines(): + line = raw_line.decode("utf-8") + sse = self.decode(line) + if sse: + yield sse + + def _iter_chunks(self, iterator: Iterator[bytes]) -> Iterator[bytes]: + """Given an iterator that yields raw binary data, iterate over it and yield individual SSE chunks""" + data = b"" + for chunk in iterator: + for line in chunk.splitlines(keepends=True): + data += line + if data.endswith((b"\r\r", b"\n\n", b"\r\n\r\n")): + yield data + data = b"" + if data: + yield data + + async def aiter_bytes(self, iterator: AsyncIterator[bytes]) -> AsyncIterator[ServerSentEvent]: + """Given an iterator that yields raw binary data, iterate over it & yield every event encountered""" + async for chunk in self._aiter_chunks(iterator): + # Split before decoding so splitlines() only uses \r and \n + for raw_line in chunk.splitlines(): + line = raw_line.decode("utf-8") + sse = self.decode(line) + if sse: + yield sse + + async def _aiter_chunks(self, iterator: AsyncIterator[bytes]) -> AsyncIterator[bytes]: + """Given an iterator that yields raw binary data, iterate over it and yield individual SSE chunks""" + data = b"" + async for chunk in iterator: + for line in chunk.splitlines(keepends=True): + data += line + if data.endswith((b"\r\r", b"\n\n", b"\r\n\r\n")): + yield data + data = b"" + if data: + yield data def decode(self, line: str) -> ServerSentEvent | None: # See: https://html.spec.whatwg.org/multipage/server-sent-events.html#event-stream-interpretation # noqa: E501 diff --git a/src/groq/_utils/__init__.py b/src/groq/_utils/__init__.py index 5697894..31b5b22 100644 --- a/src/groq/_utils/__init__.py +++ b/src/groq/_utils/__init__.py @@ -6,6 +6,7 @@ is_list as is_list, is_given as is_given, is_tuple as is_tuple, + lru_cache as lru_cache, is_mapping as is_mapping, is_tuple_t as is_tuple_t, parse_date as parse_date, diff --git a/src/groq/_utils/_proxy.py b/src/groq/_utils/_proxy.py index b9c12dc..c46a62a 100644 --- a/src/groq/_utils/_proxy.py +++ b/src/groq/_utils/_proxy.py @@ -10,7 +10,7 @@ class LazyProxy(Generic[T], ABC): """Implements data methods to pretend that an instance is another instance. - This includes forwarding attribute access and othe methods. + This includes forwarding attribute access and other methods. """ # Note: we have to special case proxies that themselves return proxies diff --git a/src/groq/_utils/_transform.py b/src/groq/_utils/_transform.py index 1bd1330..47e262a 100644 --- a/src/groq/_utils/_transform.py +++ b/src/groq/_utils/_transform.py @@ -51,6 +51,7 @@ class MyParams(TypedDict): alias: str | None format: PropertyFormat | None format_template: str | None + discriminator: str | None def __init__( self, @@ -58,14 +59,16 @@ def __init__( alias: str | None = None, format: PropertyFormat | None = None, format_template: str | None = None, + discriminator: str | None = None, ) -> None: self.alias = alias self.format = format self.format_template = format_template + self.discriminator = discriminator @override def __repr__(self) -> str: - return f"{self.__class__.__name__}(alias='{self.alias}', format={self.format}, format_template='{self.format_template}')" + return f"{self.__class__.__name__}(alias='{self.alias}', format={self.format}, format_template='{self.format_template}', discriminator='{self.discriminator}')" def maybe_transform( diff --git a/src/groq/_utils/_utils.py b/src/groq/_utils/_utils.py index 93c9551..17904ce 100644 --- a/src/groq/_utils/_utils.py +++ b/src/groq/_utils/_utils.py @@ -265,6 +265,8 @@ def wrapper(*args: object, **kwargs: object) -> object: ) msg = f"Missing required arguments; Expected either {variations} arguments to be given" else: + assert len(variants) > 0 + # TODO: this error message is not deterministic missing = list(set(variants[0]) - given_params) if len(missing) > 1: @@ -389,3 +391,13 @@ def get_async_library() -> str: return sniffio.current_async_library() except Exception: return "false" + + +def lru_cache(*, maxsize: int | None = 128) -> Callable[[CallableT], CallableT]: + """A version of functools.lru_cache that retains the type signature + for the wrapped function arguments. + """ + wrapper = functools.lru_cache( # noqa: TID251 + maxsize=maxsize, + ) + return cast(Any, wrapper) # type: ignore[no-any-return] diff --git a/src/groq/_version.py b/src/groq/_version.py index 051c4d5..8343644 100644 --- a/src/groq/_version.py +++ b/src/groq/_version.py @@ -1,4 +1,4 @@ -# File generated from our OpenAPI spec by Stainless. +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. __title__ = "groq" __version__ = "0.5.0" # x-release-please-version diff --git a/src/groq/resources/__init__.py b/src/groq/resources/__init__.py index 56b9aed..66ea74d 100644 --- a/src/groq/resources/__init__.py +++ b/src/groq/resources/__init__.py @@ -1,47 +1,47 @@ -# File generated from our OpenAPI spec by Stainless. +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. from .chat import ( - Chat, - AsyncChat, - ChatWithRawResponse, - AsyncChatWithRawResponse, - ChatWithStreamingResponse, - AsyncChatWithStreamingResponse, + ChatResource, + AsyncChatResource, + ChatResourceWithRawResponse, + AsyncChatResourceWithRawResponse, + ChatResourceWithStreamingResponse, + AsyncChatResourceWithStreamingResponse, ) from .audio import ( - Audio, - AsyncAudio, - AudioWithRawResponse, - AsyncAudioWithRawResponse, - AudioWithStreamingResponse, - AsyncAudioWithStreamingResponse, + AudioResource, + AsyncAudioResource, + AudioResourceWithRawResponse, + AsyncAudioResourceWithRawResponse, + AudioResourceWithStreamingResponse, + AsyncAudioResourceWithStreamingResponse, ) from .models import ( - Models, - AsyncModels, - ModelsWithRawResponse, - AsyncModelsWithRawResponse, - ModelsWithStreamingResponse, - AsyncModelsWithStreamingResponse, + ModelsResource, + AsyncModelsResource, + ModelsResourceWithRawResponse, + AsyncModelsResourceWithRawResponse, + ModelsResourceWithStreamingResponse, + AsyncModelsResourceWithStreamingResponse, ) __all__ = [ - "Chat", - "AsyncChat", - "ChatWithRawResponse", - "AsyncChatWithRawResponse", - "ChatWithStreamingResponse", - "AsyncChatWithStreamingResponse", - "Audio", - "AsyncAudio", - "AudioWithRawResponse", - "AsyncAudioWithRawResponse", - "AudioWithStreamingResponse", - "AsyncAudioWithStreamingResponse", - "Models", - "AsyncModels", - "ModelsWithRawResponse", - "AsyncModelsWithRawResponse", - "ModelsWithStreamingResponse", - "AsyncModelsWithStreamingResponse", + "ChatResource", + "AsyncChatResource", + "ChatResourceWithRawResponse", + "AsyncChatResourceWithRawResponse", + "ChatResourceWithStreamingResponse", + "AsyncChatResourceWithStreamingResponse", + "AudioResource", + "AsyncAudioResource", + "AudioResourceWithRawResponse", + "AsyncAudioResourceWithRawResponse", + "AudioResourceWithStreamingResponse", + "AsyncAudioResourceWithStreamingResponse", + "ModelsResource", + "AsyncModelsResource", + "ModelsResourceWithRawResponse", + "AsyncModelsResourceWithRawResponse", + "ModelsResourceWithStreamingResponse", + "AsyncModelsResourceWithStreamingResponse", ] diff --git a/src/groq/resources/audio/__init__.py b/src/groq/resources/audio/__init__.py index 78a82de..934f87d 100644 --- a/src/groq/resources/audio/__init__.py +++ b/src/groq/resources/audio/__init__.py @@ -1,47 +1,47 @@ -# File generated from our OpenAPI spec by Stainless. +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. from .audio import ( - Audio, - AsyncAudio, - AudioWithRawResponse, - AsyncAudioWithRawResponse, - AudioWithStreamingResponse, - AsyncAudioWithStreamingResponse, + AudioResource, + AsyncAudioResource, + AudioResourceWithRawResponse, + AsyncAudioResourceWithRawResponse, + AudioResourceWithStreamingResponse, + AsyncAudioResourceWithStreamingResponse, ) from .translations import ( - Translations, - AsyncTranslations, - TranslationsWithRawResponse, - AsyncTranslationsWithRawResponse, - TranslationsWithStreamingResponse, - AsyncTranslationsWithStreamingResponse, + TranslationsResource, + AsyncTranslationsResource, + TranslationsResourceWithRawResponse, + AsyncTranslationsResourceWithRawResponse, + TranslationsResourceWithStreamingResponse, + AsyncTranslationsResourceWithStreamingResponse, ) from .transcriptions import ( - Transcriptions, - AsyncTranscriptions, - TranscriptionsWithRawResponse, - AsyncTranscriptionsWithRawResponse, - TranscriptionsWithStreamingResponse, - AsyncTranscriptionsWithStreamingResponse, + TranscriptionsResource, + AsyncTranscriptionsResource, + TranscriptionsResourceWithRawResponse, + AsyncTranscriptionsResourceWithRawResponse, + TranscriptionsResourceWithStreamingResponse, + AsyncTranscriptionsResourceWithStreamingResponse, ) __all__ = [ - "Transcriptions", - "AsyncTranscriptions", - "TranscriptionsWithRawResponse", - "AsyncTranscriptionsWithRawResponse", - "TranscriptionsWithStreamingResponse", - "AsyncTranscriptionsWithStreamingResponse", - "Translations", - "AsyncTranslations", - "TranslationsWithRawResponse", - "AsyncTranslationsWithRawResponse", - "TranslationsWithStreamingResponse", - "AsyncTranslationsWithStreamingResponse", - "Audio", - "AsyncAudio", - "AudioWithRawResponse", - "AsyncAudioWithRawResponse", - "AudioWithStreamingResponse", - "AsyncAudioWithStreamingResponse", + "TranscriptionsResource", + "AsyncTranscriptionsResource", + "TranscriptionsResourceWithRawResponse", + "AsyncTranscriptionsResourceWithRawResponse", + "TranscriptionsResourceWithStreamingResponse", + "AsyncTranscriptionsResourceWithStreamingResponse", + "TranslationsResource", + "AsyncTranslationsResource", + "TranslationsResourceWithRawResponse", + "AsyncTranslationsResourceWithRawResponse", + "TranslationsResourceWithStreamingResponse", + "AsyncTranslationsResourceWithStreamingResponse", + "AudioResource", + "AsyncAudioResource", + "AudioResourceWithRawResponse", + "AsyncAudioResourceWithRawResponse", + "AudioResourceWithStreamingResponse", + "AsyncAudioResourceWithStreamingResponse", ] diff --git a/src/groq/resources/audio/audio.py b/src/groq/resources/audio/audio.py index 958a363..003dc19 100644 --- a/src/groq/resources/audio/audio.py +++ b/src/groq/resources/audio/audio.py @@ -1,112 +1,112 @@ -# File generated from our OpenAPI spec by Stainless. +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. from __future__ import annotations from ..._compat import cached_property from ..._resource import SyncAPIResource, AsyncAPIResource from .translations import ( - Translations, - AsyncTranslations, - TranslationsWithRawResponse, - AsyncTranslationsWithRawResponse, - TranslationsWithStreamingResponse, - AsyncTranslationsWithStreamingResponse, + TranslationsResource, + AsyncTranslationsResource, + TranslationsResourceWithRawResponse, + AsyncTranslationsResourceWithRawResponse, + TranslationsResourceWithStreamingResponse, + AsyncTranslationsResourceWithStreamingResponse, ) from .transcriptions import ( - Transcriptions, - AsyncTranscriptions, - TranscriptionsWithRawResponse, - AsyncTranscriptionsWithRawResponse, - TranscriptionsWithStreamingResponse, - AsyncTranscriptionsWithStreamingResponse, + TranscriptionsResource, + AsyncTranscriptionsResource, + TranscriptionsResourceWithRawResponse, + AsyncTranscriptionsResourceWithRawResponse, + TranscriptionsResourceWithStreamingResponse, + AsyncTranscriptionsResourceWithStreamingResponse, ) -__all__ = ["Audio", "AsyncAudio"] +__all__ = ["AudioResource", "AsyncAudioResource"] -class Audio(SyncAPIResource): +class AudioResource(SyncAPIResource): @cached_property - def transcriptions(self) -> Transcriptions: - return Transcriptions(self._client) + def transcriptions(self) -> TranscriptionsResource: + return TranscriptionsResource(self._client) @cached_property - def translations(self) -> Translations: - return Translations(self._client) + def translations(self) -> TranslationsResource: + return TranslationsResource(self._client) @cached_property - def with_raw_response(self) -> AudioWithRawResponse: - return AudioWithRawResponse(self) + def with_raw_response(self) -> AudioResourceWithRawResponse: + return AudioResourceWithRawResponse(self) @cached_property - def with_streaming_response(self) -> AudioWithStreamingResponse: - return AudioWithStreamingResponse(self) + def with_streaming_response(self) -> AudioResourceWithStreamingResponse: + return AudioResourceWithStreamingResponse(self) -class AsyncAudio(AsyncAPIResource): +class AsyncAudioResource(AsyncAPIResource): @cached_property - def transcriptions(self) -> AsyncTranscriptions: - return AsyncTranscriptions(self._client) + def transcriptions(self) -> AsyncTranscriptionsResource: + return AsyncTranscriptionsResource(self._client) @cached_property - def translations(self) -> AsyncTranslations: - return AsyncTranslations(self._client) + def translations(self) -> AsyncTranslationsResource: + return AsyncTranslationsResource(self._client) @cached_property - def with_raw_response(self) -> AsyncAudioWithRawResponse: - return AsyncAudioWithRawResponse(self) + def with_raw_response(self) -> AsyncAudioResourceWithRawResponse: + return AsyncAudioResourceWithRawResponse(self) @cached_property - def with_streaming_response(self) -> AsyncAudioWithStreamingResponse: - return AsyncAudioWithStreamingResponse(self) + def with_streaming_response(self) -> AsyncAudioResourceWithStreamingResponse: + return AsyncAudioResourceWithStreamingResponse(self) -class AudioWithRawResponse: - def __init__(self, audio: Audio) -> None: +class AudioResourceWithRawResponse: + def __init__(self, audio: AudioResource) -> None: self._audio = audio @cached_property - def transcriptions(self) -> TranscriptionsWithRawResponse: - return TranscriptionsWithRawResponse(self._audio.transcriptions) + def transcriptions(self) -> TranscriptionsResourceWithRawResponse: + return TranscriptionsResourceWithRawResponse(self._audio.transcriptions) @cached_property - def translations(self) -> TranslationsWithRawResponse: - return TranslationsWithRawResponse(self._audio.translations) + def translations(self) -> TranslationsResourceWithRawResponse: + return TranslationsResourceWithRawResponse(self._audio.translations) -class AsyncAudioWithRawResponse: - def __init__(self, audio: AsyncAudio) -> None: +class AsyncAudioResourceWithRawResponse: + def __init__(self, audio: AsyncAudioResource) -> None: self._audio = audio @cached_property - def transcriptions(self) -> AsyncTranscriptionsWithRawResponse: - return AsyncTranscriptionsWithRawResponse(self._audio.transcriptions) + def transcriptions(self) -> AsyncTranscriptionsResourceWithRawResponse: + return AsyncTranscriptionsResourceWithRawResponse(self._audio.transcriptions) @cached_property - def translations(self) -> AsyncTranslationsWithRawResponse: - return AsyncTranslationsWithRawResponse(self._audio.translations) + def translations(self) -> AsyncTranslationsResourceWithRawResponse: + return AsyncTranslationsResourceWithRawResponse(self._audio.translations) -class AudioWithStreamingResponse: - def __init__(self, audio: Audio) -> None: +class AudioResourceWithStreamingResponse: + def __init__(self, audio: AudioResource) -> None: self._audio = audio @cached_property - def transcriptions(self) -> TranscriptionsWithStreamingResponse: - return TranscriptionsWithStreamingResponse(self._audio.transcriptions) + def transcriptions(self) -> TranscriptionsResourceWithStreamingResponse: + return TranscriptionsResourceWithStreamingResponse(self._audio.transcriptions) @cached_property - def translations(self) -> TranslationsWithStreamingResponse: - return TranslationsWithStreamingResponse(self._audio.translations) + def translations(self) -> TranslationsResourceWithStreamingResponse: + return TranslationsResourceWithStreamingResponse(self._audio.translations) -class AsyncAudioWithStreamingResponse: - def __init__(self, audio: AsyncAudio) -> None: +class AsyncAudioResourceWithStreamingResponse: + def __init__(self, audio: AsyncAudioResource) -> None: self._audio = audio @cached_property - def transcriptions(self) -> AsyncTranscriptionsWithStreamingResponse: - return AsyncTranscriptionsWithStreamingResponse(self._audio.transcriptions) + def transcriptions(self) -> AsyncTranscriptionsResourceWithStreamingResponse: + return AsyncTranscriptionsResourceWithStreamingResponse(self._audio.transcriptions) @cached_property - def translations(self) -> AsyncTranslationsWithStreamingResponse: - return AsyncTranslationsWithStreamingResponse(self._audio.translations) + def translations(self) -> AsyncTranslationsResourceWithStreamingResponse: + return AsyncTranslationsResourceWithStreamingResponse(self._audio.translations) diff --git a/src/groq/resources/audio/transcriptions.py b/src/groq/resources/audio/transcriptions.py index bb9c523..0ab55c5 100644 --- a/src/groq/resources/audio/transcriptions.py +++ b/src/groq/resources/audio/transcriptions.py @@ -1,4 +1,4 @@ -# File generated from our OpenAPI spec by Stainless. +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. from __future__ import annotations @@ -22,22 +22,23 @@ async_to_raw_response_wrapper, async_to_streamed_response_wrapper, ) -from ...types.audio import Transcription, transcription_create_params +from ...types.audio import transcription_create_params from ..._base_client import ( make_request_options, ) +from ...types.audio.transcription import Transcription -__all__ = ["Transcriptions", "AsyncTranscriptions"] +__all__ = ["TranscriptionsResource", "AsyncTranscriptionsResource"] -class Transcriptions(SyncAPIResource): +class TranscriptionsResource(SyncAPIResource): @cached_property - def with_raw_response(self) -> TranscriptionsWithRawResponse: - return TranscriptionsWithRawResponse(self) + def with_raw_response(self) -> TranscriptionsResourceWithRawResponse: + return TranscriptionsResourceWithRawResponse(self) @cached_property - def with_streaming_response(self) -> TranscriptionsWithStreamingResponse: - return TranscriptionsWithStreamingResponse(self) + def with_streaming_response(self) -> TranscriptionsResourceWithStreamingResponse: + return TranscriptionsResourceWithStreamingResponse(self) def create( self, @@ -125,14 +126,14 @@ def create( ) -class AsyncTranscriptions(AsyncAPIResource): +class AsyncTranscriptionsResource(AsyncAPIResource): @cached_property - def with_raw_response(self) -> AsyncTranscriptionsWithRawResponse: - return AsyncTranscriptionsWithRawResponse(self) + def with_raw_response(self) -> AsyncTranscriptionsResourceWithRawResponse: + return AsyncTranscriptionsResourceWithRawResponse(self) @cached_property - def with_streaming_response(self) -> AsyncTranscriptionsWithStreamingResponse: - return AsyncTranscriptionsWithStreamingResponse(self) + def with_streaming_response(self) -> AsyncTranscriptionsResourceWithStreamingResponse: + return AsyncTranscriptionsResourceWithStreamingResponse(self) async def create( self, @@ -220,8 +221,8 @@ async def create( ) -class TranscriptionsWithRawResponse: - def __init__(self, transcriptions: Transcriptions) -> None: +class TranscriptionsResourceWithRawResponse: + def __init__(self, transcriptions: TranscriptionsResource) -> None: self._transcriptions = transcriptions self.create = to_raw_response_wrapper( @@ -229,8 +230,8 @@ def __init__(self, transcriptions: Transcriptions) -> None: ) -class AsyncTranscriptionsWithRawResponse: - def __init__(self, transcriptions: AsyncTranscriptions) -> None: +class AsyncTranscriptionsResourceWithRawResponse: + def __init__(self, transcriptions: AsyncTranscriptionsResource) -> None: self._transcriptions = transcriptions self.create = async_to_raw_response_wrapper( @@ -238,8 +239,8 @@ def __init__(self, transcriptions: AsyncTranscriptions) -> None: ) -class TranscriptionsWithStreamingResponse: - def __init__(self, transcriptions: Transcriptions) -> None: +class TranscriptionsResourceWithStreamingResponse: + def __init__(self, transcriptions: TranscriptionsResource) -> None: self._transcriptions = transcriptions self.create = to_streamed_response_wrapper( @@ -247,8 +248,8 @@ def __init__(self, transcriptions: Transcriptions) -> None: ) -class AsyncTranscriptionsWithStreamingResponse: - def __init__(self, transcriptions: AsyncTranscriptions) -> None: +class AsyncTranscriptionsResourceWithStreamingResponse: + def __init__(self, transcriptions: AsyncTranscriptionsResource) -> None: self._transcriptions = transcriptions self.create = async_to_streamed_response_wrapper( diff --git a/src/groq/resources/audio/translations.py b/src/groq/resources/audio/translations.py index 37745c4..6267909 100644 --- a/src/groq/resources/audio/translations.py +++ b/src/groq/resources/audio/translations.py @@ -1,4 +1,4 @@ -# File generated from our OpenAPI spec by Stainless. +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. from __future__ import annotations @@ -7,7 +7,6 @@ import httpx -from ...types import Translation from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven, FileTypes from ..._utils import ( extract_files, @@ -27,18 +26,19 @@ from ..._base_client import ( make_request_options, ) +from ...types.translation import Translation -__all__ = ["Translations", "AsyncTranslations"] +__all__ = ["TranslationsResource", "AsyncTranslationsResource"] -class Translations(SyncAPIResource): +class TranslationsResource(SyncAPIResource): @cached_property - def with_raw_response(self) -> TranslationsWithRawResponse: - return TranslationsWithRawResponse(self) + def with_raw_response(self) -> TranslationsResourceWithRawResponse: + return TranslationsResourceWithRawResponse(self) @cached_property - def with_streaming_response(self) -> TranslationsWithStreamingResponse: - return TranslationsWithStreamingResponse(self) + def with_streaming_response(self) -> TranslationsResourceWithStreamingResponse: + return TranslationsResourceWithStreamingResponse(self) def create( self, @@ -111,14 +111,14 @@ def create( ) -class AsyncTranslations(AsyncAPIResource): +class AsyncTranslationsResource(AsyncAPIResource): @cached_property - def with_raw_response(self) -> AsyncTranslationsWithRawResponse: - return AsyncTranslationsWithRawResponse(self) + def with_raw_response(self) -> AsyncTranslationsResourceWithRawResponse: + return AsyncTranslationsResourceWithRawResponse(self) @cached_property - def with_streaming_response(self) -> AsyncTranslationsWithStreamingResponse: - return AsyncTranslationsWithStreamingResponse(self) + def with_streaming_response(self) -> AsyncTranslationsResourceWithStreamingResponse: + return AsyncTranslationsResourceWithStreamingResponse(self) async def create( self, @@ -191,8 +191,8 @@ async def create( ) -class TranslationsWithRawResponse: - def __init__(self, translations: Translations) -> None: +class TranslationsResourceWithRawResponse: + def __init__(self, translations: TranslationsResource) -> None: self._translations = translations self.create = to_raw_response_wrapper( @@ -200,8 +200,8 @@ def __init__(self, translations: Translations) -> None: ) -class AsyncTranslationsWithRawResponse: - def __init__(self, translations: AsyncTranslations) -> None: +class AsyncTranslationsResourceWithRawResponse: + def __init__(self, translations: AsyncTranslationsResource) -> None: self._translations = translations self.create = async_to_raw_response_wrapper( @@ -209,8 +209,8 @@ def __init__(self, translations: AsyncTranslations) -> None: ) -class TranslationsWithStreamingResponse: - def __init__(self, translations: Translations) -> None: +class TranslationsResourceWithStreamingResponse: + def __init__(self, translations: TranslationsResource) -> None: self._translations = translations self.create = to_streamed_response_wrapper( @@ -218,8 +218,8 @@ def __init__(self, translations: Translations) -> None: ) -class AsyncTranslationsWithStreamingResponse: - def __init__(self, translations: AsyncTranslations) -> None: +class AsyncTranslationsResourceWithStreamingResponse: + def __init__(self, translations: AsyncTranslationsResource) -> None: self._translations = translations self.create = async_to_streamed_response_wrapper( diff --git a/src/groq/resources/chat/__init__.py b/src/groq/resources/chat/__init__.py index a966805..ec960eb 100644 --- a/src/groq/resources/chat/__init__.py +++ b/src/groq/resources/chat/__init__.py @@ -1,33 +1,33 @@ -# File generated from our OpenAPI spec by Stainless. +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. from .chat import ( - Chat, - AsyncChat, - ChatWithRawResponse, - AsyncChatWithRawResponse, - ChatWithStreamingResponse, - AsyncChatWithStreamingResponse, + ChatResource, + AsyncChatResource, + ChatResourceWithRawResponse, + AsyncChatResourceWithRawResponse, + ChatResourceWithStreamingResponse, + AsyncChatResourceWithStreamingResponse, ) from .completions import ( - Completions, - AsyncCompletions, - CompletionsWithRawResponse, - AsyncCompletionsWithRawResponse, - CompletionsWithStreamingResponse, - AsyncCompletionsWithStreamingResponse, + CompletionsResource, + AsyncCompletionsResource, + CompletionsResourceWithRawResponse, + AsyncCompletionsResourceWithRawResponse, + CompletionsResourceWithStreamingResponse, + AsyncCompletionsResourceWithStreamingResponse, ) __all__ = [ - "Completions", - "AsyncCompletions", - "CompletionsWithRawResponse", - "AsyncCompletionsWithRawResponse", - "CompletionsWithStreamingResponse", - "AsyncCompletionsWithStreamingResponse", - "Chat", - "AsyncChat", - "ChatWithRawResponse", - "AsyncChatWithRawResponse", - "ChatWithStreamingResponse", - "AsyncChatWithStreamingResponse", + "CompletionsResource", + "AsyncCompletionsResource", + "CompletionsResourceWithRawResponse", + "AsyncCompletionsResourceWithRawResponse", + "CompletionsResourceWithStreamingResponse", + "AsyncCompletionsResourceWithStreamingResponse", + "ChatResource", + "AsyncChatResource", + "ChatResourceWithRawResponse", + "AsyncChatResourceWithRawResponse", + "ChatResourceWithStreamingResponse", + "AsyncChatResourceWithStreamingResponse", ] diff --git a/src/groq/resources/chat/chat.py b/src/groq/resources/chat/chat.py index b6effa4..3a74254 100644 --- a/src/groq/resources/chat/chat.py +++ b/src/groq/resources/chat/chat.py @@ -1,80 +1,80 @@ -# File generated from our OpenAPI spec by Stainless. +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. from __future__ import annotations from ..._compat import cached_property from ..._resource import SyncAPIResource, AsyncAPIResource from .completions import ( - Completions, - AsyncCompletions, - CompletionsWithRawResponse, - AsyncCompletionsWithRawResponse, - CompletionsWithStreamingResponse, - AsyncCompletionsWithStreamingResponse, + CompletionsResource, + AsyncCompletionsResource, + CompletionsResourceWithRawResponse, + AsyncCompletionsResourceWithRawResponse, + CompletionsResourceWithStreamingResponse, + AsyncCompletionsResourceWithStreamingResponse, ) -__all__ = ["Chat", "AsyncChat"] +__all__ = ["ChatResource", "AsyncChatResource"] -class Chat(SyncAPIResource): +class ChatResource(SyncAPIResource): @cached_property - def completions(self) -> Completions: - return Completions(self._client) + def completions(self) -> CompletionsResource: + return CompletionsResource(self._client) @cached_property - def with_raw_response(self) -> ChatWithRawResponse: - return ChatWithRawResponse(self) + def with_raw_response(self) -> ChatResourceWithRawResponse: + return ChatResourceWithRawResponse(self) @cached_property - def with_streaming_response(self) -> ChatWithStreamingResponse: - return ChatWithStreamingResponse(self) + def with_streaming_response(self) -> ChatResourceWithStreamingResponse: + return ChatResourceWithStreamingResponse(self) -class AsyncChat(AsyncAPIResource): +class AsyncChatResource(AsyncAPIResource): @cached_property - def completions(self) -> AsyncCompletions: - return AsyncCompletions(self._client) + def completions(self) -> AsyncCompletionsResource: + return AsyncCompletionsResource(self._client) @cached_property - def with_raw_response(self) -> AsyncChatWithRawResponse: - return AsyncChatWithRawResponse(self) + def with_raw_response(self) -> AsyncChatResourceWithRawResponse: + return AsyncChatResourceWithRawResponse(self) @cached_property - def with_streaming_response(self) -> AsyncChatWithStreamingResponse: - return AsyncChatWithStreamingResponse(self) + def with_streaming_response(self) -> AsyncChatResourceWithStreamingResponse: + return AsyncChatResourceWithStreamingResponse(self) -class ChatWithRawResponse: - def __init__(self, chat: Chat) -> None: +class ChatResourceWithRawResponse: + def __init__(self, chat: ChatResource) -> None: self._chat = chat @cached_property - def completions(self) -> CompletionsWithRawResponse: - return CompletionsWithRawResponse(self._chat.completions) + def completions(self) -> CompletionsResourceWithRawResponse: + return CompletionsResourceWithRawResponse(self._chat.completions) -class AsyncChatWithRawResponse: - def __init__(self, chat: AsyncChat) -> None: +class AsyncChatResourceWithRawResponse: + def __init__(self, chat: AsyncChatResource) -> None: self._chat = chat @cached_property - def completions(self) -> AsyncCompletionsWithRawResponse: - return AsyncCompletionsWithRawResponse(self._chat.completions) + def completions(self) -> AsyncCompletionsResourceWithRawResponse: + return AsyncCompletionsResourceWithRawResponse(self._chat.completions) -class ChatWithStreamingResponse: - def __init__(self, chat: Chat) -> None: +class ChatResourceWithStreamingResponse: + def __init__(self, chat: ChatResource) -> None: self._chat = chat @cached_property - def completions(self) -> CompletionsWithStreamingResponse: - return CompletionsWithStreamingResponse(self._chat.completions) + def completions(self) -> CompletionsResourceWithStreamingResponse: + return CompletionsResourceWithStreamingResponse(self._chat.completions) -class AsyncChatWithStreamingResponse: - def __init__(self, chat: AsyncChat) -> None: +class AsyncChatResourceWithStreamingResponse: + def __init__(self, chat: AsyncChatResource) -> None: self._chat = chat @cached_property - def completions(self) -> AsyncCompletionsWithStreamingResponse: - return AsyncCompletionsWithStreamingResponse(self._chat.completions) + def completions(self) -> AsyncCompletionsResourceWithStreamingResponse: + return AsyncCompletionsResourceWithStreamingResponse(self._chat.completions) diff --git a/src/groq/resources/chat/completions.py b/src/groq/resources/chat/completions.py index 277aa58..03b1d2b 100644 --- a/src/groq/resources/chat/completions.py +++ b/src/groq/resources/chat/completions.py @@ -1,9 +1,8 @@ -# File generated from our OpenAPI spec by Stainless. +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. from __future__ import annotations -from typing import Dict, List, Union, Iterable, Optional, overload -from typing_extensions import Literal +from typing import Dict, List, Union, Iterable, Optional import httpx @@ -20,26 +19,24 @@ async_to_raw_response_wrapper, async_to_streamed_response_wrapper, ) -from ..._streaming import Stream, AsyncStream -from ...types.chat import ChatCompletion, completion_create_params +from ...types.chat import completion_create_params from ..._base_client import ( make_request_options, ) -from ...lib.chat_completion_chunk import ChatCompletionChunk +from ...types.chat.chat_completion import ChatCompletion -__all__ = ["Completions", "AsyncCompletions"] +__all__ = ["CompletionsResource", "AsyncCompletionsResource"] -class Completions(SyncAPIResource): +class CompletionsResource(SyncAPIResource): @cached_property - def with_raw_response(self) -> CompletionsWithRawResponse: - return CompletionsWithRawResponse(self) + def with_raw_response(self) -> CompletionsResourceWithRawResponse: + return CompletionsResourceWithRawResponse(self) @cached_property - def with_streaming_response(self) -> CompletionsWithStreamingResponse: - return CompletionsWithStreamingResponse(self) + def with_streaming_response(self) -> CompletionsResourceWithStreamingResponse: + return CompletionsResourceWithStreamingResponse(self) - @overload def create( self, *, @@ -54,7 +51,7 @@ def create( response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN, seed: int | NotGiven = NOT_GIVEN, stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN, - stream: Optional[Literal[False]] | NotGiven = NOT_GIVEN, + stream: bool | NotGiven = NOT_GIVEN, temperature: float | NotGiven = NOT_GIVEN, tool_choice: completion_create_params.ToolChoice | NotGiven = NOT_GIVEN, tools: Iterable[completion_create_params.Tool] | NotGiven = NOT_GIVEN, @@ -68,98 +65,6 @@ def create( extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> ChatCompletion: - ... - - @overload - def create( - self, - *, - messages: Iterable[completion_create_params.Message], - model: str, - frequency_penalty: float | NotGiven = NOT_GIVEN, - logit_bias: Dict[str, int] | NotGiven = NOT_GIVEN, - logprobs: bool | NotGiven = NOT_GIVEN, - max_tokens: int | NotGiven = NOT_GIVEN, - n: int | NotGiven = NOT_GIVEN, - presence_penalty: float | NotGiven = NOT_GIVEN, - response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN, - seed: int | NotGiven = NOT_GIVEN, - stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN, - stream: Literal[True], - temperature: float | NotGiven = NOT_GIVEN, - tool_choice: completion_create_params.ToolChoice | NotGiven = NOT_GIVEN, - tools: Iterable[completion_create_params.Tool] | NotGiven = NOT_GIVEN, - top_logprobs: int | NotGiven = NOT_GIVEN, - top_p: float | NotGiven = NOT_GIVEN, - user: str | NotGiven = NOT_GIVEN, - # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. - # The extra values given here take precedence over values defined on the client or passed to this method. - extra_headers: Headers | None = None, - extra_query: Query | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> Stream[ChatCompletionChunk]: - ... - - @overload - def create( - self, - *, - messages: Iterable[completion_create_params.Message], - model: str, - frequency_penalty: float | NotGiven = NOT_GIVEN, - logit_bias: Dict[str, int] | NotGiven = NOT_GIVEN, - logprobs: bool | NotGiven = NOT_GIVEN, - max_tokens: int | NotGiven = NOT_GIVEN, - n: int | NotGiven = NOT_GIVEN, - presence_penalty: float | NotGiven = NOT_GIVEN, - response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN, - seed: int | NotGiven = NOT_GIVEN, - stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN, - stream: bool, - temperature: float | NotGiven = NOT_GIVEN, - tool_choice: completion_create_params.ToolChoice | NotGiven = NOT_GIVEN, - tools: Iterable[completion_create_params.Tool] | NotGiven = NOT_GIVEN, - top_logprobs: int | NotGiven = NOT_GIVEN, - top_p: float | NotGiven = NOT_GIVEN, - user: str | NotGiven = NOT_GIVEN, - # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. - # The extra values given here take precedence over values defined on the client or passed to this method. - extra_headers: Headers | None = None, - extra_query: Query | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> ChatCompletion | Stream[ChatCompletionChunk]: - ... - - def create( - self, - *, - messages: Iterable[completion_create_params.Message], - model: str, - frequency_penalty: float | NotGiven = NOT_GIVEN, - logit_bias: Dict[str, int] | NotGiven = NOT_GIVEN, - logprobs: bool | NotGiven = NOT_GIVEN, - max_tokens: int | NotGiven = NOT_GIVEN, - n: int | NotGiven = NOT_GIVEN, - presence_penalty: float | NotGiven = NOT_GIVEN, - response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN, - seed: int | NotGiven = NOT_GIVEN, - stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN, - stream: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN, - temperature: float | NotGiven = NOT_GIVEN, - tool_choice: completion_create_params.ToolChoice | NotGiven = NOT_GIVEN, - tools: Iterable[completion_create_params.Tool] | NotGiven = NOT_GIVEN, - top_logprobs: int | NotGiven = NOT_GIVEN, - top_p: float | NotGiven = NOT_GIVEN, - user: str | NotGiven = NOT_GIVEN, - # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. - # The extra values given here take precedence over values defined on the client or passed to this method. - extra_headers: Headers | None = None, - extra_query: Query | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> ChatCompletion | Stream[ChatCompletionChunk]: """ Creates a completion for a chat prompt @@ -204,21 +109,18 @@ def create( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), cast_to=ChatCompletion, - stream=stream or False, - stream_cls=Stream[ChatCompletionChunk], ) -class AsyncCompletions(AsyncAPIResource): +class AsyncCompletionsResource(AsyncAPIResource): @cached_property - def with_raw_response(self) -> AsyncCompletionsWithRawResponse: - return AsyncCompletionsWithRawResponse(self) + def with_raw_response(self) -> AsyncCompletionsResourceWithRawResponse: + return AsyncCompletionsResourceWithRawResponse(self) @cached_property - def with_streaming_response(self) -> AsyncCompletionsWithStreamingResponse: - return AsyncCompletionsWithStreamingResponse(self) + def with_streaming_response(self) -> AsyncCompletionsResourceWithStreamingResponse: + return AsyncCompletionsResourceWithStreamingResponse(self) - @overload async def create( self, *, @@ -233,7 +135,7 @@ async def create( response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN, seed: int | NotGiven = NOT_GIVEN, stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN, - stream: Optional[Literal[False]] | NotGiven = NOT_GIVEN, + stream: bool | NotGiven = NOT_GIVEN, temperature: float | NotGiven = NOT_GIVEN, tool_choice: completion_create_params.ToolChoice | NotGiven = NOT_GIVEN, tools: Iterable[completion_create_params.Tool] | NotGiven = NOT_GIVEN, @@ -247,98 +149,6 @@ async def create( extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> ChatCompletion: - ... - - @overload - async def create( - self, - *, - messages: Iterable[completion_create_params.Message], - model: str, - frequency_penalty: float | NotGiven = NOT_GIVEN, - logit_bias: Dict[str, int] | NotGiven = NOT_GIVEN, - logprobs: bool | NotGiven = NOT_GIVEN, - max_tokens: int | NotGiven = NOT_GIVEN, - n: int | NotGiven = NOT_GIVEN, - presence_penalty: float | NotGiven = NOT_GIVEN, - response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN, - seed: int | NotGiven = NOT_GIVEN, - stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN, - stream: Literal[True], - temperature: float | NotGiven = NOT_GIVEN, - tool_choice: completion_create_params.ToolChoice | NotGiven = NOT_GIVEN, - tools: Iterable[completion_create_params.Tool] | NotGiven = NOT_GIVEN, - top_logprobs: int | NotGiven = NOT_GIVEN, - top_p: float | NotGiven = NOT_GIVEN, - user: str | NotGiven = NOT_GIVEN, - # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. - # The extra values given here take precedence over values defined on the client or passed to this method. - extra_headers: Headers | None = None, - extra_query: Query | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> AsyncStream[ChatCompletionChunk]: - ... - - @overload - async def create( - self, - *, - messages: Iterable[completion_create_params.Message], - model: str, - frequency_penalty: float | NotGiven = NOT_GIVEN, - logit_bias: Dict[str, int] | NotGiven = NOT_GIVEN, - logprobs: bool | NotGiven = NOT_GIVEN, - max_tokens: int | NotGiven = NOT_GIVEN, - n: int | NotGiven = NOT_GIVEN, - presence_penalty: float | NotGiven = NOT_GIVEN, - response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN, - seed: int | NotGiven = NOT_GIVEN, - stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN, - stream: bool, - temperature: float | NotGiven = NOT_GIVEN, - tool_choice: completion_create_params.ToolChoice | NotGiven = NOT_GIVEN, - tools: Iterable[completion_create_params.Tool] | NotGiven = NOT_GIVEN, - top_logprobs: int | NotGiven = NOT_GIVEN, - top_p: float | NotGiven = NOT_GIVEN, - user: str | NotGiven = NOT_GIVEN, - # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. - # The extra values given here take precedence over values defined on the client or passed to this method. - extra_headers: Headers | None = None, - extra_query: Query | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> ChatCompletion | AsyncStream[ChatCompletionChunk]: - ... - - async def create( - self, - *, - messages: Iterable[completion_create_params.Message], - model: str, - frequency_penalty: float | NotGiven = NOT_GIVEN, - logit_bias: Dict[str, int] | NotGiven = NOT_GIVEN, - logprobs: bool | NotGiven = NOT_GIVEN, - max_tokens: int | NotGiven = NOT_GIVEN, - n: int | NotGiven = NOT_GIVEN, - presence_penalty: float | NotGiven = NOT_GIVEN, - response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN, - seed: int | NotGiven = NOT_GIVEN, - stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN, - stream: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN, - temperature: float | NotGiven = NOT_GIVEN, - tool_choice: completion_create_params.ToolChoice | NotGiven = NOT_GIVEN, - tools: Iterable[completion_create_params.Tool] | NotGiven = NOT_GIVEN, - top_logprobs: int | NotGiven = NOT_GIVEN, - top_p: float | NotGiven = NOT_GIVEN, - user: str | NotGiven = NOT_GIVEN, - # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. - # The extra values given here take precedence over values defined on the client or passed to this method. - extra_headers: Headers | None = None, - extra_query: Query | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> ChatCompletion | AsyncStream[ChatCompletionChunk]: """ Creates a completion for a chat prompt @@ -383,13 +193,11 @@ async def create( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), cast_to=ChatCompletion, - stream=stream or False, - stream_cls=AsyncStream[ChatCompletionChunk], ) -class CompletionsWithRawResponse: - def __init__(self, completions: Completions) -> None: +class CompletionsResourceWithRawResponse: + def __init__(self, completions: CompletionsResource) -> None: self._completions = completions self.create = to_raw_response_wrapper( @@ -397,8 +205,8 @@ def __init__(self, completions: Completions) -> None: ) -class AsyncCompletionsWithRawResponse: - def __init__(self, completions: AsyncCompletions) -> None: +class AsyncCompletionsResourceWithRawResponse: + def __init__(self, completions: AsyncCompletionsResource) -> None: self._completions = completions self.create = async_to_raw_response_wrapper( @@ -406,8 +214,8 @@ def __init__(self, completions: AsyncCompletions) -> None: ) -class CompletionsWithStreamingResponse: - def __init__(self, completions: Completions) -> None: +class CompletionsResourceWithStreamingResponse: + def __init__(self, completions: CompletionsResource) -> None: self._completions = completions self.create = to_streamed_response_wrapper( @@ -415,8 +223,8 @@ def __init__(self, completions: Completions) -> None: ) -class AsyncCompletionsWithStreamingResponse: - def __init__(self, completions: AsyncCompletions) -> None: +class AsyncCompletionsResourceWithStreamingResponse: + def __init__(self, completions: AsyncCompletionsResource) -> None: self._completions = completions self.create = async_to_streamed_response_wrapper( diff --git a/src/groq/resources/models.py b/src/groq/resources/models.py index 5962f77..3fdb482 100644 --- a/src/groq/resources/models.py +++ b/src/groq/resources/models.py @@ -1,10 +1,9 @@ -# File generated from our OpenAPI spec by Stainless. +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. from __future__ import annotations import httpx -from ..types import Model, ModelList from .._types import NOT_GIVEN, Body, Query, Headers, NoneType, NotGiven from .._compat import cached_property from .._resource import SyncAPIResource, AsyncAPIResource @@ -14,21 +13,23 @@ async_to_raw_response_wrapper, async_to_streamed_response_wrapper, ) +from ..types.model import Model from .._base_client import ( make_request_options, ) +from ..types.model_list import ModelList -__all__ = ["Models", "AsyncModels"] +__all__ = ["ModelsResource", "AsyncModelsResource"] -class Models(SyncAPIResource): +class ModelsResource(SyncAPIResource): @cached_property - def with_raw_response(self) -> ModelsWithRawResponse: - return ModelsWithRawResponse(self) + def with_raw_response(self) -> ModelsResourceWithRawResponse: + return ModelsResourceWithRawResponse(self) @cached_property - def with_streaming_response(self) -> ModelsWithStreamingResponse: - return ModelsWithStreamingResponse(self) + def with_streaming_response(self) -> ModelsResourceWithStreamingResponse: + return ModelsResourceWithStreamingResponse(self) def retrieve( self, @@ -117,14 +118,14 @@ def delete( ) -class AsyncModels(AsyncAPIResource): +class AsyncModelsResource(AsyncAPIResource): @cached_property - def with_raw_response(self) -> AsyncModelsWithRawResponse: - return AsyncModelsWithRawResponse(self) + def with_raw_response(self) -> AsyncModelsResourceWithRawResponse: + return AsyncModelsResourceWithRawResponse(self) @cached_property - def with_streaming_response(self) -> AsyncModelsWithStreamingResponse: - return AsyncModelsWithStreamingResponse(self) + def with_streaming_response(self) -> AsyncModelsResourceWithStreamingResponse: + return AsyncModelsResourceWithStreamingResponse(self) async def retrieve( self, @@ -213,8 +214,8 @@ async def delete( ) -class ModelsWithRawResponse: - def __init__(self, models: Models) -> None: +class ModelsResourceWithRawResponse: + def __init__(self, models: ModelsResource) -> None: self._models = models self.retrieve = to_raw_response_wrapper( @@ -228,8 +229,8 @@ def __init__(self, models: Models) -> None: ) -class AsyncModelsWithRawResponse: - def __init__(self, models: AsyncModels) -> None: +class AsyncModelsResourceWithRawResponse: + def __init__(self, models: AsyncModelsResource) -> None: self._models = models self.retrieve = async_to_raw_response_wrapper( @@ -243,8 +244,8 @@ def __init__(self, models: AsyncModels) -> None: ) -class ModelsWithStreamingResponse: - def __init__(self, models: Models) -> None: +class ModelsResourceWithStreamingResponse: + def __init__(self, models: ModelsResource) -> None: self._models = models self.retrieve = to_streamed_response_wrapper( @@ -258,8 +259,8 @@ def __init__(self, models: Models) -> None: ) -class AsyncModelsWithStreamingResponse: - def __init__(self, models: AsyncModels) -> None: +class AsyncModelsResourceWithStreamingResponse: + def __init__(self, models: AsyncModelsResource) -> None: self._models = models self.retrieve = async_to_streamed_response_wrapper( diff --git a/src/groq/types/__init__.py b/src/groq/types/__init__.py index b7f0319..ee038d4 100644 --- a/src/groq/types/__init__.py +++ b/src/groq/types/__init__.py @@ -1,4 +1,4 @@ -# File generated from our OpenAPI spec by Stainless. +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. from __future__ import annotations diff --git a/src/groq/types/audio/__init__.py b/src/groq/types/audio/__init__.py index 6b1acc9..ae3a015 100644 --- a/src/groq/types/audio/__init__.py +++ b/src/groq/types/audio/__init__.py @@ -1,4 +1,4 @@ -# File generated from our OpenAPI spec by Stainless. +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. from __future__ import annotations diff --git a/src/groq/types/audio/transcription.py b/src/groq/types/audio/transcription.py index 6532611..0b6ab39 100644 --- a/src/groq/types/audio/transcription.py +++ b/src/groq/types/audio/transcription.py @@ -1,4 +1,6 @@ -# File generated from our OpenAPI spec by Stainless. +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + + from ..._models import BaseModel diff --git a/src/groq/types/audio/transcription_create_params.py b/src/groq/types/audio/transcription_create_params.py index c92ba89..373793c 100644 --- a/src/groq/types/audio/transcription_create_params.py +++ b/src/groq/types/audio/transcription_create_params.py @@ -1,4 +1,4 @@ -# File generated from our OpenAPI spec by Stainless. +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. from __future__ import annotations diff --git a/src/groq/types/audio/translation_create_params.py b/src/groq/types/audio/translation_create_params.py index 541542b..c98b320 100644 --- a/src/groq/types/audio/translation_create_params.py +++ b/src/groq/types/audio/translation_create_params.py @@ -1,4 +1,4 @@ -# File generated from our OpenAPI spec by Stainless. +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. from __future__ import annotations diff --git a/src/groq/types/chat/__init__.py b/src/groq/types/chat/__init__.py index 00f0222..fa7a61a 100644 --- a/src/groq/types/chat/__init__.py +++ b/src/groq/types/chat/__init__.py @@ -1,4 +1,4 @@ -# File generated from our OpenAPI spec by Stainless. +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. from __future__ import annotations diff --git a/src/groq/types/chat/chat_completion.py b/src/groq/types/chat/chat_completion.py index 2b8b3e2..9e36c15 100644 --- a/src/groq/types/chat/chat_completion.py +++ b/src/groq/types/chat/chat_completion.py @@ -1,4 +1,4 @@ -# File generated from our OpenAPI spec by Stainless. +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. from typing import List, Optional diff --git a/src/groq/types/chat/completion_create_params.py b/src/groq/types/chat/completion_create_params.py index 0f9712b..7f44523 100644 --- a/src/groq/types/chat/completion_create_params.py +++ b/src/groq/types/chat/completion_create_params.py @@ -1,4 +1,4 @@ -# File generated from our OpenAPI spec by Stainless. +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. from __future__ import annotations diff --git a/src/groq/types/model.py b/src/groq/types/model.py index a8a5229..ed49a06 100644 --- a/src/groq/types/model.py +++ b/src/groq/types/model.py @@ -1,4 +1,4 @@ -# File generated from our OpenAPI spec by Stainless. +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. from typing import Optional diff --git a/src/groq/types/model_list.py b/src/groq/types/model_list.py index 26fe9e5..8315560 100644 --- a/src/groq/types/model_list.py +++ b/src/groq/types/model_list.py @@ -1,4 +1,4 @@ -# File generated from our OpenAPI spec by Stainless. +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. from typing import List, Optional diff --git a/src/groq/types/translation.py b/src/groq/types/translation.py index 76cdbb5..f36fade 100644 --- a/src/groq/types/translation.py +++ b/src/groq/types/translation.py @@ -1,4 +1,6 @@ -# File generated from our OpenAPI spec by Stainless. +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + + from .._models import BaseModel diff --git a/tests/__init__.py b/tests/__init__.py index 1016754..fd8019a 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1 +1 @@ -# File generated from our OpenAPI spec by Stainless. +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. diff --git a/tests/api_resources/__init__.py b/tests/api_resources/__init__.py index 1016754..fd8019a 100644 --- a/tests/api_resources/__init__.py +++ b/tests/api_resources/__init__.py @@ -1 +1 @@ -# File generated from our OpenAPI spec by Stainless. +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. diff --git a/tests/api_resources/audio/__init__.py b/tests/api_resources/audio/__init__.py index 1016754..fd8019a 100644 --- a/tests/api_resources/audio/__init__.py +++ b/tests/api_resources/audio/__init__.py @@ -1 +1 @@ -# File generated from our OpenAPI spec by Stainless. +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. diff --git a/tests/api_resources/audio/test_transcriptions.py b/tests/api_resources/audio/test_transcriptions.py index a54fdd9..b54784b 100644 --- a/tests/api_resources/audio/test_transcriptions.py +++ b/tests/api_resources/audio/test_transcriptions.py @@ -1,4 +1,4 @@ -# File generated from our OpenAPI spec by Stainless. +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. from __future__ import annotations diff --git a/tests/api_resources/audio/test_translations.py b/tests/api_resources/audio/test_translations.py index ccd5a5d..eae2a01 100644 --- a/tests/api_resources/audio/test_translations.py +++ b/tests/api_resources/audio/test_translations.py @@ -1,4 +1,4 @@ -# File generated from our OpenAPI spec by Stainless. +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. from __future__ import annotations diff --git a/tests/api_resources/chat/__init__.py b/tests/api_resources/chat/__init__.py index 1016754..fd8019a 100644 --- a/tests/api_resources/chat/__init__.py +++ b/tests/api_resources/chat/__init__.py @@ -1 +1 @@ -# File generated from our OpenAPI spec by Stainless. +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. diff --git a/tests/api_resources/chat/test_completions.py b/tests/api_resources/chat/test_completions.py index 1fdfc34..9a3c0d8 100644 --- a/tests/api_resources/chat/test_completions.py +++ b/tests/api_resources/chat/test_completions.py @@ -1,4 +1,4 @@ -# File generated from our OpenAPI spec by Stainless. +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. from __future__ import annotations diff --git a/tests/api_resources/test_models.py b/tests/api_resources/test_models.py index 64fb2e9..e07c0d8 100644 --- a/tests/api_resources/test_models.py +++ b/tests/api_resources/test_models.py @@ -1,4 +1,4 @@ -# File generated from our OpenAPI spec by Stainless. +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. from __future__ import annotations diff --git a/tests/test_client.py b/tests/test_client.py index 158bd01..6426d8a 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,4 +1,4 @@ -# File generated from our OpenAPI spec by Stainless. +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. from __future__ import annotations @@ -17,7 +17,6 @@ from pydantic import ValidationError from groq import Groq, AsyncGroq, APIResponseValidationError -from groq._client import Groq, AsyncGroq from groq._models import BaseModel, FinalRequestOptions from groq._constants import RAW_RESPONSE_HEADER from groq._exceptions import APIStatusError, APITimeoutError, APIResponseValidationError @@ -630,6 +629,10 @@ class Model(BaseModel): assert isinstance(exc.value.__cause__, ValidationError) + def test_client_max_retries_validation(self) -> None: + with pytest.raises(TypeError, match=r"max_retries cannot be None"): + Groq(base_url=base_url, api_key=api_key, _strict_response_validation=True, max_retries=cast(Any, None)) + @pytest.mark.respx(base_url=base_url) def test_received_text_for_expected_json(self, respx_mock: MockRouter) -> None: class Model(BaseModel): @@ -1332,6 +1335,10 @@ class Model(BaseModel): assert isinstance(exc.value.__cause__, ValidationError) + async def test_client_max_retries_validation(self) -> None: + with pytest.raises(TypeError, match=r"max_retries cannot be None"): + AsyncGroq(base_url=base_url, api_key=api_key, _strict_response_validation=True, max_retries=cast(Any, None)) + @pytest.mark.respx(base_url=base_url) @pytest.mark.asyncio async def test_received_text_for_expected_json(self, respx_mock: MockRouter) -> None: diff --git a/tests/test_models.py b/tests/test_models.py index 095b12e..af5307e 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -1,14 +1,15 @@ import json from typing import Any, Dict, List, Union, Optional, cast from datetime import datetime, timezone -from typing_extensions import Literal +from typing_extensions import Literal, Annotated import pytest import pydantic from pydantic import Field +from groq._utils import PropertyInfo from groq._compat import PYDANTIC_V2, parse_obj, model_dump, model_json -from groq._models import BaseModel +from groq._models import BaseModel, construct_type class BasicModel(BaseModel): @@ -500,6 +501,42 @@ class Model(BaseModel): assert "resource_id" in m.model_fields_set +def test_to_dict() -> None: + class Model(BaseModel): + foo: Optional[str] = Field(alias="FOO", default=None) + + m = Model(FOO="hello") + assert m.to_dict() == {"FOO": "hello"} + assert m.to_dict(use_api_names=False) == {"foo": "hello"} + + m2 = Model() + assert m2.to_dict() == {} + assert m2.to_dict(exclude_unset=False) == {"FOO": None} + assert m2.to_dict(exclude_unset=False, exclude_none=True) == {} + assert m2.to_dict(exclude_unset=False, exclude_defaults=True) == {} + + m3 = Model(FOO=None) + assert m3.to_dict() == {"FOO": None} + assert m3.to_dict(exclude_none=True) == {} + assert m3.to_dict(exclude_defaults=True) == {} + + if PYDANTIC_V2: + + class Model2(BaseModel): + created_at: datetime + + time_str = "2024-03-21T11:39:01.275859" + m4 = Model2.construct(created_at=time_str) + assert m4.to_dict(mode="python") == {"created_at": datetime.fromisoformat(time_str)} + assert m4.to_dict(mode="json") == {"created_at": time_str} + else: + with pytest.raises(ValueError, match="mode is only supported in Pydantic v2"): + m.to_dict(mode="json") + + with pytest.raises(ValueError, match="warnings is only supported in Pydantic v2"): + m.to_dict(warnings=False) + + def test_forwards_compat_model_dump_method() -> None: class Model(BaseModel): foo: Optional[str] = Field(alias="FOO", default=None) @@ -531,6 +568,34 @@ class Model(BaseModel): m.model_dump(warnings=False) +def test_to_json() -> None: + class Model(BaseModel): + foo: Optional[str] = Field(alias="FOO", default=None) + + m = Model(FOO="hello") + assert json.loads(m.to_json()) == {"FOO": "hello"} + assert json.loads(m.to_json(use_api_names=False)) == {"foo": "hello"} + + if PYDANTIC_V2: + assert m.to_json(indent=None) == '{"FOO":"hello"}' + else: + assert m.to_json(indent=None) == '{"FOO": "hello"}' + + m2 = Model() + assert json.loads(m2.to_json()) == {} + assert json.loads(m2.to_json(exclude_unset=False)) == {"FOO": None} + assert json.loads(m2.to_json(exclude_unset=False, exclude_none=True)) == {} + assert json.loads(m2.to_json(exclude_unset=False, exclude_defaults=True)) == {} + + m3 = Model(FOO=None) + assert json.loads(m3.to_json()) == {"FOO": None} + assert json.loads(m3.to_json(exclude_none=True)) == {} + + if not PYDANTIC_V2: + with pytest.raises(ValueError, match="warnings is only supported in Pydantic v2"): + m.to_json(warnings=False) + + def test_forwards_compat_model_dump_json_method() -> None: class Model(BaseModel): foo: Optional[str] = Field(alias="FOO", default=None) @@ -571,3 +636,194 @@ class OurModel(BaseModel): foo: Optional[str] = None takes_pydantic(OurModel()) + + +def test_annotated_types() -> None: + class Model(BaseModel): + value: str + + m = construct_type( + value={"value": "foo"}, + type_=cast(Any, Annotated[Model, "random metadata"]), + ) + assert isinstance(m, Model) + assert m.value == "foo" + + +def test_discriminated_unions_invalid_data() -> None: + class A(BaseModel): + type: Literal["a"] + + data: str + + class B(BaseModel): + type: Literal["b"] + + data: int + + m = construct_type( + value={"type": "b", "data": "foo"}, + type_=cast(Any, Annotated[Union[A, B], PropertyInfo(discriminator="type")]), + ) + assert isinstance(m, B) + assert m.type == "b" + assert m.data == "foo" # type: ignore[comparison-overlap] + + m = construct_type( + value={"type": "a", "data": 100}, + type_=cast(Any, Annotated[Union[A, B], PropertyInfo(discriminator="type")]), + ) + assert isinstance(m, A) + assert m.type == "a" + if PYDANTIC_V2: + assert m.data == 100 # type: ignore[comparison-overlap] + else: + # pydantic v1 automatically converts inputs to strings + # if the expected type is a str + assert m.data == "100" + + +def test_discriminated_unions_unknown_variant() -> None: + class A(BaseModel): + type: Literal["a"] + + data: str + + class B(BaseModel): + type: Literal["b"] + + data: int + + m = construct_type( + value={"type": "c", "data": None, "new_thing": "bar"}, + type_=cast(Any, Annotated[Union[A, B], PropertyInfo(discriminator="type")]), + ) + + # just chooses the first variant + assert isinstance(m, A) + assert m.type == "c" # type: ignore[comparison-overlap] + assert m.data == None # type: ignore[unreachable] + assert m.new_thing == "bar" + + +def test_discriminated_unions_invalid_data_nested_unions() -> None: + class A(BaseModel): + type: Literal["a"] + + data: str + + class B(BaseModel): + type: Literal["b"] + + data: int + + class C(BaseModel): + type: Literal["c"] + + data: bool + + m = construct_type( + value={"type": "b", "data": "foo"}, + type_=cast(Any, Annotated[Union[Union[A, B], C], PropertyInfo(discriminator="type")]), + ) + assert isinstance(m, B) + assert m.type == "b" + assert m.data == "foo" # type: ignore[comparison-overlap] + + m = construct_type( + value={"type": "c", "data": "foo"}, + type_=cast(Any, Annotated[Union[Union[A, B], C], PropertyInfo(discriminator="type")]), + ) + assert isinstance(m, C) + assert m.type == "c" + assert m.data == "foo" # type: ignore[comparison-overlap] + + +def test_discriminated_unions_with_aliases_invalid_data() -> None: + class A(BaseModel): + foo_type: Literal["a"] = Field(alias="type") + + data: str + + class B(BaseModel): + foo_type: Literal["b"] = Field(alias="type") + + data: int + + m = construct_type( + value={"type": "b", "data": "foo"}, + type_=cast(Any, Annotated[Union[A, B], PropertyInfo(discriminator="foo_type")]), + ) + assert isinstance(m, B) + assert m.foo_type == "b" + assert m.data == "foo" # type: ignore[comparison-overlap] + + m = construct_type( + value={"type": "a", "data": 100}, + type_=cast(Any, Annotated[Union[A, B], PropertyInfo(discriminator="foo_type")]), + ) + assert isinstance(m, A) + assert m.foo_type == "a" + if PYDANTIC_V2: + assert m.data == 100 # type: ignore[comparison-overlap] + else: + # pydantic v1 automatically converts inputs to strings + # if the expected type is a str + assert m.data == "100" + + +def test_discriminated_unions_overlapping_discriminators_invalid_data() -> None: + class A(BaseModel): + type: Literal["a"] + + data: bool + + class B(BaseModel): + type: Literal["a"] + + data: int + + m = construct_type( + value={"type": "a", "data": "foo"}, + type_=cast(Any, Annotated[Union[A, B], PropertyInfo(discriminator="type")]), + ) + assert isinstance(m, B) + assert m.type == "a" + assert m.data == "foo" # type: ignore[comparison-overlap] + + +def test_discriminated_unions_invalid_data_uses_cache() -> None: + class A(BaseModel): + type: Literal["a"] + + data: str + + class B(BaseModel): + type: Literal["b"] + + data: int + + UnionType = cast(Any, Union[A, B]) + + assert not hasattr(UnionType, "__discriminator__") + + m = construct_type( + value={"type": "b", "data": "foo"}, type_=cast(Any, Annotated[UnionType, PropertyInfo(discriminator="type")]) + ) + assert isinstance(m, B) + assert m.type == "b" + assert m.data == "foo" # type: ignore[comparison-overlap] + + discriminator = UnionType.__discriminator__ + assert discriminator is not None + + m = construct_type( + value={"type": "b", "data": "foo"}, type_=cast(Any, Annotated[UnionType, PropertyInfo(discriminator="type")]) + ) + assert isinstance(m, B) + assert m.type == "b" + assert m.data == "foo" # type: ignore[comparison-overlap] + + # if the discriminator details object stays the same between invocations then + # we hit the cache + assert UnionType.__discriminator__ is discriminator diff --git a/tests/test_response.py b/tests/test_response.py index eea1264..a89f629 100644 --- a/tests/test_response.py +++ b/tests/test_response.py @@ -1,5 +1,6 @@ import json -from typing import List +from typing import List, cast +from typing_extensions import Annotated import httpx import pytest @@ -157,3 +158,37 @@ async def test_async_response_parse_custom_model(async_client: AsyncGroq) -> Non obj = await response.parse(to=CustomModel) assert obj.foo == "hello!" assert obj.bar == 2 + + +def test_response_parse_annotated_type(client: Groq) -> None: + response = APIResponse( + raw=httpx.Response(200, content=json.dumps({"foo": "hello!", "bar": 2})), + client=client, + stream=False, + stream_cls=None, + cast_to=str, + options=FinalRequestOptions.construct(method="get", url="/foo"), + ) + + obj = response.parse( + to=cast("type[CustomModel]", Annotated[CustomModel, "random metadata"]), + ) + assert obj.foo == "hello!" + assert obj.bar == 2 + + +async def test_async_response_parse_annotated_type(async_client: AsyncGroq) -> None: + response = AsyncAPIResponse( + raw=httpx.Response(200, content=json.dumps({"foo": "hello!", "bar": 2})), + client=async_client, + stream=False, + stream_cls=None, + cast_to=str, + options=FinalRequestOptions.construct(method="get", url="/foo"), + ) + + obj = await response.parse( + to=cast("type[CustomModel]", Annotated[CustomModel, "random metadata"]), + ) + assert obj.foo == "hello!" + assert obj.bar == 2 diff --git a/tests/test_streaming.py b/tests/test_streaming.py index be16333..6209220 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -1,104 +1,248 @@ +from __future__ import annotations + from typing import Iterator, AsyncIterator +import httpx import pytest -from groq._streaming import SSEDecoder +from groq import Groq, AsyncGroq +from groq._streaming import Stream, AsyncStream, ServerSentEvent @pytest.mark.asyncio -async def test_basic_async() -> None: - async def body() -> AsyncIterator[str]: - yield "event: completion" - yield 'data: {"foo":true}' - yield "" - - async for sse in SSEDecoder().aiter(body()): - assert sse.event == "completion" - assert sse.json() == {"foo": True} +@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"]) +async def test_basic(sync: bool, client: Groq, async_client: AsyncGroq) -> None: + def body() -> Iterator[bytes]: + yield b"event: completion\n" + yield b'data: {"foo":true}\n' + yield b"\n" + iterator = make_event_iterator(content=body(), sync=sync, client=client, async_client=async_client) -def test_basic() -> None: - def body() -> Iterator[str]: - yield "event: completion" - yield 'data: {"foo":true}' - yield "" - - it = SSEDecoder().iter(body()) - sse = next(it) + sse = await iter_next(iterator) assert sse.event == "completion" assert sse.json() == {"foo": True} - with pytest.raises(StopIteration): - next(it) + await assert_empty_iter(iterator) -def test_data_missing_event() -> None: - def body() -> Iterator[str]: - yield 'data: {"foo":true}' - yield "" +@pytest.mark.asyncio +@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"]) +async def test_data_missing_event(sync: bool, client: Groq, async_client: AsyncGroq) -> None: + def body() -> Iterator[bytes]: + yield b'data: {"foo":true}\n' + yield b"\n" - it = SSEDecoder().iter(body()) - sse = next(it) + iterator = make_event_iterator(content=body(), sync=sync, client=client, async_client=async_client) + + sse = await iter_next(iterator) assert sse.event is None assert sse.json() == {"foo": True} - with pytest.raises(StopIteration): - next(it) + await assert_empty_iter(iterator) + +@pytest.mark.asyncio +@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"]) +async def test_event_missing_data(sync: bool, client: Groq, async_client: AsyncGroq) -> None: + def body() -> Iterator[bytes]: + yield b"event: ping\n" + yield b"\n" -def test_event_missing_data() -> None: - def body() -> Iterator[str]: - yield "event: ping" - yield "" + iterator = make_event_iterator(content=body(), sync=sync, client=client, async_client=async_client) - it = SSEDecoder().iter(body()) - sse = next(it) + sse = await iter_next(iterator) assert sse.event == "ping" assert sse.data == "" - with pytest.raises(StopIteration): - next(it) + await assert_empty_iter(iterator) -def test_multiple_events() -> None: - def body() -> Iterator[str]: - yield "event: ping" - yield "" - yield "event: completion" - yield "" +@pytest.mark.asyncio +@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"]) +async def test_multiple_events(sync: bool, client: Groq, async_client: AsyncGroq) -> None: + def body() -> Iterator[bytes]: + yield b"event: ping\n" + yield b"\n" + yield b"event: completion\n" + yield b"\n" - it = SSEDecoder().iter(body()) + iterator = make_event_iterator(content=body(), sync=sync, client=client, async_client=async_client) - sse = next(it) + sse = await iter_next(iterator) assert sse.event == "ping" assert sse.data == "" - sse = next(it) + sse = await iter_next(iterator) assert sse.event == "completion" assert sse.data == "" - with pytest.raises(StopIteration): - next(it) - - -def test_multiple_events_with_data() -> None: - def body() -> Iterator[str]: - yield "event: ping" - yield 'data: {"foo":true}' - yield "" - yield "event: completion" - yield 'data: {"bar":false}' - yield "" + await assert_empty_iter(iterator) - it = SSEDecoder().iter(body()) - sse = next(it) +@pytest.mark.asyncio +@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"]) +async def test_multiple_events_with_data(sync: bool, client: Groq, async_client: AsyncGroq) -> None: + def body() -> Iterator[bytes]: + yield b"event: ping\n" + yield b'data: {"foo":true}\n' + yield b"\n" + yield b"event: completion\n" + yield b'data: {"bar":false}\n' + yield b"\n" + + iterator = make_event_iterator(content=body(), sync=sync, client=client, async_client=async_client) + + sse = await iter_next(iterator) assert sse.event == "ping" assert sse.json() == {"foo": True} - sse = next(it) + sse = await iter_next(iterator) assert sse.event == "completion" assert sse.json() == {"bar": False} - with pytest.raises(StopIteration): - next(it) + await assert_empty_iter(iterator) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"]) +async def test_multiple_data_lines_with_empty_line(sync: bool, client: Groq, async_client: AsyncGroq) -> None: + def body() -> Iterator[bytes]: + yield b"event: ping\n" + yield b"data: {\n" + yield b'data: "foo":\n' + yield b"data: \n" + yield b"data:\n" + yield b"data: true}\n" + yield b"\n\n" + + iterator = make_event_iterator(content=body(), sync=sync, client=client, async_client=async_client) + + sse = await iter_next(iterator) + assert sse.event == "ping" + assert sse.json() == {"foo": True} + assert sse.data == '{\n"foo":\n\n\ntrue}' + + await assert_empty_iter(iterator) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"]) +async def test_data_json_escaped_double_new_line(sync: bool, client: Groq, async_client: AsyncGroq) -> None: + def body() -> Iterator[bytes]: + yield b"event: ping\n" + yield b'data: {"foo": "my long\\n\\ncontent"}' + yield b"\n\n" + + iterator = make_event_iterator(content=body(), sync=sync, client=client, async_client=async_client) + + sse = await iter_next(iterator) + assert sse.event == "ping" + assert sse.json() == {"foo": "my long\n\ncontent"} + + await assert_empty_iter(iterator) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"]) +async def test_multiple_data_lines(sync: bool, client: Groq, async_client: AsyncGroq) -> None: + def body() -> Iterator[bytes]: + yield b"event: ping\n" + yield b"data: {\n" + yield b'data: "foo":\n' + yield b"data: true}\n" + yield b"\n\n" + + iterator = make_event_iterator(content=body(), sync=sync, client=client, async_client=async_client) + + sse = await iter_next(iterator) + assert sse.event == "ping" + assert sse.json() == {"foo": True} + + await assert_empty_iter(iterator) + + +@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"]) +async def test_special_new_line_character( + sync: bool, + client: Groq, + async_client: AsyncGroq, +) -> None: + def body() -> Iterator[bytes]: + yield b'data: {"content":" culpa"}\n' + yield b"\n" + yield b'data: {"content":" \xe2\x80\xa8"}\n' + yield b"\n" + yield b'data: {"content":"foo"}\n' + yield b"\n" + + iterator = make_event_iterator(content=body(), sync=sync, client=client, async_client=async_client) + + sse = await iter_next(iterator) + assert sse.event is None + assert sse.json() == {"content": " culpa"} + + sse = await iter_next(iterator) + assert sse.event is None + assert sse.json() == {"content": " 
"} + + sse = await iter_next(iterator) + assert sse.event is None + assert sse.json() == {"content": "foo"} + + await assert_empty_iter(iterator) + + +@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"]) +async def test_multi_byte_character_multiple_chunks( + sync: bool, + client: Groq, + async_client: AsyncGroq, +) -> None: + def body() -> Iterator[bytes]: + yield b'data: {"content":"' + # bytes taken from the string 'известни' and arbitrarily split + # so that some multi-byte characters span multiple chunks + yield b"\xd0" + yield b"\xb8\xd0\xb7\xd0" + yield b"\xb2\xd0\xb5\xd1\x81\xd1\x82\xd0\xbd\xd0\xb8" + yield b'"}\n' + yield b"\n" + + iterator = make_event_iterator(content=body(), sync=sync, client=client, async_client=async_client) + + sse = await iter_next(iterator) + assert sse.event is None + assert sse.json() == {"content": "известни"} + + +async def to_aiter(iter: Iterator[bytes]) -> AsyncIterator[bytes]: + for chunk in iter: + yield chunk + + +async def iter_next(iter: Iterator[ServerSentEvent] | AsyncIterator[ServerSentEvent]) -> ServerSentEvent: + if isinstance(iter, AsyncIterator): + return await iter.__anext__() + + return next(iter) + + +async def assert_empty_iter(iter: Iterator[ServerSentEvent] | AsyncIterator[ServerSentEvent]) -> None: + with pytest.raises((StopAsyncIteration, RuntimeError)): + await iter_next(iter) + + +def make_event_iterator( + content: Iterator[bytes], + *, + sync: bool, + client: Groq, + async_client: AsyncGroq, +) -> Iterator[ServerSentEvent] | AsyncIterator[ServerSentEvent]: + if sync: + return Stream(cast_to=object, client=client, response=httpx.Response(200, content=content))._iter_events() + + return AsyncStream( + cast_to=object, client=async_client, response=httpx.Response(200, content=to_aiter(content)) + )._iter_events() diff --git a/tests/utils.py b/tests/utils.py index 5e3ba5a..96a9cba 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -14,6 +14,8 @@ is_list, is_list_type, is_union_type, + extract_type_arg, + is_annotated_type, ) from groq._compat import PYDANTIC_V2, field_outer_type, get_model_fields from groq._models import BaseModel @@ -49,6 +51,10 @@ def assert_matches_type( path: list[str], allow_none: bool = False, ) -> None: + # unwrap `Annotated[T, ...]` -> `T` + if is_annotated_type(type_): + type_ = extract_type_arg(type_, 0) + if allow_none and value is None: return @@ -91,7 +97,22 @@ def assert_matches_type( assert_matches_type(key_type, key, path=[*path, ""]) assert_matches_type(items_type, item, path=[*path, ""]) elif is_union_type(type_): - for i, variant in enumerate(get_args(type_)): + variants = get_args(type_) + + try: + none_index = variants.index(type(None)) + except ValueError: + pass + else: + # special case Optional[T] for better error messages + if len(variants) == 2: + if value is None: + # valid + return + + return assert_matches_type(type_=variants[not none_index], value=value, path=path) + + for i, variant in enumerate(variants): try: assert_matches_type(variant, value, path=[*path, f"variant {i}"]) return