diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml new file mode 100644 index 00000000..7260b00b --- /dev/null +++ b/.github/workflows/integration-tests.yml @@ -0,0 +1,21 @@ +name: Integration tests + +on: + push: + branches: + - master + +jobs: + integration_tests: + name: Run integration tests + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Setup Go + uses: actions/setup-go@v5 + with: + go-version: '1.21' + - name: Run integration tests + env: + OPENAI_TOKEN: ${{ secrets.OPENAI_TOKEN }} + run: go test -v -tags=integration ./api_integration_test.go diff --git a/api_integration_test.go b/api_integration_test.go index 736040c5..f3468518 100644 --- a/api_integration_test.go +++ b/api_integration_test.go @@ -26,7 +26,7 @@ func TestAPI(t *testing.T) { _, err = c.ListEngines(ctx) checks.NoError(t, err, "ListEngines error") - _, err = c.GetEngine(ctx, "davinci") + _, err = c.GetEngine(ctx, openai.GPT3Davinci002) checks.NoError(t, err, "GetEngine error") fileRes, err := c.ListFiles(ctx) @@ -42,7 +42,7 @@ func TestAPI(t *testing.T) { "The food was delicious and the waiter", "Other examples of embedding request", }, - Model: openai.AdaSearchQuery, + Model: openai.AdaEmbeddingV2, } _, err = c.CreateEmbeddings(ctx, embeddingReq) checks.NoError(t, err, "Embedding error") @@ -77,31 +77,6 @@ func TestAPI(t *testing.T) { ) checks.NoError(t, err, "CreateChatCompletion (with name) returned error") - stream, err := c.CreateCompletionStream(ctx, openai.CompletionRequest{ - Prompt: "Ex falso quodlibet", - Model: openai.GPT3Ada, - MaxTokens: 5, - Stream: true, - }) - checks.NoError(t, err, "CreateCompletionStream returned error") - defer stream.Close() - - counter := 0 - for { - _, err = stream.Recv() - if err != nil { - if errors.Is(err, io.EOF) { - break - } - t.Errorf("Stream error: %v", err) - } else { - counter++ - } - } - if counter == 0 { - t.Error("Stream did not return any responses") - } - _, err = c.CreateChatCompletion( context.Background(), openai.ChatCompletionRequest{ @@ -134,6 +109,41 @@ func TestAPI(t *testing.T) { checks.NoError(t, err, "CreateChatCompletion (with functions) returned error") } +func TestCompletionStream(t *testing.T) { + apiToken := os.Getenv("OPENAI_TOKEN") + if apiToken == "" { + t.Skip("Skipping testing against production OpenAI API. Set OPENAI_TOKEN environment variable to enable it.") + } + + c := openai.NewClient(apiToken) + ctx := context.Background() + + stream, err := c.CreateCompletionStream(ctx, openai.CompletionRequest{ + Prompt: "Ex falso quodlibet", + Model: openai.GPT3Babbage002, + MaxTokens: 5, + Stream: true, + }) + checks.NoError(t, err, "CreateCompletionStream returned error") + defer stream.Close() + + counter := 0 + for { + _, err = stream.Recv() + if err != nil { + if errors.Is(err, io.EOF) { + break + } + t.Errorf("Stream error: %v", err) + } else { + counter++ + } + } + if counter == 0 { + t.Error("Stream did not return any responses") + } +} + func TestAPIError(t *testing.T) { apiToken := os.Getenv("OPENAI_TOKEN") if apiToken == "" { diff --git a/batch.go b/batch.go new file mode 100644 index 00000000..4aba966b --- /dev/null +++ b/batch.go @@ -0,0 +1,275 @@ +package openai + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "net/url" +) + +const batchesSuffix = "/batches" + +type BatchEndpoint string + +const ( + BatchEndpointChatCompletions BatchEndpoint = "/v1/chat/completions" + BatchEndpointCompletions BatchEndpoint = "/v1/completions" + BatchEndpointEmbeddings BatchEndpoint = "/v1/embeddings" +) + +type BatchLineItem interface { + MarshalBatchLineItem() []byte +} + +type BatchChatCompletionRequest struct { + CustomID string `json:"custom_id"` + Body ChatCompletionRequest `json:"body"` + Method string `json:"method"` + URL BatchEndpoint `json:"url"` +} + +func (r BatchChatCompletionRequest) MarshalBatchLineItem() []byte { + marshal, _ := json.Marshal(r) + return marshal +} + +type BatchCompletionRequest struct { + CustomID string `json:"custom_id"` + Body CompletionRequest `json:"body"` + Method string `json:"method"` + URL BatchEndpoint `json:"url"` +} + +func (r BatchCompletionRequest) MarshalBatchLineItem() []byte { + marshal, _ := json.Marshal(r) + return marshal +} + +type BatchEmbeddingRequest struct { + CustomID string `json:"custom_id"` + Body EmbeddingRequest `json:"body"` + Method string `json:"method"` + URL BatchEndpoint `json:"url"` +} + +func (r BatchEmbeddingRequest) MarshalBatchLineItem() []byte { + marshal, _ := json.Marshal(r) + return marshal +} + +type Batch struct { + ID string `json:"id"` + Object string `json:"object"` + Endpoint BatchEndpoint `json:"endpoint"` + Errors *struct { + Object string `json:"object,omitempty"` + Data struct { + Code string `json:"code,omitempty"` + Message string `json:"message,omitempty"` + Param *string `json:"param,omitempty"` + Line *int `json:"line,omitempty"` + } `json:"data"` + } `json:"errors"` + InputFileID string `json:"input_file_id"` + CompletionWindow string `json:"completion_window"` + Status string `json:"status"` + OutputFileID *string `json:"output_file_id"` + ErrorFileID *string `json:"error_file_id"` + CreatedAt int `json:"created_at"` + InProgressAt *int `json:"in_progress_at"` + ExpiresAt *int `json:"expires_at"` + FinalizingAt *int `json:"finalizing_at"` + CompletedAt *int `json:"completed_at"` + FailedAt *int `json:"failed_at"` + ExpiredAt *int `json:"expired_at"` + CancellingAt *int `json:"cancelling_at"` + CancelledAt *int `json:"cancelled_at"` + RequestCounts BatchRequestCounts `json:"request_counts"` + Metadata map[string]any `json:"metadata"` +} + +type BatchRequestCounts struct { + Total int `json:"total"` + Completed int `json:"completed"` + Failed int `json:"failed"` +} + +type CreateBatchRequest struct { + InputFileID string `json:"input_file_id"` + Endpoint BatchEndpoint `json:"endpoint"` + CompletionWindow string `json:"completion_window"` + Metadata map[string]any `json:"metadata"` +} + +type BatchResponse struct { + httpHeader + Batch +} + +var ErrUploadBatchFileFailed = errors.New("upload batch file failed") + +// CreateBatch — API call to Create batch. +func (c *Client) CreateBatch( + ctx context.Context, + request CreateBatchRequest, +) (response BatchResponse, err error) { + if request.CompletionWindow == "" { + request.CompletionWindow = "24h" + } + + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(batchesSuffix), withBody(request)) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +type UploadBatchFileRequest struct { + FileName string + Lines []BatchLineItem +} + +func (r *UploadBatchFileRequest) MarshalJSONL() []byte { + buff := bytes.Buffer{} + for i, line := range r.Lines { + if i != 0 { + buff.Write([]byte("\n")) + } + buff.Write(line.MarshalBatchLineItem()) + } + return buff.Bytes() +} + +func (r *UploadBatchFileRequest) AddChatCompletion(customerID string, body ChatCompletionRequest) { + r.Lines = append(r.Lines, BatchChatCompletionRequest{ + CustomID: customerID, + Body: body, + Method: "POST", + URL: BatchEndpointChatCompletions, + }) +} + +func (r *UploadBatchFileRequest) AddCompletion(customerID string, body CompletionRequest) { + r.Lines = append(r.Lines, BatchCompletionRequest{ + CustomID: customerID, + Body: body, + Method: "POST", + URL: BatchEndpointCompletions, + }) +} + +func (r *UploadBatchFileRequest) AddEmbedding(customerID string, body EmbeddingRequest) { + r.Lines = append(r.Lines, BatchEmbeddingRequest{ + CustomID: customerID, + Body: body, + Method: "POST", + URL: BatchEndpointEmbeddings, + }) +} + +// UploadBatchFile — upload batch file. +func (c *Client) UploadBatchFile(ctx context.Context, request UploadBatchFileRequest) (File, error) { + if request.FileName == "" { + request.FileName = "@batchinput.jsonl" + } + return c.CreateFileBytes(ctx, FileBytesRequest{ + Name: request.FileName, + Bytes: request.MarshalJSONL(), + Purpose: PurposeBatch, + }) +} + +type CreateBatchWithUploadFileRequest struct { + Endpoint BatchEndpoint `json:"endpoint"` + CompletionWindow string `json:"completion_window"` + Metadata map[string]any `json:"metadata"` + UploadBatchFileRequest +} + +// CreateBatchWithUploadFile — API call to Create batch with upload file. +func (c *Client) CreateBatchWithUploadFile( + ctx context.Context, + request CreateBatchWithUploadFileRequest, +) (response BatchResponse, err error) { + var file File + file, err = c.UploadBatchFile(ctx, UploadBatchFileRequest{ + FileName: request.FileName, + Lines: request.Lines, + }) + if err != nil { + err = errors.Join(ErrUploadBatchFileFailed, err) + return + } + return c.CreateBatch(ctx, CreateBatchRequest{ + InputFileID: file.ID, + Endpoint: request.Endpoint, + CompletionWindow: request.CompletionWindow, + Metadata: request.Metadata, + }) +} + +// RetrieveBatch — API call to Retrieve batch. +func (c *Client) RetrieveBatch( + ctx context.Context, + batchID string, +) (response BatchResponse, err error) { + urlSuffix := fmt.Sprintf("%s/%s", batchesSuffix, batchID) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix)) + if err != nil { + return + } + err = c.sendRequest(req, &response) + return +} + +// CancelBatch — API call to Cancel batch. +func (c *Client) CancelBatch( + ctx context.Context, + batchID string, +) (response BatchResponse, err error) { + urlSuffix := fmt.Sprintf("%s/%s/cancel", batchesSuffix, batchID) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix)) + if err != nil { + return + } + err = c.sendRequest(req, &response) + return +} + +type ListBatchResponse struct { + httpHeader + Object string `json:"object"` + Data []Batch `json:"data"` + FirstID string `json:"first_id"` + LastID string `json:"last_id"` + HasMore bool `json:"has_more"` +} + +// ListBatch API call to List batch. +func (c *Client) ListBatch(ctx context.Context, after *string, limit *int) (response ListBatchResponse, err error) { + urlValues := url.Values{} + if limit != nil { + urlValues.Add("limit", fmt.Sprintf("%d", *limit)) + } + if after != nil { + urlValues.Add("after", *after) + } + encodedValues := "" + if len(urlValues) > 0 { + encodedValues = "?" + urlValues.Encode() + } + + urlSuffix := fmt.Sprintf("%s%s", batchesSuffix, encodedValues) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix)) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} diff --git a/batch_test.go b/batch_test.go new file mode 100644 index 00000000..4b2261e0 --- /dev/null +++ b/batch_test.go @@ -0,0 +1,368 @@ +package openai_test + +import ( + "context" + "fmt" + "net/http" + "reflect" + "testing" + + "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" +) + +func TestUploadBatchFile(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + + server.RegisterHandler("/v1/files", handleCreateFile) + req := openai.UploadBatchFileRequest{} + req.AddChatCompletion("req-1", openai.ChatCompletionRequest{ + MaxTokens: 5, + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + }) + _, err := client.UploadBatchFile(context.Background(), req) + checks.NoError(t, err, "UploadBatchFile error") +} + +func TestCreateBatch(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + + server.RegisterHandler("/v1/batches", handleBatchEndpoint) + _, err := client.CreateBatch(context.Background(), openai.CreateBatchRequest{ + InputFileID: "file-abc", + Endpoint: openai.BatchEndpointChatCompletions, + CompletionWindow: "24h", + }) + checks.NoError(t, err, "CreateBatch error") +} + +func TestCreateBatchWithUploadFile(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/files", handleCreateFile) + server.RegisterHandler("/v1/batches", handleBatchEndpoint) + req := openai.CreateBatchWithUploadFileRequest{ + Endpoint: openai.BatchEndpointChatCompletions, + } + req.AddChatCompletion("req-1", openai.ChatCompletionRequest{ + MaxTokens: 5, + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + }) + _, err := client.CreateBatchWithUploadFile(context.Background(), req) + checks.NoError(t, err, "CreateBatchWithUploadFile error") +} + +func TestRetrieveBatch(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/batches/file-id-1", handleRetrieveBatchEndpoint) + _, err := client.RetrieveBatch(context.Background(), "file-id-1") + checks.NoError(t, err, "RetrieveBatch error") +} + +func TestCancelBatch(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/batches/file-id-1/cancel", handleCancelBatchEndpoint) + _, err := client.CancelBatch(context.Background(), "file-id-1") + checks.NoError(t, err, "RetrieveBatch error") +} + +func TestListBatch(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/batches", handleBatchEndpoint) + after := "batch_abc123" + limit := 10 + _, err := client.ListBatch(context.Background(), &after, &limit) + checks.NoError(t, err, "RetrieveBatch error") +} + +func TestUploadBatchFileRequest_AddChatCompletion(t *testing.T) { + type args struct { + customerID string + body openai.ChatCompletionRequest + } + tests := []struct { + name string + args []args + want []byte + }{ + {"", []args{ + { + customerID: "req-1", + body: openai.ChatCompletionRequest{ + MaxTokens: 5, + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + }, + }, + { + customerID: "req-2", + body: openai.ChatCompletionRequest{ + MaxTokens: 5, + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + }, + }, + }, []byte("{\"custom_id\":\"req-1\",\"body\":{\"model\":\"gpt-3.5-turbo\",\"messages\":[{\"role\":\"user\",\"content\":\"Hello!\"}],\"max_tokens\":5},\"method\":\"POST\",\"url\":\"/v1/chat/completions\"}\n{\"custom_id\":\"req-2\",\"body\":{\"model\":\"gpt-3.5-turbo\",\"messages\":[{\"role\":\"user\",\"content\":\"Hello!\"}],\"max_tokens\":5},\"method\":\"POST\",\"url\":\"/v1/chat/completions\"}")}, //nolint:lll + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &openai.UploadBatchFileRequest{} + for _, arg := range tt.args { + r.AddChatCompletion(arg.customerID, arg.body) + } + got := r.MarshalJSONL() + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Marshal() got = %v, want %v", got, tt.want) + } + }) + } +} + +func TestUploadBatchFileRequest_AddCompletion(t *testing.T) { + type args struct { + customerID string + body openai.CompletionRequest + } + tests := []struct { + name string + args []args + want []byte + }{ + {"", []args{ + { + customerID: "req-1", + body: openai.CompletionRequest{ + Model: openai.GPT3Dot5Turbo, + User: "Hello", + }, + }, + { + customerID: "req-2", + body: openai.CompletionRequest{ + Model: openai.GPT3Dot5Turbo, + User: "Hello", + }, + }, + }, []byte("{\"custom_id\":\"req-1\",\"body\":{\"model\":\"gpt-3.5-turbo\",\"user\":\"Hello\"},\"method\":\"POST\",\"url\":\"/v1/completions\"}\n{\"custom_id\":\"req-2\",\"body\":{\"model\":\"gpt-3.5-turbo\",\"user\":\"Hello\"},\"method\":\"POST\",\"url\":\"/v1/completions\"}")}, //nolint:lll + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &openai.UploadBatchFileRequest{} + for _, arg := range tt.args { + r.AddCompletion(arg.customerID, arg.body) + } + got := r.MarshalJSONL() + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Marshal() got = %v, want %v", got, tt.want) + } + }) + } +} + +func TestUploadBatchFileRequest_AddEmbedding(t *testing.T) { + type args struct { + customerID string + body openai.EmbeddingRequest + } + tests := []struct { + name string + args []args + want []byte + }{ + {"", []args{ + { + customerID: "req-1", + body: openai.EmbeddingRequest{ + Model: openai.GPT3Dot5Turbo, + Input: []string{"Hello", "World"}, + }, + }, + { + customerID: "req-2", + body: openai.EmbeddingRequest{ + Model: openai.AdaEmbeddingV2, + Input: []string{"Hello", "World"}, + }, + }, + }, []byte("{\"custom_id\":\"req-1\",\"body\":{\"input\":[\"Hello\",\"World\"],\"model\":\"gpt-3.5-turbo\",\"user\":\"\"},\"method\":\"POST\",\"url\":\"/v1/embeddings\"}\n{\"custom_id\":\"req-2\",\"body\":{\"input\":[\"Hello\",\"World\"],\"model\":\"text-embedding-ada-002\",\"user\":\"\"},\"method\":\"POST\",\"url\":\"/v1/embeddings\"}")}, //nolint:lll + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &openai.UploadBatchFileRequest{} + for _, arg := range tt.args { + r.AddEmbedding(arg.customerID, arg.body) + } + got := r.MarshalJSONL() + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Marshal() got = %v, want %v", got, tt.want) + } + }) + } +} + +func handleBatchEndpoint(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + _, _ = fmt.Fprintln(w, `{ + "id": "batch_abc123", + "object": "batch", + "endpoint": "/v1/completions", + "errors": null, + "input_file_id": "file-abc123", + "completion_window": "24h", + "status": "completed", + "output_file_id": "file-cvaTdG", + "error_file_id": "file-HOWS94", + "created_at": 1711471533, + "in_progress_at": 1711471538, + "expires_at": 1711557933, + "finalizing_at": 1711493133, + "completed_at": 1711493163, + "failed_at": null, + "expired_at": null, + "cancelling_at": null, + "cancelled_at": null, + "request_counts": { + "total": 100, + "completed": 95, + "failed": 5 + }, + "metadata": { + "customer_id": "user_123456789", + "batch_description": "Nightly eval job" + } + }`) + } else if r.Method == http.MethodGet { + _, _ = fmt.Fprintln(w, `{ + "object": "list", + "data": [ + { + "id": "batch_abc123", + "object": "batch", + "endpoint": "/v1/chat/completions", + "errors": null, + "input_file_id": "file-abc123", + "completion_window": "24h", + "status": "completed", + "output_file_id": "file-cvaTdG", + "error_file_id": "file-HOWS94", + "created_at": 1711471533, + "in_progress_at": 1711471538, + "expires_at": 1711557933, + "finalizing_at": 1711493133, + "completed_at": 1711493163, + "failed_at": null, + "expired_at": null, + "cancelling_at": null, + "cancelled_at": null, + "request_counts": { + "total": 100, + "completed": 95, + "failed": 5 + }, + "metadata": { + "customer_id": "user_123456789", + "batch_description": "Nightly job" + } + } + ], + "first_id": "batch_abc123", + "last_id": "batch_abc456", + "has_more": true + }`) + } +} + +func handleRetrieveBatchEndpoint(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet { + _, _ = fmt.Fprintln(w, `{ + "id": "batch_abc123", + "object": "batch", + "endpoint": "/v1/completions", + "errors": null, + "input_file_id": "file-abc123", + "completion_window": "24h", + "status": "completed", + "output_file_id": "file-cvaTdG", + "error_file_id": "file-HOWS94", + "created_at": 1711471533, + "in_progress_at": 1711471538, + "expires_at": 1711557933, + "finalizing_at": 1711493133, + "completed_at": 1711493163, + "failed_at": null, + "expired_at": null, + "cancelling_at": null, + "cancelled_at": null, + "request_counts": { + "total": 100, + "completed": 95, + "failed": 5 + }, + "metadata": { + "customer_id": "user_123456789", + "batch_description": "Nightly eval job" + } + }`) + } +} + +func handleCancelBatchEndpoint(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + _, _ = fmt.Fprintln(w, `{ + "id": "batch_abc123", + "object": "batch", + "endpoint": "/v1/chat/completions", + "errors": null, + "input_file_id": "file-abc123", + "completion_window": "24h", + "status": "cancelling", + "output_file_id": null, + "error_file_id": null, + "created_at": 1711471533, + "in_progress_at": 1711471538, + "expires_at": 1711557933, + "finalizing_at": null, + "completed_at": null, + "failed_at": null, + "expired_at": null, + "cancelling_at": 1711475133, + "cancelled_at": null, + "request_counts": { + "total": 100, + "completed": 23, + "failed": 1 + }, + "metadata": { + "customer_id": "user_123456789", + "batch_description": "Nightly eval job" + } + }`) + } +} diff --git a/client_test.go b/client_test.go index a08d10f2..e49da9b3 100644 --- a/client_test.go +++ b/client_test.go @@ -396,6 +396,17 @@ func TestClientReturnsRequestBuilderErrors(t *testing.T) { {"CreateSpeech", func() (any, error) { return client.CreateSpeech(ctx, CreateSpeechRequest{Model: TTSModel1, Voice: VoiceAlloy}) }}, + {"CreateBatch", func() (any, error) { + return client.CreateBatch(ctx, CreateBatchRequest{}) + }}, + {"CreateBatchWithUploadFile", func() (any, error) { + return client.CreateBatchWithUploadFile(ctx, CreateBatchWithUploadFileRequest{}) + }}, + {"RetrieveBatch", func() (any, error) { + return client.RetrieveBatch(ctx, "") + }}, + {"CancelBatch", func() (any, error) { return client.CancelBatch(ctx, "") }}, + {"ListBatch", func() (any, error) { return client.ListBatch(ctx, nil, nil) }}, } for _, testCase := range testCases { diff --git a/completion.go b/completion.go index ced8e060..024f09b1 100644 --- a/completion.go +++ b/completion.go @@ -39,30 +39,33 @@ const ( GPT3Dot5Turbo16K0613 = "gpt-3.5-turbo-16k-0613" GPT3Dot5Turbo = "gpt-3.5-turbo" GPT3Dot5TurboInstruct = "gpt-3.5-turbo-instruct" - // Deprecated: Will be shut down on January 04, 2024. Use gpt-3.5-turbo-instruct instead. + // Deprecated: Model is shutdown. Use gpt-3.5-turbo-instruct instead. GPT3TextDavinci003 = "text-davinci-003" - // Deprecated: Will be shut down on January 04, 2024. Use gpt-3.5-turbo-instruct instead. + // Deprecated: Model is shutdown. Use gpt-3.5-turbo-instruct instead. GPT3TextDavinci002 = "text-davinci-002" - // Deprecated: Will be shut down on January 04, 2024. Use gpt-3.5-turbo-instruct instead. + // Deprecated: Model is shutdown. Use gpt-3.5-turbo-instruct instead. GPT3TextCurie001 = "text-curie-001" - // Deprecated: Will be shut down on January 04, 2024. Use gpt-3.5-turbo-instruct instead. + // Deprecated: Model is shutdown. Use gpt-3.5-turbo-instruct instead. GPT3TextBabbage001 = "text-babbage-001" - // Deprecated: Will be shut down on January 04, 2024. Use gpt-3.5-turbo-instruct instead. + // Deprecated: Model is shutdown. Use gpt-3.5-turbo-instruct instead. GPT3TextAda001 = "text-ada-001" - // Deprecated: Will be shut down on January 04, 2024. Use gpt-3.5-turbo-instruct instead. + // Deprecated: Model is shutdown. Use gpt-3.5-turbo-instruct instead. GPT3TextDavinci001 = "text-davinci-001" - // Deprecated: Will be shut down on January 04, 2024. Use gpt-3.5-turbo-instruct instead. + // Deprecated: Model is shutdown. Use gpt-3.5-turbo-instruct instead. GPT3DavinciInstructBeta = "davinci-instruct-beta" - GPT3Davinci = "davinci" - GPT3Davinci002 = "davinci-002" - // Deprecated: Will be shut down on January 04, 2024. Use gpt-3.5-turbo-instruct instead. + // Deprecated: Model is shutdown. Use davinci-002 instead. + GPT3Davinci = "davinci" + GPT3Davinci002 = "davinci-002" + // Deprecated: Model is shutdown. Use gpt-3.5-turbo-instruct instead. GPT3CurieInstructBeta = "curie-instruct-beta" GPT3Curie = "curie" GPT3Curie002 = "curie-002" - GPT3Ada = "ada" - GPT3Ada002 = "ada-002" - GPT3Babbage = "babbage" - GPT3Babbage002 = "babbage-002" + // Deprecated: Model is shutdown. Use babbage-002 instead. + GPT3Ada = "ada" + GPT3Ada002 = "ada-002" + // Deprecated: Model is shutdown. Use babbage-002 instead. + GPT3Babbage = "babbage" + GPT3Babbage002 = "babbage-002" ) // Codex Defines the models provided by OpenAI. diff --git a/embeddings.go b/embeddings.go index c5633a31..b513ba6a 100644 --- a/embeddings.go +++ b/embeddings.go @@ -16,7 +16,7 @@ var ErrVectorLengthMismatch = errors.New("vector length mismatch") type EmbeddingModel string const ( - // Deprecated: The following block will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. + // Deprecated: The following block is shut down. Use text-embedding-ada-002 instead. AdaSimilarity EmbeddingModel = "text-similarity-ada-001" BabbageSimilarity EmbeddingModel = "text-similarity-babbage-001" CurieSimilarity EmbeddingModel = "text-similarity-curie-001" diff --git a/files.go b/files.go index b40a44f1..26ad6bd7 100644 --- a/files.go +++ b/files.go @@ -22,6 +22,7 @@ const ( PurposeFineTuneResults PurposeType = "fine-tune-results" PurposeAssistants PurposeType = "assistants" PurposeAssistantsOutput PurposeType = "assistants_output" + PurposeBatch PurposeType = "batch" ) // FileBytesRequest represents a file upload request. diff --git a/internal/test/checks/checks.go b/internal/test/checks/checks.go index 71336915..6bd0964c 100644 --- a/internal/test/checks/checks.go +++ b/internal/test/checks/checks.go @@ -12,6 +12,13 @@ func NoError(t *testing.T, err error, message ...string) { } } +func NoErrorF(t *testing.T, err error, message ...string) { + t.Helper() + if err != nil { + t.Fatal(err, message) + } +} + func HasError(t *testing.T, err error, message ...string) { t.Helper() if err == nil { diff --git a/thread.go b/thread.go index 900e3f2e..6f752145 100644 --- a/thread.go +++ b/thread.go @@ -10,21 +10,74 @@ const ( ) type Thread struct { - ID string `json:"id"` - Object string `json:"object"` - CreatedAt int64 `json:"created_at"` - Metadata map[string]any `json:"metadata"` + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int64 `json:"created_at"` + Metadata map[string]any `json:"metadata"` + ToolResources ToolResources `json:"tool_resources,omitempty"` httpHeader } type ThreadRequest struct { - Messages []ThreadMessage `json:"messages,omitempty"` - Metadata map[string]any `json:"metadata,omitempty"` + Messages []ThreadMessage `json:"messages,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` + ToolResources *ToolResourcesRequest `json:"tool_resources,omitempty"` } +type ToolResources struct { + CodeInterpreter *CodeInterpreterToolResources `json:"code_interpreter,omitempty"` + FileSearch *FileSearchToolResources `json:"file_search,omitempty"` +} + +type CodeInterpreterToolResources struct { + FileIDs []string `json:"file_ids,omitempty"` +} + +type FileSearchToolResources struct { + VectorStoreIDs []string `json:"vector_store_ids,omitempty"` +} + +type ToolResourcesRequest struct { + CodeInterpreter *CodeInterpreterToolResourcesRequest `json:"code_interpreter,omitempty"` + FileSearch *FileSearchToolResourcesRequest `json:"file_search,omitempty"` +} + +type CodeInterpreterToolResourcesRequest struct { + FileIDs []string `json:"file_ids,omitempty"` +} + +type FileSearchToolResourcesRequest struct { + VectorStoreIDs []string `json:"vector_store_ids,omitempty"` + VectorStores []VectorStoreToolResources `json:"vector_stores,omitempty"` +} + +type VectorStoreToolResources struct { + FileIDs []string `json:"file_ids,omitempty"` + ChunkingStrategy *ChunkingStrategy `json:"chunking_strategy,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` +} + +type ChunkingStrategy struct { + Type ChunkingStrategyType `json:"type"` + Static *StaticChunkingStrategy `json:"static,omitempty"` +} + +type StaticChunkingStrategy struct { + MaxChunkSizeTokens int `json:"max_chunk_size_tokens"` + ChunkOverlapTokens int `json:"chunk_overlap_tokens"` +} + +type ChunkingStrategyType string + +const ( + ChunkingStrategyTypeAuto ChunkingStrategyType = "auto" + ChunkingStrategyTypeStatic ChunkingStrategyType = "static" +) + type ModifyThreadRequest struct { - Metadata map[string]any `json:"metadata"` + Metadata map[string]any `json:"metadata"` + ToolResources *ToolResources `json:"tool_resources,omitempty"` } type ThreadMessageRole string