From 152365ae06785d464d2695ec199b1bfafc11d2cd Mon Sep 17 00:00:00 2001 From: Daniel Nelson Date: Mon, 15 Oct 2018 13:03:52 -0700 Subject: [PATCH] Rework mqtt_consumer connect/reconnect (#4846) --- plugins/inputs/mqtt_consumer/mqtt_consumer.go | 123 ++++++++---------- .../mqtt_consumer/mqtt_consumer_test.go | 106 ++------------- 2 files changed, 68 insertions(+), 161 deletions(-) diff --git a/plugins/inputs/mqtt_consumer/mqtt_consumer.go b/plugins/inputs/mqtt_consumer/mqtt_consumer.go index 5853ad939..0a253b8d8 100644 --- a/plugins/inputs/mqtt_consumer/mqtt_consumer.go +++ b/plugins/inputs/mqtt_consumer/mqtt_consumer.go @@ -1,10 +1,10 @@ package mqtt_consumer import ( + "errors" "fmt" "log" "strings" - "sync" "time" "github.com/influxdata/telegraf" @@ -19,6 +19,14 @@ import ( // 30 Seconds is the default used by paho.mqtt.golang var defaultConnectionTimeout = internal.Duration{Duration: 30 * time.Second} +type ConnectionState int + +const ( + Disconnected ConnectionState = iota + Connecting + Connected +) + type MQTTConsumer struct { Servers []string Topics []string @@ -36,16 +44,10 @@ type MQTTConsumer struct { ClientID string `toml:"client_id"` tls.ClientConfig - sync.Mutex - client mqtt.Client - // channel of all incoming raw mqtt messages - in chan mqtt.Message - done chan struct{} - - // keep the accumulator internally: - acc telegraf.Accumulator - - connected bool + client mqtt.Client + acc telegraf.Accumulator + state ConnectionState + subscribed bool } var sampleConfig = ` @@ -110,22 +112,19 @@ func (m *MQTTConsumer) SetParser(parser parsers.Parser) { } func (m *MQTTConsumer) Start(acc telegraf.Accumulator) error { - m.Lock() - defer m.Unlock() - m.connected = false + m.state = Disconnected if m.PersistentSession && m.ClientID == "" { - return fmt.Errorf("ERROR MQTT Consumer: When using persistent_session" + - " = true, you MUST also set client_id") + return errors.New("persistent_session requires client_id") } m.acc = acc if m.QoS > 2 || m.QoS < 0 { - return fmt.Errorf("MQTT Consumer, invalid QoS value: %d", m.QoS) + return fmt.Errorf("qos value must be 0, 1, or 2: %d", m.QoS) } if m.ConnectionTimeout.Duration < 1*time.Second { - return fmt.Errorf("MQTT Consumer, invalid connection_timeout value: %s", m.ConnectionTimeout.Duration) + return fmt.Errorf("connection_timeout must be greater than 1s: %s", m.ConnectionTimeout.Duration) } opts, err := m.createOpts() @@ -134,9 +133,7 @@ func (m *MQTTConsumer) Start(acc telegraf.Accumulator) error { } m.client = mqtt.NewClient(opts) - m.in = make(chan mqtt.Message, 1000) - m.done = make(chan struct{}) - + m.state = Connecting m.connect() return nil @@ -145,80 +142,68 @@ func (m *MQTTConsumer) Start(acc telegraf.Accumulator) error { func (m *MQTTConsumer) connect() error { if token := m.client.Connect(); token.Wait() && token.Error() != nil { err := token.Error() - log.Printf("D! MQTT Consumer, connection error - %v", err) - + m.state = Disconnected return err } - go m.receiver() + log.Printf("I! [inputs.mqtt_consumer]: connected %v", m.Servers) + m.state = Connected - return nil -} - -func (m *MQTTConsumer) onConnect(c mqtt.Client) { - log.Printf("I! MQTT Client Connected") - if !m.PersistentSession || !m.connected { + // Only subscribe on first connection when using persistent sessions. On + // subsequent connections the subscriptions should be stored in the + // session, but the proper way to do this is to check the connection + // response to ensure a session was found. + if !m.PersistentSession || !m.subscribed { topics := make(map[string]byte) for _, topic := range m.Topics { topics[topic] = byte(m.QoS) } - subscribeToken := c.SubscribeMultiple(topics, m.recvMessage) + subscribeToken := m.client.SubscribeMultiple(topics, m.recvMessage) subscribeToken.Wait() if subscribeToken.Error() != nil { - m.acc.AddError(fmt.Errorf("E! MQTT Subscribe Error\ntopics: %s\nerror: %s", + m.acc.AddError(fmt.Errorf("subscription error: topics: %s: %v", strings.Join(m.Topics[:], ","), subscribeToken.Error())) } - m.connected = true + m.subscribed = true } - return + + return nil } func (m *MQTTConsumer) onConnectionLost(c mqtt.Client, err error) { - m.acc.AddError(fmt.Errorf("E! MQTT Connection lost\nerror: %s\nMQTT Client will try to reconnect", err.Error())) + m.acc.AddError(fmt.Errorf("connection lost: %v", err)) + log.Printf("D! [inputs.mqtt_consumer]: disconnected %v", m.Servers) + m.state = Disconnected return } -// receiver() reads all incoming messages from the consumer, and parses them into -// influxdb metric points. -func (m *MQTTConsumer) receiver() { - for { - select { - case <-m.done: - return - case msg := <-m.in: - topic := msg.Topic() - metrics, err := m.parser.Parse(msg.Payload()) - if err != nil { - m.acc.AddError(fmt.Errorf("E! MQTT Parse Error\nmessage: %s\nerror: %s", - string(msg.Payload()), err.Error())) - } +func (m *MQTTConsumer) recvMessage(c mqtt.Client, msg mqtt.Message) { + topic := msg.Topic() + metrics, err := m.parser.Parse(msg.Payload()) + if err != nil { + m.acc.AddError(err) + } - for _, metric := range metrics { - tags := metric.Tags() - tags["topic"] = topic - m.acc.AddFields(metric.Name(), metric.Fields(), tags, metric.Time()) - } - } + for _, metric := range metrics { + tags := metric.Tags() + tags["topic"] = topic + m.acc.AddFields(metric.Name(), metric.Fields(), tags, metric.Time()) } } -func (m *MQTTConsumer) recvMessage(_ mqtt.Client, msg mqtt.Message) { - m.in <- msg -} - func (m *MQTTConsumer) Stop() { - m.Lock() - defer m.Unlock() - - if m.connected { - close(m.done) + if m.state == Connected { + log.Printf("D! [inputs.mqtt_consumer]: disconnecting %v", m.Servers) m.client.Disconnect(200) - m.connected = false + log.Printf("D! [inputs.mqtt_consumer]: disconnected %v", m.Servers) + m.state = Disconnected } } func (m *MQTTConsumer) Gather(acc telegraf.Accumulator) error { - if !m.connected { + if m.state == Disconnected { + m.state = Connecting + log.Printf("D! [inputs.mqtt_consumer]: connecting %v", m.Servers) m.connect() } @@ -261,7 +246,7 @@ func (m *MQTTConsumer) createOpts() (*mqtt.ClientOptions, error) { for _, server := range m.Servers { // Preserve support for host:port style servers; deprecated in Telegraf 1.4.4 if !strings.Contains(server, "://") { - log.Printf("W! mqtt_consumer server %q should be updated to use `scheme://host:port` format", server) + log.Printf("W! [inputs.mqtt_consumer] server %q should be updated to use `scheme://host:port` format", server) if tlsCfg == nil { server = "tcp://" + server } else { @@ -271,10 +256,9 @@ func (m *MQTTConsumer) createOpts() (*mqtt.ClientOptions, error) { opts.AddBroker(server) } - opts.SetAutoReconnect(true) + opts.SetAutoReconnect(false) opts.SetKeepAlive(time.Second * 60) opts.SetCleanSession(!m.PersistentSession) - opts.SetOnConnectHandler(m.onConnect) opts.SetConnectionLostHandler(m.onConnectionLost) return opts, nil @@ -284,6 +268,7 @@ func init() { inputs.Add("mqtt_consumer", func() telegraf.Input { return &MQTTConsumer{ ConnectionTimeout: defaultConnectionTimeout, + state: Disconnected, } }) } diff --git a/plugins/inputs/mqtt_consumer/mqtt_consumer_test.go b/plugins/inputs/mqtt_consumer/mqtt_consumer_test.go index a2e5deaa8..c04bd18a7 100644 --- a/plugins/inputs/mqtt_consumer/mqtt_consumer_test.go +++ b/plugins/inputs/mqtt_consumer/mqtt_consumer_test.go @@ -12,24 +12,17 @@ import ( ) const ( - testMsg = "cpu_load_short,host=server01 value=23422.0 1422568543702900257\n" - testMsgNeg = "cpu_load_short,host=server01 value=-23422.0 1422568543702900257\n" - testMsgGraphite = "cpu.load.short.graphite 23422 1454780029" - testMsgJSON = "{\"a\": 5, \"b\": {\"c\": 6}}\n" - invalidMsg = "cpu_load_short,host=server01 1422568543702900257\n" + testMsg = "cpu_load_short,host=server01 value=23422.0 1422568543702900257\n" + invalidMsg = "cpu_load_short,host=server01 1422568543702900257\n" ) -func newTestMQTTConsumer() (*MQTTConsumer, chan mqtt.Message) { - in := make(chan mqtt.Message, 100) +func newTestMQTTConsumer() *MQTTConsumer { n := &MQTTConsumer{ - Topics: []string{"telegraf"}, - Servers: []string{"localhost:1883"}, - in: in, - done: make(chan struct{}), - connected: true, + Topics: []string{"telegraf"}, + Servers: []string{"localhost:1883"}, } - return n, in + return n } // Test that default client has random ID @@ -79,31 +72,12 @@ func TestPersistentClientIDFail(t *testing.T) { } func TestRunParser(t *testing.T) { - n, in := newTestMQTTConsumer() + n := newTestMQTTConsumer() acc := testutil.Accumulator{} n.acc = &acc - defer close(n.done) - n.parser, _ = parsers.NewInfluxParser() - go n.receiver() - in <- mqttMsg(testMsgNeg) - acc.Wait(1) - if a := acc.NFields(); a != 1 { - t.Errorf("got %v, expected %v", a, 1) - } -} - -func TestRunParserNegativeNumber(t *testing.T) { - n, in := newTestMQTTConsumer() - acc := testutil.Accumulator{} - n.acc = &acc - defer close(n.done) - - n.parser, _ = parsers.NewInfluxParser() - go n.receiver() - in <- mqttMsg(testMsg) - acc.Wait(1) + n.recvMessage(nil, mqttMsg(testMsg)) if a := acc.NFields(); a != 1 { t.Errorf("got %v, expected %v", a, 1) @@ -112,84 +86,32 @@ func TestRunParserNegativeNumber(t *testing.T) { // Test that the parser ignores invalid messages func TestRunParserInvalidMsg(t *testing.T) { - n, in := newTestMQTTConsumer() + n := newTestMQTTConsumer() acc := testutil.Accumulator{} n.acc = &acc - defer close(n.done) - n.parser, _ = parsers.NewInfluxParser() - go n.receiver() - in <- mqttMsg(invalidMsg) - acc.WaitError(1) + + n.recvMessage(nil, mqttMsg(invalidMsg)) if a := acc.NFields(); a != 0 { t.Errorf("got %v, expected %v", a, 0) } - assert.Contains(t, acc.Errors[0].Error(), "MQTT Parse Error") + assert.Len(t, acc.Errors, 1) } // Test that the parser parses line format messages into metrics func TestRunParserAndGather(t *testing.T) { - n, in := newTestMQTTConsumer() + n := newTestMQTTConsumer() acc := testutil.Accumulator{} n.acc = &acc - - defer close(n.done) - n.parser, _ = parsers.NewInfluxParser() - go n.receiver() - in <- mqttMsg(testMsg) - acc.Wait(1) - n.Gather(&acc) + n.recvMessage(nil, mqttMsg(testMsg)) acc.AssertContainsFields(t, "cpu_load_short", map[string]interface{}{"value": float64(23422)}) } -// Test that the parser parses graphite format messages into metrics -func TestRunParserAndGatherGraphite(t *testing.T) { - n, in := newTestMQTTConsumer() - acc := testutil.Accumulator{} - n.acc = &acc - defer close(n.done) - - n.parser, _ = parsers.NewGraphiteParser("_", []string{}, nil) - go n.receiver() - in <- mqttMsg(testMsgGraphite) - - n.Gather(&acc) - acc.Wait(1) - - acc.AssertContainsFields(t, "cpu_load_short_graphite", - map[string]interface{}{"value": float64(23422)}) -} - -// Test that the parser parses json format messages into metrics -func TestRunParserAndGatherJSON(t *testing.T) { - n, in := newTestMQTTConsumer() - acc := testutil.Accumulator{} - n.acc = &acc - defer close(n.done) - - n.parser, _ = parsers.NewParser(&parsers.Config{ - DataFormat: "json", - MetricName: "nats_json_test", - }) - go n.receiver() - in <- mqttMsg(testMsgJSON) - - n.Gather(&acc) - - acc.Wait(1) - - acc.AssertContainsFields(t, "nats_json_test", - map[string]interface{}{ - "a": float64(5), - "b_c": float64(6), - }) -} - func mqttMsg(val string) mqtt.Message { return &message{ topic: "telegraf/unit_test",