Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor/internal testing #194

Merged
merged 10 commits into from
Mar 24, 2023
42 changes: 11 additions & 31 deletions api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package openai_test

import (
. "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test/checks"

"context"
"errors"
Expand All @@ -20,25 +21,17 @@ func TestAPI(t *testing.T) {
c := NewClient(apiToken)
ctx := context.Background()
_, err = c.ListEngines(ctx)
if err != nil {
t.Fatalf("ListEngines error: %v", err)
}
checks.NoError(t, err, "ListEngines error")

_, err = c.GetEngine(ctx, "davinci")
if err != nil {
t.Fatalf("GetEngine error: %v", err)
}
checks.NoError(t, err, "GetEngine error")

fileRes, err := c.ListFiles(ctx)
if err != nil {
t.Fatalf("ListFiles error: %v", err)
}
checks.NoError(t, err, "ListFiles error")

if len(fileRes.Files) > 0 {
_, err = c.GetFile(ctx, fileRes.Files[0].ID)
if err != nil {
t.Fatalf("GetFile error: %v", err)
}
checks.NoError(t, err, "GetFile error")
} // else skip

embeddingReq := EmbeddingRequest{
Expand All @@ -49,9 +42,7 @@ func TestAPI(t *testing.T) {
Model: AdaSearchQuery,
}
_, err = c.CreateEmbeddings(ctx, embeddingReq)
if err != nil {
t.Fatalf("Embedding error: %v", err)
}
checks.NoError(t, err, "Embedding error")

_, err = c.CreateChatCompletion(
ctx,
Expand All @@ -66,9 +57,7 @@ func TestAPI(t *testing.T) {
},
)

if err != nil {
t.Errorf("CreateChatCompletion (without name) returned error: %v", err)
}
checks.NoError(t, err, "CreateChatCompletion (without name) returned error")

_, err = c.CreateChatCompletion(
ctx,
Expand All @@ -83,20 +72,15 @@ func TestAPI(t *testing.T) {
},
},
)

if err != nil {
t.Errorf("CreateChatCompletion (with name) returned error: %v", err)
}
checks.NoError(t, err, "CreateChatCompletion (with name) returned error")

stream, err := c.CreateCompletionStream(ctx, CompletionRequest{
Prompt: "Ex falso quodlibet",
Model: GPT3Ada,
MaxTokens: 5,
Stream: true,
})
if err != nil {
t.Errorf("CreateCompletionStream returned error: %v", err)
}
checks.NoError(t, err, "CreateCompletionStream returned error")
defer stream.Close()

counter := 0
Expand Down Expand Up @@ -126,9 +110,7 @@ func TestAPIError(t *testing.T) {
c := NewClient(apiToken + "_invalid")
ctx := context.Background()
_, err = c.ListEngines(ctx)
if err == nil {
t.Fatal("ListEngines did not fail")
}
checks.NoError(t, err, "ListEngines did not fail")

var apiErr *APIError
if !errors.As(err, &apiErr) {
Expand All @@ -154,9 +136,7 @@ func TestRequestError(t *testing.T) {
c := NewClientWithConfig(config)
ctx := context.Background()
_, err = c.ListEngines(ctx)
if err == nil {
t.Fatal("ListEngines request did not fail")
}
checks.HasError(t, err, "ListEngines did not fail")

var reqErr *RequestError
if !errors.As(err, &reqErr) {
Expand Down
18 changes: 6 additions & 12 deletions audio_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (

. "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test"
"github.com/sashabaranov/go-openai/internal/test/checks"

"context"
"testing"
Expand Down Expand Up @@ -62,9 +63,7 @@ func TestAudio(t *testing.T) {
Model: "whisper-3",
}
_, err = tc.createFn(ctx, req)
if err != nil {
t.Fatalf("audio API error: %v", err)
}
checks.NoError(t, err, "audio API error")
})
}
}
Expand Down Expand Up @@ -115,19 +114,16 @@ func TestAudioWithOptionalArgs(t *testing.T) {
Language: "zh",
}
_, err = tc.createFn(ctx, req)
if err != nil {
t.Fatalf("audio API error: %v", err)
}
checks.NoError(t, err, "audio API error")
})
}
}

// createTestFile creates a fake file with "hello" as the content.
func createTestFile(t *testing.T, path string) {
file, err := os.Create(path)
if err != nil {
t.Fatalf("failed to create file %v", err)
}
checks.NoError(t, err, "failed to create file")

if _, err = file.WriteString("hello"); err != nil {
t.Fatalf("failed to write to file %v", err)
}
Expand All @@ -139,9 +135,7 @@ func createTestDirectory(t *testing.T) (path string, cleanup func()) {
t.Helper()

path, err := os.MkdirTemp(os.TempDir(), "")
if err != nil {
t.Fatal(err)
}
checks.NoError(t, err)

return path, func() { os.RemoveAll(path) }
}
Expand Down
28 changes: 10 additions & 18 deletions chat_stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package openai_test
import (
. "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test"
"github.com/sashabaranov/go-openai/internal/test/checks"

"context"
"encoding/json"
Expand Down Expand Up @@ -55,9 +56,7 @@ func TestCreateChatCompletionStream(t *testing.T) {
dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...)

_, err := w.Write(dataBytes)
if err != nil {
t.Errorf("Write error: %s", err)
}
checks.NoError(t, err, "Write error")
}))
defer server.Close()

Expand Down Expand Up @@ -85,9 +84,7 @@ func TestCreateChatCompletionStream(t *testing.T) {
}

stream, err := client.CreateChatCompletionStream(ctx, request)
if err != nil {
t.Errorf("CreateCompletionStream returned error: %v", err)
}
checks.NoError(t, err, "CreateCompletionStream returned error")
defer stream.Close()

expectedResponses := []ChatCompletionStreamResponse{
Expand Down Expand Up @@ -126,9 +123,7 @@ func TestCreateChatCompletionStream(t *testing.T) {
t.Logf("%d: %s", ix, string(b))

receivedResponse, streamErr := stream.Recv()
if streamErr != nil {
t.Errorf("stream.Recv() failed: %v", streamErr)
}
checks.NoError(t, streamErr, "stream.Recv() failed")
if !compareChatResponses(expectedResponse, receivedResponse) {
t.Errorf("Stream response %v is %v, expected %v", ix, receivedResponse, expectedResponse)
}
Expand All @@ -140,6 +135,8 @@ func TestCreateChatCompletionStream(t *testing.T) {
}

_, streamErr = stream.Recv()

checks.ErrorIs(t, streamErr, io.EOF, "stream.Recv() did not return EOF when the stream is finished")
if !errors.Is(streamErr, io.EOF) {
t.Errorf("stream.Recv() did not return EOF when the stream is finished: %v", streamErr)
}
Expand All @@ -166,9 +163,7 @@ func TestCreateChatCompletionStreamError(t *testing.T) {
}

_, err := w.Write(dataBytes)
if err != nil {
t.Errorf("Write error: %s", err)
}
checks.NoError(t, err, "Write error")
}))
defer server.Close()

Expand Down Expand Up @@ -196,15 +191,12 @@ func TestCreateChatCompletionStreamError(t *testing.T) {
}

stream, err := client.CreateChatCompletionStream(ctx, request)
if err != nil {
t.Errorf("CreateCompletionStream returned error: %v", err)
}
checks.NoError(t, err, "CreateCompletionStream returned error")
defer stream.Close()

_, streamErr := stream.Recv()
if streamErr == nil {
t.Errorf("stream.Recv() did not return error")
}
checks.HasError(t, streamErr, "stream.Recv() did not return error")

var apiErr *APIError
if !errors.As(streamErr, &apiErr) {
t.Errorf("stream.Recv() did not return APIError")
Expand Down
15 changes: 5 additions & 10 deletions chat_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@ package openai_test
import (
. "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test"
"github.com/sashabaranov/go-openai/internal/test/checks"

"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
Expand All @@ -33,9 +33,8 @@ func TestChatCompletionsWrongModel(t *testing.T) {
},
}
_, err := client.CreateChatCompletion(ctx, req)
if !errors.Is(err, ErrChatCompletionInvalidModel) {
t.Fatalf("CreateChatCompletion should return ErrChatCompletionInvalidModel, but returned: %v", err)
}
msg := fmt.Sprintf("CreateChatCompletion should return wrong model error, returned: %s", err)
checks.ErrorIs(t, err, ErrChatCompletionInvalidModel, msg)
}

func TestChatCompletionsWithStream(t *testing.T) {
Expand All @@ -48,9 +47,7 @@ func TestChatCompletionsWithStream(t *testing.T) {
Stream: true,
}
_, err := client.CreateChatCompletion(ctx, req)
if !errors.Is(err, ErrChatCompletionStreamNotSupported) {
t.Fatalf("CreateChatCompletion didn't return ErrChatCompletionStreamNotSupported error")
}
checks.ErrorIs(t, err, ErrChatCompletionStreamNotSupported, "unexpected error")
}

// TestCompletions Tests the completions endpoint of the API using the mocked server.
Expand Down Expand Up @@ -79,9 +76,7 @@ func TestChatCompletions(t *testing.T) {
},
}
_, err = client.CreateChatCompletion(ctx, req)
if err != nil {
t.Fatalf("CreateChatCompletion error: %v", err)
}
checks.NoError(t, err, "CreateChatCompletion error")
}

// handleChatCompletionEndpoint Handles the ChatGPT completion endpoint by the test server.
Expand Down
5 changes: 2 additions & 3 deletions completion_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package openai_test
import (
. "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test"
"github.com/sashabaranov/go-openai/internal/test/checks"

"context"
"encoding/json"
Expand Down Expand Up @@ -66,9 +67,7 @@ func TestCompletions(t *testing.T) {
}
req.Prompt = "Lorem ipsum"
_, err = client.CreateCompletion(ctx, req)
if err != nil {
t.Fatalf("CreateCompletion error: %v", err)
}
checks.NoError(t, err, "CreateCompletion error")
}

// handleCompletionEndpoint Handles the completion endpoint by the test server.
Expand Down
5 changes: 2 additions & 3 deletions edits_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package openai_test
import (
. "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test"
"github.com/sashabaranov/go-openai/internal/test/checks"

"context"
"encoding/json"
Expand Down Expand Up @@ -40,9 +41,7 @@ func TestEdits(t *testing.T) {
N: 3,
}
response, err := client.Edits(ctx, editReq)
if err != nil {
t.Fatalf("Edits error: %v", err)
}
checks.NoError(t, err, "Edits error")
if len(response.Choices) != editReq.N {
t.Fatalf("edits does not properly return the correct number of choices")
}
Expand Down
5 changes: 2 additions & 3 deletions embeddings_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package openai_test

import (
. "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test/checks"

"bytes"
"encoding/json"
Expand Down Expand Up @@ -38,9 +39,7 @@ func TestEmbedding(t *testing.T) {
// marshal embeddingReq to JSON and confirm that the model field equals
// the AdaSearchQuery type
marshaled, err := json.Marshal(embeddingReq)
if err != nil {
t.Fatalf("Could not marshal embedding request: %v", err)
}
checks.NoError(t, err, "Could not marshal embedding request")
if !bytes.Contains(marshaled, []byte(`"model":"`+model.String()+`"`)) {
t.Fatalf("Expected embedding request to contain model field")
}
Expand Down
10 changes: 4 additions & 6 deletions error_accumulator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"testing"

"github.com/sashabaranov/go-openai/internal/test"
"github.com/sashabaranov/go-openai/internal/test/checks"
)

var (
Expand Down Expand Up @@ -81,16 +82,13 @@ func TestErrorAccumulatorWriteErrors(t *testing.T) {
ctx := context.Background()

stream, err := client.CreateChatCompletionStream(ctx, ChatCompletionRequest{})
if err != nil {
t.Fatal(err)
}
checks.NoError(t, err)

stream.errAccumulator = &defaultErrorAccumulator{
buffer: &failingErrorBuffer{},
unmarshaler: &jsonUnmarshaler{},
}

_, err = stream.Recv()
if !errors.Is(err, errTestErrorAccumulatorWriteFailed) {
t.Fatalf("Did not return error when write failed: %v", err)
}
checks.ErrorIs(t, err, errTestErrorAccumulatorWriteFailed, "Did not return error when write failed", err.Error())
}
5 changes: 2 additions & 3 deletions files_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package openai_test
import (
. "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test"
"github.com/sashabaranov/go-openai/internal/test/checks"

"context"
"encoding/json"
Expand Down Expand Up @@ -33,9 +34,7 @@ func TestFileUpload(t *testing.T) {
Purpose: "fine-tune",
}
_, err = client.CreateFile(ctx, req)
if err != nil {
t.Fatalf("CreateFile error: %v", err)
}
checks.NoError(t, err, "CreateFile erro")
}

// handleCreateFile Handles the images endpoint by the test server.
Expand Down
Loading