Skip to content

Commit

Permalink
update SetForceNumericString
Browse files Browse the repository at this point in the history
  • Loading branch information
muroon committed Sep 7, 2021
1 parent 6df95ed commit 255dabf
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 9 deletions.
4 changes: 2 additions & 2 deletions context.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ const forceNumericContextKey string = "force_numeric_string_key"
// ForceNumericStringContextKey context key of force numeric string
var ForceNumericStringContextKey = contextPrefix + forceNumericContextKey

func SetForceNumericString(ctx context.Context) context.Context {
return context.WithValue(ctx, ForceNumericStringContextKey, true)
func SetForceNumericString(ctx context.Context, val bool) context.Context {
return context.WithValue(ctx, ForceNumericStringContextKey, val)
}

func getForNumericString(ctx context.Context) bool {
Expand Down
9 changes: 7 additions & 2 deletions context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,15 @@ func Test_getForNumericString(t *testing.T) {
want: false,
},
{
name: "SetForceNumericString",
ctx: SetForceNumericString(context.Background()),
name: "SetForceNumericString:true",
ctx: SetForceNumericString(context.Background(), true),
want: true,
},
{
name: "SetForceNumericString:false",
ctx: SetForceNumericString(context.Background(), false),
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand Down
29 changes: 24 additions & 5 deletions db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ func TestPrepare(t *testing.T) {
BooleanType: false,
FloatType: 3.14159,
DoubleType: 3141592653589.8,
StringType: "another string",
StringType: "123.456",
TimestampType: athenaTimestamp(time.Date(2017, 12, 3, 20, 11, 12, 0, time.UTC)),
DateType: athenaDate(time.Date(2017, 12, 3, 0, 0, 0, 0, time.UTC)),
DecimalType: 0.48,
Expand All @@ -202,10 +202,12 @@ func TestPrepare(t *testing.T) {
}

tests := []struct {
name string
sql string
params []interface{}
want dummyRow
name string
sql string
params []interface{}
startFunc func(ctx context.Context) context.Context
endFunc func(ctx context.Context) context.Context
want dummyRow
}{
{
name: "NoInput",
Expand All @@ -231,6 +233,14 @@ func TestPrepare(t *testing.T) {
params: []interface{}{strconv.FormatFloat(float64(data[0].FloatType), 'f', -1, 32)},
want: data[0],
},
{
name: "Numeric String",
sql: fmt.Sprintf("select * from %s where stringType = ?", harness.table),
params: []interface{}{data[2].StringType},
startFunc: func(ctx context.Context) context.Context { return SetForceNumericString(ctx, true) },
endFunc: func(ctx context.Context) context.Context { return SetForceNumericString(ctx, false) },
want: data[2],
},
}

for _, resultMode := range resultModes {
Expand All @@ -246,6 +256,15 @@ func TestPrepare(t *testing.T) {

for _, test := range tests {
t.Run(fmt.Sprintf("ResultMode:%v/%s", resultMode, test.name), func(t *testing.T) {
if startFunc := test.startFunc; startFunc != nil {
ctx = startFunc(ctx)
}
if endFunc := test.startFunc; endFunc != nil {
defer func() {
ctx = endFunc(ctx)
}()
}

stmt, err := harness.prepare(ctx, test.sql)
defer func() {
err := stmt.Close()
Expand Down

0 comments on commit 255dabf

Please sign in to comment.