Skip to content

Commit

Permalink
Limit req/resp size and support stream bodies (#628)
Browse files Browse the repository at this point in the history
* limit req/resp size and allow body to be stream

* make some filters to aware of stream bodies

* fix panic

* update test

* update test

* fix random test fail

* make more filters aware stream

* fix random test failure

Co-authored-by: chen <[email protected]>
  • Loading branch information
localvar and suchen-sci committed May 19, 2022
1 parent f9da7a3 commit 1c329e2
Show file tree
Hide file tree
Showing 34 changed files with 774 additions and 394 deletions.
42 changes: 31 additions & 11 deletions pkg/filters/fallback/fallback.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@
package fallback

import (
"io"
"strconv"

"github.com/megaease/easegress/pkg/context"
"github.com/megaease/easegress/pkg/filters"
"github.com/megaease/easegress/pkg/protocols/httpprot"
"github.com/megaease/easegress/pkg/protocols/httpprot/fallback"
)

const (
Expand Down Expand Up @@ -51,15 +53,18 @@ func init() {
type (
// Fallback is filter Fallback.
Fallback struct {
spec *Spec
f *fallback.Fallback
spec *Spec
mockBody []byte
bodyLength string
}

// Spec describes the Fallback.
Spec struct {
filters.BaseSpec `yaml:",inline"`

fallback.Spec `yaml:",inline"`
MockCode int `yaml:"mockCode" jsonschema:"required,format=httpcode"`
MockHeaders map[string]string `yaml:"mockHeaders" jsonschema:"omitempty"`
MockBody string `yaml:"mockBody" jsonschema:"omitempty"`
}
)

Expand Down Expand Up @@ -90,18 +95,32 @@ func (f *Fallback) Inherit(previousGeneration filters.Filter) {
}

func (f *Fallback) reload() {
f.f = fallback.New(&f.spec.Spec)
f.mockBody = []byte(f.spec.MockBody)
f.bodyLength = strconv.Itoa(len(f.mockBody))
}

// Handle fallbacks HTTPContext.
// It always returns fallback.
func (f *Fallback) Handle(ctx *context.Context) string {
resp := ctx.GetResponse(ctx.TargetResponseID())
if resp != nil {
f.f.Fallback(resp.(*httpprot.Response))
return resultFallback
resp := ctx.GetResponse(ctx.TargetResponseID()).(*httpprot.Response)
if resp == nil {
return resultResponseNotFound
}

resp.SetStatusCode(f.spec.MockCode)
resp.HTTPHeader().Set("Content-Length", f.bodyLength)
for key, value := range f.spec.MockHeaders {
resp.HTTPHeader().Set(key, value)
}
return resultResponseNotFound

if resp.IsStream() {
if c, ok := resp.GetPayload().(io.Closer); ok {
c.Close()
}
}

resp.SetPayload(f.mockBody)
return resultFallback
}

// Status returns Status.
Expand All @@ -110,4 +129,5 @@ func (f *Fallback) Status() interface{} {
}

// Close closes Fallback.
func (f *Fallback) Close() {}
func (f *Fallback) Close() {
}
13 changes: 8 additions & 5 deletions pkg/filters/headertojson/headertojson.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,17 +123,19 @@ func (h *HeaderToJSON) Handle(ctx *context.Context) string {
return ""
}

reqBody, err := io.ReadAll(ctx.Request().GetPayload())
if err != nil {
req := ctx.Request()
if req.IsStream() {
return resultBodyReadErr
}

reqBody := req.RawPayload()

var body interface{}
if len(reqBody) == 0 {
body = headerMap
} else {
body, err = getNewBody(reqBody, headerMap)
if err != nil {
var err error
if body, err = getNewBody(reqBody, headerMap); err != nil {
return resultJSONEncodeDecodeErr
}
}
Expand All @@ -142,7 +144,8 @@ func (h *HeaderToJSON) Handle(ctx *context.Context) string {
if err != nil {
return resultJSONEncodeDecodeErr
}
ctx.Request().SetPayload(bodyBytes)

req.SetPayload(bodyBytes)
return ""
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/filters/headertojson/headertojson_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func init() {
func setRequest(t *testing.T, ctx *context.Context, id string, req *http.Request) {
httpreq, err := httpprot.NewRequest(req)
assert.Nil(t, err)
_, err = httpreq.FetchPayload()
err = httpreq.FetchPayload(1024 * 1024)
assert.Nil(t, err)
ctx.SetRequest(id, httpreq)
ctx.UseRequest(id, id)
Expand Down
17 changes: 15 additions & 2 deletions pkg/filters/httpbuilder/httpbuilder.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package httpbuilder
import (
"bytes"
"encoding/json"
"fmt"
"html/template"
"net/http"

Expand Down Expand Up @@ -170,22 +171,34 @@ func (r *response) YAMLBody() (interface{}, error) {
}

func prepareBuilderData(ctx *context.Context) (map[string]interface{}, error) {
var rawBody []byte

requests := make(map[string]*request)
responses := make(map[string]*response)

for k, v := range ctx.Requests() {
req := v.(*httpprot.Request)
if req.IsStream() {
rawBody = []byte(fmt.Sprintf("the body of request %s is a stream", k))
} else {
rawBody = req.RawPayload()
}
requests[k] = &request{
Request: req.Std(),
rawBody: req.RawPayload(),
rawBody: rawBody,
}
}

for k, v := range ctx.Responses() {
resp := v.(*httpprot.Response)
if resp.IsStream() {
rawBody = []byte(fmt.Sprintf("the body of response %s is a stream", k))
} else {
rawBody = resp.RawPayload()
}
responses[k] = &response{
Response: resp.Std(),
rawBody: resp.RawPayload(),
rawBody: rawBody,
}
}

Expand Down
10 changes: 5 additions & 5 deletions pkg/filters/httpbuilder/httprequestbuilder.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,21 +150,21 @@ func (rb *HTTPRequestBuilder) Handle(ctx *context.Context) (result string) {
return resultBuildErr
}

req, err := http.NewRequest(ri.Method, ri.URL, nil)
stdReq, err := http.NewRequest(ri.Method, ri.URL, http.NoBody)
if err != nil {
logger.Warnf("failed to create new request: %v", err)
return resultBuildErr
}

for k, vs := range ri.Headers {
for _, v := range vs {
req.Header.Add(k, v)
stdReq.Header.Add(k, v)
}
}

egreq, _ := httpprot.NewRequest(req)
egreq.SetPayload([]byte(ri.Body))
req, _ := httpprot.NewRequest(stdReq)
req.SetPayload([]byte(ri.Body))

ctx.SetRequest(ctx.TargetRequestID(), egreq)
ctx.SetRequest(ctx.TargetRequestID(), req)
return ""
}
16 changes: 8 additions & 8 deletions pkg/filters/httpbuilder/httprequestbuilder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -229,8 +229,8 @@ func TestRequestBody(t *testing.T) {

res := rb.Handle(ctx)
assert.Empty(res)
testReq := ctx.GetRequest("test").(*httpprot.Request).Std()
data, err := io.ReadAll(testReq.Body)
testReq := ctx.GetRequest("test").(*httpprot.Request)
data, err := io.ReadAll(testReq.GetPayload())
assert.Nil(err)
assert.Equal("body", string(data))
}
Expand All @@ -256,8 +256,8 @@ func TestRequestBody(t *testing.T) {

res := rb.Handle(ctx)
assert.Empty(res)
testReq := ctx.GetRequest("test").(*httpprot.Request).Std()
data, err := io.ReadAll(testReq.Body)
testReq := ctx.GetRequest("test").(*httpprot.Request)
data, err := io.ReadAll(testReq.GetPayload())
assert.Nil(err)
assert.Equal("body 123", string(data))
}
Expand All @@ -283,8 +283,8 @@ func TestRequestBody(t *testing.T) {

res := rb.Handle(ctx)
assert.Empty(res)
testReq := ctx.GetRequest("test").(*httpprot.Request).Std()
data, err := io.ReadAll(testReq.Body)
testReq := ctx.GetRequest("test").(*httpprot.Request)
data, err := io.ReadAll(testReq.GetPayload())
assert.Nil(err)
assert.Equal("body value1 value2", string(data))
}
Expand Down Expand Up @@ -313,8 +313,8 @@ field2: value2

res := rb.Handle(ctx)
assert.Empty(res)
testReq := ctx.GetRequest("test").(*httpprot.Request).Std()
data, err := io.ReadAll(testReq.Body)
testReq := ctx.GetRequest("test").(*httpprot.Request)
data, err := io.ReadAll(testReq.GetPayload())
assert.Nil(err)
assert.Equal("body value1 value2", string(data))
}
Expand Down
12 changes: 6 additions & 6 deletions pkg/filters/httpbuilder/httpresponsebuilder.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,19 +129,19 @@ func (rb *HTTPResponseBuilder) Handle(ctx *context.Context) (result string) {
return resultBuildErr
}

resp := &http.Response{Header: http.Header{}}
resp.StatusCode = ri.StatusCode
stdResp := &http.Response{Header: http.Header{}, Body: http.NoBody}
stdResp.StatusCode = ri.StatusCode

for k, vs := range ri.Headers {
for _, v := range vs {
resp.Header.Add(k, v)
stdResp.Header.Add(k, v)
}
}

// build body
egresp, _ := httpprot.NewResponse(resp)
egresp.SetPayload([]byte(ri.Body))
resp, _ := httpprot.NewResponse(stdResp)
resp.SetPayload([]byte(ri.Body))

ctx.SetResponse(ctx.TargetResponseID(), egresp)
ctx.SetResponse(ctx.TargetResponseID(), resp)
return ""
}
18 changes: 9 additions & 9 deletions pkg/filters/httpbuilder/httpresponsebuilder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func getResponseBuilder(spec *HTTPResponseBuilderSpec) *HTTPResponseBuilder {

func setRequest(t *testing.T, ctx *context.Context, id string, req *http.Request) {
r, err := httpprot.NewRequest(req)
r.FetchPayload()
r.FetchPayload(1024 * 1024)
assert.Nil(t, err)
ctx.SetRequest(id, r)
}
Expand Down Expand Up @@ -150,8 +150,8 @@ func TestResponseBody(t *testing.T) {
ctx.UseResponse("test")
res := rb.Handle(ctx)
assert.Empty(res)
testReq := ctx.GetResponse("test").(*httpprot.Response).Std()
data, err := io.ReadAll(testReq.Body)
testReq := ctx.GetResponse("test").(*httpprot.Response)
data, err := io.ReadAll(testReq.GetPayload())
assert.Nil(err)
assert.Equal("body", string(data))
}
Expand All @@ -175,8 +175,8 @@ func TestResponseBody(t *testing.T) {
ctx.UseResponse("test")
res := rb.Handle(ctx)
assert.Empty(res)
testResp := ctx.GetResponse("test").(*httpprot.Response).Std()
data, err := io.ReadAll(testResp.Body)
testResp := ctx.GetResponse("test").(*httpprot.Response)
data, err := io.ReadAll(testResp.GetPayload())
assert.Nil(err)
assert.Equal("body 123", string(data))
}
Expand All @@ -200,8 +200,8 @@ func TestResponseBody(t *testing.T) {
ctx.UseResponse("test")
res := rb.Handle(ctx)
assert.Empty(res)
testResp := ctx.GetResponse("test").(*httpprot.Response).Std()
data, err := io.ReadAll(testResp.Body)
testResp := ctx.GetResponse("test").(*httpprot.Response)
data, err := io.ReadAll(testResp.GetPayload())
assert.Nil(err)
assert.Equal("body value1 value2", string(data))
}
Expand All @@ -228,8 +228,8 @@ field2: value2
ctx.UseResponse("test")
res := rb.Handle(ctx)
assert.Empty(res)
testResp := ctx.GetResponse("test").(*httpprot.Response).Std()
data, err := io.ReadAll(testResp.Body)
testResp := ctx.GetResponse("test").(*httpprot.Response)
data, err := io.ReadAll(testResp.GetPayload())
assert.Nil(err)
assert.Equal("body value1 value2", string(data))
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/filters/kafkabackend/kafka_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ func defaultFilterSpec(t *testing.T, spec *Spec) filters.Spec {
func setRequest(t *testing.T, ctx *context.Context, id string, req *http.Request) {
httpreq, err := httpprot.NewRequest(req)
assert.Nil(t, err)
_, err = httpreq.FetchPayload()
err = httpreq.FetchPayload(1024 * 1024)
assert.Nil(t, err)
ctx.SetRequest(id, httpreq)
ctx.UseRequest(id, id)
Expand Down
14 changes: 7 additions & 7 deletions pkg/filters/mock/mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,23 +198,23 @@ func (m *Mock) match(ctx *context.Context) *Rule {
}

func (m *Mock) mock(ctx *context.Context, rule *Rule) {
httpreq := ctx.Request().(*httpprot.Request)
httpresp, _ := httpprot.NewResponse(nil)
resp, _ := httpprot.NewResponse(nil)

httpresp.SetStatusCode(rule.Code)
resp.SetStatusCode(rule.Code)
for key, value := range rule.Headers {
httpresp.Std().Header.Set(key, value)
resp.Std().Header.Set(key, value)
}
httpresp.SetPayload([]byte(rule.Body))
ctx.SetResponse(ctx.TargetResponseID(), httpresp)
resp.SetPayload([]byte(rule.Body))
ctx.SetResponse(ctx.TargetResponseID(), resp)

if rule.delay <= 0 {
return
}

req := ctx.Request().(*httpprot.Request)
logger.Debugf("delay for %v ...", rule.delay)
select {
case <-httpreq.Context().Done():
case <-req.Context().Done():
logger.Debugf("request cancelled in the middle of delay mocking")
case <-time.After(rule.delay):
}
Expand Down
5 changes: 4 additions & 1 deletion pkg/filters/proxy/loadbalance_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package proxy

import (
"fmt"
"math/rand"
"net/http"
"testing"

Expand Down Expand Up @@ -57,6 +58,7 @@ func TestRoundRobinLoadBalancer(t *testing.T) {

func TestRandomLoadBalancer(t *testing.T) {
assert := assert.New(t)
rand.Seed(0)

var svrs []*Server
lb := NewLoadBalancer(&LoadBalanceSpec{Policy: "random"}, svrs)
Expand All @@ -74,13 +76,14 @@ func TestRandomLoadBalancer(t *testing.T) {

for i := 0; i < 10; i++ {
if v := counter[i]; v < 900 || v > 1100 {
t.Error("possibility is not even")
t.Errorf("possibility is not even with value %v", v)
}
}
}

func TestWeightedRandomLoadBalancer(t *testing.T) {
assert := assert.New(t)
rand.Seed(0)

var svrs []*Server
lb := NewLoadBalancer(&LoadBalanceSpec{Policy: "weightedRandom"}, svrs)
Expand Down
Loading

0 comments on commit 1c329e2

Please sign in to comment.