Skip to content

Commit

Permalink
getting location from workgroup
Browse files Browse the repository at this point in the history
  • Loading branch information
muroon committed Jul 4, 2021
1 parent daf154b commit 76d7de7
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 18 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ The tests support a few environment variables:
- `ATHENA_DATABASE` can be used to override the default database "go_athena_tests"
- `S3_BUCKET` can be used to override the default S3 bucket of "go-athena-tests"
- `ATHENA_REGION` or `AWS_DEFAULT_REGION` can be used to override the default region of "us-east-1"
- `ATHENA_WORK_GROUP` can be used to override the default workgroup of "primary"
- `ATHENA_AUTO_OUTPUT_LOCATION` is a parameter for test of getting output_location value from workgroup
- 0 (default value): Useless
- 1: It can make you get output_location value from query result location in workgroup


[database/sql]: https://golang.org/pkg/database/sql/
Expand Down
9 changes: 9 additions & 0 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,15 @@ func (c *conn) runQuery(ctx context.Context, query string) (driver.Rows, error)
catalog = cat
}

// output location (with empty value)
if checkOutputLocation(resultMode, c.OutputLocation) {
var err error
c.OutputLocation, err = getOutputLocation(c.athena, c.workgroup)
if err != nil {
return nil, err
}
}

// mode ctas
var ctasTable string
var afterDownload func() error
Expand Down
40 changes: 27 additions & 13 deletions db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@ import (
)

var (
AthenaDatabase = "go_athena_tests"
S3Bucket = "go-athena-tests"
AwsRegion = "us-east-1"
AthenaDatabase = "go_athena_tests"
S3Bucket = "go-athena-tests"
AwsRegion = "us-east-1"
WorkGroup = "primary"
AutoOutputLocation = false
)

func init() {
Expand All @@ -40,6 +42,10 @@ func init() {
if v := os.Getenv("ATHENA_REGION"); v != "" {
AwsRegion = v
}
if v := os.Getenv("ATHENA_WORK_GROUP"); v != "" {
WorkGroup = v
}
AutoOutputLocation = os.Getenv("ATHENA_AUTO_OUTPUT_LOCATION") == "1"
}

func TestQuery(t *testing.T) {
Expand Down Expand Up @@ -158,16 +164,19 @@ func TestOpen(t *testing.T) {
ResultModeGzipDL,
}

config := Config{
Session: session,
Database: AthenaDatabase,
WorkGroup: WorkGroup,
Timeout: timeOutLimitDefault,
}
if !AutoOutputLocation {
config.OutputLocation = fmt.Sprintf("s3:https://%s", S3Bucket)
}

for _, resultMode := range resultModes {
db, err := Open(Config{
Session: session,
Database: AthenaDatabase,
OutputLocation: fmt.Sprintf("s3:https://%s", S3Bucket),

ResultMode: resultMode,
WorkGroup: "primary",
Timeout: timeOutLimitDefault,
})
config.ResultMode = resultMode
db, err := Open(config)
require.NoError(t, err, fmt.Sprintf("Open. resultMode:%v", resultMode))

ctx := context.Background()
Expand Down Expand Up @@ -234,7 +243,12 @@ func setup(t *testing.T) *athenaHarness {
}
harness := athenaHarness{t: t, sess: sess}

harness.db, err = sql.Open("athena", fmt.Sprintf("db=%s&output_location=s3:https://%s&region=%s", AthenaDatabase, S3Bucket, AwsRegion))
connStr := fmt.Sprintf("db=%s&output_location=s3:https://%s&region=%s", AthenaDatabase, S3Bucket, AwsRegion)
if AutoOutputLocation {
connStr = fmt.Sprintf("db=%s&region=%s&workgroup=%s", AthenaDatabase, AwsRegion, WorkGroup)
}

harness.db, err = sql.Open("athena", connStr)
require.NoError(t, err)

harness.setupTable()
Expand Down
38 changes: 33 additions & 5 deletions driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"database/sql/driver"
"errors"
"fmt"
"github.com/aws/aws-sdk-go/service/athena/athenaiface"
"net/url"
"strconv"
"strings"
Expand Down Expand Up @@ -85,8 +86,20 @@ func (d *Driver) Open(connStr string) (driver.Conn, error) {
cfg.PollFrequency = 5 * time.Second
}

// athena client
athenaClient := athena.New(cfg.Session)

// output location (with empty value)
if checkOutputLocation(cfg.ResultMode, cfg.OutputLocation) {
var err error
cfg.OutputLocation, err = getOutputLocation(athenaClient, cfg.WorkGroup)
if err != nil {
return nil, err
}
}

return &conn{
athena: athena.New(cfg.Session),
athena: athenaClient,
db: cfg.Database,
OutputLocation: cfg.OutputLocation,
pollFrequency: cfg.PollFrequency,
Expand All @@ -106,10 +119,6 @@ func Open(cfg Config) (*sql.DB, error) {
return nil, errors.New("db is required")
}

if cfg.OutputLocation == "" {
return nil, errors.New("s3_staging_url is required")
}

if cfg.Session == nil {
return nil, errors.New("session is required")
}
Expand Down Expand Up @@ -198,3 +207,22 @@ func configFromConnectionString(connStr string) (*Config, error) {

return &cfg, nil
}

// checkOutputLocation is to check if outputLocation should be obtained from workgroup.
func checkOutputLocation(resultMode ResultMode, outputLocation string) bool {
return resultMode != ResultModeAPI && outputLocation == ""
}

// getOutputLocation is for getting output location value from workgroup when location value is empty.
func getOutputLocation(athenaClient athenaiface.AthenaAPI, workGroup string) (string, error) {
var outputLocation string
output, err := athenaClient.GetWorkGroup(
&athena.GetWorkGroupInput{
WorkGroup: aws.String(workGroup),
},
)
if err == nil {
outputLocation = *output.WorkGroup.Configuration.ResultConfiguration.OutputLocation
}
return outputLocation, err
}
4 changes: 4 additions & 0 deletions rows_gzip.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ func (r *rowsGzipDL) downloadCompressedDataAsync(
}

func (r *rowsGzipDL) downloadCompressedData(sess *session.Session, location string) error {
if location[len(location)-1:] == "/" {
location = location[:len(location)-1]
}

// remove the first 5 characters "s3:https://" from location
bucketName := location[5:]

Expand Down

0 comments on commit 76d7de7

Please sign in to comment.