diff --git a/buffer.go b/buffer.go index 191c1485..e44e94ed 100644 --- a/buffer.go +++ b/buffer.go @@ -51,35 +51,45 @@ func (b *buffer) fill(need int) (err error) { return } -// read len(p) bytes -func (b *buffer) read(p []byte) (err error) { - need := len(p) +// returns next N bytes from buffer. +// The returned slice is only guaranteed to be valid until the next read +func (b *buffer) readNext(need int) (p []byte, err error) { + // return slice from buffer if possible + if b.length >= need { + p = b.buf[b.idx : b.idx+need] + b.idx += need + b.length -= need + return - if b.length < need { + } else { + p = make([]byte, need) + has := 0 + + // copy data that is already in the buffer if b.length > 0 { copy(p[0:b.length], b.buf[b.idx:]) - need -= b.length - p = p[b.length:] - + has = b.length + need -= has b.idx = 0 b.length = 0 } - if need >= len(b.buf) { + // does the data fit into the buffer? + if need < len(b.buf) { + err = b.fill(need) // err deferred + copy(p[has:has+need], b.buf[b.idx:]) + b.idx += need + b.length -= need + return + + } else { var n int - has := 0 - for err == nil && need > has { + for err == nil && need > 0 { n, err = b.rd.Read(p[has:]) has += n + need -= n } - return } - - err = b.fill(need) // err deferred } - - copy(p, b.buf[b.idx:]) - b.idx += need - b.length -= need return } diff --git a/packets.go b/packets.go index 7c56245d..27863948 100644 --- a/packets.go +++ b/packets.go @@ -26,15 +26,14 @@ import ( // Read packet to buffer 'data' func (mc *mysqlConn) readPacket() (data []byte, err error) { // Read packet header - data = make([]byte, 4) - err = mc.buf.read(data) + data, err = mc.buf.readNext(4) if err != nil { errLog.Print(err.Error()) return nil, driver.ErrBadConn } // Packet Length [24 bit] - pktLen := uint32(data[0]) | uint32(data[1])<<8 | uint32(data[2])<<16 + pktLen := int(uint32(data[0]) | uint32(data[1])<<8 | uint32(data[2])<<16) if pktLen < 1 { errLog.Print(errMalformPkt.Error()) @@ -52,8 +51,7 @@ func (mc *mysqlConn) readPacket() (data []byte, err error) { mc.sequence++ // Read packet body [pktLen bytes] - data = make([]byte, pktLen) - err = mc.buf.read(data) + data, err = mc.buf.readNext(pktLen) if err == nil { if pktLen < maxPacketSize { return data, nil