Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support for sql databases #4

Merged
merged 9 commits into from
May 2, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
finish mssql adapter
  • Loading branch information
anbraten committed May 2, 2022
commit 2608a9a8318b6b94607b908b1cbf431d1a6f8d76
41 changes: 29 additions & 12 deletions adapters/mssql_adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,55 +14,68 @@ type mssqlAdapter struct {

func (adapter mssqlAdapter) HasDatabase(ctx context.Context, database string) (bool, error) {
var count int
query := fmt.Sprintf("SELECT COUNT(*) FROM master.sys.databases WHERE name=@p1")
err := adapter.db.QueryRowContext(ctx, query, database).Scan(&count)
query := fmt.Sprintf("SELECT COUNT(*) FROM master.sys.databases WHERE name='%s';", database)
err := adapter.db.QueryRowContext(ctx, query).Scan(&count)
if err != nil {
return false, err
}
return count == 1, nil
}

func (adapter mssqlAdapter) CreateDatabase(ctx context.Context, database string) error {
_, err := adapter.db.ExecContext(ctx, "CREATE DATABASE @p1;", database)
query := fmt.Sprintf("EXEC ('sp_configure ''contained database authentication'', 1; reconfigure;');")
_, err := adapter.db.ExecContext(ctx, query)
if err != nil {
return err
}

query = fmt.Sprintf("CREATE DATABASE [%s] CONTAINMENT=PARTIAL;", database)
_, err = adapter.db.ExecContext(ctx, query)
return err
}

func (adapter mssqlAdapter) DeleteDatabase(ctx context.Context, database string) error {
_, err := adapter.db.ExecContext(ctx, "DROP DATABASE @p1;", database)
query := fmt.Sprintf("DROP DATABASE [%s];", database)
_, err := adapter.db.ExecContext(ctx, query)
return err
}

func (adapter mssqlAdapter) HasDatabaseUserWithAccess(ctx context.Context, database string, username string) (bool, error) {
// TODO implement
return false, nil
var count int
query := fmt.Sprintf("USE [%s]; SELECT COUNT(*) FROM sys.database_principals WHERE authentication_type=2 AND name='%s';", database, username)
err := adapter.db.QueryRowContext(ctx, query).Scan(&count)
if err != nil {
return false, err
}
return count == 1, nil
}

func (adapter mssqlAdapter) CreateDatabaseUser(ctx context.Context, database string, username string, password string) error {
// make password sql safe
quotedPassword := QuoteLiteral(password)
query := fmt.Sprintf("CREATE USER %s WITH PASSWORD = %s", username, quotedPassword)
query := fmt.Sprintf("USE [%s]; CREATE USER [%s] WITH PASSWORD=%s", database, username, quotedPassword)
_, err := adapter.db.ExecContext(ctx, query)
if err != nil {
return err
}

query = fmt.Sprintf("GRANT ALL PRIVILEGES ON DATABASE %s TO %s;", database, username)
query = fmt.Sprintf("USE [%s]; ALTER ROLE db_owner ADD MEMBER [%s];", database, username)
_, err = adapter.db.ExecContext(ctx, query)

return err
}

func (adapter mssqlAdapter) DeleteDatabaseUser(ctx context.Context, database string, username string) error {
// TODO implement
return nil
query := fmt.Sprintf("USE [%s]; DROP USER %s;", database, username)
_, err := adapter.db.ExecContext(ctx, query)
return err
}

func (adapter mssqlAdapter) Close(ctx context.Context) error {
return adapter.db.Close()
}

func GetMssqlConnection(ctx context.Context, url string) (*mssqlAdapter, error) {
db, err := sql.Open("mssql", url)
db, err := sql.Open("sqlserver", url)
if err != nil {
return nil, err
}
Expand All @@ -71,5 +84,9 @@ func GetMssqlConnection(ctx context.Context, url string) (*mssqlAdapter, error)
db: db,
}

if err := adapter.db.PingContext(ctx); err != nil {
return nil, err
}

return &adapter, nil
}
16 changes: 12 additions & 4 deletions adapters/mssql_adapter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,29 @@ func TestMsSqlDB(t *testing.T) {
databasePort := "1433"

ctx := context.Background()
url := fmt.Sprintf("mssql:https://%s:%s@%s:%s", "sa", "pA_sw0rd", databaseHost, databasePort)
adapter, err := adapters.GetCouchdbConnection(ctx, url)
url := fmt.Sprintf("sqlserver:https://%s:%s@%s:%s", "sa", "pA_sw0rd", databaseHost, databasePort)
adapter, err := adapters.GetMssqlConnection(ctx, url)
if err != nil {
t.Fatalf("Error opening database connection: %s", err)
}

clientConnectTest := func(ctx context.Context, databaseName string, databaseUsername string, databasePassword string) error {
url := fmt.Sprintf("sqlserver:https://%s:%s@%s:%s", databaseUsername, databasePassword, databaseHost, databasePort)
url := fmt.Sprintf("sqlserver:https://%s:%s@%s:%s?database=%s", databaseUsername, databasePassword, databaseHost, databasePort, databaseName)
client, err := sql.Open("sqlserver", url)
if err != nil {
return err
}

if err = client.PingContext(ctx); err != nil {
return err
}

_, err = client.ExecContext(ctx, "CREATE TABLE test (id int);")
return err
if err != nil {
return err
}

return client.Close()
}

testHelper(t, ctx, adapter, clientConnectTest)
Expand Down