From a2ca01bb6dae1a7d58860a5b2d5d5273667e089e Mon Sep 17 00:00:00 2001 From: Simone Vellei Date: Tue, 29 Aug 2023 14:04:27 +0200 Subject: [PATCH 01/98] feat: implement new fine tuning job API (#479) * feat: implement new fine tuning job API * fix: export ListFineTuningJobEventsParameter * fix: lint errors * fix: test errors * fix: code test coverage * fix: code test coverage * fix: use any * chore: use url.Values --- client_test.go | 12 ++++ fine_tuning_job.go | 153 ++++++++++++++++++++++++++++++++++++++++ fine_tuning_job_test.go | 90 +++++++++++++++++++++++ 3 files changed, 255 insertions(+) create mode 100644 fine_tuning_job.go create mode 100644 fine_tuning_job_test.go diff --git a/client_test.go b/client_test.go index 29d84edf..9b504689 100644 --- a/client_test.go +++ b/client_test.go @@ -223,6 +223,18 @@ func TestClientReturnsRequestBuilderErrors(t *testing.T) { {"ListFineTuneEvents", func() (any, error) { return client.ListFineTuneEvents(ctx, "") }}, + {"CreateFineTuningJob", func() (any, error) { + return client.CreateFineTuningJob(ctx, FineTuningJobRequest{}) + }}, + {"CancelFineTuningJob", func() (any, error) { + return client.CancelFineTuningJob(ctx, "") + }}, + {"RetrieveFineTuningJob", func() (any, error) { + return client.RetrieveFineTuningJob(ctx, "") + }}, + {"ListFineTuningJobEvents", func() (any, error) { + return client.ListFineTuningJobEvents(ctx, "") + }}, {"Moderations", func() (any, error) { return client.Moderations(ctx, ModerationRequest{}) }}, diff --git a/fine_tuning_job.go b/fine_tuning_job.go new file mode 100644 index 00000000..a840b7ec --- /dev/null +++ b/fine_tuning_job.go @@ -0,0 +1,153 @@ +package openai + +import ( + "context" + "fmt" + "net/http" + "net/url" +) + +type FineTuningJob struct { + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int64 `json:"created_at"` + FinishedAt int64 `json:"finished_at"` + Model string `json:"model"` + FineTunedModel string `json:"fine_tuned_model,omitempty"` + OrganizationID string `json:"organization_id"` + Status string `json:"status"` + Hyperparameters Hyperparameters `json:"hyperparameters"` + TrainingFile string `json:"training_file"` + ValidationFile string `json:"validation_file,omitempty"` + ResultFiles []string `json:"result_files"` + TrainedTokens int `json:"trained_tokens"` +} + +type Hyperparameters struct { + Epochs int `json:"n_epochs"` +} + +type FineTuningJobRequest struct { + TrainingFile string `json:"training_file"` + ValidationFile string `json:"validation_file,omitempty"` + Model string `json:"model,omitempty"` + Hyperparameters *Hyperparameters `json:"hyperparameters,omitempty"` + Suffix string `json:"suffix,omitempty"` +} + +type FineTuningJobEventList struct { + Object string `json:"object"` + Data []FineTuneEvent `json:"data"` + HasMore bool `json:"has_more"` +} + +type FineTuningJobEvent struct { + Object string `json:"object"` + ID string `json:"id"` + CreatedAt int `json:"created_at"` + Level string `json:"level"` + Message string `json:"message"` + Data any `json:"data"` + Type string `json:"type"` +} + +// CreateFineTuningJob create a fine tuning job. +func (c *Client) CreateFineTuningJob( + ctx context.Context, + request FineTuningJobRequest, +) (response FineTuningJob, err error) { + urlSuffix := "/fine_tuning/jobs" + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request)) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +// CancelFineTuningJob cancel a fine tuning job. +func (c *Client) CancelFineTuningJob(ctx context.Context, fineTuningJobID string) (response FineTuningJob, err error) { + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/fine_tuning/jobs/"+fineTuningJobID+"/cancel")) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +// RetrieveFineTuningJob retrieve a fine tuning job. +func (c *Client) RetrieveFineTuningJob( + ctx context.Context, + fineTuningJobID string, +) (response FineTuningJob, err error) { + urlSuffix := fmt.Sprintf("/fine_tuning/jobs/%s", fineTuningJobID) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix)) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +type listFineTuningJobEventsParameters struct { + after *string + limit *int +} + +type ListFineTuningJobEventsParameter func(*listFineTuningJobEventsParameters) + +func ListFineTuningJobEventsWithAfter(after string) ListFineTuningJobEventsParameter { + return func(args *listFineTuningJobEventsParameters) { + args.after = &after + } +} + +func ListFineTuningJobEventsWithLimit(limit int) ListFineTuningJobEventsParameter { + return func(args *listFineTuningJobEventsParameters) { + args.limit = &limit + } +} + +// ListFineTuningJobs list fine tuning jobs events. +func (c *Client) ListFineTuningJobEvents( + ctx context.Context, + fineTuningJobID string, + setters ...ListFineTuningJobEventsParameter, +) (response FineTuningJobEventList, err error) { + parameters := &listFineTuningJobEventsParameters{ + after: nil, + limit: nil, + } + + for _, setter := range setters { + setter(parameters) + } + + urlValues := url.Values{} + if parameters.after != nil { + urlValues.Add("after", *parameters.after) + } + if parameters.limit != nil { + urlValues.Add("limit", fmt.Sprintf("%d", *parameters.limit)) + } + + encodedValues := "" + if len(urlValues) > 0 { + encodedValues = "?" + urlValues.Encode() + } + + req, err := c.newRequest( + ctx, + http.MethodGet, + c.fullURL("/fine_tuning/jobs/"+fineTuningJobID+"/events"+encodedValues), + ) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} diff --git a/fine_tuning_job_test.go b/fine_tuning_job_test.go new file mode 100644 index 00000000..519c6cd2 --- /dev/null +++ b/fine_tuning_job_test.go @@ -0,0 +1,90 @@ +package openai_test + +import ( + "context" + + . "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" + + "encoding/json" + "fmt" + "net/http" + "testing" +) + +const testFineTuninigJobID = "fine-tuning-job-id" + +// TestFineTuningJob Tests the fine tuning job endpoint of the API using the mocked server. +func TestFineTuningJob(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler( + "/v1/fine_tuning/jobs", + func(w http.ResponseWriter, r *http.Request) { + var resBytes []byte + resBytes, _ = json.Marshal(FineTuningJob{}) + fmt.Fprintln(w, string(resBytes)) + }, + ) + + server.RegisterHandler( + "/fine_tuning/jobs/"+testFineTuninigJobID+"/cancel", + func(w http.ResponseWriter, r *http.Request) { + resBytes, _ := json.Marshal(FineTuningJob{}) + fmt.Fprintln(w, string(resBytes)) + }, + ) + + server.RegisterHandler( + "/v1/fine_tuning/jobs/"+testFineTuninigJobID, + func(w http.ResponseWriter, r *http.Request) { + var resBytes []byte + resBytes, _ = json.Marshal(FineTuningJob{}) + fmt.Fprintln(w, string(resBytes)) + }, + ) + + server.RegisterHandler( + "/v1/fine_tuning/jobs/"+testFineTuninigJobID+"/events", + func(w http.ResponseWriter, r *http.Request) { + resBytes, _ := json.Marshal(FineTuningJobEventList{}) + fmt.Fprintln(w, string(resBytes)) + }, + ) + + ctx := context.Background() + + _, err := client.CreateFineTuningJob(ctx, FineTuningJobRequest{}) + checks.NoError(t, err, "CreateFineTuningJob error") + + _, err = client.CancelFineTuningJob(ctx, testFineTuninigJobID) + checks.NoError(t, err, "CancelFineTuningJob error") + + _, err = client.RetrieveFineTuningJob(ctx, testFineTuninigJobID) + checks.NoError(t, err, "RetrieveFineTuningJob error") + + _, err = client.ListFineTuningJobEvents(ctx, testFineTuninigJobID) + checks.NoError(t, err, "ListFineTuningJobEvents error") + + _, err = client.ListFineTuningJobEvents( + ctx, + testFineTuninigJobID, + ListFineTuningJobEventsWithAfter("last-event-id"), + ) + checks.NoError(t, err, "ListFineTuningJobEvents error") + + _, err = client.ListFineTuningJobEvents( + ctx, + testFineTuninigJobID, + ListFineTuningJobEventsWithLimit(10), + ) + checks.NoError(t, err, "ListFineTuningJobEvents error") + + _, err = client.ListFineTuningJobEvents( + ctx, + testFineTuninigJobID, + ListFineTuningJobEventsWithAfter("last-event-id"), + ListFineTuningJobEventsWithLimit(10), + ) + checks.NoError(t, err, "ListFineTuningJobEvents error") +} From 25da859c189c62c2454717fb2214da079017ff8e Mon Sep 17 00:00:00 2001 From: Simone Vellei Date: Thu, 31 Aug 2023 12:14:39 +0200 Subject: [PATCH 02/98] Chore Deprecate legacy fine tunes API (#484) * chore: add deprecation message * chore: use new fine tuning API in README example --- README.md | 21 +++++++++++++-------- fine_tunes.go | 40 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 9714d89f..440c4096 100644 --- a/README.md +++ b/README.md @@ -631,11 +631,16 @@ func main() { client := openai.NewClient("your token") ctx := context.Background() - // create a .jsonl file with your training data + // create a .jsonl file with your training data for conversational model // {"prompt": "", "completion": ""} // {"prompt": "", "completion": ""} // {"prompt": "", "completion": ""} + // chat models are trained using the following file format: + // {"messages": [{"role": "system", "content": "Marv is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "What's the capital of France?"}, {"role": "assistant", "content": "Paris, as if everyone doesn't know that already."}]} + // {"messages": [{"role": "system", "content": "Marv is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "Who wrote 'Romeo and Juliet'?"}, {"role": "assistant", "content": "Oh, just some guy named William Shakespeare. Ever heard of him?"}]} + // {"messages": [{"role": "system", "content": "Marv is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "How far is the Moon from Earth?"}, {"role": "assistant", "content": "Around 384,400 kilometers. Give or take a few, like that really matters."}]} + // you can use openai cli tool to validate the data // For more info - https://platform.openai.com/docs/guides/fine-tuning @@ -648,29 +653,29 @@ func main() { return } - // create a fine tune job + // create a fine tuning job // Streams events until the job is done (this often takes minutes, but can take hours if there are many jobs in the queue or your dataset is large) // use below get method to know the status of your model - tune, err := client.CreateFineTune(ctx, openai.FineTuneRequest{ + fineTuningJob, err := client.CreateFineTuningJob(ctx, openai.FineTuningJobRequest{ TrainingFile: file.ID, - Model: "ada", // babbage, curie, davinci, or a fine-tuned model created after 2022-04-21. + Model: "davinci-002", // gpt-3.5-turbo-0613, babbage-002. }) if err != nil { fmt.Printf("Creating new fine tune model error: %v\n", err) return } - getTune, err := client.GetFineTune(ctx, tune.ID) + fineTuningJob, err = client.RetrieveFineTuningJob(ctx, fineTuningJob.ID) if err != nil { fmt.Printf("Getting fine tune model error: %v\n", err) return } - fmt.Println(getTune.FineTunedModel) + fmt.Println(fineTuningJob.FineTunedModel) - // once the status of getTune is `succeeded`, you can use your fine tune model in Completion Request + // once the status of fineTuningJob is `succeeded`, you can use your fine tune model in Completion Request or Chat Completion Request // resp, err := client.CreateCompletion(ctx, openai.CompletionRequest{ - // Model: getTune.FineTunedModel, + // Model: fineTuningJob.FineTunedModel, // Prompt: "your prompt", // }) // if err != nil { diff --git a/fine_tunes.go b/fine_tunes.go index 96e731d5..7d3b59db 100644 --- a/fine_tunes.go +++ b/fine_tunes.go @@ -6,6 +6,9 @@ import ( "net/http" ) +// Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. +// This API will be officially deprecated on January 4th, 2024. +// OpenAI recommends to migrate to the new fine tuning API implemented in fine_tuning_job.go. type FineTuneRequest struct { TrainingFile string `json:"training_file"` ValidationFile string `json:"validation_file,omitempty"` @@ -21,6 +24,9 @@ type FineTuneRequest struct { Suffix string `json:"suffix,omitempty"` } +// Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. +// This API will be officially deprecated on January 4th, 2024. +// OpenAI recommends to migrate to the new fine tuning API implemented in fine_tuning_job.go. type FineTune struct { ID string `json:"id"` Object string `json:"object"` @@ -37,6 +43,9 @@ type FineTune struct { UpdatedAt int64 `json:"updated_at"` } +// Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. +// This API will be officially deprecated on January 4th, 2024. +// OpenAI recommends to migrate to the new fine tuning API implemented in fine_tuning_job.go. type FineTuneEvent struct { Object string `json:"object"` CreatedAt int64 `json:"created_at"` @@ -44,6 +53,9 @@ type FineTuneEvent struct { Message string `json:"message"` } +// Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. +// This API will be officially deprecated on January 4th, 2024. +// OpenAI recommends to migrate to the new fine tuning API implemented in fine_tuning_job.go. type FineTuneHyperParams struct { BatchSize int `json:"batch_size"` LearningRateMultiplier float64 `json:"learning_rate_multiplier"` @@ -51,21 +63,34 @@ type FineTuneHyperParams struct { PromptLossWeight float64 `json:"prompt_loss_weight"` } +// Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. +// This API will be officially deprecated on January 4th, 2024. +// OpenAI recommends to migrate to the new fine tuning API implemented in fine_tuning_job.go. type FineTuneList struct { Object string `json:"object"` Data []FineTune `json:"data"` } + +// Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. +// This API will be officially deprecated on January 4th, 2024. +// OpenAI recommends to migrate to the new fine tuning API implemented in fine_tuning_job.go. type FineTuneEventList struct { Object string `json:"object"` Data []FineTuneEvent `json:"data"` } +// Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. +// This API will be officially deprecated on January 4th, 2024. +// OpenAI recommends to migrate to the new fine tuning API implemented in fine_tuning_job.go. type FineTuneDeleteResponse struct { ID string `json:"id"` Object string `json:"object"` Deleted bool `json:"deleted"` } +// Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. +// This API will be officially deprecated on January 4th, 2024. +// OpenAI recommends to migrate to the new fine tuning API implemented in fine_tuning_job.go. func (c *Client) CreateFineTune(ctx context.Context, request FineTuneRequest) (response FineTune, err error) { urlSuffix := "/fine-tunes" req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request)) @@ -78,6 +103,9 @@ func (c *Client) CreateFineTune(ctx context.Context, request FineTuneRequest) (r } // CancelFineTune cancel a fine-tune job. +// Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. +// This API will be officially deprecated on January 4th, 2024. +// OpenAI recommends to migrate to the new fine tuning API implemented in fine_tuning_job.go. func (c *Client) CancelFineTune(ctx context.Context, fineTuneID string) (response FineTune, err error) { req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/fine-tunes/"+fineTuneID+"/cancel")) if err != nil { @@ -88,6 +116,9 @@ func (c *Client) CancelFineTune(ctx context.Context, fineTuneID string) (respons return } +// Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. +// This API will be officially deprecated on January 4th, 2024. +// OpenAI recommends to migrate to the new fine tuning API implemented in fine_tuning_job.go. func (c *Client) ListFineTunes(ctx context.Context) (response FineTuneList, err error) { req, err := c.newRequest(ctx, http.MethodGet, c.fullURL("/fine-tunes")) if err != nil { @@ -98,6 +129,9 @@ func (c *Client) ListFineTunes(ctx context.Context) (response FineTuneList, err return } +// Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. +// This API will be officially deprecated on January 4th, 2024. +// OpenAI recommends to migrate to the new fine tuning API implemented in fine_tuning_job.go. func (c *Client) GetFineTune(ctx context.Context, fineTuneID string) (response FineTune, err error) { urlSuffix := fmt.Sprintf("/fine-tunes/%s", fineTuneID) req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix)) @@ -109,6 +143,9 @@ func (c *Client) GetFineTune(ctx context.Context, fineTuneID string) (response F return } +// Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. +// This API will be officially deprecated on January 4th, 2024. +// OpenAI recommends to migrate to the new fine tuning API implemented in fine_tuning_job.go. func (c *Client) DeleteFineTune(ctx context.Context, fineTuneID string) (response FineTuneDeleteResponse, err error) { req, err := c.newRequest(ctx, http.MethodDelete, c.fullURL("/fine-tunes/"+fineTuneID)) if err != nil { @@ -119,6 +156,9 @@ func (c *Client) DeleteFineTune(ctx context.Context, fineTuneID string) (respons return } +// Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. +// This API will be officially deprecated on January 4th, 2024. +// OpenAI recommends to migrate to the new fine tuning API implemented in fine_tuning_job.go. func (c *Client) ListFineTuneEvents(ctx context.Context, fineTuneID string) (response FineTuneEventList, err error) { req, err := c.newRequest(ctx, http.MethodGet, c.fullURL("/fine-tunes/"+fineTuneID+"/events")) if err != nil { From 3589837b229aeace205f312aa839bf73154e2820 Mon Sep 17 00:00:00 2001 From: NullpointerW <58949721+NullpointerW@users.noreply.github.com> Date: Thu, 7 Sep 2023 18:52:47 +0800 Subject: [PATCH 03/98] Update OpenAPI file return struct (#486) * completionBatchingRequestSupport * lint fix * fix Run test fail * fix TestClientReturnsRequestBuilderErrors fail * fix Codecov check * ignore TestClientReturnsRequestBuilderErrors lint * fix lint again * lint again*2 * replace checkPromptType implementation * remove nil check * update file return struct --------- Co-authored-by: W <825708370@qq.com> --- files.go | 15 ++++++++------- files_api_test.go | 1 - 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/files.go b/files.go index ea1f50a7..8b933c36 100644 --- a/files.go +++ b/files.go @@ -17,13 +17,14 @@ type FileRequest struct { // File struct represents an OpenAPI file. type File struct { - Bytes int `json:"bytes"` - CreatedAt int64 `json:"created_at"` - ID string `json:"id"` - FileName string `json:"filename"` - Object string `json:"object"` - Owner string `json:"owner"` - Purpose string `json:"purpose"` + Bytes int `json:"bytes"` + CreatedAt int64 `json:"created_at"` + ID string `json:"id"` + FileName string `json:"filename"` + Object string `json:"object"` + Status string `json:"status"` + Purpose string `json:"purpose"` + StatusDetails string `json:"status_details"` } // FilesList is a list of files that belong to the user or organization. diff --git a/files_api_test.go b/files_api_test.go index f0a08764..1cbc7289 100644 --- a/files_api_test.go +++ b/files_api_test.go @@ -64,7 +64,6 @@ func handleCreateFile(w http.ResponseWriter, r *http.Request) { Purpose: purpose, CreatedAt: time.Now().Unix(), Object: "test-objecct", - Owner: "test-owner", } resBytes, _ = json.Marshal(fileReq) From 8e4b7963a3f378332bd512a5040d75d8504505c8 Mon Sep 17 00:00:00 2001 From: Simone Vellei Date: Mon, 11 Sep 2023 15:44:46 +0200 Subject: [PATCH 04/98] Chore Support base64 embedding format (#485) * chore: support base64 embedding format * fix: add sizeOfFloat32 * chore: refactor base64 decoding * chore: add tests * fix linting * fix test * fix return error * fix: use smaller slice for tests * fix [skip ci] * chore: refactor test to consider CreateEmbeddings response * trigger build * chore: remove named returns * chore: refactor code to simplify the understanding * chore: tests have been refactored to match the encoding format passed by request * chore: fix tests * fix * fix --- embeddings.go | 116 +++++++++++++++++++++++++++++++++++---- embeddings_test.go | 131 ++++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 229 insertions(+), 18 deletions(-) diff --git a/embeddings.go b/embeddings.go index 1d319959..5ba91f23 100644 --- a/embeddings.go +++ b/embeddings.go @@ -2,6 +2,9 @@ package openai import ( "context" + "encoding/base64" + "encoding/binary" + "math" "net/http" ) @@ -129,15 +132,83 @@ type EmbeddingResponse struct { Usage Usage `json:"usage"` } +type base64String string + +func (b base64String) Decode() ([]float32, error) { + decodedData, err := base64.StdEncoding.DecodeString(string(b)) + if err != nil { + return nil, err + } + + const sizeOfFloat32 = 4 + floats := make([]float32, len(decodedData)/sizeOfFloat32) + for i := 0; i < len(floats); i++ { + floats[i] = math.Float32frombits(binary.LittleEndian.Uint32(decodedData[i*4 : (i+1)*4])) + } + + return floats, nil +} + +// Base64Embedding is a container for base64 encoded embeddings. +type Base64Embedding struct { + Object string `json:"object"` + Embedding base64String `json:"embedding"` + Index int `json:"index"` +} + +// EmbeddingResponseBase64 is the response from a Create embeddings request with base64 encoding format. +type EmbeddingResponseBase64 struct { + Object string `json:"object"` + Data []Base64Embedding `json:"data"` + Model EmbeddingModel `json:"model"` + Usage Usage `json:"usage"` +} + +// ToEmbeddingResponse converts an embeddingResponseBase64 to an EmbeddingResponse. +func (r *EmbeddingResponseBase64) ToEmbeddingResponse() (EmbeddingResponse, error) { + data := make([]Embedding, len(r.Data)) + + for i, base64Embedding := range r.Data { + embedding, err := base64Embedding.Embedding.Decode() + if err != nil { + return EmbeddingResponse{}, err + } + + data[i] = Embedding{ + Object: base64Embedding.Object, + Embedding: embedding, + Index: base64Embedding.Index, + } + } + + return EmbeddingResponse{ + Object: r.Object, + Model: r.Model, + Data: data, + Usage: r.Usage, + }, nil +} + type EmbeddingRequestConverter interface { // Needs to be of type EmbeddingRequestStrings or EmbeddingRequestTokens Convert() EmbeddingRequest } +// EmbeddingEncodingFormat is the format of the embeddings data. +// Currently, only "float" and "base64" are supported, however, "base64" is not officially documented. +// If not specified OpenAI will use "float". +type EmbeddingEncodingFormat string + +const ( + EmbeddingEncodingFormatFloat EmbeddingEncodingFormat = "float" + EmbeddingEncodingFormatBase64 EmbeddingEncodingFormat = "base64" +) + type EmbeddingRequest struct { - Input any `json:"input"` - Model EmbeddingModel `json:"model"` - User string `json:"user"` + Input any `json:"input"` + Model EmbeddingModel `json:"model"` + User string `json:"user"` + EncodingFormat EmbeddingEncodingFormat `json:"encoding_format,omitempty"` } func (r EmbeddingRequest) Convert() EmbeddingRequest { @@ -158,13 +229,18 @@ type EmbeddingRequestStrings struct { Model EmbeddingModel `json:"model"` // A unique identifier representing your end-user, which will help OpenAI to monitor and detect abuse. User string `json:"user"` + // EmbeddingEncodingFormat is the format of the embeddings data. + // Currently, only "float" and "base64" are supported, however, "base64" is not officially documented. + // If not specified OpenAI will use "float". + EncodingFormat EmbeddingEncodingFormat `json:"encoding_format,omitempty"` } func (r EmbeddingRequestStrings) Convert() EmbeddingRequest { return EmbeddingRequest{ - Input: r.Input, - Model: r.Model, - User: r.User, + Input: r.Input, + Model: r.Model, + User: r.User, + EncodingFormat: r.EncodingFormat, } } @@ -181,13 +257,18 @@ type EmbeddingRequestTokens struct { Model EmbeddingModel `json:"model"` // A unique identifier representing your end-user, which will help OpenAI to monitor and detect abuse. User string `json:"user"` + // EmbeddingEncodingFormat is the format of the embeddings data. + // Currently, only "float" and "base64" are supported, however, "base64" is not officially documented. + // If not specified OpenAI will use "float". + EncodingFormat EmbeddingEncodingFormat `json:"encoding_format,omitempty"` } func (r EmbeddingRequestTokens) Convert() EmbeddingRequest { return EmbeddingRequest{ - Input: r.Input, - Model: r.Model, - User: r.User, + Input: r.Input, + Model: r.Model, + User: r.User, + EncodingFormat: r.EncodingFormat, } } @@ -196,14 +277,27 @@ func (r EmbeddingRequestTokens) Convert() EmbeddingRequest { // // Body should be of type EmbeddingRequestStrings for embedding strings or EmbeddingRequestTokens // for embedding groups of text already converted to tokens. -func (c *Client) CreateEmbeddings(ctx context.Context, conv EmbeddingRequestConverter) (res EmbeddingResponse, err error) { //nolint:lll +func (c *Client) CreateEmbeddings( + ctx context.Context, + conv EmbeddingRequestConverter, +) (res EmbeddingResponse, err error) { baseReq := conv.Convert() req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/embeddings", baseReq.Model.String()), withBody(baseReq)) if err != nil { return } - err = c.sendRequest(req, &res) + if baseReq.EncodingFormat != EmbeddingEncodingFormatBase64 { + err = c.sendRequest(req, &res) + return + } + + base64Response := &EmbeddingResponseBase64{} + err = c.sendRequest(req, base64Response) + if err != nil { + return + } + res, err = base64Response.ToEmbeddingResponse() return } diff --git a/embeddings_test.go b/embeddings_test.go index 47c4f510..9c48c5b8 100644 --- a/embeddings_test.go +++ b/embeddings_test.go @@ -1,15 +1,16 @@ package openai_test import ( - . "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test/checks" - "bytes" "context" "encoding/json" "fmt" "net/http" + "reflect" "testing" + + . "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" ) func TestEmbedding(t *testing.T) { @@ -97,22 +98,138 @@ func TestEmbeddingModel(t *testing.T) { func TestEmbeddingEndpoint(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() + + sampleEmbeddings := []Embedding{ + {Embedding: []float32{1.23, 4.56, 7.89}}, + {Embedding: []float32{-0.006968617, -0.0052718227, 0.011901081}}, + } + + sampleBase64Embeddings := []Base64Embedding{ + {Embedding: "pHCdP4XrkUDhevxA"}, + {Embedding: "/1jku0G/rLvA/EI8"}, + } + server.RegisterHandler( "/v1/embeddings", func(w http.ResponseWriter, r *http.Request) { - resBytes, _ := json.Marshal(EmbeddingResponse{}) + var req struct { + EncodingFormat EmbeddingEncodingFormat `json:"encoding_format"` + User string `json:"user"` + } + _ = json.NewDecoder(r.Body).Decode(&req) + + var resBytes []byte + switch { + case req.User == "invalid": + w.WriteHeader(http.StatusBadRequest) + return + case req.EncodingFormat == EmbeddingEncodingFormatBase64: + resBytes, _ = json.Marshal(EmbeddingResponseBase64{Data: sampleBase64Embeddings}) + default: + resBytes, _ = json.Marshal(EmbeddingResponse{Data: sampleEmbeddings}) + } fmt.Fprintln(w, string(resBytes)) }, ) // test create embeddings with strings (simple embedding request) - _, err := client.CreateEmbeddings(context.Background(), EmbeddingRequest{}) + res, err := client.CreateEmbeddings(context.Background(), EmbeddingRequest{}) + checks.NoError(t, err, "CreateEmbeddings error") + if !reflect.DeepEqual(res.Data, sampleEmbeddings) { + t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data) + } + + // test create embeddings with strings (simple embedding request) + res, err = client.CreateEmbeddings( + context.Background(), + EmbeddingRequest{ + EncodingFormat: EmbeddingEncodingFormatBase64, + }, + ) checks.NoError(t, err, "CreateEmbeddings error") + if !reflect.DeepEqual(res.Data, sampleEmbeddings) { + t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data) + } // test create embeddings with strings - _, err = client.CreateEmbeddings(context.Background(), EmbeddingRequestStrings{}) + res, err = client.CreateEmbeddings(context.Background(), EmbeddingRequestStrings{}) checks.NoError(t, err, "CreateEmbeddings strings error") + if !reflect.DeepEqual(res.Data, sampleEmbeddings) { + t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data) + } // test create embeddings with tokens - _, err = client.CreateEmbeddings(context.Background(), EmbeddingRequestTokens{}) + res, err = client.CreateEmbeddings(context.Background(), EmbeddingRequestTokens{}) checks.NoError(t, err, "CreateEmbeddings tokens error") + if !reflect.DeepEqual(res.Data, sampleEmbeddings) { + t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data) + } + + // test failed sendRequest + _, err = client.CreateEmbeddings(context.Background(), EmbeddingRequest{ + User: "invalid", + EncodingFormat: EmbeddingEncodingFormatBase64, + }) + checks.HasError(t, err, "CreateEmbeddings error") +} + +func TestEmbeddingResponseBase64_ToEmbeddingResponse(t *testing.T) { + type fields struct { + Object string + Data []Base64Embedding + Model EmbeddingModel + Usage Usage + } + tests := []struct { + name string + fields fields + want EmbeddingResponse + wantErr bool + }{ + { + name: "test embedding response base64 to embedding response", + fields: fields{ + Data: []Base64Embedding{ + {Embedding: "pHCdP4XrkUDhevxA"}, + {Embedding: "/1jku0G/rLvA/EI8"}, + }, + }, + want: EmbeddingResponse{ + Data: []Embedding{ + {Embedding: []float32{1.23, 4.56, 7.89}}, + {Embedding: []float32{-0.006968617, -0.0052718227, 0.011901081}}, + }, + }, + wantErr: false, + }, + { + name: "Invalid embedding", + fields: fields{ + Data: []Base64Embedding{ + { + Embedding: "----", + }, + }, + }, + want: EmbeddingResponse{}, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &EmbeddingResponseBase64{ + Object: tt.fields.Object, + Data: tt.fields.Data, + Model: tt.fields.Model, + Usage: tt.fields.Usage, + } + got, err := r.ToEmbeddingResponse() + if (err != nil) != tt.wantErr { + t.Errorf("EmbeddingResponseBase64.ToEmbeddingResponse() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("EmbeddingResponseBase64.ToEmbeddingResponse() = %v, want %v", got, tt.want) + } + }) + } } From 0d5256fb820a34a95b8944b9410a1e562087cd8f Mon Sep 17 00:00:00 2001 From: Brendan Martin Date: Mon, 25 Sep 2023 04:08:45 -0400 Subject: [PATCH 05/98] added delete fine tune model endpoint (#497) --- client_test.go | 3 +++ models.go | 20 ++++++++++++++++++++ models_test.go | 15 +++++++++++++++ 3 files changed, 38 insertions(+) diff --git a/client_test.go b/client_test.go index 9b504689..2c1d749e 100644 --- a/client_test.go +++ b/client_test.go @@ -271,6 +271,9 @@ func TestClientReturnsRequestBuilderErrors(t *testing.T) { {"GetModel", func() (any, error) { return client.GetModel(ctx, "text-davinci-003") }}, + {"DeleteFineTuneModel", func() (any, error) { + return client.DeleteFineTuneModel(ctx, "") + }}, } for _, testCase := range testCases { diff --git a/models.go b/models.go index 560402e3..c207f0a8 100644 --- a/models.go +++ b/models.go @@ -33,6 +33,13 @@ type Permission struct { IsBlocking bool `json:"is_blocking"` } +// FineTuneModelDeleteResponse represents the deletion status of a fine-tuned model. +type FineTuneModelDeleteResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Deleted bool `json:"deleted"` +} + // ModelsList is a list of models, including those that belong to the user or organization. type ModelsList struct { Models []Model `json:"data"` @@ -62,3 +69,16 @@ func (c *Client) GetModel(ctx context.Context, modelID string) (model Model, err err = c.sendRequest(req, &model) return } + +// DeleteFineTuneModel Deletes a fine-tune model. You must have the Owner +// role in your organization to delete a model. +func (c *Client) DeleteFineTuneModel(ctx context.Context, modelID string) ( + response FineTuneModelDeleteResponse, err error) { + req, err := c.newRequest(ctx, http.MethodDelete, c.fullURL("/models/"+modelID)) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} diff --git a/models_test.go b/models_test.go index 59b4f5ef..9ff73042 100644 --- a/models_test.go +++ b/models_test.go @@ -14,6 +14,8 @@ import ( "testing" ) +const testFineTuneModelID = "fine-tune-model-id" + // TestListModels Tests the list models endpoint of the API using the mocked server. func TestListModels(t *testing.T) { client, server, teardown := setupOpenAITestServer() @@ -78,3 +80,16 @@ func TestGetModelReturnTimeoutError(t *testing.T) { t.Fatal("Did not return timeout error") } } + +func TestDeleteFineTuneModel(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/models/"+testFineTuneModelID, handleDeleteFineTuneModelEndpoint) + _, err := client.DeleteFineTuneModel(context.Background(), testFineTuneModelID) + checks.NoError(t, err, "DeleteFineTuneModel error") +} + +func handleDeleteFineTuneModelEndpoint(w http.ResponseWriter, _ *http.Request) { + resBytes, _ := json.Marshal(FineTuneModelDeleteResponse{}) + fmt.Fprintln(w, string(resBytes)) +} From 84f77a0acda6eb541f3312ed8f7711c89e661443 Mon Sep 17 00:00:00 2001 From: "e. alvarez" <55966724+ealvar3z@users.noreply.github.com> Date: Mon, 2 Oct 2023 07:39:10 -0700 Subject: [PATCH 06/98] Add DotProduct Method and README Example for Embedding Similarity Search (#492) * Add DotProduct Method and README Example for Embedding Similarity Search - Implement a DotProduct() method for the Embedding struct to calculate the dot product between two embeddings. - Add a custom error type for vector length mismatch. - Update README.md with a complete example demonstrating how to perform an embedding similarity search for user queries. - Add unit tests to validate the new DotProduct() method and error handling. * Update README to focus on Embedding Semantic Similarity --- README.md | 56 ++++++++++++++++++++++++++++++++++++++++++++++ embeddings.go | 20 +++++++++++++++++ embeddings_test.go | 38 +++++++++++++++++++++++++++++++ 3 files changed, 114 insertions(+) diff --git a/README.md b/README.md index 440c4096..c618cd7f 100644 --- a/README.md +++ b/README.md @@ -483,6 +483,62 @@ func main() { ``` + +Embedding Semantic Similarity + +```go +package main + +import ( + "context" + "log" + openai "github.com/sashabaranov/go-openai" + +) + +func main() { + client := openai.NewClient("your-token") + + // Create an EmbeddingRequest for the user query + queryReq := openai.EmbeddingRequest{ + Input: []string{"How many chucks would a woodchuck chuck"}, + Model: openai.AdaEmbeddingv2, + } + + // Create an embedding for the user query + queryResponse, err := client.CreateEmbeddings(context.Background(), queryReq) + if err != nil { + log.Fatal("Error creating query embedding:", err) + } + + // Create an EmbeddingRequest for the target text + targetReq := openai.EmbeddingRequest{ + Input: []string{"How many chucks would a woodchuck chuck if the woodchuck could chuck wood"}, + Model: openai.AdaEmbeddingv2, + } + + // Create an embedding for the target text + targetResponse, err := client.CreateEmbeddings(context.Background(), targetReq) + if err != nil { + log.Fatal("Error creating target embedding:", err) + } + + // Now that we have the embeddings for the user query and the target text, we + // can calculate their similarity. + queryEmbedding := queryResponse.Data[0] + targetEmbedding := targetResponse.Data[0] + + similarity, err := queryEmbedding.DotProduct(&targetEmbedding) + if err != nil { + log.Fatal("Error calculating dot product:", err) + } + + log.Printf("The similarity score between the query and the target is %f", similarity) +} + +``` + +
Azure OpenAI Embeddings diff --git a/embeddings.go b/embeddings.go index 5ba91f23..660bc24c 100644 --- a/embeddings.go +++ b/embeddings.go @@ -4,10 +4,13 @@ import ( "context" "encoding/base64" "encoding/binary" + "errors" "math" "net/http" ) +var ErrVectorLengthMismatch = errors.New("vector length mismatch") + // EmbeddingModel enumerates the models which can be used // to generate Embedding vectors. type EmbeddingModel int @@ -124,6 +127,23 @@ type Embedding struct { Index int `json:"index"` } +// DotProduct calculates the dot product of the embedding vector with another +// embedding vector. Both vectors must have the same length; otherwise, an +// ErrVectorLengthMismatch is returned. The method returns the calculated dot +// product as a float32 value. +func (e *Embedding) DotProduct(other *Embedding) (float32, error) { + if len(e.Embedding) != len(other.Embedding) { + return 0, ErrVectorLengthMismatch + } + + var dotProduct float32 + for i := range e.Embedding { + dotProduct += e.Embedding[i] * other.Embedding[i] + } + + return dotProduct, nil +} + // EmbeddingResponse is the response from a Create embeddings request. type EmbeddingResponse struct { Object string `json:"object"` diff --git a/embeddings_test.go b/embeddings_test.go index 9c48c5b8..72e8c245 100644 --- a/embeddings_test.go +++ b/embeddings_test.go @@ -4,7 +4,9 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" + "math" "net/http" "reflect" "testing" @@ -233,3 +235,39 @@ func TestEmbeddingResponseBase64_ToEmbeddingResponse(t *testing.T) { }) } } + +func TestDotProduct(t *testing.T) { + v1 := &Embedding{Embedding: []float32{1, 2, 3}} + v2 := &Embedding{Embedding: []float32{2, 4, 6}} + expected := float32(28.0) + + result, err := v1.DotProduct(v2) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + if math.Abs(float64(result-expected)) > 1e-12 { + t.Errorf("Unexpected result. Expected: %v, but got %v", expected, result) + } + + v1 = &Embedding{Embedding: []float32{1, 0, 0}} + v2 = &Embedding{Embedding: []float32{0, 1, 0}} + expected = float32(0.0) + + result, err = v1.DotProduct(v2) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + if math.Abs(float64(result-expected)) > 1e-12 { + t.Errorf("Unexpected result. Expected: %v, but got %v", expected, result) + } + + // Test for VectorLengthMismatchError + v1 = &Embedding{Embedding: []float32{1, 0, 0}} + v2 = &Embedding{Embedding: []float32{0, 1}} + _, err = v1.DotProduct(v2) + if !errors.Is(err, ErrVectorLengthMismatch) { + t.Errorf("Expected Vector Length Mismatch Error, but got: %v", err) + } +} From 533935e4fc31f2542ef77d3e545a527c756b641c Mon Sep 17 00:00:00 2001 From: Simone Vellei Date: Fri, 6 Oct 2023 11:32:21 +0200 Subject: [PATCH 07/98] fix: use any for n_epochs (#499) * fix: use custom marshaler for n_epochs * chore: use any for n_epochs --- fine_tuning_job.go | 2 +- fine_tuning_job_test.go | 19 +++++++++++++++++-- 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/fine_tuning_job.go b/fine_tuning_job.go index a840b7ec..07b0c337 100644 --- a/fine_tuning_job.go +++ b/fine_tuning_job.go @@ -24,7 +24,7 @@ type FineTuningJob struct { } type Hyperparameters struct { - Epochs int `json:"n_epochs"` + Epochs any `json:"n_epochs,omitempty"` } type FineTuningJobRequest struct { diff --git a/fine_tuning_job_test.go b/fine_tuning_job_test.go index 519c6cd2..f6d41c33 100644 --- a/fine_tuning_job_test.go +++ b/fine_tuning_job_test.go @@ -21,8 +21,23 @@ func TestFineTuningJob(t *testing.T) { server.RegisterHandler( "/v1/fine_tuning/jobs", func(w http.ResponseWriter, r *http.Request) { - var resBytes []byte - resBytes, _ = json.Marshal(FineTuningJob{}) + resBytes, _ := json.Marshal(FineTuningJob{ + Object: "fine_tuning.job", + ID: testFineTuninigJobID, + Model: "davinci-002", + CreatedAt: 1692661014, + FinishedAt: 1692661190, + FineTunedModel: "ft:davinci-002:my-org:custom_suffix:7q8mpxmy", + OrganizationID: "org-123", + ResultFiles: []string{"file-abc123"}, + Status: "succeeded", + ValidationFile: "", + TrainingFile: "file-abc123", + Hyperparameters: Hyperparameters{ + Epochs: "auto", + }, + TrainedTokens: 5768, + }) fmt.Fprintln(w, string(resBytes)) }, ) From 8e165dc9aadc9f7045b91dd1b02d6404940dc023 Mon Sep 17 00:00:00 2001 From: Simone Vellei Date: Mon, 9 Oct 2023 17:41:54 +0200 Subject: [PATCH 08/98] Feat Add headers to openai responses (#506) * feat: add headers to http response * chore: add test * fix: rename to httpHeader --- audio.go | 19 ++++++++++++++++++- chat.go | 2 ++ chat_test.go | 30 ++++++++++++++++++++++++++++++ client.go | 20 +++++++++++++++++++- completion.go | 2 ++ edits.go | 2 ++ embeddings.go | 4 ++++ engines.go | 4 ++++ files.go | 4 ++++ fine_tunes.go | 8 ++++++++ fine_tuning_job.go | 4 ++++ image.go | 2 ++ models.go | 6 ++++++ moderation.go | 2 ++ 14 files changed, 107 insertions(+), 2 deletions(-) diff --git a/audio.go b/audio.go index 9f469159..4cbe4fe6 100644 --- a/audio.go +++ b/audio.go @@ -63,6 +63,21 @@ type AudioResponse struct { Transient bool `json:"transient"` } `json:"segments"` Text string `json:"text"` + + httpHeader +} + +type audioTextResponse struct { + Text string `json:"text"` + + httpHeader +} + +func (r *audioTextResponse) ToAudioResponse() AudioResponse { + return AudioResponse{ + Text: r.Text, + httpHeader: r.httpHeader, + } } // CreateTranscription — API call to create a transcription. Returns transcribed text. @@ -104,7 +119,9 @@ func (c *Client) callAudioAPI( if request.HasJSONResponse() { err = c.sendRequest(req, &response) } else { - err = c.sendRequest(req, &response.Text) + var textResponse audioTextResponse + err = c.sendRequest(req, &textResponse) + response = textResponse.ToAudioResponse() } if err != nil { return AudioResponse{}, err diff --git a/chat.go b/chat.go index 8d29b323..df0e5f97 100644 --- a/chat.go +++ b/chat.go @@ -142,6 +142,8 @@ type ChatCompletionResponse struct { Model string `json:"model"` Choices []ChatCompletionChoice `json:"choices"` Usage Usage `json:"usage"` + + httpHeader } // CreateChatCompletion — API call to Create a completion for the chat message. diff --git a/chat_test.go b/chat_test.go index 38d66fa6..52cd0bde 100644 --- a/chat_test.go +++ b/chat_test.go @@ -16,6 +16,11 @@ import ( "github.com/sashabaranov/go-openai/jsonschema" ) +const ( + xCustomHeader = "X-CUSTOM-HEADER" + xCustomHeaderValue = "test" +) + func TestChatCompletionsWrongModel(t *testing.T) { config := DefaultConfig("whatever") config.BaseURL = "http://localhost/v1" @@ -68,6 +73,30 @@ func TestChatCompletions(t *testing.T) { checks.NoError(t, err, "CreateChatCompletion error") } +// TestCompletions Tests the completions endpoint of the API using the mocked server. +func TestChatCompletionsWithHeaders(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint) + resp, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{ + MaxTokens: 5, + Model: GPT3Dot5Turbo, + Messages: []ChatCompletionMessage{ + { + Role: ChatMessageRoleUser, + Content: "Hello!", + }, + }, + }) + checks.NoError(t, err, "CreateChatCompletion error") + + a := resp.Header().Get(xCustomHeader) + _ = a + if resp.Header().Get(xCustomHeader) != xCustomHeaderValue { + t.Errorf("expected header %s to be %s", xCustomHeader, xCustomHeaderValue) + } +} + // TestChatCompletionsFunctions tests including a function call. func TestChatCompletionsFunctions(t *testing.T) { client, server, teardown := setupOpenAITestServer() @@ -281,6 +310,7 @@ func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) { TotalTokens: inputTokens + completionTokens, } resBytes, _ = json.Marshal(res) + w.Header().Set(xCustomHeader, xCustomHeaderValue) fmt.Fprintln(w, string(resBytes)) } diff --git a/client.go b/client.go index 5779a8e1..19902285 100644 --- a/client.go +++ b/client.go @@ -20,6 +20,20 @@ type Client struct { createFormBuilder func(io.Writer) utils.FormBuilder } +type Response interface { + SetHeader(http.Header) +} + +type httpHeader http.Header + +func (h *httpHeader) SetHeader(header http.Header) { + *h = httpHeader(header) +} + +func (h httpHeader) Header() http.Header { + return http.Header(h) +} + // NewClient creates new OpenAI API client. func NewClient(authToken string) *Client { config := DefaultConfig(authToken) @@ -82,7 +96,7 @@ func (c *Client) newRequest(ctx context.Context, method, url string, setters ... return req, nil } -func (c *Client) sendRequest(req *http.Request, v any) error { +func (c *Client) sendRequest(req *http.Request, v Response) error { req.Header.Set("Accept", "application/json; charset=utf-8") // Check whether Content-Type is already set, Upload Files API requires @@ -103,6 +117,10 @@ func (c *Client) sendRequest(req *http.Request, v any) error { return c.handleErrorResp(res) } + if v != nil { + v.SetHeader(res.Header) + } + return decodeResponse(res.Body, v) } diff --git a/completion.go b/completion.go index 7b9ae89e..c7ff94af 100644 --- a/completion.go +++ b/completion.go @@ -154,6 +154,8 @@ type CompletionResponse struct { Model string `json:"model"` Choices []CompletionChoice `json:"choices"` Usage Usage `json:"usage"` + + httpHeader } // CreateCompletion — API call to create a completion. This is the main endpoint of the API. Returns new text as well diff --git a/edits.go b/edits.go index 831aade2..97d02602 100644 --- a/edits.go +++ b/edits.go @@ -28,6 +28,8 @@ type EditsResponse struct { Created int64 `json:"created"` Usage Usage `json:"usage"` Choices []EditsChoice `json:"choices"` + + httpHeader } // Edits Perform an API call to the Edits endpoint. diff --git a/embeddings.go b/embeddings.go index 660bc24c..7e2aa7eb 100644 --- a/embeddings.go +++ b/embeddings.go @@ -150,6 +150,8 @@ type EmbeddingResponse struct { Data []Embedding `json:"data"` Model EmbeddingModel `json:"model"` Usage Usage `json:"usage"` + + httpHeader } type base64String string @@ -182,6 +184,8 @@ type EmbeddingResponseBase64 struct { Data []Base64Embedding `json:"data"` Model EmbeddingModel `json:"model"` Usage Usage `json:"usage"` + + httpHeader } // ToEmbeddingResponse converts an embeddingResponseBase64 to an EmbeddingResponse. diff --git a/engines.go b/engines.go index adf6025c..5a0dba85 100644 --- a/engines.go +++ b/engines.go @@ -12,11 +12,15 @@ type Engine struct { Object string `json:"object"` Owner string `json:"owner"` Ready bool `json:"ready"` + + httpHeader } // EnginesList is a list of engines. type EnginesList struct { Engines []Engine `json:"data"` + + httpHeader } // ListEngines Lists the currently available engines, and provides basic diff --git a/files.go b/files.go index 8b933c36..9e521fbb 100644 --- a/files.go +++ b/files.go @@ -25,11 +25,15 @@ type File struct { Status string `json:"status"` Purpose string `json:"purpose"` StatusDetails string `json:"status_details"` + + httpHeader } // FilesList is a list of files that belong to the user or organization. type FilesList struct { Files []File `json:"data"` + + httpHeader } // CreateFile uploads a jsonl file to GPT3 diff --git a/fine_tunes.go b/fine_tunes.go index 7d3b59db..ca840781 100644 --- a/fine_tunes.go +++ b/fine_tunes.go @@ -41,6 +41,8 @@ type FineTune struct { ValidationFiles []File `json:"validation_files"` TrainingFiles []File `json:"training_files"` UpdatedAt int64 `json:"updated_at"` + + httpHeader } // Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. @@ -69,6 +71,8 @@ type FineTuneHyperParams struct { type FineTuneList struct { Object string `json:"object"` Data []FineTune `json:"data"` + + httpHeader } // Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. @@ -77,6 +81,8 @@ type FineTuneList struct { type FineTuneEventList struct { Object string `json:"object"` Data []FineTuneEvent `json:"data"` + + httpHeader } // Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. @@ -86,6 +92,8 @@ type FineTuneDeleteResponse struct { ID string `json:"id"` Object string `json:"object"` Deleted bool `json:"deleted"` + + httpHeader } // Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. diff --git a/fine_tuning_job.go b/fine_tuning_job.go index 07b0c337..9dcb49de 100644 --- a/fine_tuning_job.go +++ b/fine_tuning_job.go @@ -21,6 +21,8 @@ type FineTuningJob struct { ValidationFile string `json:"validation_file,omitempty"` ResultFiles []string `json:"result_files"` TrainedTokens int `json:"trained_tokens"` + + httpHeader } type Hyperparameters struct { @@ -39,6 +41,8 @@ type FineTuningJobEventList struct { Object string `json:"object"` Data []FineTuneEvent `json:"data"` HasMore bool `json:"has_more"` + + httpHeader } type FineTuningJobEvent struct { diff --git a/image.go b/image.go index cb96f4f5..4addcdb1 100644 --- a/image.go +++ b/image.go @@ -33,6 +33,8 @@ type ImageRequest struct { type ImageResponse struct { Created int64 `json:"created,omitempty"` Data []ImageResponseDataInner `json:"data,omitempty"` + + httpHeader } // ImageResponseDataInner represents a response data structure for image API. diff --git a/models.go b/models.go index c207f0a8..d94f9883 100644 --- a/models.go +++ b/models.go @@ -15,6 +15,8 @@ type Model struct { Permission []Permission `json:"permission"` Root string `json:"root"` Parent string `json:"parent"` + + httpHeader } // Permission struct represents an OpenAPI permission. @@ -38,11 +40,15 @@ type FineTuneModelDeleteResponse struct { ID string `json:"id"` Object string `json:"object"` Deleted bool `json:"deleted"` + + httpHeader } // ModelsList is a list of models, including those that belong to the user or organization. type ModelsList struct { Models []Model `json:"data"` + + httpHeader } // ListModels Lists the currently available models, diff --git a/moderation.go b/moderation.go index a32f123f..f8d20ee5 100644 --- a/moderation.go +++ b/moderation.go @@ -69,6 +69,8 @@ type ModerationResponse struct { ID string `json:"id"` Model string `json:"model"` Results []Result `json:"results"` + + httpHeader } // Moderations — perform a moderation api call over a string. From b77d01edca43500f267c4b43333f645b84a4fcf0 Mon Sep 17 00:00:00 2001 From: Liu Shuang Date: Tue, 10 Oct 2023 10:29:41 -0500 Subject: [PATCH 09/98] Support get http header and x-ratelimit-* headers (#507) * feat: add headers to http response * feat: support rate limit headers * fix: go lint * fix: test coverage * refactor streamReader * refactor streamReader * refactor: NewRateLimitHeaders to newRateLimitHeaders * refactor: RateLimitHeaders Resets filed * refactor: move RateLimitHeaders struct --- chat_stream_test.go | 89 +++++++++++++++++++++++++++++++++++++++++++-- chat_test.go | 53 +++++++++++++++++++++++++++ client.go | 9 ++++- ratelimit.go | 43 ++++++++++++++++++++++ stream_reader.go | 2 + 5 files changed, 191 insertions(+), 5 deletions(-) create mode 100644 ratelimit.go diff --git a/chat_stream_test.go b/chat_stream_test.go index 5fc70b03..2c109d45 100644 --- a/chat_stream_test.go +++ b/chat_stream_test.go @@ -1,15 +1,17 @@ package openai_test import ( - . "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test/checks" - "context" "encoding/json" "errors" + "fmt" "io" "net/http" + "strconv" "testing" + + . "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" ) func TestChatCompletionsStreamWrongModel(t *testing.T) { @@ -178,6 +180,87 @@ func TestCreateChatCompletionStreamError(t *testing.T) { t.Logf("%+v\n", apiErr) } +func TestCreateChatCompletionStreamWithHeaders(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set(xCustomHeader, xCustomHeaderValue) + + // Send test responses + //nolint:lll + dataBytes := []byte(`data: {"error":{"message":"The server had an error while processing your request. Sorry about that!", "type":"server_ error", "param":null,"code":null}}`) + dataBytes = append(dataBytes, []byte("\n\ndata: [DONE]\n\n")...) + + _, err := w.Write(dataBytes) + checks.NoError(t, err, "Write error") + }) + + stream, err := client.CreateChatCompletionStream(context.Background(), ChatCompletionRequest{ + MaxTokens: 5, + Model: GPT3Dot5Turbo, + Messages: []ChatCompletionMessage{ + { + Role: ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Stream: true, + }) + checks.NoError(t, err, "CreateCompletionStream returned error") + defer stream.Close() + + value := stream.Header().Get(xCustomHeader) + if value != xCustomHeaderValue { + t.Errorf("expected %s to be %s", xCustomHeaderValue, value) + } +} + +func TestCreateChatCompletionStreamWithRatelimitHeaders(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + for k, v := range rateLimitHeaders { + switch val := v.(type) { + case int: + w.Header().Set(k, strconv.Itoa(val)) + default: + w.Header().Set(k, fmt.Sprintf("%s", v)) + } + } + + // Send test responses + //nolint:lll + dataBytes := []byte(`data: {"error":{"message":"The server had an error while processing your request. Sorry about that!", "type":"server_ error", "param":null,"code":null}}`) + dataBytes = append(dataBytes, []byte("\n\ndata: [DONE]\n\n")...) + + _, err := w.Write(dataBytes) + checks.NoError(t, err, "Write error") + }) + + stream, err := client.CreateChatCompletionStream(context.Background(), ChatCompletionRequest{ + MaxTokens: 5, + Model: GPT3Dot5Turbo, + Messages: []ChatCompletionMessage{ + { + Role: ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Stream: true, + }) + checks.NoError(t, err, "CreateCompletionStream returned error") + defer stream.Close() + + headers := stream.GetRateLimitHeaders() + bs1, _ := json.Marshal(headers) + bs2, _ := json.Marshal(rateLimitHeaders) + if string(bs1) != string(bs2) { + t.Errorf("expected rate limit header %s to be %s", bs2, bs1) + } +} + func TestCreateChatCompletionStreamErrorWithDataPrefix(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() diff --git a/chat_test.go b/chat_test.go index 52cd0bde..329b2b9c 100644 --- a/chat_test.go +++ b/chat_test.go @@ -21,6 +21,17 @@ const ( xCustomHeaderValue = "test" ) +var ( + rateLimitHeaders = map[string]any{ + "x-ratelimit-limit-requests": 60, + "x-ratelimit-limit-tokens": 150000, + "x-ratelimit-remaining-requests": 59, + "x-ratelimit-remaining-tokens": 149984, + "x-ratelimit-reset-requests": "1s", + "x-ratelimit-reset-tokens": "6m0s", + } +) + func TestChatCompletionsWrongModel(t *testing.T) { config := DefaultConfig("whatever") config.BaseURL = "http://localhost/v1" @@ -97,6 +108,40 @@ func TestChatCompletionsWithHeaders(t *testing.T) { } } +// TestChatCompletionsWithRateLimitHeaders Tests the completions endpoint of the API using the mocked server. +func TestChatCompletionsWithRateLimitHeaders(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint) + resp, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{ + MaxTokens: 5, + Model: GPT3Dot5Turbo, + Messages: []ChatCompletionMessage{ + { + Role: ChatMessageRoleUser, + Content: "Hello!", + }, + }, + }) + checks.NoError(t, err, "CreateChatCompletion error") + + headers := resp.GetRateLimitHeaders() + resetRequests := headers.ResetRequests.String() + if resetRequests != rateLimitHeaders["x-ratelimit-reset-requests"] { + t.Errorf("expected resetRequests %s to be %s", resetRequests, rateLimitHeaders["x-ratelimit-reset-requests"]) + } + resetRequestsTime := headers.ResetRequests.Time() + if resetRequestsTime.Before(time.Now()) { + t.Errorf("unexpected reset requetsts: %v", resetRequestsTime) + } + + bs1, _ := json.Marshal(headers) + bs2, _ := json.Marshal(rateLimitHeaders) + if string(bs1) != string(bs2) { + t.Errorf("expected rate limit header %s to be %s", bs2, bs1) + } +} + // TestChatCompletionsFunctions tests including a function call. func TestChatCompletionsFunctions(t *testing.T) { client, server, teardown := setupOpenAITestServer() @@ -311,6 +356,14 @@ func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) { } resBytes, _ = json.Marshal(res) w.Header().Set(xCustomHeader, xCustomHeaderValue) + for k, v := range rateLimitHeaders { + switch val := v.(type) { + case int: + w.Header().Set(k, strconv.Itoa(val)) + default: + w.Header().Set(k, fmt.Sprintf("%s", v)) + } + } fmt.Fprintln(w, string(resBytes)) } diff --git a/client.go b/client.go index 19902285..65ece812 100644 --- a/client.go +++ b/client.go @@ -30,8 +30,12 @@ func (h *httpHeader) SetHeader(header http.Header) { *h = httpHeader(header) } -func (h httpHeader) Header() http.Header { - return http.Header(h) +func (h *httpHeader) Header() http.Header { + return http.Header(*h) +} + +func (h *httpHeader) GetRateLimitHeaders() RateLimitHeaders { + return newRateLimitHeaders(h.Header()) } // NewClient creates new OpenAI API client. @@ -156,6 +160,7 @@ func sendRequestStream[T streamable](client *Client, req *http.Request) (*stream response: resp, errAccumulator: utils.NewErrorAccumulator(), unmarshaler: &utils.JSONUnmarshaler{}, + httpHeader: httpHeader(resp.Header), }, nil } diff --git a/ratelimit.go b/ratelimit.go new file mode 100644 index 00000000..e8953f71 --- /dev/null +++ b/ratelimit.go @@ -0,0 +1,43 @@ +package openai + +import ( + "net/http" + "strconv" + "time" +) + +// RateLimitHeaders struct represents Openai rate limits headers. +type RateLimitHeaders struct { + LimitRequests int `json:"x-ratelimit-limit-requests"` + LimitTokens int `json:"x-ratelimit-limit-tokens"` + RemainingRequests int `json:"x-ratelimit-remaining-requests"` + RemainingTokens int `json:"x-ratelimit-remaining-tokens"` + ResetRequests ResetTime `json:"x-ratelimit-reset-requests"` + ResetTokens ResetTime `json:"x-ratelimit-reset-tokens"` +} + +type ResetTime string + +func (r ResetTime) String() string { + return string(r) +} + +func (r ResetTime) Time() time.Time { + d, _ := time.ParseDuration(string(r)) + return time.Now().Add(d) +} + +func newRateLimitHeaders(h http.Header) RateLimitHeaders { + limitReq, _ := strconv.Atoi(h.Get("x-ratelimit-limit-requests")) + limitTokens, _ := strconv.Atoi(h.Get("x-ratelimit-limit-tokens")) + remainingReq, _ := strconv.Atoi(h.Get("x-ratelimit-remaining-requests")) + remainingTokens, _ := strconv.Atoi(h.Get("x-ratelimit-remaining-tokens")) + return RateLimitHeaders{ + LimitRequests: limitReq, + LimitTokens: limitTokens, + RemainingRequests: remainingReq, + RemainingTokens: remainingTokens, + ResetRequests: ResetTime(h.Get("x-ratelimit-reset-requests")), + ResetTokens: ResetTime(h.Get("x-ratelimit-reset-tokens")), + } +} diff --git a/stream_reader.go b/stream_reader.go index 87e59e0c..d1741259 100644 --- a/stream_reader.go +++ b/stream_reader.go @@ -27,6 +27,8 @@ type streamReader[T streamable] struct { response *http.Response errAccumulator utils.ErrorAccumulator unmarshaler utils.Unmarshaler + + httpHeader } func (stream *streamReader[T]) Recv() (response T, err error) { From c47ddfc1a13b850115a80b03f3f9dd1822733bf7 Mon Sep 17 00:00:00 2001 From: Alexander Baranov <677093+sashabaranov@users.noreply.github.com> Date: Tue, 10 Oct 2023 21:22:45 +0400 Subject: [PATCH 10/98] Update README.md (#511) --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index c618cd7f..b41947be 100644 --- a/README.md +++ b/README.md @@ -483,7 +483,7 @@ func main() { ```
- +
Embedding Semantic Similarity ```go @@ -537,7 +537,7 @@ func main() { } ``` - +
Azure OpenAI Embeddings From 6c52952b691ec294b7987689a5292a87a9acdbcb Mon Sep 17 00:00:00 2001 From: Simon Klee Date: Mon, 6 Nov 2023 21:22:48 +0100 Subject: [PATCH 11/98] feat(completion): add constants for new GPT models (#520) Added constants for new GPT models including `gpt-4-1106-preview`, `gpt-4-vision-preview` and `gpt-3.5-turbo-1106`. The models were announced in the following blog post: https://openai.com/blog/new-models-and-developer-products-announced-at-devday --- completion.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/completion.go b/completion.go index c7ff94af..2709c8b0 100644 --- a/completion.go +++ b/completion.go @@ -22,7 +22,10 @@ const ( GPT432K = "gpt-4-32k" GPT40613 = "gpt-4-0613" GPT40314 = "gpt-4-0314" + GPT4TurboPreview = "gpt-4-1106-preview" + GPT4VisionPreview = "gpt-4-vision-preview" GPT4 = "gpt-4" + GPT3Dot5Turbo1106 = "gpt-3.5-turbo-1106" GPT3Dot5Turbo0613 = "gpt-3.5-turbo-0613" GPT3Dot5Turbo0301 = "gpt-3.5-turbo-0301" GPT3Dot5Turbo16K = "gpt-3.5-turbo-16k" @@ -69,9 +72,12 @@ var disabledModelsForEndpoints = map[string]map[string]bool{ GPT3Dot5Turbo: true, GPT3Dot5Turbo0301: true, GPT3Dot5Turbo0613: true, + GPT3Dot5Turbo1106: true, GPT3Dot5Turbo16K: true, GPT3Dot5Turbo16K0613: true, GPT4: true, + GPT4TurboPreview: true, + GPT4VisionPreview: true, GPT40314: true, GPT40613: true, GPT432K: true, From 9e0232f941a0f2c1780bf20743effd051a39e4d3 Mon Sep 17 00:00:00 2001 From: Eli Bendersky Date: Mon, 6 Nov 2023 12:27:08 -0800 Subject: [PATCH 12/98] Fix typo in README: AdaEmbeddingV2 (#516) Copy-pasting the old sample caused compilation errors --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index b41947be..f0b60908 100644 --- a/README.md +++ b/README.md @@ -502,7 +502,7 @@ func main() { // Create an EmbeddingRequest for the user query queryReq := openai.EmbeddingRequest{ Input: []string{"How many chucks would a woodchuck chuck"}, - Model: openai.AdaEmbeddingv2, + Model: openai.AdaEmbeddingV2, } // Create an embedding for the user query @@ -514,7 +514,7 @@ func main() { // Create an EmbeddingRequest for the target text targetReq := openai.EmbeddingRequest{ Input: []string{"How many chucks would a woodchuck chuck if the woodchuck could chuck wood"}, - Model: openai.AdaEmbeddingv2, + Model: openai.AdaEmbeddingV2, } // Create an embedding for the target text From 0664105387f52c99b13bb40fcbf966a8b8c8d838 Mon Sep 17 00:00:00 2001 From: Simon Klee Date: Tue, 7 Nov 2023 10:23:06 +0100 Subject: [PATCH 13/98] lint: fix linter warnings reported by golangci-lint (#522) - Fix #519 --- api_integration_test.go | 1 - audio_api_test.go | 14 ++-- audio_test.go | 2 +- chat_stream_test.go | 110 ++++++++++++++-------------- chat_test.go | 154 ++++++++++++++++++++-------------------- completion_test.go | 42 +++++------ config_test.go | 4 +- edits_test.go | 24 +++---- embeddings_test.go | 110 ++++++++++++++-------------- engines_test.go | 12 ++-- error_test.go | 60 ++++++++-------- example_test.go | 2 - files_api_test.go | 12 ++-- files_test.go | 6 +- fine_tunes.go | 1 + fine_tunes_test.go | 24 +++---- fine_tuning_job_test.go | 35 +++++---- image_api_test.go | 52 +++++++------- jsonschema/json_test.go | 62 ++++++++-------- models_test.go | 17 +++-- moderation_test.go | 52 +++++++------- openai_test.go | 14 ++-- stream_test.go | 46 ++++++------ 23 files changed, 425 insertions(+), 431 deletions(-) diff --git a/api_integration_test.go b/api_integration_test.go index 254fbeb0..6be188bc 100644 --- a/api_integration_test.go +++ b/api_integration_test.go @@ -9,7 +9,6 @@ import ( "os" "testing" - . "github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai/internal/test/checks" "github.com/sashabaranov/go-openai/jsonschema" ) diff --git a/audio_api_test.go b/audio_api_test.go index aad7a225..a0efc792 100644 --- a/audio_api_test.go +++ b/audio_api_test.go @@ -12,7 +12,7 @@ import ( "strings" "testing" - . "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai/internal/test" "github.com/sashabaranov/go-openai/internal/test/checks" ) @@ -26,7 +26,7 @@ func TestAudio(t *testing.T) { testcases := []struct { name string - createFn func(context.Context, AudioRequest) (AudioResponse, error) + createFn func(context.Context, openai.AudioRequest) (openai.AudioResponse, error) }{ { "transcribe", @@ -48,7 +48,7 @@ func TestAudio(t *testing.T) { path := filepath.Join(dir, "fake.mp3") test.CreateTestFile(t, path) - req := AudioRequest{ + req := openai.AudioRequest{ FilePath: path, Model: "whisper-3", } @@ -57,7 +57,7 @@ func TestAudio(t *testing.T) { }) t.Run(tc.name+" (with reader)", func(t *testing.T) { - req := AudioRequest{ + req := openai.AudioRequest{ FilePath: "fake.webm", Reader: bytes.NewBuffer([]byte(`some webm binary data`)), Model: "whisper-3", @@ -76,7 +76,7 @@ func TestAudioWithOptionalArgs(t *testing.T) { testcases := []struct { name string - createFn func(context.Context, AudioRequest) (AudioResponse, error) + createFn func(context.Context, openai.AudioRequest) (openai.AudioResponse, error) }{ { "transcribe", @@ -98,13 +98,13 @@ func TestAudioWithOptionalArgs(t *testing.T) { path := filepath.Join(dir, "fake.mp3") test.CreateTestFile(t, path) - req := AudioRequest{ + req := openai.AudioRequest{ FilePath: path, Model: "whisper-3", Prompt: "用简体中文", Temperature: 0.5, Language: "zh", - Format: AudioResponseFormatSRT, + Format: openai.AudioResponseFormatSRT, } _, err := tc.createFn(ctx, req) checks.NoError(t, err, "audio API error") diff --git a/audio_test.go b/audio_test.go index e19a873f..5346244c 100644 --- a/audio_test.go +++ b/audio_test.go @@ -40,7 +40,7 @@ func TestAudioWithFailingFormBuilder(t *testing.T) { } var failForField string - mockBuilder.mockWriteField = func(fieldname, value string) error { + mockBuilder.mockWriteField = func(fieldname, _ string) error { if fieldname == failForField { return mockFailedErr } diff --git a/chat_stream_test.go b/chat_stream_test.go index 2c109d45..bd571cb4 100644 --- a/chat_stream_test.go +++ b/chat_stream_test.go @@ -10,28 +10,28 @@ import ( "strconv" "testing" - . "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai/internal/test/checks" ) func TestChatCompletionsStreamWrongModel(t *testing.T) { - config := DefaultConfig("whatever") + config := openai.DefaultConfig("whatever") config.BaseURL = "http://localhost/v1" - client := NewClientWithConfig(config) + client := openai.NewClientWithConfig(config) ctx := context.Background() - req := ChatCompletionRequest{ + req := openai.ChatCompletionRequest{ MaxTokens: 5, Model: "ada", - Messages: []ChatCompletionMessage{ + Messages: []openai.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: openai.ChatMessageRoleUser, Content: "Hello!", }, }, } _, err := client.CreateChatCompletionStream(ctx, req) - if !errors.Is(err, ErrChatCompletionInvalidModel) { + if !errors.Is(err, openai.ErrChatCompletionInvalidModel) { t.Fatalf("CreateChatCompletion should return ErrChatCompletionInvalidModel, but returned: %v", err) } } @@ -39,7 +39,7 @@ func TestChatCompletionsStreamWrongModel(t *testing.T) { func TestCreateChatCompletionStream(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() - server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { + server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "text/event-stream") // Send test responses @@ -61,12 +61,12 @@ func TestCreateChatCompletionStream(t *testing.T) { checks.NoError(t, err, "Write error") }) - stream, err := client.CreateChatCompletionStream(context.Background(), ChatCompletionRequest{ + stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{ MaxTokens: 5, - Model: GPT3Dot5Turbo, - Messages: []ChatCompletionMessage{ + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: openai.ChatMessageRoleUser, Content: "Hello!", }, }, @@ -75,15 +75,15 @@ func TestCreateChatCompletionStream(t *testing.T) { checks.NoError(t, err, "CreateCompletionStream returned error") defer stream.Close() - expectedResponses := []ChatCompletionStreamResponse{ + expectedResponses := []openai.ChatCompletionStreamResponse{ { ID: "1", Object: "completion", Created: 1598069254, - Model: GPT3Dot5Turbo, - Choices: []ChatCompletionStreamChoice{ + Model: openai.GPT3Dot5Turbo, + Choices: []openai.ChatCompletionStreamChoice{ { - Delta: ChatCompletionStreamChoiceDelta{ + Delta: openai.ChatCompletionStreamChoiceDelta{ Content: "response1", }, FinishReason: "max_tokens", @@ -94,10 +94,10 @@ func TestCreateChatCompletionStream(t *testing.T) { ID: "2", Object: "completion", Created: 1598069255, - Model: GPT3Dot5Turbo, - Choices: []ChatCompletionStreamChoice{ + Model: openai.GPT3Dot5Turbo, + Choices: []openai.ChatCompletionStreamChoice{ { - Delta: ChatCompletionStreamChoiceDelta{ + Delta: openai.ChatCompletionStreamChoiceDelta{ Content: "response2", }, FinishReason: "max_tokens", @@ -133,7 +133,7 @@ func TestCreateChatCompletionStream(t *testing.T) { func TestCreateChatCompletionStreamError(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() - server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { + server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "text/event-stream") // Send test responses @@ -156,12 +156,12 @@ func TestCreateChatCompletionStreamError(t *testing.T) { checks.NoError(t, err, "Write error") }) - stream, err := client.CreateChatCompletionStream(context.Background(), ChatCompletionRequest{ + stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{ MaxTokens: 5, - Model: GPT3Dot5Turbo, - Messages: []ChatCompletionMessage{ + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: openai.ChatMessageRoleUser, Content: "Hello!", }, }, @@ -173,7 +173,7 @@ func TestCreateChatCompletionStreamError(t *testing.T) { _, streamErr := stream.Recv() checks.HasError(t, streamErr, "stream.Recv() did not return error") - var apiErr *APIError + var apiErr *openai.APIError if !errors.As(streamErr, &apiErr) { t.Errorf("stream.Recv() did not return APIError") } @@ -183,7 +183,7 @@ func TestCreateChatCompletionStreamError(t *testing.T) { func TestCreateChatCompletionStreamWithHeaders(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() - server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { + server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "text/event-stream") w.Header().Set(xCustomHeader, xCustomHeaderValue) @@ -196,12 +196,12 @@ func TestCreateChatCompletionStreamWithHeaders(t *testing.T) { checks.NoError(t, err, "Write error") }) - stream, err := client.CreateChatCompletionStream(context.Background(), ChatCompletionRequest{ + stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{ MaxTokens: 5, - Model: GPT3Dot5Turbo, - Messages: []ChatCompletionMessage{ + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: openai.ChatMessageRoleUser, Content: "Hello!", }, }, @@ -219,7 +219,7 @@ func TestCreateChatCompletionStreamWithHeaders(t *testing.T) { func TestCreateChatCompletionStreamWithRatelimitHeaders(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() - server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { + server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "text/event-stream") for k, v := range rateLimitHeaders { switch val := v.(type) { @@ -239,12 +239,12 @@ func TestCreateChatCompletionStreamWithRatelimitHeaders(t *testing.T) { checks.NoError(t, err, "Write error") }) - stream, err := client.CreateChatCompletionStream(context.Background(), ChatCompletionRequest{ + stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{ MaxTokens: 5, - Model: GPT3Dot5Turbo, - Messages: []ChatCompletionMessage{ + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: openai.ChatMessageRoleUser, Content: "Hello!", }, }, @@ -264,7 +264,7 @@ func TestCreateChatCompletionStreamWithRatelimitHeaders(t *testing.T) { func TestCreateChatCompletionStreamErrorWithDataPrefix(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() - server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { + server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "text/event-stream") // Send test responses @@ -276,12 +276,12 @@ func TestCreateChatCompletionStreamErrorWithDataPrefix(t *testing.T) { checks.NoError(t, err, "Write error") }) - stream, err := client.CreateChatCompletionStream(context.Background(), ChatCompletionRequest{ + stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{ MaxTokens: 5, - Model: GPT3Dot5Turbo, - Messages: []ChatCompletionMessage{ + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: openai.ChatMessageRoleUser, Content: "Hello!", }, }, @@ -293,7 +293,7 @@ func TestCreateChatCompletionStreamErrorWithDataPrefix(t *testing.T) { _, streamErr := stream.Recv() checks.HasError(t, streamErr, "stream.Recv() did not return error") - var apiErr *APIError + var apiErr *openai.APIError if !errors.As(streamErr, &apiErr) { t.Errorf("stream.Recv() did not return APIError") } @@ -303,7 +303,7 @@ func TestCreateChatCompletionStreamErrorWithDataPrefix(t *testing.T) { func TestCreateChatCompletionStreamRateLimitError(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() - server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { + server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(429) @@ -317,18 +317,18 @@ func TestCreateChatCompletionStreamRateLimitError(t *testing.T) { _, err := w.Write(dataBytes) checks.NoError(t, err, "Write error") }) - _, err := client.CreateChatCompletionStream(context.Background(), ChatCompletionRequest{ + _, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{ MaxTokens: 5, - Model: GPT3Dot5Turbo, - Messages: []ChatCompletionMessage{ + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: openai.ChatMessageRoleUser, Content: "Hello!", }, }, Stream: true, }) - var apiErr *APIError + var apiErr *openai.APIError if !errors.As(err, &apiErr) { t.Errorf("TestCreateChatCompletionStreamRateLimitError did not return APIError") } @@ -345,7 +345,7 @@ func TestAzureCreateChatCompletionStreamRateLimitError(t *testing.T) { client, server, teardown := setupAzureTestServer() defer teardown() server.RegisterHandler("/openai/deployments/gpt-35-turbo/chat/completions", - func(w http.ResponseWriter, r *http.Request) { + func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusTooManyRequests) // Send test responses @@ -355,13 +355,13 @@ func TestAzureCreateChatCompletionStreamRateLimitError(t *testing.T) { checks.NoError(t, err, "Write error") }) - apiErr := &APIError{} - _, err := client.CreateChatCompletionStream(context.Background(), ChatCompletionRequest{ + apiErr := &openai.APIError{} + _, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{ MaxTokens: 5, - Model: GPT3Dot5Turbo, - Messages: []ChatCompletionMessage{ + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: openai.ChatMessageRoleUser, Content: "Hello!", }, }, @@ -387,7 +387,7 @@ func TestAzureCreateChatCompletionStreamRateLimitError(t *testing.T) { } // Helper funcs. -func compareChatResponses(r1, r2 ChatCompletionStreamResponse) bool { +func compareChatResponses(r1, r2 openai.ChatCompletionStreamResponse) bool { if r1.ID != r2.ID || r1.Object != r2.Object || r1.Created != r2.Created || r1.Model != r2.Model { return false } @@ -402,7 +402,7 @@ func compareChatResponses(r1, r2 ChatCompletionStreamResponse) bool { return true } -func compareChatStreamResponseChoices(c1, c2 ChatCompletionStreamChoice) bool { +func compareChatStreamResponseChoices(c1, c2 openai.ChatCompletionStreamChoice) bool { if c1.Index != c2.Index { return false } diff --git a/chat_test.go b/chat_test.go index 329b2b9c..5bf1eaf6 100644 --- a/chat_test.go +++ b/chat_test.go @@ -11,7 +11,7 @@ import ( "testing" "time" - . "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai/internal/test/checks" "github.com/sashabaranov/go-openai/jsonschema" ) @@ -21,49 +21,47 @@ const ( xCustomHeaderValue = "test" ) -var ( - rateLimitHeaders = map[string]any{ - "x-ratelimit-limit-requests": 60, - "x-ratelimit-limit-tokens": 150000, - "x-ratelimit-remaining-requests": 59, - "x-ratelimit-remaining-tokens": 149984, - "x-ratelimit-reset-requests": "1s", - "x-ratelimit-reset-tokens": "6m0s", - } -) +var rateLimitHeaders = map[string]any{ + "x-ratelimit-limit-requests": 60, + "x-ratelimit-limit-tokens": 150000, + "x-ratelimit-remaining-requests": 59, + "x-ratelimit-remaining-tokens": 149984, + "x-ratelimit-reset-requests": "1s", + "x-ratelimit-reset-tokens": "6m0s", +} func TestChatCompletionsWrongModel(t *testing.T) { - config := DefaultConfig("whatever") + config := openai.DefaultConfig("whatever") config.BaseURL = "http://localhost/v1" - client := NewClientWithConfig(config) + client := openai.NewClientWithConfig(config) ctx := context.Background() - req := ChatCompletionRequest{ + req := openai.ChatCompletionRequest{ MaxTokens: 5, Model: "ada", - Messages: []ChatCompletionMessage{ + Messages: []openai.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: openai.ChatMessageRoleUser, Content: "Hello!", }, }, } _, err := client.CreateChatCompletion(ctx, req) msg := fmt.Sprintf("CreateChatCompletion should return wrong model error, returned: %s", err) - checks.ErrorIs(t, err, ErrChatCompletionInvalidModel, msg) + checks.ErrorIs(t, err, openai.ErrChatCompletionInvalidModel, msg) } func TestChatCompletionsWithStream(t *testing.T) { - config := DefaultConfig("whatever") + config := openai.DefaultConfig("whatever") config.BaseURL = "http://localhost/v1" - client := NewClientWithConfig(config) + client := openai.NewClientWithConfig(config) ctx := context.Background() - req := ChatCompletionRequest{ + req := openai.ChatCompletionRequest{ Stream: true, } _, err := client.CreateChatCompletion(ctx, req) - checks.ErrorIs(t, err, ErrChatCompletionStreamNotSupported, "unexpected error") + checks.ErrorIs(t, err, openai.ErrChatCompletionStreamNotSupported, "unexpected error") } // TestCompletions Tests the completions endpoint of the API using the mocked server. @@ -71,12 +69,12 @@ func TestChatCompletions(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint) - _, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{ + _, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ MaxTokens: 5, - Model: GPT3Dot5Turbo, - Messages: []ChatCompletionMessage{ + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: openai.ChatMessageRoleUser, Content: "Hello!", }, }, @@ -89,12 +87,12 @@ func TestChatCompletionsWithHeaders(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint) - resp, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{ + resp, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ MaxTokens: 5, - Model: GPT3Dot5Turbo, - Messages: []ChatCompletionMessage{ + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: openai.ChatMessageRoleUser, Content: "Hello!", }, }, @@ -113,12 +111,12 @@ func TestChatCompletionsWithRateLimitHeaders(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint) - resp, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{ + resp, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ MaxTokens: 5, - Model: GPT3Dot5Turbo, - Messages: []ChatCompletionMessage{ + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: openai.ChatMessageRoleUser, Content: "Hello!", }, }, @@ -150,16 +148,16 @@ func TestChatCompletionsFunctions(t *testing.T) { t.Run("bytes", func(t *testing.T) { //nolint:lll msg := json.RawMessage(`{"properties":{"count":{"type":"integer","description":"total number of words in sentence"},"words":{"items":{"type":"string"},"type":"array","description":"list of words in sentence"}},"type":"object","required":["count","words"]}`) - _, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{ + _, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ MaxTokens: 5, - Model: GPT3Dot5Turbo0613, - Messages: []ChatCompletionMessage{ + Model: openai.GPT3Dot5Turbo0613, + Messages: []openai.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: openai.ChatMessageRoleUser, Content: "Hello!", }, }, - Functions: []FunctionDefinition{{ + Functions: []openai.FunctionDefinition{{ Name: "test", Parameters: &msg, }}, @@ -175,16 +173,16 @@ func TestChatCompletionsFunctions(t *testing.T) { Count: 2, Words: []string{"hello", "world"}, } - _, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{ + _, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ MaxTokens: 5, - Model: GPT3Dot5Turbo0613, - Messages: []ChatCompletionMessage{ + Model: openai.GPT3Dot5Turbo0613, + Messages: []openai.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: openai.ChatMessageRoleUser, Content: "Hello!", }, }, - Functions: []FunctionDefinition{{ + Functions: []openai.FunctionDefinition{{ Name: "test", Parameters: &msg, }}, @@ -192,16 +190,16 @@ func TestChatCompletionsFunctions(t *testing.T) { checks.NoError(t, err, "CreateChatCompletion with functions error") }) t.Run("JSONSchemaDefinition", func(t *testing.T) { - _, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{ + _, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ MaxTokens: 5, - Model: GPT3Dot5Turbo0613, - Messages: []ChatCompletionMessage{ + Model: openai.GPT3Dot5Turbo0613, + Messages: []openai.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: openai.ChatMessageRoleUser, Content: "Hello!", }, }, - Functions: []FunctionDefinition{{ + Functions: []openai.FunctionDefinition{{ Name: "test", Parameters: &jsonschema.Definition{ Type: jsonschema.Object, @@ -229,16 +227,16 @@ func TestChatCompletionsFunctions(t *testing.T) { }) t.Run("JSONSchemaDefinitionWithFunctionDefine", func(t *testing.T) { // this is a compatibility check - _, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{ + _, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ MaxTokens: 5, - Model: GPT3Dot5Turbo0613, - Messages: []ChatCompletionMessage{ + Model: openai.GPT3Dot5Turbo0613, + Messages: []openai.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: openai.ChatMessageRoleUser, Content: "Hello!", }, }, - Functions: []FunctionDefine{{ + Functions: []openai.FunctionDefine{{ Name: "test", Parameters: &jsonschema.Definition{ Type: jsonschema.Object, @@ -271,12 +269,12 @@ func TestAzureChatCompletions(t *testing.T) { defer teardown() server.RegisterHandler("/openai/deployments/*", handleChatCompletionEndpoint) - _, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{ + _, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ MaxTokens: 5, - Model: GPT3Dot5Turbo, - Messages: []ChatCompletionMessage{ + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: openai.ChatMessageRoleUser, Content: "Hello!", }, }, @@ -293,12 +291,12 @@ func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) { if r.Method != "POST" { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) } - var completionReq ChatCompletionRequest + var completionReq openai.ChatCompletionRequest if completionReq, err = getChatCompletionBody(r); err != nil { http.Error(w, "could not read request", http.StatusInternalServerError) return } - res := ChatCompletionResponse{ + res := openai.ChatCompletionResponse{ ID: strconv.Itoa(int(time.Now().Unix())), Object: "test-object", Created: time.Now().Unix(), @@ -323,11 +321,11 @@ func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) { return } - res.Choices = append(res.Choices, ChatCompletionChoice{ - Message: ChatCompletionMessage{ - Role: ChatMessageRoleFunction, + res.Choices = append(res.Choices, openai.ChatCompletionChoice{ + Message: openai.ChatCompletionMessage{ + Role: openai.ChatMessageRoleFunction, // this is valid json so it should be fine - FunctionCall: &FunctionCall{ + FunctionCall: &openai.FunctionCall{ Name: completionReq.Functions[0].Name, Arguments: string(fcb), }, @@ -339,9 +337,9 @@ func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) { // generate a random string of length completionReq.Length completionStr := strings.Repeat("a", completionReq.MaxTokens) - res.Choices = append(res.Choices, ChatCompletionChoice{ - Message: ChatCompletionMessage{ - Role: ChatMessageRoleAssistant, + res.Choices = append(res.Choices, openai.ChatCompletionChoice{ + Message: openai.ChatCompletionMessage{ + Role: openai.ChatMessageRoleAssistant, Content: completionStr, }, Index: i, @@ -349,7 +347,7 @@ func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) { } inputTokens := numTokens(completionReq.Messages[0].Content) * n completionTokens := completionReq.MaxTokens * n - res.Usage = Usage{ + res.Usage = openai.Usage{ PromptTokens: inputTokens, CompletionTokens: completionTokens, TotalTokens: inputTokens + completionTokens, @@ -368,23 +366,23 @@ func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) { } // getChatCompletionBody Returns the body of the request to create a completion. -func getChatCompletionBody(r *http.Request) (ChatCompletionRequest, error) { - completion := ChatCompletionRequest{} +func getChatCompletionBody(r *http.Request) (openai.ChatCompletionRequest, error) { + completion := openai.ChatCompletionRequest{} // read the request body reqBody, err := io.ReadAll(r.Body) if err != nil { - return ChatCompletionRequest{}, err + return openai.ChatCompletionRequest{}, err } err = json.Unmarshal(reqBody, &completion) if err != nil { - return ChatCompletionRequest{}, err + return openai.ChatCompletionRequest{}, err } return completion, nil } func TestFinishReason(t *testing.T) { - c := &ChatCompletionChoice{ - FinishReason: FinishReasonNull, + c := &openai.ChatCompletionChoice{ + FinishReason: openai.FinishReasonNull, } resBytes, _ := json.Marshal(c) if !strings.Contains(string(resBytes), `"finish_reason":null`) { @@ -398,11 +396,11 @@ func TestFinishReason(t *testing.T) { t.Error("null should not be quoted") } - otherReasons := []FinishReason{ - FinishReasonStop, - FinishReasonLength, - FinishReasonFunctionCall, - FinishReasonContentFilter, + otherReasons := []openai.FinishReason{ + openai.FinishReasonStop, + openai.FinishReasonLength, + openai.FinishReasonFunctionCall, + openai.FinishReasonContentFilter, } for _, r := range otherReasons { c.FinishReason = r diff --git a/completion_test.go b/completion_test.go index 844ef484..89950bf9 100644 --- a/completion_test.go +++ b/completion_test.go @@ -1,9 +1,6 @@ package openai_test import ( - . "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test/checks" - "context" "encoding/json" "errors" @@ -14,33 +11,36 @@ import ( "strings" "testing" "time" + + "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" ) func TestCompletionsWrongModel(t *testing.T) { - config := DefaultConfig("whatever") + config := openai.DefaultConfig("whatever") config.BaseURL = "http://localhost/v1" - client := NewClientWithConfig(config) + client := openai.NewClientWithConfig(config) _, err := client.CreateCompletion( context.Background(), - CompletionRequest{ + openai.CompletionRequest{ MaxTokens: 5, - Model: GPT3Dot5Turbo, + Model: openai.GPT3Dot5Turbo, }, ) - if !errors.Is(err, ErrCompletionUnsupportedModel) { + if !errors.Is(err, openai.ErrCompletionUnsupportedModel) { t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel, but returned: %v", err) } } func TestCompletionWithStream(t *testing.T) { - config := DefaultConfig("whatever") - client := NewClientWithConfig(config) + config := openai.DefaultConfig("whatever") + client := openai.NewClientWithConfig(config) ctx := context.Background() - req := CompletionRequest{Stream: true} + req := openai.CompletionRequest{Stream: true} _, err := client.CreateCompletion(ctx, req) - if !errors.Is(err, ErrCompletionStreamNotSupported) { + if !errors.Is(err, openai.ErrCompletionStreamNotSupported) { t.Fatalf("CreateCompletion didn't return ErrCompletionStreamNotSupported") } } @@ -50,7 +50,7 @@ func TestCompletions(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() server.RegisterHandler("/v1/completions", handleCompletionEndpoint) - req := CompletionRequest{ + req := openai.CompletionRequest{ MaxTokens: 5, Model: "ada", Prompt: "Lorem ipsum", @@ -68,12 +68,12 @@ func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) { if r.Method != "POST" { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) } - var completionReq CompletionRequest + var completionReq openai.CompletionRequest if completionReq, err = getCompletionBody(r); err != nil { http.Error(w, "could not read request", http.StatusInternalServerError) return } - res := CompletionResponse{ + res := openai.CompletionResponse{ ID: strconv.Itoa(int(time.Now().Unix())), Object: "test-object", Created: time.Now().Unix(), @@ -93,14 +93,14 @@ func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) { if completionReq.Echo { completionStr = completionReq.Prompt.(string) + completionStr } - res.Choices = append(res.Choices, CompletionChoice{ + res.Choices = append(res.Choices, openai.CompletionChoice{ Text: completionStr, Index: i, }) } inputTokens := numTokens(completionReq.Prompt.(string)) * n completionTokens := completionReq.MaxTokens * n - res.Usage = Usage{ + res.Usage = openai.Usage{ PromptTokens: inputTokens, CompletionTokens: completionTokens, TotalTokens: inputTokens + completionTokens, @@ -110,16 +110,16 @@ func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) { } // getCompletionBody Returns the body of the request to create a completion. -func getCompletionBody(r *http.Request) (CompletionRequest, error) { - completion := CompletionRequest{} +func getCompletionBody(r *http.Request) (openai.CompletionRequest, error) { + completion := openai.CompletionRequest{} // read the request body reqBody, err := io.ReadAll(r.Body) if err != nil { - return CompletionRequest{}, err + return openai.CompletionRequest{}, err } err = json.Unmarshal(reqBody, &completion) if err != nil { - return CompletionRequest{}, err + return openai.CompletionRequest{}, err } return completion, nil } diff --git a/config_test.go b/config_test.go index 488511b1..3e528c3e 100644 --- a/config_test.go +++ b/config_test.go @@ -3,7 +3,7 @@ package openai_test import ( "testing" - . "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai" ) func TestGetAzureDeploymentByModel(t *testing.T) { @@ -49,7 +49,7 @@ func TestGetAzureDeploymentByModel(t *testing.T) { for _, c := range cases { t.Run(c.Model, func(t *testing.T) { - conf := DefaultAzureConfig("", "https://test.openai.azure.com/") + conf := openai.DefaultAzureConfig("", "https://test.openai.azure.com/") if c.AzureModelMapperFunc != nil { conf.AzureModelMapperFunc = c.AzureModelMapperFunc } diff --git a/edits_test.go b/edits_test.go index c0bb8439..d2a6db40 100644 --- a/edits_test.go +++ b/edits_test.go @@ -1,9 +1,6 @@ package openai_test import ( - . "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test/checks" - "context" "encoding/json" "fmt" @@ -11,6 +8,9 @@ import ( "net/http" "testing" "time" + + "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" ) // TestEdits Tests the edits endpoint of the API using the mocked server. @@ -20,7 +20,7 @@ func TestEdits(t *testing.T) { server.RegisterHandler("/v1/edits", handleEditEndpoint) // create an edit request model := "ada" - editReq := EditsRequest{ + editReq := openai.EditsRequest{ Model: &model, Input: "Lorem ipsum dolor sit amet, consectetur adipiscing elit, " + "sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim" + @@ -45,14 +45,14 @@ func handleEditEndpoint(w http.ResponseWriter, r *http.Request) { if r.Method != "POST" { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) } - var editReq EditsRequest + var editReq openai.EditsRequest editReq, err = getEditBody(r) if err != nil { http.Error(w, "could not read request", http.StatusInternalServerError) return } // create a response - res := EditsResponse{ + res := openai.EditsResponse{ Object: "test-object", Created: time.Now().Unix(), } @@ -62,12 +62,12 @@ func handleEditEndpoint(w http.ResponseWriter, r *http.Request) { completionTokens := int(float32(len(editString))/4) * editReq.N for i := 0; i < editReq.N; i++ { // instruction will be hidden and only seen by OpenAI - res.Choices = append(res.Choices, EditsChoice{ + res.Choices = append(res.Choices, openai.EditsChoice{ Text: editReq.Input + editString, Index: i, }) } - res.Usage = Usage{ + res.Usage = openai.Usage{ PromptTokens: inputTokens, CompletionTokens: completionTokens, TotalTokens: inputTokens + completionTokens, @@ -77,16 +77,16 @@ func handleEditEndpoint(w http.ResponseWriter, r *http.Request) { } // getEditBody Returns the body of the request to create an edit. -func getEditBody(r *http.Request) (EditsRequest, error) { - edit := EditsRequest{} +func getEditBody(r *http.Request) (openai.EditsRequest, error) { + edit := openai.EditsRequest{} // read the request body reqBody, err := io.ReadAll(r.Body) if err != nil { - return EditsRequest{}, err + return openai.EditsRequest{}, err } err = json.Unmarshal(reqBody, &edit) if err != nil { - return EditsRequest{}, err + return openai.EditsRequest{}, err } return edit, nil } diff --git a/embeddings_test.go b/embeddings_test.go index 72e8c245..af04d96b 100644 --- a/embeddings_test.go +++ b/embeddings_test.go @@ -11,32 +11,32 @@ import ( "reflect" "testing" - . "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai/internal/test/checks" ) func TestEmbedding(t *testing.T) { - embeddedModels := []EmbeddingModel{ - AdaSimilarity, - BabbageSimilarity, - CurieSimilarity, - DavinciSimilarity, - AdaSearchDocument, - AdaSearchQuery, - BabbageSearchDocument, - BabbageSearchQuery, - CurieSearchDocument, - CurieSearchQuery, - DavinciSearchDocument, - DavinciSearchQuery, - AdaCodeSearchCode, - AdaCodeSearchText, - BabbageCodeSearchCode, - BabbageCodeSearchText, + embeddedModels := []openai.EmbeddingModel{ + openai.AdaSimilarity, + openai.BabbageSimilarity, + openai.CurieSimilarity, + openai.DavinciSimilarity, + openai.AdaSearchDocument, + openai.AdaSearchQuery, + openai.BabbageSearchDocument, + openai.BabbageSearchQuery, + openai.CurieSearchDocument, + openai.CurieSearchQuery, + openai.DavinciSearchDocument, + openai.DavinciSearchQuery, + openai.AdaCodeSearchCode, + openai.AdaCodeSearchText, + openai.BabbageCodeSearchCode, + openai.BabbageCodeSearchText, } for _, model := range embeddedModels { // test embedding request with strings (simple embedding request) - embeddingReq := EmbeddingRequest{ + embeddingReq := openai.EmbeddingRequest{ Input: []string{ "The food was delicious and the waiter", "Other examples of embedding request", @@ -52,7 +52,7 @@ func TestEmbedding(t *testing.T) { } // test embedding request with strings - embeddingReqStrings := EmbeddingRequestStrings{ + embeddingReqStrings := openai.EmbeddingRequestStrings{ Input: []string{ "The food was delicious and the waiter", "Other examples of embedding request", @@ -66,7 +66,7 @@ func TestEmbedding(t *testing.T) { } // test embedding request with tokens - embeddingReqTokens := EmbeddingRequestTokens{ + embeddingReqTokens := openai.EmbeddingRequestTokens{ Input: [][]int{ {464, 2057, 373, 12625, 290, 262, 46612}, {6395, 6096, 286, 11525, 12083, 2581}, @@ -82,17 +82,17 @@ func TestEmbedding(t *testing.T) { } func TestEmbeddingModel(t *testing.T) { - var em EmbeddingModel + var em openai.EmbeddingModel err := em.UnmarshalText([]byte("text-similarity-ada-001")) checks.NoError(t, err, "Could not marshal embedding model") - if em != AdaSimilarity { + if em != openai.AdaSimilarity { t.Errorf("Model is not equal to AdaSimilarity") } err = em.UnmarshalText([]byte("some-non-existent-model")) checks.NoError(t, err, "Could not marshal embedding model") - if em != Unknown { + if em != openai.Unknown { t.Errorf("Model is not equal to Unknown") } } @@ -101,12 +101,12 @@ func TestEmbeddingEndpoint(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() - sampleEmbeddings := []Embedding{ + sampleEmbeddings := []openai.Embedding{ {Embedding: []float32{1.23, 4.56, 7.89}}, {Embedding: []float32{-0.006968617, -0.0052718227, 0.011901081}}, } - sampleBase64Embeddings := []Base64Embedding{ + sampleBase64Embeddings := []openai.Base64Embedding{ {Embedding: "pHCdP4XrkUDhevxA"}, {Embedding: "/1jku0G/rLvA/EI8"}, } @@ -115,8 +115,8 @@ func TestEmbeddingEndpoint(t *testing.T) { "/v1/embeddings", func(w http.ResponseWriter, r *http.Request) { var req struct { - EncodingFormat EmbeddingEncodingFormat `json:"encoding_format"` - User string `json:"user"` + EncodingFormat openai.EmbeddingEncodingFormat `json:"encoding_format"` + User string `json:"user"` } _ = json.NewDecoder(r.Body).Decode(&req) @@ -125,16 +125,16 @@ func TestEmbeddingEndpoint(t *testing.T) { case req.User == "invalid": w.WriteHeader(http.StatusBadRequest) return - case req.EncodingFormat == EmbeddingEncodingFormatBase64: - resBytes, _ = json.Marshal(EmbeddingResponseBase64{Data: sampleBase64Embeddings}) + case req.EncodingFormat == openai.EmbeddingEncodingFormatBase64: + resBytes, _ = json.Marshal(openai.EmbeddingResponseBase64{Data: sampleBase64Embeddings}) default: - resBytes, _ = json.Marshal(EmbeddingResponse{Data: sampleEmbeddings}) + resBytes, _ = json.Marshal(openai.EmbeddingResponse{Data: sampleEmbeddings}) } fmt.Fprintln(w, string(resBytes)) }, ) // test create embeddings with strings (simple embedding request) - res, err := client.CreateEmbeddings(context.Background(), EmbeddingRequest{}) + res, err := client.CreateEmbeddings(context.Background(), openai.EmbeddingRequest{}) checks.NoError(t, err, "CreateEmbeddings error") if !reflect.DeepEqual(res.Data, sampleEmbeddings) { t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data) @@ -143,8 +143,8 @@ func TestEmbeddingEndpoint(t *testing.T) { // test create embeddings with strings (simple embedding request) res, err = client.CreateEmbeddings( context.Background(), - EmbeddingRequest{ - EncodingFormat: EmbeddingEncodingFormatBase64, + openai.EmbeddingRequest{ + EncodingFormat: openai.EmbeddingEncodingFormatBase64, }, ) checks.NoError(t, err, "CreateEmbeddings error") @@ -153,23 +153,23 @@ func TestEmbeddingEndpoint(t *testing.T) { } // test create embeddings with strings - res, err = client.CreateEmbeddings(context.Background(), EmbeddingRequestStrings{}) + res, err = client.CreateEmbeddings(context.Background(), openai.EmbeddingRequestStrings{}) checks.NoError(t, err, "CreateEmbeddings strings error") if !reflect.DeepEqual(res.Data, sampleEmbeddings) { t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data) } // test create embeddings with tokens - res, err = client.CreateEmbeddings(context.Background(), EmbeddingRequestTokens{}) + res, err = client.CreateEmbeddings(context.Background(), openai.EmbeddingRequestTokens{}) checks.NoError(t, err, "CreateEmbeddings tokens error") if !reflect.DeepEqual(res.Data, sampleEmbeddings) { t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data) } // test failed sendRequest - _, err = client.CreateEmbeddings(context.Background(), EmbeddingRequest{ + _, err = client.CreateEmbeddings(context.Background(), openai.EmbeddingRequest{ User: "invalid", - EncodingFormat: EmbeddingEncodingFormatBase64, + EncodingFormat: openai.EmbeddingEncodingFormatBase64, }) checks.HasError(t, err, "CreateEmbeddings error") } @@ -177,26 +177,26 @@ func TestEmbeddingEndpoint(t *testing.T) { func TestEmbeddingResponseBase64_ToEmbeddingResponse(t *testing.T) { type fields struct { Object string - Data []Base64Embedding - Model EmbeddingModel - Usage Usage + Data []openai.Base64Embedding + Model openai.EmbeddingModel + Usage openai.Usage } tests := []struct { name string fields fields - want EmbeddingResponse + want openai.EmbeddingResponse wantErr bool }{ { name: "test embedding response base64 to embedding response", fields: fields{ - Data: []Base64Embedding{ + Data: []openai.Base64Embedding{ {Embedding: "pHCdP4XrkUDhevxA"}, {Embedding: "/1jku0G/rLvA/EI8"}, }, }, - want: EmbeddingResponse{ - Data: []Embedding{ + want: openai.EmbeddingResponse{ + Data: []openai.Embedding{ {Embedding: []float32{1.23, 4.56, 7.89}}, {Embedding: []float32{-0.006968617, -0.0052718227, 0.011901081}}, }, @@ -206,19 +206,19 @@ func TestEmbeddingResponseBase64_ToEmbeddingResponse(t *testing.T) { { name: "Invalid embedding", fields: fields{ - Data: []Base64Embedding{ + Data: []openai.Base64Embedding{ { Embedding: "----", }, }, }, - want: EmbeddingResponse{}, + want: openai.EmbeddingResponse{}, wantErr: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - r := &EmbeddingResponseBase64{ + r := &openai.EmbeddingResponseBase64{ Object: tt.fields.Object, Data: tt.fields.Data, Model: tt.fields.Model, @@ -237,8 +237,8 @@ func TestEmbeddingResponseBase64_ToEmbeddingResponse(t *testing.T) { } func TestDotProduct(t *testing.T) { - v1 := &Embedding{Embedding: []float32{1, 2, 3}} - v2 := &Embedding{Embedding: []float32{2, 4, 6}} + v1 := &openai.Embedding{Embedding: []float32{1, 2, 3}} + v2 := &openai.Embedding{Embedding: []float32{2, 4, 6}} expected := float32(28.0) result, err := v1.DotProduct(v2) @@ -250,8 +250,8 @@ func TestDotProduct(t *testing.T) { t.Errorf("Unexpected result. Expected: %v, but got %v", expected, result) } - v1 = &Embedding{Embedding: []float32{1, 0, 0}} - v2 = &Embedding{Embedding: []float32{0, 1, 0}} + v1 = &openai.Embedding{Embedding: []float32{1, 0, 0}} + v2 = &openai.Embedding{Embedding: []float32{0, 1, 0}} expected = float32(0.0) result, err = v1.DotProduct(v2) @@ -264,10 +264,10 @@ func TestDotProduct(t *testing.T) { } // Test for VectorLengthMismatchError - v1 = &Embedding{Embedding: []float32{1, 0, 0}} - v2 = &Embedding{Embedding: []float32{0, 1}} + v1 = &openai.Embedding{Embedding: []float32{1, 0, 0}} + v2 = &openai.Embedding{Embedding: []float32{0, 1}} _, err = v1.DotProduct(v2) - if !errors.Is(err, ErrVectorLengthMismatch) { + if !errors.Is(err, openai.ErrVectorLengthMismatch) { t.Errorf("Expected Vector Length Mismatch Error, but got: %v", err) } } diff --git a/engines_test.go b/engines_test.go index 31e7ec8b..d26aa554 100644 --- a/engines_test.go +++ b/engines_test.go @@ -7,7 +7,7 @@ import ( "net/http" "testing" - . "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai/internal/test/checks" ) @@ -15,8 +15,8 @@ import ( func TestGetEngine(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() - server.RegisterHandler("/v1/engines/text-davinci-003", func(w http.ResponseWriter, r *http.Request) { - resBytes, _ := json.Marshal(Engine{}) + server.RegisterHandler("/v1/engines/text-davinci-003", func(w http.ResponseWriter, _ *http.Request) { + resBytes, _ := json.Marshal(openai.Engine{}) fmt.Fprintln(w, string(resBytes)) }) _, err := client.GetEngine(context.Background(), "text-davinci-003") @@ -27,8 +27,8 @@ func TestGetEngine(t *testing.T) { func TestListEngines(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() - server.RegisterHandler("/v1/engines", func(w http.ResponseWriter, r *http.Request) { - resBytes, _ := json.Marshal(EnginesList{}) + server.RegisterHandler("/v1/engines", func(w http.ResponseWriter, _ *http.Request) { + resBytes, _ := json.Marshal(openai.EnginesList{}) fmt.Fprintln(w, string(resBytes)) }) _, err := client.ListEngines(context.Background()) @@ -38,7 +38,7 @@ func TestListEngines(t *testing.T) { func TestListEnginesReturnError(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() - server.RegisterHandler("/v1/engines", func(w http.ResponseWriter, r *http.Request) { + server.RegisterHandler("/v1/engines", func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusTeapot) }) diff --git a/error_test.go b/error_test.go index a0806b7e..48cbe4f2 100644 --- a/error_test.go +++ b/error_test.go @@ -6,7 +6,7 @@ import ( "reflect" "testing" - . "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai" ) func TestAPIErrorUnmarshalJSON(t *testing.T) { @@ -14,7 +14,7 @@ func TestAPIErrorUnmarshalJSON(t *testing.T) { name string response string hasError bool - checkFunc func(t *testing.T, apiErr APIError) + checkFunc func(t *testing.T, apiErr openai.APIError) } testCases := []testCase{ // testcase for message field @@ -22,7 +22,7 @@ func TestAPIErrorUnmarshalJSON(t *testing.T) { name: "parse succeeds when the message is string", response: `{"message":"foo","type":"invalid_request_error","param":null,"code":null}`, hasError: false, - checkFunc: func(t *testing.T, apiErr APIError) { + checkFunc: func(t *testing.T, apiErr openai.APIError) { assertAPIErrorMessage(t, apiErr, "foo") }, }, @@ -30,7 +30,7 @@ func TestAPIErrorUnmarshalJSON(t *testing.T) { name: "parse succeeds when the message is array with single item", response: `{"message":["foo"],"type":"invalid_request_error","param":null,"code":null}`, hasError: false, - checkFunc: func(t *testing.T, apiErr APIError) { + checkFunc: func(t *testing.T, apiErr openai.APIError) { assertAPIErrorMessage(t, apiErr, "foo") }, }, @@ -38,7 +38,7 @@ func TestAPIErrorUnmarshalJSON(t *testing.T) { name: "parse succeeds when the message is array with multiple items", response: `{"message":["foo", "bar", "baz"],"type":"invalid_request_error","param":null,"code":null}`, hasError: false, - checkFunc: func(t *testing.T, apiErr APIError) { + checkFunc: func(t *testing.T, apiErr openai.APIError) { assertAPIErrorMessage(t, apiErr, "foo, bar, baz") }, }, @@ -46,7 +46,7 @@ func TestAPIErrorUnmarshalJSON(t *testing.T) { name: "parse succeeds when the message is empty array", response: `{"message":[],"type":"invalid_request_error","param":null,"code":null}`, hasError: false, - checkFunc: func(t *testing.T, apiErr APIError) { + checkFunc: func(t *testing.T, apiErr openai.APIError) { assertAPIErrorMessage(t, apiErr, "") }, }, @@ -54,7 +54,7 @@ func TestAPIErrorUnmarshalJSON(t *testing.T) { name: "parse succeeds when the message is null", response: `{"message":null,"type":"invalid_request_error","param":null,"code":null}`, hasError: false, - checkFunc: func(t *testing.T, apiErr APIError) { + checkFunc: func(t *testing.T, apiErr openai.APIError) { assertAPIErrorMessage(t, apiErr, "") }, }, @@ -89,23 +89,23 @@ func TestAPIErrorUnmarshalJSON(t *testing.T) { } }`, hasError: false, - checkFunc: func(t *testing.T, apiErr APIError) { - assertAPIErrorInnerError(t, apiErr, &InnerError{ + checkFunc: func(t *testing.T, apiErr openai.APIError) { + assertAPIErrorInnerError(t, apiErr, &openai.InnerError{ Code: "ResponsibleAIPolicyViolation", - ContentFilterResults: ContentFilterResults{ - Hate: Hate{ + ContentFilterResults: openai.ContentFilterResults{ + Hate: openai.Hate{ Filtered: false, Severity: "safe", }, - SelfHarm: SelfHarm{ + SelfHarm: openai.SelfHarm{ Filtered: false, Severity: "safe", }, - Sexual: Sexual{ + Sexual: openai.Sexual{ Filtered: true, Severity: "medium", }, - Violence: Violence{ + Violence: openai.Violence{ Filtered: false, Severity: "safe", }, @@ -117,16 +117,16 @@ func TestAPIErrorUnmarshalJSON(t *testing.T) { name: "parse succeeds when the innerError is empty (Azure Openai)", response: `{"message": "","type": null,"param": "","code": "","status": 0,"innererror": {}}`, hasError: false, - checkFunc: func(t *testing.T, apiErr APIError) { - assertAPIErrorInnerError(t, apiErr, &InnerError{}) + checkFunc: func(t *testing.T, apiErr openai.APIError) { + assertAPIErrorInnerError(t, apiErr, &openai.InnerError{}) }, }, { name: "parse succeeds when the innerError is not InnerError struct (Azure Openai)", response: `{"message": "","type": null,"param": "","code": "","status": 0,"innererror": "test"}`, hasError: true, - checkFunc: func(t *testing.T, apiErr APIError) { - assertAPIErrorInnerError(t, apiErr, &InnerError{}) + checkFunc: func(t *testing.T, apiErr openai.APIError) { + assertAPIErrorInnerError(t, apiErr, &openai.InnerError{}) }, }, { @@ -159,7 +159,7 @@ func TestAPIErrorUnmarshalJSON(t *testing.T) { name: "parse succeeds when the code is int", response: `{"code":418,"message":"I'm a teapot","param":"prompt","type":"teapot_error"}`, hasError: false, - checkFunc: func(t *testing.T, apiErr APIError) { + checkFunc: func(t *testing.T, apiErr openai.APIError) { assertAPIErrorCode(t, apiErr, 418) }, }, @@ -167,7 +167,7 @@ func TestAPIErrorUnmarshalJSON(t *testing.T) { name: "parse succeeds when the code is string", response: `{"code":"teapot","message":"I'm a teapot","param":"prompt","type":"teapot_error"}`, hasError: false, - checkFunc: func(t *testing.T, apiErr APIError) { + checkFunc: func(t *testing.T, apiErr openai.APIError) { assertAPIErrorCode(t, apiErr, "teapot") }, }, @@ -175,7 +175,7 @@ func TestAPIErrorUnmarshalJSON(t *testing.T) { name: "parse succeeds when the code is not exists", response: `{"message":"I'm a teapot","param":"prompt","type":"teapot_error"}`, hasError: false, - checkFunc: func(t *testing.T, apiErr APIError) { + checkFunc: func(t *testing.T, apiErr openai.APIError) { assertAPIErrorCode(t, apiErr, nil) }, }, @@ -196,7 +196,7 @@ func TestAPIErrorUnmarshalJSON(t *testing.T) { name: "parse failed when the response is invalid json", response: `--- {"code":418,"message":"I'm a teapot","param":"prompt","type":"teapot_error"}`, hasError: true, - checkFunc: func(t *testing.T, apiErr APIError) { + checkFunc: func(t *testing.T, apiErr openai.APIError) { assertAPIErrorCode(t, apiErr, nil) assertAPIErrorMessage(t, apiErr, "") assertAPIErrorParam(t, apiErr, nil) @@ -206,7 +206,7 @@ func TestAPIErrorUnmarshalJSON(t *testing.T) { } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - var apiErr APIError + var apiErr openai.APIError err := apiErr.UnmarshalJSON([]byte(tc.response)) if (err != nil) != tc.hasError { t.Errorf("Unexpected error: %v", err) @@ -218,19 +218,19 @@ func TestAPIErrorUnmarshalJSON(t *testing.T) { } } -func assertAPIErrorMessage(t *testing.T, apiErr APIError, expected string) { +func assertAPIErrorMessage(t *testing.T, apiErr openai.APIError, expected string) { if apiErr.Message != expected { t.Errorf("Unexpected APIError message: %v; expected: %s", apiErr, expected) } } -func assertAPIErrorInnerError(t *testing.T, apiErr APIError, expected interface{}) { +func assertAPIErrorInnerError(t *testing.T, apiErr openai.APIError, expected interface{}) { if !reflect.DeepEqual(apiErr.InnerError, expected) { t.Errorf("Unexpected APIError InnerError: %v; expected: %v; ", apiErr, expected) } } -func assertAPIErrorCode(t *testing.T, apiErr APIError, expected interface{}) { +func assertAPIErrorCode(t *testing.T, apiErr openai.APIError, expected interface{}) { switch v := apiErr.Code.(type) { case int: if v != expected { @@ -246,25 +246,25 @@ func assertAPIErrorCode(t *testing.T, apiErr APIError, expected interface{}) { } } -func assertAPIErrorParam(t *testing.T, apiErr APIError, expected *string) { +func assertAPIErrorParam(t *testing.T, apiErr openai.APIError, expected *string) { if apiErr.Param != expected { t.Errorf("Unexpected APIError param: %v; expected: %s", apiErr, *expected) } } -func assertAPIErrorType(t *testing.T, apiErr APIError, typ string) { +func assertAPIErrorType(t *testing.T, apiErr openai.APIError, typ string) { if apiErr.Type != typ { t.Errorf("Unexpected API type: %v; expected: %s", apiErr, typ) } } func TestRequestError(t *testing.T) { - var err error = &RequestError{ + var err error = &openai.RequestError{ HTTPStatusCode: http.StatusTeapot, Err: errors.New("i am a teapot"), } - var reqErr *RequestError + var reqErr *openai.RequestError if !errors.As(err, &reqErr) { t.Fatalf("Error is not a RequestError: %+v", err) } diff --git a/example_test.go b/example_test.go index b5dfafea..de67c57c 100644 --- a/example_test.go +++ b/example_test.go @@ -28,7 +28,6 @@ func Example() { }, }, ) - if err != nil { fmt.Printf("ChatCompletion error: %v\n", err) return @@ -319,7 +318,6 @@ func ExampleDefaultAzureConfig() { }, }, ) - if err != nil { fmt.Printf("ChatCompletion error: %v\n", err) return diff --git a/files_api_test.go b/files_api_test.go index 1cbc7289..330b8815 100644 --- a/files_api_test.go +++ b/files_api_test.go @@ -12,7 +12,7 @@ import ( "testing" "time" - . "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai/internal/test/checks" ) @@ -20,7 +20,7 @@ func TestFileUpload(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() server.RegisterHandler("/v1/files", handleCreateFile) - req := FileRequest{ + req := openai.FileRequest{ FileName: "test.go", FilePath: "client.go", Purpose: "fine-tune", @@ -57,7 +57,7 @@ func handleCreateFile(w http.ResponseWriter, r *http.Request) { } defer file.Close() - var fileReq = File{ + fileReq := openai.File{ Bytes: int(header.Size), ID: strconv.Itoa(int(time.Now().Unix())), FileName: header.Filename, @@ -82,7 +82,7 @@ func TestListFile(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() server.RegisterHandler("/v1/files", func(w http.ResponseWriter, r *http.Request) { - resBytes, _ := json.Marshal(FilesList{}) + resBytes, _ := json.Marshal(openai.FilesList{}) fmt.Fprintln(w, string(resBytes)) }) _, err := client.ListFiles(context.Background()) @@ -93,7 +93,7 @@ func TestGetFile(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() server.RegisterHandler("/v1/files/deadbeef", func(w http.ResponseWriter, r *http.Request) { - resBytes, _ := json.Marshal(File{}) + resBytes, _ := json.Marshal(openai.File{}) fmt.Fprintln(w, string(resBytes)) }) _, err := client.GetFile(context.Background(), "deadbeef") @@ -148,7 +148,7 @@ func TestGetFileContentReturnError(t *testing.T) { t.Fatal("Did not return error") } - apiErr := &APIError{} + apiErr := &openai.APIError{} if !errors.As(err, &apiErr) { t.Fatalf("Did not return APIError: %+v\n", apiErr) } diff --git a/files_test.go b/files_test.go index df6eaef7..f588b30d 100644 --- a/files_test.go +++ b/files_test.go @@ -1,14 +1,14 @@ package openai //nolint:testpackage // testing private field import ( - utils "github.com/sashabaranov/go-openai/internal" - "github.com/sashabaranov/go-openai/internal/test/checks" - "context" "fmt" "io" "os" "testing" + + utils "github.com/sashabaranov/go-openai/internal" + "github.com/sashabaranov/go-openai/internal/test/checks" ) func TestFileUploadWithFailingFormBuilder(t *testing.T) { diff --git a/fine_tunes.go b/fine_tunes.go index ca840781..46f89f16 100644 --- a/fine_tunes.go +++ b/fine_tunes.go @@ -115,6 +115,7 @@ func (c *Client) CreateFineTune(ctx context.Context, request FineTuneRequest) (r // This API will be officially deprecated on January 4th, 2024. // OpenAI recommends to migrate to the new fine tuning API implemented in fine_tuning_job.go. func (c *Client) CancelFineTune(ctx context.Context, fineTuneID string) (response FineTune, err error) { + //nolint:goconst // Decreases readability req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/fine-tunes/"+fineTuneID+"/cancel")) if err != nil { return diff --git a/fine_tunes_test.go b/fine_tunes_test.go index 67f681d9..2ab6817f 100644 --- a/fine_tunes_test.go +++ b/fine_tunes_test.go @@ -1,14 +1,14 @@ package openai_test import ( - . "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test/checks" - "context" "encoding/json" "fmt" "net/http" "testing" + + "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" ) const testFineTuneID = "fine-tune-id" @@ -22,9 +22,9 @@ func TestFineTunes(t *testing.T) { func(w http.ResponseWriter, r *http.Request) { var resBytes []byte if r.Method == http.MethodGet { - resBytes, _ = json.Marshal(FineTuneList{}) + resBytes, _ = json.Marshal(openai.FineTuneList{}) } else { - resBytes, _ = json.Marshal(FineTune{}) + resBytes, _ = json.Marshal(openai.FineTune{}) } fmt.Fprintln(w, string(resBytes)) }, @@ -32,8 +32,8 @@ func TestFineTunes(t *testing.T) { server.RegisterHandler( "/v1/fine-tunes/"+testFineTuneID+"/cancel", - func(w http.ResponseWriter, r *http.Request) { - resBytes, _ := json.Marshal(FineTune{}) + func(w http.ResponseWriter, _ *http.Request) { + resBytes, _ := json.Marshal(openai.FineTune{}) fmt.Fprintln(w, string(resBytes)) }, ) @@ -43,9 +43,9 @@ func TestFineTunes(t *testing.T) { func(w http.ResponseWriter, r *http.Request) { var resBytes []byte if r.Method == http.MethodDelete { - resBytes, _ = json.Marshal(FineTuneDeleteResponse{}) + resBytes, _ = json.Marshal(openai.FineTuneDeleteResponse{}) } else { - resBytes, _ = json.Marshal(FineTune{}) + resBytes, _ = json.Marshal(openai.FineTune{}) } fmt.Fprintln(w, string(resBytes)) }, @@ -53,8 +53,8 @@ func TestFineTunes(t *testing.T) { server.RegisterHandler( "/v1/fine-tunes/"+testFineTuneID+"/events", - func(w http.ResponseWriter, r *http.Request) { - resBytes, _ := json.Marshal(FineTuneEventList{}) + func(w http.ResponseWriter, _ *http.Request) { + resBytes, _ := json.Marshal(openai.FineTuneEventList{}) fmt.Fprintln(w, string(resBytes)) }, ) @@ -64,7 +64,7 @@ func TestFineTunes(t *testing.T) { _, err := client.ListFineTunes(ctx) checks.NoError(t, err, "ListFineTunes error") - _, err = client.CreateFineTune(ctx, FineTuneRequest{}) + _, err = client.CreateFineTune(ctx, openai.FineTuneRequest{}) checks.NoError(t, err, "CreateFineTune error") _, err = client.CancelFineTune(ctx, testFineTuneID) diff --git a/fine_tuning_job_test.go b/fine_tuning_job_test.go index f6d41c33..c892ef77 100644 --- a/fine_tuning_job_test.go +++ b/fine_tuning_job_test.go @@ -2,14 +2,13 @@ package openai_test import ( "context" - - . "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test/checks" - "encoding/json" "fmt" "net/http" "testing" + + "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" ) const testFineTuninigJobID = "fine-tuning-job-id" @@ -20,8 +19,8 @@ func TestFineTuningJob(t *testing.T) { defer teardown() server.RegisterHandler( "/v1/fine_tuning/jobs", - func(w http.ResponseWriter, r *http.Request) { - resBytes, _ := json.Marshal(FineTuningJob{ + func(w http.ResponseWriter, _ *http.Request) { + resBytes, _ := json.Marshal(openai.FineTuningJob{ Object: "fine_tuning.job", ID: testFineTuninigJobID, Model: "davinci-002", @@ -33,7 +32,7 @@ func TestFineTuningJob(t *testing.T) { Status: "succeeded", ValidationFile: "", TrainingFile: "file-abc123", - Hyperparameters: Hyperparameters{ + Hyperparameters: openai.Hyperparameters{ Epochs: "auto", }, TrainedTokens: 5768, @@ -44,32 +43,32 @@ func TestFineTuningJob(t *testing.T) { server.RegisterHandler( "/fine_tuning/jobs/"+testFineTuninigJobID+"/cancel", - func(w http.ResponseWriter, r *http.Request) { - resBytes, _ := json.Marshal(FineTuningJob{}) + func(w http.ResponseWriter, _ *http.Request) { + resBytes, _ := json.Marshal(openai.FineTuningJob{}) fmt.Fprintln(w, string(resBytes)) }, ) server.RegisterHandler( "/v1/fine_tuning/jobs/"+testFineTuninigJobID, - func(w http.ResponseWriter, r *http.Request) { + func(w http.ResponseWriter, _ *http.Request) { var resBytes []byte - resBytes, _ = json.Marshal(FineTuningJob{}) + resBytes, _ = json.Marshal(openai.FineTuningJob{}) fmt.Fprintln(w, string(resBytes)) }, ) server.RegisterHandler( "/v1/fine_tuning/jobs/"+testFineTuninigJobID+"/events", - func(w http.ResponseWriter, r *http.Request) { - resBytes, _ := json.Marshal(FineTuningJobEventList{}) + func(w http.ResponseWriter, _ *http.Request) { + resBytes, _ := json.Marshal(openai.FineTuningJobEventList{}) fmt.Fprintln(w, string(resBytes)) }, ) ctx := context.Background() - _, err := client.CreateFineTuningJob(ctx, FineTuningJobRequest{}) + _, err := client.CreateFineTuningJob(ctx, openai.FineTuningJobRequest{}) checks.NoError(t, err, "CreateFineTuningJob error") _, err = client.CancelFineTuningJob(ctx, testFineTuninigJobID) @@ -84,22 +83,22 @@ func TestFineTuningJob(t *testing.T) { _, err = client.ListFineTuningJobEvents( ctx, testFineTuninigJobID, - ListFineTuningJobEventsWithAfter("last-event-id"), + openai.ListFineTuningJobEventsWithAfter("last-event-id"), ) checks.NoError(t, err, "ListFineTuningJobEvents error") _, err = client.ListFineTuningJobEvents( ctx, testFineTuninigJobID, - ListFineTuningJobEventsWithLimit(10), + openai.ListFineTuningJobEventsWithLimit(10), ) checks.NoError(t, err, "ListFineTuningJobEvents error") _, err = client.ListFineTuningJobEvents( ctx, testFineTuninigJobID, - ListFineTuningJobEventsWithAfter("last-event-id"), - ListFineTuningJobEventsWithLimit(10), + openai.ListFineTuningJobEventsWithAfter("last-event-id"), + openai.ListFineTuningJobEventsWithLimit(10), ) checks.NoError(t, err, "ListFineTuningJobEvents error") } diff --git a/image_api_test.go b/image_api_test.go index b472eb04..422f831f 100644 --- a/image_api_test.go +++ b/image_api_test.go @@ -1,9 +1,6 @@ package openai_test import ( - . "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test/checks" - "context" "encoding/json" "fmt" @@ -12,13 +9,16 @@ import ( "os" "testing" "time" + + "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" ) func TestImages(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() server.RegisterHandler("/v1/images/generations", handleImageEndpoint) - _, err := client.CreateImage(context.Background(), ImageRequest{ + _, err := client.CreateImage(context.Background(), openai.ImageRequest{ Prompt: "Lorem ipsum", }) checks.NoError(t, err, "CreateImage error") @@ -33,20 +33,20 @@ func handleImageEndpoint(w http.ResponseWriter, r *http.Request) { if r.Method != "POST" { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) } - var imageReq ImageRequest + var imageReq openai.ImageRequest if imageReq, err = getImageBody(r); err != nil { http.Error(w, "could not read request", http.StatusInternalServerError) return } - res := ImageResponse{ + res := openai.ImageResponse{ Created: time.Now().Unix(), } for i := 0; i < imageReq.N; i++ { - imageData := ImageResponseDataInner{} + imageData := openai.ImageResponseDataInner{} switch imageReq.ResponseFormat { - case CreateImageResponseFormatURL, "": + case openai.CreateImageResponseFormatURL, "": imageData.URL = "https://example.com/image.png" - case CreateImageResponseFormatB64JSON: + case openai.CreateImageResponseFormatB64JSON: // This decodes to "{}" in base64. imageData.B64JSON = "e30K" default: @@ -60,16 +60,16 @@ func handleImageEndpoint(w http.ResponseWriter, r *http.Request) { } // getImageBody Returns the body of the request to create a image. -func getImageBody(r *http.Request) (ImageRequest, error) { - image := ImageRequest{} +func getImageBody(r *http.Request) (openai.ImageRequest, error) { + image := openai.ImageRequest{} // read the request body reqBody, err := io.ReadAll(r.Body) if err != nil { - return ImageRequest{}, err + return openai.ImageRequest{}, err } err = json.Unmarshal(reqBody, &image) if err != nil { - return ImageRequest{}, err + return openai.ImageRequest{}, err } return image, nil } @@ -98,13 +98,13 @@ func TestImageEdit(t *testing.T) { os.Remove("image.png") }() - _, err = client.CreateEditImage(context.Background(), ImageEditRequest{ + _, err = client.CreateEditImage(context.Background(), openai.ImageEditRequest{ Image: origin, Mask: mask, Prompt: "There is a turtle in the pool", N: 3, - Size: CreateImageSize1024x1024, - ResponseFormat: CreateImageResponseFormatURL, + Size: openai.CreateImageSize1024x1024, + ResponseFormat: openai.CreateImageResponseFormatURL, }) checks.NoError(t, err, "CreateImage error") } @@ -125,12 +125,12 @@ func TestImageEditWithoutMask(t *testing.T) { os.Remove("image.png") }() - _, err = client.CreateEditImage(context.Background(), ImageEditRequest{ + _, err = client.CreateEditImage(context.Background(), openai.ImageEditRequest{ Image: origin, Prompt: "There is a turtle in the pool", N: 3, - Size: CreateImageSize1024x1024, - ResponseFormat: CreateImageResponseFormatURL, + Size: openai.CreateImageSize1024x1024, + ResponseFormat: openai.CreateImageResponseFormatURL, }) checks.NoError(t, err, "CreateImage error") } @@ -144,9 +144,9 @@ func handleEditImageEndpoint(w http.ResponseWriter, r *http.Request) { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) } - responses := ImageResponse{ + responses := openai.ImageResponse{ Created: time.Now().Unix(), - Data: []ImageResponseDataInner{ + Data: []openai.ImageResponseDataInner{ { URL: "test-url1", B64JSON: "", @@ -182,11 +182,11 @@ func TestImageVariation(t *testing.T) { os.Remove("image.png") }() - _, err = client.CreateVariImage(context.Background(), ImageVariRequest{ + _, err = client.CreateVariImage(context.Background(), openai.ImageVariRequest{ Image: origin, N: 3, - Size: CreateImageSize1024x1024, - ResponseFormat: CreateImageResponseFormatURL, + Size: openai.CreateImageSize1024x1024, + ResponseFormat: openai.CreateImageResponseFormatURL, }) checks.NoError(t, err, "CreateImage error") } @@ -200,9 +200,9 @@ func handleVariateImageEndpoint(w http.ResponseWriter, r *http.Request) { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) } - responses := ImageResponse{ + responses := openai.ImageResponse{ Created: time.Now().Unix(), - Data: []ImageResponseDataInner{ + Data: []openai.ImageResponseDataInner{ { URL: "test-url1", B64JSON: "", diff --git a/jsonschema/json_test.go b/jsonschema/json_test.go index c8d0c1d9..74470608 100644 --- a/jsonschema/json_test.go +++ b/jsonschema/json_test.go @@ -5,28 +5,28 @@ import ( "reflect" "testing" - . "github.com/sashabaranov/go-openai/jsonschema" + "github.com/sashabaranov/go-openai/jsonschema" ) func TestDefinition_MarshalJSON(t *testing.T) { tests := []struct { name string - def Definition + def jsonschema.Definition want string }{ { name: "Test with empty Definition", - def: Definition{}, + def: jsonschema.Definition{}, want: `{"properties":{}}`, }, { name: "Test with Definition properties set", - def: Definition{ - Type: String, + def: jsonschema.Definition{ + Type: jsonschema.String, Description: "A string type", - Properties: map[string]Definition{ + Properties: map[string]jsonschema.Definition{ "name": { - Type: String, + Type: jsonschema.String, }, }, }, @@ -43,17 +43,17 @@ func TestDefinition_MarshalJSON(t *testing.T) { }, { name: "Test with nested Definition properties", - def: Definition{ - Type: Object, - Properties: map[string]Definition{ + def: jsonschema.Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ "user": { - Type: Object, - Properties: map[string]Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ "name": { - Type: String, + Type: jsonschema.String, }, "age": { - Type: Integer, + Type: jsonschema.Integer, }, }, }, @@ -80,26 +80,26 @@ func TestDefinition_MarshalJSON(t *testing.T) { }, { name: "Test with complex nested Definition", - def: Definition{ - Type: Object, - Properties: map[string]Definition{ + def: jsonschema.Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ "user": { - Type: Object, - Properties: map[string]Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ "name": { - Type: String, + Type: jsonschema.String, }, "age": { - Type: Integer, + Type: jsonschema.Integer, }, "address": { - Type: Object, - Properties: map[string]Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ "city": { - Type: String, + Type: jsonschema.String, }, "country": { - Type: String, + Type: jsonschema.String, }, }, }, @@ -141,14 +141,14 @@ func TestDefinition_MarshalJSON(t *testing.T) { }, { name: "Test with Array type Definition", - def: Definition{ - Type: Array, - Items: &Definition{ - Type: String, + def: jsonschema.Definition{ + Type: jsonschema.Array, + Items: &jsonschema.Definition{ + Type: jsonschema.String, }, - Properties: map[string]Definition{ + Properties: map[string]jsonschema.Definition{ "name": { - Type: String, + Type: jsonschema.String, }, }, }, diff --git a/models_test.go b/models_test.go index 9ff73042..4a4c759d 100644 --- a/models_test.go +++ b/models_test.go @@ -1,17 +1,16 @@ package openai_test import ( - "os" - "time" - - . "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test/checks" - "context" "encoding/json" "fmt" "net/http" + "os" "testing" + "time" + + "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" ) const testFineTuneModelID = "fine-tune-model-id" @@ -35,7 +34,7 @@ func TestAzureListModels(t *testing.T) { // handleListModelsEndpoint Handles the list models endpoint by the test server. func handleListModelsEndpoint(w http.ResponseWriter, _ *http.Request) { - resBytes, _ := json.Marshal(ModelsList{}) + resBytes, _ := json.Marshal(openai.ModelsList{}) fmt.Fprintln(w, string(resBytes)) } @@ -58,7 +57,7 @@ func TestAzureGetModel(t *testing.T) { // handleGetModelsEndpoint Handles the get model endpoint by the test server. func handleGetModelEndpoint(w http.ResponseWriter, _ *http.Request) { - resBytes, _ := json.Marshal(Model{}) + resBytes, _ := json.Marshal(openai.Model{}) fmt.Fprintln(w, string(resBytes)) } @@ -90,6 +89,6 @@ func TestDeleteFineTuneModel(t *testing.T) { } func handleDeleteFineTuneModelEndpoint(w http.ResponseWriter, _ *http.Request) { - resBytes, _ := json.Marshal(FineTuneModelDeleteResponse{}) + resBytes, _ := json.Marshal(openai.FineTuneModelDeleteResponse{}) fmt.Fprintln(w, string(resBytes)) } diff --git a/moderation_test.go b/moderation_test.go index 68f9565e..059f0d1c 100644 --- a/moderation_test.go +++ b/moderation_test.go @@ -1,9 +1,6 @@ package openai_test import ( - . "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test/checks" - "context" "encoding/json" "fmt" @@ -13,6 +10,9 @@ import ( "strings" "testing" "time" + + "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" ) // TestModeration Tests the moderations endpoint of the API using the mocked server. @@ -20,8 +20,8 @@ func TestModerations(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() server.RegisterHandler("/v1/moderations", handleModerationEndpoint) - _, err := client.Moderations(context.Background(), ModerationRequest{ - Model: ModerationTextStable, + _, err := client.Moderations(context.Background(), openai.ModerationRequest{ + Model: openai.ModerationTextStable, Input: "I want to kill them.", }) checks.NoError(t, err, "Moderation error") @@ -34,16 +34,16 @@ func TestModerationsWithDifferentModelOptions(t *testing.T) { expect error } modelOptions = append(modelOptions, - getModerationModelTestOption(GPT3Dot5Turbo, ErrModerationInvalidModel), - getModerationModelTestOption(ModerationTextStable, nil), - getModerationModelTestOption(ModerationTextLatest, nil), + getModerationModelTestOption(openai.GPT3Dot5Turbo, openai.ErrModerationInvalidModel), + getModerationModelTestOption(openai.ModerationTextStable, nil), + getModerationModelTestOption(openai.ModerationTextLatest, nil), getModerationModelTestOption("", nil), ) client, server, teardown := setupOpenAITestServer() defer teardown() server.RegisterHandler("/v1/moderations", handleModerationEndpoint) for _, modelTest := range modelOptions { - _, err := client.Moderations(context.Background(), ModerationRequest{ + _, err := client.Moderations(context.Background(), openai.ModerationRequest{ Model: modelTest.model, Input: "I want to kill them.", }) @@ -71,32 +71,32 @@ func handleModerationEndpoint(w http.ResponseWriter, r *http.Request) { if r.Method != "POST" { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) } - var moderationReq ModerationRequest + var moderationReq openai.ModerationRequest if moderationReq, err = getModerationBody(r); err != nil { http.Error(w, "could not read request", http.StatusInternalServerError) return } - resCat := ResultCategories{} - resCatScore := ResultCategoryScores{} + resCat := openai.ResultCategories{} + resCatScore := openai.ResultCategoryScores{} switch { case strings.Contains(moderationReq.Input, "kill"): - resCat = ResultCategories{Violence: true} - resCatScore = ResultCategoryScores{Violence: 1} + resCat = openai.ResultCategories{Violence: true} + resCatScore = openai.ResultCategoryScores{Violence: 1} case strings.Contains(moderationReq.Input, "hate"): - resCat = ResultCategories{Hate: true} - resCatScore = ResultCategoryScores{Hate: 1} + resCat = openai.ResultCategories{Hate: true} + resCatScore = openai.ResultCategoryScores{Hate: 1} case strings.Contains(moderationReq.Input, "suicide"): - resCat = ResultCategories{SelfHarm: true} - resCatScore = ResultCategoryScores{SelfHarm: 1} + resCat = openai.ResultCategories{SelfHarm: true} + resCatScore = openai.ResultCategoryScores{SelfHarm: 1} case strings.Contains(moderationReq.Input, "porn"): - resCat = ResultCategories{Sexual: true} - resCatScore = ResultCategoryScores{Sexual: 1} + resCat = openai.ResultCategories{Sexual: true} + resCatScore = openai.ResultCategoryScores{Sexual: 1} } - result := Result{Categories: resCat, CategoryScores: resCatScore, Flagged: true} + result := openai.Result{Categories: resCat, CategoryScores: resCatScore, Flagged: true} - res := ModerationResponse{ + res := openai.ModerationResponse{ ID: strconv.Itoa(int(time.Now().Unix())), Model: moderationReq.Model, } @@ -107,16 +107,16 @@ func handleModerationEndpoint(w http.ResponseWriter, r *http.Request) { } // getModerationBody Returns the body of the request to do a moderation. -func getModerationBody(r *http.Request) (ModerationRequest, error) { - moderation := ModerationRequest{} +func getModerationBody(r *http.Request) (openai.ModerationRequest, error) { + moderation := openai.ModerationRequest{} // read the request body reqBody, err := io.ReadAll(r.Body) if err != nil { - return ModerationRequest{}, err + return openai.ModerationRequest{}, err } err = json.Unmarshal(reqBody, &moderation) if err != nil { - return ModerationRequest{}, err + return openai.ModerationRequest{}, err } return moderation, nil } diff --git a/openai_test.go b/openai_test.go index 4fc41ecc..729d8880 100644 --- a/openai_test.go +++ b/openai_test.go @@ -1,29 +1,29 @@ package openai_test import ( - . "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai/internal/test" ) -func setupOpenAITestServer() (client *Client, server *test.ServerTest, teardown func()) { +func setupOpenAITestServer() (client *openai.Client, server *test.ServerTest, teardown func()) { server = test.NewTestServer() ts := server.OpenAITestServer() ts.Start() teardown = ts.Close - config := DefaultConfig(test.GetTestToken()) + config := openai.DefaultConfig(test.GetTestToken()) config.BaseURL = ts.URL + "/v1" - client = NewClientWithConfig(config) + client = openai.NewClientWithConfig(config) return } -func setupAzureTestServer() (client *Client, server *test.ServerTest, teardown func()) { +func setupAzureTestServer() (client *openai.Client, server *test.ServerTest, teardown func()) { server = test.NewTestServer() ts := server.OpenAITestServer() ts.Start() teardown = ts.Close - config := DefaultAzureConfig(test.GetTestToken(), "https://dummylab.openai.azure.com/") + config := openai.DefaultAzureConfig(test.GetTestToken(), "https://dummylab.openai.azure.com/") config.BaseURL = ts.URL - client = NewClientWithConfig(config) + client = openai.NewClientWithConfig(config) return } diff --git a/stream_test.go b/stream_test.go index f3f8f85c..35c52ae3 100644 --- a/stream_test.go +++ b/stream_test.go @@ -10,23 +10,23 @@ import ( "testing" "time" - . "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai/internal/test/checks" ) func TestCompletionsStreamWrongModel(t *testing.T) { - config := DefaultConfig("whatever") + config := openai.DefaultConfig("whatever") config.BaseURL = "http://localhost/v1" - client := NewClientWithConfig(config) + client := openai.NewClientWithConfig(config) _, err := client.CreateCompletionStream( context.Background(), - CompletionRequest{ + openai.CompletionRequest{ MaxTokens: 5, - Model: GPT3Dot5Turbo, + Model: openai.GPT3Dot5Turbo, }, ) - if !errors.Is(err, ErrCompletionUnsupportedModel) { + if !errors.Is(err, openai.ErrCompletionUnsupportedModel) { t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel, but returned: %v", err) } } @@ -56,7 +56,7 @@ func TestCreateCompletionStream(t *testing.T) { checks.NoError(t, err, "Write error") }) - stream, err := client.CreateCompletionStream(context.Background(), CompletionRequest{ + stream, err := client.CreateCompletionStream(context.Background(), openai.CompletionRequest{ Prompt: "Ex falso quodlibet", Model: "text-davinci-002", MaxTokens: 10, @@ -65,20 +65,20 @@ func TestCreateCompletionStream(t *testing.T) { checks.NoError(t, err, "CreateCompletionStream returned error") defer stream.Close() - expectedResponses := []CompletionResponse{ + expectedResponses := []openai.CompletionResponse{ { ID: "1", Object: "completion", Created: 1598069254, Model: "text-davinci-002", - Choices: []CompletionChoice{{Text: "response1", FinishReason: "max_tokens"}}, + Choices: []openai.CompletionChoice{{Text: "response1", FinishReason: "max_tokens"}}, }, { ID: "2", Object: "completion", Created: 1598069255, Model: "text-davinci-002", - Choices: []CompletionChoice{{Text: "response2", FinishReason: "max_tokens"}}, + Choices: []openai.CompletionChoice{{Text: "response2", FinishReason: "max_tokens"}}, }, } @@ -129,9 +129,9 @@ func TestCreateCompletionStreamError(t *testing.T) { checks.NoError(t, err, "Write error") }) - stream, err := client.CreateCompletionStream(context.Background(), CompletionRequest{ + stream, err := client.CreateCompletionStream(context.Background(), openai.CompletionRequest{ MaxTokens: 5, - Model: GPT3TextDavinci003, + Model: openai.GPT3TextDavinci003, Prompt: "Hello!", Stream: true, }) @@ -141,7 +141,7 @@ func TestCreateCompletionStreamError(t *testing.T) { _, streamErr := stream.Recv() checks.HasError(t, streamErr, "stream.Recv() did not return error") - var apiErr *APIError + var apiErr *openai.APIError if !errors.As(streamErr, &apiErr) { t.Errorf("stream.Recv() did not return APIError") } @@ -166,10 +166,10 @@ func TestCreateCompletionStreamRateLimitError(t *testing.T) { checks.NoError(t, err, "Write error") }) - var apiErr *APIError - _, err := client.CreateCompletionStream(context.Background(), CompletionRequest{ + var apiErr *openai.APIError + _, err := client.CreateCompletionStream(context.Background(), openai.CompletionRequest{ MaxTokens: 5, - Model: GPT3Ada, + Model: openai.GPT3Ada, Prompt: "Hello!", Stream: true, }) @@ -209,7 +209,7 @@ func TestCreateCompletionStreamTooManyEmptyStreamMessagesError(t *testing.T) { checks.NoError(t, err, "Write error") }) - stream, err := client.CreateCompletionStream(context.Background(), CompletionRequest{ + stream, err := client.CreateCompletionStream(context.Background(), openai.CompletionRequest{ Prompt: "Ex falso quodlibet", Model: "text-davinci-002", MaxTokens: 10, @@ -220,7 +220,7 @@ func TestCreateCompletionStreamTooManyEmptyStreamMessagesError(t *testing.T) { _, _ = stream.Recv() _, streamErr := stream.Recv() - if !errors.Is(streamErr, ErrTooManyEmptyStreamMessages) { + if !errors.Is(streamErr, openai.ErrTooManyEmptyStreamMessages) { t.Errorf("TestCreateCompletionStreamTooManyEmptyStreamMessagesError did not return ErrTooManyEmptyStreamMessages") } } @@ -244,7 +244,7 @@ func TestCreateCompletionStreamUnexpectedTerminatedError(t *testing.T) { checks.NoError(t, err, "Write error") }) - stream, err := client.CreateCompletionStream(context.Background(), CompletionRequest{ + stream, err := client.CreateCompletionStream(context.Background(), openai.CompletionRequest{ Prompt: "Ex falso quodlibet", Model: "text-davinci-002", MaxTokens: 10, @@ -285,7 +285,7 @@ func TestCreateCompletionStreamBrokenJSONError(t *testing.T) { checks.NoError(t, err, "Write error") }) - stream, err := client.CreateCompletionStream(context.Background(), CompletionRequest{ + stream, err := client.CreateCompletionStream(context.Background(), openai.CompletionRequest{ Prompt: "Ex falso quodlibet", Model: "text-davinci-002", MaxTokens: 10, @@ -312,7 +312,7 @@ func TestCreateCompletionStreamReturnTimeoutError(t *testing.T) { ctx, cancel := context.WithTimeout(ctx, time.Nanosecond) defer cancel() - _, err := client.CreateCompletionStream(ctx, CompletionRequest{ + _, err := client.CreateCompletionStream(ctx, openai.CompletionRequest{ Prompt: "Ex falso quodlibet", Model: "text-davinci-002", MaxTokens: 10, @@ -327,7 +327,7 @@ func TestCreateCompletionStreamReturnTimeoutError(t *testing.T) { } // Helper funcs. -func compareResponses(r1, r2 CompletionResponse) bool { +func compareResponses(r1, r2 openai.CompletionResponse) bool { if r1.ID != r2.ID || r1.Object != r2.Object || r1.Created != r2.Created || r1.Model != r2.Model { return false } @@ -342,7 +342,7 @@ func compareResponses(r1, r2 CompletionResponse) bool { return true } -func compareResponseChoices(c1, c2 CompletionChoice) bool { +func compareResponseChoices(c1, c2 openai.CompletionChoice) bool { if c1.Text != c2.Text || c1.FinishReason != c2.FinishReason { return false } From d07833e19bfbb2f26011c8881f7fb61366c07e75 Mon Sep 17 00:00:00 2001 From: Carson Kahn Date: Tue, 7 Nov 2023 04:27:29 -0500 Subject: [PATCH 14/98] Doc ways to improve reproducability besides Temp (#532) --- README.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index f0b60908..4cb77db6 100644 --- a/README.md +++ b/README.md @@ -757,8 +757,9 @@ Even when specifying a temperature field of 0, it doesn't guarantee that you'll Due to the factors mentioned above, different answers may be returned even for the same question. **Workarounds:** -1. Using `math.SmallestNonzeroFloat32`: By specifying `math.SmallestNonzeroFloat32` in the temperature field instead of 0, you can mimic the behavior of setting it to 0. -2. Limiting Token Count: By limiting the number of tokens in the input and output and especially avoiding large requests close to 32k tokens, you can reduce the risk of non-deterministic behavior. +1. As of November 2023, use [the new `seed` parameter](https://platform.openai.com/docs/guides/text-generation/reproducible-outputs) in conjunction with the `system_fingerprint` response field, alongside Temperature management. +2. Try using `math.SmallestNonzeroFloat32`: By specifying `math.SmallestNonzeroFloat32` in the temperature field instead of 0, you can mimic the behavior of setting it to 0. +3. Limiting Token Count: By limiting the number of tokens in the input and output and especially avoiding large requests close to 32k tokens, you can reduce the risk of non-deterministic behavior. By adopting these strategies, you can expect more consistent results. From 6d9c3a6365643d02692ecc6f0b34a5fa3e7fea45 Mon Sep 17 00:00:00 2001 From: Simone Vellei Date: Tue, 7 Nov 2023 15:25:21 +0100 Subject: [PATCH 15/98] Feat Support chat completion response format and seed new fields (#525) * feat: support chat completion response format * fix linting error * fix * fix linting * Revert "fix linting" This reverts commit 015c6ad62aad561218b693225f58670b5619dba8. * Revert "fix" This reverts commit 7b2ffe28c3e586b629d23479ec1728bf52f0c66f. * Revert "fix linting error" This reverts commit 29960423784e296cb6d22c5db8f8ccf00cac59fd. * chore: add seed new parameter * fix --- chat.go | 33 +++++++++++++++++++++++---------- 1 file changed, 23 insertions(+), 10 deletions(-) diff --git a/chat.go b/chat.go index df0e5f97..88db8cf1 100644 --- a/chat.go +++ b/chat.go @@ -69,18 +69,31 @@ type FunctionCall struct { Arguments string `json:"arguments,omitempty"` } +type ChatCompletionResponseFormatType string + +const ( + ChatCompletionResponseFormatTypeJSONObject ChatCompletionResponseFormatType = "json_object" + ChatCompletionResponseFormatTypeText ChatCompletionResponseFormatType = "text" +) + +type ChatCompletionResponseFormat struct { + Type ChatCompletionResponseFormatType `json:"type"` +} + // ChatCompletionRequest represents a request structure for chat completion API. type ChatCompletionRequest struct { - Model string `json:"model"` - Messages []ChatCompletionMessage `json:"messages"` - MaxTokens int `json:"max_tokens,omitempty"` - Temperature float32 `json:"temperature,omitempty"` - TopP float32 `json:"top_p,omitempty"` - N int `json:"n,omitempty"` - Stream bool `json:"stream,omitempty"` - Stop []string `json:"stop,omitempty"` - PresencePenalty float32 `json:"presence_penalty,omitempty"` - FrequencyPenalty float32 `json:"frequency_penalty,omitempty"` + Model string `json:"model"` + Messages []ChatCompletionMessage `json:"messages"` + MaxTokens int `json:"max_tokens,omitempty"` + Temperature float32 `json:"temperature,omitempty"` + TopP float32 `json:"top_p,omitempty"` + N int `json:"n,omitempty"` + Stream bool `json:"stream,omitempty"` + Stop []string `json:"stop,omitempty"` + PresencePenalty float32 `json:"presence_penalty,omitempty"` + ResponseFormat ChatCompletionResponseFormat `json:"response_format,omitempty"` + Seed *int `json:"seed,omitempty"` + FrequencyPenalty float32 `json:"frequency_penalty,omitempty"` // LogitBias is must be a token id string (specified by their token ID in the tokenizer), not a word string. // incorrect: `"logit_bias":{"You": 6}`, correct: `"logit_bias":{"1639": 6}` // refs: https://platform.openai.com/docs/api-reference/chat/create#chat/create-logit_bias From 3063e676bf5932024d76be8e8d9e41df06d4e8cc Mon Sep 17 00:00:00 2001 From: Simone Vellei Date: Tue, 7 Nov 2023 16:20:59 +0100 Subject: [PATCH 16/98] Feat Implement assistants API (#535) * chore: implement assistants API * fix * fix * chore: add tests * fix tests * fix linting --- assistant.go | 260 ++++++++++++++++++++++++++++++++++++++++++++++ assistant_test.go | 202 +++++++++++++++++++++++++++++++++++ client_test.go | 27 +++++ 3 files changed, 489 insertions(+) create mode 100644 assistant.go create mode 100644 assistant_test.go diff --git a/assistant.go b/assistant.go new file mode 100644 index 00000000..d75eebef --- /dev/null +++ b/assistant.go @@ -0,0 +1,260 @@ +package openai + +import ( + "context" + "fmt" + "net/http" + "net/url" +) + +const ( + assistantsSuffix = "/assistants" + assistantsFilesSuffix = "/files" +) + +type Assistant struct { + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int64 `json:"created_at"` + Name *string `json:"name,omitempty"` + Description *string `json:"description,omitempty"` + Model string `json:"model"` + Instructions *string `json:"instructions,omitempty"` + Tools []any `json:"tools,omitempty"` + + httpHeader +} + +type AssistantTool struct { + Type string `json:"type"` +} + +type AssistantToolCodeInterpreter struct { + AssistantTool +} + +type AssistantToolRetrieval struct { + AssistantTool +} + +type AssistantToolFunction struct { + AssistantTool + Function FunctionDefinition `json:"function"` +} + +type AssistantRequest struct { + Model string `json:"model"` + Name *string `json:"name,omitempty"` + Description *string `json:"description,omitempty"` + Instructions *string `json:"instructions,omitempty"` + Tools []any `json:"tools,omitempty"` + FileIDs []string `json:"file_ids,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` +} + +// AssistantsList is a list of assistants. +type AssistantsList struct { + Assistants []Assistant `json:"data"` + + httpHeader +} + +type AssistantFile struct { + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int64 `json:"created_at"` + AssistantID string `json:"assistant_id"` + + httpHeader +} + +type AssistantFileRequest struct { + FileID string `json:"file_id"` +} + +type AssistantFilesList struct { + AssistantFiles []AssistantFile `json:"data"` + + httpHeader +} + +// CreateAssistant creates a new assistant. +func (c *Client) CreateAssistant(ctx context.Context, request AssistantRequest) (response Assistant, err error) { + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(assistantsSuffix), withBody(request)) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +// RetrieveAssistant retrieves an assistant. +func (c *Client) RetrieveAssistant( + ctx context.Context, + assistantID string, +) (response Assistant, err error) { + urlSuffix := fmt.Sprintf("%s/%s", assistantsSuffix, assistantID) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix)) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +// ModifyAssistant modifies an assistant. +func (c *Client) ModifyAssistant( + ctx context.Context, + assistantID string, + request AssistantRequest, +) (response Assistant, err error) { + urlSuffix := fmt.Sprintf("%s/%s", assistantsSuffix, assistantID) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request)) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +// DeleteAssistant deletes an assistant. +func (c *Client) DeleteAssistant( + ctx context.Context, + assistantID string, +) (response Assistant, err error) { + urlSuffix := fmt.Sprintf("%s/%s", assistantsSuffix, assistantID) + req, err := c.newRequest(ctx, http.MethodDelete, c.fullURL(urlSuffix)) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +// ListAssistants Lists the currently available assistants. +func (c *Client) ListAssistants( + ctx context.Context, + limit *int, + order *string, + after *string, + before *string, +) (reponse AssistantsList, err error) { + urlValues := url.Values{} + if limit != nil { + urlValues.Add("limit", fmt.Sprintf("%d", *limit)) + } + if order != nil { + urlValues.Add("order", *order) + } + if after != nil { + urlValues.Add("after", *after) + } + if before != nil { + urlValues.Add("before", *before) + } + + encodedValues := "" + if len(urlValues) > 0 { + encodedValues = "?" + urlValues.Encode() + } + + urlSuffix := fmt.Sprintf("%s%s", assistantsSuffix, encodedValues) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix)) + if err != nil { + return + } + + err = c.sendRequest(req, &reponse) + return +} + +// CreateAssistantFile creates a new assistant file. +func (c *Client) CreateAssistantFile( + ctx context.Context, + assistantID string, + request AssistantFileRequest, +) (response AssistantFile, err error) { + urlSuffix := fmt.Sprintf("%s/%s%s", assistantsSuffix, assistantID, assistantsFilesSuffix) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), + withBody(request)) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +// RetrieveAssistantFile retrieves an assistant file. +func (c *Client) RetrieveAssistantFile( + ctx context.Context, + assistantID string, + fileID string, +) (response AssistantFile, err error) { + urlSuffix := fmt.Sprintf("%s/%s%s/%s", assistantsSuffix, assistantID, assistantsFilesSuffix, fileID) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix)) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +// DeleteAssistantFile deletes an existing file. +func (c *Client) DeleteAssistantFile( + ctx context.Context, + assistantID string, + fileID string, +) (err error) { + urlSuffix := fmt.Sprintf("%s/%s%s/%s", assistantsSuffix, assistantID, assistantsFilesSuffix, fileID) + req, err := c.newRequest(ctx, http.MethodDelete, c.fullURL(urlSuffix)) + if err != nil { + return + } + + err = c.sendRequest(req, nil) + return +} + +// ListAssistantFiles Lists the currently available files for an assistant. +func (c *Client) ListAssistantFiles( + ctx context.Context, + assistantID string, + limit *int, + order *string, + after *string, + before *string, +) (response AssistantFilesList, err error) { + urlValues := url.Values{} + if limit != nil { + urlValues.Add("limit", fmt.Sprintf("%d", *limit)) + } + if order != nil { + urlValues.Add("order", *order) + } + if after != nil { + urlValues.Add("after", *after) + } + if before != nil { + urlValues.Add("before", *before) + } + + encodedValues := "" + if len(urlValues) > 0 { + encodedValues = "?" + urlValues.Encode() + } + + urlSuffix := fmt.Sprintf("%s/%s%s%s", assistantsSuffix, assistantID, assistantsFilesSuffix, encodedValues) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix)) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} diff --git a/assistant_test.go b/assistant_test.go new file mode 100644 index 00000000..eb6f4245 --- /dev/null +++ b/assistant_test.go @@ -0,0 +1,202 @@ +package openai_test + +import ( + "context" + + openai "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" + + "encoding/json" + "fmt" + "net/http" + "testing" +) + +// TestAssistant Tests the assistant endpoint of the API using the mocked server. +func TestAssistant(t *testing.T) { + assistantID := "asst_abc123" + assistantName := "Ambrogio" + assistantDescription := "Ambrogio is a friendly assistant." + assitantInstructions := `You are a personal math tutor. +When asked a question, write and run Python code to answer the question.` + assistantFileID := "file-wB6RM6wHdA49HfS2DJ9fEyrH" + limit := 20 + order := "desc" + after := "asst_abc122" + before := "asst_abc124" + + client, server, teardown := setupOpenAITestServer() + defer teardown() + + server.RegisterHandler( + "/v1/assistants/"+assistantID+"/files/"+assistantFileID, + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet { + resBytes, _ := json.Marshal(openai.AssistantFile{ + ID: assistantFileID, + Object: "assistant.file", + CreatedAt: 1234567890, + AssistantID: assistantID, + }) + fmt.Fprintln(w, string(resBytes)) + } else if r.Method == http.MethodDelete { + fmt.Fprintln(w, `{ + id: "file-wB6RM6wHdA49HfS2DJ9fEyrH", + object: "assistant.file.deleted", + deleted: true + }`) + } + }, + ) + + server.RegisterHandler( + "/v1/assistants/"+assistantID+"/files", + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet { + resBytes, _ := json.Marshal(openai.AssistantFilesList{ + AssistantFiles: []openai.AssistantFile{ + { + ID: assistantFileID, + Object: "assistant.file", + CreatedAt: 1234567890, + AssistantID: assistantID, + }, + }, + }) + fmt.Fprintln(w, string(resBytes)) + } else if r.Method == http.MethodPost { + var request openai.AssistantFileRequest + err := json.NewDecoder(r.Body).Decode(&request) + checks.NoError(t, err, "Decode error") + + resBytes, _ := json.Marshal(openai.AssistantFile{ + ID: request.FileID, + Object: "assistant.file", + CreatedAt: 1234567890, + AssistantID: assistantID, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + server.RegisterHandler( + "/v1/assistants/"+assistantID, + func(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet: + resBytes, _ := json.Marshal(openai.Assistant{ + ID: assistantID, + Object: "assistant", + CreatedAt: 1234567890, + Name: &assistantName, + Model: openai.GPT4TurboPreview, + Description: &assistantDescription, + Instructions: &assitantInstructions, + }) + fmt.Fprintln(w, string(resBytes)) + case http.MethodPost: + var request openai.AssistantRequest + err := json.NewDecoder(r.Body).Decode(&request) + checks.NoError(t, err, "Decode error") + + resBytes, _ := json.Marshal(openai.Assistant{ + ID: assistantID, + Object: "assistant", + CreatedAt: 1234567890, + Name: request.Name, + Model: request.Model, + Description: request.Description, + Instructions: request.Instructions, + Tools: request.Tools, + }) + fmt.Fprintln(w, string(resBytes)) + case http.MethodDelete: + fmt.Fprintln(w, `{ + "id": "asst_abc123", + "object": "assistant.deleted", + "deleted": true + }`) + } + }, + ) + + server.RegisterHandler( + "/v1/assistants", + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + var request openai.AssistantRequest + err := json.NewDecoder(r.Body).Decode(&request) + checks.NoError(t, err, "Decode error") + + resBytes, _ := json.Marshal(openai.Assistant{ + ID: assistantID, + Object: "assistant", + CreatedAt: 1234567890, + Name: request.Name, + Model: request.Model, + Description: request.Description, + Instructions: request.Instructions, + Tools: request.Tools, + }) + fmt.Fprintln(w, string(resBytes)) + } else if r.Method == http.MethodGet { + resBytes, _ := json.Marshal(openai.AssistantsList{ + Assistants: []openai.Assistant{ + { + ID: assistantID, + Object: "assistant", + CreatedAt: 1234567890, + Name: &assistantName, + Model: openai.GPT4TurboPreview, + Description: &assistantDescription, + Instructions: &assitantInstructions, + }, + }, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + ctx := context.Background() + + _, err := client.CreateAssistant(ctx, openai.AssistantRequest{ + Name: &assistantName, + Description: &assistantDescription, + Model: openai.GPT4TurboPreview, + Instructions: &assitantInstructions, + }) + checks.NoError(t, err, "CreateAssistant error") + + _, err = client.RetrieveAssistant(ctx, assistantID) + checks.NoError(t, err, "RetrieveAssistant error") + + _, err = client.ModifyAssistant(ctx, assistantID, openai.AssistantRequest{ + Name: &assistantName, + Description: &assistantDescription, + Model: openai.GPT4TurboPreview, + Instructions: &assitantInstructions, + }) + checks.NoError(t, err, "ModifyAssistant error") + + _, err = client.DeleteAssistant(ctx, assistantID) + checks.NoError(t, err, "DeleteAssistant error") + + _, err = client.ListAssistants(ctx, &limit, &order, &after, &before) + checks.NoError(t, err, "ListAssistants error") + + _, err = client.CreateAssistantFile(ctx, assistantID, openai.AssistantFileRequest{ + FileID: assistantFileID, + }) + checks.NoError(t, err, "CreateAssistantFile error") + + _, err = client.ListAssistantFiles(ctx, assistantID, &limit, &order, &after, &before) + checks.NoError(t, err, "ListAssistantFiles error") + + _, err = client.RetrieveAssistantFile(ctx, assistantID, assistantFileID) + checks.NoError(t, err, "RetrieveAssistantFile error") + + err = client.DeleteAssistantFile(ctx, assistantID, assistantFileID) + checks.NoError(t, err, "DeleteAssistantFile error") +} diff --git a/client_test.go b/client_test.go index 2c1d749e..bff2597c 100644 --- a/client_test.go +++ b/client_test.go @@ -274,6 +274,33 @@ func TestClientReturnsRequestBuilderErrors(t *testing.T) { {"DeleteFineTuneModel", func() (any, error) { return client.DeleteFineTuneModel(ctx, "") }}, + {"CreateAssistant", func() (any, error) { + return client.CreateAssistant(ctx, AssistantRequest{}) + }}, + {"RetrieveAssistant", func() (any, error) { + return client.RetrieveAssistant(ctx, "") + }}, + {"ModifyAssistant", func() (any, error) { + return client.ModifyAssistant(ctx, "", AssistantRequest{}) + }}, + {"DeleteAssistant", func() (any, error) { + return client.DeleteAssistant(ctx, "") + }}, + {"ListAssistants", func() (any, error) { + return client.ListAssistants(ctx, nil, nil, nil, nil) + }}, + {"CreateAssistantFile", func() (any, error) { + return client.CreateAssistantFile(ctx, "", AssistantFileRequest{}) + }}, + {"ListAssistantFiles", func() (any, error) { + return client.ListAssistantFiles(ctx, "", nil, nil, nil, nil) + }}, + {"RetrieveAssistantFile", func() (any, error) { + return client.RetrieveAssistantFile(ctx, "", "") + }}, + {"DeleteAssistantFile", func() (any, error) { + return nil, client.DeleteAssistantFile(ctx, "", "") + }}, } for _, testCase := range testCases { From 1ad6b6f53dcd9abfaf56e8adb02b5b599936580c Mon Sep 17 00:00:00 2001 From: Simone Vellei Date: Tue, 7 Nov 2023 16:53:24 +0100 Subject: [PATCH 17/98] Feat Support tools and tools choice new fileds (#526) * feat: support tools and tools choice new fileds * fix: use value not pointers --- chat.go | 41 +++++++++++++++++++++++++++++++++++++---- chat_stream.go | 1 + 2 files changed, 38 insertions(+), 4 deletions(-) diff --git a/chat.go b/chat.go index 88db8cf1..04303184 100644 --- a/chat.go +++ b/chat.go @@ -12,6 +12,7 @@ const ( ChatMessageRoleUser = "user" ChatMessageRoleAssistant = "assistant" ChatMessageRoleFunction = "function" + ChatMessageRoleTool = "tool" ) const chatCompletionsSuffix = "/chat/completions" @@ -61,6 +62,12 @@ type ChatCompletionMessage struct { Name string `json:"name,omitempty"` FunctionCall *FunctionCall `json:"function_call,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` +} + +type ToolCall struct { + ID string `json:"id"` + Function FunctionCall `json:"function"` } type FunctionCall struct { @@ -97,10 +104,35 @@ type ChatCompletionRequest struct { // LogitBias is must be a token id string (specified by their token ID in the tokenizer), not a word string. // incorrect: `"logit_bias":{"You": 6}`, correct: `"logit_bias":{"1639": 6}` // refs: https://platform.openai.com/docs/api-reference/chat/create#chat/create-logit_bias - LogitBias map[string]int `json:"logit_bias,omitempty"` - User string `json:"user,omitempty"` - Functions []FunctionDefinition `json:"functions,omitempty"` - FunctionCall any `json:"function_call,omitempty"` + LogitBias map[string]int `json:"logit_bias,omitempty"` + User string `json:"user,omitempty"` + // Deprecated: use Tools instead. + Functions []FunctionDefinition `json:"functions,omitempty"` + // Deprecated: use ToolChoice instead. + FunctionCall any `json:"function_call,omitempty"` + Tools []Tool `json:"tools,omitempty"` + // This can be either a string or an ToolChoice object. + ToolChoiche any `json:"tool_choice,omitempty"` +} + +type ToolType string + +const ( + ToolTypeFunction ToolType = "function" +) + +type Tool struct { + Type ToolType `json:"type"` + Function FunctionDefinition `json:"function,omitempty"` +} + +type ToolChoiche struct { + Type ToolType `json:"type"` + Function ToolFunction `json:"function,omitempty"` +} + +type ToolFunction struct { + Name string `json:"name"` } type FunctionDefinition struct { @@ -123,6 +155,7 @@ const ( FinishReasonStop FinishReason = "stop" FinishReasonLength FinishReason = "length" FinishReasonFunctionCall FinishReason = "function_call" + FinishReasonToolCalls FinishReason = "tool_calls" FinishReasonContentFilter FinishReason = "content_filter" FinishReasonNull FinishReason = "null" ) diff --git a/chat_stream.go b/chat_stream.go index f1faa396..57cfa789 100644 --- a/chat_stream.go +++ b/chat_stream.go @@ -9,6 +9,7 @@ type ChatCompletionStreamChoiceDelta struct { Content string `json:"content,omitempty"` Role string `json:"role,omitempty"` FunctionCall *FunctionCall `json:"function_call,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` } type ChatCompletionStreamChoice struct { From a20eb08b79e5c34882888a401020b47c145357ff Mon Sep 17 00:00:00 2001 From: Simone Vellei Date: Tue, 7 Nov 2023 22:30:05 +0100 Subject: [PATCH 18/98] fix: use pointer for ChatCompletionResponseFormat (#544) --- chat.go | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/chat.go b/chat.go index 04303184..609e0c31 100644 --- a/chat.go +++ b/chat.go @@ -89,18 +89,18 @@ type ChatCompletionResponseFormat struct { // ChatCompletionRequest represents a request structure for chat completion API. type ChatCompletionRequest struct { - Model string `json:"model"` - Messages []ChatCompletionMessage `json:"messages"` - MaxTokens int `json:"max_tokens,omitempty"` - Temperature float32 `json:"temperature,omitempty"` - TopP float32 `json:"top_p,omitempty"` - N int `json:"n,omitempty"` - Stream bool `json:"stream,omitempty"` - Stop []string `json:"stop,omitempty"` - PresencePenalty float32 `json:"presence_penalty,omitempty"` - ResponseFormat ChatCompletionResponseFormat `json:"response_format,omitempty"` - Seed *int `json:"seed,omitempty"` - FrequencyPenalty float32 `json:"frequency_penalty,omitempty"` + Model string `json:"model"` + Messages []ChatCompletionMessage `json:"messages"` + MaxTokens int `json:"max_tokens,omitempty"` + Temperature float32 `json:"temperature,omitempty"` + TopP float32 `json:"top_p,omitempty"` + N int `json:"n,omitempty"` + Stream bool `json:"stream,omitempty"` + Stop []string `json:"stop,omitempty"` + PresencePenalty float32 `json:"presence_penalty,omitempty"` + ResponseFormat *ChatCompletionResponseFormat `json:"response_format,omitempty"` + Seed *int `json:"seed,omitempty"` + FrequencyPenalty float32 `json:"frequency_penalty,omitempty"` // LogitBias is must be a token id string (specified by their token ID in the tokenizer), not a word string. // incorrect: `"logit_bias":{"You": 6}`, correct: `"logit_bias":{"1639": 6}` // refs: https://platform.openai.com/docs/api-reference/chat/create#chat/create-logit_bias From a0159ad2b00e4f127222814694bec68863395543 Mon Sep 17 00:00:00 2001 From: Mike Cutalo Date: Tue, 7 Nov 2023 23:16:22 -0800 Subject: [PATCH 19/98] Support new fields for /v1/images/generation API (#530) * add support for new image/generation api * fix one lint * add revised_prompt to response * fix lints * add CreateImageQualityStandard --- image.go | 26 ++++++++++++++++++++++++-- image_api_test.go | 9 ++++++++- 2 files changed, 32 insertions(+), 3 deletions(-) diff --git a/image.go b/image.go index 4addcdb1..4fe8b3a3 100644 --- a/image.go +++ b/image.go @@ -13,6 +13,9 @@ const ( CreateImageSize256x256 = "256x256" CreateImageSize512x512 = "512x512" CreateImageSize1024x1024 = "1024x1024" + // dall-e-3 supported only. + CreateImageSize1792x1024 = "1792x1024" + CreateImageSize1024x1792 = "1024x1792" ) const ( @@ -20,11 +23,29 @@ const ( CreateImageResponseFormatB64JSON = "b64_json" ) +const ( + CreateImageModelDallE2 = "dall-e-2" + CreateImageModelDallE3 = "dall-e-3" +) + +const ( + CreateImageQualityHD = "hd" + CreateImageQualityStandard = "standard" +) + +const ( + CreateImageStyleVivid = "vivid" + CreateImageStyleNatural = "natural" +) + // ImageRequest represents the request structure for the image API. type ImageRequest struct { Prompt string `json:"prompt,omitempty"` + Model string `json:"model,omitempty"` N int `json:"n,omitempty"` + Quality string `json:"quality,omitempty"` Size string `json:"size,omitempty"` + Style string `json:"style,omitempty"` ResponseFormat string `json:"response_format,omitempty"` User string `json:"user,omitempty"` } @@ -39,8 +60,9 @@ type ImageResponse struct { // ImageResponseDataInner represents a response data structure for image API. type ImageResponseDataInner struct { - URL string `json:"url,omitempty"` - B64JSON string `json:"b64_json,omitempty"` + URL string `json:"url,omitempty"` + B64JSON string `json:"b64_json,omitempty"` + RevisedPrompt string `json:"revised_prompt,omitempty"` } // CreateImage - API call to create an image. This is the main endpoint of the DALL-E API. diff --git a/image_api_test.go b/image_api_test.go index 422f831f..2eb46f2b 100644 --- a/image_api_test.go +++ b/image_api_test.go @@ -19,7 +19,14 @@ func TestImages(t *testing.T) { defer teardown() server.RegisterHandler("/v1/images/generations", handleImageEndpoint) _, err := client.CreateImage(context.Background(), openai.ImageRequest{ - Prompt: "Lorem ipsum", + Prompt: "Lorem ipsum", + Model: openai.CreateImageModelDallE3, + N: 1, + Quality: openai.CreateImageQualityHD, + Size: openai.CreateImageSize1024x1024, + Style: openai.CreateImageStyleVivid, + ResponseFormat: openai.CreateImageResponseFormatURL, + User: "user", }) checks.NoError(t, err, "CreateImage error") } From a2d2bf685122fd51d768f2a828787cae587d9ad6 Mon Sep 17 00:00:00 2001 From: Simone Vellei Date: Wed, 8 Nov 2023 10:20:20 +0100 Subject: [PATCH 20/98] Fix Refactor assistant api (#545) * fix: refactor assistant API * fix * trigger build * fix: use AssistantDeleteResponse --- assistant.go | 90 ++++++++++++++++++++++++++++++---------------------- client.go | 6 ++++ 2 files changed, 58 insertions(+), 38 deletions(-) diff --git a/assistant.go b/assistant.go index d75eebef..de49be68 100644 --- a/assistant.go +++ b/assistant.go @@ -10,46 +10,43 @@ import ( const ( assistantsSuffix = "/assistants" assistantsFilesSuffix = "/files" + openaiAssistantsV1 = "assistants=v1" ) type Assistant struct { - ID string `json:"id"` - Object string `json:"object"` - CreatedAt int64 `json:"created_at"` - Name *string `json:"name,omitempty"` - Description *string `json:"description,omitempty"` - Model string `json:"model"` - Instructions *string `json:"instructions,omitempty"` - Tools []any `json:"tools,omitempty"` + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int64 `json:"created_at"` + Name *string `json:"name,omitempty"` + Description *string `json:"description,omitempty"` + Model string `json:"model"` + Instructions *string `json:"instructions,omitempty"` + Tools []AssistantTool `json:"tools,omitempty"` httpHeader } -type AssistantTool struct { - Type string `json:"type"` -} - -type AssistantToolCodeInterpreter struct { - AssistantTool -} +type AssistantToolType string -type AssistantToolRetrieval struct { - AssistantTool -} +const ( + AssistantToolTypeCodeInterpreter AssistantToolType = "code_interpreter" + AssistantToolTypeRetrieval AssistantToolType = "retrieval" + AssistantToolTypeFunction AssistantToolType = "function" +) -type AssistantToolFunction struct { - AssistantTool - Function FunctionDefinition `json:"function"` +type AssistantTool struct { + Type AssistantToolType `json:"type"` + Function *FunctionDefinition `json:"function,omitempty"` } type AssistantRequest struct { - Model string `json:"model"` - Name *string `json:"name,omitempty"` - Description *string `json:"description,omitempty"` - Instructions *string `json:"instructions,omitempty"` - Tools []any `json:"tools,omitempty"` - FileIDs []string `json:"file_ids,omitempty"` - Metadata map[string]any `json:"metadata,omitempty"` + Model string `json:"model"` + Name *string `json:"name,omitempty"` + Description *string `json:"description,omitempty"` + Instructions *string `json:"instructions,omitempty"` + Tools []AssistantTool `json:"tools,omitempty"` + FileIDs []string `json:"file_ids,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` } // AssistantsList is a list of assistants. @@ -59,6 +56,14 @@ type AssistantsList struct { httpHeader } +type AssistantDeleteResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Deleted bool `json:"deleted"` + + httpHeader +} + type AssistantFile struct { ID string `json:"id"` Object string `json:"object"` @@ -80,7 +85,8 @@ type AssistantFilesList struct { // CreateAssistant creates a new assistant. func (c *Client) CreateAssistant(ctx context.Context, request AssistantRequest) (response Assistant, err error) { - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(assistantsSuffix), withBody(request)) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(assistantsSuffix), withBody(request), + withBetaAssistantV1()) if err != nil { return } @@ -95,7 +101,8 @@ func (c *Client) RetrieveAssistant( assistantID string, ) (response Assistant, err error) { urlSuffix := fmt.Sprintf("%s/%s", assistantsSuffix, assistantID) - req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix)) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantV1()) if err != nil { return } @@ -111,7 +118,8 @@ func (c *Client) ModifyAssistant( request AssistantRequest, ) (response Assistant, err error) { urlSuffix := fmt.Sprintf("%s/%s", assistantsSuffix, assistantID) - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request)) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request), + withBetaAssistantV1()) if err != nil { return } @@ -124,9 +132,10 @@ func (c *Client) ModifyAssistant( func (c *Client) DeleteAssistant( ctx context.Context, assistantID string, -) (response Assistant, err error) { +) (response AssistantDeleteResponse, err error) { urlSuffix := fmt.Sprintf("%s/%s", assistantsSuffix, assistantID) - req, err := c.newRequest(ctx, http.MethodDelete, c.fullURL(urlSuffix)) + req, err := c.newRequest(ctx, http.MethodDelete, c.fullURL(urlSuffix), + withBetaAssistantV1()) if err != nil { return } @@ -163,7 +172,8 @@ func (c *Client) ListAssistants( } urlSuffix := fmt.Sprintf("%s%s", assistantsSuffix, encodedValues) - req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix)) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantV1()) if err != nil { return } @@ -180,7 +190,8 @@ func (c *Client) CreateAssistantFile( ) (response AssistantFile, err error) { urlSuffix := fmt.Sprintf("%s/%s%s", assistantsSuffix, assistantID, assistantsFilesSuffix) req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), - withBody(request)) + withBody(request), + withBetaAssistantV1()) if err != nil { return } @@ -196,7 +207,8 @@ func (c *Client) RetrieveAssistantFile( fileID string, ) (response AssistantFile, err error) { urlSuffix := fmt.Sprintf("%s/%s%s/%s", assistantsSuffix, assistantID, assistantsFilesSuffix, fileID) - req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix)) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantV1()) if err != nil { return } @@ -212,7 +224,8 @@ func (c *Client) DeleteAssistantFile( fileID string, ) (err error) { urlSuffix := fmt.Sprintf("%s/%s%s/%s", assistantsSuffix, assistantID, assistantsFilesSuffix, fileID) - req, err := c.newRequest(ctx, http.MethodDelete, c.fullURL(urlSuffix)) + req, err := c.newRequest(ctx, http.MethodDelete, c.fullURL(urlSuffix), + withBetaAssistantV1()) if err != nil { return } @@ -250,7 +263,8 @@ func (c *Client) ListAssistantFiles( } urlSuffix := fmt.Sprintf("%s/%s%s%s", assistantsSuffix, assistantID, assistantsFilesSuffix, encodedValues) - req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix)) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantV1()) if err != nil { return } diff --git a/client.go b/client.go index 65ece812..056226c6 100644 --- a/client.go +++ b/client.go @@ -83,6 +83,12 @@ func withContentType(contentType string) requestOption { } } +func withBetaAssistantV1() requestOption { + return func(args *requestOptions) { + args.header.Set("OpenAI-Beta", "assistants=v1") + } +} + func (c *Client) newRequest(ctx context.Context, method, url string, setters ...requestOption) (*http.Request, error) { // Default Options args := &requestOptions{ From 08c167fecf6953619d1905ab2959ed341bfb063d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=A1=E9=82=89=E7=A5=90=E4=B8=80=20/=20Yuichi=20Watana?= =?UTF-8?q?be?= Date: Wed, 8 Nov 2023 18:21:51 +0900 Subject: [PATCH 21/98] test: fix compile error in api integration test (#548) --- api_integration_test.go | 41 +++++++++++++++++++++-------------------- 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/api_integration_test.go b/api_integration_test.go index 6be188bc..736040c5 100644 --- a/api_integration_test.go +++ b/api_integration_test.go @@ -9,6 +9,7 @@ import ( "os" "testing" + "github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai/internal/test/checks" "github.com/sashabaranov/go-openai/jsonschema" ) @@ -20,7 +21,7 @@ func TestAPI(t *testing.T) { } var err error - c := NewClient(apiToken) + c := openai.NewClient(apiToken) ctx := context.Background() _, err = c.ListEngines(ctx) checks.NoError(t, err, "ListEngines error") @@ -36,23 +37,23 @@ func TestAPI(t *testing.T) { checks.NoError(t, err, "GetFile error") } // else skip - embeddingReq := EmbeddingRequest{ + embeddingReq := openai.EmbeddingRequest{ Input: []string{ "The food was delicious and the waiter", "Other examples of embedding request", }, - Model: AdaSearchQuery, + Model: openai.AdaSearchQuery, } _, err = c.CreateEmbeddings(ctx, embeddingReq) checks.NoError(t, err, "Embedding error") _, err = c.CreateChatCompletion( ctx, - ChatCompletionRequest{ - Model: GPT3Dot5Turbo, - Messages: []ChatCompletionMessage{ + openai.ChatCompletionRequest{ + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: openai.ChatMessageRoleUser, Content: "Hello!", }, }, @@ -63,11 +64,11 @@ func TestAPI(t *testing.T) { _, err = c.CreateChatCompletion( ctx, - ChatCompletionRequest{ - Model: GPT3Dot5Turbo, - Messages: []ChatCompletionMessage{ + openai.ChatCompletionRequest{ + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: openai.ChatMessageRoleUser, Name: "John_Doe", Content: "Hello!", }, @@ -76,9 +77,9 @@ func TestAPI(t *testing.T) { ) checks.NoError(t, err, "CreateChatCompletion (with name) returned error") - stream, err := c.CreateCompletionStream(ctx, CompletionRequest{ + stream, err := c.CreateCompletionStream(ctx, openai.CompletionRequest{ Prompt: "Ex falso quodlibet", - Model: GPT3Ada, + Model: openai.GPT3Ada, MaxTokens: 5, Stream: true, }) @@ -103,15 +104,15 @@ func TestAPI(t *testing.T) { _, err = c.CreateChatCompletion( context.Background(), - ChatCompletionRequest{ - Model: GPT3Dot5Turbo, - Messages: []ChatCompletionMessage{ + openai.ChatCompletionRequest{ + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: openai.ChatMessageRoleUser, Content: "What is the weather like in Boston?", }, }, - Functions: []FunctionDefinition{{ + Functions: []openai.FunctionDefinition{{ Name: "get_current_weather", Parameters: jsonschema.Definition{ Type: jsonschema.Object, @@ -140,12 +141,12 @@ func TestAPIError(t *testing.T) { } var err error - c := NewClient(apiToken + "_invalid") + c := openai.NewClient(apiToken + "_invalid") ctx := context.Background() _, err = c.ListEngines(ctx) checks.HasError(t, err, "ListEngines should fail with an invalid key") - var apiErr *APIError + var apiErr *openai.APIError if !errors.As(err, &apiErr) { t.Fatalf("Error is not an APIError: %+v", err) } From bc89139c1ddcc4f6d5b15b7e8d0491c69dda402c Mon Sep 17 00:00:00 2001 From: Simone Vellei Date: Thu, 9 Nov 2023 09:05:44 +0100 Subject: [PATCH 22/98] Feat Implement threads API (#536) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: implement threads API * fix * add tests * fix * trigger£ * trigger * chore: add beta header --- client_test.go | 12 ++++++ thread.go | 107 +++++++++++++++++++++++++++++++++++++++++++++++++ thread_test.go | 95 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 214 insertions(+) create mode 100644 thread.go create mode 100644 thread_test.go diff --git a/client_test.go b/client_test.go index bff2597c..b2f28f90 100644 --- a/client_test.go +++ b/client_test.go @@ -301,6 +301,18 @@ func TestClientReturnsRequestBuilderErrors(t *testing.T) { {"DeleteAssistantFile", func() (any, error) { return nil, client.DeleteAssistantFile(ctx, "", "") }}, + {"CreateThread", func() (any, error) { + return client.CreateThread(ctx, ThreadRequest{}) + }}, + {"RetrieveThread", func() (any, error) { + return client.RetrieveThread(ctx, "") + }}, + {"ModifyThread", func() (any, error) { + return client.ModifyThread(ctx, "", ModifyThreadRequest{}) + }}, + {"DeleteThread", func() (any, error) { + return client.DeleteThread(ctx, "") + }}, } for _, testCase := range testCases { diff --git a/thread.go b/thread.go new file mode 100644 index 00000000..291f3dca --- /dev/null +++ b/thread.go @@ -0,0 +1,107 @@ +package openai + +import ( + "context" + "net/http" +) + +const ( + threadsSuffix = "/threads" +) + +type Thread struct { + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int64 `json:"created_at"` + Metadata map[string]any `json:"metadata"` + + httpHeader +} + +type ThreadRequest struct { + Messages []ThreadMessage `json:"messages,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` +} + +type ModifyThreadRequest struct { + Metadata map[string]any `json:"metadata"` +} + +type ThreadMessageRole string + +const ( + ThreadMessageRoleUser ThreadMessageRole = "user" +) + +type ThreadMessage struct { + Role ThreadMessageRole `json:"role"` + Content string `json:"content"` + FileIDs []string `json:"file_ids,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` +} + +type ThreadDeleteResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Deleted bool `json:"deleted"` + + httpHeader +} + +// CreateThread creates a new thread. +func (c *Client) CreateThread(ctx context.Context, request ThreadRequest) (response Thread, err error) { + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(threadsSuffix), withBody(request), + withBetaAssistantV1()) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +// RetrieveThread retrieves a thread. +func (c *Client) RetrieveThread(ctx context.Context, threadID string) (response Thread, err error) { + urlSuffix := threadsSuffix + "/" + threadID + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantV1()) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +// ModifyThread modifies a thread. +func (c *Client) ModifyThread( + ctx context.Context, + threadID string, + request ModifyThreadRequest, +) (response Thread, err error) { + urlSuffix := threadsSuffix + "/" + threadID + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request), + withBetaAssistantV1()) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +// DeleteThread deletes a thread. +func (c *Client) DeleteThread( + ctx context.Context, + threadID string, +) (response ThreadDeleteResponse, err error) { + urlSuffix := threadsSuffix + "/" + threadID + req, err := c.newRequest(ctx, http.MethodDelete, c.fullURL(urlSuffix), + withBetaAssistantV1()) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} diff --git a/thread_test.go b/thread_test.go new file mode 100644 index 00000000..227ab633 --- /dev/null +++ b/thread_test.go @@ -0,0 +1,95 @@ +package openai_test + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "testing" + + openai "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" +) + +// TestThread Tests the thread endpoint of the API using the mocked server. +func TestThread(t *testing.T) { + threadID := "thread_abc123" + client, server, teardown := setupOpenAITestServer() + defer teardown() + + server.RegisterHandler( + "/v1/threads/"+threadID, + func(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet: + resBytes, _ := json.Marshal(openai.Thread{ + ID: threadID, + Object: "thread", + CreatedAt: 1234567890, + }) + fmt.Fprintln(w, string(resBytes)) + case http.MethodPost: + var request openai.ThreadRequest + err := json.NewDecoder(r.Body).Decode(&request) + checks.NoError(t, err, "Decode error") + + resBytes, _ := json.Marshal(openai.Thread{ + ID: threadID, + Object: "thread", + CreatedAt: 1234567890, + }) + fmt.Fprintln(w, string(resBytes)) + case http.MethodDelete: + fmt.Fprintln(w, `{ + "id": "thread_abc123", + "object": "thread.deleted", + "deleted": true + }`) + } + }, + ) + + server.RegisterHandler( + "/v1/threads", + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + var request openai.ModifyThreadRequest + err := json.NewDecoder(r.Body).Decode(&request) + checks.NoError(t, err, "Decode error") + + resBytes, _ := json.Marshal(openai.Thread{ + ID: threadID, + Object: "thread", + CreatedAt: 1234567890, + Metadata: request.Metadata, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + ctx := context.Background() + + _, err := client.CreateThread(ctx, openai.ThreadRequest{ + Messages: []openai.ThreadMessage{ + { + Role: openai.ThreadMessageRoleUser, + Content: "Hello, World!", + }, + }, + }) + checks.NoError(t, err, "CreateThread error") + + _, err = client.RetrieveThread(ctx, threadID) + checks.NoError(t, err, "RetrieveThread error") + + _, err = client.ModifyThread(ctx, threadID, openai.ModifyThreadRequest{ + Metadata: map[string]interface{}{ + "key": "value", + }, + }) + checks.NoError(t, err, "ModifyThread error") + + _, err = client.DeleteThread(ctx, threadID) + checks.NoError(t, err, "DeleteThread error") +} From e3e065deb0a190e2d3c3bbf9caf54471b32f675e Mon Sep 17 00:00:00 2001 From: Gabriel Burt Date: Thu, 9 Nov 2023 03:08:43 -0500 Subject: [PATCH 23/98] Add SystemFingerprint and chatMsg.ToolCallID field (#543) * fix ToolChoiche typo * add tool_call_id to ChatCompletionMessage * add /chat system_fingerprint response field * check empty ToolCallID JSON marshaling and add omitempty for tool_call_id * messages also required; don't omitempty * add Type to ToolCall, required by the API * fix test, omitempty for response_format ptr * fix casing of role values in comments --- chat.go | 27 +++++++++++++++++---------- chat_test.go | 14 ++++++++++++++ 2 files changed, 31 insertions(+), 10 deletions(-) diff --git a/chat.go b/chat.go index 609e0c31..9ad31c46 100644 --- a/chat.go +++ b/chat.go @@ -62,11 +62,17 @@ type ChatCompletionMessage struct { Name string `json:"name,omitempty"` FunctionCall *FunctionCall `json:"function_call,omitempty"` - ToolCalls []ToolCall `json:"tool_calls,omitempty"` + + // For Role=assistant prompts this may be set to the tool calls generated by the model, such as function calls. + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + + // For Role=tool prompts this should be set to the ID given in the assistant's prior request to call a tool. + ToolCallID string `json:"tool_call_id,omitempty"` } type ToolCall struct { ID string `json:"id"` + Type ToolType `json:"type"` Function FunctionCall `json:"function"` } @@ -84,7 +90,7 @@ const ( ) type ChatCompletionResponseFormat struct { - Type ChatCompletionResponseFormatType `json:"type"` + Type ChatCompletionResponseFormatType `json:"type,omitempty"` } // ChatCompletionRequest represents a request structure for chat completion API. @@ -112,7 +118,7 @@ type ChatCompletionRequest struct { FunctionCall any `json:"function_call,omitempty"` Tools []Tool `json:"tools,omitempty"` // This can be either a string or an ToolChoice object. - ToolChoiche any `json:"tool_choice,omitempty"` + ToolChoice any `json:"tool_choice,omitempty"` } type ToolType string @@ -126,7 +132,7 @@ type Tool struct { Function FunctionDefinition `json:"function,omitempty"` } -type ToolChoiche struct { +type ToolChoice struct { Type ToolType `json:"type"` Function ToolFunction `json:"function,omitempty"` } @@ -182,12 +188,13 @@ type ChatCompletionChoice struct { // ChatCompletionResponse represents a response structure for chat completion API. type ChatCompletionResponse struct { - ID string `json:"id"` - Object string `json:"object"` - Created int64 `json:"created"` - Model string `json:"model"` - Choices []ChatCompletionChoice `json:"choices"` - Usage Usage `json:"usage"` + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Choices []ChatCompletionChoice `json:"choices"` + Usage Usage `json:"usage"` + SystemFingerprint string `json:"system_fingerprint"` httpHeader } diff --git a/chat_test.go b/chat_test.go index 5bf1eaf6..a8155edf 100644 --- a/chat_test.go +++ b/chat_test.go @@ -51,6 +51,20 @@ func TestChatCompletionsWrongModel(t *testing.T) { checks.ErrorIs(t, err, openai.ErrChatCompletionInvalidModel, msg) } +func TestChatRequestOmitEmpty(t *testing.T) { + data, err := json.Marshal(openai.ChatCompletionRequest{ + // We set model b/c it's required, so omitempty doesn't make sense + Model: "gpt-4", + }) + checks.NoError(t, err) + + // messages is also required so isn't omitted + const expected = `{"model":"gpt-4","messages":null}` + if string(data) != expected { + t.Errorf("expected JSON with all empty fields to be %v but was %v", expected, string(data)) + } +} + func TestChatCompletionsWithStream(t *testing.T) { config := openai.DefaultConfig("whatever") config.BaseURL = "http://localhost/v1" From 81270725539980d202829528054f3fda346970db Mon Sep 17 00:00:00 2001 From: Urjit Singh Bhatia Date: Thu, 9 Nov 2023 00:20:39 -0800 Subject: [PATCH 24/98] fix test server setup: (#549) * fix test server setup: - go map access is not deterministic - this can lead to a route: /foo/bar/1 matching /foo/bar before matching /foo/bar/1 if the map iteration go through /foo/bar first since the regex match wasn't bound to start and end anchors - registering handlers now converts * in routes to .* for proper regex matching - test server route handling now tries to fully match the handler route * add missing /v1 prefix to fine-tuning job cancel test server handler --- fine_tuning_job_test.go | 2 +- internal/test/server.go | 10 ++++++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/fine_tuning_job_test.go b/fine_tuning_job_test.go index c892ef77..d2fbcd4c 100644 --- a/fine_tuning_job_test.go +++ b/fine_tuning_job_test.go @@ -42,7 +42,7 @@ func TestFineTuningJob(t *testing.T) { ) server.RegisterHandler( - "/fine_tuning/jobs/"+testFineTuninigJobID+"/cancel", + "/v1/fine_tuning/jobs/"+testFineTuninigJobID+"/cancel", func(w http.ResponseWriter, _ *http.Request) { resBytes, _ := json.Marshal(openai.FineTuningJob{}) fmt.Fprintln(w, string(resBytes)) diff --git a/internal/test/server.go b/internal/test/server.go index 3813ff86..127d4c16 100644 --- a/internal/test/server.go +++ b/internal/test/server.go @@ -5,6 +5,7 @@ import ( "net/http" "net/http/httptest" "regexp" + "strings" ) const testAPI = "this-is-my-secure-token-do-not-steal!!" @@ -23,13 +24,16 @@ func NewTestServer() *ServerTest { } func (ts *ServerTest) RegisterHandler(path string, handler handler) { + // to make the registered paths friendlier to a regex match in the route handler + // in OpenAITestServer + path = strings.ReplaceAll(path, "*", ".*") ts.handlers[path] = handler } // OpenAITestServer Creates a mocked OpenAI server which can pretend to handle requests during testing. func (ts *ServerTest) OpenAITestServer() *httptest.Server { return httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - log.Printf("received request at path %q\n", r.URL.Path) + log.Printf("received a %s request at path %q\n", r.Method, r.URL.Path) // check auth if r.Header.Get("Authorization") != "Bearer "+GetTestToken() && r.Header.Get("api-key") != GetTestToken() { @@ -38,8 +42,10 @@ func (ts *ServerTest) OpenAITestServer() *httptest.Server { } // Handle /path/* routes. + // Note: the * is converted to a .* in register handler for proper regex handling for route, handler := range ts.handlers { - pattern, _ := regexp.Compile(route) + // Adding ^ and $ to make path matching deterministic since go map iteration isn't ordered + pattern, _ := regexp.Compile("^" + route + "$") if pattern.MatchString(r.URL.Path) { handler(w, r) return From 78862a2798df46f6ca8bb73350b720f9c8d4a592 Mon Sep 17 00:00:00 2001 From: Simone Vellei Date: Thu, 9 Nov 2023 15:05:03 +0100 Subject: [PATCH 25/98] fix: add missing fields in tool_calls (#558) --- chat.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/chat.go b/chat.go index 9ad31c46..ebdc0e24 100644 --- a/chat.go +++ b/chat.go @@ -71,6 +71,8 @@ type ChatCompletionMessage struct { } type ToolCall struct { + // Index is not nil only in chat completion chunk object + Index *int `json:"index,omitempty"` ID string `json:"id"` Type ToolType `json:"type"` Function FunctionCall `json:"function"` From d6f3bdcdac9172ab5248d6be8c3e1761446a434c Mon Sep 17 00:00:00 2001 From: Simone Vellei Date: Thu, 9 Nov 2023 20:17:30 +0100 Subject: [PATCH 26/98] Feat implement Run APIs (#560) * chore: first commit * add apis * chore: add tests * feat add apis * chore: add api and tests * chore: add tests * fix * trigger build * fix * chore: formatting code * chore: add pagination type --- client_test.go | 27 ++++ run.go | 399 +++++++++++++++++++++++++++++++++++++++++++++++++ run_test.go | 237 +++++++++++++++++++++++++++++ 3 files changed, 663 insertions(+) create mode 100644 run.go create mode 100644 run_test.go diff --git a/client_test.go b/client_test.go index b2f28f90..d5d3e264 100644 --- a/client_test.go +++ b/client_test.go @@ -313,6 +313,33 @@ func TestClientReturnsRequestBuilderErrors(t *testing.T) { {"DeleteThread", func() (any, error) { return client.DeleteThread(ctx, "") }}, + {"CreateRun", func() (any, error) { + return client.CreateRun(ctx, "", RunRequest{}) + }}, + {"RetrieveRun", func() (any, error) { + return client.RetrieveRun(ctx, "", "") + }}, + {"ModifyRun", func() (any, error) { + return client.ModifyRun(ctx, "", "", RunModifyRequest{}) + }}, + {"ListRuns", func() (any, error) { + return client.ListRuns(ctx, "", Pagination{}) + }}, + {"SubmitToolOutputs", func() (any, error) { + return client.SubmitToolOutputs(ctx, "", "", SubmitToolOutputsRequest{}) + }}, + {"CancelRun", func() (any, error) { + return client.CancelRun(ctx, "", "") + }}, + {"CreateThreadAndRun", func() (any, error) { + return client.CreateThreadAndRun(ctx, CreateThreadAndRunRequest{}) + }}, + {"RetrieveRunStep", func() (any, error) { + return client.RetrieveRunStep(ctx, "", "", "") + }}, + {"ListRunSteps", func() (any, error) { + return client.ListRunSteps(ctx, "", "", Pagination{}) + }}, } for _, testCase := range testCases { diff --git a/run.go b/run.go new file mode 100644 index 00000000..5d6ea58d --- /dev/null +++ b/run.go @@ -0,0 +1,399 @@ +package openai + +import ( + "context" + "fmt" + "net/http" + "net/url" +) + +type Run struct { + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int64 `json:"created_at"` + ThreadID string `json:"thread_id"` + AssistantID string `json:"assistant_id"` + Status RunStatus `json:"status"` + RequiredAction *RunRequiredAction `json:"required_action,omitempty"` + LastError *RunLastError `json:"last_error,omitempty"` + ExpiresAt int64 `json:"expires_at"` + StartedAt *int64 `json:"started_at,omitempty"` + CancelledAt *int64 `json:"cancelled_at,omitempty"` + FailedAt *int64 `json:"failed_at,omitempty"` + CompletedAt *int64 `json:"completed_at,omitempty"` + Model string `json:"model"` + Instructions string `json:"instructions,omitempty"` + Tools []Tool `json:"tools"` + FileIDS []string `json:"file_ids"` + Metadata map[string]any `json:"metadata"` + + httpHeader +} + +type RunStatus string + +const ( + RunStatusQueued RunStatus = "queued" + RunStatusInProgress RunStatus = "in_progress" + RunStatusRequiresAction RunStatus = "requires_action" + RunStatusCancelling RunStatus = "cancelling" + RunStatusFailed RunStatus = "failed" + RunStatusCompleted RunStatus = "completed" + RunStatusExpired RunStatus = "expired" +) + +type RunRequiredAction struct { + Type RequiredActionType `json:"type"` + SubmitToolOutputs *SubmitToolOutputs `json:"submit_tool_outputs,omitempty"` +} + +type RequiredActionType string + +const ( + RequiredActionTypeSubmitToolOutputs RequiredActionType = "submit_tool_outputs" +) + +type SubmitToolOutputs struct { + ToolCalls []ToolCall `json:"tool_calls"` +} + +type RunLastError struct { + Code RunError `json:"code"` + Message string `json:"message"` +} + +type RunError string + +const ( + RunErrorServerError RunError = "server_error" + RunErrorRateLimitExceeded RunError = "rate_limit_exceeded" +) + +type RunRequest struct { + AssistantID string `json:"assistant_id"` + Model *string `json:"model,omitempty"` + Instructions *string `json:"instructions,omitempty"` + Tools []Tool `json:"tools,omitempty"` + Metadata map[string]any +} + +type RunModifyRequest struct { + Metadata map[string]any `json:"metadata,omitempty"` +} + +// RunList is a list of runs. +type RunList struct { + Runs []Run `json:"data"` + + httpHeader +} + +type SubmitToolOutputsRequest struct { + ToolOutputs []ToolOutput `json:"tool_outputs"` +} + +type ToolOutput struct { + ToolCallID string `json:"tool_call_id"` + Output any `json:"output"` +} + +type CreateThreadAndRunRequest struct { + RunRequest + Thread ThreadRequest `json:"thread"` +} + +type RunStep struct { + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int64 `json:"created_at"` + AssistantID string `json:"assistant_id"` + ThreadID string `json:"thread_id"` + RunID string `json:"run_id"` + Type RunStepType `json:"type"` + Status RunStepStatus `json:"status"` + StepDetails StepDetails `json:"step_details"` + LastError *RunLastError `json:"last_error,omitempty"` + ExpiredAt *int64 `json:"expired_at,omitempty"` + CancelledAt *int64 `json:"cancelled_at,omitempty"` + FailedAt *int64 `json:"failed_at,omitempty"` + CompletedAt *int64 `json:"completed_at,omitempty"` + Metadata map[string]any `json:"metadata"` + + httpHeader +} + +type RunStepStatus string + +const ( + RunStepStatusInProgress RunStepStatus = "in_progress" + RunStepStatusCancelling RunStepStatus = "cancelled" + RunStepStatusFailed RunStepStatus = "failed" + RunStepStatusCompleted RunStepStatus = "completed" + RunStepStatusExpired RunStepStatus = "expired" +) + +type RunStepType string + +const ( + RunStepTypeMessageCreation RunStepType = "message_creation" + RunStepTypeToolCalls RunStepType = "tool_calls" +) + +type StepDetails struct { + Type RunStepType `json:"type"` + MessageCreation *StepDetailsMessageCreation `json:"message_creation,omitempty"` + ToolCalls *StepDetailsToolCalls `json:"tool_calls,omitempty"` +} + +type StepDetailsMessageCreation struct { + MessageID string `json:"message_id"` +} + +type StepDetailsToolCalls struct { + ToolCalls []ToolCall `json:"tool_calls"` +} + +// RunStepList is a list of steps. +type RunStepList struct { + RunSteps []RunStep `json:"data"` + + httpHeader +} + +type Pagination struct { + Limit *int + Order *string + After *string + Before *string +} + +// CreateRun creates a new run. +func (c *Client) CreateRun( + ctx context.Context, + threadID string, + request RunRequest, +) (response Run, err error) { + urlSuffix := fmt.Sprintf("/threads/%s/runs", threadID) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL(urlSuffix), + withBody(request), + withBetaAssistantV1(), + ) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +// RetrieveRun retrieves a run. +func (c *Client) RetrieveRun( + ctx context.Context, + threadID string, + runID string, +) (response Run, err error) { + urlSuffix := fmt.Sprintf("/threads/%s/runs/%s", threadID, runID) + req, err := c.newRequest( + ctx, + http.MethodGet, + c.fullURL(urlSuffix), + withBetaAssistantV1(), + ) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +// ModifyRun modifies a run. +func (c *Client) ModifyRun( + ctx context.Context, + threadID string, + runID string, + request RunModifyRequest, +) (response Run, err error) { + urlSuffix := fmt.Sprintf("/threads/%s/runs/%s", threadID, runID) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL(urlSuffix), + withBody(request), + withBetaAssistantV1(), + ) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +// ListRuns lists runs. +func (c *Client) ListRuns( + ctx context.Context, + threadID string, + pagination Pagination, +) (response RunList, err error) { + urlValues := url.Values{} + if pagination.Limit != nil { + urlValues.Add("limit", fmt.Sprintf("%d", *pagination.Limit)) + } + if pagination.Order != nil { + urlValues.Add("order", *pagination.Order) + } + if pagination.After != nil { + urlValues.Add("after", *pagination.After) + } + if pagination.Before != nil { + urlValues.Add("before", *pagination.Before) + } + + encodedValues := "" + if len(urlValues) > 0 { + encodedValues = "?" + urlValues.Encode() + } + + urlSuffix := fmt.Sprintf("/threads/%s/runs%s", threadID, encodedValues) + req, err := c.newRequest( + ctx, + http.MethodGet, + c.fullURL(urlSuffix), + withBetaAssistantV1(), + ) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +// SubmitToolOutputs submits tool outputs. +func (c *Client) SubmitToolOutputs( + ctx context.Context, + threadID string, + runID string, + request SubmitToolOutputsRequest) (response Run, err error) { + urlSuffix := fmt.Sprintf("/threads/%s/runs/%s/submit_tool_outputs", threadID, runID) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL(urlSuffix), + withBody(request), + withBetaAssistantV1(), + ) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +// CancelRun cancels a run. +func (c *Client) CancelRun( + ctx context.Context, + threadID string, + runID string) (response Run, err error) { + urlSuffix := fmt.Sprintf("/threads/%s/runs/%s/cancel", threadID, runID) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL(urlSuffix), + withBetaAssistantV1(), + ) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +// CreateThreadAndRun submits tool outputs. +func (c *Client) CreateThreadAndRun( + ctx context.Context, + request CreateThreadAndRunRequest) (response Run, err error) { + urlSuffix := "/threads/runs" + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL(urlSuffix), + withBody(request), + withBetaAssistantV1(), + ) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +// RetrieveRunStep retrieves a run step. +func (c *Client) RetrieveRunStep( + ctx context.Context, + threadID string, + runID string, + stepID string, +) (response RunStep, err error) { + urlSuffix := fmt.Sprintf("/threads/%s/runs/%s/steps/%s", threadID, runID, stepID) + req, err := c.newRequest( + ctx, + http.MethodGet, + c.fullURL(urlSuffix), + withBetaAssistantV1(), + ) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +// ListRunSteps lists run steps. +func (c *Client) ListRunSteps( + ctx context.Context, + threadID string, + runID string, + pagination Pagination, +) (response RunStepList, err error) { + urlValues := url.Values{} + if pagination.Limit != nil { + urlValues.Add("limit", fmt.Sprintf("%d", *pagination.Limit)) + } + if pagination.Order != nil { + urlValues.Add("order", *pagination.Order) + } + if pagination.After != nil { + urlValues.Add("after", *pagination.After) + } + if pagination.Before != nil { + urlValues.Add("before", *pagination.Before) + } + + encodedValues := "" + if len(urlValues) > 0 { + encodedValues = "?" + urlValues.Encode() + } + + urlSuffix := fmt.Sprintf("/threads/%s/runs/%s/steps%s", threadID, runID, encodedValues) + req, err := c.newRequest( + ctx, + http.MethodGet, + c.fullURL(urlSuffix), + withBetaAssistantV1(), + ) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} diff --git a/run_test.go b/run_test.go new file mode 100644 index 00000000..cdf99db0 --- /dev/null +++ b/run_test.go @@ -0,0 +1,237 @@ +package openai_test + +import ( + "context" + + openai "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" + + "encoding/json" + "fmt" + "net/http" + "testing" +) + +// TestAssistant Tests the assistant endpoint of the API using the mocked server. +func TestRun(t *testing.T) { + assistantID := "asst_abc123" + threadID := "thread_abc123" + runID := "run_abc123" + stepID := "step_abc123" + limit := 20 + order := "desc" + after := "asst_abc122" + before := "asst_abc124" + + client, server, teardown := setupOpenAITestServer() + defer teardown() + + server.RegisterHandler( + "/v1/threads/"+threadID+"/runs/"+runID+"/steps/"+stepID, + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet { + resBytes, _ := json.Marshal(openai.RunStep{ + ID: runID, + Object: "run", + CreatedAt: 1234567890, + Status: openai.RunStepStatusCompleted, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + server.RegisterHandler( + "/v1/threads/"+threadID+"/runs/"+runID+"/steps", + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet { + resBytes, _ := json.Marshal(openai.RunStepList{ + RunSteps: []openai.RunStep{ + { + ID: runID, + Object: "run", + CreatedAt: 1234567890, + Status: openai.RunStepStatusCompleted, + }, + }, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + server.RegisterHandler( + "/v1/threads/"+threadID+"/runs/"+runID+"/cancel", + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + resBytes, _ := json.Marshal(openai.Run{ + ID: runID, + Object: "run", + CreatedAt: 1234567890, + Status: openai.RunStatusCancelling, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + server.RegisterHandler( + "/v1/threads/"+threadID+"/runs/"+runID+"/submit_tool_outputs", + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + resBytes, _ := json.Marshal(openai.Run{ + ID: runID, + Object: "run", + CreatedAt: 1234567890, + Status: openai.RunStatusCancelling, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + server.RegisterHandler( + "/v1/threads/"+threadID+"/runs/"+runID, + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet { + resBytes, _ := json.Marshal(openai.Run{ + ID: runID, + Object: "run", + CreatedAt: 1234567890, + Status: openai.RunStatusQueued, + }) + fmt.Fprintln(w, string(resBytes)) + } else if r.Method == http.MethodPost { + var request openai.RunModifyRequest + err := json.NewDecoder(r.Body).Decode(&request) + checks.NoError(t, err, "Decode error") + + resBytes, _ := json.Marshal(openai.Run{ + ID: runID, + Object: "run", + CreatedAt: 1234567890, + Status: openai.RunStatusQueued, + Metadata: request.Metadata, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + server.RegisterHandler( + "/v1/threads/"+threadID+"/runs", + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + var request openai.RunRequest + err := json.NewDecoder(r.Body).Decode(&request) + checks.NoError(t, err, "Decode error") + + resBytes, _ := json.Marshal(openai.Run{ + ID: runID, + Object: "run", + CreatedAt: 1234567890, + Status: openai.RunStatusQueued, + }) + fmt.Fprintln(w, string(resBytes)) + } else if r.Method == http.MethodGet { + resBytes, _ := json.Marshal(openai.RunList{ + Runs: []openai.Run{ + { + ID: runID, + Object: "run", + CreatedAt: 1234567890, + Status: openai.RunStatusQueued, + }, + }, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + server.RegisterHandler( + "/v1/threads/runs", + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + var request openai.CreateThreadAndRunRequest + err := json.NewDecoder(r.Body).Decode(&request) + checks.NoError(t, err, "Decode error") + + resBytes, _ := json.Marshal(openai.Run{ + ID: runID, + Object: "run", + CreatedAt: 1234567890, + Status: openai.RunStatusQueued, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + ctx := context.Background() + + _, err := client.CreateRun(ctx, threadID, openai.RunRequest{ + AssistantID: assistantID, + }) + checks.NoError(t, err, "CreateRun error") + + _, err = client.RetrieveRun(ctx, threadID, runID) + checks.NoError(t, err, "RetrieveRun error") + + _, err = client.ModifyRun(ctx, threadID, runID, openai.RunModifyRequest{ + Metadata: map[string]any{ + "key": "value", + }, + }) + checks.NoError(t, err, "ModifyRun error") + + _, err = client.ListRuns( + ctx, + threadID, + openai.Pagination{ + Limit: &limit, + Order: &order, + After: &after, + Before: &before, + }, + ) + checks.NoError(t, err, "ListRuns error") + + _, err = client.SubmitToolOutputs(ctx, threadID, runID, + openai.SubmitToolOutputsRequest{}) + checks.NoError(t, err, "SubmitToolOutputs error") + + _, err = client.CancelRun(ctx, threadID, runID) + checks.NoError(t, err, "CancelRun error") + + _, err = client.CreateThreadAndRun(ctx, openai.CreateThreadAndRunRequest{ + RunRequest: openai.RunRequest{ + AssistantID: assistantID, + }, + Thread: openai.ThreadRequest{ + Messages: []openai.ThreadMessage{ + { + Role: openai.ThreadMessageRoleUser, + Content: "Hello, World!", + }, + }, + }, + }) + checks.NoError(t, err, "CreateThreadAndRun error") + + _, err = client.RetrieveRunStep(ctx, threadID, runID, stepID) + checks.NoError(t, err, "RetrieveRunStep error") + + _, err = client.ListRunSteps( + ctx, + threadID, + runID, + openai.Pagination{ + Limit: &limit, + Order: &order, + After: &after, + Before: &before, + }, + ) + checks.NoError(t, err, "ListRunSteps error") +} From 35495ccd364265f37800a6fa72fed7f05705eb82 Mon Sep 17 00:00:00 2001 From: Kyle Bolton Date: Sun, 12 Nov 2023 06:09:40 -0500 Subject: [PATCH 27/98] Add `json:"metadata,omitempty"` to RunRequest struct (#561) Metadata is an optional field per the api spec https://platform.openai.com/docs/api-reference/runs/createRun --- run.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/run.go b/run.go index 5d6ea58d..7ff730fe 100644 --- a/run.go +++ b/run.go @@ -70,11 +70,11 @@ const ( ) type RunRequest struct { - AssistantID string `json:"assistant_id"` - Model *string `json:"model,omitempty"` - Instructions *string `json:"instructions,omitempty"` - Tools []Tool `json:"tools,omitempty"` - Metadata map[string]any + AssistantID string `json:"assistant_id"` + Model *string `json:"model,omitempty"` + Instructions *string `json:"instructions,omitempty"` + Tools []Tool `json:"tools,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` } type RunModifyRequest struct { From 9fefd50e12ad138efa3f38756be5dd2ed5fefadd Mon Sep 17 00:00:00 2001 From: Ikko Eltociear Ashimine Date: Sun, 12 Nov 2023 20:10:00 +0900 Subject: [PATCH 28/98] Fix typo in chat_test.go (#564) requetsts -> requests --- chat_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chat_test.go b/chat_test.go index a8155edf..8377809d 100644 --- a/chat_test.go +++ b/chat_test.go @@ -144,7 +144,7 @@ func TestChatCompletionsWithRateLimitHeaders(t *testing.T) { } resetRequestsTime := headers.ResetRequests.Time() if resetRequestsTime.Before(time.Now()) { - t.Errorf("unexpected reset requetsts: %v", resetRequestsTime) + t.Errorf("unexpected reset requests: %v", resetRequestsTime) } bs1, _ := json.Marshal(headers) From b7cac703acb1a8be0e803c81ad3236be66be969a Mon Sep 17 00:00:00 2001 From: Urjit Singh Bhatia Date: Mon, 13 Nov 2023 08:33:26 -0600 Subject: [PATCH 29/98] Feat/messages api (#546) * fix test server setup: - go map access is not deterministic - this can lead to a route: /foo/bar/1 matching /foo/bar before matching /foo/bar/1 if the map iteration go through /foo/bar first since the regex match wasn't bound to start and end anchors - registering handlers now converts * in routes to .* for proper regex matching - test server route handling now tries to fully match the handler route * add missing /v1 prefix to fine-tuning job cancel test server handler * add create message call * add messages list call * add get message call * add modify message call, fix return types for other message calls * add message file retrieve call * add list message files call * code style fixes * add test for list messages with pagination options * add beta header to msg calls now that #545 is merged * Update messages.go Co-authored-by: Simone Vellei * Update messages.go Co-authored-by: Simone Vellei * add missing object details for message, fix tests * fix merge formatting * minor style fixes --------- Co-authored-by: Simone Vellei --- client_test.go | 18 ++++ messages.go | 178 +++++++++++++++++++++++++++++++++++ messages_test.go | 235 +++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 431 insertions(+) create mode 100644 messages.go create mode 100644 messages_test.go diff --git a/client_test.go b/client_test.go index d5d3e264..24cb5ffa 100644 --- a/client_test.go +++ b/client_test.go @@ -301,6 +301,24 @@ func TestClientReturnsRequestBuilderErrors(t *testing.T) { {"DeleteAssistantFile", func() (any, error) { return nil, client.DeleteAssistantFile(ctx, "", "") }}, + {"CreateMessage", func() (any, error) { + return client.CreateMessage(ctx, "", MessageRequest{}) + }}, + {"ListMessage", func() (any, error) { + return client.ListMessage(ctx, "", nil, nil, nil, nil) + }}, + {"RetrieveMessage", func() (any, error) { + return client.RetrieveMessage(ctx, "", "") + }}, + {"ModifyMessage", func() (any, error) { + return client.ModifyMessage(ctx, "", "", nil) + }}, + {"RetrieveMessageFile", func() (any, error) { + return client.RetrieveMessageFile(ctx, "", "", "") + }}, + {"ListMessageFiles", func() (any, error) { + return client.ListMessageFiles(ctx, "", "") + }}, {"CreateThread", func() (any, error) { return client.CreateThread(ctx, ThreadRequest{}) }}, diff --git a/messages.go b/messages.go new file mode 100644 index 00000000..4e691a8b --- /dev/null +++ b/messages.go @@ -0,0 +1,178 @@ +package openai + +import ( + "context" + "fmt" + "net/http" + "net/url" +) + +const ( + messagesSuffix = "messages" +) + +type Message struct { + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int `json:"created_at"` + ThreadID string `json:"thread_id"` + Role string `json:"role"` + Content []MessageContent `json:"content"` + FileIds []string `json:"file_ids"` + AssistantID *string `json:"assistant_id,omitempty"` + RunID *string `json:"run_id,omitempty"` + Metadata map[string]any `json:"metadata"` + + httpHeader +} + +type MessagesList struct { + Messages []Message `json:"data"` + + httpHeader +} + +type MessageContent struct { + Type string `json:"type"` + Text *MessageText `json:"text,omitempty"` + ImageFile *ImageFile `json:"image_file,omitempty"` +} +type MessageText struct { + Value string `json:"value"` + Annotations []any `json:"annotations"` +} + +type ImageFile struct { + FileID string `json:"file_id"` +} + +type MessageRequest struct { + Role string `json:"role"` + Content string `json:"content"` + FileIds []string `json:"file_ids,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` +} + +type MessageFile struct { + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int `json:"created_at"` + MessageID string `json:"message_id"` + + httpHeader +} + +type MessageFilesList struct { + MessageFiles []MessageFile `json:"data"` + + httpHeader +} + +// CreateMessage creates a new message. +func (c *Client) CreateMessage(ctx context.Context, threadID string, request MessageRequest) (msg Message, err error) { + urlSuffix := fmt.Sprintf("/threads/%s/%s", threadID, messagesSuffix) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request)) + if err != nil { + return + } + + err = c.sendRequest(req, &msg) + return +} + +// ListMessage fetches all messages in the thread. +func (c *Client) ListMessage(ctx context.Context, threadID string, + limit *int, + order *string, + after *string, + before *string, +) (messages MessagesList, err error) { + urlValues := url.Values{} + if limit != nil { + urlValues.Add("limit", fmt.Sprintf("%d", *limit)) + } + if order != nil { + urlValues.Add("order", *order) + } + if after != nil { + urlValues.Add("after", *after) + } + if before != nil { + urlValues.Add("before", *before) + } + encodedValues := "" + if len(urlValues) > 0 { + encodedValues = "?" + urlValues.Encode() + } + + urlSuffix := fmt.Sprintf("/threads/%s/%s%s", threadID, messagesSuffix, encodedValues) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), withBetaAssistantV1()) + if err != nil { + return + } + + err = c.sendRequest(req, &messages) + return +} + +// RetrieveMessage retrieves a Message. +func (c *Client) RetrieveMessage( + ctx context.Context, + threadID, messageID string, +) (msg Message, err error) { + urlSuffix := fmt.Sprintf("/threads/%s/%s/%s", threadID, messagesSuffix, messageID) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), withBetaAssistantV1()) + if err != nil { + return + } + + err = c.sendRequest(req, &msg) + return +} + +// ModifyMessage modifies a message. +func (c *Client) ModifyMessage( + ctx context.Context, + threadID, messageID string, + metadata map[string]any, +) (msg Message, err error) { + urlSuffix := fmt.Sprintf("/threads/%s/%s/%s", threadID, messagesSuffix, messageID) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), + withBody(metadata), withBetaAssistantV1()) + if err != nil { + return + } + + err = c.sendRequest(req, &msg) + return +} + +// RetrieveMessageFile fetches a message file. +func (c *Client) RetrieveMessageFile( + ctx context.Context, + threadID, messageID, fileID string, +) (file MessageFile, err error) { + urlSuffix := fmt.Sprintf("/threads/%s/%s/%s/files/%s", threadID, messagesSuffix, messageID, fileID) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), withBetaAssistantV1()) + if err != nil { + return + } + + err = c.sendRequest(req, &file) + return +} + +// ListMessageFiles fetches all files attached to a message. +func (c *Client) ListMessageFiles( + ctx context.Context, + threadID, messageID string, +) (files MessageFilesList, err error) { + urlSuffix := fmt.Sprintf("/threads/%s/%s/%s/files", threadID, messagesSuffix, messageID) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), withBetaAssistantV1()) + if err != nil { + return + } + + err = c.sendRequest(req, &files) + return +} diff --git a/messages_test.go b/messages_test.go new file mode 100644 index 00000000..282b1cc9 --- /dev/null +++ b/messages_test.go @@ -0,0 +1,235 @@ +package openai_test + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "testing" + + "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" +) + +var emptyStr = "" + +// TestMessages Tests the messages endpoint of the API using the mocked server. +func TestMessages(t *testing.T) { + threadID := "thread_abc123" + messageID := "msg_abc123" + fileID := "file_abc123" + + client, server, teardown := setupOpenAITestServer() + defer teardown() + + server.RegisterHandler( + "/v1/threads/"+threadID+"/messages/"+messageID+"/files/"+fileID, + func(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet: + resBytes, _ := json.Marshal( + openai.MessageFile{ + ID: fileID, + Object: "thread.message.file", + CreatedAt: 1699061776, + MessageID: messageID, + }) + fmt.Fprintln(w, string(resBytes)) + default: + t.Fatalf("unsupported messages http method: %s", r.Method) + } + }, + ) + + server.RegisterHandler( + "/v1/threads/"+threadID+"/messages/"+messageID+"/files", + func(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet: + resBytes, _ := json.Marshal( + openai.MessageFilesList{MessageFiles: []openai.MessageFile{{ + ID: fileID, + Object: "thread.message.file", + CreatedAt: 0, + MessageID: messageID, + }}}) + fmt.Fprintln(w, string(resBytes)) + default: + t.Fatalf("unsupported messages http method: %s", r.Method) + } + }, + ) + + server.RegisterHandler( + "/v1/threads/"+threadID+"/messages/"+messageID, + func(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodPost: + metadata := map[string]any{} + err := json.NewDecoder(r.Body).Decode(&metadata) + checks.NoError(t, err, "unable to decode metadata in modify message call") + + resBytes, _ := json.Marshal( + openai.Message{ + ID: messageID, + Object: "thread.message", + CreatedAt: 1234567890, + ThreadID: threadID, + Role: "user", + Content: []openai.MessageContent{{ + Type: "text", + Text: &openai.MessageText{ + Value: "How does AI work?", + Annotations: nil, + }, + }}, + FileIds: nil, + AssistantID: &emptyStr, + RunID: &emptyStr, + Metadata: metadata, + }) + fmt.Fprintln(w, string(resBytes)) + case http.MethodGet: + resBytes, _ := json.Marshal( + openai.Message{ + ID: messageID, + Object: "thread.message", + CreatedAt: 1234567890, + ThreadID: threadID, + Role: "user", + Content: []openai.MessageContent{{ + Type: "text", + Text: &openai.MessageText{ + Value: "How does AI work?", + Annotations: nil, + }, + }}, + FileIds: nil, + AssistantID: &emptyStr, + RunID: &emptyStr, + Metadata: nil, + }) + fmt.Fprintln(w, string(resBytes)) + default: + t.Fatalf("unsupported messages http method: %s", r.Method) + } + }, + ) + + server.RegisterHandler( + "/v1/threads/"+threadID+"/messages", + func(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodPost: + resBytes, _ := json.Marshal(openai.Message{ + ID: messageID, + Object: "thread.message", + CreatedAt: 1234567890, + ThreadID: threadID, + Role: "user", + Content: []openai.MessageContent{{ + Type: "text", + Text: &openai.MessageText{ + Value: "How does AI work?", + Annotations: nil, + }, + }}, + FileIds: nil, + AssistantID: &emptyStr, + RunID: &emptyStr, + Metadata: nil, + }) + fmt.Fprintln(w, string(resBytes)) + case http.MethodGet: + resBytes, _ := json.Marshal(openai.MessagesList{ + Messages: []openai.Message{{ + ID: messageID, + Object: "thread.message", + CreatedAt: 1234567890, + ThreadID: threadID, + Role: "user", + Content: []openai.MessageContent{{ + Type: "text", + Text: &openai.MessageText{ + Value: "How does AI work?", + Annotations: nil, + }, + }}, + FileIds: nil, + AssistantID: &emptyStr, + RunID: &emptyStr, + Metadata: nil, + }}}) + fmt.Fprintln(w, string(resBytes)) + default: + t.Fatalf("unsupported messages http method: %s", r.Method) + } + }, + ) + + ctx := context.Background() + + // static assertion of return type + var msg openai.Message + msg, err := client.CreateMessage(ctx, threadID, openai.MessageRequest{ + Role: "user", + Content: "How does AI work?", + FileIds: nil, + Metadata: nil, + }) + checks.NoError(t, err, "CreateMessage error") + if msg.ID != messageID { + t.Fatalf("unexpected message id: '%s'", msg.ID) + } + + var msgs openai.MessagesList + msgs, err = client.ListMessage(ctx, threadID, nil, nil, nil, nil) + checks.NoError(t, err, "ListMessages error") + if len(msgs.Messages) != 1 { + t.Fatalf("unexpected length of fetched messages") + } + + // with pagination options set + limit := 1 + order := "desc" + after := "obj_foo" + before := "obj_bar" + msgs, err = client.ListMessage(ctx, threadID, &limit, &order, &after, &before) + checks.NoError(t, err, "ListMessages error") + if len(msgs.Messages) != 1 { + t.Fatalf("unexpected length of fetched messages") + } + + msg, err = client.RetrieveMessage(ctx, threadID, messageID) + checks.NoError(t, err, "RetrieveMessage error") + if msg.ID != messageID { + t.Fatalf("unexpected message id: '%s'", msg.ID) + } + + msg, err = client.ModifyMessage(ctx, threadID, messageID, + map[string]any{ + "foo": "bar", + }) + checks.NoError(t, err, "ModifyMessage error") + if msg.Metadata["foo"] != "bar" { + t.Fatalf("expected message metadata to get modified") + } + + // message files + var msgFile openai.MessageFile + msgFile, err = client.RetrieveMessageFile(ctx, threadID, messageID, fileID) + checks.NoError(t, err, "RetrieveMessageFile error") + if msgFile.ID != fileID { + t.Fatalf("unexpected message file id: '%s'", msgFile.ID) + } + + var msgFiles openai.MessageFilesList + msgFiles, err = client.ListMessageFiles(ctx, threadID, messageID) + checks.NoError(t, err, "RetrieveMessageFile error") + if len(msgFiles.MessageFiles) != 1 { + t.Fatalf("unexpected count of message files: %d", len(msgFiles.MessageFiles)) + } + if msgFiles.MessageFiles[0].ID != fileID { + t.Fatalf("unexpected message file id: '%s' in list message files", msgFiles.MessageFiles[0].ID) + } +} From 515de0219d3b4d30351d44d8a0f508599de6c053 Mon Sep 17 00:00:00 2001 From: Chris Hua Date: Mon, 13 Nov 2023 09:35:34 -0500 Subject: [PATCH 30/98] feat: initial TTS support (#528) * feat: initial TTS support * chore: lint, omitempty * chore: dont use pointer in struct * fix: add mocked server tests to speech_test.go Co-authored-by: Lachlan Laycock * chore: update imports * chore: fix lint * chore: add an error check * chore: ignore lint * chore: add error checks in package * chore: add test * chore: fix test --------- Co-authored-by: Lachlan Laycock --- client_test.go | 3 ++ speech.go | 87 +++++++++++++++++++++++++++++++++++++ speech_test.go | 115 +++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 205 insertions(+) create mode 100644 speech.go create mode 100644 speech_test.go diff --git a/client_test.go b/client_test.go index 24cb5ffa..1c908458 100644 --- a/client_test.go +++ b/client_test.go @@ -358,6 +358,9 @@ func TestClientReturnsRequestBuilderErrors(t *testing.T) { {"ListRunSteps", func() (any, error) { return client.ListRunSteps(ctx, "", "", Pagination{}) }}, + {"CreateSpeech", func() (any, error) { + return client.CreateSpeech(ctx, CreateSpeechRequest{Model: TTSModel1, Voice: VoiceAlloy}) + }}, } for _, testCase := range testCases { diff --git a/speech.go b/speech.go new file mode 100644 index 00000000..a3d5f5dc --- /dev/null +++ b/speech.go @@ -0,0 +1,87 @@ +package openai + +import ( + "context" + "errors" + "io" + "net/http" +) + +type SpeechModel string + +const ( + TTSModel1 SpeechModel = "tts-1" + TTsModel1HD SpeechModel = "tts-1-hd" +) + +type SpeechVoice string + +const ( + VoiceAlloy SpeechVoice = "alloy" + VoiceEcho SpeechVoice = "echo" + VoiceFable SpeechVoice = "fable" + VoiceOnyx SpeechVoice = "onyx" + VoiceNova SpeechVoice = "nova" + VoiceShimmer SpeechVoice = "shimmer" +) + +type SpeechResponseFormat string + +const ( + SpeechResponseFormatMp3 SpeechResponseFormat = "mp3" + SpeechResponseFormatOpus SpeechResponseFormat = "opus" + SpeechResponseFormatAac SpeechResponseFormat = "aac" + SpeechResponseFormatFlac SpeechResponseFormat = "flac" +) + +var ( + ErrInvalidSpeechModel = errors.New("invalid speech model") + ErrInvalidVoice = errors.New("invalid voice") +) + +type CreateSpeechRequest struct { + Model SpeechModel `json:"model"` + Input string `json:"input"` + Voice SpeechVoice `json:"voice"` + ResponseFormat SpeechResponseFormat `json:"response_format,omitempty"` // Optional, default to mp3 + Speed float64 `json:"speed,omitempty"` // Optional, default to 1.0 +} + +func contains[T comparable](s []T, e T) bool { + for _, v := range s { + if v == e { + return true + } + } + return false +} + +func isValidSpeechModel(model SpeechModel) bool { + return contains([]SpeechModel{TTSModel1, TTsModel1HD}, model) +} + +func isValidVoice(voice SpeechVoice) bool { + return contains([]SpeechVoice{VoiceAlloy, VoiceEcho, VoiceFable, VoiceOnyx, VoiceNova, VoiceShimmer}, voice) +} + +func (c *Client) CreateSpeech(ctx context.Context, request CreateSpeechRequest) (response io.ReadCloser, err error) { + if !isValidSpeechModel(request.Model) { + err = ErrInvalidSpeechModel + return + } + if !isValidVoice(request.Voice) { + err = ErrInvalidVoice + return + } + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/audio/speech", request.Model), + withBody(request), + withContentType("application/json; charset=utf-8"), + ) + if err != nil { + return + } + + response, err = c.sendRequestRaw(req) + + return +} diff --git a/speech_test.go b/speech_test.go new file mode 100644 index 00000000..d9ba58b1 --- /dev/null +++ b/speech_test.go @@ -0,0 +1,115 @@ +package openai_test + +import ( + "context" + "encoding/json" + "fmt" + "io" + "mime" + "net/http" + "os" + "path/filepath" + "testing" + + "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test" + "github.com/sashabaranov/go-openai/internal/test/checks" +) + +func TestSpeechIntegration(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + + server.RegisterHandler("/v1/audio/speech", func(w http.ResponseWriter, r *http.Request) { + dir, cleanup := test.CreateTestDirectory(t) + path := filepath.Join(dir, "fake.mp3") + test.CreateTestFile(t, path) + defer cleanup() + + // audio endpoints only accept POST requests + if r.Method != "POST" { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + + mediaType, _, err := mime.ParseMediaType(r.Header.Get("Content-Type")) + if err != nil { + http.Error(w, "failed to parse media type", http.StatusBadRequest) + return + } + + if mediaType != "application/json" { + http.Error(w, "request is not json", http.StatusBadRequest) + return + } + + // Parse the JSON body of the request + var params map[string]interface{} + err = json.NewDecoder(r.Body).Decode(¶ms) + if err != nil { + http.Error(w, "failed to parse request body", http.StatusBadRequest) + return + } + + // Check if each required field is present in the parsed JSON object + reqParams := []string{"model", "input", "voice"} + for _, param := range reqParams { + _, ok := params[param] + if !ok { + http.Error(w, fmt.Sprintf("no %s in params", param), http.StatusBadRequest) + return + } + } + + // read audio file content + audioFile, err := os.ReadFile(path) + if err != nil { + http.Error(w, "failed to read audio file", http.StatusInternalServerError) + return + } + + // write audio file content to response + w.Header().Set("Content-Type", "audio/mpeg") + w.Header().Set("Transfer-Encoding", "chunked") + w.Header().Set("Connection", "keep-alive") + _, err = w.Write(audioFile) + if err != nil { + http.Error(w, "failed to write body", http.StatusInternalServerError) + return + } + }) + + t.Run("happy path", func(t *testing.T) { + res, err := client.CreateSpeech(context.Background(), openai.CreateSpeechRequest{ + Model: openai.TTSModel1, + Input: "Hello!", + Voice: openai.VoiceAlloy, + }) + checks.NoError(t, err, "CreateSpeech error") + defer res.Close() + + buf, err := io.ReadAll(res) + checks.NoError(t, err, "ReadAll error") + + // save buf to file as mp3 + err = os.WriteFile("test.mp3", buf, 0644) + checks.NoError(t, err, "Create error") + }) + t.Run("invalid model", func(t *testing.T) { + _, err := client.CreateSpeech(context.Background(), openai.CreateSpeechRequest{ + Model: "invalid_model", + Input: "Hello!", + Voice: openai.VoiceAlloy, + }) + checks.ErrorIs(t, err, openai.ErrInvalidSpeechModel, "CreateSpeech error") + }) + + t.Run("invalid voice", func(t *testing.T) { + _, err := client.CreateSpeech(context.Background(), openai.CreateSpeechRequest{ + Model: openai.TTSModel1, + Input: "Hello!", + Voice: "invalid_voice", + }) + checks.ErrorIs(t, err, openai.ErrInvalidVoice, "CreateSpeech error") + }) +} From fe67abb97ed472bad359cc606c2d63289277cabf Mon Sep 17 00:00:00 2001 From: Donnie Flood Date: Wed, 15 Nov 2023 09:06:57 -0700 Subject: [PATCH 31/98] fix: add beta assistant header to CreateMessage call (#566) --- messages.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/messages.go b/messages.go index 4e691a8b..3fd377fc 100644 --- a/messages.go +++ b/messages.go @@ -71,7 +71,7 @@ type MessageFilesList struct { // CreateMessage creates a new message. func (c *Client) CreateMessage(ctx context.Context, threadID string, request MessageRequest) (msg Message, err error) { urlSuffix := fmt.Sprintf("/threads/%s/%s", threadID, messagesSuffix) - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request)) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request), withBetaAssistantV1()) if err != nil { return } From 71848ccf6928157d1487c5bbd5029ceaf3af53ed Mon Sep 17 00:00:00 2001 From: Donnie Flood Date: Wed, 15 Nov 2023 09:08:48 -0700 Subject: [PATCH 32/98] feat: support direct bytes for file upload (#568) * feat: support direct bytes for file upload * add test for errors * add coverage --- client_test.go | 3 +++ files.go | 49 ++++++++++++++++++++++++++++++++++++++++++++++ files_api_test.go | 13 ++++++++++++ files_test.go | 50 +++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 115 insertions(+) diff --git a/client_test.go b/client_test.go index 1c908458..664f9fb9 100644 --- a/client_test.go +++ b/client_test.go @@ -247,6 +247,9 @@ func TestClientReturnsRequestBuilderErrors(t *testing.T) { {"CreateImage", func() (any, error) { return client.CreateImage(ctx, ImageRequest{}) }}, + {"CreateFileBytes", func() (any, error) { + return client.CreateFileBytes(ctx, FileBytesRequest{}) + }}, {"DeleteFile", func() (any, error) { return nil, client.DeleteFile(ctx, "") }}, diff --git a/files.go b/files.go index 9e521fbb..371d06c6 100644 --- a/files.go +++ b/files.go @@ -15,6 +15,24 @@ type FileRequest struct { Purpose string `json:"purpose"` } +// PurposeType represents the purpose of the file when uploading. +type PurposeType string + +const ( + PurposeFineTune PurposeType = "fine-tune" + PurposeAssistants PurposeType = "assistants" +) + +// FileBytesRequest represents a file upload request. +type FileBytesRequest struct { + // the name of the uploaded file in OpenAI + Name string + // the bytes of the file + Bytes []byte + // the purpose of the file + Purpose PurposeType +} + // File struct represents an OpenAPI file. type File struct { Bytes int `json:"bytes"` @@ -36,6 +54,37 @@ type FilesList struct { httpHeader } +// CreateFileBytes uploads bytes directly to OpenAI without requiring a local file. +func (c *Client) CreateFileBytes(ctx context.Context, request FileBytesRequest) (file File, err error) { + var b bytes.Buffer + reader := bytes.NewReader(request.Bytes) + builder := c.createFormBuilder(&b) + + err = builder.WriteField("purpose", string(request.Purpose)) + if err != nil { + return + } + + err = builder.CreateFormFileReader("file", reader, request.Name) + if err != nil { + return + } + + err = builder.Close() + if err != nil { + return + } + + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/files"), + withBody(&b), withContentType(builder.FormDataContentType())) + if err != nil { + return + } + + err = c.sendRequest(req, &file) + return +} + // CreateFile uploads a jsonl file to GPT3 // FilePath must be a local file path. func (c *Client) CreateFile(ctx context.Context, request FileRequest) (file File, err error) { diff --git a/files_api_test.go b/files_api_test.go index 330b8815..6f62a3fb 100644 --- a/files_api_test.go +++ b/files_api_test.go @@ -16,6 +16,19 @@ import ( "github.com/sashabaranov/go-openai/internal/test/checks" ) +func TestFileBytesUpload(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/files", handleCreateFile) + req := openai.FileBytesRequest{ + Name: "foo", + Bytes: []byte("foo"), + Purpose: openai.PurposeFineTune, + } + _, err := client.CreateFileBytes(context.Background(), req) + checks.NoError(t, err, "CreateFile error") +} + func TestFileUpload(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() diff --git a/files_test.go b/files_test.go index f588b30d..3c1b99fb 100644 --- a/files_test.go +++ b/files_test.go @@ -11,6 +11,53 @@ import ( "github.com/sashabaranov/go-openai/internal/test/checks" ) +func TestFileBytesUploadWithFailingFormBuilder(t *testing.T) { + config := DefaultConfig("") + config.BaseURL = "" + client := NewClientWithConfig(config) + mockBuilder := &mockFormBuilder{} + client.createFormBuilder = func(io.Writer) utils.FormBuilder { + return mockBuilder + } + + ctx := context.Background() + req := FileBytesRequest{ + Name: "foo", + Bytes: []byte("foo"), + Purpose: PurposeAssistants, + } + + mockError := fmt.Errorf("mockWriteField error") + mockBuilder.mockWriteField = func(string, string) error { + return mockError + } + _, err := client.CreateFileBytes(ctx, req) + checks.ErrorIs(t, err, mockError, "CreateFile should return error if form builder fails") + + mockError = fmt.Errorf("mockCreateFormFile error") + mockBuilder.mockWriteField = func(string, string) error { + return nil + } + mockBuilder.mockCreateFormFileReader = func(string, io.Reader, string) error { + return mockError + } + _, err = client.CreateFileBytes(ctx, req) + checks.ErrorIs(t, err, mockError, "CreateFile should return error if form builder fails") + + mockError = fmt.Errorf("mockClose error") + mockBuilder.mockWriteField = func(string, string) error { + return nil + } + mockBuilder.mockCreateFormFileReader = func(string, io.Reader, string) error { + return nil + } + mockBuilder.mockClose = func() error { + return mockError + } + _, err = client.CreateFileBytes(ctx, req) + checks.ErrorIs(t, err, mockError, "CreateFile should return error if form builder fails") +} + func TestFileUploadWithFailingFormBuilder(t *testing.T) { config := DefaultConfig("") config.BaseURL = "" @@ -55,6 +102,9 @@ func TestFileUploadWithFailingFormBuilder(t *testing.T) { return mockError } _, err = client.CreateFile(ctx, req) + if err == nil { + t.Fatal("CreateFile should return error if form builder fails") + } checks.ErrorIs(t, err, mockError, "CreateFile should return error if form builder fails") } From 464b85b6d766a53c922a15dd1138570e31ec661b Mon Sep 17 00:00:00 2001 From: Liron Levin Date: Wed, 15 Nov 2023 18:22:39 +0200 Subject: [PATCH 33/98] Pagination fields are missing from assistants list beta API (#571) curl "https://api.openai.com/v1/assistants?order=desc&limit=20" \ -H "Content-Type: application/json" \ -H "Authorization: Bearer $OPENAI_API_KEY" \ -H "OpenAI-Beta: assistants=v1" { "object": "list", "data": [], "first_id": null, "last_id": null, "has_more": false } --- assistant.go | 4 +++- assistant_test.go | 2 ++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/assistant.go b/assistant.go index de49be68..59f78284 100644 --- a/assistant.go +++ b/assistant.go @@ -52,7 +52,9 @@ type AssistantRequest struct { // AssistantsList is a list of assistants. type AssistantsList struct { Assistants []Assistant `json:"data"` - + LastID *string `json:"last_id"` + FirstID *string `json:"first_id"` + HasMore bool `json:"has_more"` httpHeader } diff --git a/assistant_test.go b/assistant_test.go index eb6f4245..30daec2b 100644 --- a/assistant_test.go +++ b/assistant_test.go @@ -142,6 +142,8 @@ When asked a question, write and run Python code to answer the question.` fmt.Fprintln(w, string(resBytes)) } else if r.Method == http.MethodGet { resBytes, _ := json.Marshal(openai.AssistantsList{ + LastID: &assistantID, + FirstID: &assistantID, Assistants: []openai.Assistant{ { ID: assistantID, From 3220f19ee209de5e4bbc6db44261adcd4bbf1df1 Mon Sep 17 00:00:00 2001 From: Ccheers <1048315650@qq.com> Date: Thu, 16 Nov 2023 00:23:41 +0800 Subject: [PATCH 34/98] feat(runapi): add RunStepList response args https://platform.openai.com/docs/api-reference/runs/listRunSteps (#573) --- run.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/run.go b/run.go index 7ff730fe..f95bf0e3 100644 --- a/run.go +++ b/run.go @@ -157,6 +157,10 @@ type StepDetailsToolCalls struct { type RunStepList struct { RunSteps []RunStep `json:"data"` + FirstID string `json:"first_id"` + LastID string `json:"last_id"` + HasMore bool `json:"has_more"` + httpHeader } From 18465723f7d96587045ce0a450d6874128b870cd Mon Sep 17 00:00:00 2001 From: Charlie Revett <2796074+revett@users.noreply.github.com> Date: Wed, 15 Nov 2023 16:25:18 +0000 Subject: [PATCH 35/98] Add missing struct properties. (#579) --- assistant.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/assistant.go b/assistant.go index 59f78284..bd335833 100644 --- a/assistant.go +++ b/assistant.go @@ -22,6 +22,8 @@ type Assistant struct { Model string `json:"model"` Instructions *string `json:"instructions,omitempty"` Tools []AssistantTool `json:"tools,omitempty"` + FileIDs []string `json:"file_ids,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` httpHeader } From 4fd904c2927c421cdbff89249979bc6a8a371d11 Mon Sep 17 00:00:00 2001 From: Charlie Revett <2796074+revett@users.noreply.github.com> Date: Sat, 18 Nov 2023 06:55:58 +0000 Subject: [PATCH 36/98] Add File purposes as constants (#577) * Add purposes. * Formatting. --- files.go | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/files.go b/files.go index 371d06c6..a37d45f1 100644 --- a/files.go +++ b/files.go @@ -19,8 +19,10 @@ type FileRequest struct { type PurposeType string const ( - PurposeFineTune PurposeType = "fine-tune" - PurposeAssistants PurposeType = "assistants" + PurposeFineTune PurposeType = "fine-tune" + PurposeFineTuneResults PurposeType = "fine-tune-results" + PurposeAssistants PurposeType = "assistants" + PurposeAssistantsOutput PurposeType = "assistants_output" ) // FileBytesRequest represents a file upload request. From 9efad284d02d90b2de3eeefc67a966743e47a2ac Mon Sep 17 00:00:00 2001 From: Albert Putra Purnama <14824254+albertpurnama@users.noreply.github.com> Date: Fri, 17 Nov 2023 22:59:01 -0800 Subject: [PATCH 37/98] Updates the tool call struct (#595) --- run.go | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/run.go b/run.go index f95bf0e3..dbb708a1 100644 --- a/run.go +++ b/run.go @@ -142,17 +142,13 @@ const ( type StepDetails struct { Type RunStepType `json:"type"` MessageCreation *StepDetailsMessageCreation `json:"message_creation,omitempty"` - ToolCalls *StepDetailsToolCalls `json:"tool_calls,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` } type StepDetailsMessageCreation struct { MessageID string `json:"message_id"` } -type StepDetailsToolCalls struct { - ToolCalls []ToolCall `json:"tool_calls"` -} - // RunStepList is a list of steps. type RunStepList struct { RunSteps []RunStep `json:"data"` From a130cfee26427b99ae0bf957be74e32ca8a7f567 Mon Sep 17 00:00:00 2001 From: Albert Putra Purnama <14824254+albertpurnama@users.noreply.github.com> Date: Fri, 17 Nov 2023 23:01:06 -0800 Subject: [PATCH 38/98] Add missing response fields for pagination (#584) --- messages.go | 5 +++++ messages_test.go | 7 ++++++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/messages.go b/messages.go index 3fd377fc..ead247f5 100644 --- a/messages.go +++ b/messages.go @@ -29,6 +29,11 @@ type Message struct { type MessagesList struct { Messages []Message `json:"data"` + Object string `json:"object"` + FirstID *string `json:"first_id"` + LastID *string `json:"last_id"` + HasMore bool `json:"has_more"` + httpHeader } diff --git a/messages_test.go b/messages_test.go index 282b1cc9..9168d6cc 100644 --- a/messages_test.go +++ b/messages_test.go @@ -142,6 +142,7 @@ func TestMessages(t *testing.T) { fmt.Fprintln(w, string(resBytes)) case http.MethodGet: resBytes, _ := json.Marshal(openai.MessagesList{ + Object: "list", Messages: []openai.Message{{ ID: messageID, Object: "thread.message", @@ -159,7 +160,11 @@ func TestMessages(t *testing.T) { AssistantID: &emptyStr, RunID: &emptyStr, Metadata: nil, - }}}) + }}, + FirstID: &messageID, + LastID: &messageID, + HasMore: false, + }) fmt.Fprintln(w, string(resBytes)) default: t.Fatalf("unsupported messages http method: %s", r.Method) From f87909596f8b0d293142ca00c4d4adc872c52ded Mon Sep 17 00:00:00 2001 From: pjuhasz Date: Fri, 24 Nov 2023 07:34:25 +0000 Subject: [PATCH 39/98] Add canary-tts to speech models (#603) Co-authored-by: Peter Juhasz --- speech.go | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/speech.go b/speech.go index a3d5f5dc..f2442b92 100644 --- a/speech.go +++ b/speech.go @@ -10,8 +10,9 @@ import ( type SpeechModel string const ( - TTSModel1 SpeechModel = "tts-1" - TTsModel1HD SpeechModel = "tts-1-hd" + TTSModel1 SpeechModel = "tts-1" + TTSModel1HD SpeechModel = "tts-1-hd" + TTSModelCanary SpeechModel = "canary-tts" ) type SpeechVoice string @@ -57,7 +58,7 @@ func contains[T comparable](s []T, e T) bool { } func isValidSpeechModel(model SpeechModel) bool { - return contains([]SpeechModel{TTSModel1, TTsModel1HD}, model) + return contains([]SpeechModel{TTSModel1, TTSModel1HD, TTSModelCanary}, model) } func isValidVoice(voice SpeechVoice) bool { From 726099132704fd5ebc1680166f45bbd280bdb546 Mon Sep 17 00:00:00 2001 From: Alexander Baranov <677093+sashabaranov@users.noreply.github.com> Date: Fri, 24 Nov 2023 13:36:10 +0400 Subject: [PATCH 40/98] Update PULL_REQUEST_TEMPLATE.md (#606) --- .github/PULL_REQUEST_TEMPLATE.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 44bf697e..222c065c 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -8,11 +8,14 @@ Thanks for submitting a pull request! Please provide enough information so that **Describe the change** Please provide a clear and concise description of the changes you're proposing. Explain what problem it solves or what feature it adds. +**Provide OpenAI documentation link** +Provide a relevant API doc from https://platform.openai.com/docs/api-reference + **Describe your solution** Describe how your changes address the problem or how they add the feature. This should include a brief description of your approach and any new libraries or dependencies you're using. **Tests** -Briefly describe how you have tested these changes. +Briefly describe how you have tested these changes. If possible — please add integration tests. **Additional context** Add any other context or screenshots or logs about your pull request here. If the pull request relates to an open issue, please link to it. From 03caea89b75c4e6a5ac32f6e60e69e309d852e8b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rados=C5=82aw=20Kintzi?= Date: Fri, 24 Nov 2023 13:17:00 +0000 Subject: [PATCH 41/98] Add support for multi part chat messages (and gpt-4-vision-preview model) (#580) * Add support for multi part chat messages OpenAI has recently introduced a new model called gpt-4-visual-preview, which now supports images as input. The chat completion endpoint accepts multi-part chat messages, where the content can be an array of structs in addition to the usual string format. This commit introduces new structures and constants to represent different types of content parts. It also implements the json.Marshaler and json.Unmarshaler interfaces on ChatCompletionMessage. * Add ImageURLDetail and ChatMessagePartType types * Optimize ChatCompletionMessage deserialization * Add ErrContentFieldsMisused error --- chat.go | 91 ++++++++++++++++++++++++++++++++++++++++++++- chat_test.go | 103 +++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 192 insertions(+), 2 deletions(-) diff --git a/chat.go b/chat.go index ebdc0e24..5b87b6bd 100644 --- a/chat.go +++ b/chat.go @@ -2,6 +2,7 @@ package openai import ( "context" + "encoding/json" "errors" "net/http" ) @@ -20,6 +21,7 @@ const chatCompletionsSuffix = "/chat/completions" var ( ErrChatCompletionInvalidModel = errors.New("this model is not supported with this method, please use CreateCompletion client method instead") //nolint:lll ErrChatCompletionStreamNotSupported = errors.New("streaming is not supported with this method, please use CreateChatCompletionStream") //nolint:lll + ErrContentFieldsMisused = errors.New("can't use both Content and MultiContent properties simultaneously") ) type Hate struct { @@ -51,9 +53,36 @@ type PromptAnnotation struct { ContentFilterResults ContentFilterResults `json:"content_filter_results,omitempty"` } +type ImageURLDetail string + +const ( + ImageURLDetailHigh ImageURLDetail = "high" + ImageURLDetailLow ImageURLDetail = "low" + ImageURLDetailAuto ImageURLDetail = "auto" +) + +type ChatMessageImageURL struct { + URL string `json:"url,omitempty"` + Detail ImageURLDetail `json:"detail,omitempty"` +} + +type ChatMessagePartType string + +const ( + ChatMessagePartTypeText ChatMessagePartType = "text" + ChatMessagePartTypeImageURL ChatMessagePartType = "image_url" +) + +type ChatMessagePart struct { + Type ChatMessagePartType `json:"type,omitempty"` + Text string `json:"text,omitempty"` + ImageURL *ChatMessageImageURL `json:"image_url,omitempty"` +} + type ChatCompletionMessage struct { - Role string `json:"role"` - Content string `json:"content"` + Role string `json:"role"` + Content string `json:"content"` + MultiContent []ChatMessagePart // This property isn't in the official documentation, but it's in // the documentation for the official library for python: @@ -70,6 +99,64 @@ type ChatCompletionMessage struct { ToolCallID string `json:"tool_call_id,omitempty"` } +func (m ChatCompletionMessage) MarshalJSON() ([]byte, error) { + if m.Content != "" && m.MultiContent != nil { + return nil, ErrContentFieldsMisused + } + if len(m.MultiContent) > 0 { + msg := struct { + Role string `json:"role"` + Content string `json:"-"` + MultiContent []ChatMessagePart `json:"content,omitempty"` + Name string `json:"name,omitempty"` + FunctionCall *FunctionCall `json:"function_call,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` + }(m) + return json.Marshal(msg) + } + msg := struct { + Role string `json:"role"` + Content string `json:"content"` + MultiContent []ChatMessagePart `json:"-"` + Name string `json:"name,omitempty"` + FunctionCall *FunctionCall `json:"function_call,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` + }(m) + return json.Marshal(msg) +} + +func (m *ChatCompletionMessage) UnmarshalJSON(bs []byte) error { + msg := struct { + Role string `json:"role"` + Content string `json:"content"` + MultiContent []ChatMessagePart + Name string `json:"name,omitempty"` + FunctionCall *FunctionCall `json:"function_call,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` + }{} + if err := json.Unmarshal(bs, &msg); err == nil { + *m = ChatCompletionMessage(msg) + return nil + } + multiMsg := struct { + Role string `json:"role"` + Content string + MultiContent []ChatMessagePart `json:"content"` + Name string `json:"name,omitempty"` + FunctionCall *FunctionCall `json:"function_call,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` + }{} + if err := json.Unmarshal(bs, &multiMsg); err != nil { + return err + } + *m = ChatCompletionMessage(multiMsg) + return nil +} + type ToolCall struct { // Index is not nil only in chat completion chunk object Index *int `json:"index,omitempty"` diff --git a/chat_test.go b/chat_test.go index 8377809d..520bf5ca 100644 --- a/chat_test.go +++ b/chat_test.go @@ -3,6 +3,7 @@ package openai_test import ( "context" "encoding/json" + "errors" "fmt" "io" "net/http" @@ -296,6 +297,108 @@ func TestAzureChatCompletions(t *testing.T) { checks.NoError(t, err, "CreateAzureChatCompletion error") } +func TestMultipartChatCompletions(t *testing.T) { + client, server, teardown := setupAzureTestServer() + defer teardown() + server.RegisterHandler("/openai/deployments/*", handleChatCompletionEndpoint) + + _, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ + MaxTokens: 5, + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + MultiContent: []openai.ChatMessagePart{ + { + Type: openai.ChatMessagePartTypeText, + Text: "Hello!", + }, + { + Type: openai.ChatMessagePartTypeImageURL, + ImageURL: &openai.ChatMessageImageURL{ + URL: "URL", + Detail: openai.ImageURLDetailLow, + }, + }, + }, + }, + }, + }) + checks.NoError(t, err, "CreateAzureChatCompletion error") +} + +func TestMultipartChatMessageSerialization(t *testing.T) { + jsonText := `[{"role":"system","content":"system-message"},` + + `{"role":"user","content":[{"type":"text","text":"nice-text"},` + + `{"type":"image_url","image_url":{"url":"URL","detail":"high"}}]}]` + + var msgs []openai.ChatCompletionMessage + err := json.Unmarshal([]byte(jsonText), &msgs) + if err != nil { + t.Fatalf("Expected no error: %s", err) + } + if len(msgs) != 2 { + t.Errorf("unexpected number of messages") + } + if msgs[0].Role != "system" || msgs[0].Content != "system-message" || msgs[0].MultiContent != nil { + t.Errorf("invalid user message: %v", msgs[0]) + } + if msgs[1].Role != "user" || msgs[1].Content != "" || len(msgs[1].MultiContent) != 2 { + t.Errorf("invalid user message") + } + parts := msgs[1].MultiContent + if parts[0].Type != "text" || parts[0].Text != "nice-text" { + t.Errorf("invalid text part: %v", parts[0]) + } + if parts[1].Type != "image_url" || parts[1].ImageURL.URL != "URL" || parts[1].ImageURL.Detail != "high" { + t.Errorf("invalid image_url part") + } + + s, err := json.Marshal(msgs) + if err != nil { + t.Fatalf("Expected no error: %s", err) + } + res := strings.ReplaceAll(string(s), " ", "") + if res != jsonText { + t.Fatalf("invalid message: %s", string(s)) + } + + invalidMsg := []openai.ChatCompletionMessage{ + { + Role: "user", + Content: "some-text", + MultiContent: []openai.ChatMessagePart{ + { + Type: "text", + Text: "nice-text", + }, + }, + }, + } + _, err = json.Marshal(invalidMsg) + if !errors.Is(err, openai.ErrContentFieldsMisused) { + t.Fatalf("Expected error: %s", err) + } + + err = json.Unmarshal([]byte(`["not-a-message"]`), &msgs) + if err == nil { + t.Fatalf("Expected error") + } + + emptyMultiContentMsg := openai.ChatCompletionMessage{ + Role: "user", + MultiContent: []openai.ChatMessagePart{}, + } + s, err = json.Marshal(emptyMultiContentMsg) + if err != nil { + t.Fatalf("Unexpected error") + } + res = strings.ReplaceAll(string(s), " ", "") + if res != `{"role":"user","content":""}` { + t.Fatalf("invalid message: %s", string(s)) + } +} + // handleChatCompletionEndpoint Handles the ChatGPT completion endpoint by the test server. func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) { var err error From a09cb0c528c110a6955a9ee9a5d021a57ed44b90 Mon Sep 17 00:00:00 2001 From: mikeb26 <83850730+mikeb26@users.noreply.github.com> Date: Sun, 26 Nov 2023 08:45:28 +0000 Subject: [PATCH 42/98] Add completion-with-tool example (#598) As a user of this go SDK it was not immediately intuitive to me how to correctly utilize the function calling capability of GPT4 (https://platform.openai.com/docs/guides/function-calling). While the aformentioned link provides a helpful example written in python, I initially tripped over how to correclty translate the specification of function arguments when usingthis go SDK. To make it easier for others in the future this commit adds a completion-with-tool example showing how to correctly utilize the function calling capability of GPT4 using this SDK end-to-end in a CreateChatCompletion() sequence. --- examples/completion-with-tool/main.go | 94 +++++++++++++++++++++++++++ 1 file changed, 94 insertions(+) create mode 100644 examples/completion-with-tool/main.go diff --git a/examples/completion-with-tool/main.go b/examples/completion-with-tool/main.go new file mode 100644 index 00000000..2c7fedc5 --- /dev/null +++ b/examples/completion-with-tool/main.go @@ -0,0 +1,94 @@ +package main + +import ( + "context" + "fmt" + "os" + + "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/jsonschema" +) + +func main() { + ctx := context.Background() + client := openai.NewClient(os.Getenv("OPENAI_API_KEY")) + + // describe the function & its inputs + params := jsonschema.Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "location": { + Type: jsonschema.String, + Description: "The city and state, e.g. San Francisco, CA", + }, + "unit": { + Type: jsonschema.String, + Enum: []string{"celsius", "fahrenheit"}, + }, + }, + Required: []string{"location"}, + } + f := openai.FunctionDefinition{ + Name: "get_current_weather", + Description: "Get the current weather in a given location", + Parameters: params, + } + t := openai.Tool{ + Type: openai.ToolTypeFunction, + Function: f, + } + + // simulate user asking a question that requires the function + dialogue := []openai.ChatCompletionMessage{ + {Role: openai.ChatMessageRoleUser, Content: "What is the weather in Boston today?"}, + } + fmt.Printf("Asking OpenAI '%v' and providing it a '%v()' function...\n", + dialogue[0].Content, f.Name) + resp, err := client.CreateChatCompletion(ctx, + openai.ChatCompletionRequest{ + Model: openai.GPT4TurboPreview, + Messages: dialogue, + Tools: []openai.Tool{t}, + }, + ) + if err != nil || len(resp.Choices) != 1 { + fmt.Printf("Completion error: err:%v len(choices):%v\n", err, + len(resp.Choices)) + return + } + msg := resp.Choices[0].Message + if len(msg.ToolCalls) != 1 { + fmt.Printf("Completion error: len(toolcalls): %v\n", len(msg.ToolCalls)) + return + } + + // simulate calling the function & responding to OpenAI + dialogue = append(dialogue, msg) + fmt.Printf("OpenAI called us back wanting to invoke our function '%v' with params '%v'\n", + msg.ToolCalls[0].Function.Name, msg.ToolCalls[0].Function.Arguments) + dialogue = append(dialogue, openai.ChatCompletionMessage{ + Role: openai.ChatMessageRoleTool, + Content: "Sunny and 80 degrees.", + Name: msg.ToolCalls[0].Function.Name, + ToolCallID: msg.ToolCalls[0].ID, + }) + fmt.Printf("Sending OpenAI our '%v()' function's response and requesting the reply to the original question...\n", + f.Name) + resp, err = client.CreateChatCompletion(ctx, + openai.ChatCompletionRequest{ + Model: openai.GPT4TurboPreview, + Messages: dialogue, + Tools: []openai.Tool{t}, + }, + ) + if err != nil || len(resp.Choices) != 1 { + fmt.Printf("2nd completion error: err:%v len(choices):%v\n", err, + len(resp.Choices)) + return + } + + // display OpenAI's response to the original question utilizing our function + msg = resp.Choices[0].Message + fmt.Printf("OpenAI answered the original request with: %v\n", + msg.Content) +} From c9615e0cbe3b68088ee04221acdfde63d6d20766 Mon Sep 17 00:00:00 2001 From: "xuanming.zhang" Date: Wed, 3 Jan 2024 19:42:57 +0800 Subject: [PATCH 43/98] Added support for createImage Azure models (#608) --- image.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/image.go b/image.go index 4fe8b3a3..afd4e196 100644 --- a/image.go +++ b/image.go @@ -68,7 +68,7 @@ type ImageResponseDataInner struct { // CreateImage - API call to create an image. This is the main endpoint of the DALL-E API. func (c *Client) CreateImage(ctx context.Context, request ImageRequest) (response ImageResponse, err error) { urlSuffix := "/images/generations" - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request)) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), withBody(request)) if err != nil { return } From f10955ce090c7b0d8f38458c753c01cd9b88aca5 Mon Sep 17 00:00:00 2001 From: Danai Antoniou <32068609+danai-antoniou@users.noreply.github.com> Date: Tue, 9 Jan 2024 16:50:56 +0000 Subject: [PATCH 44/98] Log probabilities for chat completion output tokens (#625) * Add logprobs * Logprobs pointer * Move toplogporbs * Create toplogprobs struct * Remove pointers --- chat.go | 33 ++++++++++++++++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/chat.go b/chat.go index 5b87b6bd..33b8755c 100644 --- a/chat.go +++ b/chat.go @@ -200,7 +200,15 @@ type ChatCompletionRequest struct { // incorrect: `"logit_bias":{"You": 6}`, correct: `"logit_bias":{"1639": 6}` // refs: https://platform.openai.com/docs/api-reference/chat/create#chat/create-logit_bias LogitBias map[string]int `json:"logit_bias,omitempty"` - User string `json:"user,omitempty"` + // LogProbs indicates whether to return log probabilities of the output tokens or not. + // If true, returns the log probabilities of each output token returned in the content of message. + // This option is currently not available on the gpt-4-vision-preview model. + LogProbs bool `json:"logprobs,omitempty"` + // TopLogProbs is an integer between 0 and 5 specifying the number of most likely tokens to return at each + // token position, each with an associated log probability. + // logprobs must be set to true if this parameter is used. + TopLogProbs int `json:"top_logprobs,omitempty"` + User string `json:"user,omitempty"` // Deprecated: use Tools instead. Functions []FunctionDefinition `json:"functions,omitempty"` // Deprecated: use ToolChoice instead. @@ -244,6 +252,28 @@ type FunctionDefinition struct { // Deprecated: use FunctionDefinition instead. type FunctionDefine = FunctionDefinition +type TopLogProbs struct { + Token string `json:"token"` + LogProb float64 `json:"logprob"` + Bytes []byte `json:"bytes,omitempty"` +} + +// LogProb represents the probability information for a token. +type LogProb struct { + Token string `json:"token"` + LogProb float64 `json:"logprob"` + Bytes []byte `json:"bytes,omitempty"` // Omitting the field if it is null + // TopLogProbs is a list of the most likely tokens and their log probability, at this token position. + // In rare cases, there may be fewer than the number of requested top_logprobs returned. + TopLogProbs []TopLogProbs `json:"top_logprobs"` +} + +// LogProbs is the top-level structure containing the log probability information. +type LogProbs struct { + // Content is a list of message content tokens with log probability information. + Content []LogProb `json:"content"` +} + type FinishReason string const ( @@ -273,6 +303,7 @@ type ChatCompletionChoice struct { // content_filter: Omitted content due to a flag from our content filters // null: API response still in progress or incomplete FinishReason FinishReason `json:"finish_reason"` + LogProbs *LogProbs `json:"logprobs,omitempty"` } // ChatCompletionResponse represents a response structure for chat completion API. From 682b7adb0bd645f290031fbca6028feb5c22ab9c Mon Sep 17 00:00:00 2001 From: Alexander Kledal Date: Thu, 11 Jan 2024 11:45:15 +0100 Subject: [PATCH 45/98] Update README.md (#631) Ensure variables in examples are valid --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 4cb77db6..9a479c0a 100644 --- a/README.md +++ b/README.md @@ -453,7 +453,7 @@ func main() { config := openai.DefaultAzureConfig("your Azure OpenAI Key", "https://your Azure OpenAI Endpoint") // If you use a deployment name different from the model name, you can customize the AzureModelMapperFunc function // config.AzureModelMapperFunc = func(model string) string { - // azureModelMapping = map[string]string{ + // azureModelMapping := map[string]string{ // "gpt-3.5-turbo": "your gpt-3.5-turbo deployment name", // } // return azureModelMapping[model] @@ -559,7 +559,7 @@ func main() { //If you use a deployment name different from the model name, you can customize the AzureModelMapperFunc function //config.AzureModelMapperFunc = func(model string) string { - // azureModelMapping = map[string]string{ + // azureModelMapping := map[string]string{ // "gpt-3.5-turbo":"your gpt-3.5-turbo deployment name", // } // return azureModelMapping[model] From e01a2d7231fafec2c1cbdd176806e3be767df965 Mon Sep 17 00:00:00 2001 From: Matthew Jaffee Date: Mon, 15 Jan 2024 03:33:02 -0600 Subject: [PATCH 46/98] convert EmbeddingModel to string type (#629) This gives the user the ability to pass in models for embeddings that are not already defined in the library. Also more closely matches how the completions API works. --- embeddings.go | 120 ++++++++------------------------------------- embeddings_test.go | 22 ++------- 2 files changed, 24 insertions(+), 118 deletions(-) diff --git a/embeddings.go b/embeddings.go index 7e2aa7eb..f79df9df 100644 --- a/embeddings.go +++ b/embeddings.go @@ -13,108 +13,30 @@ var ErrVectorLengthMismatch = errors.New("vector length mismatch") // EmbeddingModel enumerates the models which can be used // to generate Embedding vectors. -type EmbeddingModel int - -// String implements the fmt.Stringer interface. -func (e EmbeddingModel) String() string { - return enumToString[e] -} - -// MarshalText implements the encoding.TextMarshaler interface. -func (e EmbeddingModel) MarshalText() ([]byte, error) { - return []byte(e.String()), nil -} - -// UnmarshalText implements the encoding.TextUnmarshaler interface. -// On unrecognized value, it sets |e| to Unknown. -func (e *EmbeddingModel) UnmarshalText(b []byte) error { - if val, ok := stringToEnum[(string(b))]; ok { - *e = val - return nil - } - - *e = Unknown - - return nil -} +type EmbeddingModel string const ( - Unknown EmbeddingModel = iota - // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. - AdaSimilarity - // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. - BabbageSimilarity - // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. - CurieSimilarity - // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. - DavinciSimilarity - // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. - AdaSearchDocument - // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. - AdaSearchQuery - // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. - BabbageSearchDocument - // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. - BabbageSearchQuery - // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. - CurieSearchDocument - // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. - CurieSearchQuery - // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. - DavinciSearchDocument - // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. - DavinciSearchQuery - // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. - AdaCodeSearchCode - // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. - AdaCodeSearchText - // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. - BabbageCodeSearchCode - // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. - BabbageCodeSearchText - AdaEmbeddingV2 + // Deprecated: The following block will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. + AdaSimilarity EmbeddingModel = "text-similarity-ada-001" + BabbageSimilarity EmbeddingModel = "text-similarity-babbage-001" + CurieSimilarity EmbeddingModel = "text-similarity-curie-001" + DavinciSimilarity EmbeddingModel = "text-similarity-davinci-001" + AdaSearchDocument EmbeddingModel = "text-search-ada-doc-001" + AdaSearchQuery EmbeddingModel = "text-search-ada-query-001" + BabbageSearchDocument EmbeddingModel = "text-search-babbage-doc-001" + BabbageSearchQuery EmbeddingModel = "text-search-babbage-query-001" + CurieSearchDocument EmbeddingModel = "text-search-curie-doc-001" + CurieSearchQuery EmbeddingModel = "text-search-curie-query-001" + DavinciSearchDocument EmbeddingModel = "text-search-davinci-doc-001" + DavinciSearchQuery EmbeddingModel = "text-search-davinci-query-001" + AdaCodeSearchCode EmbeddingModel = "code-search-ada-code-001" + AdaCodeSearchText EmbeddingModel = "code-search-ada-text-001" + BabbageCodeSearchCode EmbeddingModel = "code-search-babbage-code-001" + BabbageCodeSearchText EmbeddingModel = "code-search-babbage-text-001" + + AdaEmbeddingV2 EmbeddingModel = "text-embedding-ada-002" ) -var enumToString = map[EmbeddingModel]string{ - AdaSimilarity: "text-similarity-ada-001", - BabbageSimilarity: "text-similarity-babbage-001", - CurieSimilarity: "text-similarity-curie-001", - DavinciSimilarity: "text-similarity-davinci-001", - AdaSearchDocument: "text-search-ada-doc-001", - AdaSearchQuery: "text-search-ada-query-001", - BabbageSearchDocument: "text-search-babbage-doc-001", - BabbageSearchQuery: "text-search-babbage-query-001", - CurieSearchDocument: "text-search-curie-doc-001", - CurieSearchQuery: "text-search-curie-query-001", - DavinciSearchDocument: "text-search-davinci-doc-001", - DavinciSearchQuery: "text-search-davinci-query-001", - AdaCodeSearchCode: "code-search-ada-code-001", - AdaCodeSearchText: "code-search-ada-text-001", - BabbageCodeSearchCode: "code-search-babbage-code-001", - BabbageCodeSearchText: "code-search-babbage-text-001", - AdaEmbeddingV2: "text-embedding-ada-002", -} - -var stringToEnum = map[string]EmbeddingModel{ - "text-similarity-ada-001": AdaSimilarity, - "text-similarity-babbage-001": BabbageSimilarity, - "text-similarity-curie-001": CurieSimilarity, - "text-similarity-davinci-001": DavinciSimilarity, - "text-search-ada-doc-001": AdaSearchDocument, - "text-search-ada-query-001": AdaSearchQuery, - "text-search-babbage-doc-001": BabbageSearchDocument, - "text-search-babbage-query-001": BabbageSearchQuery, - "text-search-curie-doc-001": CurieSearchDocument, - "text-search-curie-query-001": CurieSearchQuery, - "text-search-davinci-doc-001": DavinciSearchDocument, - "text-search-davinci-query-001": DavinciSearchQuery, - "code-search-ada-code-001": AdaCodeSearchCode, - "code-search-ada-text-001": AdaCodeSearchText, - "code-search-babbage-code-001": BabbageCodeSearchCode, - "code-search-babbage-text-001": BabbageCodeSearchText, - "text-embedding-ada-002": AdaEmbeddingV2, -} - // Embedding is a special format of data representation that can be easily utilized by machine // learning models and algorithms. The embedding is an information dense representation of the // semantic meaning of a piece of text. Each embedding is a vector of floating point numbers, @@ -306,7 +228,7 @@ func (c *Client) CreateEmbeddings( conv EmbeddingRequestConverter, ) (res EmbeddingResponse, err error) { baseReq := conv.Convert() - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/embeddings", baseReq.Model.String()), withBody(baseReq)) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/embeddings", baseReq.Model), withBody(baseReq)) if err != nil { return } diff --git a/embeddings_test.go b/embeddings_test.go index af04d96b..846d1995 100644 --- a/embeddings_test.go +++ b/embeddings_test.go @@ -47,7 +47,7 @@ func TestEmbedding(t *testing.T) { // the AdaSearchQuery type marshaled, err := json.Marshal(embeddingReq) checks.NoError(t, err, "Could not marshal embedding request") - if !bytes.Contains(marshaled, []byte(`"model":"`+model.String()+`"`)) { + if !bytes.Contains(marshaled, []byte(`"model":"`+model+`"`)) { t.Fatalf("Expected embedding request to contain model field") } @@ -61,7 +61,7 @@ func TestEmbedding(t *testing.T) { } marshaled, err = json.Marshal(embeddingReqStrings) checks.NoError(t, err, "Could not marshal embedding request") - if !bytes.Contains(marshaled, []byte(`"model":"`+model.String()+`"`)) { + if !bytes.Contains(marshaled, []byte(`"model":"`+model+`"`)) { t.Fatalf("Expected embedding request to contain model field") } @@ -75,28 +75,12 @@ func TestEmbedding(t *testing.T) { } marshaled, err = json.Marshal(embeddingReqTokens) checks.NoError(t, err, "Could not marshal embedding request") - if !bytes.Contains(marshaled, []byte(`"model":"`+model.String()+`"`)) { + if !bytes.Contains(marshaled, []byte(`"model":"`+model+`"`)) { t.Fatalf("Expected embedding request to contain model field") } } } -func TestEmbeddingModel(t *testing.T) { - var em openai.EmbeddingModel - err := em.UnmarshalText([]byte("text-similarity-ada-001")) - checks.NoError(t, err, "Could not marshal embedding model") - - if em != openai.AdaSimilarity { - t.Errorf("Model is not equal to AdaSimilarity") - } - - err = em.UnmarshalText([]byte("some-non-existent-model")) - checks.NoError(t, err, "Could not marshal embedding model") - if em != openai.Unknown { - t.Errorf("Model is not equal to Unknown") - } -} - func TestEmbeddingEndpoint(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() From 09f6920ad04666f65dd86ed542e5ebf8bffc93a8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9F=A9=E5=AE=8F=E6=95=8F?= Date: Mon, 15 Jan 2024 20:01:49 +0800 Subject: [PATCH 47/98] fixed #594 (#609) APITypeAzure dall-e3 model url Co-authored-by: HanHongmin --- image.go | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/image.go b/image.go index afd4e196..665de1a7 100644 --- a/image.go +++ b/image.go @@ -82,6 +82,7 @@ type ImageEditRequest struct { Image *os.File `json:"image,omitempty"` Mask *os.File `json:"mask,omitempty"` Prompt string `json:"prompt,omitempty"` + Model string `json:"model,omitempty"` N int `json:"n,omitempty"` Size string `json:"size,omitempty"` ResponseFormat string `json:"response_format,omitempty"` @@ -131,7 +132,7 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest) return } - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/images/edits"), + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/images/edits", request.Model), withBody(body), withContentType(builder.FormDataContentType())) if err != nil { return @@ -144,6 +145,7 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest) // ImageVariRequest represents the request structure for the image API. type ImageVariRequest struct { Image *os.File `json:"image,omitempty"` + Model string `json:"model,omitempty"` N int `json:"n,omitempty"` Size string `json:"size,omitempty"` ResponseFormat string `json:"response_format,omitempty"` @@ -181,7 +183,7 @@ func (c *Client) CreateVariImage(ctx context.Context, request ImageVariRequest) return } - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/images/variations"), + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/images/variations", request.Model), withBody(body), withContentType(builder.FormDataContentType())) if err != nil { return From 4ce03a919ae9fdcb62e8098a03500ef77eafe348 Mon Sep 17 00:00:00 2001 From: Grey Baker Date: Tue, 16 Jan 2024 04:32:48 -0500 Subject: [PATCH 48/98] Fix Azure embeddings model detection by passing string to `fullURL` (#637) --- embeddings.go | 2 +- embeddings_test.go | 26 ++++++++++++++++++++++++++ 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/embeddings.go b/embeddings.go index f79df9df..c144119f 100644 --- a/embeddings.go +++ b/embeddings.go @@ -228,7 +228,7 @@ func (c *Client) CreateEmbeddings( conv EmbeddingRequestConverter, ) (res EmbeddingResponse, err error) { baseReq := conv.Convert() - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/embeddings", baseReq.Model), withBody(baseReq)) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/embeddings", string(baseReq.Model)), withBody(baseReq)) if err != nil { return } diff --git a/embeddings_test.go b/embeddings_test.go index 846d1995..ed6384f3 100644 --- a/embeddings_test.go +++ b/embeddings_test.go @@ -158,6 +158,32 @@ func TestEmbeddingEndpoint(t *testing.T) { checks.HasError(t, err, "CreateEmbeddings error") } +func TestAzureEmbeddingEndpoint(t *testing.T) { + client, server, teardown := setupAzureTestServer() + defer teardown() + + sampleEmbeddings := []openai.Embedding{ + {Embedding: []float32{1.23, 4.56, 7.89}}, + {Embedding: []float32{-0.006968617, -0.0052718227, 0.011901081}}, + } + + server.RegisterHandler( + "/openai/deployments/text-embedding-ada-002/embeddings", + func(w http.ResponseWriter, r *http.Request) { + resBytes, _ := json.Marshal(openai.EmbeddingResponse{Data: sampleEmbeddings}) + fmt.Fprintln(w, string(resBytes)) + }, + ) + // test create embeddings with strings (simple embedding request) + res, err := client.CreateEmbeddings(context.Background(), openai.EmbeddingRequest{ + Model: openai.AdaEmbeddingV2, + }) + checks.NoError(t, err, "CreateEmbeddings error") + if !reflect.DeepEqual(res.Data, sampleEmbeddings) { + t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data) + } +} + func TestEmbeddingResponseBase64_ToEmbeddingResponse(t *testing.T) { type fields struct { Object string From eff8dc1118ea82a1b50ee316608e24d83df74d6b Mon Sep 17 00:00:00 2001 From: Qiying Wang <781345688@qq.com> Date: Thu, 18 Jan 2024 01:42:07 +0800 Subject: [PATCH 49/98] fix(audio): fix audioTextResponse decode (#638) * fix(audio): fix audioTextResponse decode * test(audio): add audioTextResponse decode test * test(audio): simplify code --- client.go | 10 +++++++--- client_test.go | 48 ++++++++++++++++++++++++++++++++++++++++-------- 2 files changed, 47 insertions(+), 11 deletions(-) diff --git a/client.go b/client.go index 056226c6..8bbbb875 100644 --- a/client.go +++ b/client.go @@ -193,10 +193,14 @@ func decodeResponse(body io.Reader, v any) error { return nil } - if result, ok := v.(*string); ok { - return decodeString(body, result) + switch o := v.(type) { + case *string: + return decodeString(body, o) + case *audioTextResponse: + return decodeString(body, &o.Text) + default: + return json.NewDecoder(body).Decode(v) } - return json.NewDecoder(body).Decode(v) } func decodeString(body io.Reader, output *string) error { diff --git a/client_test.go b/client_test.go index 664f9fb9..bc5133ed 100644 --- a/client_test.go +++ b/client_test.go @@ -7,9 +7,11 @@ import ( "fmt" "io" "net/http" + "reflect" "testing" "github.com/sashabaranov/go-openai/internal/test" + "github.com/sashabaranov/go-openai/internal/test/checks" ) var errTestRequestBuilderFailed = errors.New("test request builder failed") @@ -43,23 +45,29 @@ func TestDecodeResponse(t *testing.T) { testCases := []struct { name string value interface{} + expected interface{} body io.Reader hasError bool }{ { - name: "nil input", - value: nil, - body: bytes.NewReader([]byte("")), + name: "nil input", + value: nil, + body: bytes.NewReader([]byte("")), + expected: nil, }, { - name: "string input", - value: &stringInput, - body: bytes.NewReader([]byte("test")), + name: "string input", + value: &stringInput, + body: bytes.NewReader([]byte("test")), + expected: "test", }, { name: "map input", value: &map[string]interface{}{}, body: bytes.NewReader([]byte(`{"test": "test"}`)), + expected: map[string]interface{}{ + "test": "test", + }, }, { name: "reader return error", @@ -67,14 +75,38 @@ func TestDecodeResponse(t *testing.T) { body: &errorReader{err: errors.New("dummy")}, hasError: true, }, + { + name: "audio text input", + value: &audioTextResponse{}, + body: bytes.NewReader([]byte("test")), + expected: audioTextResponse{ + Text: "test", + }, + }, + } + + assertEqual := func(t *testing.T, expected, actual interface{}) { + t.Helper() + if expected == actual { + return + } + v := reflect.ValueOf(actual).Elem().Interface() + if !reflect.DeepEqual(v, expected) { + t.Fatalf("Unexpected value: %v, expected: %v", v, expected) + } } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { err := decodeResponse(tc.body, tc.value) - if (err != nil) != tc.hasError { - t.Errorf("Unexpected error: %v", err) + if tc.hasError { + checks.HasError(t, err, "Unexpected nil error") + return + } + if err != nil { + t.Fatalf("Unexpected error: %v", err) } + assertEqual(t, tc.expected, tc.value) }) } } From 4c41f24a99ad56f707df7c25b8833fb0a374c8c5 Mon Sep 17 00:00:00 2001 From: Daniil <7709243+bazuker@users.noreply.github.com> Date: Fri, 26 Jan 2024 00:41:48 -0800 Subject: [PATCH 50/98] Support January 25, 2024, models update. (#644) --- completion.go | 6 +++++- embeddings.go | 4 +++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/completion.go b/completion.go index 2709c8b0..6326a72a 100644 --- a/completion.go +++ b/completion.go @@ -22,7 +22,9 @@ const ( GPT432K = "gpt-4-32k" GPT40613 = "gpt-4-0613" GPT40314 = "gpt-4-0314" - GPT4TurboPreview = "gpt-4-1106-preview" + GPT4Turbo0125 = "gpt-4-0125-preview" + GPT4Turbo1106 = "gpt-4-1106-preview" + GPT4TurboPreview = "gpt-4-turbo-preview" GPT4VisionPreview = "gpt-4-vision-preview" GPT4 = "gpt-4" GPT3Dot5Turbo1106 = "gpt-3.5-turbo-1106" @@ -78,6 +80,8 @@ var disabledModelsForEndpoints = map[string]map[string]bool{ GPT4: true, GPT4TurboPreview: true, GPT4VisionPreview: true, + GPT4Turbo1106: true, + GPT4Turbo0125: true, GPT40314: true, GPT40613: true, GPT432K: true, diff --git a/embeddings.go b/embeddings.go index c144119f..517027f5 100644 --- a/embeddings.go +++ b/embeddings.go @@ -34,7 +34,9 @@ const ( BabbageCodeSearchCode EmbeddingModel = "code-search-babbage-code-001" BabbageCodeSearchText EmbeddingModel = "code-search-babbage-text-001" - AdaEmbeddingV2 EmbeddingModel = "text-embedding-ada-002" + AdaEmbeddingV2 EmbeddingModel = "text-embedding-ada-002" + SmallEmbedding3 EmbeddingModel = "text-embedding-3-small" + LargeEmbedding3 EmbeddingModel = "text-embedding-3-large" ) // Embedding is a special format of data representation that can be easily utilized by machine From 06ff541559eaf66482a89202da946644b6c96510 Mon Sep 17 00:00:00 2001 From: chenhhA <463474838@qq.com> Date: Mon, 29 Jan 2024 15:09:56 +0800 Subject: [PATCH 51/98] Add new struct filed dimensions for embedding API (#645) * add new struct filed dimensions for embedding API * docs: remove long single-line comments * change embedding request param Dimensions type to int --- embeddings.go | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/embeddings.go b/embeddings.go index 517027f5..c5633a31 100644 --- a/embeddings.go +++ b/embeddings.go @@ -157,6 +157,9 @@ type EmbeddingRequest struct { Model EmbeddingModel `json:"model"` User string `json:"user"` EncodingFormat EmbeddingEncodingFormat `json:"encoding_format,omitempty"` + // Dimensions The number of dimensions the resulting output embeddings should have. + // Only supported in text-embedding-3 and later models. + Dimensions int `json:"dimensions,omitempty"` } func (r EmbeddingRequest) Convert() EmbeddingRequest { @@ -181,6 +184,9 @@ type EmbeddingRequestStrings struct { // Currently, only "float" and "base64" are supported, however, "base64" is not officially documented. // If not specified OpenAI will use "float". EncodingFormat EmbeddingEncodingFormat `json:"encoding_format,omitempty"` + // Dimensions The number of dimensions the resulting output embeddings should have. + // Only supported in text-embedding-3 and later models. + Dimensions int `json:"dimensions,omitempty"` } func (r EmbeddingRequestStrings) Convert() EmbeddingRequest { @@ -189,6 +195,7 @@ func (r EmbeddingRequestStrings) Convert() EmbeddingRequest { Model: r.Model, User: r.User, EncodingFormat: r.EncodingFormat, + Dimensions: r.Dimensions, } } @@ -209,6 +216,9 @@ type EmbeddingRequestTokens struct { // Currently, only "float" and "base64" are supported, however, "base64" is not officially documented. // If not specified OpenAI will use "float". EncodingFormat EmbeddingEncodingFormat `json:"encoding_format,omitempty"` + // Dimensions The number of dimensions the resulting output embeddings should have. + // Only supported in text-embedding-3 and later models. + Dimensions int `json:"dimensions,omitempty"` } func (r EmbeddingRequestTokens) Convert() EmbeddingRequest { @@ -217,6 +227,7 @@ func (r EmbeddingRequestTokens) Convert() EmbeddingRequest { Model: r.Model, User: r.User, EncodingFormat: r.EncodingFormat, + Dimensions: r.Dimensions, } } From bc8cdd33d158ea165fcecde4a64fc5f1580f0192 Mon Sep 17 00:00:00 2001 From: Liu Shuang Date: Fri, 2 Feb 2024 18:30:24 +0800 Subject: [PATCH 52/98] add GPT3Dot5Turbo0125 model (#648) --- completion.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/completion.go b/completion.go index 6326a72a..ab1dbd6c 100644 --- a/completion.go +++ b/completion.go @@ -27,6 +27,7 @@ const ( GPT4TurboPreview = "gpt-4-turbo-preview" GPT4VisionPreview = "gpt-4-vision-preview" GPT4 = "gpt-4" + GPT3Dot5Turbo0125 = "gpt-3.5-turbo-0125" GPT3Dot5Turbo1106 = "gpt-3.5-turbo-1106" GPT3Dot5Turbo0613 = "gpt-3.5-turbo-0613" GPT3Dot5Turbo0301 = "gpt-3.5-turbo-0301" @@ -75,6 +76,7 @@ var disabledModelsForEndpoints = map[string]map[string]bool{ GPT3Dot5Turbo0301: true, GPT3Dot5Turbo0613: true, GPT3Dot5Turbo1106: true, + GPT3Dot5Turbo0125: true, GPT3Dot5Turbo16K: true, GPT3Dot5Turbo16K0613: true, GPT4: true, From bb6ed545306ba56b99d297a77da0a93b0bcfb80e Mon Sep 17 00:00:00 2001 From: shadowpigy <71599610+shadowpigy@users.noreply.github.com> Date: Fri, 2 Feb 2024 20:41:39 +0800 Subject: [PATCH 53/98] Fix: Add RunStatusCancelled (#650) Co-authored-by: shadowpigy --- run.go | 1 + 1 file changed, 1 insertion(+) diff --git a/run.go b/run.go index dbb708a1..d0675657 100644 --- a/run.go +++ b/run.go @@ -40,6 +40,7 @@ const ( RunStatusFailed RunStatus = "failed" RunStatusCompleted RunStatus = "completed" RunStatusExpired RunStatus = "expired" + RunStatusCancelled RunStatus = "cancelled" ) type RunRequiredAction struct { From 69e3fcbc2726d208d34e9d89089b47ebebdff01b Mon Sep 17 00:00:00 2001 From: chrbsg <52408325+chrbsg@users.noreply.github.com> Date: Tue, 6 Feb 2024 19:04:40 +0000 Subject: [PATCH 54/98] Fix typo assitantInstructions (#655) --- assistant_test.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/assistant_test.go b/assistant_test.go index 30daec2b..9e1e3f38 100644 --- a/assistant_test.go +++ b/assistant_test.go @@ -17,7 +17,7 @@ func TestAssistant(t *testing.T) { assistantID := "asst_abc123" assistantName := "Ambrogio" assistantDescription := "Ambrogio is a friendly assistant." - assitantInstructions := `You are a personal math tutor. + assistantInstructions := `You are a personal math tutor. When asked a question, write and run Python code to answer the question.` assistantFileID := "file-wB6RM6wHdA49HfS2DJ9fEyrH" limit := 20 @@ -92,7 +92,7 @@ When asked a question, write and run Python code to answer the question.` Name: &assistantName, Model: openai.GPT4TurboPreview, Description: &assistantDescription, - Instructions: &assitantInstructions, + Instructions: &assistantInstructions, }) fmt.Fprintln(w, string(resBytes)) case http.MethodPost: @@ -152,7 +152,7 @@ When asked a question, write and run Python code to answer the question.` Name: &assistantName, Model: openai.GPT4TurboPreview, Description: &assistantDescription, - Instructions: &assitantInstructions, + Instructions: &assistantInstructions, }, }, }) @@ -167,7 +167,7 @@ When asked a question, write and run Python code to answer the question.` Name: &assistantName, Description: &assistantDescription, Model: openai.GPT4TurboPreview, - Instructions: &assitantInstructions, + Instructions: &assistantInstructions, }) checks.NoError(t, err, "CreateAssistant error") @@ -178,7 +178,7 @@ When asked a question, write and run Python code to answer the question.` Name: &assistantName, Description: &assistantDescription, Model: openai.GPT4TurboPreview, - Instructions: &assitantInstructions, + Instructions: &assistantInstructions, }) checks.NoError(t, err, "ModifyAssistant error") From 6c2e3162dfe3b32cbd1d026043957f8e589e987c Mon Sep 17 00:00:00 2001 From: "xuanming.zhang" Date: Thu, 8 Feb 2024 15:40:39 +0800 Subject: [PATCH 55/98] Added support for CreateSpeech Azure models (#657) --- speech.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/speech.go b/speech.go index f2442b92..b9344ac6 100644 --- a/speech.go +++ b/speech.go @@ -74,7 +74,7 @@ func (c *Client) CreateSpeech(ctx context.Context, request CreateSpeechRequest) err = ErrInvalidVoice return } - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/audio/speech", request.Model), + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/audio/speech", string(request.Model)), withBody(request), withContentType("application/json; charset=utf-8"), ) From a7954c854c89f45d3f5df62aab8df688b4c20b20 Mon Sep 17 00:00:00 2001 From: shadowpigy <71599610+shadowpigy@users.noreply.github.com> Date: Thu, 8 Feb 2024 21:08:30 +0800 Subject: [PATCH 56/98] Feat: Add assistant usage (#649) * Feat: Add assistant usage --------- Co-authored-by: shadowpigy --- run.go | 1 + 1 file changed, 1 insertion(+) diff --git a/run.go b/run.go index d0675657..4befe0b4 100644 --- a/run.go +++ b/run.go @@ -26,6 +26,7 @@ type Run struct { Tools []Tool `json:"tools"` FileIDS []string `json:"file_ids"` Metadata map[string]any `json:"metadata"` + Usage Usage `json:"usage,omitempty"` httpHeader } From 11ad4b69d0f0dc61ed8777ac2d54a6787c8d2fea Mon Sep 17 00:00:00 2001 From: Alexander Baranov <677093+sashabaranov@users.noreply.github.com> Date: Thu, 15 Feb 2024 16:02:48 +0400 Subject: [PATCH 57/98] make linter happy (#661) --- embeddings_test.go | 2 +- files_api_test.go | 10 +++++----- image_test.go | 10 +++++----- messages.go | 4 ++-- models_test.go | 2 +- run.go | 2 +- stream_test.go | 14 +++++++------- 7 files changed, 22 insertions(+), 22 deletions(-) diff --git a/embeddings_test.go b/embeddings_test.go index ed6384f3..43897816 100644 --- a/embeddings_test.go +++ b/embeddings_test.go @@ -169,7 +169,7 @@ func TestAzureEmbeddingEndpoint(t *testing.T) { server.RegisterHandler( "/openai/deployments/text-embedding-ada-002/embeddings", - func(w http.ResponseWriter, r *http.Request) { + func(w http.ResponseWriter, _ *http.Request) { resBytes, _ := json.Marshal(openai.EmbeddingResponse{Data: sampleEmbeddings}) fmt.Fprintln(w, string(resBytes)) }, diff --git a/files_api_test.go b/files_api_test.go index 6f62a3fb..c92162a8 100644 --- a/files_api_test.go +++ b/files_api_test.go @@ -86,7 +86,7 @@ func handleCreateFile(w http.ResponseWriter, r *http.Request) { func TestDeleteFile(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() - server.RegisterHandler("/v1/files/deadbeef", func(w http.ResponseWriter, r *http.Request) {}) + server.RegisterHandler("/v1/files/deadbeef", func(http.ResponseWriter, *http.Request) {}) err := client.DeleteFile(context.Background(), "deadbeef") checks.NoError(t, err, "DeleteFile error") } @@ -94,7 +94,7 @@ func TestDeleteFile(t *testing.T) { func TestListFile(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() - server.RegisterHandler("/v1/files", func(w http.ResponseWriter, r *http.Request) { + server.RegisterHandler("/v1/files", func(w http.ResponseWriter, _ *http.Request) { resBytes, _ := json.Marshal(openai.FilesList{}) fmt.Fprintln(w, string(resBytes)) }) @@ -105,7 +105,7 @@ func TestListFile(t *testing.T) { func TestGetFile(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() - server.RegisterHandler("/v1/files/deadbeef", func(w http.ResponseWriter, r *http.Request) { + server.RegisterHandler("/v1/files/deadbeef", func(w http.ResponseWriter, _ *http.Request) { resBytes, _ := json.Marshal(openai.File{}) fmt.Fprintln(w, string(resBytes)) }) @@ -151,7 +151,7 @@ func TestGetFileContentReturnError(t *testing.T) { }` client, server, teardown := setupOpenAITestServer() defer teardown() - server.RegisterHandler("/v1/files/deadbeef/content", func(w http.ResponseWriter, r *http.Request) { + server.RegisterHandler("/v1/files/deadbeef/content", func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusBadRequest) fmt.Fprint(w, wantErrorResp) }) @@ -178,7 +178,7 @@ func TestGetFileContentReturnError(t *testing.T) { func TestGetFileContentReturnTimeoutError(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() - server.RegisterHandler("/v1/files/deadbeef/content", func(w http.ResponseWriter, r *http.Request) { + server.RegisterHandler("/v1/files/deadbeef/content", func(http.ResponseWriter, *http.Request) { time.Sleep(10 * time.Nanosecond) }) ctx := context.Background() diff --git a/image_test.go b/image_test.go index 81fff6cb..9332dd5c 100644 --- a/image_test.go +++ b/image_test.go @@ -60,7 +60,7 @@ func TestImageFormBuilderFailures(t *testing.T) { _, err := client.CreateEditImage(ctx, req) checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails") - mockBuilder.mockCreateFormFile = func(name string, file *os.File) error { + mockBuilder.mockCreateFormFile = func(name string, _ *os.File) error { if name == "mask" { return mockFailedErr } @@ -69,12 +69,12 @@ func TestImageFormBuilderFailures(t *testing.T) { _, err = client.CreateEditImage(ctx, req) checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails") - mockBuilder.mockCreateFormFile = func(name string, file *os.File) error { + mockBuilder.mockCreateFormFile = func(string, *os.File) error { return nil } var failForField string - mockBuilder.mockWriteField = func(fieldname, value string) error { + mockBuilder.mockWriteField = func(fieldname, _ string) error { if fieldname == failForField { return mockFailedErr } @@ -125,12 +125,12 @@ func TestVariImageFormBuilderFailures(t *testing.T) { _, err := client.CreateVariImage(ctx, req) checks.ErrorIs(t, err, mockFailedErr, "CreateVariImage should return error if form builder fails") - mockBuilder.mockCreateFormFile = func(name string, file *os.File) error { + mockBuilder.mockCreateFormFile = func(string, *os.File) error { return nil } var failForField string - mockBuilder.mockWriteField = func(fieldname, value string) error { + mockBuilder.mockWriteField = func(fieldname, _ string) error { if fieldname == failForField { return mockFailedErr } diff --git a/messages.go b/messages.go index ead247f5..86146323 100644 --- a/messages.go +++ b/messages.go @@ -18,7 +18,7 @@ type Message struct { ThreadID string `json:"thread_id"` Role string `json:"role"` Content []MessageContent `json:"content"` - FileIds []string `json:"file_ids"` + FileIds []string `json:"file_ids"` //nolint:revive //backwards-compatibility AssistantID *string `json:"assistant_id,omitempty"` RunID *string `json:"run_id,omitempty"` Metadata map[string]any `json:"metadata"` @@ -54,7 +54,7 @@ type ImageFile struct { type MessageRequest struct { Role string `json:"role"` Content string `json:"content"` - FileIds []string `json:"file_ids,omitempty"` + FileIds []string `json:"file_ids,omitempty"` //nolint:revive // backwards-compatibility Metadata map[string]any `json:"metadata,omitempty"` } diff --git a/models_test.go b/models_test.go index 4a4c759d..24a28ed2 100644 --- a/models_test.go +++ b/models_test.go @@ -64,7 +64,7 @@ func handleGetModelEndpoint(w http.ResponseWriter, _ *http.Request) { func TestGetModelReturnTimeoutError(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() - server.RegisterHandler("/v1/models/text-davinci-003", func(w http.ResponseWriter, r *http.Request) { + server.RegisterHandler("/v1/models/text-davinci-003", func(http.ResponseWriter, *http.Request) { time.Sleep(10 * time.Nanosecond) }) ctx := context.Background() diff --git a/run.go b/run.go index 4befe0b4..ba09366c 100644 --- a/run.go +++ b/run.go @@ -24,7 +24,7 @@ type Run struct { Model string `json:"model"` Instructions string `json:"instructions,omitempty"` Tools []Tool `json:"tools"` - FileIDS []string `json:"file_ids"` + FileIDS []string `json:"file_ids"` //nolint:revive // backwards-compatibility Metadata map[string]any `json:"metadata"` Usage Usage `json:"usage,omitempty"` diff --git a/stream_test.go b/stream_test.go index 35c52ae3..2822a353 100644 --- a/stream_test.go +++ b/stream_test.go @@ -34,7 +34,7 @@ func TestCompletionsStreamWrongModel(t *testing.T) { func TestCreateCompletionStream(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() - server.RegisterHandler("/v1/completions", func(w http.ResponseWriter, r *http.Request) { + server.RegisterHandler("/v1/completions", func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "text/event-stream") // Send test responses @@ -106,7 +106,7 @@ func TestCreateCompletionStream(t *testing.T) { func TestCreateCompletionStreamError(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() - server.RegisterHandler("/v1/completions", func(w http.ResponseWriter, r *http.Request) { + server.RegisterHandler("/v1/completions", func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "text/event-stream") // Send test responses @@ -151,7 +151,7 @@ func TestCreateCompletionStreamError(t *testing.T) { func TestCreateCompletionStreamRateLimitError(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() - server.RegisterHandler("/v1/completions", func(w http.ResponseWriter, r *http.Request) { + server.RegisterHandler("/v1/completions", func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(429) @@ -182,7 +182,7 @@ func TestCreateCompletionStreamRateLimitError(t *testing.T) { func TestCreateCompletionStreamTooManyEmptyStreamMessagesError(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() - server.RegisterHandler("/v1/completions", func(w http.ResponseWriter, r *http.Request) { + server.RegisterHandler("/v1/completions", func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "text/event-stream") // Send test responses @@ -228,7 +228,7 @@ func TestCreateCompletionStreamTooManyEmptyStreamMessagesError(t *testing.T) { func TestCreateCompletionStreamUnexpectedTerminatedError(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() - server.RegisterHandler("/v1/completions", func(w http.ResponseWriter, r *http.Request) { + server.RegisterHandler("/v1/completions", func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "text/event-stream") // Send test responses @@ -263,7 +263,7 @@ func TestCreateCompletionStreamUnexpectedTerminatedError(t *testing.T) { func TestCreateCompletionStreamBrokenJSONError(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() - server.RegisterHandler("/v1/completions", func(w http.ResponseWriter, r *http.Request) { + server.RegisterHandler("/v1/completions", func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "text/event-stream") // Send test responses @@ -305,7 +305,7 @@ func TestCreateCompletionStreamBrokenJSONError(t *testing.T) { func TestCreateCompletionStreamReturnTimeoutError(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() - server.RegisterHandler("/v1/completions", func(w http.ResponseWriter, r *http.Request) { + server.RegisterHandler("/v1/completions", func(http.ResponseWriter, *http.Request) { time.Sleep(10 * time.Nanosecond) }) ctx := context.Background() From 66bae3ee7329619b27ba8bcb185e0d333e9b3e26 Mon Sep 17 00:00:00 2001 From: grulex Date: Thu, 15 Feb 2024 16:11:58 +0000 Subject: [PATCH 58/98] Content-type fix (#659) * charset fixes * make linter happy (#661) --------- Co-authored-by: grulex Co-authored-by: Alexander Baranov <677093+sashabaranov@users.noreply.github.com> --- client.go | 4 ++-- speech.go | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/client.go b/client.go index 8bbbb875..55c48bd4 100644 --- a/client.go +++ b/client.go @@ -107,13 +107,13 @@ func (c *Client) newRequest(ctx context.Context, method, url string, setters ... } func (c *Client) sendRequest(req *http.Request, v Response) error { - req.Header.Set("Accept", "application/json; charset=utf-8") + req.Header.Set("Accept", "application/json") // Check whether Content-Type is already set, Upload Files API requires // Content-Type == multipart/form-data contentType := req.Header.Get("Content-Type") if contentType == "" { - req.Header.Set("Content-Type", "application/json; charset=utf-8") + req.Header.Set("Content-Type", "application/json") } res, err := c.config.HTTPClient.Do(req) diff --git a/speech.go b/speech.go index b9344ac6..be895021 100644 --- a/speech.go +++ b/speech.go @@ -76,7 +76,7 @@ func (c *Client) CreateSpeech(ctx context.Context, request CreateSpeechRequest) } req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/audio/speech", string(request.Model)), withBody(request), - withContentType("application/json; charset=utf-8"), + withContentType("application/json"), ) if err != nil { return From ff61bbb32253aad84c6cc96bf9be3884aa8cde88 Mon Sep 17 00:00:00 2001 From: chrbsg <52408325+chrbsg@users.noreply.github.com> Date: Thu, 15 Feb 2024 16:12:22 +0000 Subject: [PATCH 59/98] Add RunRequest field AdditionalInstructions (#656) AdditionalInstructions is an optional string field used to append additional instructions at the end of the instructions for the run. This is useful for modifying the behavior on a per-run basis without overriding other instructions. Also, change the Model and Instructions *string fields to string. --- run.go | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/run.go b/run.go index ba09366c..1f3cb7eb 100644 --- a/run.go +++ b/run.go @@ -72,11 +72,12 @@ const ( ) type RunRequest struct { - AssistantID string `json:"assistant_id"` - Model *string `json:"model,omitempty"` - Instructions *string `json:"instructions,omitempty"` - Tools []Tool `json:"tools,omitempty"` - Metadata map[string]any `json:"metadata,omitempty"` + AssistantID string `json:"assistant_id"` + Model string `json:"model,omitempty"` + Instructions string `json:"instructions,omitempty"` + AdditionalInstructions string `json:"additional_instructions,omitempty"` + Tools []Tool `json:"tools,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` } type RunModifyRequest struct { From 69e3bbb1eb05a5c1b27a29fc9a83d02d0d040e27 Mon Sep 17 00:00:00 2001 From: Igor Berlenko Date: Fri, 16 Feb 2024 18:22:38 +0800 Subject: [PATCH 60/98] Update client.go - allow to skip Authorization header (#658) * Update client.go - allow to skip Authorization header * Update client.go --- client.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client.go b/client.go index 55c48bd4..7fdc36ca 100644 --- a/client.go +++ b/client.go @@ -175,7 +175,7 @@ func (c *Client) setCommonHeaders(req *http.Request) { // Azure API Key authentication if c.config.APIType == APITypeAzure { req.Header.Set(AzureAPIKeyHeader, c.config.authToken) - } else { + } else if c.config.authToken != "" { // OpenAI or Azure AD authentication req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken)) } From e8b347891b21187740d594409b1c11fb0846577e Mon Sep 17 00:00:00 2001 From: CaoPengFlying Date: Mon, 19 Feb 2024 20:26:04 +0800 Subject: [PATCH 61/98] fix:fix open ai original validation. modify Tool's Function to pointer (#664) Co-authored-by: caopengfei1 --- chat.go | 4 ++-- examples/completion-with-tool/main.go | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/chat.go b/chat.go index 33b8755c..efb14fd4 100644 --- a/chat.go +++ b/chat.go @@ -225,8 +225,8 @@ const ( ) type Tool struct { - Type ToolType `json:"type"` - Function FunctionDefinition `json:"function,omitempty"` + Type ToolType `json:"type"` + Function *FunctionDefinition `json:"function,omitempty"` } type ToolChoice struct { diff --git a/examples/completion-with-tool/main.go b/examples/completion-with-tool/main.go index 2c7fedc5..26126e41 100644 --- a/examples/completion-with-tool/main.go +++ b/examples/completion-with-tool/main.go @@ -35,7 +35,7 @@ func main() { } t := openai.Tool{ Type: openai.ToolTypeFunction, - Function: f, + Function: &f, } // simulate user asking a question that requires the function From 7381d18a75a673d569c7dc7657407381e5c84dd5 Mon Sep 17 00:00:00 2001 From: Rich Coggins <57115183+coggsflod@users.noreply.github.com> Date: Wed, 21 Feb 2024 07:45:15 -0500 Subject: [PATCH 62/98] Fix for broken Azure Assistants url (#665) * fix:fix url for Azure assistants api * test:add unit tests for Azure Assistants api * fix:minor liniting issue --- assistant_test.go | 190 ++++++++++++++++++++++++++++++++++++++++++++++ client.go | 2 +- 2 files changed, 191 insertions(+), 1 deletion(-) diff --git a/assistant_test.go b/assistant_test.go index 9e1e3f38..48bc6f91 100644 --- a/assistant_test.go +++ b/assistant_test.go @@ -202,3 +202,193 @@ When asked a question, write and run Python code to answer the question.` err = client.DeleteAssistantFile(ctx, assistantID, assistantFileID) checks.NoError(t, err, "DeleteAssistantFile error") } + +func TestAzureAssistant(t *testing.T) { + assistantID := "asst_abc123" + assistantName := "Ambrogio" + assistantDescription := "Ambrogio is a friendly assistant." + assistantInstructions := `You are a personal math tutor. +When asked a question, write and run Python code to answer the question.` + assistantFileID := "file-wB6RM6wHdA49HfS2DJ9fEyrH" + limit := 20 + order := "desc" + after := "asst_abc122" + before := "asst_abc124" + + client, server, teardown := setupAzureTestServer() + defer teardown() + + server.RegisterHandler( + "/openai/assistants/"+assistantID+"/files/"+assistantFileID, + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet { + resBytes, _ := json.Marshal(openai.AssistantFile{ + ID: assistantFileID, + Object: "assistant.file", + CreatedAt: 1234567890, + AssistantID: assistantID, + }) + fmt.Fprintln(w, string(resBytes)) + } else if r.Method == http.MethodDelete { + fmt.Fprintln(w, `{ + id: "file-wB6RM6wHdA49HfS2DJ9fEyrH", + object: "assistant.file.deleted", + deleted: true + }`) + } + }, + ) + + server.RegisterHandler( + "/openai/assistants/"+assistantID+"/files", + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet { + resBytes, _ := json.Marshal(openai.AssistantFilesList{ + AssistantFiles: []openai.AssistantFile{ + { + ID: assistantFileID, + Object: "assistant.file", + CreatedAt: 1234567890, + AssistantID: assistantID, + }, + }, + }) + fmt.Fprintln(w, string(resBytes)) + } else if r.Method == http.MethodPost { + var request openai.AssistantFileRequest + err := json.NewDecoder(r.Body).Decode(&request) + checks.NoError(t, err, "Decode error") + + resBytes, _ := json.Marshal(openai.AssistantFile{ + ID: request.FileID, + Object: "assistant.file", + CreatedAt: 1234567890, + AssistantID: assistantID, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + server.RegisterHandler( + "/openai/assistants/"+assistantID, + func(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet: + resBytes, _ := json.Marshal(openai.Assistant{ + ID: assistantID, + Object: "assistant", + CreatedAt: 1234567890, + Name: &assistantName, + Model: openai.GPT4TurboPreview, + Description: &assistantDescription, + Instructions: &assistantInstructions, + }) + fmt.Fprintln(w, string(resBytes)) + case http.MethodPost: + var request openai.AssistantRequest + err := json.NewDecoder(r.Body).Decode(&request) + checks.NoError(t, err, "Decode error") + + resBytes, _ := json.Marshal(openai.Assistant{ + ID: assistantID, + Object: "assistant", + CreatedAt: 1234567890, + Name: request.Name, + Model: request.Model, + Description: request.Description, + Instructions: request.Instructions, + Tools: request.Tools, + }) + fmt.Fprintln(w, string(resBytes)) + case http.MethodDelete: + fmt.Fprintln(w, `{ + "id": "asst_abc123", + "object": "assistant.deleted", + "deleted": true + }`) + } + }, + ) + + server.RegisterHandler( + "/openai/assistants", + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + var request openai.AssistantRequest + err := json.NewDecoder(r.Body).Decode(&request) + checks.NoError(t, err, "Decode error") + + resBytes, _ := json.Marshal(openai.Assistant{ + ID: assistantID, + Object: "assistant", + CreatedAt: 1234567890, + Name: request.Name, + Model: request.Model, + Description: request.Description, + Instructions: request.Instructions, + Tools: request.Tools, + }) + fmt.Fprintln(w, string(resBytes)) + } else if r.Method == http.MethodGet { + resBytes, _ := json.Marshal(openai.AssistantsList{ + LastID: &assistantID, + FirstID: &assistantID, + Assistants: []openai.Assistant{ + { + ID: assistantID, + Object: "assistant", + CreatedAt: 1234567890, + Name: &assistantName, + Model: openai.GPT4TurboPreview, + Description: &assistantDescription, + Instructions: &assistantInstructions, + }, + }, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + ctx := context.Background() + + _, err := client.CreateAssistant(ctx, openai.AssistantRequest{ + Name: &assistantName, + Description: &assistantDescription, + Model: openai.GPT4TurboPreview, + Instructions: &assistantInstructions, + }) + checks.NoError(t, err, "CreateAssistant error") + + _, err = client.RetrieveAssistant(ctx, assistantID) + checks.NoError(t, err, "RetrieveAssistant error") + + _, err = client.ModifyAssistant(ctx, assistantID, openai.AssistantRequest{ + Name: &assistantName, + Description: &assistantDescription, + Model: openai.GPT4TurboPreview, + Instructions: &assistantInstructions, + }) + checks.NoError(t, err, "ModifyAssistant error") + + _, err = client.DeleteAssistant(ctx, assistantID) + checks.NoError(t, err, "DeleteAssistant error") + + _, err = client.ListAssistants(ctx, &limit, &order, &after, &before) + checks.NoError(t, err, "ListAssistants error") + + _, err = client.CreateAssistantFile(ctx, assistantID, openai.AssistantFileRequest{ + FileID: assistantFileID, + }) + checks.NoError(t, err, "CreateAssistantFile error") + + _, err = client.ListAssistantFiles(ctx, assistantID, &limit, &order, &after, &before) + checks.NoError(t, err, "ListAssistantFiles error") + + _, err = client.RetrieveAssistantFile(ctx, assistantID, assistantFileID) + checks.NoError(t, err, "RetrieveAssistantFile error") + + err = client.DeleteAssistantFile(ctx, assistantID, assistantFileID) + checks.NoError(t, err, "DeleteAssistantFile error") +} diff --git a/client.go b/client.go index 7fdc36ca..e7a4d5be 100644 --- a/client.go +++ b/client.go @@ -221,7 +221,7 @@ func (c *Client) fullURL(suffix string, args ...any) string { baseURL = strings.TrimRight(baseURL, "/") // if suffix is /models change to {endpoint}/openai/models?api-version=2022-12-01 // https://learn.microsoft.com/en-us/rest/api/cognitiveservices/azureopenaistable/models/list?tabs=HTTP - if strings.Contains(suffix, "/models") { + if strings.Contains(suffix, "/models") || strings.Contains(suffix, "/assistants") { return fmt.Sprintf("%s/%s%s?api-version=%s", baseURL, azureAPIPrefix, suffix, c.config.APIVersion) } azureDeploymentName := "UNKNOWN" From c5401e9e6417ac2b5374993ccff1f40010e03f52 Mon Sep 17 00:00:00 2001 From: Rich Coggins <57115183+coggsflod@users.noreply.github.com> Date: Mon, 26 Feb 2024 03:46:35 -0500 Subject: [PATCH 63/98] Fix for broken Azure Threads url (#668) --- client.go | 11 ++++++- thread_test.go | 83 ++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 93 insertions(+), 1 deletion(-) diff --git a/client.go b/client.go index e7a4d5be..7b1a313a 100644 --- a/client.go +++ b/client.go @@ -221,7 +221,7 @@ func (c *Client) fullURL(suffix string, args ...any) string { baseURL = strings.TrimRight(baseURL, "/") // if suffix is /models change to {endpoint}/openai/models?api-version=2022-12-01 // https://learn.microsoft.com/en-us/rest/api/cognitiveservices/azureopenaistable/models/list?tabs=HTTP - if strings.Contains(suffix, "/models") || strings.Contains(suffix, "/assistants") { + if containsSubstr([]string{"/models", "/assistants", "/threads", "/files"}, suffix) { return fmt.Sprintf("%s/%s%s?api-version=%s", baseURL, azureAPIPrefix, suffix, c.config.APIVersion) } azureDeploymentName := "UNKNOWN" @@ -258,3 +258,12 @@ func (c *Client) handleErrorResp(resp *http.Response) error { errRes.Error.HTTPStatusCode = resp.StatusCode return errRes.Error } + +func containsSubstr(s []string, e string) bool { + for _, v := range s { + if strings.Contains(e, v) { + return true + } + } + return false +} diff --git a/thread_test.go b/thread_test.go index 227ab633..1ac0f3c0 100644 --- a/thread_test.go +++ b/thread_test.go @@ -93,3 +93,86 @@ func TestThread(t *testing.T) { _, err = client.DeleteThread(ctx, threadID) checks.NoError(t, err, "DeleteThread error") } + +// TestAzureThread Tests the thread endpoint of the API using the Azure mocked server. +func TestAzureThread(t *testing.T) { + threadID := "thread_abc123" + client, server, teardown := setupAzureTestServer() + defer teardown() + + server.RegisterHandler( + "/openai/threads/"+threadID, + func(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet: + resBytes, _ := json.Marshal(openai.Thread{ + ID: threadID, + Object: "thread", + CreatedAt: 1234567890, + }) + fmt.Fprintln(w, string(resBytes)) + case http.MethodPost: + var request openai.ThreadRequest + err := json.NewDecoder(r.Body).Decode(&request) + checks.NoError(t, err, "Decode error") + + resBytes, _ := json.Marshal(openai.Thread{ + ID: threadID, + Object: "thread", + CreatedAt: 1234567890, + }) + fmt.Fprintln(w, string(resBytes)) + case http.MethodDelete: + fmt.Fprintln(w, `{ + "id": "thread_abc123", + "object": "thread.deleted", + "deleted": true + }`) + } + }, + ) + + server.RegisterHandler( + "/openai/threads", + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + var request openai.ModifyThreadRequest + err := json.NewDecoder(r.Body).Decode(&request) + checks.NoError(t, err, "Decode error") + + resBytes, _ := json.Marshal(openai.Thread{ + ID: threadID, + Object: "thread", + CreatedAt: 1234567890, + Metadata: request.Metadata, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + ctx := context.Background() + + _, err := client.CreateThread(ctx, openai.ThreadRequest{ + Messages: []openai.ThreadMessage{ + { + Role: openai.ThreadMessageRoleUser, + Content: "Hello, World!", + }, + }, + }) + checks.NoError(t, err, "CreateThread error") + + _, err = client.RetrieveThread(ctx, threadID) + checks.NoError(t, err, "RetrieveThread error") + + _, err = client.ModifyThread(ctx, threadID, openai.ModifyThreadRequest{ + Metadata: map[string]interface{}{ + "key": "value", + }, + }) + checks.NoError(t, err, "ModifyThread error") + + _, err = client.DeleteThread(ctx, threadID) + checks.NoError(t, err, "DeleteThread error") +} From f2204439857a1085207e74c8f05abf6c8248d336 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20Oester?= <56402078+raphoester@users.noreply.github.com> Date: Mon, 26 Feb 2024 10:48:09 +0200 Subject: [PATCH 64/98] Added fields for moderation (#662) --- moderation.go | 36 ++++++++++++++++++++++-------------- moderation_test.go | 43 +++++++++++++++++++++++++++++++++++++------ 2 files changed, 59 insertions(+), 20 deletions(-) diff --git a/moderation.go b/moderation.go index f8d20ee5..45d05248 100644 --- a/moderation.go +++ b/moderation.go @@ -44,24 +44,32 @@ type Result struct { // ResultCategories represents Categories of Result. type ResultCategories struct { - Hate bool `json:"hate"` - HateThreatening bool `json:"hate/threatening"` - SelfHarm bool `json:"self-harm"` - Sexual bool `json:"sexual"` - SexualMinors bool `json:"sexual/minors"` - Violence bool `json:"violence"` - ViolenceGraphic bool `json:"violence/graphic"` + Hate bool `json:"hate"` + HateThreatening bool `json:"hate/threatening"` + Harassment bool `json:"harassment"` + HarassmentThreatening bool `json:"harassment/threatening"` + SelfHarm bool `json:"self-harm"` + SelfHarmIntent bool `json:"self-harm/intent"` + SelfHarmInstructions bool `json:"self-harm/instructions"` + Sexual bool `json:"sexual"` + SexualMinors bool `json:"sexual/minors"` + Violence bool `json:"violence"` + ViolenceGraphic bool `json:"violence/graphic"` } // ResultCategoryScores represents CategoryScores of Result. type ResultCategoryScores struct { - Hate float32 `json:"hate"` - HateThreatening float32 `json:"hate/threatening"` - SelfHarm float32 `json:"self-harm"` - Sexual float32 `json:"sexual"` - SexualMinors float32 `json:"sexual/minors"` - Violence float32 `json:"violence"` - ViolenceGraphic float32 `json:"violence/graphic"` + Hate bool `json:"hate"` + HateThreatening bool `json:"hate/threatening"` + Harassment bool `json:"harassment"` + HarassmentThreatening bool `json:"harassment/threatening"` + SelfHarm bool `json:"self-harm"` + SelfHarmIntent bool `json:"self-harm/intent"` + SelfHarmInstructions bool `json:"self-harm/instructions"` + Sexual bool `json:"sexual"` + SexualMinors bool `json:"sexual/minors"` + Violence bool `json:"violence"` + ViolenceGraphic bool `json:"violence/graphic"` } // ModerationResponse represents a response structure for moderation API. diff --git a/moderation_test.go b/moderation_test.go index 059f0d1c..7fdeb9ba 100644 --- a/moderation_test.go +++ b/moderation_test.go @@ -80,18 +80,49 @@ func handleModerationEndpoint(w http.ResponseWriter, r *http.Request) { resCat := openai.ResultCategories{} resCatScore := openai.ResultCategoryScores{} switch { - case strings.Contains(moderationReq.Input, "kill"): - resCat = openai.ResultCategories{Violence: true} - resCatScore = openai.ResultCategoryScores{Violence: 1} case strings.Contains(moderationReq.Input, "hate"): resCat = openai.ResultCategories{Hate: true} - resCatScore = openai.ResultCategoryScores{Hate: 1} + resCatScore = openai.ResultCategoryScores{Hate: true} + + case strings.Contains(moderationReq.Input, "hate more"): + resCat = openai.ResultCategories{HateThreatening: true} + resCatScore = openai.ResultCategoryScores{HateThreatening: true} + + case strings.Contains(moderationReq.Input, "harass"): + resCat = openai.ResultCategories{Harassment: true} + resCatScore = openai.ResultCategoryScores{Harassment: true} + + case strings.Contains(moderationReq.Input, "harass hard"): + resCat = openai.ResultCategories{Harassment: true} + resCatScore = openai.ResultCategoryScores{HarassmentThreatening: true} + case strings.Contains(moderationReq.Input, "suicide"): resCat = openai.ResultCategories{SelfHarm: true} - resCatScore = openai.ResultCategoryScores{SelfHarm: 1} + resCatScore = openai.ResultCategoryScores{SelfHarm: true} + + case strings.Contains(moderationReq.Input, "wanna suicide"): + resCat = openai.ResultCategories{SelfHarmIntent: true} + resCatScore = openai.ResultCategoryScores{SelfHarm: true} + + case strings.Contains(moderationReq.Input, "drink bleach"): + resCat = openai.ResultCategories{SelfHarmInstructions: true} + resCatScore = openai.ResultCategoryScores{SelfHarmInstructions: true} + case strings.Contains(moderationReq.Input, "porn"): resCat = openai.ResultCategories{Sexual: true} - resCatScore = openai.ResultCategoryScores{Sexual: 1} + resCatScore = openai.ResultCategoryScores{Sexual: true} + + case strings.Contains(moderationReq.Input, "child porn"): + resCat = openai.ResultCategories{SexualMinors: true} + resCatScore = openai.ResultCategoryScores{SexualMinors: true} + + case strings.Contains(moderationReq.Input, "kill"): + resCat = openai.ResultCategories{Violence: true} + resCatScore = openai.ResultCategoryScores{Violence: true} + + case strings.Contains(moderationReq.Input, "corpse"): + resCat = openai.ResultCategories{ViolenceGraphic: true} + resCatScore = openai.ResultCategoryScores{ViolenceGraphic: true} } result := openai.Result{Categories: resCat, CategoryScores: resCatScore, Flagged: true} From 41037783bc7668998900248ed697b90ec36c3f09 Mon Sep 17 00:00:00 2001 From: Guillaume Dussault <146769929+guillaume-dussault@users.noreply.github.com> Date: Mon, 26 Feb 2024 03:48:53 -0500 Subject: [PATCH 65/98] fix: when no Assistant Tools are specified, an empty list should be sent (#669) --- assistant.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/assistant.go b/assistant.go index bd335833..7a7a7652 100644 --- a/assistant.go +++ b/assistant.go @@ -46,7 +46,7 @@ type AssistantRequest struct { Name *string `json:"name,omitempty"` Description *string `json:"description,omitempty"` Instructions *string `json:"instructions,omitempty"` - Tools []AssistantTool `json:"tools,omitempty"` + Tools []AssistantTool `json:"tools"` FileIDs []string `json:"file_ids,omitempty"` Metadata map[string]any `json:"metadata,omitempty"` } From bb6149f64fcb22381b2ef0b5c7d8287a520dc110 Mon Sep 17 00:00:00 2001 From: Martin Heck Date: Wed, 28 Feb 2024 10:25:47 +0100 Subject: [PATCH 66/98] fix: repair json decoding of moderation response (#670) --- moderation.go | 22 +++++++++++----------- moderation_test.go | 22 +++++++++++----------- 2 files changed, 22 insertions(+), 22 deletions(-) diff --git a/moderation.go b/moderation.go index 45d05248..ae285ef8 100644 --- a/moderation.go +++ b/moderation.go @@ -59,17 +59,17 @@ type ResultCategories struct { // ResultCategoryScores represents CategoryScores of Result. type ResultCategoryScores struct { - Hate bool `json:"hate"` - HateThreatening bool `json:"hate/threatening"` - Harassment bool `json:"harassment"` - HarassmentThreatening bool `json:"harassment/threatening"` - SelfHarm bool `json:"self-harm"` - SelfHarmIntent bool `json:"self-harm/intent"` - SelfHarmInstructions bool `json:"self-harm/instructions"` - Sexual bool `json:"sexual"` - SexualMinors bool `json:"sexual/minors"` - Violence bool `json:"violence"` - ViolenceGraphic bool `json:"violence/graphic"` + Hate float32 `json:"hate"` + HateThreatening float32 `json:"hate/threatening"` + Harassment float32 `json:"harassment"` + HarassmentThreatening float32 `json:"harassment/threatening"` + SelfHarm float32 `json:"self-harm"` + SelfHarmIntent float32 `json:"self-harm/intent"` + SelfHarmInstructions float32 `json:"self-harm/instructions"` + Sexual float32 `json:"sexual"` + SexualMinors float32 `json:"sexual/minors"` + Violence float32 `json:"violence"` + ViolenceGraphic float32 `json:"violence/graphic"` } // ModerationResponse represents a response structure for moderation API. diff --git a/moderation_test.go b/moderation_test.go index 7fdeb9ba..61171c38 100644 --- a/moderation_test.go +++ b/moderation_test.go @@ -82,47 +82,47 @@ func handleModerationEndpoint(w http.ResponseWriter, r *http.Request) { switch { case strings.Contains(moderationReq.Input, "hate"): resCat = openai.ResultCategories{Hate: true} - resCatScore = openai.ResultCategoryScores{Hate: true} + resCatScore = openai.ResultCategoryScores{Hate: 1} case strings.Contains(moderationReq.Input, "hate more"): resCat = openai.ResultCategories{HateThreatening: true} - resCatScore = openai.ResultCategoryScores{HateThreatening: true} + resCatScore = openai.ResultCategoryScores{HateThreatening: 1} case strings.Contains(moderationReq.Input, "harass"): resCat = openai.ResultCategories{Harassment: true} - resCatScore = openai.ResultCategoryScores{Harassment: true} + resCatScore = openai.ResultCategoryScores{Harassment: 1} case strings.Contains(moderationReq.Input, "harass hard"): resCat = openai.ResultCategories{Harassment: true} - resCatScore = openai.ResultCategoryScores{HarassmentThreatening: true} + resCatScore = openai.ResultCategoryScores{HarassmentThreatening: 1} case strings.Contains(moderationReq.Input, "suicide"): resCat = openai.ResultCategories{SelfHarm: true} - resCatScore = openai.ResultCategoryScores{SelfHarm: true} + resCatScore = openai.ResultCategoryScores{SelfHarm: 1} case strings.Contains(moderationReq.Input, "wanna suicide"): resCat = openai.ResultCategories{SelfHarmIntent: true} - resCatScore = openai.ResultCategoryScores{SelfHarm: true} + resCatScore = openai.ResultCategoryScores{SelfHarm: 1} case strings.Contains(moderationReq.Input, "drink bleach"): resCat = openai.ResultCategories{SelfHarmInstructions: true} - resCatScore = openai.ResultCategoryScores{SelfHarmInstructions: true} + resCatScore = openai.ResultCategoryScores{SelfHarmInstructions: 1} case strings.Contains(moderationReq.Input, "porn"): resCat = openai.ResultCategories{Sexual: true} - resCatScore = openai.ResultCategoryScores{Sexual: true} + resCatScore = openai.ResultCategoryScores{Sexual: 1} case strings.Contains(moderationReq.Input, "child porn"): resCat = openai.ResultCategories{SexualMinors: true} - resCatScore = openai.ResultCategoryScores{SexualMinors: true} + resCatScore = openai.ResultCategoryScores{SexualMinors: 1} case strings.Contains(moderationReq.Input, "kill"): resCat = openai.ResultCategories{Violence: true} - resCatScore = openai.ResultCategoryScores{Violence: true} + resCatScore = openai.ResultCategoryScores{Violence: 1} case strings.Contains(moderationReq.Input, "corpse"): resCat = openai.ResultCategories{ViolenceGraphic: true} - resCatScore = openai.ResultCategoryScores{ViolenceGraphic: true} + resCatScore = openai.ResultCategoryScores{ViolenceGraphic: 1} } result := openai.Result{Categories: resCat, CategoryScores: resCatScore, Flagged: true} From 38b16a3c413a3ea076cf4082ea5cd1754b72c70f Mon Sep 17 00:00:00 2001 From: Bilal Hameed <68427058+LinuxSploit@users.noreply.github.com> Date: Thu, 7 Mar 2024 15:56:50 +0500 Subject: [PATCH 67/98] Added 'wav' and 'pcm' Audio Formats (#671) * Added 'wav' and 'pcm' Audio Formats Added "wav" and "pcm" audio formats as per OpenAI API documentation for createSpeech endpoint. Ref: https://platform.openai.com/docs/api-reference/audio/createSpeech Supported formats are mp3, opus, aac, flac, wav, and pcm. * Removed Extra Newline for Sanity Check * fix: run goimports to get accepted by the linter --- speech.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/speech.go b/speech.go index be895021..92b30b55 100644 --- a/speech.go +++ b/speech.go @@ -33,6 +33,8 @@ const ( SpeechResponseFormatOpus SpeechResponseFormat = "opus" SpeechResponseFormatAac SpeechResponseFormat = "aac" SpeechResponseFormatFlac SpeechResponseFormat = "flac" + SpeechResponseFormatWav SpeechResponseFormat = "wav" + SpeechResponseFormatPcm SpeechResponseFormat = "pcm" ) var ( From 699f397c36d05e42210f65456436a447885cc07a Mon Sep 17 00:00:00 2001 From: sunshineplan Date: Mon, 11 Mar 2024 15:27:48 +0800 Subject: [PATCH 68/98] Update streamReader Close() method to return error (#681) --- stream_reader.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/stream_reader.go b/stream_reader.go index d1741259..4210a194 100644 --- a/stream_reader.go +++ b/stream_reader.go @@ -108,6 +108,6 @@ func (stream *streamReader[T]) unmarshalError() (errResp *ErrorResponse) { return } -func (stream *streamReader[T]) Close() { - stream.response.Body.Close() +func (stream *streamReader[T]) Close() error { + return stream.response.Body.Close() } From 0925563e86c2fdc5011310aa616ba493989cfe0a Mon Sep 17 00:00:00 2001 From: Quest Henkart Date: Fri, 15 Mar 2024 18:59:16 +0800 Subject: [PATCH 69/98] Fix broken implementation AssistantModify implementation (#685) * add custom marshaller, documentation and isolate tests * fix linter --- assistant.go | 30 ++++++++++++- assistant_test.go | 109 ++++++++++++++++++++++++++++++++++------------ 2 files changed, 109 insertions(+), 30 deletions(-) diff --git a/assistant.go b/assistant.go index 7a7a7652..4ca2dda6 100644 --- a/assistant.go +++ b/assistant.go @@ -2,6 +2,7 @@ package openai import ( "context" + "encoding/json" "fmt" "net/http" "net/url" @@ -21,7 +22,7 @@ type Assistant struct { Description *string `json:"description,omitempty"` Model string `json:"model"` Instructions *string `json:"instructions,omitempty"` - Tools []AssistantTool `json:"tools,omitempty"` + Tools []AssistantTool `json:"tools"` FileIDs []string `json:"file_ids,omitempty"` Metadata map[string]any `json:"metadata,omitempty"` @@ -41,16 +42,41 @@ type AssistantTool struct { Function *FunctionDefinition `json:"function,omitempty"` } +// AssistantRequest provides the assistant request parameters. +// When modifying the tools the API functions as the following: +// If Tools is undefined, no changes are made to the Assistant's tools. +// If Tools is empty slice it will effectively delete all of the Assistant's tools. +// If Tools is populated, it will replace all of the existing Assistant's tools with the provided tools. type AssistantRequest struct { Model string `json:"model"` Name *string `json:"name,omitempty"` Description *string `json:"description,omitempty"` Instructions *string `json:"instructions,omitempty"` - Tools []AssistantTool `json:"tools"` + Tools []AssistantTool `json:"-"` FileIDs []string `json:"file_ids,omitempty"` Metadata map[string]any `json:"metadata,omitempty"` } +// MarshalJSON provides a custom marshaller for the assistant request to handle the API use cases +// If Tools is nil, the field is omitted from the JSON. +// If Tools is an empty slice, it's included in the JSON as an empty array ([]). +// If Tools is populated, it's included in the JSON with the elements. +func (a AssistantRequest) MarshalJSON() ([]byte, error) { + type Alias AssistantRequest + assistantAlias := &struct { + Tools *[]AssistantTool `json:"tools,omitempty"` + *Alias + }{ + Alias: (*Alias)(&a), + } + + if a.Tools != nil { + assistantAlias.Tools = &a.Tools + } + + return json.Marshal(assistantAlias) +} + // AssistantsList is a list of assistants. type AssistantsList struct { Assistants []Assistant `json:"data"` diff --git a/assistant_test.go b/assistant_test.go index 48bc6f91..40de0e50 100644 --- a/assistant_test.go +++ b/assistant_test.go @@ -96,7 +96,7 @@ When asked a question, write and run Python code to answer the question.` }) fmt.Fprintln(w, string(resBytes)) case http.MethodPost: - var request openai.AssistantRequest + var request openai.Assistant err := json.NewDecoder(r.Body).Decode(&request) checks.NoError(t, err, "Decode error") @@ -163,44 +163,97 @@ When asked a question, write and run Python code to answer the question.` ctx := context.Background() - _, err := client.CreateAssistant(ctx, openai.AssistantRequest{ - Name: &assistantName, - Description: &assistantDescription, - Model: openai.GPT4TurboPreview, - Instructions: &assistantInstructions, + t.Run("create_assistant", func(t *testing.T) { + _, err := client.CreateAssistant(ctx, openai.AssistantRequest{ + Name: &assistantName, + Description: &assistantDescription, + Model: openai.GPT4TurboPreview, + Instructions: &assistantInstructions, + }) + checks.NoError(t, err, "CreateAssistant error") }) - checks.NoError(t, err, "CreateAssistant error") - _, err = client.RetrieveAssistant(ctx, assistantID) - checks.NoError(t, err, "RetrieveAssistant error") + t.Run("retrieve_assistant", func(t *testing.T) { + _, err := client.RetrieveAssistant(ctx, assistantID) + checks.NoError(t, err, "RetrieveAssistant error") + }) - _, err = client.ModifyAssistant(ctx, assistantID, openai.AssistantRequest{ - Name: &assistantName, - Description: &assistantDescription, - Model: openai.GPT4TurboPreview, - Instructions: &assistantInstructions, + t.Run("delete_assistant", func(t *testing.T) { + _, err := client.DeleteAssistant(ctx, assistantID) + checks.NoError(t, err, "DeleteAssistant error") }) - checks.NoError(t, err, "ModifyAssistant error") - _, err = client.DeleteAssistant(ctx, assistantID) - checks.NoError(t, err, "DeleteAssistant error") + t.Run("list_assistant", func(t *testing.T) { + _, err := client.ListAssistants(ctx, &limit, &order, &after, &before) + checks.NoError(t, err, "ListAssistants error") + }) - _, err = client.ListAssistants(ctx, &limit, &order, &after, &before) - checks.NoError(t, err, "ListAssistants error") + t.Run("create_assistant_file", func(t *testing.T) { + _, err := client.CreateAssistantFile(ctx, assistantID, openai.AssistantFileRequest{ + FileID: assistantFileID, + }) + checks.NoError(t, err, "CreateAssistantFile error") + }) - _, err = client.CreateAssistantFile(ctx, assistantID, openai.AssistantFileRequest{ - FileID: assistantFileID, + t.Run("list_assistant_files", func(t *testing.T) { + _, err := client.ListAssistantFiles(ctx, assistantID, &limit, &order, &after, &before) + checks.NoError(t, err, "ListAssistantFiles error") }) - checks.NoError(t, err, "CreateAssistantFile error") - _, err = client.ListAssistantFiles(ctx, assistantID, &limit, &order, &after, &before) - checks.NoError(t, err, "ListAssistantFiles error") + t.Run("retrieve_assistant_file", func(t *testing.T) { + _, err := client.RetrieveAssistantFile(ctx, assistantID, assistantFileID) + checks.NoError(t, err, "RetrieveAssistantFile error") + }) - _, err = client.RetrieveAssistantFile(ctx, assistantID, assistantFileID) - checks.NoError(t, err, "RetrieveAssistantFile error") + t.Run("delete_assistant_file", func(t *testing.T) { + err := client.DeleteAssistantFile(ctx, assistantID, assistantFileID) + checks.NoError(t, err, "DeleteAssistantFile error") + }) - err = client.DeleteAssistantFile(ctx, assistantID, assistantFileID) - checks.NoError(t, err, "DeleteAssistantFile error") + t.Run("modify_assistant_no_tools", func(t *testing.T) { + assistant, err := client.ModifyAssistant(ctx, assistantID, openai.AssistantRequest{ + Name: &assistantName, + Description: &assistantDescription, + Model: openai.GPT4TurboPreview, + Instructions: &assistantInstructions, + }) + checks.NoError(t, err, "ModifyAssistant error") + + if assistant.Tools != nil { + t.Errorf("expected nil got %v", assistant.Tools) + } + }) + + t.Run("modify_assistant_with_tools", func(t *testing.T) { + assistant, err := client.ModifyAssistant(ctx, assistantID, openai.AssistantRequest{ + Name: &assistantName, + Description: &assistantDescription, + Model: openai.GPT4TurboPreview, + Instructions: &assistantInstructions, + Tools: []openai.AssistantTool{{Type: openai.AssistantToolTypeFunction}}, + }) + checks.NoError(t, err, "ModifyAssistant error") + + if assistant.Tools == nil || len(assistant.Tools) != 1 { + t.Errorf("expected a slice got %v", assistant.Tools) + } + }) + + t.Run("modify_assistant_empty_tools", func(t *testing.T) { + assistant, err := client.ModifyAssistant(ctx, assistantID, openai.AssistantRequest{ + Name: &assistantName, + Description: &assistantDescription, + Model: openai.GPT4TurboPreview, + Instructions: &assistantInstructions, + Tools: make([]openai.AssistantTool, 0), + }) + + checks.NoError(t, err, "ModifyAssistant error") + + if assistant.Tools == nil { + t.Errorf("expected a slice got %v", assistant.Tools) + } + }) } func TestAzureAssistant(t *testing.T) { From 2646bce71c0cc907e2a3d050130b712c1e5688db Mon Sep 17 00:00:00 2001 From: Qiying Wang <781345688@qq.com> Date: Sat, 6 Apr 2024 03:15:54 +0800 Subject: [PATCH 70/98] feat: get header from sendRequestRaw (#694) * feat: get header from sendRequestRaw * Fix ci lint --- client.go | 15 ++++++++++++--- files.go | 6 ++---- speech.go | 7 ++----- 3 files changed, 16 insertions(+), 12 deletions(-) diff --git a/client.go b/client.go index 7b1a313a..9a1c8958 100644 --- a/client.go +++ b/client.go @@ -38,6 +38,12 @@ func (h *httpHeader) GetRateLimitHeaders() RateLimitHeaders { return newRateLimitHeaders(h.Header()) } +type RawResponse struct { + io.ReadCloser + + httpHeader +} + // NewClient creates new OpenAI API client. func NewClient(authToken string) *Client { config := DefaultConfig(authToken) @@ -134,8 +140,8 @@ func (c *Client) sendRequest(req *http.Request, v Response) error { return decodeResponse(res.Body, v) } -func (c *Client) sendRequestRaw(req *http.Request) (body io.ReadCloser, err error) { - resp, err := c.config.HTTPClient.Do(req) +func (c *Client) sendRequestRaw(req *http.Request) (response RawResponse, err error) { + resp, err := c.config.HTTPClient.Do(req) //nolint:bodyclose // body should be closed by outer function if err != nil { return } @@ -144,7 +150,10 @@ func (c *Client) sendRequestRaw(req *http.Request) (body io.ReadCloser, err erro err = c.handleErrorResp(resp) return } - return resp.Body, nil + + response.SetHeader(resp.Header) + response.ReadCloser = resp.Body + return } func sendRequestStream[T streamable](client *Client, req *http.Request) (*streamReader[T], error) { diff --git a/files.go b/files.go index a37d45f1..b40a44f1 100644 --- a/files.go +++ b/files.go @@ -4,7 +4,6 @@ import ( "bytes" "context" "fmt" - "io" "net/http" "os" ) @@ -159,13 +158,12 @@ func (c *Client) GetFile(ctx context.Context, fileID string) (file File, err err return } -func (c *Client) GetFileContent(ctx context.Context, fileID string) (content io.ReadCloser, err error) { +func (c *Client) GetFileContent(ctx context.Context, fileID string) (content RawResponse, err error) { urlSuffix := fmt.Sprintf("/files/%s/content", fileID) req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix)) if err != nil { return } - content, err = c.sendRequestRaw(req) - return + return c.sendRequestRaw(req) } diff --git a/speech.go b/speech.go index 92b30b55..7e22e755 100644 --- a/speech.go +++ b/speech.go @@ -3,7 +3,6 @@ package openai import ( "context" "errors" - "io" "net/http" ) @@ -67,7 +66,7 @@ func isValidVoice(voice SpeechVoice) bool { return contains([]SpeechVoice{VoiceAlloy, VoiceEcho, VoiceFable, VoiceOnyx, VoiceNova, VoiceShimmer}, voice) } -func (c *Client) CreateSpeech(ctx context.Context, request CreateSpeechRequest) (response io.ReadCloser, err error) { +func (c *Client) CreateSpeech(ctx context.Context, request CreateSpeechRequest) (response RawResponse, err error) { if !isValidSpeechModel(request.Model) { err = ErrInvalidSpeechModel return @@ -84,7 +83,5 @@ func (c *Client) CreateSpeech(ctx context.Context, request CreateSpeechRequest) return } - response, err = c.sendRequestRaw(req) - - return + return c.sendRequestRaw(req) } From 774fc9dd12ed60c10a9f9f03319ddb9cd5f8780c Mon Sep 17 00:00:00 2001 From: Alexander Baranov <677093+sashabaranov@users.noreply.github.com> Date: Fri, 5 Apr 2024 23:24:30 +0400 Subject: [PATCH 71/98] make linter happy (#701) --- fine_tunes.go | 1 - 1 file changed, 1 deletion(-) diff --git a/fine_tunes.go b/fine_tunes.go index 46f89f16..ca840781 100644 --- a/fine_tunes.go +++ b/fine_tunes.go @@ -115,7 +115,6 @@ func (c *Client) CreateFineTune(ctx context.Context, request FineTuneRequest) (r // This API will be officially deprecated on January 4th, 2024. // OpenAI recommends to migrate to the new fine tuning API implemented in fine_tuning_job.go. func (c *Client) CancelFineTune(ctx context.Context, fineTuneID string) (response FineTune, err error) { - //nolint:goconst // Decreases readability req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/fine-tunes/"+fineTuneID+"/cancel")) if err != nil { return From 187f4169f8898d78716f7944d87e5d95aa9a7c41 Mon Sep 17 00:00:00 2001 From: Quest Henkart Date: Tue, 9 Apr 2024 16:22:31 +0800 Subject: [PATCH 72/98] [BREAKING_CHANGES] Fix update message payload (#699) * add custom marshaller, documentation and isolate tests * fix linter * wrap payload as expected from the API and update test * modify input to accept map[string]string only --- messages.go | 4 ++-- messages_test.go | 9 +++++++-- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/messages.go b/messages.go index 86146323..6fd0adbc 100644 --- a/messages.go +++ b/messages.go @@ -139,11 +139,11 @@ func (c *Client) RetrieveMessage( func (c *Client) ModifyMessage( ctx context.Context, threadID, messageID string, - metadata map[string]any, + metadata map[string]string, ) (msg Message, err error) { urlSuffix := fmt.Sprintf("/threads/%s/%s/%s", threadID, messagesSuffix, messageID) req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), - withBody(metadata), withBetaAssistantV1()) + withBody(map[string]any{"metadata": metadata}), withBetaAssistantV1()) if err != nil { return } diff --git a/messages_test.go b/messages_test.go index 9168d6cc..a18be20b 100644 --- a/messages_test.go +++ b/messages_test.go @@ -68,6 +68,10 @@ func TestMessages(t *testing.T) { metadata := map[string]any{} err := json.NewDecoder(r.Body).Decode(&metadata) checks.NoError(t, err, "unable to decode metadata in modify message call") + payload, ok := metadata["metadata"].(map[string]any) + if !ok { + t.Fatalf("metadata payload improperly wrapped %+v", metadata) + } resBytes, _ := json.Marshal( openai.Message{ @@ -86,8 +90,9 @@ func TestMessages(t *testing.T) { FileIds: nil, AssistantID: &emptyStr, RunID: &emptyStr, - Metadata: metadata, + Metadata: payload, }) + fmt.Fprintln(w, string(resBytes)) case http.MethodGet: resBytes, _ := json.Marshal( @@ -212,7 +217,7 @@ func TestMessages(t *testing.T) { } msg, err = client.ModifyMessage(ctx, threadID, messageID, - map[string]any{ + map[string]string{ "foo": "bar", }) checks.NoError(t, err, "ModifyMessage error") From e0d0801ac73cdc87d1b56ced0a0eb71e574546c3 Mon Sep 17 00:00:00 2001 From: Liu Shuang Date: Thu, 11 Apr 2024 16:39:10 +0800 Subject: [PATCH 73/98] feat: add GPT4Turbo and GPT4Turbo20240409 (#703) --- completion.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/completion.go b/completion.go index ab1dbd6c..00f43ff1 100644 --- a/completion.go +++ b/completion.go @@ -22,6 +22,8 @@ const ( GPT432K = "gpt-4-32k" GPT40613 = "gpt-4-0613" GPT40314 = "gpt-4-0314" + GPT4Turbo = "gpt-4-turbo" + GPT4Turbo20240409 = "gpt-4-turbo-2024-04-09" GPT4Turbo0125 = "gpt-4-0125-preview" GPT4Turbo1106 = "gpt-4-1106-preview" GPT4TurboPreview = "gpt-4-turbo-preview" @@ -84,6 +86,8 @@ var disabledModelsForEndpoints = map[string]map[string]bool{ GPT4VisionPreview: true, GPT4Turbo1106: true, GPT4Turbo0125: true, + GPT4Turbo: true, + GPT4Turbo20240409: true, GPT40314: true, GPT40613: true, GPT432K: true, From ea551f422e5f38a0afc7d938eea5cff1f69494c5 Mon Sep 17 00:00:00 2001 From: Andreas Deininger Date: Sat, 13 Apr 2024 13:32:38 +0200 Subject: [PATCH 74/98] Fixing typos (#706) --- README.md | 2 +- assistant.go | 4 ++-- client_test.go | 2 +- error.go | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 9a479c0a..7946f4d9 100644 --- a/README.md +++ b/README.md @@ -636,7 +636,7 @@ FunctionDefinition{ }, "unit": { Type: jsonschema.String, - Enum: []string{"celcius", "fahrenheit"}, + Enum: []string{"celsius", "fahrenheit"}, }, }, Required: []string{"location"}, diff --git a/assistant.go b/assistant.go index 4ca2dda6..9415325f 100644 --- a/assistant.go +++ b/assistant.go @@ -181,7 +181,7 @@ func (c *Client) ListAssistants( order *string, after *string, before *string, -) (reponse AssistantsList, err error) { +) (response AssistantsList, err error) { urlValues := url.Values{} if limit != nil { urlValues.Add("limit", fmt.Sprintf("%d", *limit)) @@ -208,7 +208,7 @@ func (c *Client) ListAssistants( return } - err = c.sendRequest(req, &reponse) + err = c.sendRequest(req, &response) return } diff --git a/client_test.go b/client_test.go index bc5133ed..a08d10f2 100644 --- a/client_test.go +++ b/client_test.go @@ -406,7 +406,7 @@ func TestClientReturnsRequestBuilderErrors(t *testing.T) { } } -func TestClientReturnsRequestBuilderErrorsAddtion(t *testing.T) { +func TestClientReturnsRequestBuilderErrorsAddition(t *testing.T) { config := DefaultConfig(test.GetTestToken()) client := NewClientWithConfig(config) client.requestBuilder = &failingRequestBuilder{} diff --git a/error.go b/error.go index b2d01e22..37959a27 100644 --- a/error.go +++ b/error.go @@ -23,7 +23,7 @@ type InnerError struct { ContentFilterResults ContentFilterResults `json:"content_filter_result,omitempty"` } -// RequestError provides informations about generic request errors. +// RequestError provides information about generic request errors. type RequestError struct { HTTPStatusCode int Err error From 2446f08f94b2750287c40bb9593377f349f5578e Mon Sep 17 00:00:00 2001 From: Andreas Deininger Date: Sat, 13 Apr 2024 13:34:23 +0200 Subject: [PATCH 75/98] Bump GitHub workflow actions to latest versions (#707) --- .github/workflows/close-inactive-issues.yml | 2 +- .github/workflows/pr.yml | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/workflows/close-inactive-issues.yml b/.github/workflows/close-inactive-issues.yml index bfe9b5c9..32723c4e 100644 --- a/.github/workflows/close-inactive-issues.yml +++ b/.github/workflows/close-inactive-issues.yml @@ -10,7 +10,7 @@ jobs: issues: write pull-requests: write steps: - - uses: actions/stale@v5 + - uses: actions/stale@v9 with: days-before-issue-stale: 30 days-before-issue-close: 14 diff --git a/.github/workflows/pr.yml b/.github/workflows/pr.yml index 8df721f0..a41fff92 100644 --- a/.github/workflows/pr.yml +++ b/.github/workflows/pr.yml @@ -9,19 +9,19 @@ jobs: name: Sanity check runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - name: Setup Go - uses: actions/setup-go@v2 + uses: actions/setup-go@v5 with: - go-version: '1.19' + go-version: '1.21' - name: Run vet run: | go vet . - name: Run golangci-lint - uses: golangci/golangci-lint-action@v3 + uses: golangci/golangci-lint-action@v4 with: version: latest - name: Run tests run: go test -race -covermode=atomic -coverprofile=coverage.out -v . - name: Upload coverage reports to Codecov - uses: codecov/codecov-action@v3 + uses: codecov/codecov-action@v4 From a42f51967f5c2f8462f8d8dfd25f7d6a8d7a46fc Mon Sep 17 00:00:00 2001 From: Quest Henkart Date: Wed, 17 Apr 2024 03:26:14 +0800 Subject: [PATCH 76/98] [New_Features] Adds recently added Assistant cost saving parameters (#710) * add cost saving parameters * add periods at the end of comments * shorten commnet * further lower comment length * fix type --- run.go | 44 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/run.go b/run.go index 1f3cb7eb..7c14779c 100644 --- a/run.go +++ b/run.go @@ -28,6 +28,16 @@ type Run struct { Metadata map[string]any `json:"metadata"` Usage Usage `json:"usage,omitempty"` + Temperature *float32 `json:"temperature,omitempty"` + // The maximum number of prompt tokens that may be used over the course of the run. + // If the run exceeds the number of prompt tokens specified, the run will end with status 'complete'. + MaxPromptTokens int `json:"max_prompt_tokens,omitempty"` + // The maximum number of completion tokens that may be used over the course of the run. + // If the run exceeds the number of completion tokens specified, the run will end with status 'complete'. + MaxCompletionTokens int `json:"max_completion_tokens,omitempty"` + // ThreadTruncationStrategy defines the truncation strategy to use for the thread. + TruncationStrategy *ThreadTruncationStrategy `json:"truncation_strategy,omitempty"` + httpHeader } @@ -78,8 +88,42 @@ type RunRequest struct { AdditionalInstructions string `json:"additional_instructions,omitempty"` Tools []Tool `json:"tools,omitempty"` Metadata map[string]any `json:"metadata,omitempty"` + + // Sampling temperature between 0 and 2. Higher values like 0.8 are more random. + // lower values are more focused and deterministic. + Temperature *float32 `json:"temperature,omitempty"` + + // The maximum number of prompt tokens that may be used over the course of the run. + // If the run exceeds the number of prompt tokens specified, the run will end with status 'complete'. + MaxPromptTokens int `json:"max_prompt_tokens,omitempty"` + + // The maximum number of completion tokens that may be used over the course of the run. + // If the run exceeds the number of completion tokens specified, the run will end with status 'complete'. + MaxCompletionTokens int `json:"max_completion_tokens,omitempty"` + + // ThreadTruncationStrategy defines the truncation strategy to use for the thread. + TruncationStrategy *ThreadTruncationStrategy `json:"truncation_strategy,omitempty"` } +// ThreadTruncationStrategy defines the truncation strategy to use for the thread. +// https://platform.openai.com/docs/assistants/how-it-works/truncation-strategy. +type ThreadTruncationStrategy struct { + // default 'auto'. + Type TruncationStrategy `json:"type,omitempty"` + // this field should be set if the truncation strategy is set to LastMessages. + LastMessages *int `json:"last_messages,omitempty"` +} + +// TruncationStrategy defines the existing truncation strategies existing for thread management in an assistant. +type TruncationStrategy string + +const ( + // TruncationStrategyAuto messages in the middle of the thread will be dropped to fit the context length of the model. + TruncationStrategyAuto = TruncationStrategy("auto") + // TruncationStrategyLastMessages the thread will be truncated to the n most recent messages in the thread. + TruncationStrategyLastMessages = TruncationStrategy("last_messages") +) + type RunModifyRequest struct { Metadata map[string]any `json:"metadata,omitempty"` } From c6a63ed19aeb0e91facc5409c5a08612db550fb2 Mon Sep 17 00:00:00 2001 From: Mike Chaykowsky Date: Tue, 16 Apr 2024 12:28:06 -0700 Subject: [PATCH 77/98] Add PromptFilterResult (#702) --- chat_stream.go | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/chat_stream.go b/chat_stream.go index 57cfa789..6ff7078e 100644 --- a/chat_stream.go +++ b/chat_stream.go @@ -19,13 +19,19 @@ type ChatCompletionStreamChoice struct { ContentFilterResults ContentFilterResults `json:"content_filter_results,omitempty"` } +type PromptFilterResult struct { + Index int `json:"index"` + ContentFilterResults ContentFilterResults `json:"content_filter_results,omitempty"` +} + type ChatCompletionStreamResponse struct { - ID string `json:"id"` - Object string `json:"object"` - Created int64 `json:"created"` - Model string `json:"model"` - Choices []ChatCompletionStreamChoice `json:"choices"` - PromptAnnotations []PromptAnnotation `json:"prompt_annotations,omitempty"` + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Choices []ChatCompletionStreamChoice `json:"choices"` + PromptAnnotations []PromptAnnotation `json:"prompt_annotations,omitempty"` + PromptFilterResults []PromptFilterResult `json:"prompt_filter_results,omitempty"` } // ChatCompletionStream From 8d15a377ec4fa3aaf2e706cd1e2ad986dd6b8242 Mon Sep 17 00:00:00 2001 From: Danai Antoniou <32068609+danai-antoniou@users.noreply.github.com> Date: Wed, 24 Apr 2024 12:59:50 +0100 Subject: [PATCH 78/98] Remove hardcoded assistants version (#719) --- assistant.go | 19 +++++++++---------- client.go | 4 ++-- config.go | 14 +++++++++----- messages.go | 17 +++++++++++------ run.go | 27 +++++++++------------------ thread.go | 8 ++++---- 6 files changed, 44 insertions(+), 45 deletions(-) diff --git a/assistant.go b/assistant.go index 9415325f..661681e8 100644 --- a/assistant.go +++ b/assistant.go @@ -11,7 +11,6 @@ import ( const ( assistantsSuffix = "/assistants" assistantsFilesSuffix = "/files" - openaiAssistantsV1 = "assistants=v1" ) type Assistant struct { @@ -116,7 +115,7 @@ type AssistantFilesList struct { // CreateAssistant creates a new assistant. func (c *Client) CreateAssistant(ctx context.Context, request AssistantRequest) (response Assistant, err error) { req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(assistantsSuffix), withBody(request), - withBetaAssistantV1()) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -132,7 +131,7 @@ func (c *Client) RetrieveAssistant( ) (response Assistant, err error) { urlSuffix := fmt.Sprintf("%s/%s", assistantsSuffix, assistantID) req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), - withBetaAssistantV1()) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -149,7 +148,7 @@ func (c *Client) ModifyAssistant( ) (response Assistant, err error) { urlSuffix := fmt.Sprintf("%s/%s", assistantsSuffix, assistantID) req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request), - withBetaAssistantV1()) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -165,7 +164,7 @@ func (c *Client) DeleteAssistant( ) (response AssistantDeleteResponse, err error) { urlSuffix := fmt.Sprintf("%s/%s", assistantsSuffix, assistantID) req, err := c.newRequest(ctx, http.MethodDelete, c.fullURL(urlSuffix), - withBetaAssistantV1()) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -203,7 +202,7 @@ func (c *Client) ListAssistants( urlSuffix := fmt.Sprintf("%s%s", assistantsSuffix, encodedValues) req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), - withBetaAssistantV1()) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -221,7 +220,7 @@ func (c *Client) CreateAssistantFile( urlSuffix := fmt.Sprintf("%s/%s%s", assistantsSuffix, assistantID, assistantsFilesSuffix) req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request), - withBetaAssistantV1()) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -238,7 +237,7 @@ func (c *Client) RetrieveAssistantFile( ) (response AssistantFile, err error) { urlSuffix := fmt.Sprintf("%s/%s%s/%s", assistantsSuffix, assistantID, assistantsFilesSuffix, fileID) req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), - withBetaAssistantV1()) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -255,7 +254,7 @@ func (c *Client) DeleteAssistantFile( ) (err error) { urlSuffix := fmt.Sprintf("%s/%s%s/%s", assistantsSuffix, assistantID, assistantsFilesSuffix, fileID) req, err := c.newRequest(ctx, http.MethodDelete, c.fullURL(urlSuffix), - withBetaAssistantV1()) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -294,7 +293,7 @@ func (c *Client) ListAssistantFiles( urlSuffix := fmt.Sprintf("%s/%s%s%s", assistantsSuffix, assistantID, assistantsFilesSuffix, encodedValues) req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), - withBetaAssistantV1()) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } diff --git a/client.go b/client.go index 9a1c8958..77d69322 100644 --- a/client.go +++ b/client.go @@ -89,9 +89,9 @@ func withContentType(contentType string) requestOption { } } -func withBetaAssistantV1() requestOption { +func withBetaAssistantVersion(version string) requestOption { return func(args *requestOptions) { - args.header.Set("OpenAI-Beta", "assistants=v1") + args.header.Set("OpenAI-Beta", fmt.Sprintf("assistants=%s", version)) } } diff --git a/config.go b/config.go index c58b71ec..599fa89c 100644 --- a/config.go +++ b/config.go @@ -23,6 +23,8 @@ const ( const AzureAPIKeyHeader = "api-key" +const defaultAssistantVersion = "v1" // This will be deprecated by the end of 2024. + // ClientConfig is a configuration of a client. type ClientConfig struct { authToken string @@ -30,7 +32,8 @@ type ClientConfig struct { BaseURL string OrgID string APIType APIType - APIVersion string // required when APIType is APITypeAzure or APITypeAzureAD + APIVersion string // required when APIType is APITypeAzure or APITypeAzureAD + AssistantVersion string AzureModelMapperFunc func(model string) string // replace model to azure deployment name func HTTPClient *http.Client @@ -39,10 +42,11 @@ type ClientConfig struct { func DefaultConfig(authToken string) ClientConfig { return ClientConfig{ - authToken: authToken, - BaseURL: openaiAPIURLv1, - APIType: APITypeOpenAI, - OrgID: "", + authToken: authToken, + BaseURL: openaiAPIURLv1, + APIType: APITypeOpenAI, + AssistantVersion: defaultAssistantVersion, + OrgID: "", HTTPClient: &http.Client{}, diff --git a/messages.go b/messages.go index 6fd0adbc..6af11844 100644 --- a/messages.go +++ b/messages.go @@ -76,7 +76,8 @@ type MessageFilesList struct { // CreateMessage creates a new message. func (c *Client) CreateMessage(ctx context.Context, threadID string, request MessageRequest) (msg Message, err error) { urlSuffix := fmt.Sprintf("/threads/%s/%s", threadID, messagesSuffix) - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request), withBetaAssistantV1()) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request), + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -111,7 +112,8 @@ func (c *Client) ListMessage(ctx context.Context, threadID string, } urlSuffix := fmt.Sprintf("/threads/%s/%s%s", threadID, messagesSuffix, encodedValues) - req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), withBetaAssistantV1()) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -126,7 +128,8 @@ func (c *Client) RetrieveMessage( threadID, messageID string, ) (msg Message, err error) { urlSuffix := fmt.Sprintf("/threads/%s/%s/%s", threadID, messagesSuffix, messageID) - req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), withBetaAssistantV1()) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -143,7 +146,7 @@ func (c *Client) ModifyMessage( ) (msg Message, err error) { urlSuffix := fmt.Sprintf("/threads/%s/%s/%s", threadID, messagesSuffix, messageID) req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), - withBody(map[string]any{"metadata": metadata}), withBetaAssistantV1()) + withBody(map[string]any{"metadata": metadata}), withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -158,7 +161,8 @@ func (c *Client) RetrieveMessageFile( threadID, messageID, fileID string, ) (file MessageFile, err error) { urlSuffix := fmt.Sprintf("/threads/%s/%s/%s/files/%s", threadID, messagesSuffix, messageID, fileID) - req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), withBetaAssistantV1()) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -173,7 +177,8 @@ func (c *Client) ListMessageFiles( threadID, messageID string, ) (files MessageFilesList, err error) { urlSuffix := fmt.Sprintf("/threads/%s/%s/%s/files", threadID, messagesSuffix, messageID) - req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), withBetaAssistantV1()) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } diff --git a/run.go b/run.go index 7c14779c..094b0a4d 100644 --- a/run.go +++ b/run.go @@ -226,8 +226,7 @@ func (c *Client) CreateRun( http.MethodPost, c.fullURL(urlSuffix), withBody(request), - withBetaAssistantV1(), - ) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -247,8 +246,7 @@ func (c *Client) RetrieveRun( ctx, http.MethodGet, c.fullURL(urlSuffix), - withBetaAssistantV1(), - ) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -270,8 +268,7 @@ func (c *Client) ModifyRun( http.MethodPost, c.fullURL(urlSuffix), withBody(request), - withBetaAssistantV1(), - ) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -310,8 +307,7 @@ func (c *Client) ListRuns( ctx, http.MethodGet, c.fullURL(urlSuffix), - withBetaAssistantV1(), - ) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -332,8 +328,7 @@ func (c *Client) SubmitToolOutputs( http.MethodPost, c.fullURL(urlSuffix), withBody(request), - withBetaAssistantV1(), - ) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -352,8 +347,7 @@ func (c *Client) CancelRun( ctx, http.MethodPost, c.fullURL(urlSuffix), - withBetaAssistantV1(), - ) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -372,8 +366,7 @@ func (c *Client) CreateThreadAndRun( http.MethodPost, c.fullURL(urlSuffix), withBody(request), - withBetaAssistantV1(), - ) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -394,8 +387,7 @@ func (c *Client) RetrieveRunStep( ctx, http.MethodGet, c.fullURL(urlSuffix), - withBetaAssistantV1(), - ) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -435,8 +427,7 @@ func (c *Client) ListRunSteps( ctx, http.MethodGet, c.fullURL(urlSuffix), - withBetaAssistantV1(), - ) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } diff --git a/thread.go b/thread.go index 291f3dca..900e3f2e 100644 --- a/thread.go +++ b/thread.go @@ -51,7 +51,7 @@ type ThreadDeleteResponse struct { // CreateThread creates a new thread. func (c *Client) CreateThread(ctx context.Context, request ThreadRequest) (response Thread, err error) { req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(threadsSuffix), withBody(request), - withBetaAssistantV1()) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -64,7 +64,7 @@ func (c *Client) CreateThread(ctx context.Context, request ThreadRequest) (respo func (c *Client) RetrieveThread(ctx context.Context, threadID string) (response Thread, err error) { urlSuffix := threadsSuffix + "/" + threadID req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), - withBetaAssistantV1()) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -81,7 +81,7 @@ func (c *Client) ModifyThread( ) (response Thread, err error) { urlSuffix := threadsSuffix + "/" + threadID req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request), - withBetaAssistantV1()) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -97,7 +97,7 @@ func (c *Client) DeleteThread( ) (response ThreadDeleteResponse, err error) { urlSuffix := threadsSuffix + "/" + threadID req, err := c.newRequest(ctx, http.MethodDelete, c.fullURL(urlSuffix), - withBetaAssistantV1()) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } From 2d58f8f4b87be26dc0b7ba2b1f0c9496ecf1dfa5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BE=80=E6=97=A5=E3=80=82?= Date: Wed, 24 Apr 2024 20:02:03 +0800 Subject: [PATCH 79/98] chore: add SystemFingerprint for chat completion stream response (#716) * chore: add SystemFingerprint for stream response * chore: add test * lint: format for test --- chat_stream.go | 1 + chat_stream_test.go | 22 ++++++++++++---------- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/chat_stream.go b/chat_stream.go index 6ff7078e..159f9f47 100644 --- a/chat_stream.go +++ b/chat_stream.go @@ -30,6 +30,7 @@ type ChatCompletionStreamResponse struct { Created int64 `json:"created"` Model string `json:"model"` Choices []ChatCompletionStreamChoice `json:"choices"` + SystemFingerprint string `json:"system_fingerprint"` PromptAnnotations []PromptAnnotation `json:"prompt_annotations,omitempty"` PromptFilterResults []PromptFilterResult `json:"prompt_filter_results,omitempty"` } diff --git a/chat_stream_test.go b/chat_stream_test.go index bd571cb4..bd1c737d 100644 --- a/chat_stream_test.go +++ b/chat_stream_test.go @@ -46,12 +46,12 @@ func TestCreateChatCompletionStream(t *testing.T) { dataBytes := []byte{} dataBytes = append(dataBytes, []byte("event: message\n")...) //nolint:lll - data := `{"id":"1","object":"completion","created":1598069254,"model":"gpt-3.5-turbo","choices":[{"index":0,"delta":{"content":"response1"},"finish_reason":"max_tokens"}]}` + data := `{"id":"1","object":"completion","created":1598069254,"model":"gpt-3.5-turbo","system_fingerprint": "fp_d9767fc5b9","choices":[{"index":0,"delta":{"content":"response1"},"finish_reason":"max_tokens"}]}` dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) dataBytes = append(dataBytes, []byte("event: message\n")...) //nolint:lll - data = `{"id":"2","object":"completion","created":1598069255,"model":"gpt-3.5-turbo","choices":[{"index":0,"delta":{"content":"response2"},"finish_reason":"max_tokens"}]}` + data = `{"id":"2","object":"completion","created":1598069255,"model":"gpt-3.5-turbo","system_fingerprint": "fp_d9767fc5b9","choices":[{"index":0,"delta":{"content":"response2"},"finish_reason":"max_tokens"}]}` dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) dataBytes = append(dataBytes, []byte("event: done\n")...) @@ -77,10 +77,11 @@ func TestCreateChatCompletionStream(t *testing.T) { expectedResponses := []openai.ChatCompletionStreamResponse{ { - ID: "1", - Object: "completion", - Created: 1598069254, - Model: openai.GPT3Dot5Turbo, + ID: "1", + Object: "completion", + Created: 1598069254, + Model: openai.GPT3Dot5Turbo, + SystemFingerprint: "fp_d9767fc5b9", Choices: []openai.ChatCompletionStreamChoice{ { Delta: openai.ChatCompletionStreamChoiceDelta{ @@ -91,10 +92,11 @@ func TestCreateChatCompletionStream(t *testing.T) { }, }, { - ID: "2", - Object: "completion", - Created: 1598069255, - Model: openai.GPT3Dot5Turbo, + ID: "2", + Object: "completion", + Created: 1598069255, + Model: openai.GPT3Dot5Turbo, + SystemFingerprint: "fp_d9767fc5b9", Choices: []openai.ChatCompletionStreamChoice{ { Delta: openai.ChatCompletionStreamChoiceDelta{ From c84ab5f6ae8da3a78826ed2c8dc4c5cf93e30589 Mon Sep 17 00:00:00 2001 From: wurui <1009479218@qq.com> Date: Wed, 24 Apr 2024 20:08:58 +0800 Subject: [PATCH 80/98] feat: support cloudflare AI Gateway flavored azure openai (#715) * feat: support cloudflare AI Gateway flavored azure openai Signed-off-by: STRRL * test: add test for cloudflare azure fullURL --------- Signed-off-by: STRRL Co-authored-by: STRRL --- api_internal_test.go | 36 ++++++++++++++++++++++++++++++++++++ client.go | 10 ++++++++-- config.go | 7 ++++--- 3 files changed, 48 insertions(+), 5 deletions(-) diff --git a/api_internal_test.go b/api_internal_test.go index 0fb0f899..a590ec9a 100644 --- a/api_internal_test.go +++ b/api_internal_test.go @@ -148,3 +148,39 @@ func TestAzureFullURL(t *testing.T) { }) } } + +func TestCloudflareAzureFullURL(t *testing.T) { + cases := []struct { + Name string + BaseURL string + Expect string + }{ + { + "CloudflareAzureBaseURLWithSlashAutoStrip", + "https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo/", + "https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo/" + + "chat/completions?api-version=2023-05-15", + }, + { + "CloudflareAzureBaseURLWithoutSlashOK", + "https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo", + "https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo/" + + "chat/completions?api-version=2023-05-15", + }, + } + + for _, c := range cases { + t.Run(c.Name, func(t *testing.T) { + az := DefaultAzureConfig("dummy", c.BaseURL) + az.APIType = APITypeCloudflareAzure + + cli := NewClientWithConfig(az) + + actual := cli.fullURL("/chat/completions") + if actual != c.Expect { + t.Errorf("Expected %s, got %s", c.Expect, actual) + } + t.Logf("Full URL: %s", actual) + }) + } +} diff --git a/client.go b/client.go index 77d69322..c57ba17c 100644 --- a/client.go +++ b/client.go @@ -182,7 +182,7 @@ func sendRequestStream[T streamable](client *Client, req *http.Request) (*stream func (c *Client) setCommonHeaders(req *http.Request) { // https://learn.microsoft.com/en-us/azure/cognitive-services/openai/reference#authentication // Azure API Key authentication - if c.config.APIType == APITypeAzure { + if c.config.APIType == APITypeAzure || c.config.APIType == APITypeCloudflareAzure { req.Header.Set(AzureAPIKeyHeader, c.config.authToken) } else if c.config.authToken != "" { // OpenAI or Azure AD authentication @@ -246,7 +246,13 @@ func (c *Client) fullURL(suffix string, args ...any) string { ) } - // c.config.APIType == APITypeOpenAI || c.config.APIType == "" + // https://developers.cloudflare.com/ai-gateway/providers/azureopenai/ + if c.config.APIType == APITypeCloudflareAzure { + baseURL := c.config.BaseURL + baseURL = strings.TrimRight(baseURL, "/") + return fmt.Sprintf("%s%s?api-version=%s", baseURL, suffix, c.config.APIVersion) + } + return fmt.Sprintf("%s%s", c.config.BaseURL, suffix) } diff --git a/config.go b/config.go index 599fa89c..bb437c97 100644 --- a/config.go +++ b/config.go @@ -16,9 +16,10 @@ const ( type APIType string const ( - APITypeOpenAI APIType = "OPEN_AI" - APITypeAzure APIType = "AZURE" - APITypeAzureAD APIType = "AZURE_AD" + APITypeOpenAI APIType = "OPEN_AI" + APITypeAzure APIType = "AZURE" + APITypeAzureAD APIType = "AZURE_AD" + APITypeCloudflareAzure APIType = "CLOUDFLARE_AZURE" ) const AzureAPIKeyHeader = "api-key" From c9953a7b051bd661254fb071029553e61c78f8bd Mon Sep 17 00:00:00 2001 From: Alireza Ghasemi Date: Sat, 27 Apr 2024 12:55:49 +0330 Subject: [PATCH 81/98] Fixup minor copy-pasta comment typo (#728) imagess -> images --- image_api_test.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/image_api_test.go b/image_api_test.go index 2eb46f2b..48416b1e 100644 --- a/image_api_test.go +++ b/image_api_test.go @@ -36,7 +36,7 @@ func handleImageEndpoint(w http.ResponseWriter, r *http.Request) { var err error var resBytes []byte - // imagess only accepts POST requests + // images only accepts POST requests if r.Method != "POST" { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) } @@ -146,7 +146,7 @@ func TestImageEditWithoutMask(t *testing.T) { func handleEditImageEndpoint(w http.ResponseWriter, r *http.Request) { var resBytes []byte - // imagess only accepts POST requests + // images only accepts POST requests if r.Method != "POST" { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) } @@ -202,7 +202,7 @@ func TestImageVariation(t *testing.T) { func handleVariateImageEndpoint(w http.ResponseWriter, r *http.Request) { var resBytes []byte - // imagess only accepts POST requests + // images only accepts POST requests if r.Method != "POST" { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) } From 3334a9c78a9d594934e33af184e4e6313c4a942b Mon Sep 17 00:00:00 2001 From: Alireza Ghasemi Date: Tue, 7 May 2024 16:10:07 +0330 Subject: [PATCH 82/98] Add support for word-level audio transcription timestamp granularity (#733) * Add support for audio transcription timestamp_granularities word * Fixup multiple timestamp granularities --- audio.go | 31 ++++++++++++++++++++++++++----- audio_api_test.go | 4 ++++ audio_test.go | 6 +++++- 3 files changed, 35 insertions(+), 6 deletions(-) diff --git a/audio.go b/audio.go index 4cbe4fe6..dbc26d15 100644 --- a/audio.go +++ b/audio.go @@ -27,8 +27,14 @@ const ( AudioResponseFormatVTT AudioResponseFormat = "vtt" ) +type TranscriptionTimestampGranularity string + +const ( + TranscriptionTimestampGranularityWord TranscriptionTimestampGranularity = "word" + TranscriptionTimestampGranularitySegment TranscriptionTimestampGranularity = "segment" +) + // AudioRequest represents a request structure for audio API. -// ResponseFormat is not supported for now. We only return JSON text, which may be sufficient. type AudioRequest struct { Model string @@ -38,10 +44,11 @@ type AudioRequest struct { // Reader is an optional io.Reader when you do not want to use an existing file. Reader io.Reader - Prompt string // For translation, it should be in English - Temperature float32 - Language string // For translation, just do not use it. It seems "en" works, not confirmed... - Format AudioResponseFormat + Prompt string + Temperature float32 + Language string // Only for transcription. + Format AudioResponseFormat + TimestampGranularities []TranscriptionTimestampGranularity // Only for transcription. } // AudioResponse represents a response structure for audio API. @@ -62,6 +69,11 @@ type AudioResponse struct { NoSpeechProb float64 `json:"no_speech_prob"` Transient bool `json:"transient"` } `json:"segments"` + Words []struct { + Word string `json:"word"` + Start float64 `json:"start"` + End float64 `json:"end"` + } `json:"words"` Text string `json:"text"` httpHeader @@ -179,6 +191,15 @@ func audioMultipartForm(request AudioRequest, b utils.FormBuilder) error { } } + if len(request.TimestampGranularities) > 0 { + for _, tg := range request.TimestampGranularities { + err = b.WriteField("timestamp_granularities[]", string(tg)) + if err != nil { + return fmt.Errorf("writing timestamp_granularities[]: %w", err) + } + } + } + // Close the multipart writer return b.Close() } diff --git a/audio_api_test.go b/audio_api_test.go index a0efc792..c2459844 100644 --- a/audio_api_test.go +++ b/audio_api_test.go @@ -105,6 +105,10 @@ func TestAudioWithOptionalArgs(t *testing.T) { Temperature: 0.5, Language: "zh", Format: openai.AudioResponseFormatSRT, + TimestampGranularities: []openai.TranscriptionTimestampGranularity{ + openai.TranscriptionTimestampGranularitySegment, + openai.TranscriptionTimestampGranularityWord, + }, } _, err := tc.createFn(ctx, req) checks.NoError(t, err, "audio API error") diff --git a/audio_test.go b/audio_test.go index 5346244c..235931f3 100644 --- a/audio_test.go +++ b/audio_test.go @@ -24,6 +24,10 @@ func TestAudioWithFailingFormBuilder(t *testing.T) { Temperature: 0.5, Language: "en", Format: AudioResponseFormatSRT, + TimestampGranularities: []TranscriptionTimestampGranularity{ + TranscriptionTimestampGranularitySegment, + TranscriptionTimestampGranularityWord, + }, } mockFailedErr := fmt.Errorf("mock form builder fail") @@ -47,7 +51,7 @@ func TestAudioWithFailingFormBuilder(t *testing.T) { return nil } - failOn := []string{"model", "prompt", "temperature", "language", "response_format"} + failOn := []string{"model", "prompt", "temperature", "language", "response_format", "timestamp_granularities[]"} for _, failingField := range failOn { failForField = failingField mockFailedErr = fmt.Errorf("mock form builder fail on field %s", failingField) From 6af32202d1ce469674050600efa07c90ec286d03 Mon Sep 17 00:00:00 2001 From: Liu Shuang Date: Tue, 7 May 2024 20:42:24 +0800 Subject: [PATCH 83/98] feat: support stream_options (#736) * feat: support stream_options * fix lint * fix lint --- chat.go | 10 ++++ chat_stream.go | 4 ++ chat_stream_test.go | 123 ++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 137 insertions(+) diff --git a/chat.go b/chat.go index efb14fd4..a1eb1172 100644 --- a/chat.go +++ b/chat.go @@ -216,6 +216,16 @@ type ChatCompletionRequest struct { Tools []Tool `json:"tools,omitempty"` // This can be either a string or an ToolChoice object. ToolChoice any `json:"tool_choice,omitempty"` + // Options for streaming response. Only set this when you set stream: true. + StreamOptions *StreamOptions `json:"stream_options,omitempty"` +} + +type StreamOptions struct { + // If set, an additional chunk will be streamed before the data: [DONE] message. + // The usage field on this chunk shows the token usage statistics for the entire request, + // and the choices field will always be an empty array. + // All other chunks will also include a usage field, but with a null value. + IncludeUsage bool `json:"include_usage,omitempty"` } type ToolType string diff --git a/chat_stream.go b/chat_stream.go index 159f9f47..ffd512ff 100644 --- a/chat_stream.go +++ b/chat_stream.go @@ -33,6 +33,10 @@ type ChatCompletionStreamResponse struct { SystemFingerprint string `json:"system_fingerprint"` PromptAnnotations []PromptAnnotation `json:"prompt_annotations,omitempty"` PromptFilterResults []PromptFilterResult `json:"prompt_filter_results,omitempty"` + // An optional field that will only be present when you set stream_options: {"include_usage": true} in your request. + // When present, it contains a null value except for the last chunk which contains the token usage statistics + // for the entire request. + Usage *Usage `json:"usage,omitempty"` } // ChatCompletionStream diff --git a/chat_stream_test.go b/chat_stream_test.go index bd1c737d..63e45ee2 100644 --- a/chat_stream_test.go +++ b/chat_stream_test.go @@ -388,6 +388,120 @@ func TestAzureCreateChatCompletionStreamRateLimitError(t *testing.T) { } } +func TestCreateChatCompletionStreamStreamOptions(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + + server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + + // Send test responses + var dataBytes []byte + //nolint:lll + data := `{"id":"1","object":"completion","created":1598069254,"model":"gpt-3.5-turbo","system_fingerprint": "fp_d9767fc5b9","choices":[{"index":0,"delta":{"content":"response1"},"finish_reason":"max_tokens"}],"usage":null}` + dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) + + //nolint:lll + data = `{"id":"2","object":"completion","created":1598069255,"model":"gpt-3.5-turbo","system_fingerprint": "fp_d9767fc5b9","choices":[{"index":0,"delta":{"content":"response2"},"finish_reason":"max_tokens"}],"usage":null}` + dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) + + //nolint:lll + data = `{"id":"3","object":"completion","created":1598069256,"model":"gpt-3.5-turbo","system_fingerprint": "fp_d9767fc5b9","choices":[],"usage":{"prompt_tokens":1,"completion_tokens":1,"total_tokens":2}}` + dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) + + dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...) + + _, err := w.Write(dataBytes) + checks.NoError(t, err, "Write error") + }) + + stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{ + MaxTokens: 5, + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Stream: true, + StreamOptions: &openai.StreamOptions{ + IncludeUsage: true, + }, + }) + checks.NoError(t, err, "CreateCompletionStream returned error") + defer stream.Close() + + expectedResponses := []openai.ChatCompletionStreamResponse{ + { + ID: "1", + Object: "completion", + Created: 1598069254, + Model: openai.GPT3Dot5Turbo, + SystemFingerprint: "fp_d9767fc5b9", + Choices: []openai.ChatCompletionStreamChoice{ + { + Delta: openai.ChatCompletionStreamChoiceDelta{ + Content: "response1", + }, + FinishReason: "max_tokens", + }, + }, + }, + { + ID: "2", + Object: "completion", + Created: 1598069255, + Model: openai.GPT3Dot5Turbo, + SystemFingerprint: "fp_d9767fc5b9", + Choices: []openai.ChatCompletionStreamChoice{ + { + Delta: openai.ChatCompletionStreamChoiceDelta{ + Content: "response2", + }, + FinishReason: "max_tokens", + }, + }, + }, + { + ID: "3", + Object: "completion", + Created: 1598069256, + Model: openai.GPT3Dot5Turbo, + SystemFingerprint: "fp_d9767fc5b9", + Choices: []openai.ChatCompletionStreamChoice{}, + Usage: &openai.Usage{ + PromptTokens: 1, + CompletionTokens: 1, + TotalTokens: 2, + }, + }, + } + + for ix, expectedResponse := range expectedResponses { + b, _ := json.Marshal(expectedResponse) + t.Logf("%d: %s", ix, string(b)) + + receivedResponse, streamErr := stream.Recv() + checks.NoError(t, streamErr, "stream.Recv() failed") + if !compareChatResponses(expectedResponse, receivedResponse) { + t.Errorf("Stream response %v is %v, expected %v", ix, receivedResponse, expectedResponse) + } + } + + _, streamErr := stream.Recv() + if !errors.Is(streamErr, io.EOF) { + t.Errorf("stream.Recv() did not return EOF in the end: %v", streamErr) + } + + _, streamErr = stream.Recv() + + checks.ErrorIs(t, streamErr, io.EOF, "stream.Recv() did not return EOF when the stream is finished") + if !errors.Is(streamErr, io.EOF) { + t.Errorf("stream.Recv() did not return EOF when the stream is finished: %v", streamErr) + } +} + // Helper funcs. func compareChatResponses(r1, r2 openai.ChatCompletionStreamResponse) bool { if r1.ID != r2.ID || r1.Object != r2.Object || r1.Created != r2.Created || r1.Model != r2.Model { @@ -401,6 +515,15 @@ func compareChatResponses(r1, r2 openai.ChatCompletionStreamResponse) bool { return false } } + if r1.Usage != nil || r2.Usage != nil { + if r1.Usage == nil || r2.Usage == nil { + return false + } + if r1.Usage.PromptTokens != r2.Usage.PromptTokens || r1.Usage.CompletionTokens != r2.Usage.CompletionTokens || + r1.Usage.TotalTokens != r2.Usage.TotalTokens { + return false + } + } return true } From 3b25e09da90715681fe4049955d7c7ce645e218c Mon Sep 17 00:00:00 2001 From: Kevin Mesiab Date: Mon, 13 May 2024 11:48:14 -0700 Subject: [PATCH 84/98] enhancement: Add new GPT4-o and alias to completion enums (#744) --- completion.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/completion.go b/completion.go index 00f43ff1..3b4f8952 100644 --- a/completion.go +++ b/completion.go @@ -22,6 +22,8 @@ const ( GPT432K = "gpt-4-32k" GPT40613 = "gpt-4-0613" GPT40314 = "gpt-4-0314" + GPT4o = "gpt-4o" + GPT4o20240513 = "gpt-4o-2024-05-13" GPT4Turbo = "gpt-4-turbo" GPT4Turbo20240409 = "gpt-4-turbo-2024-04-09" GPT4Turbo0125 = "gpt-4-0125-preview" From 9f19d1c93bf986f2a8925be62f35aa5c413a706a Mon Sep 17 00:00:00 2001 From: nullswan Date: Mon, 13 May 2024 21:07:07 +0200 Subject: [PATCH 85/98] Add gpt4o (#742) * Add gpt4o * disabled model for endpoint seen in https://github.com/sashabaranov/go-openai/commit/e0d0801ac73cdc87d1b56ced0a0eb71e574546c3 * Update completion.go --------- Co-authored-by: Alexander Baranov <677093+sashabaranov@users.noreply.github.com> --- completion.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/completion.go b/completion.go index 3b4f8952..ced8e060 100644 --- a/completion.go +++ b/completion.go @@ -84,6 +84,8 @@ var disabledModelsForEndpoints = map[string]map[string]bool{ GPT3Dot5Turbo16K: true, GPT3Dot5Turbo16K0613: true, GPT4: true, + GPT4o: true, + GPT4o20240513: true, GPT4TurboPreview: true, GPT4VisionPreview: true, GPT4Turbo1106: true, From 4f4a85687be31607536997e924b27693f5e5211a Mon Sep 17 00:00:00 2001 From: Kshirodra Meher Date: Tue, 14 May 2024 00:38:14 +0530 Subject: [PATCH 86/98] Added DALL.E 3 to readme.md (#741) * Added DALL.E 3 to readme.md Added DALL.E 3 to readme.md as its supported now as per issue https://github.com/sashabaranov/go-openai/issues/494 * Update README.md --------- Co-authored-by: Alexander Baranov <677093+sashabaranov@users.noreply.github.com> --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 7946f4d9..799dc602 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ This library provides unofficial Go clients for [OpenAI API](https://platform.op * ChatGPT * GPT-3, GPT-4 -* DALL·E 2 +* DALL·E 2, DALL·E 3 * Whisper ## Installation From 211cb49fc22766f4174fef15301c4d39aef609d3 Mon Sep 17 00:00:00 2001 From: ando-masaki Date: Fri, 24 May 2024 16:18:47 +0900 Subject: [PATCH 87/98] Update client.go to get response header whether there is an error or not. (#751) Update client.go to get response header whether there is an error or not. Because 429 Too Many Requests error response has "Retry-After" header. --- client.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/client.go b/client.go index c57ba17c..7bc28e98 100644 --- a/client.go +++ b/client.go @@ -129,14 +129,14 @@ func (c *Client) sendRequest(req *http.Request, v Response) error { defer res.Body.Close() - if isFailureStatusCode(res) { - return c.handleErrorResp(res) - } - if v != nil { v.SetHeader(res.Header) } + if isFailureStatusCode(res) { + return c.handleErrorResp(res) + } + return decodeResponse(res.Body, v) } From 30cf7b879cff5eb56f06fda19c51c9e92fce8b13 Mon Sep 17 00:00:00 2001 From: Adam Smith <62568604+TheAdamSmith@users.noreply.github.com> Date: Mon, 3 Jun 2024 09:50:22 -0700 Subject: [PATCH 88/98] feat: add params to RunRequest (#754) --- run.go | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/run.go b/run.go index 094b0a4d..6bd3933b 100644 --- a/run.go +++ b/run.go @@ -92,6 +92,7 @@ type RunRequest struct { // Sampling temperature between 0 and 2. Higher values like 0.8 are more random. // lower values are more focused and deterministic. Temperature *float32 `json:"temperature,omitempty"` + TopP *float32 `json:"top_p,omitempty"` // The maximum number of prompt tokens that may be used over the course of the run. // If the run exceeds the number of prompt tokens specified, the run will end with status 'complete'. @@ -103,6 +104,11 @@ type RunRequest struct { // ThreadTruncationStrategy defines the truncation strategy to use for the thread. TruncationStrategy *ThreadTruncationStrategy `json:"truncation_strategy,omitempty"` + + // This can be either a string or a ToolChoice object. + ToolChoice any `json:"tool_choice,omitempty"` + // This can be either a string or a ResponseFormat object. + ResponseFormat any `json:"response_format,omitempty"` } // ThreadTruncationStrategy defines the truncation strategy to use for the thread. @@ -124,6 +130,13 @@ const ( TruncationStrategyLastMessages = TruncationStrategy("last_messages") ) +// ReponseFormat specifies the format the model must output. +// https://platform.openai.com/docs/api-reference/runs/createRun#runs-createrun-response_format. +// Type can either be text or json_object. +type ReponseFormat struct { + Type string `json:"type"` +} + type RunModifyRequest struct { Metadata map[string]any `json:"metadata,omitempty"` } From 8618492b98bb91edbb43f8080b3a68275e183663 Mon Sep 17 00:00:00 2001 From: shosato0306 <38198918+shosato0306@users.noreply.github.com> Date: Wed, 5 Jun 2024 20:03:57 +0900 Subject: [PATCH 89/98] feat: add incomplete run status (#763) --- run.go | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/run.go b/run.go index 6bd3933b..5598f1df 100644 --- a/run.go +++ b/run.go @@ -30,10 +30,10 @@ type Run struct { Temperature *float32 `json:"temperature,omitempty"` // The maximum number of prompt tokens that may be used over the course of the run. - // If the run exceeds the number of prompt tokens specified, the run will end with status 'complete'. + // If the run exceeds the number of prompt tokens specified, the run will end with status 'incomplete'. MaxPromptTokens int `json:"max_prompt_tokens,omitempty"` // The maximum number of completion tokens that may be used over the course of the run. - // If the run exceeds the number of completion tokens specified, the run will end with status 'complete'. + // If the run exceeds the number of completion tokens specified, the run will end with status 'incomplete'. MaxCompletionTokens int `json:"max_completion_tokens,omitempty"` // ThreadTruncationStrategy defines the truncation strategy to use for the thread. TruncationStrategy *ThreadTruncationStrategy `json:"truncation_strategy,omitempty"` @@ -50,6 +50,7 @@ const ( RunStatusCancelling RunStatus = "cancelling" RunStatusFailed RunStatus = "failed" RunStatusCompleted RunStatus = "completed" + RunStatusIncomplete RunStatus = "incomplete" RunStatusExpired RunStatus = "expired" RunStatusCancelled RunStatus = "cancelled" ) @@ -95,11 +96,11 @@ type RunRequest struct { TopP *float32 `json:"top_p,omitempty"` // The maximum number of prompt tokens that may be used over the course of the run. - // If the run exceeds the number of prompt tokens specified, the run will end with status 'complete'. + // If the run exceeds the number of prompt tokens specified, the run will end with status 'incomplete'. MaxPromptTokens int `json:"max_prompt_tokens,omitempty"` // The maximum number of completion tokens that may be used over the course of the run. - // If the run exceeds the number of completion tokens specified, the run will end with status 'complete'. + // If the run exceeds the number of completion tokens specified, the run will end with status 'incomplete'. MaxCompletionTokens int `json:"max_completion_tokens,omitempty"` // ThreadTruncationStrategy defines the truncation strategy to use for the thread. From fd41f7a5f49e6723d97642c186e5e090abaebfe2 Mon Sep 17 00:00:00 2001 From: Adam Smith <62568604+TheAdamSmith@users.noreply.github.com> Date: Thu, 13 Jun 2024 06:23:07 -0700 Subject: [PATCH 90/98] Fix integration test (#762) * added TestCompletionStream test moved completion stream testing to seperate function added NoErrorF fixes nil pointer reference on stream object * update integration test models --- api_integration_test.go | 64 ++++++++++++++++++++-------------- completion.go | 31 ++++++++-------- embeddings.go | 2 +- internal/test/checks/checks.go | 7 ++++ 4 files changed, 62 insertions(+), 42 deletions(-) diff --git a/api_integration_test.go b/api_integration_test.go index 736040c5..f3468518 100644 --- a/api_integration_test.go +++ b/api_integration_test.go @@ -26,7 +26,7 @@ func TestAPI(t *testing.T) { _, err = c.ListEngines(ctx) checks.NoError(t, err, "ListEngines error") - _, err = c.GetEngine(ctx, "davinci") + _, err = c.GetEngine(ctx, openai.GPT3Davinci002) checks.NoError(t, err, "GetEngine error") fileRes, err := c.ListFiles(ctx) @@ -42,7 +42,7 @@ func TestAPI(t *testing.T) { "The food was delicious and the waiter", "Other examples of embedding request", }, - Model: openai.AdaSearchQuery, + Model: openai.AdaEmbeddingV2, } _, err = c.CreateEmbeddings(ctx, embeddingReq) checks.NoError(t, err, "Embedding error") @@ -77,31 +77,6 @@ func TestAPI(t *testing.T) { ) checks.NoError(t, err, "CreateChatCompletion (with name) returned error") - stream, err := c.CreateCompletionStream(ctx, openai.CompletionRequest{ - Prompt: "Ex falso quodlibet", - Model: openai.GPT3Ada, - MaxTokens: 5, - Stream: true, - }) - checks.NoError(t, err, "CreateCompletionStream returned error") - defer stream.Close() - - counter := 0 - for { - _, err = stream.Recv() - if err != nil { - if errors.Is(err, io.EOF) { - break - } - t.Errorf("Stream error: %v", err) - } else { - counter++ - } - } - if counter == 0 { - t.Error("Stream did not return any responses") - } - _, err = c.CreateChatCompletion( context.Background(), openai.ChatCompletionRequest{ @@ -134,6 +109,41 @@ func TestAPI(t *testing.T) { checks.NoError(t, err, "CreateChatCompletion (with functions) returned error") } +func TestCompletionStream(t *testing.T) { + apiToken := os.Getenv("OPENAI_TOKEN") + if apiToken == "" { + t.Skip("Skipping testing against production OpenAI API. Set OPENAI_TOKEN environment variable to enable it.") + } + + c := openai.NewClient(apiToken) + ctx := context.Background() + + stream, err := c.CreateCompletionStream(ctx, openai.CompletionRequest{ + Prompt: "Ex falso quodlibet", + Model: openai.GPT3Babbage002, + MaxTokens: 5, + Stream: true, + }) + checks.NoError(t, err, "CreateCompletionStream returned error") + defer stream.Close() + + counter := 0 + for { + _, err = stream.Recv() + if err != nil { + if errors.Is(err, io.EOF) { + break + } + t.Errorf("Stream error: %v", err) + } else { + counter++ + } + } + if counter == 0 { + t.Error("Stream did not return any responses") + } +} + func TestAPIError(t *testing.T) { apiToken := os.Getenv("OPENAI_TOKEN") if apiToken == "" { diff --git a/completion.go b/completion.go index ced8e060..024f09b1 100644 --- a/completion.go +++ b/completion.go @@ -39,30 +39,33 @@ const ( GPT3Dot5Turbo16K0613 = "gpt-3.5-turbo-16k-0613" GPT3Dot5Turbo = "gpt-3.5-turbo" GPT3Dot5TurboInstruct = "gpt-3.5-turbo-instruct" - // Deprecated: Will be shut down on January 04, 2024. Use gpt-3.5-turbo-instruct instead. + // Deprecated: Model is shutdown. Use gpt-3.5-turbo-instruct instead. GPT3TextDavinci003 = "text-davinci-003" - // Deprecated: Will be shut down on January 04, 2024. Use gpt-3.5-turbo-instruct instead. + // Deprecated: Model is shutdown. Use gpt-3.5-turbo-instruct instead. GPT3TextDavinci002 = "text-davinci-002" - // Deprecated: Will be shut down on January 04, 2024. Use gpt-3.5-turbo-instruct instead. + // Deprecated: Model is shutdown. Use gpt-3.5-turbo-instruct instead. GPT3TextCurie001 = "text-curie-001" - // Deprecated: Will be shut down on January 04, 2024. Use gpt-3.5-turbo-instruct instead. + // Deprecated: Model is shutdown. Use gpt-3.5-turbo-instruct instead. GPT3TextBabbage001 = "text-babbage-001" - // Deprecated: Will be shut down on January 04, 2024. Use gpt-3.5-turbo-instruct instead. + // Deprecated: Model is shutdown. Use gpt-3.5-turbo-instruct instead. GPT3TextAda001 = "text-ada-001" - // Deprecated: Will be shut down on January 04, 2024. Use gpt-3.5-turbo-instruct instead. + // Deprecated: Model is shutdown. Use gpt-3.5-turbo-instruct instead. GPT3TextDavinci001 = "text-davinci-001" - // Deprecated: Will be shut down on January 04, 2024. Use gpt-3.5-turbo-instruct instead. + // Deprecated: Model is shutdown. Use gpt-3.5-turbo-instruct instead. GPT3DavinciInstructBeta = "davinci-instruct-beta" - GPT3Davinci = "davinci" - GPT3Davinci002 = "davinci-002" - // Deprecated: Will be shut down on January 04, 2024. Use gpt-3.5-turbo-instruct instead. + // Deprecated: Model is shutdown. Use davinci-002 instead. + GPT3Davinci = "davinci" + GPT3Davinci002 = "davinci-002" + // Deprecated: Model is shutdown. Use gpt-3.5-turbo-instruct instead. GPT3CurieInstructBeta = "curie-instruct-beta" GPT3Curie = "curie" GPT3Curie002 = "curie-002" - GPT3Ada = "ada" - GPT3Ada002 = "ada-002" - GPT3Babbage = "babbage" - GPT3Babbage002 = "babbage-002" + // Deprecated: Model is shutdown. Use babbage-002 instead. + GPT3Ada = "ada" + GPT3Ada002 = "ada-002" + // Deprecated: Model is shutdown. Use babbage-002 instead. + GPT3Babbage = "babbage" + GPT3Babbage002 = "babbage-002" ) // Codex Defines the models provided by OpenAI. diff --git a/embeddings.go b/embeddings.go index c5633a31..b513ba6a 100644 --- a/embeddings.go +++ b/embeddings.go @@ -16,7 +16,7 @@ var ErrVectorLengthMismatch = errors.New("vector length mismatch") type EmbeddingModel string const ( - // Deprecated: The following block will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. + // Deprecated: The following block is shut down. Use text-embedding-ada-002 instead. AdaSimilarity EmbeddingModel = "text-similarity-ada-001" BabbageSimilarity EmbeddingModel = "text-similarity-babbage-001" CurieSimilarity EmbeddingModel = "text-similarity-curie-001" diff --git a/internal/test/checks/checks.go b/internal/test/checks/checks.go index 71336915..6bd0964c 100644 --- a/internal/test/checks/checks.go +++ b/internal/test/checks/checks.go @@ -12,6 +12,13 @@ func NoError(t *testing.T, err error, message ...string) { } } +func NoErrorF(t *testing.T, err error, message ...string) { + t.Helper() + if err != nil { + t.Fatal(err, message) + } +} + func HasError(t *testing.T, err error, message ...string) { t.Helper() if err == nil { From 7e96c712cbdad50b9cf67324b1ca5ef6541b6235 Mon Sep 17 00:00:00 2001 From: Alex Baranov <677093+sashabaranov@users.noreply.github.com> Date: Thu, 13 Jun 2024 19:15:27 +0400 Subject: [PATCH 91/98] run integration tests (#769) --- .github/workflows/integration-tests.yml | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) create mode 100644 .github/workflows/integration-tests.yml diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml new file mode 100644 index 00000000..19f158e4 --- /dev/null +++ b/.github/workflows/integration-tests.yml @@ -0,0 +1,19 @@ +name: Integration tests + +on: + push: + branches: + - master + +jobs: + integration_tests: + name: Run integration tests + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Setup Go + uses: actions/setup-go@v5 + with: + go-version: '1.21' + - name: Run integration tests + run: go test -v -tags=integration ./api_integration_test.go From c69c3bb1d259375d5de801f890aca40c0b2a8867 Mon Sep 17 00:00:00 2001 From: Alex Baranov <677093+sashabaranov@users.noreply.github.com> Date: Thu, 13 Jun 2024 19:21:25 +0400 Subject: [PATCH 92/98] integration tests: pass openai secret (#770) * pass openai secret * only run in master branch --- .github/workflows/integration-tests.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index 19f158e4..7260b00b 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -16,4 +16,6 @@ jobs: with: go-version: '1.21' - name: Run integration tests + env: + OPENAI_TOKEN: ${{ secrets.OPENAI_TOKEN }} run: go test -v -tags=integration ./api_integration_test.go From 99cc170b5414bd21fc1c55bccba1d6c1bad04516 Mon Sep 17 00:00:00 2001 From: eiixy <990656271@qq.com> Date: Thu, 13 Jun 2024 23:24:37 +0800 Subject: [PATCH 93/98] feat: support batches api (#746) * feat: support batches api * update batch_test.go * fix golangci-lint check * fix golangci-lint check * fix tests coverage * fix tests coverage * fix tests coverage * fix tests coverage * fix tests coverage * fix tests coverage * fix tests coverage * fix: create batch api * update batch_test.go * feat: add `CreateBatchWithUploadFile` * feat: add `UploadBatchFile` * optimize variable and type naming * expose `BatchLineItem` interface * update batches const --- batch.go | 275 ++++++++++++++++++++++++++++++++++++ batch_test.go | 368 +++++++++++++++++++++++++++++++++++++++++++++++++ client_test.go | 11 ++ files.go | 1 + 4 files changed, 655 insertions(+) create mode 100644 batch.go create mode 100644 batch_test.go diff --git a/batch.go b/batch.go new file mode 100644 index 00000000..4aba966b --- /dev/null +++ b/batch.go @@ -0,0 +1,275 @@ +package openai + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "net/url" +) + +const batchesSuffix = "/batches" + +type BatchEndpoint string + +const ( + BatchEndpointChatCompletions BatchEndpoint = "/v1/chat/completions" + BatchEndpointCompletions BatchEndpoint = "/v1/completions" + BatchEndpointEmbeddings BatchEndpoint = "/v1/embeddings" +) + +type BatchLineItem interface { + MarshalBatchLineItem() []byte +} + +type BatchChatCompletionRequest struct { + CustomID string `json:"custom_id"` + Body ChatCompletionRequest `json:"body"` + Method string `json:"method"` + URL BatchEndpoint `json:"url"` +} + +func (r BatchChatCompletionRequest) MarshalBatchLineItem() []byte { + marshal, _ := json.Marshal(r) + return marshal +} + +type BatchCompletionRequest struct { + CustomID string `json:"custom_id"` + Body CompletionRequest `json:"body"` + Method string `json:"method"` + URL BatchEndpoint `json:"url"` +} + +func (r BatchCompletionRequest) MarshalBatchLineItem() []byte { + marshal, _ := json.Marshal(r) + return marshal +} + +type BatchEmbeddingRequest struct { + CustomID string `json:"custom_id"` + Body EmbeddingRequest `json:"body"` + Method string `json:"method"` + URL BatchEndpoint `json:"url"` +} + +func (r BatchEmbeddingRequest) MarshalBatchLineItem() []byte { + marshal, _ := json.Marshal(r) + return marshal +} + +type Batch struct { + ID string `json:"id"` + Object string `json:"object"` + Endpoint BatchEndpoint `json:"endpoint"` + Errors *struct { + Object string `json:"object,omitempty"` + Data struct { + Code string `json:"code,omitempty"` + Message string `json:"message,omitempty"` + Param *string `json:"param,omitempty"` + Line *int `json:"line,omitempty"` + } `json:"data"` + } `json:"errors"` + InputFileID string `json:"input_file_id"` + CompletionWindow string `json:"completion_window"` + Status string `json:"status"` + OutputFileID *string `json:"output_file_id"` + ErrorFileID *string `json:"error_file_id"` + CreatedAt int `json:"created_at"` + InProgressAt *int `json:"in_progress_at"` + ExpiresAt *int `json:"expires_at"` + FinalizingAt *int `json:"finalizing_at"` + CompletedAt *int `json:"completed_at"` + FailedAt *int `json:"failed_at"` + ExpiredAt *int `json:"expired_at"` + CancellingAt *int `json:"cancelling_at"` + CancelledAt *int `json:"cancelled_at"` + RequestCounts BatchRequestCounts `json:"request_counts"` + Metadata map[string]any `json:"metadata"` +} + +type BatchRequestCounts struct { + Total int `json:"total"` + Completed int `json:"completed"` + Failed int `json:"failed"` +} + +type CreateBatchRequest struct { + InputFileID string `json:"input_file_id"` + Endpoint BatchEndpoint `json:"endpoint"` + CompletionWindow string `json:"completion_window"` + Metadata map[string]any `json:"metadata"` +} + +type BatchResponse struct { + httpHeader + Batch +} + +var ErrUploadBatchFileFailed = errors.New("upload batch file failed") + +// CreateBatch — API call to Create batch. +func (c *Client) CreateBatch( + ctx context.Context, + request CreateBatchRequest, +) (response BatchResponse, err error) { + if request.CompletionWindow == "" { + request.CompletionWindow = "24h" + } + + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(batchesSuffix), withBody(request)) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +type UploadBatchFileRequest struct { + FileName string + Lines []BatchLineItem +} + +func (r *UploadBatchFileRequest) MarshalJSONL() []byte { + buff := bytes.Buffer{} + for i, line := range r.Lines { + if i != 0 { + buff.Write([]byte("\n")) + } + buff.Write(line.MarshalBatchLineItem()) + } + return buff.Bytes() +} + +func (r *UploadBatchFileRequest) AddChatCompletion(customerID string, body ChatCompletionRequest) { + r.Lines = append(r.Lines, BatchChatCompletionRequest{ + CustomID: customerID, + Body: body, + Method: "POST", + URL: BatchEndpointChatCompletions, + }) +} + +func (r *UploadBatchFileRequest) AddCompletion(customerID string, body CompletionRequest) { + r.Lines = append(r.Lines, BatchCompletionRequest{ + CustomID: customerID, + Body: body, + Method: "POST", + URL: BatchEndpointCompletions, + }) +} + +func (r *UploadBatchFileRequest) AddEmbedding(customerID string, body EmbeddingRequest) { + r.Lines = append(r.Lines, BatchEmbeddingRequest{ + CustomID: customerID, + Body: body, + Method: "POST", + URL: BatchEndpointEmbeddings, + }) +} + +// UploadBatchFile — upload batch file. +func (c *Client) UploadBatchFile(ctx context.Context, request UploadBatchFileRequest) (File, error) { + if request.FileName == "" { + request.FileName = "@batchinput.jsonl" + } + return c.CreateFileBytes(ctx, FileBytesRequest{ + Name: request.FileName, + Bytes: request.MarshalJSONL(), + Purpose: PurposeBatch, + }) +} + +type CreateBatchWithUploadFileRequest struct { + Endpoint BatchEndpoint `json:"endpoint"` + CompletionWindow string `json:"completion_window"` + Metadata map[string]any `json:"metadata"` + UploadBatchFileRequest +} + +// CreateBatchWithUploadFile — API call to Create batch with upload file. +func (c *Client) CreateBatchWithUploadFile( + ctx context.Context, + request CreateBatchWithUploadFileRequest, +) (response BatchResponse, err error) { + var file File + file, err = c.UploadBatchFile(ctx, UploadBatchFileRequest{ + FileName: request.FileName, + Lines: request.Lines, + }) + if err != nil { + err = errors.Join(ErrUploadBatchFileFailed, err) + return + } + return c.CreateBatch(ctx, CreateBatchRequest{ + InputFileID: file.ID, + Endpoint: request.Endpoint, + CompletionWindow: request.CompletionWindow, + Metadata: request.Metadata, + }) +} + +// RetrieveBatch — API call to Retrieve batch. +func (c *Client) RetrieveBatch( + ctx context.Context, + batchID string, +) (response BatchResponse, err error) { + urlSuffix := fmt.Sprintf("%s/%s", batchesSuffix, batchID) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix)) + if err != nil { + return + } + err = c.sendRequest(req, &response) + return +} + +// CancelBatch — API call to Cancel batch. +func (c *Client) CancelBatch( + ctx context.Context, + batchID string, +) (response BatchResponse, err error) { + urlSuffix := fmt.Sprintf("%s/%s/cancel", batchesSuffix, batchID) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix)) + if err != nil { + return + } + err = c.sendRequest(req, &response) + return +} + +type ListBatchResponse struct { + httpHeader + Object string `json:"object"` + Data []Batch `json:"data"` + FirstID string `json:"first_id"` + LastID string `json:"last_id"` + HasMore bool `json:"has_more"` +} + +// ListBatch API call to List batch. +func (c *Client) ListBatch(ctx context.Context, after *string, limit *int) (response ListBatchResponse, err error) { + urlValues := url.Values{} + if limit != nil { + urlValues.Add("limit", fmt.Sprintf("%d", *limit)) + } + if after != nil { + urlValues.Add("after", *after) + } + encodedValues := "" + if len(urlValues) > 0 { + encodedValues = "?" + urlValues.Encode() + } + + urlSuffix := fmt.Sprintf("%s%s", batchesSuffix, encodedValues) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix)) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} diff --git a/batch_test.go b/batch_test.go new file mode 100644 index 00000000..4b2261e0 --- /dev/null +++ b/batch_test.go @@ -0,0 +1,368 @@ +package openai_test + +import ( + "context" + "fmt" + "net/http" + "reflect" + "testing" + + "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" +) + +func TestUploadBatchFile(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + + server.RegisterHandler("/v1/files", handleCreateFile) + req := openai.UploadBatchFileRequest{} + req.AddChatCompletion("req-1", openai.ChatCompletionRequest{ + MaxTokens: 5, + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + }) + _, err := client.UploadBatchFile(context.Background(), req) + checks.NoError(t, err, "UploadBatchFile error") +} + +func TestCreateBatch(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + + server.RegisterHandler("/v1/batches", handleBatchEndpoint) + _, err := client.CreateBatch(context.Background(), openai.CreateBatchRequest{ + InputFileID: "file-abc", + Endpoint: openai.BatchEndpointChatCompletions, + CompletionWindow: "24h", + }) + checks.NoError(t, err, "CreateBatch error") +} + +func TestCreateBatchWithUploadFile(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/files", handleCreateFile) + server.RegisterHandler("/v1/batches", handleBatchEndpoint) + req := openai.CreateBatchWithUploadFileRequest{ + Endpoint: openai.BatchEndpointChatCompletions, + } + req.AddChatCompletion("req-1", openai.ChatCompletionRequest{ + MaxTokens: 5, + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + }) + _, err := client.CreateBatchWithUploadFile(context.Background(), req) + checks.NoError(t, err, "CreateBatchWithUploadFile error") +} + +func TestRetrieveBatch(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/batches/file-id-1", handleRetrieveBatchEndpoint) + _, err := client.RetrieveBatch(context.Background(), "file-id-1") + checks.NoError(t, err, "RetrieveBatch error") +} + +func TestCancelBatch(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/batches/file-id-1/cancel", handleCancelBatchEndpoint) + _, err := client.CancelBatch(context.Background(), "file-id-1") + checks.NoError(t, err, "RetrieveBatch error") +} + +func TestListBatch(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/batches", handleBatchEndpoint) + after := "batch_abc123" + limit := 10 + _, err := client.ListBatch(context.Background(), &after, &limit) + checks.NoError(t, err, "RetrieveBatch error") +} + +func TestUploadBatchFileRequest_AddChatCompletion(t *testing.T) { + type args struct { + customerID string + body openai.ChatCompletionRequest + } + tests := []struct { + name string + args []args + want []byte + }{ + {"", []args{ + { + customerID: "req-1", + body: openai.ChatCompletionRequest{ + MaxTokens: 5, + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + }, + }, + { + customerID: "req-2", + body: openai.ChatCompletionRequest{ + MaxTokens: 5, + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + }, + }, + }, []byte("{\"custom_id\":\"req-1\",\"body\":{\"model\":\"gpt-3.5-turbo\",\"messages\":[{\"role\":\"user\",\"content\":\"Hello!\"}],\"max_tokens\":5},\"method\":\"POST\",\"url\":\"/v1/chat/completions\"}\n{\"custom_id\":\"req-2\",\"body\":{\"model\":\"gpt-3.5-turbo\",\"messages\":[{\"role\":\"user\",\"content\":\"Hello!\"}],\"max_tokens\":5},\"method\":\"POST\",\"url\":\"/v1/chat/completions\"}")}, //nolint:lll + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &openai.UploadBatchFileRequest{} + for _, arg := range tt.args { + r.AddChatCompletion(arg.customerID, arg.body) + } + got := r.MarshalJSONL() + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Marshal() got = %v, want %v", got, tt.want) + } + }) + } +} + +func TestUploadBatchFileRequest_AddCompletion(t *testing.T) { + type args struct { + customerID string + body openai.CompletionRequest + } + tests := []struct { + name string + args []args + want []byte + }{ + {"", []args{ + { + customerID: "req-1", + body: openai.CompletionRequest{ + Model: openai.GPT3Dot5Turbo, + User: "Hello", + }, + }, + { + customerID: "req-2", + body: openai.CompletionRequest{ + Model: openai.GPT3Dot5Turbo, + User: "Hello", + }, + }, + }, []byte("{\"custom_id\":\"req-1\",\"body\":{\"model\":\"gpt-3.5-turbo\",\"user\":\"Hello\"},\"method\":\"POST\",\"url\":\"/v1/completions\"}\n{\"custom_id\":\"req-2\",\"body\":{\"model\":\"gpt-3.5-turbo\",\"user\":\"Hello\"},\"method\":\"POST\",\"url\":\"/v1/completions\"}")}, //nolint:lll + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &openai.UploadBatchFileRequest{} + for _, arg := range tt.args { + r.AddCompletion(arg.customerID, arg.body) + } + got := r.MarshalJSONL() + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Marshal() got = %v, want %v", got, tt.want) + } + }) + } +} + +func TestUploadBatchFileRequest_AddEmbedding(t *testing.T) { + type args struct { + customerID string + body openai.EmbeddingRequest + } + tests := []struct { + name string + args []args + want []byte + }{ + {"", []args{ + { + customerID: "req-1", + body: openai.EmbeddingRequest{ + Model: openai.GPT3Dot5Turbo, + Input: []string{"Hello", "World"}, + }, + }, + { + customerID: "req-2", + body: openai.EmbeddingRequest{ + Model: openai.AdaEmbeddingV2, + Input: []string{"Hello", "World"}, + }, + }, + }, []byte("{\"custom_id\":\"req-1\",\"body\":{\"input\":[\"Hello\",\"World\"],\"model\":\"gpt-3.5-turbo\",\"user\":\"\"},\"method\":\"POST\",\"url\":\"/v1/embeddings\"}\n{\"custom_id\":\"req-2\",\"body\":{\"input\":[\"Hello\",\"World\"],\"model\":\"text-embedding-ada-002\",\"user\":\"\"},\"method\":\"POST\",\"url\":\"/v1/embeddings\"}")}, //nolint:lll + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &openai.UploadBatchFileRequest{} + for _, arg := range tt.args { + r.AddEmbedding(arg.customerID, arg.body) + } + got := r.MarshalJSONL() + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Marshal() got = %v, want %v", got, tt.want) + } + }) + } +} + +func handleBatchEndpoint(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + _, _ = fmt.Fprintln(w, `{ + "id": "batch_abc123", + "object": "batch", + "endpoint": "/v1/completions", + "errors": null, + "input_file_id": "file-abc123", + "completion_window": "24h", + "status": "completed", + "output_file_id": "file-cvaTdG", + "error_file_id": "file-HOWS94", + "created_at": 1711471533, + "in_progress_at": 1711471538, + "expires_at": 1711557933, + "finalizing_at": 1711493133, + "completed_at": 1711493163, + "failed_at": null, + "expired_at": null, + "cancelling_at": null, + "cancelled_at": null, + "request_counts": { + "total": 100, + "completed": 95, + "failed": 5 + }, + "metadata": { + "customer_id": "user_123456789", + "batch_description": "Nightly eval job" + } + }`) + } else if r.Method == http.MethodGet { + _, _ = fmt.Fprintln(w, `{ + "object": "list", + "data": [ + { + "id": "batch_abc123", + "object": "batch", + "endpoint": "/v1/chat/completions", + "errors": null, + "input_file_id": "file-abc123", + "completion_window": "24h", + "status": "completed", + "output_file_id": "file-cvaTdG", + "error_file_id": "file-HOWS94", + "created_at": 1711471533, + "in_progress_at": 1711471538, + "expires_at": 1711557933, + "finalizing_at": 1711493133, + "completed_at": 1711493163, + "failed_at": null, + "expired_at": null, + "cancelling_at": null, + "cancelled_at": null, + "request_counts": { + "total": 100, + "completed": 95, + "failed": 5 + }, + "metadata": { + "customer_id": "user_123456789", + "batch_description": "Nightly job" + } + } + ], + "first_id": "batch_abc123", + "last_id": "batch_abc456", + "has_more": true + }`) + } +} + +func handleRetrieveBatchEndpoint(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet { + _, _ = fmt.Fprintln(w, `{ + "id": "batch_abc123", + "object": "batch", + "endpoint": "/v1/completions", + "errors": null, + "input_file_id": "file-abc123", + "completion_window": "24h", + "status": "completed", + "output_file_id": "file-cvaTdG", + "error_file_id": "file-HOWS94", + "created_at": 1711471533, + "in_progress_at": 1711471538, + "expires_at": 1711557933, + "finalizing_at": 1711493133, + "completed_at": 1711493163, + "failed_at": null, + "expired_at": null, + "cancelling_at": null, + "cancelled_at": null, + "request_counts": { + "total": 100, + "completed": 95, + "failed": 5 + }, + "metadata": { + "customer_id": "user_123456789", + "batch_description": "Nightly eval job" + } + }`) + } +} + +func handleCancelBatchEndpoint(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + _, _ = fmt.Fprintln(w, `{ + "id": "batch_abc123", + "object": "batch", + "endpoint": "/v1/chat/completions", + "errors": null, + "input_file_id": "file-abc123", + "completion_window": "24h", + "status": "cancelling", + "output_file_id": null, + "error_file_id": null, + "created_at": 1711471533, + "in_progress_at": 1711471538, + "expires_at": 1711557933, + "finalizing_at": null, + "completed_at": null, + "failed_at": null, + "expired_at": null, + "cancelling_at": 1711475133, + "cancelled_at": null, + "request_counts": { + "total": 100, + "completed": 23, + "failed": 1 + }, + "metadata": { + "customer_id": "user_123456789", + "batch_description": "Nightly eval job" + } + }`) + } +} diff --git a/client_test.go b/client_test.go index a08d10f2..e49da9b3 100644 --- a/client_test.go +++ b/client_test.go @@ -396,6 +396,17 @@ func TestClientReturnsRequestBuilderErrors(t *testing.T) { {"CreateSpeech", func() (any, error) { return client.CreateSpeech(ctx, CreateSpeechRequest{Model: TTSModel1, Voice: VoiceAlloy}) }}, + {"CreateBatch", func() (any, error) { + return client.CreateBatch(ctx, CreateBatchRequest{}) + }}, + {"CreateBatchWithUploadFile", func() (any, error) { + return client.CreateBatchWithUploadFile(ctx, CreateBatchWithUploadFileRequest{}) + }}, + {"RetrieveBatch", func() (any, error) { + return client.RetrieveBatch(ctx, "") + }}, + {"CancelBatch", func() (any, error) { return client.CancelBatch(ctx, "") }}, + {"ListBatch", func() (any, error) { return client.ListBatch(ctx, nil, nil) }}, } for _, testCase := range testCases { diff --git a/files.go b/files.go index b40a44f1..26ad6bd7 100644 --- a/files.go +++ b/files.go @@ -22,6 +22,7 @@ const ( PurposeFineTuneResults PurposeType = "fine-tune-results" PurposeAssistants PurposeType = "assistants" PurposeAssistantsOutput PurposeType = "assistants_output" + PurposeBatch PurposeType = "batch" ) // FileBytesRequest represents a file upload request. From 68acf22a43903c1b460006e7c4b883ce73e35857 Mon Sep 17 00:00:00 2001 From: Pawel Kosiec Date: Thu, 13 Jun 2024 17:26:37 +0200 Subject: [PATCH 94/98] Support Tool Resources properties for Threads (#760) * Support Tool Resources properties for Threads * Add Chunking Strategy for Threads vector stores --- thread.go | 67 +++++++++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 60 insertions(+), 7 deletions(-) diff --git a/thread.go b/thread.go index 900e3f2e..6f752145 100644 --- a/thread.go +++ b/thread.go @@ -10,21 +10,74 @@ const ( ) type Thread struct { - ID string `json:"id"` - Object string `json:"object"` - CreatedAt int64 `json:"created_at"` - Metadata map[string]any `json:"metadata"` + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int64 `json:"created_at"` + Metadata map[string]any `json:"metadata"` + ToolResources ToolResources `json:"tool_resources,omitempty"` httpHeader } type ThreadRequest struct { - Messages []ThreadMessage `json:"messages,omitempty"` - Metadata map[string]any `json:"metadata,omitempty"` + Messages []ThreadMessage `json:"messages,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` + ToolResources *ToolResourcesRequest `json:"tool_resources,omitempty"` } +type ToolResources struct { + CodeInterpreter *CodeInterpreterToolResources `json:"code_interpreter,omitempty"` + FileSearch *FileSearchToolResources `json:"file_search,omitempty"` +} + +type CodeInterpreterToolResources struct { + FileIDs []string `json:"file_ids,omitempty"` +} + +type FileSearchToolResources struct { + VectorStoreIDs []string `json:"vector_store_ids,omitempty"` +} + +type ToolResourcesRequest struct { + CodeInterpreter *CodeInterpreterToolResourcesRequest `json:"code_interpreter,omitempty"` + FileSearch *FileSearchToolResourcesRequest `json:"file_search,omitempty"` +} + +type CodeInterpreterToolResourcesRequest struct { + FileIDs []string `json:"file_ids,omitempty"` +} + +type FileSearchToolResourcesRequest struct { + VectorStoreIDs []string `json:"vector_store_ids,omitempty"` + VectorStores []VectorStoreToolResources `json:"vector_stores,omitempty"` +} + +type VectorStoreToolResources struct { + FileIDs []string `json:"file_ids,omitempty"` + ChunkingStrategy *ChunkingStrategy `json:"chunking_strategy,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` +} + +type ChunkingStrategy struct { + Type ChunkingStrategyType `json:"type"` + Static *StaticChunkingStrategy `json:"static,omitempty"` +} + +type StaticChunkingStrategy struct { + MaxChunkSizeTokens int `json:"max_chunk_size_tokens"` + ChunkOverlapTokens int `json:"chunk_overlap_tokens"` +} + +type ChunkingStrategyType string + +const ( + ChunkingStrategyTypeAuto ChunkingStrategyType = "auto" + ChunkingStrategyTypeStatic ChunkingStrategyType = "static" +) + type ModifyThreadRequest struct { - Metadata map[string]any `json:"metadata"` + Metadata map[string]any `json:"metadata"` + ToolResources *ToolResources `json:"tool_resources,omitempty"` } type ThreadMessageRole string From 0a421308993425afed7796da8f8e0e1abafd4582 Mon Sep 17 00:00:00 2001 From: Peng Guan-Cheng Date: Wed, 19 Jun 2024 16:37:21 +0800 Subject: [PATCH 95/98] feat: provide vector store (#772) * implement vectore store feature * fix after integration testing * fix golint error * improve test to increare code coverage * fix golint anc code coverage problem * add tool_resource in assistant response * chore: code style * feat: use pagination param * feat: use pagination param * test: use pagination param * test: rm unused code --------- Co-authored-by: Denny Depok <61371551+kodernubie@users.noreply.github.com> Co-authored-by: eric.p --- assistant.go | 50 ++++--- config.go | 2 +- vector_store.go | 345 ++++++++++++++++++++++++++++++++++++++++++ vector_store_test.go | 349 +++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 728 insertions(+), 18 deletions(-) create mode 100644 vector_store.go create mode 100644 vector_store_test.go diff --git a/assistant.go b/assistant.go index 661681e8..cc13a302 100644 --- a/assistant.go +++ b/assistant.go @@ -14,16 +14,17 @@ const ( ) type Assistant struct { - ID string `json:"id"` - Object string `json:"object"` - CreatedAt int64 `json:"created_at"` - Name *string `json:"name,omitempty"` - Description *string `json:"description,omitempty"` - Model string `json:"model"` - Instructions *string `json:"instructions,omitempty"` - Tools []AssistantTool `json:"tools"` - FileIDs []string `json:"file_ids,omitempty"` - Metadata map[string]any `json:"metadata,omitempty"` + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int64 `json:"created_at"` + Name *string `json:"name,omitempty"` + Description *string `json:"description,omitempty"` + Model string `json:"model"` + Instructions *string `json:"instructions,omitempty"` + Tools []AssistantTool `json:"tools"` + FileIDs []string `json:"file_ids,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` + ToolResources *AssistantToolResource `json:"tool_resources,omitempty"` httpHeader } @@ -34,6 +35,7 @@ const ( AssistantToolTypeCodeInterpreter AssistantToolType = "code_interpreter" AssistantToolTypeRetrieval AssistantToolType = "retrieval" AssistantToolTypeFunction AssistantToolType = "function" + AssistantToolTypeFileSearch AssistantToolType = "file_search" ) type AssistantTool struct { @@ -41,19 +43,33 @@ type AssistantTool struct { Function *FunctionDefinition `json:"function,omitempty"` } +type AssistantToolFileSearch struct { + VectorStoreIDs []string `json:"vector_store_ids"` +} + +type AssistantToolCodeInterpreter struct { + FileIDs []string `json:"file_ids"` +} + +type AssistantToolResource struct { + FileSearch *AssistantToolFileSearch `json:"file_search,omitempty"` + CodeInterpreter *AssistantToolCodeInterpreter `json:"code_interpreter,omitempty"` +} + // AssistantRequest provides the assistant request parameters. // When modifying the tools the API functions as the following: // If Tools is undefined, no changes are made to the Assistant's tools. // If Tools is empty slice it will effectively delete all of the Assistant's tools. // If Tools is populated, it will replace all of the existing Assistant's tools with the provided tools. type AssistantRequest struct { - Model string `json:"model"` - Name *string `json:"name,omitempty"` - Description *string `json:"description,omitempty"` - Instructions *string `json:"instructions,omitempty"` - Tools []AssistantTool `json:"-"` - FileIDs []string `json:"file_ids,omitempty"` - Metadata map[string]any `json:"metadata,omitempty"` + Model string `json:"model"` + Name *string `json:"name,omitempty"` + Description *string `json:"description,omitempty"` + Instructions *string `json:"instructions,omitempty"` + Tools []AssistantTool `json:"-"` + FileIDs []string `json:"file_ids,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` + ToolResources *AssistantToolResource `json:"tool_resources,omitempty"` } // MarshalJSON provides a custom marshaller for the assistant request to handle the API use cases diff --git a/config.go b/config.go index bb437c97..1347567d 100644 --- a/config.go +++ b/config.go @@ -24,7 +24,7 @@ const ( const AzureAPIKeyHeader = "api-key" -const defaultAssistantVersion = "v1" // This will be deprecated by the end of 2024. +const defaultAssistantVersion = "v2" // upgrade to v2 to support vector store // ClientConfig is a configuration of a client. type ClientConfig struct { diff --git a/vector_store.go b/vector_store.go new file mode 100644 index 00000000..5c364362 --- /dev/null +++ b/vector_store.go @@ -0,0 +1,345 @@ +package openai + +import ( + "context" + "fmt" + "net/http" + "net/url" +) + +const ( + vectorStoresSuffix = "/vector_stores" + vectorStoresFilesSuffix = "/files" + vectorStoresFileBatchesSuffix = "/file_batches" +) + +type VectorStoreFileCount struct { + InProgress int `json:"in_progress"` + Completed int `json:"completed"` + Failed int `json:"failed"` + Cancelled int `json:"cancelled"` + Total int `json:"total"` +} + +type VectorStore struct { + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int64 `json:"created_at"` + Name string `json:"name"` + UsageBytes int `json:"usage_bytes"` + FileCounts VectorStoreFileCount `json:"file_counts"` + Status string `json:"status"` + ExpiresAfter *VectorStoreExpires `json:"expires_after"` + ExpiresAt *int `json:"expires_at"` + Metadata map[string]any `json:"metadata"` + + httpHeader +} + +type VectorStoreExpires struct { + Anchor string `json:"anchor"` + Days int `json:"days"` +} + +// VectorStoreRequest provides the vector store request parameters. +type VectorStoreRequest struct { + Name string `json:"name,omitempty"` + FileIDs []string `json:"file_ids,omitempty"` + ExpiresAfter *VectorStoreExpires `json:"expires_after,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` +} + +// VectorStoresList is a list of vector store. +type VectorStoresList struct { + VectorStores []VectorStore `json:"data"` + LastID *string `json:"last_id"` + FirstID *string `json:"first_id"` + HasMore bool `json:"has_more"` + httpHeader +} + +type VectorStoreDeleteResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Deleted bool `json:"deleted"` + + httpHeader +} + +type VectorStoreFile struct { + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int64 `json:"created_at"` + VectorStoreID string `json:"vector_store_id"` + UsageBytes int `json:"usage_bytes"` + Status string `json:"status"` + + httpHeader +} + +type VectorStoreFileRequest struct { + FileID string `json:"file_id"` +} + +type VectorStoreFilesList struct { + VectorStoreFiles []VectorStoreFile `json:"data"` + + httpHeader +} + +type VectorStoreFileBatch struct { + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int64 `json:"created_at"` + VectorStoreID string `json:"vector_store_id"` + Status string `json:"status"` + FileCounts VectorStoreFileCount `json:"file_counts"` + + httpHeader +} + +type VectorStoreFileBatchRequest struct { + FileIDs []string `json:"file_ids"` +} + +// CreateVectorStore creates a new vector store. +func (c *Client) CreateVectorStore(ctx context.Context, request VectorStoreRequest) (response VectorStore, err error) { + req, _ := c.newRequest( + ctx, + http.MethodPost, + c.fullURL(vectorStoresSuffix), + withBody(request), + withBetaAssistantVersion(c.config.AssistantVersion), + ) + + err = c.sendRequest(req, &response) + return +} + +// RetrieveVectorStore retrieves an vector store. +func (c *Client) RetrieveVectorStore( + ctx context.Context, + vectorStoreID string, +) (response VectorStore, err error) { + urlSuffix := fmt.Sprintf("%s/%s", vectorStoresSuffix, vectorStoreID) + req, _ := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, &response) + return +} + +// ModifyVectorStore modifies a vector store. +func (c *Client) ModifyVectorStore( + ctx context.Context, + vectorStoreID string, + request VectorStoreRequest, +) (response VectorStore, err error) { + urlSuffix := fmt.Sprintf("%s/%s", vectorStoresSuffix, vectorStoreID) + req, _ := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, &response) + return +} + +// DeleteVectorStore deletes an vector store. +func (c *Client) DeleteVectorStore( + ctx context.Context, + vectorStoreID string, +) (response VectorStoreDeleteResponse, err error) { + urlSuffix := fmt.Sprintf("%s/%s", vectorStoresSuffix, vectorStoreID) + req, _ := c.newRequest(ctx, http.MethodDelete, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, &response) + return +} + +// ListVectorStores Lists the currently available vector store. +func (c *Client) ListVectorStores( + ctx context.Context, + pagination Pagination, +) (response VectorStoresList, err error) { + urlValues := url.Values{} + + if pagination.After != nil { + urlValues.Add("after", *pagination.After) + } + if pagination.Order != nil { + urlValues.Add("order", *pagination.Order) + } + if pagination.Limit != nil { + urlValues.Add("limit", fmt.Sprintf("%d", *pagination.Limit)) + } + if pagination.Before != nil { + urlValues.Add("before", *pagination.Before) + } + + encodedValues := "" + if len(urlValues) > 0 { + encodedValues = "?" + urlValues.Encode() + } + + urlSuffix := fmt.Sprintf("%s%s", vectorStoresSuffix, encodedValues) + req, _ := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, &response) + return +} + +// CreateVectorStoreFile creates a new vector store file. +func (c *Client) CreateVectorStoreFile( + ctx context.Context, + vectorStoreID string, + request VectorStoreFileRequest, +) (response VectorStoreFile, err error) { + urlSuffix := fmt.Sprintf("%s/%s%s", vectorStoresSuffix, vectorStoreID, vectorStoresFilesSuffix) + req, _ := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), + withBody(request), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, &response) + return +} + +// RetrieveVectorStoreFile retrieves a vector store file. +func (c *Client) RetrieveVectorStoreFile( + ctx context.Context, + vectorStoreID string, + fileID string, +) (response VectorStoreFile, err error) { + urlSuffix := fmt.Sprintf("%s/%s%s/%s", vectorStoresSuffix, vectorStoreID, vectorStoresFilesSuffix, fileID) + req, _ := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, &response) + return +} + +// DeleteVectorStoreFile deletes an existing file. +func (c *Client) DeleteVectorStoreFile( + ctx context.Context, + vectorStoreID string, + fileID string, +) (err error) { + urlSuffix := fmt.Sprintf("%s/%s%s/%s", vectorStoresSuffix, vectorStoreID, vectorStoresFilesSuffix, fileID) + req, _ := c.newRequest(ctx, http.MethodDelete, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, nil) + return +} + +// ListVectorStoreFiles Lists the currently available files for a vector store. +func (c *Client) ListVectorStoreFiles( + ctx context.Context, + vectorStoreID string, + pagination Pagination, +) (response VectorStoreFilesList, err error) { + urlValues := url.Values{} + if pagination.After != nil { + urlValues.Add("after", *pagination.After) + } + if pagination.Limit != nil { + urlValues.Add("limit", fmt.Sprintf("%d", *pagination.Limit)) + } + if pagination.Before != nil { + urlValues.Add("before", *pagination.Before) + } + if pagination.Order != nil { + urlValues.Add("order", *pagination.Order) + } + + encodedValues := "" + if len(urlValues) > 0 { + encodedValues = "?" + urlValues.Encode() + } + + urlSuffix := fmt.Sprintf("%s/%s%s%s", vectorStoresSuffix, vectorStoreID, vectorStoresFilesSuffix, encodedValues) + req, _ := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, &response) + return +} + +// CreateVectorStoreFileBatch creates a new vector store file batch. +func (c *Client) CreateVectorStoreFileBatch( + ctx context.Context, + vectorStoreID string, + request VectorStoreFileBatchRequest, +) (response VectorStoreFileBatch, err error) { + urlSuffix := fmt.Sprintf("%s/%s%s", vectorStoresSuffix, vectorStoreID, vectorStoresFileBatchesSuffix) + req, _ := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), + withBody(request), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, &response) + return +} + +// RetrieveVectorStoreFileBatch retrieves a vector store file batch. +func (c *Client) RetrieveVectorStoreFileBatch( + ctx context.Context, + vectorStoreID string, + batchID string, +) (response VectorStoreFileBatch, err error) { + urlSuffix := fmt.Sprintf("%s/%s%s/%s", vectorStoresSuffix, vectorStoreID, vectorStoresFileBatchesSuffix, batchID) + req, _ := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, &response) + return +} + +// CancelVectorStoreFileBatch cancel a new vector store file batch. +func (c *Client) CancelVectorStoreFileBatch( + ctx context.Context, + vectorStoreID string, + batchID string, +) (response VectorStoreFileBatch, err error) { + urlSuffix := fmt.Sprintf("%s/%s%s/%s%s", vectorStoresSuffix, + vectorStoreID, vectorStoresFileBatchesSuffix, batchID, "/cancel") + req, _ := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, &response) + return +} + +// ListVectorStoreFiles Lists the currently available files for a vector store. +func (c *Client) ListVectorStoreFilesInBatch( + ctx context.Context, + vectorStoreID string, + batchID string, + pagination Pagination, +) (response VectorStoreFilesList, err error) { + urlValues := url.Values{} + if pagination.After != nil { + urlValues.Add("after", *pagination.After) + } + if pagination.Limit != nil { + urlValues.Add("limit", fmt.Sprintf("%d", *pagination.Limit)) + } + if pagination.Before != nil { + urlValues.Add("before", *pagination.Before) + } + if pagination.Order != nil { + urlValues.Add("order", *pagination.Order) + } + + encodedValues := "" + if len(urlValues) > 0 { + encodedValues = "?" + urlValues.Encode() + } + + urlSuffix := fmt.Sprintf("%s/%s%s/%s%s%s", vectorStoresSuffix, + vectorStoreID, vectorStoresFileBatchesSuffix, batchID, "/files", encodedValues) + req, _ := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, &response) + return +} diff --git a/vector_store_test.go b/vector_store_test.go new file mode 100644 index 00000000..58b9a857 --- /dev/null +++ b/vector_store_test.go @@ -0,0 +1,349 @@ +package openai_test + +import ( + "context" + + openai "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" + + "encoding/json" + "fmt" + "net/http" + "testing" +) + +// TestVectorStore Tests the vector store endpoint of the API using the mocked server. +func TestVectorStore(t *testing.T) { + vectorStoreID := "vs_abc123" + vectorStoreName := "TestStore" + vectorStoreFileID := "file-wB6RM6wHdA49HfS2DJ9fEyrH" + vectorStoreFileBatchID := "vsfb_abc123" + limit := 20 + order := "desc" + after := "vs_abc122" + before := "vs_abc123" + + client, server, teardown := setupOpenAITestServer() + defer teardown() + + server.RegisterHandler( + "/v1/vector_stores/"+vectorStoreID+"/files/"+vectorStoreFileID, + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet { + resBytes, _ := json.Marshal(openai.VectorStoreFile{ + ID: vectorStoreFileID, + Object: "vector_store.file", + CreatedAt: 1234567890, + VectorStoreID: vectorStoreID, + Status: "completed", + }) + fmt.Fprintln(w, string(resBytes)) + } else if r.Method == http.MethodDelete { + fmt.Fprintln(w, `{ + id: "file-wB6RM6wHdA49HfS2DJ9fEyrH", + object: "vector_store.file.deleted", + deleted: true + }`) + } + }, + ) + + server.RegisterHandler( + "/v1/vector_stores/"+vectorStoreID+"/files", + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet { + resBytes, _ := json.Marshal(openai.VectorStoreFilesList{ + VectorStoreFiles: []openai.VectorStoreFile{ + { + ID: vectorStoreFileID, + Object: "vector_store.file", + CreatedAt: 1234567890, + VectorStoreID: vectorStoreID, + }, + }, + }) + fmt.Fprintln(w, string(resBytes)) + } else if r.Method == http.MethodPost { + var request openai.VectorStoreFileRequest + err := json.NewDecoder(r.Body).Decode(&request) + checks.NoError(t, err, "Decode error") + + resBytes, _ := json.Marshal(openai.VectorStoreFile{ + ID: request.FileID, + Object: "vector_store.file", + CreatedAt: 1234567890, + VectorStoreID: vectorStoreID, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + server.RegisterHandler( + "/v1/vector_stores/"+vectorStoreID+"/file_batches/"+vectorStoreFileBatchID+"/files", + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet { + resBytes, _ := json.Marshal(openai.VectorStoreFilesList{ + VectorStoreFiles: []openai.VectorStoreFile{ + { + ID: vectorStoreFileID, + Object: "vector_store.file", + CreatedAt: 1234567890, + VectorStoreID: vectorStoreID, + }, + }, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + server.RegisterHandler( + "/v1/vector_stores/"+vectorStoreID+"/file_batches/"+vectorStoreFileBatchID+"/cancel", + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + resBytes, _ := json.Marshal(openai.VectorStoreFileBatch{ + ID: vectorStoreFileBatchID, + Object: "vector_store.file_batch", + CreatedAt: 1234567890, + VectorStoreID: vectorStoreID, + Status: "cancelling", + FileCounts: openai.VectorStoreFileCount{ + InProgress: 0, + Completed: 1, + Failed: 0, + Cancelled: 0, + Total: 0, + }, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + server.RegisterHandler( + "/v1/vector_stores/"+vectorStoreID+"/file_batches/"+vectorStoreFileBatchID, + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet { + resBytes, _ := json.Marshal(openai.VectorStoreFileBatch{ + ID: vectorStoreFileBatchID, + Object: "vector_store.file_batch", + CreatedAt: 1234567890, + VectorStoreID: vectorStoreID, + Status: "completed", + FileCounts: openai.VectorStoreFileCount{ + Completed: 1, + }, + }) + fmt.Fprintln(w, string(resBytes)) + } else if r.Method == http.MethodPost { + resBytes, _ := json.Marshal(openai.VectorStoreFileBatch{ + ID: vectorStoreFileBatchID, + Object: "vector_store.file_batch", + CreatedAt: 1234567890, + VectorStoreID: vectorStoreID, + Status: "cancelling", + FileCounts: openai.VectorStoreFileCount{ + Completed: 1, + }, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + server.RegisterHandler( + "/v1/vector_stores/"+vectorStoreID+"/file_batches", + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + var request openai.VectorStoreFileBatchRequest + err := json.NewDecoder(r.Body).Decode(&request) + checks.NoError(t, err, "Decode error") + + resBytes, _ := json.Marshal(openai.VectorStoreFileBatch{ + ID: vectorStoreFileBatchID, + Object: "vector_store.file_batch", + CreatedAt: 1234567890, + VectorStoreID: vectorStoreID, + Status: "completed", + FileCounts: openai.VectorStoreFileCount{ + InProgress: 0, + Completed: len(request.FileIDs), + Failed: 0, + Cancelled: 0, + Total: 0, + }, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + server.RegisterHandler( + "/v1/vector_stores/"+vectorStoreID, + func(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet: + resBytes, _ := json.Marshal(openai.VectorStore{ + ID: vectorStoreID, + Object: "vector_store", + CreatedAt: 1234567890, + Name: vectorStoreName, + }) + fmt.Fprintln(w, string(resBytes)) + case http.MethodPost: + var request openai.VectorStore + err := json.NewDecoder(r.Body).Decode(&request) + checks.NoError(t, err, "Decode error") + + resBytes, _ := json.Marshal(openai.VectorStore{ + ID: vectorStoreID, + Object: "vector_store", + CreatedAt: 1234567890, + Name: request.Name, + }) + fmt.Fprintln(w, string(resBytes)) + case http.MethodDelete: + fmt.Fprintln(w, `{ + "id": "vectorstore_abc123", + "object": "vector_store.deleted", + "deleted": true + }`) + } + }, + ) + + server.RegisterHandler( + "/v1/vector_stores", + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + var request openai.VectorStoreRequest + err := json.NewDecoder(r.Body).Decode(&request) + checks.NoError(t, err, "Decode error") + + resBytes, _ := json.Marshal(openai.VectorStore{ + ID: vectorStoreID, + Object: "vector_store", + CreatedAt: 1234567890, + Name: request.Name, + FileCounts: openai.VectorStoreFileCount{ + InProgress: 0, + Completed: 0, + Failed: 0, + Cancelled: 0, + Total: 0, + }, + }) + fmt.Fprintln(w, string(resBytes)) + } else if r.Method == http.MethodGet { + resBytes, _ := json.Marshal(openai.VectorStoresList{ + LastID: &vectorStoreID, + FirstID: &vectorStoreID, + VectorStores: []openai.VectorStore{ + { + ID: vectorStoreID, + Object: "vector_store", + CreatedAt: 1234567890, + Name: vectorStoreName, + }, + }, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + ctx := context.Background() + + t.Run("create_vector_store", func(t *testing.T) { + _, err := client.CreateVectorStore(ctx, openai.VectorStoreRequest{ + Name: vectorStoreName, + }) + checks.NoError(t, err, "CreateVectorStore error") + }) + + t.Run("retrieve_vector_store", func(t *testing.T) { + _, err := client.RetrieveVectorStore(ctx, vectorStoreID) + checks.NoError(t, err, "RetrieveVectorStore error") + }) + + t.Run("delete_vector_store", func(t *testing.T) { + _, err := client.DeleteVectorStore(ctx, vectorStoreID) + checks.NoError(t, err, "DeleteVectorStore error") + }) + + t.Run("list_vector_store", func(t *testing.T) { + _, err := client.ListVectorStores(context.TODO(), openai.Pagination{ + Limit: &limit, + Order: &order, + After: &after, + Before: &before, + }) + checks.NoError(t, err, "ListVectorStores error") + }) + + t.Run("create_vector_store_file", func(t *testing.T) { + _, err := client.CreateVectorStoreFile(context.TODO(), vectorStoreID, openai.VectorStoreFileRequest{ + FileID: vectorStoreFileID, + }) + checks.NoError(t, err, "CreateVectorStoreFile error") + }) + + t.Run("list_vector_store_files", func(t *testing.T) { + _, err := client.ListVectorStoreFiles(ctx, vectorStoreID, openai.Pagination{ + Limit: &limit, + Order: &order, + After: &after, + Before: &before, + }) + checks.NoError(t, err, "ListVectorStoreFiles error") + }) + + t.Run("retrieve_vector_store_file", func(t *testing.T) { + _, err := client.RetrieveVectorStoreFile(ctx, vectorStoreID, vectorStoreFileID) + checks.NoError(t, err, "RetrieveVectorStoreFile error") + }) + + t.Run("delete_vector_store_file", func(t *testing.T) { + err := client.DeleteVectorStoreFile(ctx, vectorStoreID, vectorStoreFileID) + checks.NoError(t, err, "DeleteVectorStoreFile error") + }) + + t.Run("modify_vector_store", func(t *testing.T) { + _, err := client.ModifyVectorStore(ctx, vectorStoreID, openai.VectorStoreRequest{ + Name: vectorStoreName, + }) + checks.NoError(t, err, "ModifyVectorStore error") + }) + + t.Run("create_vector_store_file_batch", func(t *testing.T) { + _, err := client.CreateVectorStoreFileBatch(ctx, vectorStoreID, openai.VectorStoreFileBatchRequest{ + FileIDs: []string{vectorStoreFileID}, + }) + checks.NoError(t, err, "CreateVectorStoreFileBatch error") + }) + + t.Run("retrieve_vector_store_file_batch", func(t *testing.T) { + _, err := client.RetrieveVectorStoreFileBatch(ctx, vectorStoreID, vectorStoreFileBatchID) + checks.NoError(t, err, "RetrieveVectorStoreFileBatch error") + }) + + t.Run("list_vector_store_files_in_batch", func(t *testing.T) { + _, err := client.ListVectorStoreFilesInBatch( + ctx, + vectorStoreID, + vectorStoreFileBatchID, + openai.Pagination{ + Limit: &limit, + Order: &order, + After: &after, + Before: &before, + }) + checks.NoError(t, err, "ListVectorStoreFilesInBatch error") + }) + + t.Run("cancel_vector_store_file_batch", func(t *testing.T) { + _, err := client.CancelVectorStoreFileBatch(ctx, vectorStoreID, vectorStoreFileBatchID) + checks.NoError(t, err, "CancelVectorStoreFileBatch error") + }) +} From e31185974c45949cc58c24a6cbf5ca969fb0f622 Mon Sep 17 00:00:00 2001 From: Alex Baranov <677093+sashabaranov@users.noreply.github.com> Date: Wed, 26 Jun 2024 14:06:52 +0100 Subject: [PATCH 96/98] remove errors.Join (#778) --- batch.go | 4 ---- 1 file changed, 4 deletions(-) diff --git a/batch.go b/batch.go index 4aba966b..a43d401a 100644 --- a/batch.go +++ b/batch.go @@ -4,7 +4,6 @@ import ( "bytes" "context" "encoding/json" - "errors" "fmt" "net/http" "net/url" @@ -109,8 +108,6 @@ type BatchResponse struct { Batch } -var ErrUploadBatchFileFailed = errors.New("upload batch file failed") - // CreateBatch — API call to Create batch. func (c *Client) CreateBatch( ctx context.Context, @@ -202,7 +199,6 @@ func (c *Client) CreateBatchWithUploadFile( Lines: request.Lines, }) if err != nil { - err = errors.Join(ErrUploadBatchFileFailed, err) return } return c.CreateBatch(ctx, CreateBatchRequest{ From 03851d20327b7df5358ff9fb0ac96f476be1875a Mon Sep 17 00:00:00 2001 From: Adrian Liechti Date: Sun, 30 Jun 2024 17:20:10 +0200 Subject: [PATCH 97/98] allow custom voice and speech models (#691) --- speech.go | 31 ------------------------------- speech_test.go | 17 ----------------- 2 files changed, 48 deletions(-) diff --git a/speech.go b/speech.go index 7e22e755..19b21bdf 100644 --- a/speech.go +++ b/speech.go @@ -2,7 +2,6 @@ package openai import ( "context" - "errors" "net/http" ) @@ -36,11 +35,6 @@ const ( SpeechResponseFormatPcm SpeechResponseFormat = "pcm" ) -var ( - ErrInvalidSpeechModel = errors.New("invalid speech model") - ErrInvalidVoice = errors.New("invalid voice") -) - type CreateSpeechRequest struct { Model SpeechModel `json:"model"` Input string `json:"input"` @@ -49,32 +43,7 @@ type CreateSpeechRequest struct { Speed float64 `json:"speed,omitempty"` // Optional, default to 1.0 } -func contains[T comparable](s []T, e T) bool { - for _, v := range s { - if v == e { - return true - } - } - return false -} - -func isValidSpeechModel(model SpeechModel) bool { - return contains([]SpeechModel{TTSModel1, TTSModel1HD, TTSModelCanary}, model) -} - -func isValidVoice(voice SpeechVoice) bool { - return contains([]SpeechVoice{VoiceAlloy, VoiceEcho, VoiceFable, VoiceOnyx, VoiceNova, VoiceShimmer}, voice) -} - func (c *Client) CreateSpeech(ctx context.Context, request CreateSpeechRequest) (response RawResponse, err error) { - if !isValidSpeechModel(request.Model) { - err = ErrInvalidSpeechModel - return - } - if !isValidVoice(request.Voice) { - err = ErrInvalidVoice - return - } req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/audio/speech", string(request.Model)), withBody(request), withContentType("application/json"), diff --git a/speech_test.go b/speech_test.go index d9ba58b1..f1e405c3 100644 --- a/speech_test.go +++ b/speech_test.go @@ -95,21 +95,4 @@ func TestSpeechIntegration(t *testing.T) { err = os.WriteFile("test.mp3", buf, 0644) checks.NoError(t, err, "Create error") }) - t.Run("invalid model", func(t *testing.T) { - _, err := client.CreateSpeech(context.Background(), openai.CreateSpeechRequest{ - Model: "invalid_model", - Input: "Hello!", - Voice: openai.VoiceAlloy, - }) - checks.ErrorIs(t, err, openai.ErrInvalidSpeechModel, "CreateSpeech error") - }) - - t.Run("invalid voice", func(t *testing.T) { - _, err := client.CreateSpeech(context.Background(), openai.CreateSpeechRequest{ - Model: openai.TTSModel1, - Input: "Hello!", - Voice: "invalid_voice", - }) - checks.ErrorIs(t, err, openai.ErrInvalidVoice, "CreateSpeech error") - }) } From 727944c47886924800128d1c33df706b4159eb23 Mon Sep 17 00:00:00 2001 From: Luca Giannini <68999840+LGXerxes@users.noreply.github.com> Date: Fri, 12 Jul 2024 12:31:11 +0200 Subject: [PATCH 98/98] feat: ParallelToolCalls to ChatCompletionRequest with helper functions (#787) * added ParallelToolCalls to ChatCompletionRequest with helper functions * added tests for coverage * changed ParallelToolCalls to any --- chat.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/chat.go b/chat.go index a1eb1172..eb494f41 100644 --- a/chat.go +++ b/chat.go @@ -218,6 +218,8 @@ type ChatCompletionRequest struct { ToolChoice any `json:"tool_choice,omitempty"` // Options for streaming response. Only set this when you set stream: true. StreamOptions *StreamOptions `json:"stream_options,omitempty"` + // Disable the default behavior of parallel tool calls by setting it: false. + ParallelToolCalls any `json:"parallel_tool_calls,omitempty"` } type StreamOptions struct {