Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update mqtt code for new pipeline #620

Merged
merged 2 commits into from
May 12, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 13 additions & 20 deletions pkg/filters/connectcontrol/connectcontrol.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import (
"github.com/megaease/easegress/pkg/context"
"github.com/megaease/easegress/pkg/filters"
"github.com/megaease/easegress/pkg/logger"
"github.com/megaease/easegress/pkg/object/pipeline"
"github.com/megaease/easegress/pkg/protocols/mqttprot"
)

const (
Expand Down Expand Up @@ -71,7 +71,6 @@ type (
BannedClients []string `yaml:"bannedClients" jsonschema:"omitempty"`
BannedTopicRe string `yaml:"bannedTopicRe" jsonschema:"omitempty"`
BannedTopics []string `yaml:"bannedTopics" jsonschema:"omitempty"`
EarlyStop bool `yaml:"earlyStop" jsonschema:"omitempty"`
}

// Status is ConnectControl filter status
Expand All @@ -84,7 +83,6 @@ type (
)

var _ filters.Filter = (*ConnectControl)(nil)
var _ pipeline.MQTTFilter = (*ConnectControl)(nil)

// Name returns the name of the ConnectControl filter instance.
func (cc *ConnectControl) Name() string {
Expand All @@ -103,10 +101,6 @@ func (cc *ConnectControl) Spec() filters.Spec {

// Init init ConnectControl with pipeline filter spec
func (cc *ConnectControl) Init() {
spec := cc.spec
if spec.Protocol() != context.MQTT {
panic("filter ConnectControl only support MQTT protocol for now")
}
cc.bannedClients = make(map[string]struct{})
cc.bannedTopics = make(map[string]struct{})
cc.reload()
Expand Down Expand Up @@ -160,15 +154,15 @@ func (cc *ConnectControl) Status() interface{} {
func (cc *ConnectControl) Close() {
}

func (cc *ConnectControl) checkBan(ctx context.MQTTContext) bool {
cid := ctx.Client().ClientID()
func (cc *ConnectControl) checkBan(req *mqttprot.Request) bool {
cid := req.Client().ClientID()
if cc.bannedClientRe != nil && cc.bannedClientRe.MatchString(cid) {
return true
}
if _, ok := cc.bannedClients[cid]; ok {
return true
}
topic := ctx.PublishPacket().TopicName
topic := req.PublishPacket().TopicName
if cc.bannedTopicRe != nil && cc.bannedTopicRe.MatchString(topic) {
return true
}
Expand All @@ -179,17 +173,16 @@ func (cc *ConnectControl) checkBan(ctx context.MQTTContext) bool {
}

// HandleMQTT handle MQTT request
func (cc *ConnectControl) HandleMQTT(ctx context.MQTTContext) *context.MQTTResult {
if ctx.PacketType() != context.MQTTPublish {
return &context.MQTTResult{}
func (cc *ConnectControl) Handle(ctx *context.Context) string {
req := ctx.Request().(*mqttprot.Request)
resp := ctx.Response().(*mqttprot.Response)
if req.PacketType() != mqttprot.PublishType {
return ""
}

if cc.checkBan(ctx) {
ctx.SetDisconnect()
if cc.spec.EarlyStop {
ctx.SetEarlyStop()
}
return &context.MQTTResult{ErrString: resultBannedClientOrTopic}
if cc.checkBan(req) {
resp.SetDisconnect()
return resultBannedClientOrTopic
}
return &context.MQTTResult{}
return ""
}
111 changes: 48 additions & 63 deletions pkg/filters/connectcontrol/connectcontrol_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,28 +18,36 @@
package connectcontrol

import (
stdcontext "context"
"errors"
"testing"

"github.com/eclipse/paho.mqtt.golang/packets"
"github.com/megaease/easegress/pkg/context"
"github.com/megaease/easegress/pkg/filters"
"github.com/megaease/easegress/pkg/logger"
"github.com/megaease/easegress/pkg/protocols/mqttprot"
"github.com/stretchr/testify/assert"
)

func init() {
logger.InitNop()
}

func newContext(cid string, topic string) context.MQTTContext {
client := &context.MockMQTTClient{
func newContext(cid string, topic string) *context.Context {
ctx := context.New(nil)

client := &mqttprot.MockClient{
MockClientID: cid,
}
packet := packets.NewControlPacket(packets.Publish).(*packets.PublishPacket)
packet.TopicName = topic
return context.NewMQTTContext(stdcontext.Background(), client, packet)
req := mqttprot.NewRequest(packet, client)
ctx.SetRequest("req1", req)
ctx.UseRequest("req1", "req1")

resp := mqttprot.NewResponse()
ctx.SetResponse(context.DefaultResponseID, resp)
ctx.UseResponse(context.DefaultResponseID)
return ctx
}

func defaultFilterSpec(spec *Spec) filters.Spec {
Expand All @@ -53,37 +61,19 @@ func TestConnectControl(t *testing.T) {
assert := assert.New(t)

cc := &ConnectControl{}
assert.Equal(cc.Kind(), Kind, "wrong kind")
assert.Equal(cc.DefaultSpec(), &Spec{}, "wrong spec")
assert.NotEqual(len(cc.Description()), 0, "description for ConnectControl is empty")

assert.NotNil(cc.Results(), "if update result, please update this case")
checkProtocol := func() (err error) {
defer func() {
if errMsg := recover(); errMsg != nil {
err = errors.New(errMsg.(string))
return
}
}()
/*
meta := &pipeline.FilterMetaSpec{
Protocol: context.HTTP,
}
spec := pipeline.MockFilterSpec(nil, "", meta, nil)
*/
cc.Init(spec)
return
}
err := checkProtocol()
assert.NotNil(err, "if ConnectControl supports more protocol, please update this case")
assert.Equal(cc.Kind(), kind, "wrong kind")
assert.NotEqual(len(cc.Kind().Description), 0, "description for ConnectControl is empty")

assert.NotNil(cc.Kind().Results, "if update result, please update this case")

spec := defaultFilterSpec(&Spec{
BannedClients: []string{"banClient1", "banClient2"},
BannedTopics: []string{"banTopic1", "banTopic2"},
})
cc.Init(spec)
newCc := &ConnectControl{}
newCc.Inherit(spec, cc)
}).(*Spec)
cc = kind.CreateInstance(spec).(*ConnectControl)
cc.Init()
newCc := kind.CreateInstance(spec).(*ConnectControl)
newCc.Inherit(cc)
defer newCc.Close()
status := newCc.Status().(*Status)
assert.Equal(status.BannedClientNum, len(spec.BannedClients))
Expand All @@ -95,21 +85,20 @@ type testCase struct {
topic string
errString string
disconnect bool
earlyStop bool
}

func doTest(t *testing.T, spec *Spec, testCases []testCase) {
assert := assert.New(t)
filterSpec := defaultFilterSpec(spec)
cc := &ConnectControl{}
cc.Init(filterSpec)
cc := kind.CreateInstance(filterSpec).(*ConnectControl)
cc.Init()

for _, test := range testCases {
ctx := newContext(test.cid, test.topic)
res := cc.HandleMQTT(ctx)
assert.Equal(res.ErrString, test.errString)
assert.Equal(ctx.Disconnect(), test.disconnect)
assert.Equal(ctx.EarlyStop(), test.earlyStop)
res := cc.Handle(ctx)
assert.Equal(res, test.errString)
resp := ctx.Response().(*mqttprot.Response)
assert.Equal(resp.Disconnect(), test.disconnect)
}
status := cc.Status().(*Status)
assert.Equal(status.BannedClientRe, spec.BannedClientRe)
Expand All @@ -118,57 +107,53 @@ func doTest(t *testing.T, spec *Spec, testCases []testCase) {
assert.Equal(status.BannedTopicNum, len(spec.BannedTopics))
}

func TestHandleMQTT(t *testing.T) {
func TestHandle(t *testing.T) {
// check BannedClients
spec := &Spec{
BannedClients: []string{"ban1", "ban2"},
EarlyStop: true,
}
testCases := []testCase{
{cid: "ban1", topic: "unban", errString: resultBannedClientOrTopic, disconnect: true, earlyStop: true},
{cid: "ban2", topic: "unban", errString: resultBannedClientOrTopic, disconnect: true, earlyStop: true},
{cid: "unban1", topic: "ban/sport/ball", errString: "", disconnect: false, earlyStop: false},
{cid: "unban2", topic: "ban/sport/run", errString: "", disconnect: false, earlyStop: false},
{cid: "unban", topic: "unban", errString: "", disconnect: false, earlyStop: false},
{cid: "ban1", topic: "unban", errString: resultBannedClientOrTopic, disconnect: true},
{cid: "ban2", topic: "unban", errString: resultBannedClientOrTopic, disconnect: true},
{cid: "unban1", topic: "ban/sport/ball", errString: "", disconnect: false},
{cid: "unban2", topic: "ban/sport/run", errString: "", disconnect: false},
{cid: "unban", topic: "unban", errString: "", disconnect: false},
}
doTest(t, spec, testCases)

// check BannedTopics
spec = &Spec{
BannedTopics: []string{"ban/sport/ball", "ban/sport/run"},
EarlyStop: true,
}
testCases = []testCase{
{cid: "unban1", topic: "ban/sport/ball", errString: resultBannedClientOrTopic, disconnect: true, earlyStop: true},
{cid: "unban2", topic: "ban/sport/run", errString: resultBannedClientOrTopic, disconnect: true, earlyStop: true},
{cid: "unban3", topic: "unban/sport", errString: "", disconnect: false, earlyStop: false},
{cid: "unban4", topic: "unban", errString: "", disconnect: false, earlyStop: false},
{cid: "unban1", topic: "ban/sport/ball", errString: resultBannedClientOrTopic, disconnect: true},
{cid: "unban2", topic: "ban/sport/run", errString: resultBannedClientOrTopic, disconnect: true},
{cid: "unban3", topic: "unban/sport", errString: "", disconnect: false},
{cid: "unban4", topic: "unban", errString: "", disconnect: false},
}
doTest(t, spec, testCases)

// check BannedClientRe
spec = &Spec{
BannedClientRe: "phone",
EarlyStop: true,
}
testCases = []testCase{
{cid: "phone123", topic: "ban/sport/ball", errString: resultBannedClientOrTopic, disconnect: true, earlyStop: true},
{cid: "phone256", topic: "ban/sport/run", errString: resultBannedClientOrTopic, disconnect: true, earlyStop: true},
{cid: "tv", topic: "unban/sport", errString: "", disconnect: false, earlyStop: false},
{cid: "tv", topic: "unban", errString: "", disconnect: false, earlyStop: false},
{cid: "phone123", topic: "ban/sport/ball", errString: resultBannedClientOrTopic, disconnect: true},
{cid: "phone256", topic: "ban/sport/run", errString: resultBannedClientOrTopic, disconnect: true},
{cid: "tv", topic: "unban/sport", errString: "", disconnect: false},
{cid: "tv", topic: "unban", errString: "", disconnect: false},
}
doTest(t, spec, testCases)

// check BannedTopicRe
spec = &Spec{
BannedTopicRe: "sport",
EarlyStop: true,
}
testCases = []testCase{
{cid: "phone123", topic: "ban/sport/ball", errString: resultBannedClientOrTopic, disconnect: true, earlyStop: true},
{cid: "phone256", topic: "ban/sport/run", errString: resultBannedClientOrTopic, disconnect: true, earlyStop: true},
{cid: "tv", topic: "unban", errString: "", disconnect: false, earlyStop: false},
{cid: "tv", topic: "unban", errString: "", disconnect: false, earlyStop: false},
{cid: "phone123", topic: "ban/sport/ball", errString: resultBannedClientOrTopic, disconnect: true},
{cid: "phone256", topic: "ban/sport/run", errString: resultBannedClientOrTopic, disconnect: true},
{cid: "tv", topic: "unban", errString: "", disconnect: false},
{cid: "tv", topic: "unban", errString: "", disconnect: false},
}
doTest(t, spec, testCases)

Expand All @@ -178,8 +163,8 @@ func TestHandleMQTT(t *testing.T) {
BannedTopicRe: "(?P<name>",
}
filterSpec := defaultFilterSpec(spec)
cc := &ConnectControl{}
cc.Init(filterSpec)
cc := kind.CreateInstance(filterSpec).(*ConnectControl)
cc.Init()
assert.Nil(t, cc.bannedClientRe)
assert.Nil(t, cc.bannedTopicRe)
}
30 changes: 13 additions & 17 deletions pkg/filters/kafka/kafka.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@ import (
"github.com/megaease/easegress/pkg/context"
"github.com/megaease/easegress/pkg/filters"
"github.com/megaease/easegress/pkg/logger"
"github.com/megaease/easegress/pkg/object/pipeline"
"github.com/megaease/easegress/pkg/protocols/mqttprot"
)

const (
// Kind is the kind of Kafka
Kind = "Kafka"
Kind = "KafkaMQTT"

resultGetDataFailed = "GetDataFailed"
resultGetDataFailed = "getDataFailed"
)

var kind = &filters.Kind{
Expand All @@ -41,7 +41,7 @@ var kind = &filters.Kind{
DefaultSpec: func() filters.Spec {
return &Spec{}
},
CreateInstance: func(spec filter.Spec) filters.Filter {
CreateInstance: func(spec filters.Spec) filters.Filter {
return &Kafka{spec: spec.(*Spec)}
},
}
Expand All @@ -65,7 +65,6 @@ type (
)

var _ filters.Filter = (*Kafka)(nil)
var _ pipeline.MQTTFilter = (*Kafka)(nil)

// Name returns the name of the Kafka filter instance.
func (k *Kafka) Name() string {
Expand Down Expand Up @@ -125,10 +124,6 @@ func (k *Kafka) setProducer() {

// Init init Kafka
func (k *Kafka) Init() {
spec := k.spec
if spec.Protocol() != context.MQTT {
panic("filter Kafka only support MQTT protocol for now")
}
k.done = make(chan struct{})
k.setKV()
k.setProducer()
Expand All @@ -151,7 +146,7 @@ func (k *Kafka) Status() interface{} {
}

// HandleMQTT handle MQTT context
func (k *Kafka) HandleMQTT(ctx context.MQTTContext) *context.MQTTResult {
func (k *Kafka) Handle(ctx *context.Context) string {
var topic string
var headers map[string]string
var payload []byte
Expand All @@ -161,25 +156,26 @@ func (k *Kafka) HandleMQTT(ctx context.MQTTContext) *context.MQTTResult {
if k.topicKey != "" {
topic, ok = ctx.GetKV(k.topicKey).(string)
if !ok {
return &context.MQTTResult{ErrString: resultGetDataFailed}
return resultGetDataFailed
}
}
if k.headerKey != "" {
headers, ok = ctx.GetKV(k.headerKey).(map[string]string)
if !ok {
return &context.MQTTResult{ErrString: resultGetDataFailed}
return resultGetDataFailed
}
}
if k.payloadKey != "" {
payload, ok = ctx.GetKV(k.payloadKey).([]byte)
if !ok {
return &context.MQTTResult{ErrString: resultGetDataFailed}
return resultGetDataFailed
}
}

req := ctx.Request().(*mqttprot.Request)
// set data from PublishPacket if data is missing
if ctx.PacketType() == context.MQTTPublish {
p := ctx.PublishPacket()
if req.PacketType() == mqttprot.PublishType {
p := req.PublishPacket()
if topic == "" {
topic = p.TopicName
}
Expand All @@ -193,7 +189,7 @@ func (k *Kafka) HandleMQTT(ctx context.MQTTContext) *context.MQTTResult {
}

if topic == "" || payload == nil {
return &context.MQTTResult{ErrString: resultGetDataFailed}
return resultGetDataFailed
}

kafkaHeaders := []sarama.RecordHeader{}
Expand All @@ -207,5 +203,5 @@ func (k *Kafka) HandleMQTT(ctx context.MQTTContext) *context.MQTTResult {
Value: sarama.ByteEncoder(payload),
}
k.producer.Input() <- msg
return &context.MQTTResult{}
return ""
}
Loading