Skip to content

Commit

Permalink
feat(api): Tool calling features
Browse files Browse the repository at this point in the history
Add parallel tool calling option to chat completions
Allow 'required' as a function call option
  • Loading branch information
stainless-app[bot] authored and gradenr committed Jun 11, 2024
1 parent 82d845b commit c081730
Show file tree
Hide file tree
Showing 19 changed files with 75 additions and 32 deletions.
2 changes: 1 addition & 1 deletion .devcontainer/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ FROM mcr.microsoft.com/vscode/devcontainers/python:0-${VARIANT}

USER vscode

RUN curl -sSf https://rye-up.com/get | RYE_VERSION="0.24.0" RYE_INSTALL_OPTION="--yes" bash
RUN curl -sSf https://rye.astral.sh/get | RYE_VERSION="0.24.0" RYE_INSTALL_OPTION="--yes" bash
ENV PATH=/home/vscode/.rye/shims:$PATH

RUN echo "[[ -d .venv ]] && source .venv/bin/activate" >> /home/vscode/.bashrc
4 changes: 2 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:

- name: Install Rye
run: |
curl -sSf https://rye-up.com/get | bash
curl -sSf https://rye.astral.sh/get | bash
echo "$HOME/.rye/shims" >> $GITHUB_PATH
env:
RYE_VERSION: 0.24.0
Expand All @@ -38,7 +38,7 @@ jobs:

- name: Install Rye
run: |
curl -sSf https://rye-up.com/get | bash
curl -sSf https://rye.astral.sh/get | bash
echo "$HOME/.rye/shims" >> $GITHUB_PATH
env:
RYE_VERSION: 0.24.0
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/publish-pypi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:

- name: Install Rye
run: |
curl -sSf https://rye-up.com/get | bash
curl -sSf https://rye.astral.sh/get | bash
echo "$HOME/.rye/shims" >> $GITHUB_PATH
env:
RYE_VERSION: 0.24.0
Expand Down
2 changes: 1 addition & 1 deletion .stats.yml
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
configured_endpoints: 7
openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/groqcloud%2Fgroqcloud-36fcf453a77cbc8279361577a1e785a3a86ef7bcbde2195270a83e93cbc4b8b3.yml
openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/groqcloud%2Fgroqcloud-862b73eb4a57968b79e739f345b7f4523b8273477544fb03056c9e7d4230f42d.yml
2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

### With Rye

We use [Rye](https://rye-up.com/) to manage dependencies so we highly recommend [installing it](https://rye-up.com/guide/installation/) as it will automatically provision a Python environment with the expected Python version.
We use [Rye](https://rye.astral.sh/) to manage dependencies so we highly recommend [installing it](https://rye.astral.sh/guide/installation/) as it will automatically provision a Python environment with the expected Python version.

After installing Rye, you'll just have to run this command:

Expand Down
2 changes: 1 addition & 1 deletion requirements-dev.lock
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ pydantic==2.7.1
# via groq
pydantic-core==2.18.2
# via pydantic
pyright==1.1.359
pyright==1.1.364
pytest==7.1.1
# via pytest-asyncio
pytest-asyncio==0.21.1
Expand Down
2 changes: 1 addition & 1 deletion scripts/bootstrap
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@ echo "==> Installing Python dependencies…"
# experimental uv support makes installations significantly faster
rye config --set-bool behavior.use-uv=true

rye sync
rye sync --all-features
3 changes: 1 addition & 2 deletions src/groq/_utils/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import sniffio

from .._types import Headers, NotGiven, FileTypes, NotGivenOr, HeadersLike
from .._types import NotGiven, FileTypes, NotGivenOr, HeadersLike
from .._compat import parse_date as parse_date, parse_datetime as parse_datetime

_T = TypeVar("_T")
Expand Down Expand Up @@ -370,7 +370,6 @@ def file_from_path(path: str) -> FileTypes:
def get_required_header(headers: HeadersLike, header: str) -> str:
lower_header = header.lower()
if isinstance(headers, Mapping):
headers = cast(Headers, headers)
for k, v in headers.items():
if k.lower() == lower_header and isinstance(v, str):
return v
Expand Down
14 changes: 10 additions & 4 deletions src/groq/resources/chat/completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def create(
logprobs: Optional[bool] | NotGiven = NOT_GIVEN,
max_tokens: Optional[int] | NotGiven = NOT_GIVEN,
n: Optional[int] | NotGiven = NOT_GIVEN,
parallel_tool_calls: Optional[bool] | NotGiven = NOT_GIVEN,
presence_penalty: Optional[float] | NotGiven = NOT_GIVEN,
response_format: Optional[completion_create_params.ResponseFormat] | NotGiven = NOT_GIVEN,
seed: Optional[int] | NotGiven = NOT_GIVEN,
Expand Down Expand Up @@ -89,6 +90,7 @@ def create(
logprobs: Optional[bool] | NotGiven = NOT_GIVEN,
max_tokens: Optional[int] | NotGiven = NOT_GIVEN,
n: Optional[int] | NotGiven = NOT_GIVEN,
parallel_tool_calls: Optional[bool] | NotGiven = NOT_GIVEN,
presence_penalty: Optional[float] | NotGiven = NOT_GIVEN,
response_format: Optional[completion_create_params.ResponseFormat] | NotGiven = NOT_GIVEN,
seed: Optional[int] | NotGiven = NOT_GIVEN,
Expand Down Expand Up @@ -122,6 +124,7 @@ def create(
logprobs: Optional[bool] | NotGiven = NOT_GIVEN,
max_tokens: Optional[int] | NotGiven = NOT_GIVEN,
n: Optional[int] | NotGiven = NOT_GIVEN,
parallel_tool_calls: Optional[bool] | NotGiven = NOT_GIVEN,
presence_penalty: Optional[float] | NotGiven = NOT_GIVEN,
response_format: Optional[completion_create_params.ResponseFormat] | NotGiven = NOT_GIVEN,
seed: Optional[int] | NotGiven = NOT_GIVEN,
Expand Down Expand Up @@ -212,8 +215,8 @@ def create(
context length.
n: How many chat completion choices to generate for each input message. Note that
you will be charged based on the number of generated tokens across all of the
choices. Keep `n` as `1` to minimize costs.
the current moment, only n=1 is supported. Other values will result in a 400
response.
presence_penalty: Number between -2.0 and 2.0. Positive values penalize new tokens based on
whether they appear in the text so far, increasing the model's likelihood to
Expand Down Expand Up @@ -338,6 +341,7 @@ async def create(
logprobs: Optional[bool] | NotGiven = NOT_GIVEN,
max_tokens: Optional[int] | NotGiven = NOT_GIVEN,
n: Optional[int] | NotGiven = NOT_GIVEN,
parallel_tool_calls: Optional[bool] | NotGiven = NOT_GIVEN,
presence_penalty: Optional[float] | NotGiven = NOT_GIVEN,
response_format: Optional[completion_create_params.ResponseFormat] | NotGiven = NOT_GIVEN,
seed: Optional[int] | NotGiven = NOT_GIVEN,
Expand Down Expand Up @@ -371,6 +375,7 @@ async def create(
logprobs: Optional[bool] | NotGiven = NOT_GIVEN,
max_tokens: Optional[int] | NotGiven = NOT_GIVEN,
n: Optional[int] | NotGiven = NOT_GIVEN,
parallel_tool_calls: Optional[bool] | NotGiven = NOT_GIVEN,
presence_penalty: Optional[float] | NotGiven = NOT_GIVEN,
response_format: Optional[completion_create_params.ResponseFormat] | NotGiven = NOT_GIVEN,
seed: Optional[int] | NotGiven = NOT_GIVEN,
Expand Down Expand Up @@ -404,6 +409,7 @@ async def create(
logprobs: Optional[bool] | NotGiven = NOT_GIVEN,
max_tokens: Optional[int] | NotGiven = NOT_GIVEN,
n: Optional[int] | NotGiven = NOT_GIVEN,
parallel_tool_calls: Optional[bool] | NotGiven = NOT_GIVEN,
presence_penalty: Optional[float] | NotGiven = NOT_GIVEN,
response_format: Optional[completion_create_params.ResponseFormat] | NotGiven = NOT_GIVEN,
seed: Optional[int] | NotGiven = NOT_GIVEN,
Expand Down Expand Up @@ -494,8 +500,8 @@ async def create(
context length.
n: How many chat completion choices to generate for each input message. Note that
you will be charged based on the number of generated tokens across all of the
choices. Keep `n` as `1` to minimize costs.
the current moment, only n=1 is supported. Other values will result in a 400
response.
presence_penalty: Number between -2.0 and 2.0. Positive values penalize new tokens based on
whether they appear in the text so far, increasing the model's likelihood to
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@


class FunctionCall(TypedDict, total=False):
arguments: Required[str]
arguments: str
"""
The arguments to call the function with, as generated by the model in JSON
format. Note that the model does not always generate valid JSON, and may
hallucinate parameters not defined by your function schema. Validate the
arguments in your code before calling your function.
"""

name: Required[str]
name: str
"""The name of the function to call."""


Expand Down Expand Up @@ -47,5 +47,8 @@ class ChatCompletionAssistantMessageParam(TypedDict, total=False):
role.
"""

tool_call_id: Optional[str]
"""DO NOT USE. This field is present because OpenAI allows it and userssend it."""

tool_calls: Iterable[ChatCompletionMessageToolCallParam]
"""The tool calls generated by the model, such as function calls."""
3 changes: 3 additions & 0 deletions src/groq/types/chat/chat_completion_function_message_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,6 @@ class ChatCompletionFunctionMessageParam(TypedDict, total=False):

role: Required[Literal["function"]]
"""The role of the messages author, in this case `function`."""

tool_call_id: Optional[str]
"""DO NOT USE. This field is present because OpenAI allows it and users send it."""
4 changes: 4 additions & 0 deletions src/groq/types/chat/chat_completion_system_message_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

from typing import Optional
from typing_extensions import Literal, Required, TypedDict

__all__ = ["ChatCompletionSystemMessageParam"]
Expand All @@ -20,3 +21,6 @@ class ChatCompletionSystemMessageParam(TypedDict, total=False):
Provides the model information to differentiate between participants of the same
role.
"""

tool_call_id: Optional[str]
"""DO NOT USE. This field is present because OpenAI allows it and userssend it."""
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@

__all__ = ["ChatCompletionToolChoiceOptionParam"]

ChatCompletionToolChoiceOptionParam = Union[Literal["none", "auto"], ChatCompletionNamedToolChoiceParam]
ChatCompletionToolChoiceOptionParam = Union[Literal["none", "auto", "required"], ChatCompletionNamedToolChoiceParam]
3 changes: 3 additions & 0 deletions src/groq/types/chat/chat_completion_tool_message_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,6 @@ class ChatCompletionToolMessageParam(TypedDict, total=False):

tool_call_id: Required[str]
"""Tool call that this message is responding to."""

name: str
"""DO NOT USE. This field is present because OpenAI allows it and userssend it."""
5 changes: 4 additions & 1 deletion src/groq/types/chat/chat_completion_user_message_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from __future__ import annotations

from typing import Union, Iterable
from typing import Union, Iterable, Optional
from typing_extensions import Literal, Required, TypedDict

from .chat_completion_content_part_param import ChatCompletionContentPartParam
Expand All @@ -23,3 +23,6 @@ class ChatCompletionUserMessageParam(TypedDict, total=False):
Provides the model information to differentiate between participants of the same
role.
"""

tool_call_id: Optional[str]
"""DO NOT USE. This field is present because OpenAI allows it and userssend it."""
12 changes: 6 additions & 6 deletions src/groq/types/chat/completion_create_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ class CompletionCreateParams(TypedDict, total=False):
n: Optional[int]
"""How many chat completion choices to generate for each input message.
Note that you will be charged based on the number of generated tokens across all
of the choices. Keep `n` as `1` to minimize costs.
Note that the current moment, only n=1 is supported. Other values will result in
a 400 response.
"""

presence_penalty: Optional[float]
Expand Down Expand Up @@ -170,7 +170,7 @@ class CompletionCreateParams(TypedDict, total=False):
"""


FunctionCall = Union[Literal["none", "auto"], ChatCompletionFunctionCallOptionParam]
FunctionCall = Union[Literal["none", "auto", "required"], ChatCompletionFunctionCallOptionParam]


class Function(TypedDict, total=False):
Expand All @@ -190,9 +190,9 @@ class Function(TypedDict, total=False):
parameters: shared_params.FunctionParameters
"""The parameters the functions accepts, described as a JSON Schema object.
See the [guide](/docs/guides/text-generation/function-calling) for examples, and
the [JSON Schema reference](https://json-schema.org/understanding-json-schema/)
for documentation about the format.
See the docs on [tool use](/docs/tool-use) for examples, and the
[JSON Schema reference](https://json-schema.org/understanding-json-schema/) for
documentation about the format.
Omitting `parameters` defines a function with an empty parameter list.
"""
Expand Down
6 changes: 3 additions & 3 deletions src/groq/types/shared/function_definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ class FunctionDefinition(BaseModel):
parameters: Optional[FunctionParameters] = None
"""The parameters the functions accepts, described as a JSON Schema object.
See the [guide](/docs/guides/text-generation/function-calling) for examples, and
the [JSON Schema reference](https://json-schema.org/understanding-json-schema/)
for documentation about the format.
See the docs on [tool use](/docs/tool-use) for examples, and the
[JSON Schema reference](https://json-schema.org/understanding-json-schema/) for
documentation about the format.
Omitting `parameters` defines a function with an empty parameter list.
"""
6 changes: 3 additions & 3 deletions src/groq/types/shared_params/function_definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ class FunctionDefinition(TypedDict, total=False):
parameters: shared_params.FunctionParameters
"""The parameters the functions accepts, described as a JSON Schema object.
See the [guide](/docs/guides/text-generation/function-calling) for examples, and
the [JSON Schema reference](https://json-schema.org/understanding-json-schema/)
for documentation about the format.
See the docs on [tool use](/docs/tool-use) for examples, and the
[JSON Schema reference](https://json-schema.org/understanding-json-schema/) for
documentation about the format.
Omitting `parameters` defines a function with an empty parameter list.
"""
26 changes: 24 additions & 2 deletions tests/api_resources/chat/test_completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def test_method_create_with_all_params(self, client: Groq) -> None:
"content": "string",
"role": "system",
"name": "string",
"tool_call_id": "string",
}
],
model="string",
Expand All @@ -48,7 +49,17 @@ def test_method_create_with_all_params(self, client: Groq) -> None:
"description": "string",
"name": "string",
"parameters": {"foo": "bar"},
}
},
{
"description": "string",
"name": "string",
"parameters": {"foo": "bar"},
},
{
"description": "string",
"name": "string",
"parameters": {"foo": "bar"},
},
],
logit_bias={"foo": 0},
logprobs=True,
Expand Down Expand Up @@ -154,6 +165,7 @@ async def test_method_create_with_all_params(self, async_client: AsyncGroq) -> N
"content": "string",
"role": "system",
"name": "string",
"tool_call_id": "string",
}
],
model="string",
Expand All @@ -164,7 +176,17 @@ async def test_method_create_with_all_params(self, async_client: AsyncGroq) -> N
"description": "string",
"name": "string",
"parameters": {"foo": "bar"},
}
},
{
"description": "string",
"name": "string",
"parameters": {"foo": "bar"},
},
{
"description": "string",
"name": "string",
"parameters": {"foo": "bar"},
},
],
logit_bias={"foo": 0},
logprobs=True,
Expand Down

0 comments on commit c081730

Please sign in to comment.