Skip to content

Commit

Permalink
Refactor flow (#631)
Browse files Browse the repository at this point in the history
* alias for flow node

* flow node using namespace

* refactor filter input/output validation
  • Loading branch information
localvar committed May 26, 2022
1 parent 96e546c commit 21f959c
Show file tree
Hide file tree
Showing 51 changed files with 334 additions and 384 deletions.
164 changes: 64 additions & 100 deletions pkg/context/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,15 @@ package context

import (
"bytes"
"fmt"
"runtime/debug"

"github.com/megaease/easegress/pkg/logger"
"github.com/megaease/easegress/pkg/protocols"
"github.com/megaease/easegress/pkg/tracing"
)

const (
// InitialRequestID is the ID of the initial request.
InitialRequestID = "initial"
// DefaultResponseID is the ID of the default response.
DefaultResponseID = "default"
)
// DefaultNamespace is the name of the default namespace.
const DefaultNamespace = "PIPELINE"

// Handler is the common interface for all traffic handlers,
// which handle the traffic represented by ctx.
Expand All @@ -51,12 +46,11 @@ type Context struct {
span tracing.Span
lazyTags []func() string

targetRequestID string
request protocols.Request
requests map[string]protocols.Request
inputNs string
outputNs string

targetResponseID string
responses map[string]protocols.Response
requests map[string]protocols.Request
responses map[string]protocols.Response

kv map[interface{}]interface{}
finishFuncs []func()
Expand All @@ -66,6 +60,8 @@ type Context struct {
func New(span tracing.Span) *Context {
ctx := &Context{
span: span,
inputNs: DefaultNamespace,
outputNs: DefaultNamespace,
requests: map[string]protocols.Request{},
responses: map[string]protocols.Response{},
kv: map[interface{}]interface{}{},
Expand All @@ -88,129 +84,97 @@ func (ctx *Context) LazyAddTag(lazyTagFunc func() string) {
ctx.lazyTags = append(ctx.lazyTags, lazyTagFunc)
}

// Request returns the default request.
func (ctx *Context) Request() protocols.Request {
return ctx.request
}

// UseRequest set the requests to use.
//
// dflt set the default request, the next call to Request returns
// this request, if this parameter is empty, InitialRequestID will be
// used.
//
// target is for request producers, if the request exists, the request
// producer update it in place, if the request does not exist, the
// request procuder creates a new request and save it as the target
// request.
//
// If target is an empty string, InitialRequestID will be used.
func (ctx *Context) UseRequest(dflt, target string) {
if dflt == "" {
dflt = InitialRequestID
// UseNamespace sets the input and output namespace to input and output.
func (ctx *Context) UseNamespace(input, output string) {
if input == "" {
input = DefaultNamespace
}

if target == "" {
target = InitialRequestID
if output == "" {
output = input
}

if req := ctx.requests[dflt]; req == nil {
panic(fmt.Errorf("request %s does not exist", dflt))
} else {
ctx.request = req
}

ctx.targetRequestID = target
}

// TargetRequestID returns the ID of the target request.
func (ctx *Context) TargetRequestID() string {
return ctx.targetRequestID
ctx.inputNs = input
ctx.outputNs = output
}

// Requests returns all requests.
// Requests returns all requests, the caller should NOT modify the
// return value.
func (ctx *Context) Requests() map[string]protocols.Request {
return ctx.requests
}

// GetRequest returns the request for id, the function returns nil if
// there's no such request.
func (ctx *Context) GetRequest(id string) protocols.Request {
return ctx.requests[id]
// GetOutputRequest returns the request of the output namespace.
func (ctx *Context) GetOutputRequest() protocols.Request {
return ctx.requests[ctx.outputNs]
}

// SetRequest sets the request of id to req.
func (ctx *Context) SetRequest(id string, req protocols.Request) {
prev := ctx.requests[id]
if prev != nil && prev != req {
prev.Close()
}
ctx.requests[id] = req
// SetOutputRequest sets the request of the output namespace to req.
func (ctx *Context) SetOutputRequest(req protocols.Request) {
ctx.SetRequest(ctx.outputNs, req)
}

// DeleteRequest deletes the request of id.
func (ctx *Context) DeleteRequest(id string) {
req := ctx.requests[id]
if req != nil {
req.Close()
delete(ctx.requests, id)
}
// GetRequest set the request of namespace ns to req.
func (ctx *Context) GetRequest(ns string) protocols.Request {
return ctx.requests[ns]
}

// UseResponse set the reponses to use.
//
// target is for response producers, if the response exists, the
// response producer update it in place, if the reponse does not exist,
// response producer creates a new response and save it as the target
// response.
//
// If target is an empty string, DefaultResponseID will be used;
func (ctx *Context) UseResponse(target string) {
if target == "" {
target = DefaultResponseID
// SetRequest set the request of namespace ns to req.
func (ctx *Context) SetRequest(ns string, req protocols.Request) {
prev := ctx.requests[ns]
if prev != nil && prev != req {
prev.Close()
}

ctx.targetResponseID = target
ctx.requests[ns] = req
}

// TargetResponseID returns the ID of the target response.
func (ctx *Context) TargetResponseID() string {
return ctx.targetResponseID
// GetInputRequest returns the request of the input namespace.
func (ctx *Context) GetInputRequest() protocols.Request {
return ctx.requests[ctx.inputNs]
}

// Response returns the default response, and the return value could
// be nil.
func (ctx *Context) Response() protocols.Response {
return ctx.GetResponse(DefaultResponseID)
// SetInputRequest sets the request of the input namespace to req.
func (ctx *Context) SetInputRequest(req protocols.Request) {
ctx.SetRequest(ctx.inputNs, req)
}

// Responses returns all responses.
func (ctx *Context) Responses() map[string]protocols.Response {
return ctx.responses
}

// GetResponse returns the response for id, the function returns nil if
// there's no such response.
func (ctx *Context) GetResponse(id string) protocols.Response {
return ctx.responses[id]
// GetOutputResponse returns the response of the output namespace.
func (ctx *Context) GetOutputResponse() protocols.Response {
return ctx.responses[ctx.outputNs]
}

// SetResponse sets the response of id to req.
func (ctx *Context) SetResponse(id string, resp protocols.Response) {
prev := ctx.responses[id]
// SetOutputResponse sets the response of the output namespace to resp.
func (ctx *Context) SetOutputResponse(resp protocols.Response) {
ctx.SetResponse(ctx.outputNs, resp)
}

// GetResponse returns the response of namespace ns.
func (ctx *Context) GetResponse(ns string) protocols.Response {
return ctx.responses[ns]
}

// SetResponse set the response of namespace ns to resp.
func (ctx *Context) SetResponse(ns string, resp protocols.Response) {
prev := ctx.responses[ns]
if prev != nil && prev != resp {
prev.Close()
}
ctx.responses[id] = resp
ctx.responses[ns] = resp
}

// DeleteResponse delete the response of id.
func (ctx *Context) DeleteResponse(id string) {
resp := ctx.responses[id]
if resp != nil {
resp.Close()
delete(ctx.responses, id)
}
// GetInputResponse returns the response of the input namespace.
func (ctx *Context) GetInputResponse() protocols.Response {
return ctx.GetResponse(ctx.inputNs)
}

// SetInputResponse sets the response of the input namespace to resp.
func (ctx *Context) SetInputResponse(resp protocols.Response) {
ctx.SetResponse(ctx.inputNs, resp)
}

// SetKV sets the value of key to val.
Expand Down
2 changes: 1 addition & 1 deletion pkg/filters/certextractor/certextractor.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ func (ce *CertExtractor) Close() {}

// Handle retrieves header values and sets request headers.
func (ce *CertExtractor) Handle(ctx *context.Context) string {
r := ctx.Request().(*httpprot.Request)
r := ctx.GetInputRequest().(*httpprot.Request)
connectionState := r.Std().TLS
if connectionState == nil {
return ""
Expand Down
8 changes: 4 additions & 4 deletions pkg/filters/certextractor/certextractor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import (
"github.com/megaease/easegress/pkg/logger"
"github.com/megaease/easegress/pkg/protocols/httpprot"
"github.com/megaease/easegress/pkg/supervisor"
"github.com/megaease/easegress/pkg/tracing"
"github.com/megaease/easegress/pkg/util/yamltool"
"github.com/stretchr/testify/assert"
)
Expand Down Expand Up @@ -90,15 +91,14 @@ field: "CommonName"
}

func prepareCtxAndHeader(t *testing.T, connState *tls.ConnectionState) (*context.Context, http.Header) {
ctx := context.New(nil)
ctx := context.New(tracing.NoopSpan)
stdr := &http.Request{}
stdr.Header = http.Header{}
stdr.TLS = connState

httpreq, err := httpprot.NewRequest(stdr)
req, err := httpprot.NewRequest(stdr)
assert.Nil(t, err)
ctx.SetRequest(context.InitialRequestID, httpreq)
ctx.UseRequest(context.InitialRequestID, context.InitialRequestID)
ctx.SetRequest(context.DefaultNamespace, req)
return ctx, stdr.Header
}

Expand Down
6 changes: 3 additions & 3 deletions pkg/filters/connectcontrol/connectcontrol.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,10 +171,10 @@ func (cc *ConnectControl) checkBan(req *mqttprot.Request) bool {
return false
}

// HandleMQTT handle MQTT request
// Handle handles context.
func (cc *ConnectControl) Handle(ctx *context.Context) string {
req := ctx.Request().(*mqttprot.Request)
resp := ctx.Response().(*mqttprot.Response)
req := ctx.GetInputRequest().(*mqttprot.Request)
resp := ctx.GetOutputResponse().(*mqttprot.Response)
if req.PacketType() != mqttprot.PublishType {
return ""
}
Expand Down
8 changes: 3 additions & 5 deletions pkg/filters/connectcontrol/connectcontrol_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,10 @@ func newContext(cid string, topic string) *context.Context {
packet := packets.NewControlPacket(packets.Publish).(*packets.PublishPacket)
packet.TopicName = topic
req := mqttprot.NewRequest(packet, client)
ctx.SetRequest("req1", req)
ctx.UseRequest("req1", "req1")
ctx.SetRequest(context.DefaultNamespace, req)

resp := mqttprot.NewResponse()
ctx.SetResponse(context.DefaultResponseID, resp)
ctx.UseResponse(context.DefaultResponseID)
ctx.SetOutputResponse(resp)
return ctx
}

Expand Down Expand Up @@ -97,7 +95,7 @@ func doTest(t *testing.T, spec *Spec, testCases []testCase) {
ctx := newContext(test.cid, test.topic)
res := cc.Handle(ctx)
assert.Equal(res, test.errString)
resp := ctx.Response().(*mqttprot.Response)
resp := ctx.GetOutputResponse().(*mqttprot.Response)
assert.Equal(resp.Disconnect(), test.disconnect)
}
status := cc.Status().(*Status)
Expand Down
8 changes: 4 additions & 4 deletions pkg/filters/corsadaptor/corsadaptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,29 +119,29 @@ func (a *CORSAdaptor) Handle(ctx *context.Context) string {
}

func (a *CORSAdaptor) handle(ctx *context.Context) string {
r := ctx.Request().(*httpprot.Request)
r := ctx.GetInputRequest().(*httpprot.Request)
method := r.Method()
headerAllowMethod := r.Header().Get("Access-Control-Request-Method")
if method == http.MethodOptions && headerAllowMethod != "" {
rw := httptest.NewRecorder()
a.cors.HandlerFunc(rw, r.Std())
resp, _ := httpprot.NewResponse(rw.Result())
ctx.SetResponse(ctx.TargetResponseID(), resp)
ctx.SetOutputResponse(resp)
return resultPreflighted
}
return ""
}

func (a *CORSAdaptor) handleCORS(ctx *context.Context) string {
r := ctx.Request().(*httpprot.Request)
r := ctx.GetInputRequest().(*httpprot.Request)
method := r.Method()
isCorsRequest := r.Header().Get("Origin") != ""
isPreflight := method == http.MethodOptions && r.Header().Get("Access-Control-Request-Method") != ""
// set CORS headers to response
rw := httptest.NewRecorder()
a.cors.HandlerFunc(rw, r.Std())
resp, _ := httpprot.NewResponse(rw.Result())
ctx.SetResponse(ctx.TargetResponseID(), resp)
ctx.SetOutputResponse(resp)
if !isCorsRequest {
return "" // next filter
}
Expand Down
11 changes: 5 additions & 6 deletions pkg/filters/corsadaptor/corsadaptor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,10 @@ import (
"github.com/stretchr/testify/assert"
)

func setRequest(t *testing.T, ctx *context.Context, req *http.Request, id string) {
httpreq, err := httpprot.NewRequest(req)
func setRequest(t *testing.T, ctx *context.Context, stdReq *http.Request) {
req, err := httpprot.NewRequest(stdReq)
assert.Nil(t, err)
ctx.SetRequest(id, httpreq)
ctx.UseRequest(id, id)
ctx.SetInputRequest(req)
}

func TestCORSAdaptor(t *testing.T) {
Expand All @@ -57,7 +56,7 @@ name: cors
ctx := context.New(nil)
req, err := http.NewRequest(http.MethodOptions, "http:https://example.com/", nil)
assert.Nil(err)
setRequest(t, ctx, req, "req1")
setRequest(t, ctx, req)

result := cors.Handle(ctx)
if result == resultPreflighted {
Expand Down Expand Up @@ -104,7 +103,7 @@ allowedOrigins:
ctx := context.New(nil)
req, err := http.NewRequest(http.MethodOptions, "http:https://example.com", nil)
assert.Nil(err)
setRequest(t, ctx, req, "req1")
setRequest(t, ctx, req)

result := cors.Handle(ctx)
if result == resultPreflighted {
Expand Down
2 changes: 1 addition & 1 deletion pkg/filters/fallback/fallback.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ func (f *Fallback) reload() {
// Handle fallbacks HTTPContext.
// It always returns fallback.
func (f *Fallback) Handle(ctx *context.Context) string {
resp := ctx.GetResponse(ctx.TargetResponseID()).(*httpprot.Response)
resp := ctx.GetInputResponse().(*httpprot.Response)
if resp == nil {
return resultResponseNotFound
}
Expand Down
Loading

0 comments on commit 21f959c

Please sign in to comment.