Skip to content

Commit

Permalink
fix float type parameter in stmt
Browse files Browse the repository at this point in the history
  • Loading branch information
muroon committed Oct 3, 2021
1 parent 80028c7 commit 25ae291
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 117 deletions.
17 changes: 0 additions & 17 deletions context.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,20 +70,3 @@ func getCatalog(ctx context.Context) (string, bool) {
val, ok := ctx.Value(CatalogContextKey).(string)
return val, ok
}

/*
* force using string type with numeric string
*/
const forceNumericContextKey string = "force_numeric_string_key"

// ForceNumericStringContextKey context key of force numeric string
var ForceNumericStringContextKey = contextPrefix + forceNumericContextKey

func SetForceNumericString(ctx context.Context, val bool) context.Context {
return context.WithValue(ctx, ForceNumericStringContextKey, val)
}

func getForNumericString(ctx context.Context) bool {
val, ok := ctx.Value(ForceNumericStringContextKey).(bool)
return val && ok
}
31 changes: 0 additions & 31 deletions context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,37 +5,6 @@ import (
"testing"
)

func Test_getForNumericString(t *testing.T) {
tests := []struct {
name string
ctx context.Context
want bool
}{
{
name: "Default",
ctx: context.Background(),
want: false,
},
{
name: "SetForceNumericString:true",
ctx: SetForceNumericString(context.Background(), true),
want: true,
},
{
name: "SetForceNumericString:false",
ctx: SetForceNumericString(context.Background(), false),
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := getForNumericString(tt.ctx); got != tt.want {
t.Errorf("getForNumericString() = %v, want %v", got, tt.want)
}
})
}
}

func Test_getCatalog(t *testing.T) {
tests := []struct {
name string
Expand Down
15 changes: 6 additions & 9 deletions db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"encoding/json"
"fmt"
"os"
"strconv"
"strings"
"testing"
"time"
Expand Down Expand Up @@ -229,17 +228,15 @@ func TestPrepare(t *testing.T) {
},
{
name: "FloatType",
sql: fmt.Sprintf("select * from %s where cast(floattype as decimal(8,7)) = ?", harness.table),
params: []interface{}{strconv.FormatFloat(float64(data[0].FloatType), 'f', -1, 32)},
sql: fmt.Sprintf("select * from %s where floattype = ?", harness.table),
params: []interface{}{data[0].FloatType},
want: data[0],
},
{
name: "Numeric String",
sql: fmt.Sprintf("select * from %s where stringType = ?", harness.table),
params: []interface{}{data[2].StringType},
startFunc: func(ctx context.Context) context.Context { return SetForceNumericString(ctx, true) },
endFunc: func(ctx context.Context) context.Context { return SetForceNumericString(ctx, false) },
want: data[2],
name: "DoubleType",
sql: fmt.Sprintf("select * from %s where doubletype = ?", harness.table),
params: []interface{}{data[0].DoubleType},
want: data[0],
},
}

Expand Down
56 changes: 3 additions & 53 deletions doc/prepared_statements.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,60 +44,10 @@ You can use [Prepared Statements on Athena](https://docs.aws.amazon.com/athena/l

## Examples

#### int parameter

```
intParam := 1
stmt, _ := db.PrepareContext(ctx, "SELECT * FROM test_table WHERE int_column = ?")
rows, _ := stmt.QueryContext(ctx, intParam)
```
execute `SELECT * FROM test_table WHERE int_column = 1`

#### string parameter

```
stringParam := "string value"
stmt, _ := db.PrepareContext(ctx, "SELECT * FROM test_table WHERE string_column = ?")
rows, _ := stmt.QueryContext(ctx, stringParam)
```
execute `SELECT * FROM test_table WHERE string_column = 'string value'`

#### float parameter

Neither float32 nor float64 is supported because digit precision will easily cause large problems.
If you want to set a parameter for float type column on Athena, please use string type parameter whose characters are all numeric.

stmt, _ := db.PrepareContext(ctx, "SELECT * FROM test_table WHERE int_column = ? AND string_column = ?")
rows, _ := stmt.QueryContext(ctx, intParam, stringParam)
```
// for float column
floatParam := "3.144"
stmt, _ := db.PrepareContext(ctx, "SELECT * FROM test_table WHERE float_column = ?")
rows, _ := stmt.QueryContext(ctx, floatParam)
```
execute SQL `SELECT * FROM test_table WHERE float_column = 3.144`

|golang (string)|in SQL|
| --- | --- |
|"123"|123|
|"1.23"|1.23|
|"1.23a"|'1.23a'|

#### for numeric string parameter

By default, numeric string parameter isn't converted to string type in SQL, as shown in the float parameter example.
If you want to set a numeric value for a string type column on Athena, put true in SetForceNumericString function.

```
// for string column
floatParam := "3.144"
stmt, _ := db.PrepareContext(ctx, "SELECT * FROM test_table WHERE string_column = ?")
ctx = SetForceNumericString(ctx, true) // set true
rows, _ := stmt.QueryContext(ctx, floatParam)
```
execute SQL `SELECT * FROM test_table WHERE string_column = '3.144'`

**under setting true in SetForceNumericString**

|golang (string)|in SQL|
| --- | --- |
|"123"|'123'|
|"1.23"|'1.23'|
execute `SELECT * FROM test_table WHERE int_column = 1 and string_column = 'string value'`
12 changes: 5 additions & 7 deletions stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,13 +154,11 @@ func (s *stmtAthena) runQuery(ctx context.Context, query string) (driver.Rows, e
}

func serial(ctx context.Context, v interface{}) (string, error) {
forceNumericString := getForNumericString(ctx)
if !forceNumericString {
if x, ok := v.(string); ok {
if _, err := strconv.ParseFloat(string(x), 64); err == nil {
return presto.Serial(presto.Numeric(x))
}
}
switch x := v.(type) {
case float32:
return strconv.FormatFloat(float64(x), 'g', -1, 32), nil
case float64:
return strconv.FormatFloat(x, 'g', -1, 64), nil
}

return presto.Serial(v)
Expand Down

0 comments on commit 25ae291

Please sign in to comment.