Skip to content

Commit

Permalink
Support Migrator ColumnType interface
Browse files Browse the repository at this point in the history
  • Loading branch information
jinzhu committed Feb 19, 2022
1 parent a3d0aa6 commit f666f06
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 5 deletions.
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
github.com/ClickHouse/clickhouse-go v1.5.4 h1:cKjXeYLNWVJIx2J1K6H2CqyRmfwVJVY1OV1coaaFcI0=
github.com/ClickHouse/clickhouse-go v1.5.4/go.mod h1:EaI/sW7Azgz9UATzd5ZdZHRUhHgv5+JMS9NSr2smCJI=
github.com/bkaradzic/go-lz4 v1.0.0 h1:RXc4wYsyz985CkXXeX04y4VnZFGG8Rd43pRaHsOXAKk=
github.com/bkaradzic/go-lz4 v1.0.0/go.mod h1:0YdlkowM3VswSROI7qDxhRvJ3sLhlFrRRwjwegp5jy4=
github.com/cloudflare/golz4 v0.0.0-20150217214814-ef862a3cdc58 h1:F1EaeKL/ta07PY/k9Os/UFtwERei2/XzGemhpGnBKNg=
Expand Down
65 changes: 65 additions & 0 deletions migrator.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package clickhouse

import (
"database/sql"
"errors"
"fmt"
"strconv"
Expand Down Expand Up @@ -281,6 +282,70 @@ func (m Migrator) HasColumn(value interface{}, field string) bool {
return count > 0
}

// ColumnTypes return columnTypes []gorm.ColumnType and execErr error
func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
columnTypes := make([]gorm.ColumnType, 0)
execErr := m.RunWithValue(value, func(stmt *gorm.Statement) (err error) {
rows, err := m.DB.Session(&gorm.Session{}).Table(stmt.Table).Limit(1).Rows()
if err != nil {
return err
}

defer func() {
err = rows.Close()
}()

var rawColumnTypes []*sql.ColumnType
rawColumnTypes, err = rows.ColumnTypes()

columnTypeSQL := "SELECT name, type, default_expression, comment, is_in_primary_key, character_octet_length, numeric_precision, numeric_precision_radix, numeric_scale, datetime_precision FROM system.columns WHERE database = ? AND table = ?"
columns, rowErr := m.DB.Raw(columnTypeSQL, m.CurrentDatabase(), stmt.Table).Rows()
if rowErr != nil {
return rowErr
}

defer columns.Close()

for columns.Next() {
var (
column migrator.ColumnType
datetimePrecision sql.NullInt64
radixValue sql.NullInt64
values = []interface{}{
&column.NameValue, &column.DataTypeValue, &column.DefaultValueValue, &column.CommentValue, &column.PrimayKeyValue, &column.LengthValue, &column.DecimalSizeValue, &radixValue, &column.ScaleValue, &datetimePrecision,
}
)

if scanErr := columns.Scan(values...); scanErr != nil {
return scanErr
}

column.ColumnTypeValue = column.DataTypeValue

if datetimePrecision.Valid {
column.DecimalSizeValue = datetimePrecision
}

if column.DefaultValueValue.Valid {
column.DefaultValueValue.String = strings.Trim(column.DefaultValueValue.String, "'")
}

for _, c := range rawColumnTypes {
if c.Name() == column.NameValue.String {
column.SQLColumnType = c
break
}
}

columnTypes = append(columnTypes, column)
}

return
})

return columnTypes, execErr
}

// Indexes

func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) {
Expand Down
47 changes: 42 additions & 5 deletions migrator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
)

type User struct {
ID uint
ID uint `gorm:"primaryKey"`
Name string
FirstName string
LastName string
Expand All @@ -19,10 +19,13 @@ type User struct {

func TestAutoMigrate(t *testing.T) {
type UserMigrateColumn struct {
ID uint
Name string
IsAdmin bool
Birthday time.Time `gorm:"precision:4"`
ID uint
Name string
IsAdmin bool
Birthday time.Time `gorm:"precision:4"`
Debit float64 `gorm:"precision:4"`
Note string `gorm:"size:10;comment:my note"`
DefaultValue string `gorm:"default:hello world"`
}

if DB.Migrator().HasColumn("users", "is_admin") {
Expand All @@ -40,4 +43,38 @@ func TestAutoMigrate(t *testing.T) {
if !DB.Migrator().HasColumn("users", "is_admin") {
t.Fatalf("users's is_admin column should exists after auto migrate")
}

columnTypes, err := DB.Migrator().ColumnTypes("users")
if err != nil {
t.Fatalf("failed to get column types, got error %v", err)
}

for _, columnType := range columnTypes {
switch columnType.Name() {
case "id":
if columnType.DatabaseTypeName() != "UInt64" {
t.Fatalf("column id primary key should be correct, name: %v, column: %#v", columnType.Name(), columnType)
}
case "note":
if length, ok := columnType.Length(); !ok || length != 10 {
t.Fatalf("column name length should be correct, name: %v, column: %#v", columnType.Name(), columnType)
}

if comment, ok := columnType.Comment(); !ok || comment != "my note" {
t.Fatalf("column name length should be correct, name: %v, column: %#v", columnType.Name(), columnType)
}
case "default_value":
if defaultValue, ok := columnType.DefaultValue(); !ok || defaultValue != "hello world" {
t.Fatalf("column name default_value should be correct, name: %v, column: %#v", columnType.Name(), columnType)
}
case "debit":
if decimal, scale, ok := columnType.DecimalSize(); !ok || scale != 0 || decimal != 4 {
t.Fatalf("column name debit should be correct, name: %v, column: %#v", columnType.Name(), columnType)
}
case "birthday":
if decimal, scale, ok := columnType.DecimalSize(); !ok || scale != 0 || decimal != 4 {
t.Fatalf("column name birthday should be correct, name: %v, column: %#v", columnType.Name(), columnType)
}
}
}
}

0 comments on commit f666f06

Please sign in to comment.