Skip to content

Commit

Permalink
change httpprot Request Response from interface to struct
Browse files Browse the repository at this point in the history
  • Loading branch information
suchen-sci committed Mar 22, 2022
1 parent ec699d3 commit 7f6958a
Show file tree
Hide file tree
Showing 11 changed files with 83 additions and 91 deletions.
2 changes: 1 addition & 1 deletion pkg/filters/certextractor/certextractor.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ func (ce *CertExtractor) Handle(ctx context.Context) string {

// CertExtractor extracts given field from TLS certificates and sets it to request headers.
func (ce *CertExtractor) handle(ctx context.Context) string {
r := ctx.Request().(httpprot.Request)
r := ctx.Request().(*httpprot.Request)
connectionState := r.Std().TLS
if connectionState == nil {
return ""
Expand Down
8 changes: 4 additions & 4 deletions pkg/filters/corsadaptor/corsadaptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,8 @@ func (a *CORSAdaptor) Handle(ctx context.Context) string {
}

func (a *CORSAdaptor) handle(ctx context.Context) string {
r := ctx.Request().(httpprot.Request)
w := ctx.Response().(httpprot.Response)
r := ctx.Request().(*httpprot.Request)
w := ctx.Response().(*httpprot.Response)
method := r.Method()
headerAllowMethod := r.Header().Get("Access-Control-Request-Method")
if method == http.MethodOptions && headerAllowMethod != "" {
Expand All @@ -132,8 +132,8 @@ func (a *CORSAdaptor) handle(ctx context.Context) string {
}

func (a *CORSAdaptor) handleCORS(ctx context.Context) string {
r := ctx.Request().(httpprot.Request)
w := ctx.Response().(httpprot.Response)
r := ctx.Request().(*httpprot.Request)
w := ctx.Response().(*httpprot.Response)
method := r.Method()
isCorsRequest := r.Header().Get("Origin") != ""
isPreflight := method == http.MethodOptions && r.Header().Get("Access-Control-Request-Method") != ""
Expand Down
2 changes: 1 addition & 1 deletion pkg/filters/fallback/fallback.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ func (f *Fallback) reload() {
// Handle fallbacks HTTPContext.
// It always returns fallback.
func (f *Fallback) Handle(ctx context.Context) string {
resp := ctx.Response().(httpprot.Response)
resp := ctx.Response().(*httpprot.Response)
f.f.Fallback(resp)
return resultFallback
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/filters/headerlookup/headerlookup.go
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ func (hl *HeaderLookup) handle(ctx context.Context) string {
// TODO: now headerlookup need path which make it only support for http protocol!
// this may need update later
if hl.spec.PathRegExp != "" {
httpreq, ok := ctx.Request().(httpprot.Request)
httpreq, ok := ctx.Request().(*httpprot.Request)
if ok {
path := httpreq.Path()
if match := hl.pathRegExp.FindStringSubmatch(path); match != nil && len(match) > 1 {
Expand Down
4 changes: 2 additions & 2 deletions pkg/filters/mock/mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ func (m *Mock) Handle(ctx context.Context) string {
}

func (m *Mock) match(ctx context.Context) *Rule {
path := ctx.Request().(httpprot.Request).Path()
path := ctx.Request().(*httpprot.Request).Path()
header := ctx.Request().Header()

matchPath := func(rule *Rule) bool {
Expand Down Expand Up @@ -198,7 +198,7 @@ func (m *Mock) match(ctx context.Context) *Rule {
}

func (m *Mock) mock(ctx context.Context, rule *Rule) {
w := ctx.Response().(httpprot.Response)
w := ctx.Response().(*httpprot.Response)
w.SetStatusCode(rule.Code)
for key, value := range rule.Headers {
w.Header().Set(key, value)
Expand Down
88 changes: 42 additions & 46 deletions pkg/protocols/httpprot/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,28 +27,25 @@ import (
)

type (
Request interface {
protocols.Request

Std() *http.Request
URL() *url.URL
Path() string
SetPath(path string)
Scheme() string

RealIP() string
Proto() string
Method() string
SetMethod(method string)
Host() string
SetHost(host string)

Cookie(name string) (*http.Cookie, error)
Cookies() []*http.Cookie
AddCookie(cookie *http.Cookie)
}

request struct {
// Request provide following methods
// protocols.Request
// Std() *http.Request
// URL() *url.URL
// Path() string
// SetPath(path string)
// Scheme() string

// RealIP() string
// Proto() string
// Method() string
// SetMethod(method string)
// Host() string
// SetHost(host string)

// Cookie(name string) (*http.Cookie, error)
// Cookies() []*http.Cookie
// AddCookie(cookie *http.Cookie)
Request struct {
std *http.Request
realIP string

Expand All @@ -57,95 +54,94 @@ type (
}
)

var _ Request = (*request)(nil)
var _ protocols.Request = (*request)(nil)
var _ protocols.Request = (*Request)(nil)

func newRequest(r *http.Request) Request {
req := &request{}
func newRequest(r *http.Request) *Request {
req := &Request{}
req.std = r
req.realIP = realip.FromRequest(r)
req.payload = newPayload(r.Body)
req.header = newHeader(r.Header)
return req
}

func (r *request) Std() *http.Request {
func (r *Request) Std() *http.Request {
return r.std
}

func (r *request) URL() *url.URL {
func (r *Request) URL() *url.URL {
return r.std.URL
}

func (r *request) RealIP() string {
func (r *Request) RealIP() string {
return r.realIP
}

func (r *request) Proto() string {
func (r *Request) Proto() string {
return r.std.Proto
}

func (r *request) Method() string {
func (r *Request) Method() string {
return r.std.Method
}

func (r *request) Cookie(name string) (*http.Cookie, error) {
func (r *Request) Cookie(name string) (*http.Cookie, error) {
return r.std.Cookie(name)
}

func (r *request) Cookies() []*http.Cookie {
func (r *Request) Cookies() []*http.Cookie {
return r.std.Cookies()
}

func (r *request) AddCookie(cookie *http.Cookie) {
func (r *Request) AddCookie(cookie *http.Cookie) {
r.std.AddCookie(cookie)
}

func (r *request) Header() protocols.Header {
func (r *Request) Header() protocols.Header {
return r.header
}

func (r *request) Payload() protocols.Payload {
func (r *Request) Payload() protocols.Payload {
return r.payload
}

func (r *request) Context() context.Context {
func (r *Request) Context() context.Context {
return r.std.Context()
}

func (r *request) WithContext(ctx context.Context) {
func (r *Request) WithContext(ctx context.Context) {
r.std = r.std.WithContext(ctx)
}

func (r *request) Finish() {
func (r *Request) Finish() {
r.payload.Close()
}

func (r *request) SetMethod(method string) {
func (r *Request) SetMethod(method string) {
r.std.Method = method
}

func (r *request) Host() string {
func (r *Request) Host() string {
return r.std.Host
}

func (r *request) SetHost(host string) {
func (r *Request) SetHost(host string) {
r.std.Host = host
}

func (r *request) Clone() protocols.Request {
func (r *Request) Clone() protocols.Request {
return nil
}

func (r *request) Path() string {
func (r *Request) Path() string {
return r.std.URL.Path
}

func (r *request) SetPath(path string) {
func (r *Request) SetPath(path string) {
r.std.URL.Path = path
}

func (r *request) Scheme() string {
func (r *Request) Scheme() string {
if s := r.std.URL.Scheme; s != "" {
return s
}
Expand Down
50 changes: 23 additions & 27 deletions pkg/protocols/httpprot/response.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,16 @@ import (
)

type (
Response interface {
protocols.Response

Std() http.ResponseWriter

StatusCode() int
SetStatusCode(code int)
SetCookie(cookie *http.Cookie)
FlushedBodyBytes() uint64
OnFlushBody(fn BodyFlushFunc)
}

response struct {
// Response provide following methods
// protocols.Response
// Std() http.ResponseWriter
// StatusCode() int
// SetStatusCode(code int)
// SetCookie(cookie *http.Cookie)
// FlushedBodyBytes() uint64
// OnFlushBody(fn BodyFlushFunc)

Response struct {
std http.ResponseWriter
code int
bodyWritten uint64
Expand All @@ -51,46 +48,45 @@ type (
BodyFlushFunc = func(body []byte, complete bool) (newBody []byte)
)

var _ Response = (*response)(nil)
var _ protocols.Response = (*response)(nil)
var _ protocols.Response = (*Response)(nil)

func newResponse(w http.ResponseWriter) Response {
return &response{
func newResponse(w http.ResponseWriter) *Response {
return &Response{
std: w,
code: http.StatusOK,
header: newHeader(w.Header()),
}
}

func (resp *response) Std() http.ResponseWriter {
func (resp *Response) Std() http.ResponseWriter {
return resp.std
}

func (resp *response) StatusCode() int {
func (resp *Response) StatusCode() int {
return resp.code
}

func (resp *response) SetStatusCode(code int) {
func (resp *Response) SetStatusCode(code int) {
resp.code = code
}

func (resp *response) SetCookie(cookie *http.Cookie) {
func (resp *Response) SetCookie(cookie *http.Cookie) {
http.SetCookie(resp.std, cookie)
}

func (resp *response) Payload() protocols.Payload {
func (resp *Response) Payload() protocols.Payload {
return resp.payload
}

func (resp *response) Header() protocols.Header {
func (resp *Response) Header() protocols.Header {
return resp.header
}

func (resp *response) FlushedBodyBytes() uint64 {
func (resp *Response) FlushedBodyBytes() uint64 {
return resp.bodyWritten
}

func (resp *response) Finish() {
func (resp *Response) Finish() {
resp.std.WriteHeader(resp.StatusCode())
reader := resp.payload.NewReader()
if reader == nil {
Expand All @@ -109,10 +105,10 @@ func (resp *response) Finish() {
resp.bodyWritten += uint64(written)
}

func (resp *response) Clone() protocols.Response {
func (resp *Response) Clone() protocols.Response {
return nil
}

func (resp *response) OnFlushBody(fn BodyFlushFunc) {
func (resp *Response) OnFlushBody(fn BodyFlushFunc) {
resp.bodyFlushFuncs = append(resp.bodyFlushFuncs, fn)
}
2 changes: 1 addition & 1 deletion pkg/util/fallback/fallback.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ func New(spec *Spec) *Fallback {
}

// Fallback fallbacks HTTPContext.
func (f *Fallback) Fallback(w httpprot.Response) {
func (f *Fallback) Fallback(w *httpprot.Response) {
w.SetStatusCode(f.spec.MockCode)
w.Header().Set(httpheader.KeyContentLength, f.bodyLength)
for key, value := range f.spec.MockHeaders {
Expand Down
8 changes: 4 additions & 4 deletions pkg/util/httpfilter/httpfilter.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ func New(spec *Spec) *HTTPFilter {
}

// Filter filters HTTPContext.
func (hf *HTTPFilter) Filter(req httpprot.Request) bool {
func (hf *HTTPFilter) Filter(req *httpprot.Request) bool {
if len(hf.spec.Headers) > 0 {
matchHeader := hf.filterHeader(req)
if matchHeader && len(hf.spec.URLs) > 0 {
Expand All @@ -112,7 +112,7 @@ func (hf *HTTPFilter) Filter(req httpprot.Request) bool {
return hf.filterProbability(req)
}

func (hf *HTTPFilter) filterHeader(req httpprot.Request) bool {
func (hf *HTTPFilter) filterHeader(req *httpprot.Request) bool {
h := req.Header()
headerMatchNum := 0
for key, matchRule := range hf.spec.Headers {
Expand Down Expand Up @@ -150,7 +150,7 @@ func (hf *HTTPFilter) filterHeader(req httpprot.Request) bool {
return false
}

func (hf *HTTPFilter) filterURL(req httpprot.Request) bool {
func (hf *HTTPFilter) filterURL(req *httpprot.Request) bool {
urlMatch := false
for _, url := range hf.spec.URLs {
if url.Match(req) {
Expand All @@ -161,7 +161,7 @@ func (hf *HTTPFilter) filterURL(req httpprot.Request) bool {
return urlMatch
}

func (hf *HTTPFilter) filterProbability(req httpprot.Request) bool {
func (hf *HTTPFilter) filterProbability(req *httpprot.Request) bool {
prob := hf.spec.Probability

var result uint32
Expand Down
Loading

0 comments on commit 7f6958a

Please sign in to comment.