diff --git a/AUTHORS b/AUTHORS index 2caa7d706..954e7ac7a 100644 --- a/AUTHORS +++ b/AUTHORS @@ -39,6 +39,7 @@ Evan Elias Evan Shaw Frederick Mayle Gustavo Kristic +Gusted Hajime Nakagami Hanno Braun Henri Yandell diff --git a/auth.go b/auth.go index bab282bd2..658259b24 100644 --- a/auth.go +++ b/auth.go @@ -13,10 +13,13 @@ import ( "crypto/rsa" "crypto/sha1" "crypto/sha256" + "crypto/sha512" "crypto/x509" "encoding/pem" "fmt" "sync" + + "filippo.io/edwards25519" ) // server pub keys registry @@ -225,6 +228,44 @@ func encryptPassword(password string, seed []byte, pub *rsa.PublicKey) ([]byte, return rsa.EncryptOAEP(sha1, rand.Reader, pub, plain, nil) } +// authEd25519 does ed25519 authentication used by MariaDB. +func authEd25519(scramble []byte, password string) ([]byte, error) { + // Derived from https://github.com/MariaDB/server/blob/d8e6bb00888b1f82c031938f4c8ac5d97f6874c3/plugin/auth_ed25519/ref10/sign.c + // Code style is from https://cs.opensource.google/go/go/+/refs/tags/go1.21.5:src/crypto/ed25519/ed25519.go;l=207 + h := sha512.Sum512([]byte(password)) + + s, err := edwards25519.NewScalar().SetBytesWithClamping(h[:32]) + if err != nil { + return nil, err + } + A := (&edwards25519.Point{}).ScalarBaseMult(s) + + mh := sha512.New() + mh.Write(h[32:]) + mh.Write(scramble) + messageDigest := mh.Sum(nil) + r, err := edwards25519.NewScalar().SetUniformBytes(messageDigest) + if err != nil { + return nil, err + } + + R := (&edwards25519.Point{}).ScalarBaseMult(r) + + kh := sha512.New() + kh.Write(R.Bytes()) + kh.Write(A.Bytes()) + kh.Write(scramble) + hramDigest := kh.Sum(nil) + k, err := edwards25519.NewScalar().SetUniformBytes(hramDigest) + if err != nil { + return nil, err + } + + S := k.MultiplyAdd(k, s, r) + + return append(R.Bytes(), S.Bytes()...), nil +} + func (mc *mysqlConn) sendEncryptedPassword(seed []byte, pub *rsa.PublicKey) error { enc, err := encryptPassword(mc.cfg.Passwd, seed, pub) if err != nil { @@ -290,6 +331,12 @@ func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, error) { enc, err := encryptPassword(mc.cfg.Passwd, authData, pubKey) return enc, err + case "client_ed25519": + if len(authData) != 32 { + return nil, ErrMalformPkt + } + return authEd25519(authData, mc.cfg.Passwd) + default: mc.cfg.Logger.Print("unknown auth plugin:", plugin) return nil, ErrUnknownPlugin diff --git a/auth_test.go b/auth_test.go index 3ce0ea6e0..8caed1fff 100644 --- a/auth_test.go +++ b/auth_test.go @@ -1328,3 +1328,54 @@ func TestAuthSwitchSHA256PasswordSecure(t *testing.T) { t.Errorf("got unexpected data: %v", conn.written) } } + +// Derived from https://github.com/MariaDB/server/blob/6b2287fff23fbdc362499501c562f01d0d2db52e/plugin/auth_ed25519/ed25519-t.c +func TestEd25519Auth(t *testing.T) { + conn, mc := newRWMockConn(1) + mc.cfg.User = "root" + mc.cfg.Passwd = "foobar" + + authData := []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA") + plugin := "client_ed25519" + + // Send Client Authentication Packet + authResp, err := mc.auth(authData, plugin) + if err != nil { + t.Fatal(err) + } + err = mc.writeHandshakeResponsePacket(authResp, plugin) + if err != nil { + t.Fatal(err) + } + + // check written auth response + authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + authRespEnd := authRespStart + 1 + len(authResp) + writtenAuthRespLen := conn.written[authRespStart] + writtenAuthResp := conn.written[authRespStart+1 : authRespEnd] + expectedAuthResp := []byte{ + 232, 61, 201, 63, 67, 63, 51, 53, 86, 73, 238, 35, 170, 117, 146, + 214, 26, 17, 35, 9, 8, 132, 245, 141, 48, 99, 66, 58, 36, 228, 48, + 84, 115, 254, 187, 168, 88, 162, 249, 57, 35, 85, 79, 238, 167, 106, + 68, 117, 56, 135, 171, 47, 20, 14, 133, 79, 15, 229, 124, 160, 176, + 100, 138, 14, + } + if writtenAuthRespLen != 64 { + t.Fatalf("expected 64 bytes from client, got %d", writtenAuthRespLen) + } + if !bytes.Equal(writtenAuthResp, expectedAuthResp) { + t.Fatalf("auth response did not match expected value:\n%v\n%v", writtenAuthResp, expectedAuthResp) + } + conn.written = nil + + // auth response + conn.data = []byte{ + 7, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, // OK + } + conn.maxReads = 1 + + // Handle response to auth packet + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } +} diff --git a/driver_test.go b/driver_test.go index 87892a09a..97fd5a17a 100644 --- a/driver_test.go +++ b/driver_test.go @@ -165,14 +165,14 @@ func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) { for _, test := range tests { t.Run("default", func(t *testing.T) { dbt := &DBTest{t, db} + defer dbt.db.Exec("DROP TABLE IF EXISTS test") test(dbt) - dbt.db.Exec("DROP TABLE IF EXISTS test") }) if db2 != nil { t.Run("interpolateParams", func(t *testing.T) { dbt2 := &DBTest{t, db2} + defer dbt2.db.Exec("DROP TABLE IF EXISTS test") test(dbt2) - dbt2.db.Exec("DROP TABLE IF EXISTS test") }) } } @@ -3181,14 +3181,14 @@ func TestRawBytesAreNotModified(t *testing.T) { rows, err := dbt.db.QueryContext(ctx, `SELECT id, value FROM test`) if err != nil { - t.Fatal(err) + dbt.Fatal(err) } var b int var raw sql.RawBytes for rows.Next() { if err := rows.Scan(&b, &raw); err != nil { - t.Fatal(err) + dbt.Fatal(err) } before := string(raw) @@ -3198,7 +3198,7 @@ func TestRawBytesAreNotModified(t *testing.T) { after := string(raw) if before != after { - t.Fatalf("the backing storage for sql.RawBytes has been modified (i=%v)", i) + dbt.Fatalf("the backing storage for sql.RawBytes has been modified (i=%v)", i) } } rows.Close() diff --git a/go.mod b/go.mod index 77bbb8dbf..4629714c0 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,5 @@ module github.com/go-sql-driver/mysql go 1.18 + +require filippo.io/edwards25519 v1.1.0 diff --git a/go.sum b/go.sum new file mode 100644 index 000000000..359ca94b4 --- /dev/null +++ b/go.sum @@ -0,0 +1,2 @@ +filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= +filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=