Skip to content

Commit

Permalink
update mqtt code for new pipeline (easegress-io#620)
Browse files Browse the repository at this point in the history
* update mqtt code for new pipeline

* update mqttproxy use context
  • Loading branch information
suchen-sci committed May 12, 2022
1 parent 0f802e6 commit 9cb5c68
Show file tree
Hide file tree
Showing 17 changed files with 795 additions and 278 deletions.
33 changes: 13 additions & 20 deletions pkg/filters/connectcontrol/connectcontrol.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import (
"github.com/megaease/easegress/pkg/context"
"github.com/megaease/easegress/pkg/filters"
"github.com/megaease/easegress/pkg/logger"
"github.com/megaease/easegress/pkg/object/pipeline"
"github.com/megaease/easegress/pkg/protocols/mqttprot"
)

const (
Expand Down Expand Up @@ -71,7 +71,6 @@ type (
BannedClients []string `yaml:"bannedClients" jsonschema:"omitempty"`
BannedTopicRe string `yaml:"bannedTopicRe" jsonschema:"omitempty"`
BannedTopics []string `yaml:"bannedTopics" jsonschema:"omitempty"`
EarlyStop bool `yaml:"earlyStop" jsonschema:"omitempty"`
}

// Status is ConnectControl filter status
Expand All @@ -84,7 +83,6 @@ type (
)

var _ filters.Filter = (*ConnectControl)(nil)
var _ pipeline.MQTTFilter = (*ConnectControl)(nil)

// Name returns the name of the ConnectControl filter instance.
func (cc *ConnectControl) Name() string {
Expand All @@ -103,10 +101,6 @@ func (cc *ConnectControl) Spec() filters.Spec {

// Init init ConnectControl with pipeline filter spec
func (cc *ConnectControl) Init() {
spec := cc.spec
if spec.Protocol() != context.MQTT {
panic("filter ConnectControl only support MQTT protocol for now")
}
cc.bannedClients = make(map[string]struct{})
cc.bannedTopics = make(map[string]struct{})
cc.reload()
Expand Down Expand Up @@ -160,15 +154,15 @@ func (cc *ConnectControl) Status() interface{} {
func (cc *ConnectControl) Close() {
}

func (cc *ConnectControl) checkBan(ctx context.MQTTContext) bool {
cid := ctx.Client().ClientID()
func (cc *ConnectControl) checkBan(req *mqttprot.Request) bool {
cid := req.Client().ClientID()
if cc.bannedClientRe != nil && cc.bannedClientRe.MatchString(cid) {
return true
}
if _, ok := cc.bannedClients[cid]; ok {
return true
}
topic := ctx.PublishPacket().TopicName
topic := req.PublishPacket().TopicName
if cc.bannedTopicRe != nil && cc.bannedTopicRe.MatchString(topic) {
return true
}
Expand All @@ -179,17 +173,16 @@ func (cc *ConnectControl) checkBan(ctx context.MQTTContext) bool {
}

// HandleMQTT handle MQTT request
func (cc *ConnectControl) HandleMQTT(ctx context.MQTTContext) *context.MQTTResult {
if ctx.PacketType() != context.MQTTPublish {
return &context.MQTTResult{}
func (cc *ConnectControl) Handle(ctx *context.Context) string {
req := ctx.Request().(*mqttprot.Request)
resp := ctx.Response().(*mqttprot.Response)
if req.PacketType() != mqttprot.PublishType {
return ""
}

if cc.checkBan(ctx) {
ctx.SetDisconnect()
if cc.spec.EarlyStop {
ctx.SetEarlyStop()
}
return &context.MQTTResult{ErrString: resultBannedClientOrTopic}
if cc.checkBan(req) {
resp.SetDisconnect()
return resultBannedClientOrTopic
}
return &context.MQTTResult{}
return ""
}
111 changes: 48 additions & 63 deletions pkg/filters/connectcontrol/connectcontrol_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,28 +18,36 @@
package connectcontrol

import (
stdcontext "context"
"errors"
"testing"

"github.com/eclipse/paho.mqtt.golang/packets"
"github.com/megaease/easegress/pkg/context"
"github.com/megaease/easegress/pkg/filters"
"github.com/megaease/easegress/pkg/logger"
"github.com/megaease/easegress/pkg/protocols/mqttprot"
"github.com/stretchr/testify/assert"
)

func init() {
logger.InitNop()
}

func newContext(cid string, topic string) context.MQTTContext {
client := &context.MockMQTTClient{
func newContext(cid string, topic string) *context.Context {
ctx := context.New(nil)

client := &mqttprot.MockClient{
MockClientID: cid,
}
packet := packets.NewControlPacket(packets.Publish).(*packets.PublishPacket)
packet.TopicName = topic
return context.NewMQTTContext(stdcontext.Background(), client, packet)
req := mqttprot.NewRequest(packet, client)
ctx.SetRequest("req1", req)
ctx.UseRequest("req1", "req1")

resp := mqttprot.NewResponse()
ctx.SetResponse(context.DefaultResponseID, resp)
ctx.UseResponse(context.DefaultResponseID)
return ctx
}

func defaultFilterSpec(spec *Spec) filters.Spec {
Expand All @@ -53,37 +61,19 @@ func TestConnectControl(t *testing.T) {
assert := assert.New(t)

cc := &ConnectControl{}
assert.Equal(cc.Kind(), Kind, "wrong kind")
assert.Equal(cc.DefaultSpec(), &Spec{}, "wrong spec")
assert.NotEqual(len(cc.Description()), 0, "description for ConnectControl is empty")

assert.NotNil(cc.Results(), "if update result, please update this case")
checkProtocol := func() (err error) {
defer func() {
if errMsg := recover(); errMsg != nil {
err = errors.New(errMsg.(string))
return
}
}()
/*
meta := &pipeline.FilterMetaSpec{
Protocol: context.HTTP,
}
spec := pipeline.MockFilterSpec(nil, "", meta, nil)
*/
cc.Init(spec)
return
}
err := checkProtocol()
assert.NotNil(err, "if ConnectControl supports more protocol, please update this case")
assert.Equal(cc.Kind(), kind, "wrong kind")
assert.NotEqual(len(cc.Kind().Description), 0, "description for ConnectControl is empty")

assert.NotNil(cc.Kind().Results, "if update result, please update this case")

spec := defaultFilterSpec(&Spec{
BannedClients: []string{"banClient1", "banClient2"},
BannedTopics: []string{"banTopic1", "banTopic2"},
})
cc.Init(spec)
newCc := &ConnectControl{}
newCc.Inherit(spec, cc)
}).(*Spec)
cc = kind.CreateInstance(spec).(*ConnectControl)
cc.Init()
newCc := kind.CreateInstance(spec).(*ConnectControl)
newCc.Inherit(cc)
defer newCc.Close()
status := newCc.Status().(*Status)
assert.Equal(status.BannedClientNum, len(spec.BannedClients))
Expand All @@ -95,21 +85,20 @@ type testCase struct {
topic string
errString string
disconnect bool
earlyStop bool
}

func doTest(t *testing.T, spec *Spec, testCases []testCase) {
assert := assert.New(t)
filterSpec := defaultFilterSpec(spec)
cc := &ConnectControl{}
cc.Init(filterSpec)
cc := kind.CreateInstance(filterSpec).(*ConnectControl)
cc.Init()

for _, test := range testCases {
ctx := newContext(test.cid, test.topic)
res := cc.HandleMQTT(ctx)
assert.Equal(res.ErrString, test.errString)
assert.Equal(ctx.Disconnect(), test.disconnect)
assert.Equal(ctx.EarlyStop(), test.earlyStop)
res := cc.Handle(ctx)
assert.Equal(res, test.errString)
resp := ctx.Response().(*mqttprot.Response)
assert.Equal(resp.Disconnect(), test.disconnect)
}
status := cc.Status().(*Status)
assert.Equal(status.BannedClientRe, spec.BannedClientRe)
Expand All @@ -118,57 +107,53 @@ func doTest(t *testing.T, spec *Spec, testCases []testCase) {
assert.Equal(status.BannedTopicNum, len(spec.BannedTopics))
}

func TestHandleMQTT(t *testing.T) {
func TestHandle(t *testing.T) {
// check BannedClients
spec := &Spec{
BannedClients: []string{"ban1", "ban2"},
EarlyStop: true,
}
testCases := []testCase{
{cid: "ban1", topic: "unban", errString: resultBannedClientOrTopic, disconnect: true, earlyStop: true},
{cid: "ban2", topic: "unban", errString: resultBannedClientOrTopic, disconnect: true, earlyStop: true},
{cid: "unban1", topic: "ban/sport/ball", errString: "", disconnect: false, earlyStop: false},
{cid: "unban2", topic: "ban/sport/run", errString: "", disconnect: false, earlyStop: false},
{cid: "unban", topic: "unban", errString: "", disconnect: false, earlyStop: false},
{cid: "ban1", topic: "unban", errString: resultBannedClientOrTopic, disconnect: true},
{cid: "ban2", topic: "unban", errString: resultBannedClientOrTopic, disconnect: true},
{cid: "unban1", topic: "ban/sport/ball", errString: "", disconnect: false},
{cid: "unban2", topic: "ban/sport/run", errString: "", disconnect: false},
{cid: "unban", topic: "unban", errString: "", disconnect: false},
}
doTest(t, spec, testCases)

// check BannedTopics
spec = &Spec{
BannedTopics: []string{"ban/sport/ball", "ban/sport/run"},
EarlyStop: true,
}
testCases = []testCase{
{cid: "unban1", topic: "ban/sport/ball", errString: resultBannedClientOrTopic, disconnect: true, earlyStop: true},
{cid: "unban2", topic: "ban/sport/run", errString: resultBannedClientOrTopic, disconnect: true, earlyStop: true},
{cid: "unban3", topic: "unban/sport", errString: "", disconnect: false, earlyStop: false},
{cid: "unban4", topic: "unban", errString: "", disconnect: false, earlyStop: false},
{cid: "unban1", topic: "ban/sport/ball", errString: resultBannedClientOrTopic, disconnect: true},
{cid: "unban2", topic: "ban/sport/run", errString: resultBannedClientOrTopic, disconnect: true},
{cid: "unban3", topic: "unban/sport", errString: "", disconnect: false},
{cid: "unban4", topic: "unban", errString: "", disconnect: false},
}
doTest(t, spec, testCases)

// check BannedClientRe
spec = &Spec{
BannedClientRe: "phone",
EarlyStop: true,
}
testCases = []testCase{
{cid: "phone123", topic: "ban/sport/ball", errString: resultBannedClientOrTopic, disconnect: true, earlyStop: true},
{cid: "phone256", topic: "ban/sport/run", errString: resultBannedClientOrTopic, disconnect: true, earlyStop: true},
{cid: "tv", topic: "unban/sport", errString: "", disconnect: false, earlyStop: false},
{cid: "tv", topic: "unban", errString: "", disconnect: false, earlyStop: false},
{cid: "phone123", topic: "ban/sport/ball", errString: resultBannedClientOrTopic, disconnect: true},
{cid: "phone256", topic: "ban/sport/run", errString: resultBannedClientOrTopic, disconnect: true},
{cid: "tv", topic: "unban/sport", errString: "", disconnect: false},
{cid: "tv", topic: "unban", errString: "", disconnect: false},
}
doTest(t, spec, testCases)

// check BannedTopicRe
spec = &Spec{
BannedTopicRe: "sport",
EarlyStop: true,
}
testCases = []testCase{
{cid: "phone123", topic: "ban/sport/ball", errString: resultBannedClientOrTopic, disconnect: true, earlyStop: true},
{cid: "phone256", topic: "ban/sport/run", errString: resultBannedClientOrTopic, disconnect: true, earlyStop: true},
{cid: "tv", topic: "unban", errString: "", disconnect: false, earlyStop: false},
{cid: "tv", topic: "unban", errString: "", disconnect: false, earlyStop: false},
{cid: "phone123", topic: "ban/sport/ball", errString: resultBannedClientOrTopic, disconnect: true},
{cid: "phone256", topic: "ban/sport/run", errString: resultBannedClientOrTopic, disconnect: true},
{cid: "tv", topic: "unban", errString: "", disconnect: false},
{cid: "tv", topic: "unban", errString: "", disconnect: false},
}
doTest(t, spec, testCases)

Expand All @@ -178,8 +163,8 @@ func TestHandleMQTT(t *testing.T) {
BannedTopicRe: "(?P<name>",
}
filterSpec := defaultFilterSpec(spec)
cc := &ConnectControl{}
cc.Init(filterSpec)
cc := kind.CreateInstance(filterSpec).(*ConnectControl)
cc.Init()
assert.Nil(t, cc.bannedClientRe)
assert.Nil(t, cc.bannedTopicRe)
}
30 changes: 13 additions & 17 deletions pkg/filters/kafka/kafka.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@ import (
"github.com/megaease/easegress/pkg/context"
"github.com/megaease/easegress/pkg/filters"
"github.com/megaease/easegress/pkg/logger"
"github.com/megaease/easegress/pkg/object/pipeline"
"github.com/megaease/easegress/pkg/protocols/mqttprot"
)

const (
// Kind is the kind of Kafka
Kind = "Kafka"
Kind = "KafkaMQTT"

resultGetDataFailed = "GetDataFailed"
resultGetDataFailed = "getDataFailed"
)

var kind = &filters.Kind{
Expand All @@ -41,7 +41,7 @@ var kind = &filters.Kind{
DefaultSpec: func() filters.Spec {
return &Spec{}
},
CreateInstance: func(spec filter.Spec) filters.Filter {
CreateInstance: func(spec filters.Spec) filters.Filter {
return &Kafka{spec: spec.(*Spec)}
},
}
Expand All @@ -65,7 +65,6 @@ type (
)

var _ filters.Filter = (*Kafka)(nil)
var _ pipeline.MQTTFilter = (*Kafka)(nil)

// Name returns the name of the Kafka filter instance.
func (k *Kafka) Name() string {
Expand Down Expand Up @@ -125,10 +124,6 @@ func (k *Kafka) setProducer() {

// Init init Kafka
func (k *Kafka) Init() {
spec := k.spec
if spec.Protocol() != context.MQTT {
panic("filter Kafka only support MQTT protocol for now")
}
k.done = make(chan struct{})
k.setKV()
k.setProducer()
Expand All @@ -151,7 +146,7 @@ func (k *Kafka) Status() interface{} {
}

// HandleMQTT handle MQTT context
func (k *Kafka) HandleMQTT(ctx context.MQTTContext) *context.MQTTResult {
func (k *Kafka) Handle(ctx *context.Context) string {
var topic string
var headers map[string]string
var payload []byte
Expand All @@ -161,25 +156,26 @@ func (k *Kafka) HandleMQTT(ctx context.MQTTContext) *context.MQTTResult {
if k.topicKey != "" {
topic, ok = ctx.GetKV(k.topicKey).(string)
if !ok {
return &context.MQTTResult{ErrString: resultGetDataFailed}
return resultGetDataFailed
}
}
if k.headerKey != "" {
headers, ok = ctx.GetKV(k.headerKey).(map[string]string)
if !ok {
return &context.MQTTResult{ErrString: resultGetDataFailed}
return resultGetDataFailed
}
}
if k.payloadKey != "" {
payload, ok = ctx.GetKV(k.payloadKey).([]byte)
if !ok {
return &context.MQTTResult{ErrString: resultGetDataFailed}
return resultGetDataFailed
}
}

req := ctx.Request().(*mqttprot.Request)
// set data from PublishPacket if data is missing
if ctx.PacketType() == context.MQTTPublish {
p := ctx.PublishPacket()
if req.PacketType() == mqttprot.PublishType {
p := req.PublishPacket()
if topic == "" {
topic = p.TopicName
}
Expand All @@ -193,7 +189,7 @@ func (k *Kafka) HandleMQTT(ctx context.MQTTContext) *context.MQTTResult {
}

if topic == "" || payload == nil {
return &context.MQTTResult{ErrString: resultGetDataFailed}
return resultGetDataFailed
}

kafkaHeaders := []sarama.RecordHeader{}
Expand All @@ -207,5 +203,5 @@ func (k *Kafka) HandleMQTT(ctx context.MQTTContext) *context.MQTTResult {
Value: sarama.ByteEncoder(payload),
}
k.producer.Input() <- msg
return &context.MQTTResult{}
return ""
}
Loading

0 comments on commit 9cb5c68

Please sign in to comment.