Skip to content

Commit

Permalink
update test
Browse files Browse the repository at this point in the history
  • Loading branch information
muroon committed Jul 6, 2021
1 parent 76d7de7 commit c8c7a7c
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 31 deletions.
3 changes: 0 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,6 @@ The tests support a few environment variables:
- `S3_BUCKET` can be used to override the default S3 bucket of "go-athena-tests"
- `ATHENA_REGION` or `AWS_DEFAULT_REGION` can be used to override the default region of "us-east-1"
- `ATHENA_WORK_GROUP` can be used to override the default workgroup of "primary"
- `ATHENA_AUTO_OUTPUT_LOCATION` is a parameter for test of getting output_location value from workgroup
- 0 (default value): Useless
- 1: It can make you get output_location value from query result location in workgroup


[database/sql]: https://golang.org/pkg/database/sql/
Expand Down
93 changes: 65 additions & 28 deletions db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,10 @@ import (
)

var (
AthenaDatabase = "go_athena_tests"
S3Bucket = "go-athena-tests"
AwsRegion = "us-east-1"
WorkGroup = "primary"
AutoOutputLocation = false
AthenaDatabase = "go_athena_tests"
S3Bucket = "go-athena-tests"
AwsRegion = "us-east-1"
WorkGroup = "primary"
)

func init() {
Expand All @@ -45,11 +44,10 @@ func init() {
if v := os.Getenv("ATHENA_WORK_GROUP"); v != "" {
WorkGroup = v
}
AutoOutputLocation = os.Getenv("ATHENA_AUTO_OUTPUT_LOCATION") == "1"
}

func TestQuery(t *testing.T) {
harness := setup(t)
harness := setup(t, false)
defer harness.teardown()

expected := []dummyRow{
Expand Down Expand Up @@ -151,6 +149,38 @@ func TestQuery(t *testing.T) {
require.Equal(t, 3, index+1, fmt.Sprintf("row count. resultMode:%v", resultMode))
}
}
func TestQueryForUsingWorkGroup(t *testing.T) {
resultModes := []ResultMode{
ResultModeAPI,
ResultModeDL,
ResultModeGzipDL,
}

for _, resultMode := range resultModes {
t.Run(fmt.Sprintf("ResultMode:%v", resultMode), func(t *testing.T) {
harness := setup(t, true)
defer harness.teardown()

ctx := context.Background()
switch resultMode {
case ResultModeAPI:
ctx = SetAPIMode(ctx)
case ResultModeDL:
ctx = SetDLMode(ctx)
case ResultModeGzipDL:
ctx = SetGzipDLMode(ctx)
}

rows := harness.mustQuery(ctx, "select count(*) as cnt from %s", harness.table)
defer rows.Close()
var cnt int
for rows.Next() {
require.NoError(t, rows.Scan(&cnt))
assert.Equal(t, 0, cnt)
}
})
}
}

func TestOpen(t *testing.T) {
var acfg []*aws.Config
Expand All @@ -164,33 +194,40 @@ func TestOpen(t *testing.T) {
ResultModeGzipDL,
}

config := Config{
Session: session,
Database: AthenaDatabase,
WorkGroup: WorkGroup,
Timeout: timeOutLimitDefault,
}
if !AutoOutputLocation {
config.OutputLocation = fmt.Sprintf("s3:https://%s", S3Bucket)
s3Buckes := []string{
S3Bucket,
"",
}

for _, resultMode := range resultModes {
config.ResultMode = resultMode
db, err := Open(config)
require.NoError(t, err, fmt.Sprintf("Open. resultMode:%v", resultMode))
for _, s3Bucket := range s3Buckes {
config := Config{
Session: session,
Database: AthenaDatabase,
WorkGroup: WorkGroup,
Timeout: timeOutLimitDefault,
}
if s3Bucket != "" {
config.OutputLocation = fmt.Sprintf("s3:https://%s", s3Bucket)
}

ctx := context.Background()
_, err = db.QueryContext(ctx, "SELECT 1")
if resultMode == ResultModeGzipDL {
require.Error(t, err, "Query IN Gzip DL Mode")
} else {
require.NoError(t, err, fmt.Sprintf("Query IN resultMode:%v", resultMode))
for _, resultMode := range resultModes {
config.ResultMode = resultMode
db, err := Open(config)
require.NoError(t, err, fmt.Sprintf("Open. resultMode:%v", resultMode))

ctx := context.Background()
_, err = db.QueryContext(ctx, "SELECT 1")
if resultMode == ResultModeGzipDL {
require.Error(t, err, "Query IN Gzip DL Mode")
} else {
require.NoError(t, err, fmt.Sprintf("Query IN resultMode:%v", resultMode))
}
}
}
}

func TestDDLQuery(t *testing.T) {
harness := setup(t)
harness := setup(t, false)
defer harness.teardown()

rows := harness.mustQuery(context.Background(), "show tables")
Expand Down Expand Up @@ -232,7 +269,7 @@ type athenaHarness struct {
table string
}

func setup(t *testing.T) *athenaHarness {
func setup(t *testing.T, useWorkGroup bool) *athenaHarness {
var acfg []*aws.Config
acfg = append(acfg, &aws.Config{
Region: aws.String(AwsRegion),
Expand All @@ -244,7 +281,7 @@ func setup(t *testing.T) *athenaHarness {
harness := athenaHarness{t: t, sess: sess}

connStr := fmt.Sprintf("db=%s&output_location=s3:https://%s&region=%s", AthenaDatabase, S3Bucket, AwsRegion)
if AutoOutputLocation {
if useWorkGroup {
connStr = fmt.Sprintf("db=%s&region=%s&workgroup=%s", AthenaDatabase, AwsRegion, WorkGroup)
}

Expand Down

0 comments on commit c8c7a7c

Please sign in to comment.