diff --git a/api_internal_test.go b/api_internal_test.go index 0fb0f899..a590ec9a 100644 --- a/api_internal_test.go +++ b/api_internal_test.go @@ -148,3 +148,39 @@ func TestAzureFullURL(t *testing.T) { }) } } + +func TestCloudflareAzureFullURL(t *testing.T) { + cases := []struct { + Name string + BaseURL string + Expect string + }{ + { + "CloudflareAzureBaseURLWithSlashAutoStrip", + "https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo/", + "https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo/" + + "chat/completions?api-version=2023-05-15", + }, + { + "CloudflareAzureBaseURLWithoutSlashOK", + "https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo", + "https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo/" + + "chat/completions?api-version=2023-05-15", + }, + } + + for _, c := range cases { + t.Run(c.Name, func(t *testing.T) { + az := DefaultAzureConfig("dummy", c.BaseURL) + az.APIType = APITypeCloudflareAzure + + cli := NewClientWithConfig(az) + + actual := cli.fullURL("/chat/completions") + if actual != c.Expect { + t.Errorf("Expected %s, got %s", c.Expect, actual) + } + t.Logf("Full URL: %s", actual) + }) + } +} diff --git a/assistant.go b/assistant.go index 9415325f..661681e8 100644 --- a/assistant.go +++ b/assistant.go @@ -11,7 +11,6 @@ import ( const ( assistantsSuffix = "/assistants" assistantsFilesSuffix = "/files" - openaiAssistantsV1 = "assistants=v1" ) type Assistant struct { @@ -116,7 +115,7 @@ type AssistantFilesList struct { // CreateAssistant creates a new assistant. func (c *Client) CreateAssistant(ctx context.Context, request AssistantRequest) (response Assistant, err error) { req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(assistantsSuffix), withBody(request), - withBetaAssistantV1()) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -132,7 +131,7 @@ func (c *Client) RetrieveAssistant( ) (response Assistant, err error) { urlSuffix := fmt.Sprintf("%s/%s", assistantsSuffix, assistantID) req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), - withBetaAssistantV1()) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -149,7 +148,7 @@ func (c *Client) ModifyAssistant( ) (response Assistant, err error) { urlSuffix := fmt.Sprintf("%s/%s", assistantsSuffix, assistantID) req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request), - withBetaAssistantV1()) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -165,7 +164,7 @@ func (c *Client) DeleteAssistant( ) (response AssistantDeleteResponse, err error) { urlSuffix := fmt.Sprintf("%s/%s", assistantsSuffix, assistantID) req, err := c.newRequest(ctx, http.MethodDelete, c.fullURL(urlSuffix), - withBetaAssistantV1()) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -203,7 +202,7 @@ func (c *Client) ListAssistants( urlSuffix := fmt.Sprintf("%s%s", assistantsSuffix, encodedValues) req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), - withBetaAssistantV1()) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -221,7 +220,7 @@ func (c *Client) CreateAssistantFile( urlSuffix := fmt.Sprintf("%s/%s%s", assistantsSuffix, assistantID, assistantsFilesSuffix) req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request), - withBetaAssistantV1()) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -238,7 +237,7 @@ func (c *Client) RetrieveAssistantFile( ) (response AssistantFile, err error) { urlSuffix := fmt.Sprintf("%s/%s%s/%s", assistantsSuffix, assistantID, assistantsFilesSuffix, fileID) req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), - withBetaAssistantV1()) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -255,7 +254,7 @@ func (c *Client) DeleteAssistantFile( ) (err error) { urlSuffix := fmt.Sprintf("%s/%s%s/%s", assistantsSuffix, assistantID, assistantsFilesSuffix, fileID) req, err := c.newRequest(ctx, http.MethodDelete, c.fullURL(urlSuffix), - withBetaAssistantV1()) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -294,7 +293,7 @@ func (c *Client) ListAssistantFiles( urlSuffix := fmt.Sprintf("%s/%s%s%s", assistantsSuffix, assistantID, assistantsFilesSuffix, encodedValues) req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), - withBetaAssistantV1()) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } diff --git a/chat_stream.go b/chat_stream.go index 6ff7078e..159f9f47 100644 --- a/chat_stream.go +++ b/chat_stream.go @@ -30,6 +30,7 @@ type ChatCompletionStreamResponse struct { Created int64 `json:"created"` Model string `json:"model"` Choices []ChatCompletionStreamChoice `json:"choices"` + SystemFingerprint string `json:"system_fingerprint"` PromptAnnotations []PromptAnnotation `json:"prompt_annotations,omitempty"` PromptFilterResults []PromptFilterResult `json:"prompt_filter_results,omitempty"` } diff --git a/chat_stream_test.go b/chat_stream_test.go index bd571cb4..bd1c737d 100644 --- a/chat_stream_test.go +++ b/chat_stream_test.go @@ -46,12 +46,12 @@ func TestCreateChatCompletionStream(t *testing.T) { dataBytes := []byte{} dataBytes = append(dataBytes, []byte("event: message\n")...) //nolint:lll - data := `{"id":"1","object":"completion","created":1598069254,"model":"gpt-3.5-turbo","choices":[{"index":0,"delta":{"content":"response1"},"finish_reason":"max_tokens"}]}` + data := `{"id":"1","object":"completion","created":1598069254,"model":"gpt-3.5-turbo","system_fingerprint": "fp_d9767fc5b9","choices":[{"index":0,"delta":{"content":"response1"},"finish_reason":"max_tokens"}]}` dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) dataBytes = append(dataBytes, []byte("event: message\n")...) //nolint:lll - data = `{"id":"2","object":"completion","created":1598069255,"model":"gpt-3.5-turbo","choices":[{"index":0,"delta":{"content":"response2"},"finish_reason":"max_tokens"}]}` + data = `{"id":"2","object":"completion","created":1598069255,"model":"gpt-3.5-turbo","system_fingerprint": "fp_d9767fc5b9","choices":[{"index":0,"delta":{"content":"response2"},"finish_reason":"max_tokens"}]}` dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) dataBytes = append(dataBytes, []byte("event: done\n")...) @@ -77,10 +77,11 @@ func TestCreateChatCompletionStream(t *testing.T) { expectedResponses := []openai.ChatCompletionStreamResponse{ { - ID: "1", - Object: "completion", - Created: 1598069254, - Model: openai.GPT3Dot5Turbo, + ID: "1", + Object: "completion", + Created: 1598069254, + Model: openai.GPT3Dot5Turbo, + SystemFingerprint: "fp_d9767fc5b9", Choices: []openai.ChatCompletionStreamChoice{ { Delta: openai.ChatCompletionStreamChoiceDelta{ @@ -91,10 +92,11 @@ func TestCreateChatCompletionStream(t *testing.T) { }, }, { - ID: "2", - Object: "completion", - Created: 1598069255, - Model: openai.GPT3Dot5Turbo, + ID: "2", + Object: "completion", + Created: 1598069255, + Model: openai.GPT3Dot5Turbo, + SystemFingerprint: "fp_d9767fc5b9", Choices: []openai.ChatCompletionStreamChoice{ { Delta: openai.ChatCompletionStreamChoiceDelta{ diff --git a/client.go b/client.go index 9a1c8958..c57ba17c 100644 --- a/client.go +++ b/client.go @@ -89,9 +89,9 @@ func withContentType(contentType string) requestOption { } } -func withBetaAssistantV1() requestOption { +func withBetaAssistantVersion(version string) requestOption { return func(args *requestOptions) { - args.header.Set("OpenAI-Beta", "assistants=v1") + args.header.Set("OpenAI-Beta", fmt.Sprintf("assistants=%s", version)) } } @@ -182,7 +182,7 @@ func sendRequestStream[T streamable](client *Client, req *http.Request) (*stream func (c *Client) setCommonHeaders(req *http.Request) { // https://learn.microsoft.com/en-us/azure/cognitive-services/openai/reference#authentication // Azure API Key authentication - if c.config.APIType == APITypeAzure { + if c.config.APIType == APITypeAzure || c.config.APIType == APITypeCloudflareAzure { req.Header.Set(AzureAPIKeyHeader, c.config.authToken) } else if c.config.authToken != "" { // OpenAI or Azure AD authentication @@ -246,7 +246,13 @@ func (c *Client) fullURL(suffix string, args ...any) string { ) } - // c.config.APIType == APITypeOpenAI || c.config.APIType == "" + // https://developers.cloudflare.com/ai-gateway/providers/azureopenai/ + if c.config.APIType == APITypeCloudflareAzure { + baseURL := c.config.BaseURL + baseURL = strings.TrimRight(baseURL, "/") + return fmt.Sprintf("%s%s?api-version=%s", baseURL, suffix, c.config.APIVersion) + } + return fmt.Sprintf("%s%s", c.config.BaseURL, suffix) } diff --git a/config.go b/config.go index c58b71ec..bb437c97 100644 --- a/config.go +++ b/config.go @@ -16,13 +16,16 @@ const ( type APIType string const ( - APITypeOpenAI APIType = "OPEN_AI" - APITypeAzure APIType = "AZURE" - APITypeAzureAD APIType = "AZURE_AD" + APITypeOpenAI APIType = "OPEN_AI" + APITypeAzure APIType = "AZURE" + APITypeAzureAD APIType = "AZURE_AD" + APITypeCloudflareAzure APIType = "CLOUDFLARE_AZURE" ) const AzureAPIKeyHeader = "api-key" +const defaultAssistantVersion = "v1" // This will be deprecated by the end of 2024. + // ClientConfig is a configuration of a client. type ClientConfig struct { authToken string @@ -30,7 +33,8 @@ type ClientConfig struct { BaseURL string OrgID string APIType APIType - APIVersion string // required when APIType is APITypeAzure or APITypeAzureAD + APIVersion string // required when APIType is APITypeAzure or APITypeAzureAD + AssistantVersion string AzureModelMapperFunc func(model string) string // replace model to azure deployment name func HTTPClient *http.Client @@ -39,10 +43,11 @@ type ClientConfig struct { func DefaultConfig(authToken string) ClientConfig { return ClientConfig{ - authToken: authToken, - BaseURL: openaiAPIURLv1, - APIType: APITypeOpenAI, - OrgID: "", + authToken: authToken, + BaseURL: openaiAPIURLv1, + APIType: APITypeOpenAI, + AssistantVersion: defaultAssistantVersion, + OrgID: "", HTTPClient: &http.Client{}, diff --git a/messages.go b/messages.go index 6fd0adbc..6af11844 100644 --- a/messages.go +++ b/messages.go @@ -76,7 +76,8 @@ type MessageFilesList struct { // CreateMessage creates a new message. func (c *Client) CreateMessage(ctx context.Context, threadID string, request MessageRequest) (msg Message, err error) { urlSuffix := fmt.Sprintf("/threads/%s/%s", threadID, messagesSuffix) - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request), withBetaAssistantV1()) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request), + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -111,7 +112,8 @@ func (c *Client) ListMessage(ctx context.Context, threadID string, } urlSuffix := fmt.Sprintf("/threads/%s/%s%s", threadID, messagesSuffix, encodedValues) - req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), withBetaAssistantV1()) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -126,7 +128,8 @@ func (c *Client) RetrieveMessage( threadID, messageID string, ) (msg Message, err error) { urlSuffix := fmt.Sprintf("/threads/%s/%s/%s", threadID, messagesSuffix, messageID) - req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), withBetaAssistantV1()) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -143,7 +146,7 @@ func (c *Client) ModifyMessage( ) (msg Message, err error) { urlSuffix := fmt.Sprintf("/threads/%s/%s/%s", threadID, messagesSuffix, messageID) req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), - withBody(map[string]any{"metadata": metadata}), withBetaAssistantV1()) + withBody(map[string]any{"metadata": metadata}), withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -158,7 +161,8 @@ func (c *Client) RetrieveMessageFile( threadID, messageID, fileID string, ) (file MessageFile, err error) { urlSuffix := fmt.Sprintf("/threads/%s/%s/%s/files/%s", threadID, messagesSuffix, messageID, fileID) - req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), withBetaAssistantV1()) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -173,7 +177,8 @@ func (c *Client) ListMessageFiles( threadID, messageID string, ) (files MessageFilesList, err error) { urlSuffix := fmt.Sprintf("/threads/%s/%s/%s/files", threadID, messagesSuffix, messageID) - req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), withBetaAssistantV1()) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } diff --git a/run.go b/run.go index 7c14779c..094b0a4d 100644 --- a/run.go +++ b/run.go @@ -226,8 +226,7 @@ func (c *Client) CreateRun( http.MethodPost, c.fullURL(urlSuffix), withBody(request), - withBetaAssistantV1(), - ) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -247,8 +246,7 @@ func (c *Client) RetrieveRun( ctx, http.MethodGet, c.fullURL(urlSuffix), - withBetaAssistantV1(), - ) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -270,8 +268,7 @@ func (c *Client) ModifyRun( http.MethodPost, c.fullURL(urlSuffix), withBody(request), - withBetaAssistantV1(), - ) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -310,8 +307,7 @@ func (c *Client) ListRuns( ctx, http.MethodGet, c.fullURL(urlSuffix), - withBetaAssistantV1(), - ) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -332,8 +328,7 @@ func (c *Client) SubmitToolOutputs( http.MethodPost, c.fullURL(urlSuffix), withBody(request), - withBetaAssistantV1(), - ) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -352,8 +347,7 @@ func (c *Client) CancelRun( ctx, http.MethodPost, c.fullURL(urlSuffix), - withBetaAssistantV1(), - ) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -372,8 +366,7 @@ func (c *Client) CreateThreadAndRun( http.MethodPost, c.fullURL(urlSuffix), withBody(request), - withBetaAssistantV1(), - ) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -394,8 +387,7 @@ func (c *Client) RetrieveRunStep( ctx, http.MethodGet, c.fullURL(urlSuffix), - withBetaAssistantV1(), - ) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -435,8 +427,7 @@ func (c *Client) ListRunSteps( ctx, http.MethodGet, c.fullURL(urlSuffix), - withBetaAssistantV1(), - ) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } diff --git a/thread.go b/thread.go index 291f3dca..900e3f2e 100644 --- a/thread.go +++ b/thread.go @@ -51,7 +51,7 @@ type ThreadDeleteResponse struct { // CreateThread creates a new thread. func (c *Client) CreateThread(ctx context.Context, request ThreadRequest) (response Thread, err error) { req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(threadsSuffix), withBody(request), - withBetaAssistantV1()) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -64,7 +64,7 @@ func (c *Client) CreateThread(ctx context.Context, request ThreadRequest) (respo func (c *Client) RetrieveThread(ctx context.Context, threadID string) (response Thread, err error) { urlSuffix := threadsSuffix + "/" + threadID req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), - withBetaAssistantV1()) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -81,7 +81,7 @@ func (c *Client) ModifyThread( ) (response Thread, err error) { urlSuffix := threadsSuffix + "/" + threadID req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request), - withBetaAssistantV1()) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -97,7 +97,7 @@ func (c *Client) DeleteThread( ) (response ThreadDeleteResponse, err error) { urlSuffix := threadsSuffix + "/" + threadID req, err := c.newRequest(ctx, http.MethodDelete, c.fullURL(urlSuffix), - withBetaAssistantV1()) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return }