Skip to content

Commit

Permalink
move marshaller and unmarshaler into internal pkg (#304) (#325)
Browse files Browse the repository at this point in the history
  • Loading branch information
vvatanabe committed May 28, 2023
1 parent 980504b commit 62eb4be
Show file tree
Hide file tree
Showing 12 changed files with 57 additions and 46 deletions.
4 changes: 3 additions & 1 deletion chat_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"bufio"
"context"
"net/http"

utils "github.com/sashabaranov/go-openai/internal"
)

type ChatCompletionStreamChoiceDelta struct {
Expand Down Expand Up @@ -65,7 +67,7 @@ func (c *Client) CreateChatCompletionStream(
reader: bufio.NewReader(resp.Body),
response: resp,
errAccumulator: newErrorAccumulator(),
unmarshaler: &jsonUnmarshaler{},
unmarshaler: &utils.JSONUnmarshaler{},
},
}
return
Expand Down
8 changes: 5 additions & 3 deletions error_accumulator.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"bytes"
"fmt"
"io"

utils "github.com/sashabaranov/go-openai/internal"
)

type errorAccumulator interface {
Expand All @@ -19,13 +21,13 @@ type errorBuffer interface {

type defaultErrorAccumulator struct {
buffer errorBuffer
unmarshaler unmarshaler
unmarshaler utils.Unmarshaler
}

func newErrorAccumulator() errorAccumulator {
return &defaultErrorAccumulator{
buffer: &bytes.Buffer{},
unmarshaler: &jsonUnmarshaler{},
unmarshaler: &utils.JSONUnmarshaler{},
}
}

Expand All @@ -42,7 +44,7 @@ func (e *defaultErrorAccumulator) unmarshalError() (errResp *ErrorResponse) {
return
}

err := e.unmarshaler.unmarshal(e.buffer.Bytes(), &errResp)
err := e.unmarshaler.Unmarshal(e.buffer.Bytes(), &errResp)
if err != nil {
errResp = nil
}
Expand Down
7 changes: 4 additions & 3 deletions error_accumulator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"net/http"
"testing"

utils "github.com/sashabaranov/go-openai/internal"
"github.com/sashabaranov/go-openai/internal/test"
"github.com/sashabaranov/go-openai/internal/test/checks"
)
Expand All @@ -33,7 +34,7 @@ func (b *failingErrorBuffer) Bytes() []byte {
return []byte{}
}

func (*failingUnMarshaller) unmarshal(_ []byte, _ any) error {
func (*failingUnMarshaller) Unmarshal(_ []byte, _ any) error {
return errTestUnmarshalerFailed
}

Expand Down Expand Up @@ -62,7 +63,7 @@ func TestErrorAccumulatorReturnsUnmarshalerErrors(t *testing.T) {
func TestErrorByteWriteErrors(t *testing.T) {
accumulator := &defaultErrorAccumulator{
buffer: &failingErrorBuffer{},
unmarshaler: &jsonUnmarshaler{},
unmarshaler: &utils.JSONUnmarshaler{},
}
err := accumulator.write([]byte("{"))
if !errors.Is(err, errTestErrorAccumulatorWriteFailed) {
Expand Down Expand Up @@ -91,7 +92,7 @@ func TestErrorAccumulatorWriteErrors(t *testing.T) {

stream.errAccumulator = &defaultErrorAccumulator{
buffer: &failingErrorBuffer{},
unmarshaler: &jsonUnmarshaler{},
unmarshaler: &utils.JSONUnmarshaler{},
}

_, err = stream.Recv()
Expand Down
4 changes: 2 additions & 2 deletions files_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package openai //nolint:testpackage // testing private field

import (
. "github.com/sashabaranov/go-openai/internal"
utils "github.com/sashabaranov/go-openai/internal"
"github.com/sashabaranov/go-openai/internal/test"
"github.com/sashabaranov/go-openai/internal/test/checks"

Expand Down Expand Up @@ -86,7 +86,7 @@ func TestFileUploadWithFailingFormBuilder(t *testing.T) {
config.BaseURL = ""
client := NewClientWithConfig(config)
mockBuilder := &mockFormBuilder{}
client.createFormBuilder = func(io.Writer) FormBuilder {
client.createFormBuilder = func(io.Writer) utils.FormBuilder {
return mockBuilder
}

Expand Down
15 changes: 15 additions & 0 deletions internal/marshaller.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package openai

import (
"encoding/json"
)

type Marshaller interface {
Marshal(value any) ([]byte, error)
}

type JSONMarshaller struct{}

func (jm *JSONMarshaller) Marshal(value any) ([]byte, error) {
return json.Marshal(value)
}
15 changes: 15 additions & 0 deletions internal/unmarshaler.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package openai

import (
"encoding/json"
)

type Unmarshaler interface {
Unmarshal(data []byte, v any) error
}

type JSONUnmarshaler struct{}

func (jm *JSONUnmarshaler) Unmarshal(data []byte, v any) error {
return json.Unmarshal(data, v)
}
15 changes: 0 additions & 15 deletions marshaller.go

This file was deleted.

8 changes: 5 additions & 3 deletions request_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,21 @@ import (
"bytes"
"context"
"net/http"

utils "github.com/sashabaranov/go-openai/internal"
)

type requestBuilder interface {
build(ctx context.Context, method, url string, request any) (*http.Request, error)
}

type httpRequestBuilder struct {
marshaller marshaller
marshaller utils.Marshaller
}

func newRequestBuilder() *httpRequestBuilder {
return &httpRequestBuilder{
marshaller: &jsonMarshaller{},
marshaller: &utils.JSONMarshaller{},
}
}

Expand All @@ -26,7 +28,7 @@ func (b *httpRequestBuilder) build(ctx context.Context, method, url string, requ
}

var reqBytes []byte
reqBytes, err := b.marshaller.marshal(request)
reqBytes, err := b.marshaller.Marshal(request)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion request_builder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ type (
failingMarshaller struct{}
)

func (*failingMarshaller) marshal(_ any) ([]byte, error) {
func (*failingMarshaller) Marshal(_ any) ([]byte, error) {
return []byte{}, errTestMarshallerFailed
}

Expand Down
4 changes: 3 additions & 1 deletion stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (
"context"
"errors"
"net/http"

utils "github.com/sashabaranov/go-openai/internal"
)

var (
Expand Down Expand Up @@ -54,7 +56,7 @@ func (c *Client) CreateCompletionStream(
reader: bufio.NewReader(resp.Body),
response: resp,
errAccumulator: newErrorAccumulator(),
unmarshaler: &jsonUnmarshaler{},
unmarshaler: &utils.JSONUnmarshaler{},
},
}
return
Expand Down
6 changes: 4 additions & 2 deletions stream_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (
"fmt"
"io"
"net/http"

utils "github.com/sashabaranov/go-openai/internal"
)

type streamable interface {
Expand All @@ -19,7 +21,7 @@ type streamReader[T streamable] struct {
reader *bufio.Reader
response *http.Response
errAccumulator errorAccumulator
unmarshaler unmarshaler
unmarshaler utils.Unmarshaler
}

func (stream *streamReader[T]) Recv() (response T, err error) {
Expand Down Expand Up @@ -63,7 +65,7 @@ waitForData:
return
}

err = stream.unmarshaler.unmarshal(line, &response)
err = stream.unmarshaler.Unmarshal(line, &response)
return
}

Expand Down
15 changes: 0 additions & 15 deletions unmarshaler.go

This file was deleted.

0 comments on commit 62eb4be

Please sign in to comment.