package openai_test import ( "bytes" "context" "encoding/json" "errors" "fmt" "math" "net/http" "reflect" "testing" "github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai/internal/test/checks" ) func TestEmbedding(t *testing.T) { embeddedModels := []openai.EmbeddingModel{ openai.AdaSimilarity, openai.BabbageSimilarity, openai.CurieSimilarity, openai.DavinciSimilarity, openai.AdaSearchDocument, openai.AdaSearchQuery, openai.BabbageSearchDocument, openai.BabbageSearchQuery, openai.CurieSearchDocument, openai.CurieSearchQuery, openai.DavinciSearchDocument, openai.DavinciSearchQuery, openai.AdaCodeSearchCode, openai.AdaCodeSearchText, openai.BabbageCodeSearchCode, openai.BabbageCodeSearchText, } for _, model := range embeddedModels { // test embedding request with strings (simple embedding request) embeddingReq := openai.EmbeddingRequest{ Input: []string{ "The food was delicious and the waiter", "Other examples of embedding request", }, Model: model, } // marshal embeddingReq to JSON and confirm that the model field equals // the AdaSearchQuery type marshaled, err := json.Marshal(embeddingReq) checks.NoError(t, err, "Could not marshal embedding request") if !bytes.Contains(marshaled, []byte(`"model":"`+model+`"`)) { t.Fatalf("Expected embedding request to contain model field") } // test embedding request with strings embeddingReqStrings := openai.EmbeddingRequestStrings{ Input: []string{ "The food was delicious and the waiter", "Other examples of embedding request", }, Model: model, } marshaled, err = json.Marshal(embeddingReqStrings) checks.NoError(t, err, "Could not marshal embedding request") if !bytes.Contains(marshaled, []byte(`"model":"`+model+`"`)) { t.Fatalf("Expected embedding request to contain model field") } // test embedding request with tokens embeddingReqTokens := openai.EmbeddingRequestTokens{ Input: [][]int{ {464, 2057, 373, 12625, 290, 262, 46612}, {6395, 6096, 286, 11525, 12083, 2581}, }, Model: model, } marshaled, err = json.Marshal(embeddingReqTokens) checks.NoError(t, err, "Could not marshal embedding request") if !bytes.Contains(marshaled, []byte(`"model":"`+model+`"`)) { t.Fatalf("Expected embedding request to contain model field") } } } func TestEmbeddingEndpoint(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() sampleEmbeddings := []openai.Embedding{ {Embedding: []float32{1.23, 4.56, 7.89}}, {Embedding: []float32{-0.006968617, -0.0052718227, 0.011901081}}, } sampleBase64Embeddings := []openai.Base64Embedding{ {Embedding: "pHCdP4XrkUDhevxA"}, {Embedding: "/1jku0G/rLvA/EI8"}, } server.RegisterHandler( "/v1/embeddings", func(w http.ResponseWriter, r *http.Request) { var req struct { EncodingFormat openai.EmbeddingEncodingFormat `json:"encoding_format"` User string `json:"user"` } _ = json.NewDecoder(r.Body).Decode(&req) var resBytes []byte switch { case req.User == "invalid": w.WriteHeader(http.StatusBadRequest) return case req.EncodingFormat == openai.EmbeddingEncodingFormatBase64: resBytes, _ = json.Marshal(openai.EmbeddingResponseBase64{Data: sampleBase64Embeddings}) default: resBytes, _ = json.Marshal(openai.EmbeddingResponse{Data: sampleEmbeddings}) } fmt.Fprintln(w, string(resBytes)) }, ) // test create embeddings with strings (simple embedding request) res, err := client.CreateEmbeddings(context.Background(), openai.EmbeddingRequest{}) checks.NoError(t, err, "CreateEmbeddings error") if !reflect.DeepEqual(res.Data, sampleEmbeddings) { t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data) } // test create embeddings with strings (simple embedding request) res, err = client.CreateEmbeddings( context.Background(), openai.EmbeddingRequest{ EncodingFormat: openai.EmbeddingEncodingFormatBase64, }, ) checks.NoError(t, err, "CreateEmbeddings error") if !reflect.DeepEqual(res.Data, sampleEmbeddings) { t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data) } // test create embeddings with strings res, err = client.CreateEmbeddings(context.Background(), openai.EmbeddingRequestStrings{}) checks.NoError(t, err, "CreateEmbeddings strings error") if !reflect.DeepEqual(res.Data, sampleEmbeddings) { t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data) } // test create embeddings with tokens res, err = client.CreateEmbeddings(context.Background(), openai.EmbeddingRequestTokens{}) checks.NoError(t, err, "CreateEmbeddings tokens error") if !reflect.DeepEqual(res.Data, sampleEmbeddings) { t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data) } // test failed sendRequest _, err = client.CreateEmbeddings(context.Background(), openai.EmbeddingRequest{ User: "invalid", EncodingFormat: openai.EmbeddingEncodingFormatBase64, }) checks.HasError(t, err, "CreateEmbeddings error") } func TestAzureEmbeddingEndpoint(t *testing.T) { client, server, teardown := setupAzureTestServer() defer teardown() sampleEmbeddings := []openai.Embedding{ {Embedding: []float32{1.23, 4.56, 7.89}}, {Embedding: []float32{-0.006968617, -0.0052718227, 0.011901081}}, } server.RegisterHandler( "/openai/deployments/text-embedding-ada-002/embeddings", func(w http.ResponseWriter, _ *http.Request) { resBytes, _ := json.Marshal(openai.EmbeddingResponse{Data: sampleEmbeddings}) fmt.Fprintln(w, string(resBytes)) }, ) // test create embeddings with strings (simple embedding request) res, err := client.CreateEmbeddings(context.Background(), openai.EmbeddingRequest{ Model: openai.AdaEmbeddingV2, }) checks.NoError(t, err, "CreateEmbeddings error") if !reflect.DeepEqual(res.Data, sampleEmbeddings) { t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data) } } func TestEmbeddingResponseBase64_ToEmbeddingResponse(t *testing.T) { type fields struct { Object string Data []openai.Base64Embedding Model openai.EmbeddingModel Usage openai.Usage } tests := []struct { name string fields fields want openai.EmbeddingResponse wantErr bool }{ { name: "test embedding response base64 to embedding response", fields: fields{ Data: []openai.Base64Embedding{ {Embedding: "pHCdP4XrkUDhevxA"}, {Embedding: "/1jku0G/rLvA/EI8"}, }, }, want: openai.EmbeddingResponse{ Data: []openai.Embedding{ {Embedding: []float32{1.23, 4.56, 7.89}}, {Embedding: []float32{-0.006968617, -0.0052718227, 0.011901081}}, }, }, wantErr: false, }, { name: "Invalid embedding", fields: fields{ Data: []openai.Base64Embedding{ { Embedding: "----", }, }, }, want: openai.EmbeddingResponse{}, wantErr: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { r := &openai.EmbeddingResponseBase64{ Object: tt.fields.Object, Data: tt.fields.Data, Model: tt.fields.Model, Usage: tt.fields.Usage, } got, err := r.ToEmbeddingResponse() if (err != nil) != tt.wantErr { t.Errorf("EmbeddingResponseBase64.ToEmbeddingResponse() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { t.Errorf("EmbeddingResponseBase64.ToEmbeddingResponse() = %v, want %v", got, tt.want) } }) } } func TestDotProduct(t *testing.T) { v1 := &openai.Embedding{Embedding: []float32{1, 2, 3}} v2 := &openai.Embedding{Embedding: []float32{2, 4, 6}} expected := float32(28.0) result, err := v1.DotProduct(v2) if err != nil { t.Errorf("Unexpected error: %v", err) } if math.Abs(float64(result-expected)) > 1e-12 { t.Errorf("Unexpected result. Expected: %v, but got %v", expected, result) } v1 = &openai.Embedding{Embedding: []float32{1, 0, 0}} v2 = &openai.Embedding{Embedding: []float32{0, 1, 0}} expected = float32(0.0) result, err = v1.DotProduct(v2) if err != nil { t.Errorf("Unexpected error: %v", err) } if math.Abs(float64(result-expected)) > 1e-12 { t.Errorf("Unexpected result. Expected: %v, but got %v", expected, result) } // Test for VectorLengthMismatchError v1 = &openai.Embedding{Embedding: []float32{1, 0, 0}} v2 = &openai.Embedding{Embedding: []float32{0, 1}} _, err = v1.DotProduct(v2) if !errors.Is(err, openai.ErrVectorLengthMismatch) { t.Errorf("Expected Vector Length Mismatch Error, but got: %v", err) } }