Skip to content

Commit

Permalink
Assign burst to limit, add ReadCloser, fix fujiwara#4
Browse files Browse the repository at this point in the history
  • Loading branch information
viciious committed Jan 24, 2023
1 parent b708d43 commit f1d5c7e
Showing 1 changed file with 51 additions and 16 deletions.
67 changes: 51 additions & 16 deletions shapeio.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,8 @@ import (
"golang.org/x/time/rate"
)

const burstLimit = 1000 * 1000 * 1000

type Reader struct {
r io.Reader
r io.ReadCloser
limiter *rate.Limiter
ctx context.Context
}
Expand All @@ -25,13 +23,29 @@ type Writer struct {
// NewReader returns a reader that implements io.Reader with rate limiting.
func NewReader(r io.Reader) *Reader {
return &Reader{
r: r,
r: io.NopCloser(r),
ctx: context.Background(),
}
}

// NewReaderWithContext returns a reader that implements io.Reader with rate limiting.
func NewReaderWithContext(r io.Reader, ctx context.Context) *Reader {
return &Reader{
r: io.NopCloser(r),
ctx: ctx,
}
}

// NewReadCloser returns a reader that implements io.ReadCloser with rate limiting.
func NewReadCloser(r io.ReadCloser) *Reader {
return &Reader{
r: r,
ctx: context.Background(),
}
}

// NewReadCloserWithContext returns a reader that implements io.ReadCloser with rate limiting.
func NewReadCloserWithContext(r io.ReadCloser, ctx context.Context) *Reader {
return &Reader{
r: r,
ctx: ctx,
Expand All @@ -56,16 +70,22 @@ func NewWriterWithContext(w io.Writer, ctx context.Context) *Writer {

// SetRateLimit sets rate limit (bytes/sec) to the reader.
func (s *Reader) SetRateLimit(bytesPerSec float64) {
s.limiter = rate.NewLimiter(rate.Limit(bytesPerSec), burstLimit)
s.limiter.AllowN(time.Now(), burstLimit) // spend initial burst
s.limiter = rate.NewLimiter(rate.Limit(bytesPerSec), int(bytesPerSec))
s.limiter.AllowN(time.Now(), int(bytesPerSec)) // spend initial burst
}

// Read reads bytes into p.
func (s *Reader) Read(p []byte) (int, error) {
if s.limiter == nil {
return s.r.Read(p)
}
n, err := s.r.Read(p)

limit := int(s.limiter.Limit())
if limit > len(p) {
limit = len(p)
}

n, err := s.r.Read(p[:limit])
if err != nil {
return n, err
}
Expand All @@ -75,23 +95,38 @@ func (s *Reader) Read(p []byte) (int, error) {
return n, nil
}

// Read closes the reader
func (s *Reader) Close() error {
return s.r.Close()
}

// SetRateLimit sets rate limit (bytes/sec) to the writer.
func (s *Writer) SetRateLimit(bytesPerSec float64) {
s.limiter = rate.NewLimiter(rate.Limit(bytesPerSec), burstLimit)
s.limiter.AllowN(time.Now(), burstLimit) // spend initial burst
s.limiter = rate.NewLimiter(rate.Limit(bytesPerSec), int(bytesPerSec))
s.limiter.AllowN(time.Now(), int(bytesPerSec)) // spend initial burst
}

// Write writes bytes from p.
func (s *Writer) Write(p []byte) (int, error) {
if s.limiter == nil {
return s.w.Write(p)
}
n, err := s.w.Write(p)
if err != nil {
return n, err
}
if err := s.limiter.WaitN(s.ctx, n); err != nil {
return n, err

for i := 0; i < len(p); {
rem := len(p) - i
limit := int(s.limiter.Limit())
if limit > rem {
limit = rem
}

n, err := s.w.Write(p[i : i+limit])
if err != nil {
return i + n, err
}
if err := s.limiter.WaitN(s.ctx, n); err != nil {
return i + n, err
}
i += limit
}
return n, err
return len(p), nil
}

0 comments on commit f1d5c7e

Please sign in to comment.