diff --git a/CHANGELOG.md b/CHANGELOG.md index 232ccd940..ea9b8721a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -63,6 +63,10 @@ `use_random_partitionkey` options has been deprecated in favor of the `partition` subtable. This allows for more flexible methods to set the partition key such as by metric name or by tag. +- `postgresql` plugins will now default to using a persistent connection to the database. + `Important` In environments TCP connections are terminated when idle for periods shorter than 15 minutes + and the collection interval is longer than the termination period then max_lifetime + should be set to be less than the collection interval to pervent errors when collecting metrics. - With the release of the new improved `jolokia2` input, the legacy `jolokia` plugin is deprecated and will be removed in a future release. Users of this diff --git a/plugins/inputs/postgresql/connect.go b/plugins/inputs/postgresql/connect.go deleted file mode 100644 index 011ae32e0..000000000 --- a/plugins/inputs/postgresql/connect.go +++ /dev/null @@ -1,77 +0,0 @@ -package postgresql - -import ( - "fmt" - "net" - "net/url" - "sort" - "strings" -) - -// pulled from lib/pq -// ParseURL no longer needs to be used by clients of this library since supplying a URL as a -// connection string to sql.Open() is now supported: -// -// sql.Open("postgres", "postgres://bob:secret@1.2.3.4:5432/mydb?sslmode=verify-full") -// -// It remains exported here for backwards-compatibility. -// -// ParseURL converts a url to a connection string for driver.Open. -// Example: -// -// "postgres://bob:secret@1.2.3.4:5432/mydb?sslmode=verify-full" -// -// converts to: -// -// "user=bob password=secret host=1.2.3.4 port=5432 dbname=mydb sslmode=verify-full" -// -// A minimal example: -// -// "postgres://" -// -// This will be blank, causing driver.Open to use all of the defaults -func ParseURL(uri string) (string, error) { - u, err := url.Parse(uri) - if err != nil { - return "", err - } - - if u.Scheme != "postgres" && u.Scheme != "postgresql" { - return "", fmt.Errorf("invalid connection protocol: %s", u.Scheme) - } - - var kvs []string - escaper := strings.NewReplacer(` `, `\ `, `'`, `\'`, `\`, `\\`) - accrue := func(k, v string) { - if v != "" { - kvs = append(kvs, k+"="+escaper.Replace(v)) - } - } - - if u.User != nil { - v := u.User.Username() - accrue("user", v) - - v, _ = u.User.Password() - accrue("password", v) - } - - if host, port, err := net.SplitHostPort(u.Host); err != nil { - accrue("host", u.Host) - } else { - accrue("host", host) - accrue("port", port) - } - - if u.Path != "" { - accrue("dbname", u.Path[1:]) - } - - q := u.Query() - for k := range q { - accrue(k, q.Get(k)) - } - - sort.Strings(kvs) // Makes testing easier (not a performance concern) - return strings.Join(kvs, " "), nil -} diff --git a/plugins/inputs/postgresql/postgresql.go b/plugins/inputs/postgresql/postgresql.go index 832c433ed..19c9db9ce 100644 --- a/plugins/inputs/postgresql/postgresql.go +++ b/plugins/inputs/postgresql/postgresql.go @@ -2,26 +2,21 @@ package postgresql import ( "bytes" - "database/sql" "fmt" - "regexp" - "sort" "strings" // register in driver. _ "github.com/jackc/pgx/stdlib" "github.com/influxdata/telegraf" + "github.com/influxdata/telegraf/internal" "github.com/influxdata/telegraf/plugins/inputs" ) type Postgresql struct { - Address string + Service Databases []string IgnoredDatabases []string - OrderedColumns []string - AllColumns []string - sanitizedAddress string } var ignoredColumns = map[string]bool{"stats_reset": true} @@ -41,6 +36,15 @@ var sampleConfig = ` ## to grab metrics for. ## address = "host=localhost user=postgres sslmode=disable" + ## A custom name for the database that will be used as the "server" tag in the + ## measurement output. If not specified, a default one generated from + ## the connection address is used. + # outputaddress = "db01" + + ## connection configuration. + ## maxlifetime - specify the maximum lifetime of a connection. + ## default is forever (0s) + max_lifetime = "0s" ## A list of databases to explicitly ignore. If not specified, metrics for all ## databases are gathered. Do NOT use with the 'databases' option. @@ -63,24 +67,13 @@ func (p *Postgresql) IgnoredColumns() map[string]bool { return ignoredColumns } -var localhost = "host=localhost sslmode=disable" - func (p *Postgresql) Gather(acc telegraf.Accumulator) error { var ( - err error - db *sql.DB - query string + err error + query string + columns []string ) - if p.Address == "" || p.Address == "localhost" { - p.Address = localhost - } - - if db, err = sql.Open("pgx", p.Address); err != nil { - return err - } - defer db.Close() - if len(p.Databases) == 0 && len(p.IgnoredDatabases) == 0 { query = `SELECT * FROM pg_stat_database` } else if len(p.IgnoredDatabases) != 0 { @@ -91,7 +84,7 @@ func (p *Postgresql) Gather(acc telegraf.Accumulator) error { strings.Join(p.Databases, "','")) } - rows, err := db.Query(query) + rows, err := p.DB.Query(query) if err != nil { return err } @@ -99,16 +92,12 @@ func (p *Postgresql) Gather(acc telegraf.Accumulator) error { defer rows.Close() // grab the column information from the result - p.OrderedColumns, err = rows.Columns() - if err != nil { + if columns, err = rows.Columns(); err != nil { return err - } else { - p.AllColumns = make([]string, len(p.OrderedColumns)) - copy(p.AllColumns, p.OrderedColumns) } for rows.Next() { - err = p.accRow(rows, acc) + err = p.accRow(rows, acc, columns) if err != nil { return err } @@ -116,7 +105,7 @@ func (p *Postgresql) Gather(acc telegraf.Accumulator) error { query = `SELECT * FROM pg_stat_bgwriter` - bg_writer_row, err := db.Query(query) + bg_writer_row, err := p.DB.Query(query) if err != nil { return err } @@ -124,22 +113,17 @@ func (p *Postgresql) Gather(acc telegraf.Accumulator) error { defer bg_writer_row.Close() // grab the column information from the result - p.OrderedColumns, err = bg_writer_row.Columns() - if err != nil { + if columns, err = bg_writer_row.Columns(); err != nil { return err - } else { - for _, v := range p.OrderedColumns { - p.AllColumns = append(p.AllColumns, v) - } } for bg_writer_row.Next() { - err = p.accRow(bg_writer_row, acc) + err = p.accRow(bg_writer_row, acc, columns) if err != nil { return err } } - sort.Strings(p.AllColumns) + return bg_writer_row.Err() } @@ -147,37 +131,20 @@ type scanner interface { Scan(dest ...interface{}) error } -var passwordKVMatcher, _ = regexp.Compile("password=\\S+ ?") - -func (p *Postgresql) SanitizedAddress() (_ string, err error) { - var canonicalizedAddress string - if strings.HasPrefix(p.Address, "postgres://") || strings.HasPrefix(p.Address, "postgresql://") { - canonicalizedAddress, err = ParseURL(p.Address) - if err != nil { - return p.sanitizedAddress, err - } - } else { - canonicalizedAddress = p.Address - } - p.sanitizedAddress = passwordKVMatcher.ReplaceAllString(canonicalizedAddress, "") - - return p.sanitizedAddress, err -} - -func (p *Postgresql) accRow(row scanner, acc telegraf.Accumulator) error { +func (p *Postgresql) accRow(row scanner, acc telegraf.Accumulator, columns []string) error { var columnVars []interface{} var dbname bytes.Buffer // this is where we'll store the column name with its *interface{} columnMap := make(map[string]*interface{}) - for _, column := range p.OrderedColumns { + for _, column := range columns { columnMap[column] = new(interface{}) } // populate the array of interface{} with the pointers in the right order for i := 0; i < len(columnMap); i++ { - columnVars = append(columnVars, columnMap[p.OrderedColumns[i]]) + columnVars = append(columnVars, columnMap[columns[i]]) } // deconstruct array of variables and send to Scan @@ -215,6 +182,14 @@ func (p *Postgresql) accRow(row scanner, acc telegraf.Accumulator) error { func init() { inputs.Add("postgresql", func() telegraf.Input { - return &Postgresql{} + return &Postgresql{ + Service: Service{ + MaxIdle: 1, + MaxOpen: 1, + MaxLifetime: internal.Duration{ + Duration: 0, + }, + }, + } }) } diff --git a/plugins/inputs/postgresql/postgresql_test.go b/plugins/inputs/postgresql/postgresql_test.go index 410b9b421..306dca3b6 100644 --- a/plugins/inputs/postgresql/postgresql_test.go +++ b/plugins/inputs/postgresql/postgresql_test.go @@ -15,19 +15,18 @@ func TestPostgresqlGeneratesMetrics(t *testing.T) { } p := &Postgresql{ - Address: fmt.Sprintf("host=%s user=postgres sslmode=disable", - testutil.GetLocalHost()), + Service: Service{ + Address: fmt.Sprintf( + "host=%s user=postgres sslmode=disable", + testutil.GetLocalHost(), + ), + }, Databases: []string{"postgres"}, } var acc testutil.Accumulator - err := p.Gather(&acc) - require.NoError(t, err) - - availableColumns := make(map[string]bool) - for _, col := range p.AllColumns { - availableColumns[col] = true - } + require.NoError(t, p.Start(&acc)) + require.NoError(t, p.Gather(&acc)) intMetrics := []string{ "xact_commit", @@ -71,39 +70,27 @@ func TestPostgresqlGeneratesMetrics(t *testing.T) { metricsCounted := 0 for _, metric := range intMetrics { - _, ok := availableColumns[metric] - if ok { - assert.True(t, acc.HasInt64Field("postgresql", metric)) - metricsCounted++ - } + assert.True(t, acc.HasInt64Field("postgresql", metric)) + metricsCounted++ } for _, metric := range int32Metrics { - _, ok := availableColumns[metric] - if ok { - assert.True(t, acc.HasInt32Field("postgresql", metric)) - metricsCounted++ - } + assert.True(t, acc.HasInt32Field("postgresql", metric)) + metricsCounted++ } for _, metric := range floatMetrics { - _, ok := availableColumns[metric] - if ok { - assert.True(t, acc.HasFloatField("postgresql", metric)) - metricsCounted++ - } + assert.True(t, acc.HasFloatField("postgresql", metric)) + metricsCounted++ } for _, metric := range stringMetrics { - _, ok := availableColumns[metric] - if ok { - assert.True(t, acc.HasStringField("postgresql", metric)) - metricsCounted++ - } + assert.True(t, acc.HasStringField("postgresql", metric)) + metricsCounted++ } assert.True(t, metricsCounted > 0) - assert.Equal(t, len(availableColumns)-len(p.IgnoredColumns()), metricsCounted) + assert.Equal(t, len(floatMetrics)+len(intMetrics)+len(int32Metrics)+len(stringMetrics), metricsCounted) } func TestPostgresqlTagsMetricsWithDatabaseName(t *testing.T) { @@ -112,15 +99,19 @@ func TestPostgresqlTagsMetricsWithDatabaseName(t *testing.T) { } p := &Postgresql{ - Address: fmt.Sprintf("host=%s user=postgres sslmode=disable", - testutil.GetLocalHost()), + Service: Service{ + Address: fmt.Sprintf( + "host=%s user=postgres sslmode=disable", + testutil.GetLocalHost(), + ), + }, Databases: []string{"postgres"}, } var acc testutil.Accumulator - err := p.Gather(&acc) - require.NoError(t, err) + require.NoError(t, p.Start(&acc)) + require.NoError(t, p.Gather(&acc)) point, ok := acc.Get("postgresql") require.True(t, ok) @@ -134,14 +125,18 @@ func TestPostgresqlDefaultsToAllDatabases(t *testing.T) { } p := &Postgresql{ - Address: fmt.Sprintf("host=%s user=postgres sslmode=disable", - testutil.GetLocalHost()), + Service: Service{ + Address: fmt.Sprintf( + "host=%s user=postgres sslmode=disable", + testutil.GetLocalHost(), + ), + }, } var acc testutil.Accumulator - err := p.Gather(&acc) - require.NoError(t, err) + require.NoError(t, p.Start(&acc)) + require.NoError(t, p.Gather(&acc)) var found bool @@ -163,14 +158,17 @@ func TestPostgresqlIgnoresUnwantedColumns(t *testing.T) { } p := &Postgresql{ - Address: fmt.Sprintf("host=%s user=postgres sslmode=disable", - testutil.GetLocalHost()), + Service: Service{ + Address: fmt.Sprintf( + "host=%s user=postgres sslmode=disable", + testutil.GetLocalHost(), + ), + }, } var acc testutil.Accumulator - - err := p.Gather(&acc) - require.NoError(t, err) + require.NoError(t, p.Start(&acc)) + require.NoError(t, p.Gather(&acc)) for col := range p.IgnoredColumns() { assert.False(t, acc.HasMeasurement(col)) @@ -183,15 +181,19 @@ func TestPostgresqlDatabaseWhitelistTest(t *testing.T) { } p := &Postgresql{ - Address: fmt.Sprintf("host=%s user=postgres sslmode=disable", - testutil.GetLocalHost()), + Service: Service{ + Address: fmt.Sprintf( + "host=%s user=postgres sslmode=disable", + testutil.GetLocalHost(), + ), + }, Databases: []string{"template0"}, } var acc testutil.Accumulator - err := p.Gather(&acc) - require.NoError(t, err) + require.NoError(t, p.Start(&acc)) + require.NoError(t, p.Gather(&acc)) var foundTemplate0 = false var foundTemplate1 = false @@ -219,15 +221,18 @@ func TestPostgresqlDatabaseBlacklistTest(t *testing.T) { } p := &Postgresql{ - Address: fmt.Sprintf("host=%s user=postgres sslmode=disable", - testutil.GetLocalHost()), + Service: Service{ + Address: fmt.Sprintf( + "host=%s user=postgres sslmode=disable", + testutil.GetLocalHost(), + ), + }, IgnoredDatabases: []string{"template0"}, } var acc testutil.Accumulator - - err := p.Gather(&acc) - require.NoError(t, err) + require.NoError(t, p.Start(&acc)) + require.NoError(t, p.Gather(&acc)) var foundTemplate0 = false var foundTemplate1 = false diff --git a/plugins/inputs/postgresql/service.go b/plugins/inputs/postgresql/service.go new file mode 100644 index 000000000..4f7b21e54 --- /dev/null +++ b/plugins/inputs/postgresql/service.go @@ -0,0 +1,142 @@ +package postgresql + +import ( + "database/sql" + "fmt" + "net" + "net/url" + "regexp" + "sort" + "strings" + + "github.com/influxdata/telegraf" + "github.com/influxdata/telegraf/internal" +) + +// pulled from lib/pq +// ParseURL no longer needs to be used by clients of this library since supplying a URL as a +// connection string to sql.Open() is now supported: +// +// sql.Open("postgres", "postgres://bob:secret@1.2.3.4:5432/mydb?sslmode=verify-full") +// +// It remains exported here for backwards-compatibility. +// +// ParseURL converts a url to a connection string for driver.Open. +// Example: +// +// "postgres://bob:secret@1.2.3.4:5432/mydb?sslmode=verify-full" +// +// converts to: +// +// "user=bob password=secret host=1.2.3.4 port=5432 dbname=mydb sslmode=verify-full" +// +// A minimal example: +// +// "postgres://" +// +// This will be blank, causing driver.Open to use all of the defaults +func parseURL(uri string) (string, error) { + u, err := url.Parse(uri) + if err != nil { + return "", err + } + + if u.Scheme != "postgres" && u.Scheme != "postgresql" { + return "", fmt.Errorf("invalid connection protocol: %s", u.Scheme) + } + + var kvs []string + escaper := strings.NewReplacer(` `, `\ `, `'`, `\'`, `\`, `\\`) + accrue := func(k, v string) { + if v != "" { + kvs = append(kvs, k+"="+escaper.Replace(v)) + } + } + + if u.User != nil { + v := u.User.Username() + accrue("user", v) + + v, _ = u.User.Password() + accrue("password", v) + } + + if host, port, err := net.SplitHostPort(u.Host); err != nil { + accrue("host", u.Host) + } else { + accrue("host", host) + accrue("port", port) + } + + if u.Path != "" { + accrue("dbname", u.Path[1:]) + } + + q := u.Query() + for k := range q { + accrue(k, q.Get(k)) + } + + sort.Strings(kvs) // Makes testing easier (not a performance concern) + return strings.Join(kvs, " "), nil +} + +// Service common functionality shared between the postgresql and postgresql_extensible +// packages. +type Service struct { + Address string + Outputaddress string + MaxIdle int + MaxOpen int + MaxLifetime internal.Duration + DB *sql.DB +} + +// Start starts the ServiceInput's service, whatever that may be +func (p *Service) Start(telegraf.Accumulator) (err error) { + const localhost = "host=localhost sslmode=disable" + + if p.Address == "" || p.Address == "localhost" { + p.Address = localhost + } + + if p.DB, err = sql.Open("pgx", p.Address); err != nil { + return err + } + + p.DB.SetMaxOpenConns(p.MaxOpen) + p.DB.SetMaxIdleConns(p.MaxIdle) + p.DB.SetConnMaxLifetime(p.MaxLifetime.Duration) + + return nil +} + +// Stop stops the services and closes any necessary channels and connections +func (p *Service) Stop() { + p.DB.Close() +} + +var kvMatcher, _ = regexp.Compile("(password|sslcert|sslkey|sslmode|sslrootcert)=\\S+ ?") + +// SanitizedAddress utility function to strip sensitive information from the connection string. +func (p *Service) SanitizedAddress() (sanitizedAddress string, err error) { + var ( + canonicalizedAddress string + ) + + if p.Outputaddress != "" { + return p.Outputaddress, nil + } + + if strings.HasPrefix(p.Address, "postgres://") || strings.HasPrefix(p.Address, "postgresql://") { + if canonicalizedAddress, err = parseURL(p.Address); err != nil { + return sanitizedAddress, err + } + } else { + canonicalizedAddress = p.Address + } + + sanitizedAddress = kvMatcher.ReplaceAllString(canonicalizedAddress, "") + + return sanitizedAddress, err +} diff --git a/plugins/inputs/postgresql_extensible/postgresql_extensible.go b/plugins/inputs/postgresql_extensible/postgresql_extensible.go index 07a782f89..056f4afc8 100644 --- a/plugins/inputs/postgresql_extensible/postgresql_extensible.go +++ b/plugins/inputs/postgresql_extensible/postgresql_extensible.go @@ -2,29 +2,24 @@ package postgresql_extensible import ( "bytes" - "database/sql" "fmt" "log" - "regexp" "strings" // register in driver. _ "github.com/jackc/pgx/stdlib" "github.com/influxdata/telegraf" + "github.com/influxdata/telegraf/internal" "github.com/influxdata/telegraf/plugins/inputs" "github.com/influxdata/telegraf/plugins/inputs/postgresql" ) type Postgresql struct { - Address string - Outputaddress string - Databases []string - OrderedColumns []string - AllColumns []string - AdditionalTags []string - sanitizedAddress string - Query []struct { + postgresql.Service + Databases []string + AdditionalTags []string + Query []struct { Sqlquery string Version int Withdbname bool @@ -58,14 +53,20 @@ var sampleConfig = ` ## to grab metrics for. # address = "host=localhost user=postgres sslmode=disable" + + ## connection configuration. + ## maxlifetime - specify the maximum lifetime of a connection. + ## default is forever (0s) + max_lifetime = "0s" + ## A list of databases to pull metrics about. If not specified, metrics for all ## databases are gathered. ## databases = ["app_production", "testing"] # - # outputaddress = "db01" ## A custom name for the database that will be used as the "server" tag in the ## measurement output. If not specified, a default one generated from ## the connection address is used. + # outputaddress = "db01" # ## Define the toml config where the sql queries are stored ## New queries can be added, if the withdbname is set to true and there is no @@ -113,36 +114,25 @@ func (p *Postgresql) IgnoredColumns() map[string]bool { return ignoredColumns } -var localhost = "host=localhost sslmode=disable" - func (p *Postgresql) Gather(acc telegraf.Accumulator) error { var ( err error - db *sql.DB sql_query string query_addon string db_version int query string tag_value string meas_name string + columns []string ) - if p.Address == "" || p.Address == "localhost" { - p.Address = localhost - } - - if db, err = sql.Open("pgx", p.Address); err != nil { - return err - } - defer db.Close() - // Retreiving the database version query = `select substring(setting from 1 for 3) as version from pg_settings where name='server_version_num'` - err = db.QueryRow(query).Scan(&db_version) - if err != nil { + if err = p.DB.QueryRow(query).Scan(&db_version); err != nil { db_version = 0 } + // We loop in order to process each query // Query is not run if Database version does not match the query version. @@ -168,7 +158,7 @@ func (p *Postgresql) Gather(acc telegraf.Accumulator) error { sql_query += query_addon if p.Query[i].Version <= db_version { - rows, err := db.Query(sql_query) + rows, err := p.DB.Query(sql_query) if err != nil { acc.AddError(err) continue @@ -177,15 +167,11 @@ func (p *Postgresql) Gather(acc telegraf.Accumulator) error { defer rows.Close() // grab the column information from the result - p.OrderedColumns, err = rows.Columns() - if err != nil { + if columns, err = rows.Columns(); err != nil { acc.AddError(err) continue - } else { - for _, v := range p.OrderedColumns { - p.AllColumns = append(p.AllColumns, v) - } } + p.AdditionalTags = nil if tag_value != "" { tag_list := strings.Split(tag_value, ",") @@ -195,7 +181,7 @@ func (p *Postgresql) Gather(acc telegraf.Accumulator) error { } for rows.Next() { - err = p.accRow(meas_name, rows, acc) + err = p.accRow(meas_name, rows, acc, columns) if err != nil { acc.AddError(err) break @@ -210,27 +196,7 @@ type scanner interface { Scan(dest ...interface{}) error } -var KVMatcher, _ = regexp.Compile("(password|sslcert|sslkey|sslmode|sslrootcert)=\\S+ ?") - -func (p *Postgresql) SanitizedAddress() (_ string, err error) { - if p.Outputaddress != "" { - return p.Outputaddress, nil - } - var canonicalizedAddress string - if strings.HasPrefix(p.Address, "postgres://") || strings.HasPrefix(p.Address, "postgresql://") { - canonicalizedAddress, err = postgresql.ParseURL(p.Address) - if err != nil { - return p.sanitizedAddress, err - } - } else { - canonicalizedAddress = p.Address - } - p.sanitizedAddress = KVMatcher.ReplaceAllString(canonicalizedAddress, "") - - return p.sanitizedAddress, err -} - -func (p *Postgresql) accRow(meas_name string, row scanner, acc telegraf.Accumulator) error { +func (p *Postgresql) accRow(meas_name string, row scanner, acc telegraf.Accumulator, columns []string) error { var ( err error columnVars []interface{} @@ -241,13 +207,13 @@ func (p *Postgresql) accRow(meas_name string, row scanner, acc telegraf.Accumula // this is where we'll store the column name with its *interface{} columnMap := make(map[string]*interface{}) - for _, column := range p.OrderedColumns { + for _, column := range columns { columnMap[column] = new(interface{}) } // populate the array of interface{} with the pointers in the right order for i := 0; i < len(columnMap); i++ { - columnVars = append(columnVars, columnMap[p.OrderedColumns[i]]) + columnVars = append(columnVars, columnMap[columns[i]]) } // deconstruct array of variables and send to Scan @@ -275,7 +241,7 @@ func (p *Postgresql) accRow(meas_name string, row scanner, acc telegraf.Accumula fields := make(map[string]interface{}) COLUMN: for col, val := range columnMap { - log.Printf("D! postgresql_extensible: column: %s = %T: %s\n", col, *val, *val) + log.Printf("D! postgresql_extensible: column: %s = %T: %v\n", col, *val, *val) _, ignore := ignoredColumns[col] if ignore || *val == nil { continue @@ -310,6 +276,14 @@ COLUMN: func init() { inputs.Add("postgresql_extensible", func() telegraf.Input { - return &Postgresql{} + return &Postgresql{ + Service: postgresql.Service{ + MaxIdle: 1, + MaxOpen: 1, + MaxLifetime: internal.Duration{ + Duration: 0, + }, + }, + } }) } diff --git a/plugins/inputs/postgresql_extensible/postgresql_extensible_test.go b/plugins/inputs/postgresql_extensible/postgresql_extensible_test.go index 4545a2478..77db5feb5 100644 --- a/plugins/inputs/postgresql_extensible/postgresql_extensible_test.go +++ b/plugins/inputs/postgresql_extensible/postgresql_extensible_test.go @@ -4,22 +4,28 @@ import ( "fmt" "testing" + "github.com/influxdata/telegraf/plugins/inputs/postgresql" "github.com/influxdata/telegraf/testutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -func queryRunner(t *testing.T, q query) (*Postgresql, *testutil.Accumulator) { +func queryRunner(t *testing.T, q query) *testutil.Accumulator { p := &Postgresql{ - Address: fmt.Sprintf("host=%s user=postgres sslmode=disable", - testutil.GetLocalHost()), + Service: postgresql.Service{ + Address: fmt.Sprintf( + "host=%s user=postgres sslmode=disable", + testutil.GetLocalHost(), + ), + }, Databases: []string{"postgres"}, Query: q, } var acc testutil.Accumulator + p.Start(&acc) require.NoError(t, acc.GatherError(p.Gather)) - return p, &acc + return &acc } func TestPostgresqlGeneratesMetrics(t *testing.T) { @@ -27,18 +33,13 @@ func TestPostgresqlGeneratesMetrics(t *testing.T) { t.Skip("Skipping integration test in short mode") } - p, acc := queryRunner(t, query{{ + acc := queryRunner(t, query{{ Sqlquery: "select * from pg_stat_database", Version: 901, Withdbname: false, Tagvalue: "", }}) - availableColumns := make(map[string]bool) - for _, col := range p.AllColumns { - availableColumns[col] = true - } - intMetrics := []string{ "xact_commit", "xact_rollback", @@ -71,39 +72,27 @@ func TestPostgresqlGeneratesMetrics(t *testing.T) { metricsCounted := 0 for _, metric := range intMetrics { - _, ok := availableColumns[metric] - if ok { - assert.True(t, acc.HasInt64Field("postgresql", metric)) - metricsCounted++ - } + assert.True(t, acc.HasInt64Field("postgresql", metric)) + metricsCounted++ } for _, metric := range int32Metrics { - _, ok := availableColumns[metric] - if ok { - assert.True(t, acc.HasInt32Field("postgresql", metric)) - metricsCounted++ - } + assert.True(t, acc.HasInt32Field("postgresql", metric)) + metricsCounted++ } for _, metric := range floatMetrics { - _, ok := availableColumns[metric] - if ok { - assert.True(t, acc.HasFloatField("postgresql", metric)) - metricsCounted++ - } + assert.True(t, acc.HasFloatField("postgresql", metric)) + metricsCounted++ } for _, metric := range stringMetrics { - _, ok := availableColumns[metric] - if ok { - assert.True(t, acc.HasStringField("postgresql", metric)) - metricsCounted++ - } + assert.True(t, acc.HasStringField("postgresql", metric)) + metricsCounted++ } assert.True(t, metricsCounted > 0) - assert.Equal(t, len(availableColumns)-len(p.IgnoredColumns()), metricsCounted) + assert.Equal(t, len(floatMetrics)+len(intMetrics)+len(int32Metrics)+len(stringMetrics), metricsCounted) } func TestPostgresqlQueryOutputTests(t *testing.T) { @@ -137,7 +126,7 @@ func TestPostgresqlQueryOutputTests(t *testing.T) { } for q, assertions := range examples { - _, acc := queryRunner(t, query{{ + acc := queryRunner(t, query{{ Sqlquery: q, Version: 901, Withdbname: false, @@ -153,7 +142,7 @@ func TestPostgresqlFieldOutput(t *testing.T) { t.Skip("Skipping integration test in short mode") } - _, acc := queryRunner(t, query{{ + acc := queryRunner(t, query{{ Sqlquery: "select * from pg_stat_database", Version: 901, Withdbname: false, @@ -216,13 +205,18 @@ func TestPostgresqlIgnoresUnwantedColumns(t *testing.T) { } p := &Postgresql{ - Address: fmt.Sprintf("host=%s user=postgres sslmode=disable", - testutil.GetLocalHost()), + Service: postgresql.Service{ + Address: fmt.Sprintf( + "host=%s user=postgres sslmode=disable", + testutil.GetLocalHost(), + ), + }, } var acc testutil.Accumulator - require.NoError(t, acc.GatherError(p.Gather)) + require.NoError(t, p.Start(&acc)) + require.NoError(t, acc.GatherError(p.Gather)) assert.NotEmpty(t, p.IgnoredColumns()) for col := range p.IgnoredColumns() { assert.False(t, acc.HasMeasurement(col))