Skip to content

Commit

Permalink
small code optimizations
Browse files Browse the repository at this point in the history
  • Loading branch information
julienschmidt committed Mar 3, 2013
1 parent d571cda commit 74a6452
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 73 deletions.
7 changes: 6 additions & 1 deletion buffer.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,14 @@ func (b *buffer) fill(need int) (err error) {
b.length = 0

n := 0
for err == nil && b.length < need {
for b.length < need {
n, err = b.rd.Read(b.buf[b.length:])
b.length += n

if err == nil {
continue
}
return // err
}

return
Expand Down
6 changes: 3 additions & 3 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
}

if columnCount > 0 {
_, err = stmt.mc.readUntilEOF()
err = stmt.mc.readUntilEOF()
}
}

Expand Down Expand Up @@ -159,12 +159,12 @@ func (mc *mysqlConn) exec(query string) (err error) {
var resLen int
resLen, err = mc.readResultSetHeaderPacket()
if err == nil && resLen > 0 {
_, err = mc.readUntilEOF()
err = mc.readUntilEOF()
if err != nil {
return
}

_, err = mc.readUntilEOF()
err = mc.readUntilEOF()
}

return
Expand Down
2 changes: 1 addition & 1 deletion const.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
package mysql

const (
MIN_PROTOCOL_VERSION = 10
MIN_PROTOCOL_VERSION byte = 10
//MAX_PACKET_SIZE = 1<<24 - 1
TIME_FORMAT = "2006-01-02 15:04:05"
)
Expand Down
125 changes: 60 additions & 65 deletions packets.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func (mc *mysqlConn) readPacket() (data []byte, err error) {
data = make([]byte, 4)
err = mc.buf.read(data)
if err != nil {
errLog.Print(err)
errLog.Print(err.Error())
return nil, driver.ErrBadConn
}

Expand All @@ -40,7 +40,8 @@ func (mc *mysqlConn) readPacket() (data []byte, err error) {
pktLen |= uint32(data[2]) << 16

if pktLen == 0 {
return nil, err
errLog.Print(errMalformPkt.Error())
return nil, driver.ErrBadConn
}

// Check Packet Sync
Expand All @@ -59,7 +60,7 @@ func (mc *mysqlConn) readPacket() (data []byte, err error) {
if err == nil {
return data, nil
}
errLog.Print(err)
errLog.Print(err.Error())
return nil, driver.ErrBadConn
}

Expand All @@ -74,9 +75,9 @@ func (mc *mysqlConn) writePacket(data []byte) error {
}

if err == nil { // n != len(data)
errLog.Print(errMalformPkt)
errLog.Print(errMalformPkt.Error())
} else {
errLog.Print(err)
errLog.Print(err.Error())
}
return driver.ErrBadConn
}
Expand All @@ -103,7 +104,7 @@ func (mc *mysqlConn) readInitPacket() (err error) {

// server version [null terminated string]
// connection id [4 bytes]
pos := 1 + (bytes.IndexByte(data[1:], 0x00) + 1) + 4
pos := 1 + bytes.IndexByte(data[1:], 0x00) + 1 + 4

// first part of scramble buffer [8 bytes]
mc.scrambleBuff = data[pos : pos+8]
Expand Down Expand Up @@ -287,45 +288,43 @@ func (mc *mysqlConn) writeCommandPacketUint32(command commandType, arg uint32) e
// Returns error if Packet is not an 'Result OK'-Packet
func (mc *mysqlConn) readResultOK() error {
data, err := mc.readPacket()
if err != nil {
return err
}

switch data[0] {
// OK
case 0:
mc.handleOkPacket(data)
return nil
// EOF, someone is using old_passwords
case 254:
return errOldPassword
if err == nil {
switch data[0] {
// OK
case 0:
mc.handleOkPacket(data)
return nil
// EOF, someone is using old_passwords
case 254:
return errOldPassword
}
// ERROR
return mc.handleErrorPacket(data)
}
// ERROR
return mc.handleErrorPacket(data)
return err
}

// Result Set Header Packet
// http:https://dev.mysql.com/doc/internals/en/text-protocol.html#packet-ProtocolText::Resultset
func (mc *mysqlConn) readResultSetHeaderPacket() (int, error) {
data, err := mc.readPacket()
if err != nil {
return 0, err
}
if err == nil {
if data[0] == 0 {
mc.handleOkPacket(data)
return 0, nil
} else if data[0] == 255 {
return 0, mc.handleErrorPacket(data)
}

if data[0] == 0 {
mc.handleOkPacket(data)
return 0, nil
} else if data[0] == 255 {
return 0, mc.handleErrorPacket(data)
}
// column count
num, _, n := readLengthEncodedInteger(data)
if n-len(data) == 0 {
return int(num), nil
}

// column count
num, _, n := readLengthEncodedInteger(data)
if n-len(data) == 0 {
return int(num), nil
return 0, errMalformPkt
}

return 0, errMalformPkt
return 0, err
}

// Error Packet
Expand Down Expand Up @@ -487,18 +486,17 @@ func (rows *mysqlRows) readRow(dest []driver.Value) (err error) {
}

// Reads Packets until EOF-Packet or an Error appears. Returns count of Packets read
func (mc *mysqlConn) readUntilEOF() (count uint64, err error) {
func (mc *mysqlConn) readUntilEOF() (err error) {
var data []byte

for {
data, err = mc.readPacket()

// Err or EOF Packet
if err != nil || (data[0] == 254 && len(data) == 5) {
return
// No Err and no EOF Packet
if err == nil && (data[0] != 254 || len(data) != 5) {
continue
}

count++
return
}
return
}
Expand All @@ -511,35 +509,32 @@ func (mc *mysqlConn) readUntilEOF() (count uint64, err error) {
// http:https://dev.mysql.com/doc/internals/en/prepared-statements.html#com-stmt-prepare-response
func (stmt *mysqlStmt) readPrepareResultPacket() (columnCount uint16, err error) {
data, err := stmt.mc.readPacket()
if err != nil {
return
}

// Position
pos := 0

// packet marker [1 byte]
if data[pos] != 0 { // not OK (0) ?
err = stmt.mc.handleErrorPacket(data)
return
}
pos++
if err == nil {
// Position
pos := 0

// statement id [4 bytes]
stmt.id = binary.LittleEndian.Uint32(data[pos : pos+4])
pos += 4
// packet marker [1 byte]
if data[pos] != 0 { // not OK (0) ?
err = stmt.mc.handleErrorPacket(data)
return
}
pos++

// Column count [16 bit uint]
columnCount = binary.LittleEndian.Uint16(data[pos : pos+2])
pos += 2
// statement id [4 bytes]
stmt.id = binary.LittleEndian.Uint32(data[pos : pos+4])
pos += 4

// Param count [16 bit uint]
stmt.paramCount = int(binary.LittleEndian.Uint16(data[pos : pos+2]))
pos += 2
// Column count [16 bit uint]
columnCount = binary.LittleEndian.Uint16(data[pos : pos+2])
pos += 2

// Warning count [16 bit uint]
// bytesToUint16(data[pos : pos+2])
// Param count [16 bit uint]
stmt.paramCount = int(binary.LittleEndian.Uint16(data[pos : pos+2]))
pos += 2

// Warning count [16 bit uint]
// bytesToUint16(data[pos : pos+2])
}
return
}

Expand Down
2 changes: 1 addition & 1 deletion rows.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func (rows *mysqlRows) Close() (err error) {
return errors.New("Invalid Connection")
}

_, err = rows.mc.readUntilEOF()
err = rows.mc.readUntilEOF()
}

return
Expand Down
4 changes: 2 additions & 2 deletions statement.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,13 @@ func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
if err == nil {
if resLen > 0 {
// Columns
_, err = stmt.mc.readUntilEOF()
err = stmt.mc.readUntilEOF()
if err != nil {
return nil, err
}

// Rows
_, err = stmt.mc.readUntilEOF()
err = stmt.mc.readUntilEOF()
}
if err == nil {
return &mysqlResult{
Expand Down

0 comments on commit 74a6452

Please sign in to comment.