From 5e06e56785fa849de569accbcdf2e869ac1ae665 Mon Sep 17 00:00:00 2001 From: Daniel Nelson Date: Wed, 14 Aug 2019 17:05:34 -0700 Subject: [PATCH] Fix persistent session in mqtt_consumer (#6236) --- plugins/inputs/mqtt_consumer/README.md | 39 +-- plugins/inputs/mqtt_consumer/mqtt_consumer.go | 141 ++++++--- .../mqtt_consumer/mqtt_consumer_test.go | 287 +++++++++++++----- 3 files changed, 316 insertions(+), 151 deletions(-) diff --git a/plugins/inputs/mqtt_consumer/README.md b/plugins/inputs/mqtt_consumer/README.md index da3ce43f5..53476cb3d 100644 --- a/plugins/inputs/mqtt_consumer/README.md +++ b/plugins/inputs/mqtt_consumer/README.md @@ -3,13 +3,20 @@ The [MQTT][mqtt] consumer plugin reads from the specified MQTT topics and creates metrics using one of the supported [input data formats][]. -### Configuration: +### Configuration ```toml [[inputs.mqtt_consumer]] ## MQTT broker URLs to be used. The format should be scheme://host:port, ## schema can be tcp, ssl, or ws. - servers = ["tcp://localhost:1883"] + servers = ["tcp://127.0.0.1:1883"] + + ## Topics that will be subscribed to. + topics = [ + "telegraf/host01/cpu", + "telegraf/+/mem", + "sensors/#", + ] ## QoS policy for messages ## 0 = at most once @@ -18,10 +25,10 @@ and creates metrics using one of the supported [input data formats][]. ## ## When using a QoS of 1 or 2, you should enable persistent_session to allow ## resuming unacknowledged messages. - qos = 0 + # qos = 0 ## Connection timeout for initial connection in seconds - connection_timeout = "30s" + # connection_timeout = "30s" ## Maximum messages to read from the broker that have not been written by an ## output. For best throughput set based on the number of metrics within @@ -33,21 +40,17 @@ and creates metrics using one of the supported [input data formats][]. ## waiting until the next flush_interval. # max_undelivered_messages = 1000 - ## Topics to subscribe to - topics = [ - "telegraf/host01/cpu", - "telegraf/+/mem", - "sensors/#", - ] + ## Persistent session disables clearing of the client session on connection. + ## In order for this option to work you must also set client_id to identity + ## the client. To receive messages that arrived while the client is offline, + ## also set the qos option to 1 or 2 and don't forget to also set the QoS when + ## publishing. + # persistent_session = false - # if true, messages that can't be delivered while the subscriber is offline - # will be delivered when it comes back (such as on service restart). - # NOTE: if true, client_id MUST be set - persistent_session = false - # If empty, a random client ID will be generated. - client_id = "" + ## If unset, a random client ID will be generated. + # client_id = "" - ## username and password to connect MQTT server. + ## Username and password to connect MQTT server. # username = "telegraf" # password = "metricsmetricsmetricsmetrics" @@ -65,7 +68,7 @@ and creates metrics using one of the supported [input data formats][]. data_format = "influx" ``` -### Tags: +### Metrics - All measurements are tagged with the incoming topic, ie `topic=telegraf/host01/cpu` diff --git a/plugins/inputs/mqtt_consumer/mqtt_consumer.go b/plugins/inputs/mqtt_consumer/mqtt_consumer.go index da556159e..8a6d0d4de 100644 --- a/plugins/inputs/mqtt_consumer/mqtt_consumer.go +++ b/plugins/inputs/mqtt_consumer/mqtt_consumer.go @@ -33,6 +33,15 @@ const ( Connected ) +type Client interface { + Connect() mqtt.Token + SubscribeMultiple(filters map[string]byte, callback mqtt.MessageHandler) mqtt.Token + AddRoute(topic string, callback mqtt.MessageHandler) + Disconnect(quiesce uint) +} + +type ClientFactory func(o *mqtt.ClientOptions) Client + type MQTTConsumer struct { Servers []string Topics []string @@ -51,12 +60,13 @@ type MQTTConsumer struct { ClientID string `toml:"client_id"` tls.ClientConfig - client mqtt.Client - acc telegraf.TrackingAccumulator - state ConnectionState - subscribed bool - sem semaphore - messages map[telegraf.TrackingID]bool + clientFactory ClientFactory + client Client + opts *mqtt.ClientOptions + acc telegraf.TrackingAccumulator + state ConnectionState + sem semaphore + messages map[telegraf.TrackingID]bool ctx context.Context cancel context.CancelFunc @@ -65,7 +75,14 @@ type MQTTConsumer struct { var sampleConfig = ` ## MQTT broker URLs to be used. The format should be scheme://host:port, ## schema can be tcp, ssl, or ws. - servers = ["tcp://localhost:1883"] + servers = ["tcp://127.0.0.1:1883"] + + ## Topics that will be subscribed to. + topics = [ + "telegraf/host01/cpu", + "telegraf/+/mem", + "sensors/#", + ] ## QoS policy for messages ## 0 = at most once @@ -74,10 +91,10 @@ var sampleConfig = ` ## ## When using a QoS of 1 or 2, you should enable persistent_session to allow ## resuming unacknowledged messages. - qos = 0 + # qos = 0 ## Connection timeout for initial connection in seconds - connection_timeout = "30s" + # connection_timeout = "30s" ## Maximum messages to read from the broker that have not been written by an ## output. For best throughput set based on the number of metrics within @@ -89,21 +106,17 @@ var sampleConfig = ` ## waiting until the next flush_interval. # max_undelivered_messages = 1000 - ## Topics to subscribe to - topics = [ - "telegraf/host01/cpu", - "telegraf/+/mem", - "sensors/#", - ] + ## Persistent session disables clearing of the client session on connection. + ## In order for this option to work you must also set client_id to identity + ## the client. To receive messages that arrived while the client is offline, + ## also set the qos option to 1 or 2 and don't forget to also set the QoS when + ## publishing. + # persistent_session = false - # if true, messages that can't be delivered while the subscriber is offline - # will be delivered when it comes back (such as on service restart). - # NOTE: if true, client_id MUST be set - persistent_session = false - # If empty, a random client ID will be generated. - client_id = "" + ## If unset, a random client ID will be generated. + # client_id = "" - ## username and password to connect MQTT server. + ## Username and password to connect MQTT server. # username = "telegraf" # password = "metricsmetricsmetricsmetrics" @@ -133,7 +146,7 @@ func (m *MQTTConsumer) SetParser(parser parsers.Parser) { m.parser = parser } -func (m *MQTTConsumer) Start(acc telegraf.Accumulator) error { +func (m *MQTTConsumer) Init() error { m.state = Disconnected if m.PersistentSession && m.ClientID == "" { @@ -148,15 +161,32 @@ func (m *MQTTConsumer) Start(acc telegraf.Accumulator) error { return fmt.Errorf("connection_timeout must be greater than 1s: %s", m.ConnectionTimeout.Duration) } - m.acc = acc.WithTracking(m.MaxUndeliveredMessages) - m.ctx, m.cancel = context.WithCancel(context.Background()) - opts, err := m.createOpts() if err != nil { return err } - m.client = mqtt.NewClient(opts) + m.opts = opts + + return nil +} + +func (m *MQTTConsumer) Start(acc telegraf.Accumulator) error { + m.state = Disconnected + + m.acc = acc.WithTracking(m.MaxUndeliveredMessages) + m.ctx, m.cancel = context.WithCancel(context.Background()) + + m.client = m.clientFactory(m.opts) + + // AddRoute sets up the function for handling messages. These need to be + // added in case we find a persistent session containing subscriptions so we + // know where to dispatch presisted and new messages to. In the alternate + // case that we need to create the subscriptions these will be replaced. + for _, topic := range m.Topics { + m.client.AddRoute(topic, m.recvMessage) + } + m.state = Connecting m.connect() @@ -164,7 +194,8 @@ func (m *MQTTConsumer) Start(acc telegraf.Accumulator) error { } func (m *MQTTConsumer) connect() error { - if token := m.client.Connect(); token.Wait() && token.Error() != nil { + token := m.client.Connect() + if token.Wait() && token.Error() != nil { err := token.Error() m.state = Disconnected return err @@ -175,22 +206,26 @@ func (m *MQTTConsumer) connect() error { m.sem = make(semaphore, m.MaxUndeliveredMessages) m.messages = make(map[telegraf.TrackingID]bool) - // Only subscribe on first connection when using persistent sessions. On - // subsequent connections the subscriptions should be stored in the - // session, but the proper way to do this is to check the connection - // response to ensure a session was found. - if !m.PersistentSession || !m.subscribed { - topics := make(map[string]byte) - for _, topic := range m.Topics { - topics[topic] = byte(m.QoS) - } - subscribeToken := m.client.SubscribeMultiple(topics, m.recvMessage) - subscribeToken.Wait() - if subscribeToken.Error() != nil { - m.acc.AddError(fmt.Errorf("subscription error: topics: %s: %v", - strings.Join(m.Topics[:], ","), subscribeToken.Error())) - } - m.subscribed = true + // Presistent sessions should skip subscription if a session is present, as + // the subscriptions are stored by the server. + type sessionPresent interface { + SessionPresent() bool + } + if t, ok := token.(sessionPresent); ok && t.SessionPresent() { + log.Printf("D! [inputs.mqtt_consumer] Session found %v", m.Servers) + return nil + } + + topics := make(map[string]byte) + for _, topic := range m.Topics { + topics[topic] = byte(m.QoS) + } + + subscribeToken := m.client.SubscribeMultiple(topics, m.recvMessage) + subscribeToken.Wait() + if subscribeToken.Error() != nil { + m.acc.AddError(fmt.Errorf("subscription error: topics: %s: %v", + strings.Join(m.Topics[:], ","), subscribeToken.Error())) } return nil @@ -316,12 +351,20 @@ func (m *MQTTConsumer) createOpts() (*mqtt.ClientOptions, error) { return opts, nil } +func New(factory ClientFactory) *MQTTConsumer { + return &MQTTConsumer{ + Servers: []string{"tcp://127.0.0.1:1883"}, + ConnectionTimeout: defaultConnectionTimeout, + MaxUndeliveredMessages: defaultMaxUndeliveredMessages, + clientFactory: factory, + state: Disconnected, + } +} + func init() { inputs.Add("mqtt_consumer", func() telegraf.Input { - return &MQTTConsumer{ - ConnectionTimeout: defaultConnectionTimeout, - MaxUndeliveredMessages: defaultMaxUndeliveredMessages, - state: Disconnected, - } + return New(func(o *mqtt.ClientOptions) Client { + return mqtt.NewClient(o) + }) }) } diff --git a/plugins/inputs/mqtt_consumer/mqtt_consumer_test.go b/plugins/inputs/mqtt_consumer/mqtt_consumer_test.go index 2d17c16c3..07d2015a8 100644 --- a/plugins/inputs/mqtt_consumer/mqtt_consumer_test.go +++ b/plugins/inputs/mqtt_consumer/mqtt_consumer_test.go @@ -2,114 +2,233 @@ package mqtt_consumer import ( "testing" + "time" "github.com/eclipse/paho.mqtt.golang" + "github.com/influxdata/telegraf" + "github.com/influxdata/telegraf/plugins/parsers" "github.com/influxdata/telegraf/testutil" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) -const ( - testMsg = "cpu_load_short,host=server01 value=23422.0 1422568543702900257\n" - invalidMsg = "cpu_load_short,host=server01 1422568543702900257\n" -) +type FakeClient struct { + ConnectF func() mqtt.Token + SubscribeMultipleF func(filters map[string]byte, callback mqtt.MessageHandler) mqtt.Token + AddRouteF func(topic string, callback mqtt.MessageHandler) + DisconnectF func(quiesce uint) -func newTestMQTTConsumer() *MQTTConsumer { - n := &MQTTConsumer{ - Topics: []string{"telegraf"}, - Servers: []string{"localhost:1883"}, - } + connectCallCount int + subscribeCallCount int + addRouteCallCount int + disconnectCallCount int +} - return n +func (c *FakeClient) Connect() mqtt.Token { + c.connectCallCount++ + return c.ConnectF() +} + +func (c *FakeClient) SubscribeMultiple(filters map[string]byte, callback mqtt.MessageHandler) mqtt.Token { + c.subscribeCallCount++ + return c.SubscribeMultipleF(filters, callback) +} + +func (c *FakeClient) AddRoute(topic string, callback mqtt.MessageHandler) { + c.addRouteCallCount++ + c.AddRouteF(topic, callback) +} + +func (c *FakeClient) Disconnect(quiesce uint) { + c.disconnectCallCount++ + c.DisconnectF(quiesce) +} + +type FakeParser struct { +} + +// FakeParser satisfies parsers.Parser +var _ parsers.Parser = &FakeParser{} + +func (p *FakeParser) Parse(buf []byte) ([]telegraf.Metric, error) { + panic("not implemented") +} + +func (p *FakeParser) ParseLine(line string) (telegraf.Metric, error) { + panic("not implemented") +} + +func (p *FakeParser) SetDefaultTags(tags map[string]string) { + panic("not implemented") +} + +type FakeToken struct { + sessionPresent bool +} + +// FakeToken satisfies mqtt.Token +var _ mqtt.Token = &FakeToken{} + +func (t *FakeToken) Wait() bool { + return true +} + +func (t *FakeToken) WaitTimeout(time.Duration) bool { + return true +} + +func (t *FakeToken) Error() error { + return nil +} + +func (t *FakeToken) SessionPresent() bool { + return t.sessionPresent +} + +// Test the basic lifecycle transitions of the plugin. +func TestLifecycleSanity(t *testing.T) { + var acc testutil.Accumulator + + plugin := New(func(o *mqtt.ClientOptions) Client { + return &FakeClient{ + ConnectF: func() mqtt.Token { + return &FakeToken{} + }, + AddRouteF: func(topic string, callback mqtt.MessageHandler) { + }, + SubscribeMultipleF: func(filters map[string]byte, callback mqtt.MessageHandler) mqtt.Token { + return &FakeToken{} + }, + DisconnectF: func(quiesce uint) { + }, + } + }) + plugin.Servers = []string{"tcp://127.0.0.1"} + + parser := &FakeParser{} + plugin.SetParser(parser) + + err := plugin.Init() + require.NoError(t, err) + + err = plugin.Start(&acc) + require.NoError(t, err) + + err = plugin.Gather(&acc) + require.NoError(t, err) + + plugin.Stop() } // Test that default client has random ID func TestRandomClientID(t *testing.T) { - m1 := &MQTTConsumer{ - Servers: []string{"localhost:1883"}} - opts, err := m1.createOpts() - assert.NoError(t, err) + var err error - m2 := &MQTTConsumer{ - Servers: []string{"localhost:1883"}} - opts2, err2 := m2.createOpts() - assert.NoError(t, err2) + m1 := New(nil) + err = m1.Init() + require.NoError(t, err) - assert.NotEqual(t, opts.ClientID, opts2.ClientID) + m2 := New(nil) + err = m2.Init() + require.NoError(t, err) + + require.NotEqual(t, m1.opts.ClientID, m2.opts.ClientID) } -// Test that default client has random ID -func TestClientID(t *testing.T) { - m1 := &MQTTConsumer{ - Servers: []string{"localhost:1883"}, - ClientID: "telegraf-test", - } - opts, err := m1.createOpts() - assert.NoError(t, err) - - m2 := &MQTTConsumer{ - Servers: []string{"localhost:1883"}, - ClientID: "telegraf-test", - } - opts2, err2 := m2.createOpts() - assert.NoError(t, err2) - - assert.Equal(t, "telegraf-test", opts2.ClientID) - assert.Equal(t, "telegraf-test", opts.ClientID) -} - -// Test that Start() fails if client ID is not set but persistent is +// PersistentSession requires ClientID func TestPersistentClientIDFail(t *testing.T) { - m1 := &MQTTConsumer{ - Servers: []string{"localhost:1883"}, - PersistentSession: true, + plugin := New(nil) + plugin.PersistentSession = true + + err := plugin.Init() + require.Error(t, err) +} + +func TestAddRouteCalledForEachTopic(t *testing.T) { + client := &FakeClient{ + ConnectF: func() mqtt.Token { + return &FakeToken{} + }, + AddRouteF: func(topic string, callback mqtt.MessageHandler) { + }, + SubscribeMultipleF: func(filters map[string]byte, callback mqtt.MessageHandler) mqtt.Token { + return &FakeToken{} + }, + DisconnectF: func(quiesce uint) { + }, } - acc := testutil.Accumulator{} - err := m1.Start(&acc) - assert.Error(t, err) + plugin := New(func(o *mqtt.ClientOptions) Client { + return client + }) + plugin.Topics = []string{"a", "b"} + + err := plugin.Init() + require.NoError(t, err) + + var acc testutil.Accumulator + err = plugin.Start(&acc) + require.NoError(t, err) + + plugin.Stop() + + require.Equal(t, client.addRouteCallCount, 2) } -func mqttMsg(val string) mqtt.Message { - return &message{ - topic: "telegraf/unit_test", - payload: []byte(val), +func TestSubscribeCalledIfNoSession(t *testing.T) { + client := &FakeClient{ + ConnectF: func() mqtt.Token { + return &FakeToken{} + }, + AddRouteF: func(topic string, callback mqtt.MessageHandler) { + }, + SubscribeMultipleF: func(filters map[string]byte, callback mqtt.MessageHandler) mqtt.Token { + return &FakeToken{} + }, + DisconnectF: func(quiesce uint) { + }, } + plugin := New(func(o *mqtt.ClientOptions) Client { + return client + }) + plugin.Topics = []string{"b"} + + err := plugin.Init() + require.NoError(t, err) + + var acc testutil.Accumulator + err = plugin.Start(&acc) + require.NoError(t, err) + + plugin.Stop() + + require.Equal(t, client.subscribeCallCount, 1) } -// Take the message struct from the paho mqtt client library for returning -// a test message interface. -type message struct { - duplicate bool - qos byte - retained bool - topic string - messageID uint16 - payload []byte -} +func TestSubscribeNotCalledIfSession(t *testing.T) { + client := &FakeClient{ + ConnectF: func() mqtt.Token { + return &FakeToken{sessionPresent: true} + }, + AddRouteF: func(topic string, callback mqtt.MessageHandler) { + }, + SubscribeMultipleF: func(filters map[string]byte, callback mqtt.MessageHandler) mqtt.Token { + return &FakeToken{} + }, + DisconnectF: func(quiesce uint) { + }, + } + plugin := New(func(o *mqtt.ClientOptions) Client { + return client + }) + plugin.Topics = []string{"b"} -func (m *message) Duplicate() bool { - return m.duplicate -} + err := plugin.Init() + require.NoError(t, err) -func (m *message) Ack() { - return -} + var acc testutil.Accumulator + err = plugin.Start(&acc) + require.NoError(t, err) -func (m *message) Qos() byte { - return m.qos -} + plugin.Stop() -func (m *message) Retained() bool { - return m.retained -} - -func (m *message) Topic() string { - return m.topic -} - -func (m *message) MessageID() uint16 { - return m.messageID -} - -func (m *message) Payload() []byte { - return m.payload + require.Equal(t, client.subscribeCallCount, 0) }