Skip to content

Commit

Permalink
encoding/base64: reduce the overflow risk when computing encode/decod…
Browse files Browse the repository at this point in the history
…e length

Change-Id: I0a55cdc38ae496e2070f0b9ef317a41f82352afd
GitHub-Last-Rev: c19527a26b0778cbb4548f49e1e365102709f068
GitHub-Pull-Request: golang/go#61407
Reviewed-on: https://go-review.googlesource.com/c/go/+/510635
Reviewed-by: Ian Lance Taylor <[email protected]>
Run-TryBot: Ian Lance Taylor <[email protected]>
Auto-Submit: Ian Lance Taylor <[email protected]>
TryBot-Result: Gopher Robot <[email protected]>
Run-TryBot: Ian Lance Taylor <[email protected]>
Reviewed-by: Heschi Kreinick <[email protected]>
  • Loading branch information
chanxuehong authored and gopherbot committed Jul 21, 2023
1 parent 050d4d3 commit 14adf4f
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 12 deletions.
4 changes: 2 additions & 2 deletions src/encoding/base64/base64.go
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ func NewEncoder(enc *Encoding, w io.Writer) io.WriteCloser {
// of an input buffer of length n.
func (enc *Encoding) EncodedLen(n int) int {
if enc.padChar == NoPadding {
return (n*8 + 5) / 6 // minimum # chars at 6 bits per char
return n/3*4 + (n%3*8+5)/6 // minimum # chars at 6 bits per char
}
return (n + 2) / 3 * 4 // minimum # 4-char quanta, 3 bytes each
}
Expand Down Expand Up @@ -623,7 +623,7 @@ func NewDecoder(enc *Encoding, r io.Reader) io.Reader {
func (enc *Encoding) DecodedLen(n int) int {
if enc.padChar == NoPadding {
// Unpadded data may end with partial block of 2-3 characters.
return n * 6 / 8
return n/4*3 + n%4*6/8
}
// Padded base64 should always be a multiple of 4 characters in length.
return n / 4 * 3
Expand Down
44 changes: 34 additions & 10 deletions src/encoding/base64/base64_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@ import (
"errors"
"fmt"
"io"
"math"
"reflect"
"runtime/debug"
"strconv"
"strings"
"testing"
"time"
Expand Down Expand Up @@ -262,11 +264,12 @@ func TestDecodeBounds(t *testing.T) {
}

func TestEncodedLen(t *testing.T) {
for _, tt := range []struct {
type test struct {
enc *Encoding
n int
want int
}{
want int64
}
tests := []test{
{RawStdEncoding, 0, 0},
{RawStdEncoding, 1, 2},
{RawStdEncoding, 2, 3},
Expand All @@ -278,19 +281,30 @@ func TestEncodedLen(t *testing.T) {
{StdEncoding, 3, 4},
{StdEncoding, 4, 8},
{StdEncoding, 7, 12},
} {
if got := tt.enc.EncodedLen(tt.n); got != tt.want {
}
// check overflow
switch strconv.IntSize {
case 32:
tests = append(tests, test{RawStdEncoding, (math.MaxInt-5)/8 + 1, 357913942})
tests = append(tests, test{RawStdEncoding, math.MaxInt/4*3 + 2, math.MaxInt})
case 64:
tests = append(tests, test{RawStdEncoding, (math.MaxInt-5)/8 + 1, 1537228672809129302})
tests = append(tests, test{RawStdEncoding, math.MaxInt/4*3 + 2, math.MaxInt})
}
for _, tt := range tests {
if got := tt.enc.EncodedLen(tt.n); int64(got) != tt.want {
t.Errorf("EncodedLen(%d): got %d, want %d", tt.n, got, tt.want)
}
}
}

func TestDecodedLen(t *testing.T) {
for _, tt := range []struct {
type test struct {
enc *Encoding
n int
want int
}{
want int64
}
tests := []test{
{RawStdEncoding, 0, 0},
{RawStdEncoding, 2, 1},
{RawStdEncoding, 3, 2},
Expand All @@ -299,8 +313,18 @@ func TestDecodedLen(t *testing.T) {
{StdEncoding, 0, 0},
{StdEncoding, 4, 3},
{StdEncoding, 8, 6},
} {
if got := tt.enc.DecodedLen(tt.n); got != tt.want {
}
// check overflow
switch strconv.IntSize {
case 32:
tests = append(tests, test{RawStdEncoding, math.MaxInt/6 + 1, 268435456})
tests = append(tests, test{RawStdEncoding, math.MaxInt, 1610612735})
case 64:
tests = append(tests, test{RawStdEncoding, math.MaxInt/6 + 1, 1152921504606846976})
tests = append(tests, test{RawStdEncoding, math.MaxInt, 6917529027641081855})
}
for _, tt := range tests {
if got := tt.enc.DecodedLen(tt.n); int64(got) != tt.want {
t.Errorf("DecodedLen(%d): got %d, want %d", tt.n, got, tt.want)
}
}
Expand Down

0 comments on commit 14adf4f

Please sign in to comment.