Use mysql.ParseDSN func instead of url.Parse

The MySQL DB driver has it's own DSN parsing function. Previously we
were using the url.Parse function, but this causes problems because a
valid MySQL DSN can be an invalid http URL, namely when using some
special characters in the password.

This change uses the MySQL DB driver's builtin ParseDSN function and
applies a timeout parameter natively via that.

Another benefit of this change is that we fail earlier if given an
invalid MySQL DSN.

closes #870
closes #1842
This commit is contained in:
Cameron Sparr 2016-10-12 12:43:51 +01:00
parent b00ad65b08
commit a65447d22e
4 changed files with 55 additions and 171 deletions

View File

@ -67,6 +67,7 @@ continue sending logs to /var/log/telegraf/telegraf.log.
- [#1886](https://github.com/influxdata/telegraf/issues/1886): Fix phpfpm fcgi client panic when URL does not exist. - [#1886](https://github.com/influxdata/telegraf/issues/1886): Fix phpfpm fcgi client panic when URL does not exist.
- [#1344](https://github.com/influxdata/telegraf/issues/1344): Fix config file parse error logging. - [#1344](https://github.com/influxdata/telegraf/issues/1344): Fix config file parse error logging.
- [#1771](https://github.com/influxdata/telegraf/issues/1771): Delete nil fields in the metric maker. - [#1771](https://github.com/influxdata/telegraf/issues/1771): Delete nil fields in the metric maker.
- [#870](https://github.com/influxdata/telegraf/issues/870): Fix MySQL special characters in DSN parsing.
## v1.0.1 [2016-09-26] ## v1.0.1 [2016-09-26]

View File

@ -4,16 +4,16 @@ import (
"bytes" "bytes"
"database/sql" "database/sql"
"fmt" "fmt"
"net/url"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
"time" "time"
_ "github.com/go-sql-driver/mysql"
"github.com/influxdata/telegraf" "github.com/influxdata/telegraf"
"github.com/influxdata/telegraf/internal/errchan" "github.com/influxdata/telegraf/internal/errchan"
"github.com/influxdata/telegraf/plugins/inputs" "github.com/influxdata/telegraf/plugins/inputs"
"github.com/go-sql-driver/mysql"
) )
type Mysql struct { type Mysql struct {
@ -398,27 +398,6 @@ var (
} }
) )
func dsnAddTimeout(dsn string) (string, error) {
// DSN "?timeout=5s" is not valid, but "/?timeout=5s" is valid ("" and "/"
// are the same DSN)
if dsn == "" {
dsn = "/"
}
u, err := url.Parse(dsn)
if err != nil {
return "", err
}
v := u.Query()
// Only override timeout if not already defined
if _, ok := v["timeout"]; ok == false {
v.Add("timeout", defaultTimeout.String())
u.RawQuery = v.Encode()
}
return u.String(), nil
}
// Math constants // Math constants
const ( const (
picoSeconds = 1e12 picoSeconds = 1e12
@ -682,10 +661,7 @@ func (m *Mysql) gatherGlobalVariables(db *sql.DB, serv string, acc telegraf.Accu
var val sql.RawBytes var val sql.RawBytes
// parse DSN and save server tag // parse DSN and save server tag
servtag, err := parseDSN(serv) servtag := getDSNTag(serv)
if err != nil {
servtag = "localhost"
}
tags := map[string]string{"server": servtag} tags := map[string]string{"server": servtag}
fields := make(map[string]interface{}) fields := make(map[string]interface{})
for rows.Next() { for rows.Next() {
@ -722,10 +698,7 @@ func (m *Mysql) gatherSlaveStatuses(db *sql.DB, serv string, acc telegraf.Accumu
} }
defer rows.Close() defer rows.Close()
servtag, err := parseDSN(serv) servtag := getDSNTag(serv)
if err != nil {
servtag = "localhost"
}
tags := map[string]string{"server": servtag} tags := map[string]string{"server": servtag}
fields := make(map[string]interface{}) fields := make(map[string]interface{})
@ -770,11 +743,7 @@ func (m *Mysql) gatherBinaryLogs(db *sql.DB, serv string, acc telegraf.Accumulat
defer rows.Close() defer rows.Close()
// parse DSN and save host as a tag // parse DSN and save host as a tag
var servtag string servtag := getDSNTag(serv)
servtag, err = parseDSN(serv)
if err != nil {
servtag = "localhost"
}
tags := map[string]string{"server": servtag} tags := map[string]string{"server": servtag}
var ( var (
size uint64 = 0 size uint64 = 0
@ -817,11 +786,7 @@ func (m *Mysql) gatherGlobalStatuses(db *sql.DB, serv string, acc telegraf.Accum
} }
// parse the DSN and save host name as a tag // parse the DSN and save host name as a tag
var servtag string servtag := getDSNTag(serv)
servtag, err = parseDSN(serv)
if err != nil {
servtag = "localhost"
}
tags := map[string]string{"server": servtag} tags := map[string]string{"server": servtag}
fields := make(map[string]interface{}) fields := make(map[string]interface{})
for rows.Next() { for rows.Next() {
@ -932,10 +897,7 @@ func (m *Mysql) GatherProcessListStatuses(db *sql.DB, serv string, acc telegraf.
var servtag string var servtag string
fields := make(map[string]interface{}) fields := make(map[string]interface{})
servtag, err = parseDSN(serv) servtag = getDSNTag(serv)
if err != nil {
servtag = "localhost"
}
// mapping of state with its counts // mapping of state with its counts
stateCounts := make(map[string]uint32, len(generalThreadStates)) stateCounts := make(map[string]uint32, len(generalThreadStates))
@ -978,10 +940,7 @@ func (m *Mysql) gatherPerfTableIOWaits(db *sql.DB, serv string, acc telegraf.Acc
timeFetch, timeInsert, timeUpdate, timeDelete float64 timeFetch, timeInsert, timeUpdate, timeDelete float64
) )
servtag, err = parseDSN(serv) servtag = getDSNTag(serv)
if err != nil {
servtag = "localhost"
}
for rows.Next() { for rows.Next() {
err = rows.Scan(&objSchema, &objName, err = rows.Scan(&objSchema, &objName,
@ -1030,10 +989,7 @@ func (m *Mysql) gatherPerfIndexIOWaits(db *sql.DB, serv string, acc telegraf.Acc
timeFetch, timeInsert, timeUpdate, timeDelete float64 timeFetch, timeInsert, timeUpdate, timeDelete float64
) )
servtag, err = parseDSN(serv) servtag = getDSNTag(serv)
if err != nil {
servtag = "localhost"
}
for rows.Next() { for rows.Next() {
err = rows.Scan(&objSchema, &objName, &indexName, err = rows.Scan(&objSchema, &objName, &indexName,
@ -1085,10 +1041,7 @@ func (m *Mysql) gatherInfoSchemaAutoIncStatuses(db *sql.DB, serv string, acc tel
incValue, maxInt uint64 incValue, maxInt uint64
) )
servtag, err := parseDSN(serv) servtag := getDSNTag(serv)
if err != nil {
servtag = "localhost"
}
for rows.Next() { for rows.Next() {
if err := rows.Scan(&schema, &table, &column, &incValue, &maxInt); err != nil { if err := rows.Scan(&schema, &table, &column, &incValue, &maxInt); err != nil {
@ -1132,10 +1085,7 @@ func (m *Mysql) gatherPerfTableLockWaits(db *sql.DB, serv string, acc telegraf.A
} }
defer rows.Close() defer rows.Close()
servtag, err := parseDSN(serv) servtag := getDSNTag(serv)
if err != nil {
servtag = "localhost"
}
var ( var (
objectSchema string objectSchema string
@ -1257,10 +1207,7 @@ func (m *Mysql) gatherPerfEventWaits(db *sql.DB, serv string, acc telegraf.Accum
starCount, timeWait float64 starCount, timeWait float64
) )
servtag, err := parseDSN(serv) servtag := getDSNTag(serv)
if err != nil {
servtag = "localhost"
}
tags := map[string]string{ tags := map[string]string{
"server": servtag, "server": servtag,
} }
@ -1295,10 +1242,7 @@ func (m *Mysql) gatherPerfFileEventsStatuses(db *sql.DB, serv string, acc telegr
sumNumBytesRead, sumNumBytesWrite float64 sumNumBytesRead, sumNumBytesWrite float64
) )
servtag, err := parseDSN(serv) servtag := getDSNTag(serv)
if err != nil {
servtag = "localhost"
}
tags := map[string]string{ tags := map[string]string{
"server": servtag, "server": servtag,
} }
@ -1365,10 +1309,7 @@ func (m *Mysql) gatherPerfEventsStatements(db *sql.DB, serv string, acc telegraf
noIndexUsed float64 noIndexUsed float64
) )
servtag, err := parseDSN(serv) servtag := getDSNTag(serv)
if err != nil {
servtag = "localhost"
}
tags := map[string]string{ tags := map[string]string{
"server": servtag, "server": servtag,
} }
@ -1412,14 +1353,8 @@ func (m *Mysql) gatherPerfEventsStatements(db *sql.DB, serv string, acc telegraf
// gatherTableSchema can be used to gather stats on each schema // gatherTableSchema can be used to gather stats on each schema
func (m *Mysql) gatherTableSchema(db *sql.DB, serv string, acc telegraf.Accumulator) error { func (m *Mysql) gatherTableSchema(db *sql.DB, serv string, acc telegraf.Accumulator) error {
var ( var dbList []string
dbList []string servtag := getDSNTag(serv)
servtag string
)
servtag, err := parseDSN(serv)
if err != nil {
servtag = "localhost"
}
// if the list of databases if empty, then get all databases // if the list of databases if empty, then get all databases
if len(m.TableSchemaDatabases) == 0 { if len(m.TableSchemaDatabases) == 0 {
@ -1575,6 +1510,27 @@ func copyTags(in map[string]string) map[string]string {
return out return out
} }
func dsnAddTimeout(dsn string) (string, error) {
conf, err := mysql.ParseDSN(dsn)
if err != nil {
return "", err
}
if conf.Timeout == 0 {
conf.Timeout = time.Second * 5
}
return conf.FormatDSN(), nil
}
func getDSNTag(dsn string) string {
conf, err := mysql.ParseDSN(dsn)
if err != nil {
return "127.0.0.1:3306"
}
return conf.Addr
}
func init() { func init() {
inputs.Add("mysql", func() telegraf.Input { inputs.Add("mysql", func() telegraf.Input {
return &Mysql{} return &Mysql{}

View File

@ -26,7 +26,7 @@ func TestMysqlDefaultsToLocal(t *testing.T) {
assert.True(t, acc.HasMeasurement("mysql")) assert.True(t, acc.HasMeasurement("mysql"))
} }
func TestMysqlParseDSN(t *testing.T) { func TestMysqlGetDSNTag(t *testing.T) {
tests := []struct { tests := []struct {
input string input string
output string output string
@ -78,9 +78,9 @@ func TestMysqlParseDSN(t *testing.T) {
} }
for _, test := range tests { for _, test := range tests {
output, _ := parseDSN(test.input) output := getDSNTag(test.input)
if output != test.output { if output != test.output {
t.Errorf("Expected %s, got %s\n", test.output, output) t.Errorf("Input: %s Expected %s, got %s\n", test.input, test.output, output)
} }
} }
} }
@ -92,7 +92,7 @@ func TestMysqlDNSAddTimeout(t *testing.T) {
}{ }{
{ {
"", "",
"/?timeout=5s", "tcp(127.0.0.1:3306)/?timeout=5s",
}, },
{ {
"tcp(192.168.1.1:3306)/", "tcp(192.168.1.1:3306)/",
@ -104,7 +104,19 @@ func TestMysqlDNSAddTimeout(t *testing.T) {
}, },
{ {
"root:passwd@tcp(192.168.1.1:3306)/?tls=false&timeout=10s", "root:passwd@tcp(192.168.1.1:3306)/?tls=false&timeout=10s",
"root:passwd@tcp(192.168.1.1:3306)/?tls=false&timeout=10s", "root:passwd@tcp(192.168.1.1:3306)/?timeout=10s&tls=false",
},
{
"tcp(10.150.1.123:3306)/",
"tcp(10.150.1.123:3306)/?timeout=5s",
},
{
"root:@!~(*&$#%(&@#(@&#Password@tcp(10.150.1.123:3306)/",
"root:@!~(*&$#%(&@#(@&#Password@tcp(10.150.1.123:3306)/?timeout=5s",
},
{
"root:Test3a#@!@tcp(10.150.1.123:3306)/",
"root:Test3a#@!@tcp(10.150.1.123:3306)/?timeout=5s",
}, },
} }

View File

@ -1,85 +0,0 @@
// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
import (
"errors"
"strings"
)
// parseDSN parses the DSN string to a config
func parseDSN(dsn string) (string, error) {
//var user, passwd string
var addr, net string
// [user[:password]@][net[(addr)]]/dbname[?param1=value1&paramN=valueN]
// Find the last '/' (since the password or the net addr might contain a '/')
for i := len(dsn) - 1; i >= 0; i-- {
if dsn[i] == '/' {
var j, k int
// left part is empty if i <= 0
if i > 0 {
// [username[:password]@][protocol[(address)]]
// Find the last '@' in dsn[:i]
for j = i; j >= 0; j-- {
if dsn[j] == '@' {
// username[:password]
// Find the first ':' in dsn[:j]
for k = 0; k < j; k++ {
if dsn[k] == ':' {
//passwd = dsn[k+1 : j]
break
}
}
//user = dsn[:k]
break
}
}
// [protocol[(address)]]
// Find the first '(' in dsn[j+1:i]
for k = j + 1; k < i; k++ {
if dsn[k] == '(' {
// dsn[i-1] must be == ')' if an address is specified
if dsn[i-1] != ')' {
if strings.ContainsRune(dsn[k+1:i], ')') {
return "", errors.New("Invalid DSN unescaped")
}
return "", errors.New("Invalid DSN Addr")
}
addr = dsn[k+1 : i-1]
break
}
}
net = dsn[j+1 : k]
}
break
}
}
// Set default network if empty
if net == "" {
net = "tcp"
}
// Set default address if empty
if addr == "" {
switch net {
case "tcp":
addr = "127.0.0.1:3306"
case "unix":
addr = "/tmp/mysql.sock"
default:
return "", errors.New("Default addr for network '" + net + "' unknown")
}
}
return addr, nil
}