From 9cb5c68566e18d73c26294f176249a7f4dc805c0 Mon Sep 17 00:00:00 2001 From: SU Chen Date: Thu, 12 May 2022 17:36:38 +0800 Subject: [PATCH] update mqtt code for new pipeline (#620) * update mqtt code for new pipeline * update mqttproxy use context --- pkg/filters/connectcontrol/connectcontrol.go | 33 ++-- .../connectcontrol/connectcontrol_test.go | 111 +++++------ pkg/filters/kafka/kafka.go | 30 ++- pkg/filters/kafka/kafka_test.go | 28 +-- pkg/filters/mqttclientauth/mqttauth.go | 25 ++- pkg/filters/mqttclientauth/mqttauth_test.go | 48 ++--- pkg/filters/topicmapper/topicmapper.go | 22 +-- pkg/filters/topicmapper/topicmapper_test.go | 26 ++- pkg/object/mqttproxy/broker.go | 30 ++- pkg/object/mqttproxy/client.go | 31 ++- pkg/object/mqttproxy/mock_test.go | 186 ++++++++++++++++-- pkg/object/mqttproxy/mqtt_test.go | 153 ++++++++------ pkg/object/mqttproxy/mqttproxy.go | 13 +- pkg/object/pipeline/mock.go | 25 +++ pkg/protocols/mqttprot/mock.go | 54 +++++ pkg/protocols/mqttprot/request.go | 164 +++++++++++++++ pkg/protocols/mqttprot/response.go | 94 +++++++++ 17 files changed, 795 insertions(+), 278 deletions(-) create mode 100644 pkg/object/pipeline/mock.go create mode 100644 pkg/protocols/mqttprot/mock.go create mode 100644 pkg/protocols/mqttprot/request.go create mode 100644 pkg/protocols/mqttprot/response.go diff --git a/pkg/filters/connectcontrol/connectcontrol.go b/pkg/filters/connectcontrol/connectcontrol.go index 08bd9221b6..286fb7acd5 100644 --- a/pkg/filters/connectcontrol/connectcontrol.go +++ b/pkg/filters/connectcontrol/connectcontrol.go @@ -23,7 +23,7 @@ import ( "github.com/megaease/easegress/pkg/context" "github.com/megaease/easegress/pkg/filters" "github.com/megaease/easegress/pkg/logger" - "github.com/megaease/easegress/pkg/object/pipeline" + "github.com/megaease/easegress/pkg/protocols/mqttprot" ) const ( @@ -71,7 +71,6 @@ type ( BannedClients []string `yaml:"bannedClients" jsonschema:"omitempty"` BannedTopicRe string `yaml:"bannedTopicRe" jsonschema:"omitempty"` BannedTopics []string `yaml:"bannedTopics" jsonschema:"omitempty"` - EarlyStop bool `yaml:"earlyStop" jsonschema:"omitempty"` } // Status is ConnectControl filter status @@ -84,7 +83,6 @@ type ( ) var _ filters.Filter = (*ConnectControl)(nil) -var _ pipeline.MQTTFilter = (*ConnectControl)(nil) // Name returns the name of the ConnectControl filter instance. func (cc *ConnectControl) Name() string { @@ -103,10 +101,6 @@ func (cc *ConnectControl) Spec() filters.Spec { // Init init ConnectControl with pipeline filter spec func (cc *ConnectControl) Init() { - spec := cc.spec - if spec.Protocol() != context.MQTT { - panic("filter ConnectControl only support MQTT protocol for now") - } cc.bannedClients = make(map[string]struct{}) cc.bannedTopics = make(map[string]struct{}) cc.reload() @@ -160,15 +154,15 @@ func (cc *ConnectControl) Status() interface{} { func (cc *ConnectControl) Close() { } -func (cc *ConnectControl) checkBan(ctx context.MQTTContext) bool { - cid := ctx.Client().ClientID() +func (cc *ConnectControl) checkBan(req *mqttprot.Request) bool { + cid := req.Client().ClientID() if cc.bannedClientRe != nil && cc.bannedClientRe.MatchString(cid) { return true } if _, ok := cc.bannedClients[cid]; ok { return true } - topic := ctx.PublishPacket().TopicName + topic := req.PublishPacket().TopicName if cc.bannedTopicRe != nil && cc.bannedTopicRe.MatchString(topic) { return true } @@ -179,17 +173,16 @@ func (cc *ConnectControl) checkBan(ctx context.MQTTContext) bool { } // HandleMQTT handle MQTT request -func (cc *ConnectControl) HandleMQTT(ctx context.MQTTContext) *context.MQTTResult { - if ctx.PacketType() != context.MQTTPublish { - return &context.MQTTResult{} +func (cc *ConnectControl) Handle(ctx *context.Context) string { + req := ctx.Request().(*mqttprot.Request) + resp := ctx.Response().(*mqttprot.Response) + if req.PacketType() != mqttprot.PublishType { + return "" } - if cc.checkBan(ctx) { - ctx.SetDisconnect() - if cc.spec.EarlyStop { - ctx.SetEarlyStop() - } - return &context.MQTTResult{ErrString: resultBannedClientOrTopic} + if cc.checkBan(req) { + resp.SetDisconnect() + return resultBannedClientOrTopic } - return &context.MQTTResult{} + return "" } diff --git a/pkg/filters/connectcontrol/connectcontrol_test.go b/pkg/filters/connectcontrol/connectcontrol_test.go index ba5e0a8914..0dbd479bea 100644 --- a/pkg/filters/connectcontrol/connectcontrol_test.go +++ b/pkg/filters/connectcontrol/connectcontrol_test.go @@ -18,14 +18,13 @@ package connectcontrol import ( - stdcontext "context" - "errors" "testing" "github.com/eclipse/paho.mqtt.golang/packets" "github.com/megaease/easegress/pkg/context" "github.com/megaease/easegress/pkg/filters" "github.com/megaease/easegress/pkg/logger" + "github.com/megaease/easegress/pkg/protocols/mqttprot" "github.com/stretchr/testify/assert" ) @@ -33,13 +32,22 @@ func init() { logger.InitNop() } -func newContext(cid string, topic string) context.MQTTContext { - client := &context.MockMQTTClient{ +func newContext(cid string, topic string) *context.Context { + ctx := context.New(nil) + + client := &mqttprot.MockClient{ MockClientID: cid, } packet := packets.NewControlPacket(packets.Publish).(*packets.PublishPacket) packet.TopicName = topic - return context.NewMQTTContext(stdcontext.Background(), client, packet) + req := mqttprot.NewRequest(packet, client) + ctx.SetRequest("req1", req) + ctx.UseRequest("req1", "req1") + + resp := mqttprot.NewResponse() + ctx.SetResponse(context.DefaultResponseID, resp) + ctx.UseResponse(context.DefaultResponseID) + return ctx } func defaultFilterSpec(spec *Spec) filters.Spec { @@ -53,37 +61,19 @@ func TestConnectControl(t *testing.T) { assert := assert.New(t) cc := &ConnectControl{} - assert.Equal(cc.Kind(), Kind, "wrong kind") - assert.Equal(cc.DefaultSpec(), &Spec{}, "wrong spec") - assert.NotEqual(len(cc.Description()), 0, "description for ConnectControl is empty") - - assert.NotNil(cc.Results(), "if update result, please update this case") - checkProtocol := func() (err error) { - defer func() { - if errMsg := recover(); errMsg != nil { - err = errors.New(errMsg.(string)) - return - } - }() - /* - meta := &pipeline.FilterMetaSpec{ - Protocol: context.HTTP, - } - spec := pipeline.MockFilterSpec(nil, "", meta, nil) - */ - cc.Init(spec) - return - } - err := checkProtocol() - assert.NotNil(err, "if ConnectControl supports more protocol, please update this case") + assert.Equal(cc.Kind(), kind, "wrong kind") + assert.NotEqual(len(cc.Kind().Description), 0, "description for ConnectControl is empty") + + assert.NotNil(cc.Kind().Results, "if update result, please update this case") spec := defaultFilterSpec(&Spec{ BannedClients: []string{"banClient1", "banClient2"}, BannedTopics: []string{"banTopic1", "banTopic2"}, - }) - cc.Init(spec) - newCc := &ConnectControl{} - newCc.Inherit(spec, cc) + }).(*Spec) + cc = kind.CreateInstance(spec).(*ConnectControl) + cc.Init() + newCc := kind.CreateInstance(spec).(*ConnectControl) + newCc.Inherit(cc) defer newCc.Close() status := newCc.Status().(*Status) assert.Equal(status.BannedClientNum, len(spec.BannedClients)) @@ -95,21 +85,20 @@ type testCase struct { topic string errString string disconnect bool - earlyStop bool } func doTest(t *testing.T, spec *Spec, testCases []testCase) { assert := assert.New(t) filterSpec := defaultFilterSpec(spec) - cc := &ConnectControl{} - cc.Init(filterSpec) + cc := kind.CreateInstance(filterSpec).(*ConnectControl) + cc.Init() for _, test := range testCases { ctx := newContext(test.cid, test.topic) - res := cc.HandleMQTT(ctx) - assert.Equal(res.ErrString, test.errString) - assert.Equal(ctx.Disconnect(), test.disconnect) - assert.Equal(ctx.EarlyStop(), test.earlyStop) + res := cc.Handle(ctx) + assert.Equal(res, test.errString) + resp := ctx.Response().(*mqttprot.Response) + assert.Equal(resp.Disconnect(), test.disconnect) } status := cc.Status().(*Status) assert.Equal(status.BannedClientRe, spec.BannedClientRe) @@ -118,57 +107,53 @@ func doTest(t *testing.T, spec *Spec, testCases []testCase) { assert.Equal(status.BannedTopicNum, len(spec.BannedTopics)) } -func TestHandleMQTT(t *testing.T) { +func TestHandle(t *testing.T) { // check BannedClients spec := &Spec{ BannedClients: []string{"ban1", "ban2"}, - EarlyStop: true, } testCases := []testCase{ - {cid: "ban1", topic: "unban", errString: resultBannedClientOrTopic, disconnect: true, earlyStop: true}, - {cid: "ban2", topic: "unban", errString: resultBannedClientOrTopic, disconnect: true, earlyStop: true}, - {cid: "unban1", topic: "ban/sport/ball", errString: "", disconnect: false, earlyStop: false}, - {cid: "unban2", topic: "ban/sport/run", errString: "", disconnect: false, earlyStop: false}, - {cid: "unban", topic: "unban", errString: "", disconnect: false, earlyStop: false}, + {cid: "ban1", topic: "unban", errString: resultBannedClientOrTopic, disconnect: true}, + {cid: "ban2", topic: "unban", errString: resultBannedClientOrTopic, disconnect: true}, + {cid: "unban1", topic: "ban/sport/ball", errString: "", disconnect: false}, + {cid: "unban2", topic: "ban/sport/run", errString: "", disconnect: false}, + {cid: "unban", topic: "unban", errString: "", disconnect: false}, } doTest(t, spec, testCases) // check BannedTopics spec = &Spec{ BannedTopics: []string{"ban/sport/ball", "ban/sport/run"}, - EarlyStop: true, } testCases = []testCase{ - {cid: "unban1", topic: "ban/sport/ball", errString: resultBannedClientOrTopic, disconnect: true, earlyStop: true}, - {cid: "unban2", topic: "ban/sport/run", errString: resultBannedClientOrTopic, disconnect: true, earlyStop: true}, - {cid: "unban3", topic: "unban/sport", errString: "", disconnect: false, earlyStop: false}, - {cid: "unban4", topic: "unban", errString: "", disconnect: false, earlyStop: false}, + {cid: "unban1", topic: "ban/sport/ball", errString: resultBannedClientOrTopic, disconnect: true}, + {cid: "unban2", topic: "ban/sport/run", errString: resultBannedClientOrTopic, disconnect: true}, + {cid: "unban3", topic: "unban/sport", errString: "", disconnect: false}, + {cid: "unban4", topic: "unban", errString: "", disconnect: false}, } doTest(t, spec, testCases) // check BannedClientRe spec = &Spec{ BannedClientRe: "phone", - EarlyStop: true, } testCases = []testCase{ - {cid: "phone123", topic: "ban/sport/ball", errString: resultBannedClientOrTopic, disconnect: true, earlyStop: true}, - {cid: "phone256", topic: "ban/sport/run", errString: resultBannedClientOrTopic, disconnect: true, earlyStop: true}, - {cid: "tv", topic: "unban/sport", errString: "", disconnect: false, earlyStop: false}, - {cid: "tv", topic: "unban", errString: "", disconnect: false, earlyStop: false}, + {cid: "phone123", topic: "ban/sport/ball", errString: resultBannedClientOrTopic, disconnect: true}, + {cid: "phone256", topic: "ban/sport/run", errString: resultBannedClientOrTopic, disconnect: true}, + {cid: "tv", topic: "unban/sport", errString: "", disconnect: false}, + {cid: "tv", topic: "unban", errString: "", disconnect: false}, } doTest(t, spec, testCases) // check BannedTopicRe spec = &Spec{ BannedTopicRe: "sport", - EarlyStop: true, } testCases = []testCase{ - {cid: "phone123", topic: "ban/sport/ball", errString: resultBannedClientOrTopic, disconnect: true, earlyStop: true}, - {cid: "phone256", topic: "ban/sport/run", errString: resultBannedClientOrTopic, disconnect: true, earlyStop: true}, - {cid: "tv", topic: "unban", errString: "", disconnect: false, earlyStop: false}, - {cid: "tv", topic: "unban", errString: "", disconnect: false, earlyStop: false}, + {cid: "phone123", topic: "ban/sport/ball", errString: resultBannedClientOrTopic, disconnect: true}, + {cid: "phone256", topic: "ban/sport/run", errString: resultBannedClientOrTopic, disconnect: true}, + {cid: "tv", topic: "unban", errString: "", disconnect: false}, + {cid: "tv", topic: "unban", errString: "", disconnect: false}, } doTest(t, spec, testCases) @@ -178,8 +163,8 @@ func TestHandleMQTT(t *testing.T) { BannedTopicRe: "(?P", } filterSpec := defaultFilterSpec(spec) - cc := &ConnectControl{} - cc.Init(filterSpec) + cc := kind.CreateInstance(filterSpec).(*ConnectControl) + cc.Init() assert.Nil(t, cc.bannedClientRe) assert.Nil(t, cc.bannedTopicRe) } diff --git a/pkg/filters/kafka/kafka.go b/pkg/filters/kafka/kafka.go index 8953596c94..7ec0879d3d 100644 --- a/pkg/filters/kafka/kafka.go +++ b/pkg/filters/kafka/kafka.go @@ -24,14 +24,14 @@ import ( "github.com/megaease/easegress/pkg/context" "github.com/megaease/easegress/pkg/filters" "github.com/megaease/easegress/pkg/logger" - "github.com/megaease/easegress/pkg/object/pipeline" + "github.com/megaease/easegress/pkg/protocols/mqttprot" ) const ( // Kind is the kind of Kafka - Kind = "Kafka" + Kind = "KafkaMQTT" - resultGetDataFailed = "GetDataFailed" + resultGetDataFailed = "getDataFailed" ) var kind = &filters.Kind{ @@ -41,7 +41,7 @@ var kind = &filters.Kind{ DefaultSpec: func() filters.Spec { return &Spec{} }, - CreateInstance: func(spec filter.Spec) filters.Filter { + CreateInstance: func(spec filters.Spec) filters.Filter { return &Kafka{spec: spec.(*Spec)} }, } @@ -65,7 +65,6 @@ type ( ) var _ filters.Filter = (*Kafka)(nil) -var _ pipeline.MQTTFilter = (*Kafka)(nil) // Name returns the name of the Kafka filter instance. func (k *Kafka) Name() string { @@ -125,10 +124,6 @@ func (k *Kafka) setProducer() { // Init init Kafka func (k *Kafka) Init() { - spec := k.spec - if spec.Protocol() != context.MQTT { - panic("filter Kafka only support MQTT protocol for now") - } k.done = make(chan struct{}) k.setKV() k.setProducer() @@ -151,7 +146,7 @@ func (k *Kafka) Status() interface{} { } // HandleMQTT handle MQTT context -func (k *Kafka) HandleMQTT(ctx context.MQTTContext) *context.MQTTResult { +func (k *Kafka) Handle(ctx *context.Context) string { var topic string var headers map[string]string var payload []byte @@ -161,25 +156,26 @@ func (k *Kafka) HandleMQTT(ctx context.MQTTContext) *context.MQTTResult { if k.topicKey != "" { topic, ok = ctx.GetKV(k.topicKey).(string) if !ok { - return &context.MQTTResult{ErrString: resultGetDataFailed} + return resultGetDataFailed } } if k.headerKey != "" { headers, ok = ctx.GetKV(k.headerKey).(map[string]string) if !ok { - return &context.MQTTResult{ErrString: resultGetDataFailed} + return resultGetDataFailed } } if k.payloadKey != "" { payload, ok = ctx.GetKV(k.payloadKey).([]byte) if !ok { - return &context.MQTTResult{ErrString: resultGetDataFailed} + return resultGetDataFailed } } + req := ctx.Request().(*mqttprot.Request) // set data from PublishPacket if data is missing - if ctx.PacketType() == context.MQTTPublish { - p := ctx.PublishPacket() + if req.PacketType() == mqttprot.PublishType { + p := req.PublishPacket() if topic == "" { topic = p.TopicName } @@ -193,7 +189,7 @@ func (k *Kafka) HandleMQTT(ctx context.MQTTContext) *context.MQTTResult { } if topic == "" || payload == nil { - return &context.MQTTResult{ErrString: resultGetDataFailed} + return resultGetDataFailed } kafkaHeaders := []sarama.RecordHeader{} @@ -207,5 +203,5 @@ func (k *Kafka) HandleMQTT(ctx context.MQTTContext) *context.MQTTResult { Value: sarama.ByteEncoder(payload), } k.producer.Input() <- msg - return &context.MQTTResult{} + return "" } diff --git a/pkg/filters/kafka/kafka_test.go b/pkg/filters/kafka/kafka_test.go index 4bb5328a7d..370dbea138 100644 --- a/pkg/filters/kafka/kafka_test.go +++ b/pkg/filters/kafka/kafka_test.go @@ -18,13 +18,13 @@ package kafka import ( - stdcontext "context" "fmt" "testing" "github.com/megaease/easegress/pkg/context" "github.com/megaease/easegress/pkg/filters" "github.com/megaease/easegress/pkg/logger" + "github.com/megaease/easegress/pkg/protocols/mqttprot" "github.com/Shopify/sarama" "github.com/eclipse/paho.mqtt.golang/packets" @@ -61,18 +61,22 @@ func newMockAsyncProducer() sarama.AsyncProducer { func defaultFilterSpec(spec *Spec) filters.Spec { spec.BaseSpec.MetaSpec.Kind = Kind spec.BaseSpec.MetaSpec.Name = "kafka-demo" - result, _ := filters.NewSpec(nil, "pipeline-demo", spec) - return result + return spec } -func newContext(cid string, topic string, payload []byte) context.MQTTContext { - client := &context.MockMQTTClient{ +func newContext(cid string, topic string, payload []byte) *context.Context { + ctx := context.New(nil) + + client := &mqttprot.MockClient{ MockClientID: cid, } packet := packets.NewControlPacket(packets.Publish).(*packets.PublishPacket) packet.TopicName = topic packet.Payload = payload - ctx := context.NewMQTTContext(stdcontext.Background(), client, packet) + req := mqttprot.NewRequest(packet, client) + + ctx.SetRequest("req1", req) + ctx.UseRequest("req1", "req1") return ctx } @@ -82,8 +86,8 @@ func TestKafka(t *testing.T) { Backend: []string{"localhost:1234"}, } filterSpec := defaultFilterSpec(spec) - k := &Kafka{} - assert.Panics(func() { k.Init(filterSpec) }, "kafka should panic for invalid backend") + k := kind.CreateInstance(filterSpec) + assert.Panics(func() { k.Init() }, "kafka should panic for invalid backend") kafka := Kafka{ producer: newMockAsyncProducer(), @@ -91,9 +95,11 @@ func TestKafka(t *testing.T) { } mqttCtx := newContext("test", "a/b/c", []byte("text")) - kafka.HandleMQTT(mqttCtx) + kafka.Handle(mqttCtx) msg := <-kafka.producer.(*mockAsyncProducer).ch - assert.Equal(msg.Topic, mqttCtx.PublishPacket().TopicName) + + req := mqttCtx.Request().(*mqttprot.Request) + assert.Equal(msg.Topic, req.PublishPacket().TopicName) assert.Equal(0, len(msg.Headers)) value, err := msg.Value.Encode() assert.Nil(err) @@ -122,7 +128,7 @@ func TestKafkaWithKVMap(t *testing.T) { mqttCtx.SetKV("topic", "123") mqttCtx.SetKV("headers", map[string]string{"1": "a"}) - kafka.HandleMQTT(mqttCtx) + kafka.Handle(mqttCtx) msg := <-kafka.producer.(*mockAsyncProducer).ch assert.Equal("123", msg.Topic) assert.Equal(1, len(msg.Headers)) diff --git a/pkg/filters/mqttclientauth/mqttauth.go b/pkg/filters/mqttclientauth/mqttauth.go index 52601cf341..8f7c839a1a 100644 --- a/pkg/filters/mqttclientauth/mqttauth.go +++ b/pkg/filters/mqttclientauth/mqttauth.go @@ -25,7 +25,7 @@ import ( "github.com/megaease/easegress/pkg/context" "github.com/megaease/easegress/pkg/filters" "github.com/megaease/easegress/pkg/logger" - "github.com/megaease/easegress/pkg/object/pipeline" + "github.com/megaease/easegress/pkg/protocols/mqttprot" ) const ( @@ -77,7 +77,6 @@ type ( ) var _ filters.Filter = (*MQTTClientAuth)(nil) -var _ pipeline.MQTTFilter = (*MQTTClientAuth)(nil) // Name returns the name of the MQTTClientAuth filter instance. func (a *MQTTClientAuth) Name() string { @@ -96,10 +95,6 @@ func (a *MQTTClientAuth) Spec() filters.Spec { // Init init MQTTClientAuth func (a *MQTTClientAuth) Init() { - spec := a.spec - if spec.Protocol() != context.MQTT { - panic("filter ConnectControl only support MQTT protocol for now") - } a.salt = a.spec.Salt a.authMap = make(map[string]string) @@ -108,7 +103,7 @@ func (a *MQTTClientAuth) Init() { } if len(a.authMap) == 0 { - logger.Errorf("empty valid authentication for MQTT filter %v", spec.Name()) + logger.Errorf("empty valid authentication for MQTT filter %v", a.spec.Name()) } } @@ -147,14 +142,16 @@ func (a *MQTTClientAuth) checkAuth(connect *packets.ConnectPacket) string { } // HandleMQTT handle MQTT context -func (a *MQTTClientAuth) HandleMQTT(ctx context.MQTTContext) *context.MQTTResult { - if ctx.PacketType() != context.MQTTConnect { - return &context.MQTTResult{} +func (a *MQTTClientAuth) Handle(ctx *context.Context) string { + req := ctx.Request().(*mqttprot.Request) + resp := ctx.Response().(*mqttprot.Response) + if req.PacketType() != mqttprot.ConnectType { + return "" } - result := a.checkAuth(ctx.ConnectPacket()) + result := a.checkAuth(req.ConnectPacket()) if result != "" { - ctx.SetDisconnect() - return &context.MQTTResult{ErrString: resultAuthFail} + resp.SetDisconnect() + return resultAuthFail } - return &context.MQTTResult{ErrString: ""} + return "" } diff --git a/pkg/filters/mqttclientauth/mqttauth_test.go b/pkg/filters/mqttclientauth/mqttauth_test.go index 2eed6b8e56..739b602fbf 100644 --- a/pkg/filters/mqttclientauth/mqttauth_test.go +++ b/pkg/filters/mqttclientauth/mqttauth_test.go @@ -18,15 +18,14 @@ package mqttclientauth import ( - stdcontext "context" "fmt" "sync" "testing" "github.com/eclipse/paho.mqtt.golang/packets" "github.com/megaease/easegress/pkg/context" - "github.com/megaease/easegress/pkg/filters" "github.com/megaease/easegress/pkg/logger" + "github.com/megaease/easegress/pkg/protocols/mqttprot" "github.com/stretchr/testify/assert" ) @@ -34,37 +33,40 @@ func init() { logger.InitNop() } -func newContext(cid, username, password string) context.MQTTContext { - client := &context.MockMQTTClient{ +func newContext(cid, username, password string) *context.Context { + ctx := context.New(nil) + + client := &mqttprot.MockClient{ MockClientID: cid, } packet := packets.NewControlPacket(packets.Connect).(*packets.ConnectPacket) packet.ClientIdentifier = cid packet.Username = username packet.Password = []byte(password) - return context.NewMQTTContext(stdcontext.Background(), client, packet) -} -func defaultFilterSpec(spec *Spec) filters.Spec { - spec.BaseSpec.MetaSpec.Kind = Kind - spec.BaseSpec.MetaSpec.Name = "connect-demo" - result, _ := filters.NewSpec(nil, "pipeline-demo", spec) - return result + req := mqttprot.NewRequest(packet, client) + ctx.SetRequest("req1", req) + ctx.UseRequest("req1", "req1") + + resp := mqttprot.NewResponse() + ctx.SetResponse(context.DefaultResponseID, resp) + ctx.UseResponse(context.DefaultResponseID) + + return ctx } func TestAuth(t *testing.T) { assert := assert.New(t) spec := &Spec{} - filterSpec := defaultFilterSpec(spec) - auth := &MQTTClientAuth{} - auth.Init(filterSpec) + auth := kind.CreateInstance(spec) + auth.Init() - assert.Equal(Kind, auth.Kind()) - assert.Equal(1, len(auth.Results()), "please update this case if add more results") + assert.Equal(Kind, auth.Kind().Name) + assert.Equal(1, len(kind.Results), "please update this case if add more results") assert.Nil(auth.Status(), "please update this case if return status") - newAuth := &MQTTClientAuth{} - newAuth.Inherit(filterSpec, auth) + newAuth := kind.CreateInstance(spec) + newAuth.Inherit(auth) newAuth.Close() } @@ -79,9 +81,8 @@ func TestAuthFile(t *testing.T) { }, } - filterSpec := defaultFilterSpec(spec) - auth := &MQTTClientAuth{} - auth.Init(filterSpec) + auth := kind.CreateInstance(spec) + auth.Init() type testCase struct { cid string @@ -102,8 +103,9 @@ func TestAuthFile(t *testing.T) { wg.Add(1) go func(test testCase) { ctx := newContext(test.cid, test.name, test.pass) - auth.HandleMQTT(ctx) - assert.Equal(test.disconnect, ctx.Disconnect(), fmt.Errorf("test case %+v got wrong result", test)) + auth.Handle(ctx) + resp := ctx.Response().(*mqttprot.Response) + assert.Equal(test.disconnect, resp.Disconnect(), fmt.Errorf("test case %+v got wrong result", test)) wg.Done() }(test) } diff --git a/pkg/filters/topicmapper/topicmapper.go b/pkg/filters/topicmapper/topicmapper.go index 7ef0a521ba..94f4ccec3a 100644 --- a/pkg/filters/topicmapper/topicmapper.go +++ b/pkg/filters/topicmapper/topicmapper.go @@ -21,7 +21,7 @@ import ( "github.com/megaease/easegress/pkg/context" "github.com/megaease/easegress/pkg/filters" "github.com/megaease/easegress/pkg/logger" - "github.com/megaease/easegress/pkg/object/pipeline" + "github.com/megaease/easegress/pkg/protocols/mqttprot" ) const ( @@ -56,7 +56,6 @@ type ( ) var _ filters.Filter = (*TopicMapper)(nil) -var _ pipeline.MQTTFilter = (*TopicMapper)(nil) // Name returns the name of the TopicMapper filter instance. func (k *TopicMapper) Name() string { @@ -75,10 +74,6 @@ func (k *TopicMapper) Spec() filters.Spec { // Init init TopicMapper func (k *TopicMapper) Init() { - spec := k.spec - if spec.Protocol() != context.MQTT { - panic("filter TopicMapper only support MQTT protocol for now") - } k.mapFn = getTopicMapFunc(k.spec) if k.mapFn == nil { panic("invalid spec for TopicMapper") @@ -101,19 +96,20 @@ func (k *TopicMapper) Status() interface{} { } // HandleMQTT handle MQTT context -func (k *TopicMapper) HandleMQTT(ctx context.MQTTContext) *context.MQTTResult { - if ctx.PacketType() != context.MQTTPublish { - return &context.MQTTResult{} +func (k *TopicMapper) Handle(ctx *context.Context) string { + req := ctx.Request().(*mqttprot.Request) + if req.PacketType() != mqttprot.PublishType { + return "" } - publish := ctx.PublishPacket() + publish := req.PublishPacket() topic, headers, err := k.mapFn(publish.TopicName) if err != nil { logger.Errorf("map topic %v failed, %v", publish.TopicName, err) - ctx.SetEarlyStop() - return &context.MQTTResult{ErrString: resultMQTTTopicMapFailed} + + return resultMQTTTopicMapFailed } ctx.SetKV(k.spec.SetKV.Topic, topic) ctx.SetKV(k.spec.SetKV.Headers, headers) - return &context.MQTTResult{ErrString: ""} + return "" } diff --git a/pkg/filters/topicmapper/topicmapper_test.go b/pkg/filters/topicmapper/topicmapper_test.go index 0590514337..995d694770 100644 --- a/pkg/filters/topicmapper/topicmapper_test.go +++ b/pkg/filters/topicmapper/topicmapper_test.go @@ -18,37 +18,33 @@ package topicmapper import ( - stdcontext "context" "testing" "github.com/eclipse/paho.mqtt.golang/packets" "github.com/megaease/easegress/pkg/context" - "github.com/megaease/easegress/pkg/filters" + "github.com/megaease/easegress/pkg/protocols/mqttprot" "github.com/stretchr/testify/assert" ) -func defaultFilterSpec(spec *Spec) filters.Spec { - spec.BaseSpec.MetaSpec.Kind = Kind - spec.BaseSpec.MetaSpec.Name = "topic-mapper-demo" - result, _ := filters.NewSpec(nil, "pipeline-demo", spec) - return result -} +func newContext(cid string, topic string) *context.Context { + ctx := context.New(nil) -func newContext(cid string, topic string) context.MQTTContext { - client := &context.MockMQTTClient{ + client := &mqttprot.MockClient{ MockClientID: cid, } packet := packets.NewControlPacket(packets.Publish).(*packets.PublishPacket) packet.TopicName = topic - ctx := context.NewMQTTContext(stdcontext.Background(), client, packet) + + req := mqttprot.NewRequest(packet, client) + ctx.SetRequest("req1", req) + ctx.UseRequest("req1", "req1") return ctx } func TestTopicMapper(t *testing.T) { spec := getDefaultSpec() - filterSpec := defaultFilterSpec(spec) - topicMapper := &TopicMapper{} - topicMapper.Init(filterSpec) + topicMapper := kind.CreateInstance(spec) + topicMapper.Init() defer topicMapper.Close() tests := []struct { @@ -82,7 +78,7 @@ func TestTopicMapper(t *testing.T) { for _, tt := range tests { ctx := newContext("client", tt.mqttTopic) - topicMapper.HandleMQTT(ctx) + topicMapper.Handle(ctx) assert.Equal(t, tt.topic, ctx.GetKV("topic").(string)) assert.Equal(t, tt.headers, ctx.GetKV("headers").(map[string]string)) } diff --git a/pkg/object/mqttproxy/broker.go b/pkg/object/mqttproxy/broker.go index 01af7721c1..51a164a6e2 100644 --- a/pkg/object/mqttproxy/broker.go +++ b/pkg/object/mqttproxy/broker.go @@ -19,7 +19,6 @@ package mqttproxy import ( "bytes" - stdcontext "context" "crypto/tls" "encoding/base64" "encoding/json" @@ -36,7 +35,8 @@ import ( "github.com/megaease/easegress/pkg/api" "github.com/megaease/easegress/pkg/context" "github.com/megaease/easegress/pkg/logger" - "github.com/megaease/easegress/pkg/object/pipeline" + "github.com/megaease/easegress/pkg/protocols/mqttprot" + "github.com/megaease/easegress/pkg/tracing" "github.com/openzipkin/zipkin-go/model" "github.com/openzipkin/zipkin-go/propagation/b3" ) @@ -53,6 +53,7 @@ type ( clients map[string]*Client tlsCfg *tls.Config pipelines map[PacketType]string + muxMapper context.MuxMapper sessMgr *SessionManager topicMgr *TopicManager @@ -108,7 +109,7 @@ func getPipelineMap(spec *Spec) (map[PacketType]string, error) { return ans, nil } -func newBroker(spec *Spec, store storage, memberURL func(string, string) ([]string, error)) *Broker { +func newBroker(spec *Spec, store storage, muxMapper context.MuxMapper, memberURL func(string, string) ([]string, error)) *Broker { broker := &Broker{ egName: spec.EGName, name: spec.Name, @@ -116,6 +117,7 @@ func newBroker(spec *Spec, store storage, memberURL func(string, string) ([]stri clients: make(map[string]*Client), memberURL: memberURL, done: make(chan struct{}), + muxMapper: muxMapper, } pipelines, err := getPipelineMap(spec) if err != nil { @@ -300,14 +302,15 @@ func (b *Broker) connectionValidation(connect *packets.ConnectPacket, conn net.C authPipeline, ok := b.pipelines[Connect] if ok { - pipe, err := pipeline.GetPipeline(authPipeline, context.MQTT) - if err != nil { - logger.SpanErrorf(nil, "get pipeline %v failed, %v", authPipeline, err) + pipe, ok := b.muxMapper.GetHandler(authPipeline) + if !ok { + logger.SpanErrorf(nil, "get pipeline %v failed", authPipeline) authFail = true } else { - ctx := context.NewMQTTContext(stdcontext.Background(), client, connect) - pipe.HandleMQTT(ctx) - if ctx.Disconnect() { + ctx := newContext(connect, client) + pipe.Handle(ctx) + res := ctx.Response().(*mqttprot.Response) + if res.Disconnect() { logger.SpanErrorf(nil, "client %v not get connect permission from pipeline", connect.ClientIdentifier) authFail = true } @@ -668,3 +671,12 @@ func (b *Broker) close() { } b.clients = nil } + +func newContext(packet packets.ControlPacket, client mqttprot.Client) *context.Context { + ctx := context.New(tracing.NoopSpan) + req := mqttprot.NewRequest(packet, client) + ctx.SetRequest(context.InitialRequestID, req) + resp := mqttprot.NewResponse() + ctx.SetResponse(context.DefaultResponseID, resp) + return ctx +} diff --git a/pkg/object/mqttproxy/client.go b/pkg/object/mqttproxy/client.go index 1e94aaf703..142f41333e 100644 --- a/pkg/object/mqttproxy/client.go +++ b/pkg/object/mqttproxy/client.go @@ -18,7 +18,6 @@ package mqttproxy import ( - stdcontext "context" "errors" "net" "reflect" @@ -27,9 +26,8 @@ import ( "time" "github.com/eclipse/paho.mqtt.golang/packets" - "github.com/megaease/easegress/pkg/context" "github.com/megaease/easegress/pkg/logger" - "github.com/megaease/easegress/pkg/object/pipeline" + "github.com/megaease/easegress/pkg/protocols/mqttprot" ) const ( @@ -102,7 +100,7 @@ type ( } ) -var _ context.MQTTClient = (*Client)(nil) +var _ mqttprot.Client = (*Client)(nil) // ClientID return client id of Client func (c *Client) ClientID() string { @@ -222,19 +220,20 @@ func (c *Client) runPipeline(packet packets.ControlPacket, packetType PacketType return nil } - pipe, err := pipeline.GetPipeline(pipelineName, context.MQTT) - if err != nil { - logger.SpanErrorf(nil, "get pipeline %v failed, %v", pipelineName, err) + pipe, ok := c.broker.muxMapper.GetHandler(pipelineName) + if !ok { + logger.SpanErrorf(nil, "get pipeline %v failed", pipelineName) return nil } - ctx := context.NewMQTTContext(stdcontext.Background(), c, packet) - pipe.HandleMQTT(ctx) - if ctx.Disconnect() { + ctx := newContext(packet, c) + pipe.Handle(ctx) + resp := ctx.Response().(*mqttprot.Response) + if resp.Disconnect() { c.close() return errors.New("pipeline set disconnect") } - if ctx.Drop() { + if resp.Drop() { return errors.New("pipeline set drop") } return nil @@ -275,13 +274,13 @@ func (c *Client) close() { if !ok { return } - pipe, err := pipeline.GetPipeline(pipelineName, context.MQTT) - if err != nil { - logger.SpanErrorf(nil, "get pipeline %v failed, %v", pipelineName, err) + pipe, ok := c.broker.muxMapper.GetHandler(pipelineName) + if !ok { + logger.SpanErrorf(nil, "get pipeline %v failed", pipelineName) } else { disconnect := packets.NewControlPacket(packets.Disconnect).(*packets.DisconnectPacket) - ctx := context.NewMQTTContext(stdcontext.Background(), c, disconnect) - pipe.HandleMQTT(ctx) + ctx := newContext(disconnect, c) + pipe.Handle(ctx) } } diff --git a/pkg/object/mqttproxy/mock_test.go b/pkg/object/mqttproxy/mock_test.go index 6baeba50f1..9545236782 100644 --- a/pkg/object/mqttproxy/mock_test.go +++ b/pkg/object/mqttproxy/mock_test.go @@ -29,7 +29,7 @@ import ( "github.com/megaease/easegress/pkg/cluster" "github.com/megaease/easegress/pkg/context" "github.com/megaease/easegress/pkg/filters" - "github.com/megaease/easegress/pkg/object/pipeline" + "github.com/megaease/easegress/pkg/protocols/mqttprot" "go.etcd.io/etcd/api/v3/mvccpb" clientv3 "go.etcd.io/etcd/client/v3" "go.etcd.io/etcd/client/v3/concurrency" @@ -197,34 +197,186 @@ func TestMockStorage(t *testing.T) { } type MockKafka struct { - ch chan *packets.PublishPacket + ch chan *packets.PublishPacket + spec *MockKafkaSpec } -type MockKafkaSpec struct{} +type MockKafkaSpec struct { + filters.BaseSpec `yaml:",inline"` +} + +var mockKafkaKind = &filters.Kind{ + Name: "MockKafka", + Description: "MockKafka filter is used for testing MQTTProxy", + Results: []string{}, + DefaultSpec: func() filters.Spec { + return &MockKafkaSpec{} + }, + CreateInstance: func(spec filters.Spec) filters.Filter { + return &MockKafka{ + spec: spec.(*MockKafkaSpec), + } + }, +} -var _ pipeline.MQTTFilter = (*MockKafka)(nil) +var _ filters.Filter = (*MockKafka)(nil) -func (k *MockKafka) Kind() string { return "MockKafka" } -func (k *MockKafka) DefaultSpec() interface{} { return &MockKafkaSpec{} } -func (k *MockKafka) Status() interface{} { return nil } -func (k *MockKafka) Description() string { return "mock kafka" } -func (k *MockKafka) Inherit(filterSpec *filters.Spec, previous filters.Filter) {} -func (k *MockKafka) Close() {} -func (k *MockKafka) Results() []string { return nil } +func (k *MockKafka) Name() string { return k.spec.Name() } +func (k *MockKafka) Kind() *filters.Kind { return mockKafkaKind } +func (k *MockKafka) Spec() filters.Spec { return nil } +func (k *MockKafka) Status() interface{} { return nil } +func (k *MockKafka) Inherit(previous filters.Filter) {} +func (k *MockKafka) Close() {} -func (k *MockKafka) Init(filterSpec *filters.Spec) { +func (k *MockKafka) Init() { k.ch = make(chan *packets.PublishPacket, 100) } -func (k *MockKafka) HandleMQTT(ctx context.MQTTContext) *context.MQTTResult { - if ctx.PacketType() != context.MQTTPublish { - panic(fmt.Errorf("mock kafka for test should only receive publish packet, but received %v", ctx.PacketType())) +func (k *MockKafka) Handle(ctx *context.Context) string { + req := ctx.Request().(*mqttprot.Request) + if req.PacketType() != mqttprot.PublishType { + panic(fmt.Errorf("mock kafka for test should only receive publish packet, but received %v", req.PacketType())) } - k.ch <- ctx.PublishPacket() - return &context.MQTTResult{} + k.ch <- req.PublishPacket() + return "" } func (k *MockKafka) get() *packets.PublishPacket { p := <-k.ch return p } + +var mockMQTTFilterKind = &filters.Kind{ + Name: "MockMQTTFilter", + Description: "MockFilter is used for testing MQTTProxy", + Results: []string{}, + DefaultSpec: func() filters.Spec { + return &MockMQTTSpec{} + }, + CreateInstance: func(spec filters.Spec) filters.Filter { + return &MockMQTTFilter{spec: spec.(*MockMQTTSpec)} + }, +} + +var _ filters.Filter = (*MockMQTTFilter)(nil) + +// MockMQTTFilter is used for test pipeline, which will count the client number of MQTTContext +type MockMQTTFilter struct { + mu sync.Mutex + + spec *MockMQTTSpec + clients map[string]int + disconnect map[string]struct{} + subscribe map[string][]string + unsubscribe map[string][]string +} + +// MockMQTTSpec is spec of MockMQTTFilter +type MockMQTTSpec struct { + filters.BaseSpec `yaml:",inline"` + UserName string `yaml:"userName" jsonschema:"required"` + Password string `yaml:"password" jsonschema:"required"` + Port uint16 `yaml:"port" jsonschema:"required"` + BackendType string `yaml:"backendType" jsonschema:"required"` + EarlyStop bool `yaml:"earlyStop" jsonschema:"omitempty"` + KeysToStore []string `yaml:"keysToStore" jsonschema:"omitempty"` + ConnectKey string `yaml:"connectKey" jsonschema:"omitempty"` +} + +// MockMQTTStatus is status of MockMQTTFilter +type MockMQTTStatus struct { + ClientCount map[string]int + ClientDisconnect map[string]struct{} + Subscribe map[string][]string + Unsubscribe map[string][]string +} + +var _ filters.Filter = (*MockMQTTFilter)(nil) + +// Kind retrun kind of MockMQTTFilter +func (m *MockMQTTFilter) Kind() *filters.Kind { + return mockMQTTFilterKind +} + +// Init init MockMQTTFilter +func (m *MockMQTTFilter) Init() { + m.clients = make(map[string]int) + m.disconnect = make(map[string]struct{}) + m.subscribe = make(map[string][]string) + m.unsubscribe = make(map[string][]string) +} + +func (m *MockMQTTFilter) Name() string { + return m.spec.Name() +} + +func (m *MockMQTTFilter) Inherit(previous filters.Filter) { + m.Init() +} + +// HandleMQTT handle MQTTContext +func (m *MockMQTTFilter) Handle(ctx *context.Context) string { + req := ctx.Request().(*mqttprot.Request) + resp := ctx.Response().(*mqttprot.Response) + m.mu.Lock() + defer m.mu.Unlock() + m.clients[req.Client().ClientID()]++ + + switch req.PacketType() { + case mqttprot.ConnectType: + req.Client().Store(m.spec.ConnectKey, struct{}{}) + if req.ConnectPacket().Username != m.spec.UserName || string(req.ConnectPacket().Password) != m.spec.Password { + resp.SetDisconnect() + } + case mqttprot.DisconnectType: + m.disconnect[req.Client().ClientID()] = struct{}{} + case mqttprot.SubscribeType: + m.subscribe[req.Client().ClientID()] = req.SubscribePacket().Topics + case mqttprot.UnsubscribeType: + m.unsubscribe[req.Client().ClientID()] = req.UnsubscribePacket().Topics + } + + for _, k := range m.spec.KeysToStore { + req.Client().Store(k, struct{}{}) + } + return "" +} + +// Status return status of MockMQTTFilter +func (m *MockMQTTFilter) Status() interface{} { + m.mu.Lock() + defer m.mu.Unlock() + + clientCount := make(map[string]int) + for k, v := range m.clients { + clientCount[k] = v + } + disconnect := make(map[string]struct{}) + for k := range m.disconnect { + disconnect[k] = struct{}{} + } + subscribe := make(map[string][]string) + for k, v := range m.subscribe { + vv := make([]string, len(v)) + copy(vv, v) + subscribe[k] = v + } + unsubscribe := make(map[string][]string) + for k, v := range m.unsubscribe { + vv := make([]string, len(v)) + copy(vv, v) + unsubscribe[k] = vv + } + return MockMQTTStatus{ + ClientCount: clientCount, + ClientDisconnect: disconnect, + Subscribe: subscribe, + Unsubscribe: unsubscribe, + } +} + +func (m *MockMQTTFilter) Spec() filters.Spec { + return &MockMQTTSpec{} +} + +func (m *MockMQTTFilter) Close() {} diff --git a/pkg/object/mqttproxy/mqtt_test.go b/pkg/object/mqttproxy/mqtt_test.go index f50de8d159..198e4b7550 100644 --- a/pkg/object/mqttproxy/mqtt_test.go +++ b/pkg/object/mqttproxy/mqtt_test.go @@ -19,7 +19,7 @@ package mqttproxy import ( "bytes" - "context" + stdcontext "context" "encoding/base64" "encoding/json" "errors" @@ -35,6 +35,7 @@ import ( paho "github.com/eclipse/paho.mqtt.golang" "github.com/eclipse/paho.mqtt.golang/packets" + "github.com/megaease/easegress/pkg/context" "github.com/megaease/easegress/pkg/filters" _ "github.com/megaease/easegress/pkg/filters/mqttclientauth" "github.com/megaease/easegress/pkg/logger" @@ -48,9 +49,19 @@ import ( func init() { logger.InitNop() - filters.Register(&pipeline.MockMQTTFilter{}) - filters.Register(&MockKafka{}) - // filters.Register(&authentication.Authentication{}) + filters.Register(mockKafkaKind) + filters.Register(mockMQTTFilterKind) +} + +type mockMuxMapper struct { + MockFunc func(name string) (context.Handler, bool) +} + +func (m *mockMuxMapper) GetHandler(name string) (context.Handler, bool) { + if m.MockFunc != nil { + return m.MockFunc(name) + } + return nil, false } var ( @@ -106,9 +117,9 @@ func getDefaultSpec() *Spec { return spec } -func getBrokerFromSpec(spec *Spec) *Broker { +func getBrokerFromSpec(spec *Spec, mapper context.MuxMapper) *Broker { store := newStorage(nil) - broker := newBroker(spec, store, func(s, ss string) ([]string, error) { + broker := newBroker(spec, store, mapper, func(s, ss string) ([]string, error) { m := map[string]string{ "test": "http://localhost:8888/mqtt", "test1": "http://localhost:8889/mqtt", @@ -124,16 +135,15 @@ func getBrokerFromSpec(spec *Spec) *Broker { return broker } -func getDefaultBroker() *Broker { +func getDefaultBroker(mapper context.MuxMapper) *Broker { spec := getDefaultSpec() - return getBrokerFromSpec(spec) + return getBrokerFromSpec(spec, mapper) } func getConnectPipeline(t *testing.T) *pipeline.Pipeline { yamlStr := ` name: %s kind: Pipeline -protocol: MQTT filters: - name: connect kind: MockMQTTFilter @@ -146,7 +156,7 @@ filters: superSpec, err := super.NewSpec(yamlStr) require.Nil(t, err) pipe := &pipeline.Pipeline{} - pipe.Init(superSpec) + pipe.Init(superSpec, nil) return pipe } @@ -165,7 +175,7 @@ filters: superSpec, err := super.NewSpec(yamlStr) require.Nil(t, err) pipe := &pipeline.Pipeline{} - pipe.Init(superSpec) + pipe.Init(superSpec, nil) filter := pipeline.MockGetFilter(pipe, "publish") require.NotNil(t, filter) @@ -205,6 +215,9 @@ func checkSessionStore(broker *Broker, cid, topic string) error { } func TestConnection(t *testing.T) { + pipe := getConnectPipeline(t) + defer pipe.Close() + spec := getDefaultSpec() spec.Rules = append(spec.Rules, &Rule{ When: &When{ @@ -212,12 +225,14 @@ func TestConnection(t *testing.T) { }, Pipeline: connectPipeline, }) - broker := getBrokerFromSpec(spec) + mapper := &mockMuxMapper{ + MockFunc: func(name string) (context.Handler, bool) { + return pipe, true + }, + } + broker := getBrokerFromSpec(spec, mapper) defer broker.close() - pipe := getConnectPipeline(t) - defer pipe.Close() - c1 := getDefaultMQTTClient(t, "test", true) c1.Disconnect(200) @@ -229,12 +244,17 @@ func TestConnection(t *testing.T) { } func TestPublish(t *testing.T) { - broker := getDefaultBroker() - defer broker.close() - pipe, backend := getPublishPipeline(t) defer pipe.Close() + mapper := &mockMuxMapper{ + MockFunc: func(name string) (context.Handler, bool) { + return pipe, true + }, + } + broker := getDefaultBroker(mapper) + defer broker.close() + client := getMQTTClient(t, "test", "test", "test", true) for i := 0; i < 5; i++ { @@ -265,7 +285,7 @@ func TestPublish(t *testing.T) { } func TestSubUnsub(t *testing.T) { - broker := getDefaultBroker() + broker := getDefaultBroker(nil) defer broker.close() client := getMQTTClient(t, "test", "test", "test", true) @@ -282,7 +302,8 @@ func TestSubUnsub(t *testing.T) { } func TestCleanSession(t *testing.T) { - broker := getDefaultBroker() + mapper := &mockMuxMapper{} + broker := getDefaultBroker(mapper) defer broker.close() // client that set cleanSession @@ -379,12 +400,17 @@ func TestCleanSession(t *testing.T) { } func TestMultiClientPublish(t *testing.T) { - broker := getDefaultBroker() - defer broker.close() - pipe, backend := getPublishPipeline(t) defer pipe.Close() + mapper := &mockMuxMapper{ + MockFunc: func(name string) (context.Handler, bool) { + return pipe, true + }, + } + broker := getDefaultBroker(mapper) + defer broker.close() + var wg sync.WaitGroup clientNum := 30 @@ -433,7 +459,7 @@ func TestMultiClientPublish(t *testing.T) { } func TestSession(t *testing.T) { - broker := getDefaultBroker() + broker := getDefaultBroker(nil) defer broker.close() client := getDefaultMQTTClient(t, "test", true) @@ -534,13 +560,18 @@ func TestSendMsgBack(t *testing.T) { subscribeCh := make(chan CheckMsg, 100) var wg sync.WaitGroup - // make broker - broker := getDefaultBroker() - defer broker.close() - pipe, backend := getPublishPipeline(t) defer pipe.Close() + // make broker + mapper := &mockMuxMapper{ + MockFunc: func(name string) (context.Handler, bool) { + return pipe, true + }, + } + broker := getDefaultBroker(mapper) + defer broker.close() + // subscribe handler handler := getMQTTSubscribeHandler(subscribeCh) // make client and subscribe different topic @@ -674,7 +705,7 @@ func TestSendMsgBack(t *testing.T) { } func TestYamlEncodeDecode(t *testing.T) { - broker := getDefaultBroker() + broker := getDefaultBroker(nil) defer broker.close() // old session @@ -740,7 +771,7 @@ func (ts *testServer) start() error { } func (ts *testServer) shutdown() { - ts.srv.Shutdown(context.Background()) + ts.srv.Shutdown(stdcontext.Background()) } func topicsPublish(t *testing.T, data HTTPJsonData) int { @@ -761,7 +792,7 @@ func topicsPublish(t *testing.T, data HTTPJsonData) int { } func TestHTTPRequest(t *testing.T) { - broker := getDefaultBroker() + broker := getDefaultBroker(nil) defer broker.close() srv := newServer(":8888") @@ -845,7 +876,7 @@ func TestHTTPRequest(t *testing.T) { } func TestHTTPPublish(t *testing.T) { - broker := getDefaultBroker() + broker := getDefaultBroker(nil) defer broker.close() srv := newServer(":8888") @@ -904,7 +935,7 @@ func TestHTTPPublish(t *testing.T) { } func TestHTTPTransfer(t *testing.T) { - broker0 := getDefaultBroker() + broker0 := getDefaultBroker(nil) srv0 := newServer(":8888") srv0.addHandlerFunc("/mqtt", broker0.httpTopicsPublishHandler) @@ -914,7 +945,7 @@ func TestHTTPTransfer(t *testing.T) { spec.EGName = "test1" spec.Name = "test1" spec.Port = 1884 - broker1 := getBrokerFromSpec(spec) + broker1 := getBrokerFromSpec(spec, nil) srv1 := newServer(":8889") srv1.addHandlerFunc("/mqtt", broker1.httpTopicsPublishHandler) @@ -1154,7 +1185,7 @@ func TestTLSConfig(t *testing.T) { } func TestSessMgr(t *testing.T) { - broker := getDefaultBroker() + broker := getDefaultBroker(nil) defer broker.close() sessMgr := broker.sessMgr @@ -1259,7 +1290,7 @@ func TestBrokerListen(t *testing.T) { {"fake", "abc", "abc"}, }, } - broker := getBrokerFromSpec(spec) + broker := getBrokerFromSpec(spec, nil) if broker != nil { t.Errorf("invalid tls config should return nil broker") } @@ -1274,12 +1305,12 @@ func TestBrokerListen(t *testing.T) { {"demo", certPem, keyPem}, }, } - broker = getBrokerFromSpec(spec) + broker = getBrokerFromSpec(spec, nil) if broker == nil { t.Errorf("valid tls config should not return nil broker") } - broker1 := getBrokerFromSpec(spec) + broker1 := getBrokerFromSpec(spec, nil) if broker1 != nil { t.Errorf("not valid port should return nil broker") } @@ -1290,7 +1321,7 @@ func TestBrokerListen(t *testing.T) { EGName: "test-1", Port: 1883, } - broker2 := getBrokerFromSpec(spec) + broker2 := getBrokerFromSpec(spec, nil) if broker2 != nil { t.Errorf("not valid port should return nil broker") } @@ -1300,7 +1331,7 @@ func TestBrokerListen(t *testing.T) { } func TestBrokerHandleConn(t *testing.T) { - broker := getDefaultBroker() + broker := getDefaultBroker(nil) // broker handleConn return if error happen svcConn, clientConn := net.Pipe() @@ -1329,7 +1360,7 @@ func TestMQTTProxy(t *testing.T) { mp := MQTTProxy{} mp.Status() - broker := getDefaultBroker() + broker := getDefaultBroker(nil) mp.broker = broker broker.reconnectWatcher() @@ -1365,15 +1396,23 @@ filters: t.Errorf("supervisor unmarshal yaml failed, %s", err) t.Skip() } - pipe := pipeline.Pipeline{} - pipe.Init(superSpec) + pipe := &pipeline.Pipeline{} + pipe.Init(superSpec, nil) defer pipe.Close() publishPipe, backend := getPublishPipeline(t) defer publishPipe.Close() // get broker - broker := getDefaultBroker() + mapper := &mockMuxMapper{ + MockFunc: func(name string) (context.Handler, bool) { + if name == publishPipeline { + return publishPipe, true + } + return pipe, true + }, + } + broker := getDefaultBroker(mapper) broker.pipelines[Subscribe] = "mqtt-test-pipeline" broker.pipelines[Unsubscribe] = "mqtt-test-pipeline" broker.pipelines[Disconnect] = "mqtt-test-pipeline" @@ -1413,7 +1452,7 @@ filters: time.Sleep(100 * time.Millisecond) } - filterStatus := pipe.Status().ObjectStatus.(*pipeline.Status).Filters["mqtt-filter"].(pipeline.MockMQTTStatus) + filterStatus := pipe.Status().ObjectStatus.(*pipeline.Status).Filters["mqtt-filter"].(MockMQTTStatus) if len(filterStatus.ClientCount) != clientNum { t.Errorf("filter get wrong result %v for client num", len(filterStatus.ClientCount)) } @@ -1429,7 +1468,8 @@ filters: } func TestAuthByPipeline(t *testing.T) { - broker := getDefaultBroker() + mapper := &mockMuxMapper{} + broker := getDefaultBroker(mapper) broker.pipelines[Connect] = "mqtt-test-pipeline" defer broker.close() @@ -1462,10 +1502,13 @@ filters: t.Errorf("supervisor unmarshal yaml failed, %s", err) t.Skip() } - pipe := pipeline.Pipeline{} - pipe.Init(superSpec) + pipe := &pipeline.Pipeline{} + pipe.Init(superSpec, nil) defer pipe.Close() + mapper.MockFunc = func(name string) (context.Handler, bool) { + return pipe, true + } // same username and passwd with filter, success option = paho.NewClientOptions().AddBroker("tcp://0.0.0.0:1883").SetClientID("test").SetUsername("filter-auth-name").SetPassword("filter-auth-passwd") client = paho.NewClient(option) @@ -1491,7 +1534,7 @@ filters: func TestMaxAllowedConnection(t *testing.T) { spec := getDefaultSpec() spec.MaxAllowedConnection = 10 - broker := getBrokerFromSpec(spec) + broker := getBrokerFromSpec(spec, nil) defer broker.close() clients := []paho.Client{} @@ -1528,7 +1571,7 @@ func TestConnectionLimit(t *testing.T) { RequestRate: 10, TimePeriod: 1000, } - broker := getBrokerFromSpec(spec) + broker := getBrokerFromSpec(spec, nil) defer broker.close() // use all rate @@ -1547,7 +1590,7 @@ func TestClientPublishLimit(t *testing.T) { RequestRate: 10, TimePeriod: 1000, } - broker := getBrokerFromSpec(spec) + broker := getBrokerFromSpec(spec, nil) defer broker.close() client := getMQTTClient(t, "test", "test", "test", true) @@ -1566,7 +1609,7 @@ func TestClientPublishLimit(t *testing.T) { } func TestHTTPGetAllSession(t *testing.T) { - broker := getDefaultBroker() + broker := getDefaultBroker(nil) defer broker.close() // connect 10 clients @@ -1655,7 +1698,7 @@ func TestHTTPGetAllSession(t *testing.T) { } func TestHTTPDeleteSession(t *testing.T) { - broker := getDefaultBroker() + broker := getDefaultBroker(nil) defer broker.close() // connect 10 clients @@ -1709,7 +1752,7 @@ func TestHTTPDeleteSession(t *testing.T) { func TestHTTPTransferHeaderCopy(t *testing.T) { done := make(chan bool, 2) - broker0 := getDefaultBroker() + broker0 := getDefaultBroker(nil) srv0 := newServer(":8888") srv0.addHandlerFunc("/mqtt", func(w http.ResponseWriter, r *http.Request) { broker0.httpTopicsPublishHandler(w, r) @@ -1721,7 +1764,7 @@ func TestHTTPTransferHeaderCopy(t *testing.T) { spec.Name = "test1" spec.EGName = "test1" spec.Port = 1884 - broker1 := getBrokerFromSpec(spec) + broker1 := getBrokerFromSpec(spec, nil) srv1 := newServer(":8889") srv1.addHandlerFunc("/mqtt", func(w http.ResponseWriter, r *http.Request) { diff --git a/pkg/object/mqttproxy/mqttproxy.go b/pkg/object/mqttproxy/mqttproxy.go index 5c26ca790f..36ccbe3b23 100644 --- a/pkg/object/mqttproxy/mqttproxy.go +++ b/pkg/object/mqttproxy/mqttproxy.go @@ -23,6 +23,7 @@ import ( "net/url" "github.com/megaease/easegress/pkg/cluster" + "github.com/megaease/easegress/pkg/context" "github.com/megaease/easegress/pkg/logger" "github.com/megaease/easegress/pkg/supervisor" "gopkg.in/yaml.v2" @@ -30,12 +31,14 @@ import ( const ( // Category is the category of MQTTProxy. - Category = supervisor.CategoryBusinessController + Category = supervisor.CategoryTrafficGate // Kind is the kind of MQTTProxy. Kind = "MQTTProxy" ) +var _ supervisor.TrafficObject = (*MQTTProxy)(nil) + func init() { supervisor.Register(&MQTTProxy{}) } @@ -130,14 +133,14 @@ func memberURLFunc(superSpec *supervisor.Spec) func(string, string) ([]string, e } // Init initializes Function. -func (mp *MQTTProxy) Init(superSpec *supervisor.Spec) { +func (mp *MQTTProxy) Init(superSpec *supervisor.Spec, muxMapper context.MuxMapper) { spec := superSpec.ObjectSpec().(*Spec) spec.Name = superSpec.Name() spec.EGName = superSpec.Super().Options().Name mp.superSpec, mp.spec = superSpec, spec store := newStorage(superSpec.Super().Cluster()) - mp.broker = newBroker(spec, store, memberURLFunc(superSpec)) + mp.broker = newBroker(spec, store, muxMapper, memberURLFunc(superSpec)) if mp.broker == nil { panic(fmt.Sprintf("broker %v start failed", spec.Name)) } @@ -145,9 +148,9 @@ func (mp *MQTTProxy) Init(superSpec *supervisor.Spec) { } // Inherit inherits previous generation of WebSocketServer. -func (mp *MQTTProxy) Inherit(superSpec *supervisor.Spec, previousGeneration supervisor.Object) { +func (mp *MQTTProxy) Inherit(superSpec *supervisor.Spec, previousGeneration supervisor.Object, muxMapper context.MuxMapper) { previousGeneration.Close() - mp.Init(superSpec) + mp.Init(superSpec, muxMapper) } // Close closes MQTTProxy. diff --git a/pkg/object/pipeline/mock.go b/pkg/object/pipeline/mock.go new file mode 100644 index 0000000000..70357df9fc --- /dev/null +++ b/pkg/object/pipeline/mock.go @@ -0,0 +1,25 @@ +/* + * Copyright (c) 2017, MegaEase + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package pipeline + +import "github.com/megaease/easegress/pkg/filters" + +// MockGetFilter is used to get filter from pipeline for testing. +func MockGetFilter(p *Pipeline, name string) filters.Filter { + return p.filters[name] +} diff --git a/pkg/protocols/mqttprot/mock.go b/pkg/protocols/mqttprot/mock.go new file mode 100644 index 0000000000..1b226f3807 --- /dev/null +++ b/pkg/protocols/mqttprot/mock.go @@ -0,0 +1,54 @@ +/* + * Copyright (c) 2017, MegaEase + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package mqttprot + +import "sync" + +// MockClient is mock client for MQTT protocol +type MockClient struct { + MockClientID string + MockUserName string + MockKVMap sync.Map +} + +var _ Client = (*MockClient)(nil) + +// ClientID return client id of MockClient +func (m *MockClient) ClientID() string { + return m.MockClientID +} + +// UserName return username if MockClient +func (m *MockClient) UserName() string { + return m.MockUserName +} + +// Load load value keep in MockClient kv map +func (m *MockClient) Load(key interface{}) (value interface{}, ok bool) { + return m.MockKVMap.Load(key) +} + +// Store store kv pair into MockClient kv map +func (m *MockClient) Store(key interface{}, value interface{}) { + m.MockKVMap.Store(key, value) +} + +// Delete delete key-value pair in MockClient kv map +func (m *MockClient) Delete(key interface{}) { + m.MockKVMap.Delete(key) +} diff --git a/pkg/protocols/mqttprot/request.go b/pkg/protocols/mqttprot/request.go new file mode 100644 index 0000000000..ebe14f5d7e --- /dev/null +++ b/pkg/protocols/mqttprot/request.go @@ -0,0 +1,164 @@ +/* + * Copyright (c) 2017, MegaEase + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package mqttprot + +import ( + "bytes" + "io" + + "github.com/megaease/easegress/pkg/protocols" + + "github.com/eclipse/paho.mqtt.golang/packets" +) + +type ( + // Request contains MQTT packet. + Request struct { + client Client + packet packets.ControlPacket + packetType PacketType + payload []byte + } + + // Client contains MQTT client info that send this packet + Client interface { + ClientID() string + UserName() string + Load(key interface{}) (value interface{}, ok bool) + Store(key interface{}, value interface{}) + Delete(key interface{}) + } + + // PacketType contains supported MQTT packet type + PacketType int +) + +const ( + // ConnectType is MQTT packet type of connect + ConnectType PacketType = 1 + + // PublishType is MQTT packet type of publish + PublishType PacketType = 2 + + // DisconnectType is MQTT packet type of disconnect + DisconnectType PacketType = 3 + + // SubscribeType is MQTT packet type of subscribe + SubscribeType PacketType = 4 + + // UnsubscribeType is MQTT packet type of unsubscribe + UnsubscribeType PacketType = 5 + + // OtherType is all other MQTT packet type + OtherType PacketType = 99 +) + +var _ protocols.Request = (*Request)(nil) + +// NewRequest create new MQTT Request +func NewRequest(packet packets.ControlPacket, client Client) *Request { + req := &Request{ + client: client, + packet: packet, + } + switch p := packet.(type) { + case *packets.ConnectPacket: + req.packetType = ConnectType + case *packets.PublishPacket: + req.packetType = PublishType + req.payload = p.Payload + case *packets.DisconnectPacket: + req.packetType = DisconnectType + case *packets.SubscribePacket: + req.packetType = SubscribeType + case *packets.UnsubscribePacket: + req.packetType = UnsubscribeType + default: + req.packetType = OtherType + } + return req +} + +// Client return MQTT request client +func (r *Request) Client() Client { + return r.client +} + +// PacketType return MQTT request packet type +func (r *Request) PacketType() PacketType { + return r.packetType +} + +// ConnectPacket return MQTT connect packet if PacketType is ConnectType +func (r *Request) ConnectPacket() *packets.ConnectPacket { + return r.packet.(*packets.ConnectPacket) +} + +// PublishPacket return MQTT publish packet if PacketType is PublishType +func (r *Request) PublishPacket() *packets.PublishPacket { + return r.packet.(*packets.PublishPacket) +} + +// DisconnectPacket return MQTT disconnect packet if PacketType is DisconnectType +func (r *Request) DisconnectPacket() *packets.DisconnectPacket { + return r.packet.(*packets.DisconnectPacket) +} + +// SubscribePacket return MQTT subscribe packet if PacketType is SubscribeType +func (r *Request) SubscribePacket() *packets.SubscribePacket { + return r.packet.(*packets.SubscribePacket) +} + +// UnsubscribePacket return MQTT unsubscribe packet if PacketType is UnsubscribeType +func (r *Request) UnsubscribePacket() *packets.UnsubscribePacket { + return r.packet.(*packets.UnsubscribePacket) +} + +// Header return MQTT request header +func (r *Request) Header() protocols.Header { + // TODO: what header to return? + return nil +} + +// SetPayload set the payload of the request to payload. +func (r *Request) SetPayload(payload []byte) { + r.payload = payload + if r.packetType == PublishType { + r.packet.(*packets.PublishPacket).Payload = payload + } +} + +// GetPayload returns a new payload reader. +func (r *Request) GetPayload() io.Reader { + return bytes.NewReader(r.payload) +} + +// RawPayload returns the payload in []byte, the caller should +// not modify its content. +func (r *Request) RawPayload() []byte { + return r.payload +} + +// PayloadLength returns the length of the payload. +func (r *Request) PayloadLength() int { + return len(r.payload) +} + +// Close closes the request. +func (r *Request) Close() { +} diff --git a/pkg/protocols/mqttprot/response.go b/pkg/protocols/mqttprot/response.go new file mode 100644 index 0000000000..6e601b9771 --- /dev/null +++ b/pkg/protocols/mqttprot/response.go @@ -0,0 +1,94 @@ +/* + * Copyright (c) 2017, MegaEase + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package mqttprot + +import ( + "bytes" + "io" + + "github.com/megaease/easegress/pkg/protocols" +) + +type ( + // Response contains MQTT response. + Response struct { + drop bool + disconnect bool + payload []byte + } +) + +var _ protocols.Response = (*Response)(nil) + +// NewResponse returns a new MQTT response. +func NewResponse() *Response { + return &Response{} +} + +// SetDrop means the packet in context will be drop. +func (r *Response) SetDrop() { + r.drop = true +} + +// Drop return true if the packet in context will be drop. +// For example, if SetDrop and the packet in Request is subscribe packet, MQTTProxy will +// not subscribe the topics in the packet. +func (r *Response) Drop() bool { + return r.drop +} + +// SetDisconnect means the MQTT client will be disconnect. +func (r *Response) SetDisconnect() { + r.disconnect = true +} + +// Disconnect return true if the MQTT client will be disconnect. +func (r *Response) Disconnect() bool { + return r.disconnect +} + +// Header return MQTT response header +func (r *Response) Header() protocols.Header { + // TODO: what header to return? + return nil +} + +// SetPayload set the payload of the response to payload. +func (r *Response) SetPayload(payload []byte) { + r.payload = payload +} + +// GetPayload returns a new payload reader. +func (r *Response) GetPayload() io.Reader { + return bytes.NewReader(r.payload) +} + +// RawPayload returns the payload in []byte, the caller should +// not modify its content. +func (r *Response) RawPayload() []byte { + return r.payload +} + +// PayloadLength returns the length of the payload. +func (r *Response) PayloadLength() int { + return len(r.payload) +} + +// Close closes the response. +func (r *Response) Close() { +}