diff --git a/context.go b/context.go index 2c0a8a3..284958a 100644 --- a/context.go +++ b/context.go @@ -79,8 +79,8 @@ const forceNumericContextKey string = "force_numeric_string_key" // ForceNumericStringContextKey context key of force numeric string var ForceNumericStringContextKey = contextPrefix + forceNumericContextKey -func SetForceNumericString(ctx context.Context) context.Context { - return context.WithValue(ctx, ForceNumericStringContextKey, true) +func SetForceNumericString(ctx context.Context, val bool) context.Context { + return context.WithValue(ctx, ForceNumericStringContextKey, val) } func getForNumericString(ctx context.Context) bool { diff --git a/context_test.go b/context_test.go index dada3bc..d6f2264 100644 --- a/context_test.go +++ b/context_test.go @@ -17,10 +17,15 @@ func Test_getForNumericString(t *testing.T) { want: false, }, { - name: "SetForceNumericString", - ctx: SetForceNumericString(context.Background()), + name: "SetForceNumericString:true", + ctx: SetForceNumericString(context.Background(), true), want: true, }, + { + name: "SetForceNumericString:false", + ctx: SetForceNumericString(context.Background(), false), + want: false, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/db_test.go b/db_test.go index 03061e0..14a3cad 100644 --- a/db_test.go +++ b/db_test.go @@ -187,7 +187,7 @@ func TestPrepare(t *testing.T) { BooleanType: false, FloatType: 3.14159, DoubleType: 3141592653589.8, - StringType: "another string", + StringType: "123.456", TimestampType: athenaTimestamp(time.Date(2017, 12, 3, 20, 11, 12, 0, time.UTC)), DateType: athenaDate(time.Date(2017, 12, 3, 0, 0, 0, 0, time.UTC)), DecimalType: 0.48, @@ -202,10 +202,12 @@ func TestPrepare(t *testing.T) { } tests := []struct { - name string - sql string - params []interface{} - want dummyRow + name string + sql string + params []interface{} + startFunc func(ctx context.Context) context.Context + endFunc func(ctx context.Context) context.Context + want dummyRow }{ { name: "NoInput", @@ -231,6 +233,14 @@ func TestPrepare(t *testing.T) { params: []interface{}{strconv.FormatFloat(float64(data[0].FloatType), 'f', -1, 32)}, want: data[0], }, + { + name: "Numeric String", + sql: fmt.Sprintf("select * from %s where stringType = ?", harness.table), + params: []interface{}{data[2].StringType}, + startFunc: func(ctx context.Context) context.Context { return SetForceNumericString(ctx, true) }, + endFunc: func(ctx context.Context) context.Context { return SetForceNumericString(ctx, false) }, + want: data[2], + }, } for _, resultMode := range resultModes { @@ -246,6 +256,15 @@ func TestPrepare(t *testing.T) { for _, test := range tests { t.Run(fmt.Sprintf("ResultMode:%v/%s", resultMode, test.name), func(t *testing.T) { + if startFunc := test.startFunc; startFunc != nil { + ctx = startFunc(ctx) + } + if endFunc := test.startFunc; endFunc != nil { + defer func() { + ctx = endFunc(ctx) + }() + } + stmt, err := harness.prepare(ctx, test.sql) defer func() { err := stmt.Close()