Skip to content

Commit

Permalink
return byte slice directly from buffer
Browse files Browse the repository at this point in the history
  • Loading branch information
julienschmidt committed Apr 21, 2013
1 parent 0e8690a commit 96a4f13
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 22 deletions.
44 changes: 27 additions & 17 deletions buffer.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,35 +51,45 @@ func (b *buffer) fill(need int) (err error) {
return
}

// read len(p) bytes
func (b *buffer) read(p []byte) (err error) {
need := len(p)
// returns next N bytes from buffer.
// The returned slice is only guaranteed to be valid until the next read
func (b *buffer) readNext(need int) (p []byte, err error) {
// return slice from buffer if possible
if b.length >= need {
p = b.buf[b.idx : b.idx+need]
b.idx += need
b.length -= need
return

if b.length < need {
} else {
p = make([]byte, need)
has := 0

// copy data that is already in the buffer
if b.length > 0 {
copy(p[0:b.length], b.buf[b.idx:])
need -= b.length
p = p[b.length:]

has = b.length
need -= has
b.idx = 0
b.length = 0
}

if need >= len(b.buf) {
// does the data fit into the buffer?
if need < len(b.buf) {
err = b.fill(need) // err deferred
copy(p[has:has+need], b.buf[b.idx:])
b.idx += need
b.length -= need
return

} else {
var n int
has := 0
for err == nil && need > has {
for err == nil && need > 0 {
n, err = b.rd.Read(p[has:])
has += n
need -= n
}
return
}

err = b.fill(need) // err deferred
}

copy(p, b.buf[b.idx:])
b.idx += need
b.length -= need
return
}
8 changes: 3 additions & 5 deletions packets.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,14 @@ import (
// Read packet to buffer 'data'
func (mc *mysqlConn) readPacket() (data []byte, err error) {
// Read packet header
data = make([]byte, 4)
err = mc.buf.read(data)
data, err = mc.buf.readNext(4)
if err != nil {
errLog.Print(err.Error())
return nil, driver.ErrBadConn
}

// Packet Length [24 bit]
pktLen := uint32(data[0]) | uint32(data[1])<<8 | uint32(data[2])<<16
pktLen := int(uint32(data[0]) | uint32(data[1])<<8 | uint32(data[2])<<16)

if pktLen < 1 {
errLog.Print(errMalformPkt.Error())
Expand All @@ -52,8 +51,7 @@ func (mc *mysqlConn) readPacket() (data []byte, err error) {
mc.sequence++

// Read packet body [pktLen bytes]
data = make([]byte, pktLen)
err = mc.buf.read(data)
data, err = mc.buf.readNext(pktLen)
if err == nil {
if pktLen < maxPacketSize {
return data, nil
Expand Down

0 comments on commit 96a4f13

Please sign in to comment.