Skip to content

Commit

Permalink
Add hive db-api in codegen (sql-machine-learning#304)
Browse files Browse the repository at this point in the history
* hive parse config

* support hive in codegen

* add more args for codegen

* remove dup \n in codegen

* fix doc

* add hive dependency in dockerfile

* remove pip env from dockerfile
  • Loading branch information
weiguoz committed Apr 28, 2019
1 parent 67ba1fb commit d467d76
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 25 deletions.
3 changes: 2 additions & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@ FROM ubuntu:16.04

RUN apt-get update
RUN apt-get install -y python3-pip

RUN pip3 install --upgrade pip
RUN pip3 install tensorflow mysql-connector-python jupyter sqlflow
RUN pip3 install tensorflow mysql-connector-python pyhive jupyter sqlflow
# Fix jupyter server "connecting to kernel" problem
# https://github.com/jupyter/notebook/issues/2664#issuecomment-468954423
RUN pip3 install tornado==4.5.3
Expand Down
1 change: 1 addition & 0 deletions Dockerfile.dev
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ RUN pip3 install --upgrade pip
RUN pip3 install grpcio-tools
RUN pip3 install tensorflow
RUN pip3 install mysql-connector-python
RUN pip3 install pyhive

RUN wget --quiet https://dl.google.com/go/go1.11.5.linux-amd64.tar.gz
RUN tar -C /usr/local -xzf go1.11.5.linux-amd64.tar.gz
Expand Down
2 changes: 1 addition & 1 deletion doc/build.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Canonical Development Environment

Referring to [this example](https://github.com/sql-machine-learning/canonicalize-go-python-grpc-dev-env),
Referring to [this example](https://github.com/wangkuiyi/canonicalize-go-python-grpc-dev-env),
we create a canonical development environment for Go and Python programmers using Docker.

## Editing on Host
Expand Down
2 changes: 1 addition & 1 deletion doc/installation.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ After data is popularized in MySQL, let's test the installation from running a q

```
docker run --rm -it -p 8888:8888 sqlflow/sqlflow:latest \
bash -c "sqlflowserver --datasource="mysql:https://root:root@tcp(host.docker.internal:3306)/?maxAllowedPacket=0" &
bash -c "sqlflowserver --datasource='mysql:https://root:root@tcp(host.docker.internal:3306)/?maxAllowedPacket=0' &
SQLFLOW_SERVER=localhost:50051 jupyter notebook --ip=0.0.0.0 --port=8888 --allow-root"
```

Expand Down
60 changes: 40 additions & 20 deletions sql/codegen.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"text/template"

"github.com/go-sql-driver/mysql"
"sqlflow.org/gohive"
)

// TODO(tonyyang): This is currently a quick hack to map from SQL
Expand All @@ -25,6 +26,7 @@ type connectionConfig struct {
Host string
Port string
Database string
Auth string
}

type modelConfig struct {
Expand Down Expand Up @@ -74,20 +76,27 @@ func newFiller(pr *extendedSelect, fts fieldTypes, db *DB) (*filler, error) {
r.TableName = strings.Join(strings.Split(pr.into, ".")[:2], ".")
}

r.Driver = db.driverName
switch db.driverName {
case "mysql":
cfg, err := mysql.ParseDSN(db.dataSourceName)
if err != nil {
return nil, err
}
r.Driver = "mysql"
r.User = cfg.User
r.Password = cfg.Passwd
r.Host = strings.Split(cfg.Addr, ":")[0]
r.Port = strings.Split(cfg.Addr, ":")[1]
sa := strings.Split(cfg.Addr, ":")
r.Host, r.Port, r.Database = sa[0], sa[1], cfg.DBName
r.User, r.Password = cfg.User, cfg.Passwd
case "sqlite3":
r.Driver = "sqlite3"
r.Database = db.dataSourceName
case "hive":
cfg, err := gohive.ParseDSN(db.dataSourceName)
if err != nil {
return nil, err
}
sa := strings.Split(cfg.Addr, ":")
r.Host, r.Port, r.Database = sa[0], sa[1], cfg.DBName
// FIXME(weiguo): make gohive support Auth
r.User, r.Password, r.Auth = cfg.User, cfg.Passwd, "NOSALS"
default:
return nil, fmt.Errorf("sqlfow currently doesn't support DB %v", db.driverName)
}
Expand All @@ -113,7 +122,14 @@ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import sys, json
import tensorflow as tf
{{if eq .Driver "mysql"}}
import mysql.connector
{{else if eq .Driver "sqlite3"}}
import sqlite3
{{else if eq .Driver "hive"}}
from pyhive import hive
{{end}}
# Disable Tensorflow INFO and WARNING logs
tf.logging.set_verbosity(tf.logging.ERROR)
Expand All @@ -126,25 +142,31 @@ STEP = 1000
{{if eq .Driver "mysql"}}
db = mysql.connector.connect(user="{{.User}}",
passwd="{{.Password}}",
{{if ne .Database ""}}database="{{.Database}}",{{end}}
host="{{.Host}}",
port={{.Port}}{{if eq .Database ""}}{{- else}}, database="{{.DATABASE}}"{{end}})
{{else}}
{{if eq .Driver "sqlite3"}}
port={{.Port}})
{{else if eq .Driver "sqlite3"}}
db = sqlite3.connect({{.Database}})
{{else if eq .Driver "hive"}}
hive.connect(username="{{.User}}",
password="{{.Password}}",
{{if ne .Database ""}}database="{{.Database}}",{{end}}
auth="{{.Auth}}",
host="{{.Host}}",
port={{.Port}})
{{else}}
raise ValueError("unrecognized database driver: {{.Driver}}")
{{end}}
{{end}}
cursor = db.cursor()
cursor.execute("""{{.StandardSelect}}""")
field_names = [i[0] for i in cursor.description]
columns = list(map(list, zip(*cursor.fetchall())))
feature_columns = [{{range .X}}tf.feature_column.{{.Type}}(key="{{.Name}}"),
{{end}}]
{{end}}]
feature_column_names = [{{range .X}}"{{.Name}}",
{{end}}]
{{end}}]
X = {name: columns[field_names.index(name)] for name in feature_column_names}
{{if .Train}}
Expand All @@ -154,7 +176,7 @@ Y = columns[field_names.index("{{.Y.Name}}")]
classifier = tf.estimator.{{.Estimator}}(
feature_columns=feature_columns,{{range $key, $value := .Attrs}}
{{$key}} = {{$value}},{{end}}
model_dir= "{{.Save}}")
model_dir = "{{.Save}}")
{{if .Train}}
def train_input_fn(features, labels, batch_size):
Expand All @@ -172,10 +194,8 @@ def eval_input_fn(features, labels, batch_size):
return dataset
eval_result = classifier.evaluate(
input_fn=lambda:eval_input_fn(X, Y, BATCHSIZE),
steps=STEP)
print("\nTraining set accuracy: {accuracy:0.5f}\n".format(**eval_result))
input_fn=lambda:eval_input_fn(X, Y, BATCHSIZE), steps=STEP)
print("Training set accuracy: {accuracy:0.5f}".format(**eval_result))
print("Done training")
{{- else}}
def eval_input_fn(features, batch_size):
Expand All @@ -194,8 +214,8 @@ def insert(table_name, X, db):
field_names = [key for key in X]
# FIXME(tony): HIVE and ODPS use INSERT INTO TABLE ...
sql = "INSERT INTO {} ({}) VALUES ({})".format(
table_name, ",".join(field_names), ",".join(["%s" for _ in field_names]))
sql = "INSERT INTO {} ({}) VALUES ({})".format(table_name,
",".join(field_names), ",".join(["%s" for _ in field_names]))
val = []
for i in range(length[0]):
val.append(tuple([str(X[f][i]) for f in field_names]))
Expand All @@ -206,7 +226,7 @@ def insert(table_name, X, db):
insert("{{.TableName}}", X, db)
print("Done predicting. Predict Table : {{.TableName}}")
print("Done predicting. Predict table : {{.TableName}}")
{{- end}}
`

Expand Down
4 changes: 2 additions & 2 deletions sql/verifier.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ func (ft fieldTypes) get(ident string) (string, bool) {
// identifier: t.f=>(t,f), db.t.f=>(db.t,f), f=>("",f).
func decomp(ident string) (tbl string, fld string) {
s := strings.Split(ident, ".")
return strings.Join(s[:len(s)-1], "."), s[len(s)-1]
n := len(s)
return strings.Join(s[:n-1], "."), s[n-1]
}

// Retrieve the type of fields mentioned in SELECT.
Expand Down Expand Up @@ -112,7 +113,6 @@ func describeTables(slct *extendedSelect, db *DB) (ft fieldTypes, e error) {
}
}
}

}
return ft, nil
}
Expand Down

0 comments on commit d467d76

Please sign in to comment.