diff --git a/plugins/postgresql/postgresql.go b/plugins/postgresql/postgresql.go index 1a467fee9..a7a8d1acd 100644 --- a/plugins/postgresql/postgresql.go +++ b/plugins/postgresql/postgresql.go @@ -1,7 +1,10 @@ package postgresql import ( + "bytes" "database/sql" + "fmt" + "strings" "github.com/influxdb/telegraf/plugins" @@ -9,8 +12,9 @@ import ( ) type Server struct { - Address string - Databases []string + Address string + Databases []string + OrderedColumns []string } type Postgresql struct { @@ -51,6 +55,7 @@ func (p *Postgresql) Description() string { } var localhost = &Server{Address: "sslmode=disable"} +var ignoredColumns = map[string]bool{"datid": true, "datname": true, "stats_reset": true} func (p *Postgresql) Gather(acc plugins.Accumulator) error { if len(p.Servers) == 0 { @@ -69,6 +74,8 @@ func (p *Postgresql) Gather(acc plugins.Accumulator) error { } func (p *Postgresql) gatherServer(serv *Server, acc plugins.Accumulator) error { + var query string + if serv.Address == "" || serv.Address == "localhost" { serv = localhost } @@ -81,77 +88,69 @@ func (p *Postgresql) gatherServer(serv *Server, acc plugins.Accumulator) error { defer db.Close() if len(serv.Databases) == 0 { - rows, err := db.Query(`SELECT * FROM pg_stat_database`) + query = `SELECT * FROM pg_stat_database` + } else { + query = fmt.Sprintf(`SELECT * FROM pg_stat_database WHERE datname IN ('%s')`, strings.Join(serv.Databases, "','")) + } + + rows, err := db.Query(query) + if err != nil { + return err + } + + defer rows.Close() + + serv.OrderedColumns, err = rows.Columns() + if err != nil { + return err + } + + for rows.Next() { + err := p.accRow(rows, acc, serv) if err != nil { return err } - - defer rows.Close() - - for rows.Next() { - err := p.accRow(rows, acc, serv.Address) - if err != nil { - return err - } - } - - return rows.Err() - } else { - for _, name := range serv.Databases { - row := db.QueryRow(`SELECT * FROM pg_stat_database WHERE datname=$1`, name) - - err := p.accRow(row, acc, serv.Address) - if err != nil { - return err - } - } } - return nil + return rows.Err() } type scanner interface { Scan(dest ...interface{}) error } -func (p *Postgresql) accRow(row scanner, acc plugins.Accumulator, server string) error { - var ignore interface{} - var name string - var commit, rollback, read, hit int64 - var returned, fetched, inserted, updated, deleted int64 - var conflicts, temp_files, temp_bytes, deadlocks int64 - var read_time, write_time float64 +func (p *Postgresql) accRow(row scanner, acc plugins.Accumulator, serv *Server) error { + var columnVars []interface{} + var dbname bytes.Buffer - err := row.Scan(&ignore, &name, &ignore, - &commit, &rollback, - &read, &hit, - &returned, &fetched, &inserted, &updated, &deleted, - &conflicts, &temp_files, &temp_bytes, - &deadlocks, &read_time, &write_time, - &ignore, - ) + columnMap := make(map[string]*interface{}) + + for _, column := range serv.OrderedColumns { + columnMap[column] = new(interface{}) + } + + for i := 0; i < len(columnMap); i++ { + columnVars = append(columnVars, columnMap[serv.OrderedColumns[i]]) + } + + err := row.Scan(columnVars...) if err != nil { return err } - tags := map[string]string{"server": server, "db": name} + dbnameChars := (*columnMap["datname"]).([]uint8) + for i := 0; i < len(dbnameChars); i++ { + dbname.WriteString(string(dbnameChars[i])) + } - acc.Add("xact_commit", commit, tags) - acc.Add("xact_rollback", rollback, tags) - acc.Add("blks_read", read, tags) - acc.Add("blks_hit", hit, tags) - acc.Add("tup_returned", returned, tags) - acc.Add("tup_fetched", fetched, tags) - acc.Add("tup_inserted", inserted, tags) - acc.Add("tup_updated", updated, tags) - acc.Add("tup_deleted", deleted, tags) - acc.Add("conflicts", conflicts, tags) - acc.Add("temp_files", temp_files, tags) - acc.Add("temp_bytes", temp_bytes, tags) - acc.Add("deadlocks", deadlocks, tags) - acc.Add("blk_read_time", read_time, tags) - acc.Add("blk_write_time", read_time, tags) + tags := map[string]string{"server": serv.Address, "db": dbname.String()} + + for col, val := range columnMap { + if !ignoredColumns[col] { + acc.Add(col, *val, tags) + } + } return nil } diff --git a/plugins/postgresql/postgresql_test.go b/plugins/postgresql/postgresql_test.go index 363d289f9..7910425f5 100644 --- a/plugins/postgresql/postgresql_test.go +++ b/plugins/postgresql/postgresql_test.go @@ -117,3 +117,34 @@ func TestPostgresqlDefaultsToAllDatabases(t *testing.T) { assert.True(t, found) } + +func TestPostgresqlIgnoresUnwantedColumns(t *testing.T) { + // if testing.Short() { + // t.Skip("Skipping integration test in short mode") + // } + + p := &Postgresql{ + Servers: []*Server{ + { + Address: fmt.Sprintf("host=%s user=postgres sslmode=disable", + testutil.GetLocalHost()), + }, + }, + } + + var acc testutil.Accumulator + + err := p.Gather(&acc) + require.NoError(t, err) + + var found bool + + for _, pnt := range acc.Points { + if pnt.Measurement == "datname" || pnt.Measurement == "datid" || pnt.Measurement == "stats_reset" { + found = true + break + } + } + + assert.False(t, found) +}