Skip to content

Commit

Permalink
convert Context from interface to struct
Browse files Browse the repository at this point in the history
  • Loading branch information
localvar committed Apr 11, 2022
1 parent da7b154 commit b3ab105
Show file tree
Hide file tree
Showing 32 changed files with 214 additions and 539 deletions.
229 changes: 95 additions & 134 deletions pkg/context/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,111 +22,52 @@ import (

"github.com/megaease/easegress/pkg/protocols"
"github.com/megaease/easegress/pkg/tracing"
"github.com/megaease/easegress/pkg/util/fasttime"
)

const (
// InitialRequestID is the ID of the initial request.
InitialRequestID = "initial"
// DefaultResponseID is the ID of the initial response.
// InitialResponseID is the ID of the initial response.
InitialResponseID = "initial"
)

type (
// Handler is the common interface for all traffic handlers,
// which handle the traffic represented by ctx.
Handler interface {
Handle(ctx Context) string
}

// MuxMapper gets the traffic handler by name.
MuxMapper interface {
GetHandler(name string) (Handler, bool)
}

// Context defines the common interface of the context.
Context interface {
Span() tracing.Span
// StatMetric() *httpstat.Metric

AddTag(tag string)
AddLazyTag(func() string)

// UseRequest set the requests to use.
//
// dflt set the default request, the next call to Request returns
// this request, if this parameter is empty, InitialRequestID will be
// used.
//
// base & target are for request producers, the request producer
// creates a new request based on the base request, update its content
// and then save it as the target request.
//
// If parameter base is empty, a new request is created with default
// content; if paramter target is empty, InitialRequestID will be used;
// if base equals to target, the request will be updated in place.
UseRequest(dflt, base, target string)
BaseRequestID() string
TargetRequestID() string

Request() protocols.Request
Requests() map[string]protocols.Request
GetRequest(id string) protocols.Request
SetRequest(id string, req protocols.Request)
DeleteRequest(id string)

// UseResponse set the reponses to use.
//
// base & target are for response producers, the response producer
// creates a new response based on the base response, update its content
// and then save it as the target response.
//
// If parameter base is empty, a new response is created with default
// content; if paramter target is empty, InitialResponseID will be used;
// if base equals to target, the response will be updated in place.
UseResponse(base, target string)
BaseResponseID() string
TargetResponseID() string

Response() protocols.Response
Responses() map[string]protocols.Response
GetResponse(id string) protocols.Response
SetResponse(id string, resp protocols.Response)
DeleteResponse(id string)

// change put setkv getkv into context not request or response maybe better
SetKV(key, val interface{})
GetKV(key interface{}) interface{}

OnFinish(fn func())
Finish()
}
// Handler is the common interface for all traffic handlers,
// which handle the traffic represented by ctx.
type Handler interface {
Handle(ctx *Context) string
}

// context manage requests and responses
// there are no HTTPContext and MQTTContext, but only HTTPRequest, HTTPResponse,
// MQTTRequest and MQTTResponse
context struct {
lazyTags []func() string
// MuxMapper gets the traffic handler by name.
type MuxMapper interface {
GetHandler(name string) (Handler, bool)
}

baseRequestID string
targetRequestID string
request protocols.Request
requests map[string]protocols.Request
// Context holds requests, responses and other data that need to be passed
// through the pipeline.
type Context struct {
span tracing.Span
lazyTags []func() string

baseResponseID string
targetResponseID string
response protocols.Response
responses map[string]protocols.Response
targetRequestID string
request protocols.Request
requests map[string]protocols.Request

kv map[interface{}]interface{}
finishFuncs []func()
}
)
targetResponseID string
response protocols.Response
responses map[string]protocols.Response

var _ Context = (*context)(nil)
kv map[interface{}]interface{}
finishFuncs []func()
}

// New creates a new Context.
func New(req protocols.Request, resp protocols.Response, tracer *tracing.Tracing, spanName string) Context {
ctx := &context{
func New(req protocols.Request, resp protocols.Response, tracer *tracing.Tracing, spanName string) *Context {
now := fasttime.Now()
span := tracing.NewSpanWithStart(tracer, spanName, now)

ctx := &Context{
span: span,
request: req,
requests: map[string]protocols.Request{
InitialRequestID: req,
Expand All @@ -140,31 +81,47 @@ func New(req protocols.Request, resp protocols.Response, tracer *tracing.Tracing
return ctx
}

func (ctx *context) AddTag(tag string) {
// Span returns the span of this Context.
func (ctx *Context) Span() tracing.Span {
return ctx.span
}

// AddTag add a tag to the Context.
func (ctx *Context) AddTag(tag string) {
ctx.lazyTags = append(ctx.lazyTags, func() string { return tag })
}

// AddLazyTag add a lazy tag to the Context.
// questions about tags
// how to access statistics data from every requests?
// add new method when finish?
// add tag to every single request or add tag to context
func (ctx *context) AddLazyTag(lazyTagFunc func() string) {
// add tag to every single request or add tag to Context
func (ctx *Context) AddLazyTag(lazyTagFunc func() string) {
ctx.lazyTags = append(ctx.lazyTags, lazyTagFunc)
}

func (ctx *context) Request() protocols.Request {
// Request returns the default request.
func (ctx *Context) Request() protocols.Request {
return ctx.request
}

func (ctx *context) UseRequest(dflt, base, target string) {
// UseRequest set the requests to use.
//
// dflt set the default request, the next call to Request returns
// this request, if this parameter is empty, InitialRequestID will be
// used.
//
// target is for request producers, if the request exists, the request
// producer update it in place, if the request does not exist, the
// request procuder creates a new request and save it as the target
// request.
//
// If target is an empty string, InitialRequestID will be used.
func (ctx *Context) UseRequest(dflt, target string) {
if dflt == "" {
dflt = InitialRequestID
}

if base != "" && ctx.requests[base] == nil {
panic(fmt.Errorf("request %s does not exist", base))
}

if target == "" {
target = InitialRequestID
}
Expand All @@ -175,104 +132,113 @@ func (ctx *context) UseRequest(dflt, base, target string) {
ctx.request = req
}

ctx.baseRequestID = base
ctx.targetRequestID = target
}

func (ctx *context) BaseRequestID() string {
return ctx.baseRequestID
}

func (ctx *context) TargetRequestID() string {
// TargetRequestID returns the ID of the target request.
func (ctx *Context) TargetRequestID() string {
return ctx.targetRequestID
}

func (ctx *context) Requests() map[string]protocols.Request {
// Requests returns all requests.
func (ctx *Context) Requests() map[string]protocols.Request {
return ctx.requests
}

func (ctx *context) GetRequest(id string) protocols.Request {
// GetRequest returns the request for id.
func (ctx *Context) GetRequest(id string) protocols.Request {
return ctx.requests[id]
}

func (ctx *context) SetRequest(id string, req protocols.Request) {
// SetRequest sets the request of id to req.
func (ctx *Context) SetRequest(id string, req protocols.Request) {
prev := ctx.requests[id]
if prev != nil && prev != req {
prev.Close()
}
ctx.requests[id] = req
}

func (ctx *context) DeleteRequest(id string) {
// DeleteRequest deletes the request of id.
func (ctx *Context) DeleteRequest(id string) {
req := ctx.requests[id]
if req != nil {
req.Close()
delete(ctx.requests, id)
}
}

func (ctx *context) UseResponse(base, target string) {
if base != "" && ctx.responses[base] == nil {
panic(fmt.Errorf("response %s does not exist", base))
}

// UseResponse set the reponses to use.
//
// target is for response producers, if the response exists, the
// response producer update it in place, if the reponse does not exist,
// response producer creates a new response and save it as the target
// response.
//
// If target is an empty string, InitialResponseID will be used;
func (ctx *Context) UseResponse(target string) {
if target == "" {
target = InitialResponseID
}

ctx.baseRequestID = base
ctx.targetResponseID = target
}

func (ctx *context) BaseResponseID() string {
return ctx.baseResponseID
}

func (ctx *context) TargetResponseID() string {
// TargetResponseID returns the ID of the target response.
func (ctx *Context) TargetResponseID() string {
return ctx.targetResponseID
}

func (ctx *context) Response() protocols.Response {
// Response returns the default response.
func (ctx *Context) Response() protocols.Response {
return ctx.response
}

func (ctx *context) Responses() map[string]protocols.Response {
// Responses returns all responses.
func (ctx *Context) Responses() map[string]protocols.Response {
return ctx.responses
}

func (ctx *context) GetResponse(id string) protocols.Response {
// GetResponse returns the response for id.
func (ctx *Context) GetResponse(id string) protocols.Response {
return ctx.responses[id]
}

func (ctx *context) SetResponse(id string, resp protocols.Response) {
// SetResponse sets the response of id to req.
func (ctx *Context) SetResponse(id string, resp protocols.Response) {
prev := ctx.responses[id]
if prev != nil && prev != resp {
prev.Close()
}
ctx.responses[id] = resp
}

func (ctx *context) DeleteResponse(id string) {
// DeleteResponse delete the response of id.
func (ctx *Context) DeleteResponse(id string) {
resp := ctx.responses[id]
if resp != nil {
resp.Close()
delete(ctx.responses, id)
}
}

func (ctx *context) SetKV(key, val interface{}) {
// SetKV sets the value of key to val.
func (ctx *Context) SetKV(key, val interface{}) {
ctx.kv[key] = val
}

func (ctx *context) GetKV(key interface{}) interface{} {
// GetKV returns the value of key.
func (ctx *Context) GetKV(key interface{}) interface{} {
return ctx.kv[key]
}

func (ctx *context) OnFinish(fn func()) {
// OnFinish registers a function to be called in Finish.
func (ctx *Context) OnFinish(fn func()) {
ctx.finishFuncs = append(ctx.finishFuncs, fn)
}

func (ctx *context) Finish() {
// Finish calls all finish functions.
func (ctx *Context) Finish() {
// TODO: add tags here
for _, req := range ctx.requests {
req.Close()
Expand All @@ -284,8 +250,3 @@ func (ctx *context) Finish() {
fn()
}
}

func (ctx *context) Span() tracing.Span {
// TODO: add span
return nil
}
10 changes: 5 additions & 5 deletions pkg/filters/apiaggregator/apiaggregator.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,12 +174,12 @@ func (aa *APIAggregator) reload() {
}

// Handle limits HTTPContext.
func (aa *APIAggregator) Handle(ctx context.Context) (result string) {
func (aa *APIAggregator) Handle(ctx *context.Context) (result string) {
httpreq := ctx.Request().(*httpprot.Request)
httpresp := ctx.Response().(*httpprot.Response)
buff := bytes.NewBuffer(nil)
if aa.spec.MaxBodyBytes > 0 {
written, err := io.CopyN(buff, httpreq.Payload().NewReader(), aa.spec.MaxBodyBytes+1)
written, err := io.CopyN(buff, httpreq.GetPayload(), aa.spec.MaxBodyBytes+1)
if written > aa.spec.MaxBodyBytes {
ctx.AddTag(fmt.Sprintf("apiAggregator: request body exceed %dB", aa.spec.MaxBodyBytes))
httpresp.SetStatusCode(http.StatusRequestEntityTooLarge)
Expand Down Expand Up @@ -253,7 +253,7 @@ func (aa *APIAggregator) Handle(ctx context.Context) (result string) {
return aa.formatResponse(ctx, data)
}

func (aa *APIAggregator) newHTTPReq(ctx context.Context, p *Pipeline, buff *bytes.Buffer) (*http.Request, error) {
func (aa *APIAggregator) newHTTPReq(ctx *context.Context, p *Pipeline, buff *bytes.Buffer) (*http.Request, error) {
httpreq := ctx.Request().(*httpprot.Request)

var stdctx stdcontext.Context = httpreq.Context()
Expand All @@ -280,7 +280,7 @@ func (aa *APIAggregator) newHTTPReq(ctx context.Context, p *Pipeline, buff *byte
return http.NewRequestWithContext(stdctx, method, url.String(), body)
}

func (aa *APIAggregator) copyHTTPBody2Map(body io.Reader, ctx context.Context, data map[string][]byte, name string) string {
func (aa *APIAggregator) copyHTTPBody2Map(body io.Reader, ctx *context.Context, data map[string][]byte, name string) string {
httpresp := ctx.Response().(*httpprot.Response)
respBody := bytes.NewBuffer(nil)

Expand All @@ -301,7 +301,7 @@ func (aa *APIAggregator) copyHTTPBody2Map(body io.Reader, ctx context.Context, d
return ""
}

func (aa *APIAggregator) formatResponse(ctx context.Context, data map[string][]byte) string {
func (aa *APIAggregator) formatResponse(ctx *context.Context, data map[string][]byte) string {
httpresp := ctx.Response().(*httpprot.Response)
if aa.spec.MergeResponse {
result := map[string]interface{}{}
Expand Down
Loading

0 comments on commit b3ab105

Please sign in to comment.