diff --git a/plugins/inputs/postgresql_extensible/postgresql_extensible.go b/plugins/inputs/postgresql_extensible/postgresql_extensible.go index a04382888..a247b603a 100644 --- a/plugins/inputs/postgresql_extensible/postgresql_extensible.go +++ b/plugins/inputs/postgresql_extensible/postgresql_extensible.go @@ -19,14 +19,8 @@ type Postgresql struct { postgresql.Service Databases []string AdditionalTags []string - Query []struct { - Sqlquery string - Version int - Withdbname bool - Tagvalue string - Measurement string - } - Debug bool + Query query + Debug bool } type query []struct { @@ -127,7 +121,6 @@ func (p *Postgresql) Gather(acc telegraf.Accumulator) error { ) // Retreiving the database version - query = `select substring(setting from 1 for 3) as version from pg_settings where name='server_version_num'` if err = p.DB.QueryRow(query).Scan(&db_version); err != nil { db_version = 0 @@ -135,7 +128,6 @@ func (p *Postgresql) Gather(acc telegraf.Accumulator) error { // We loop in order to process each query // Query is not run if Database version does not match the query version. - for i := range p.Query { sql_query = p.Query[i].Sqlquery tag_value = p.Query[i].Tagvalue @@ -221,9 +213,14 @@ func (p *Postgresql) accRow(meas_name string, row scanner, acc telegraf.Accumula return err } - if columnMap["datname"] != nil { + if c, ok := columnMap["datname"]; ok && *c != nil { // extract the database name from the column map - dbname.WriteString((*columnMap["datname"]).(string)) + switch datname := (*c).(type) { + case string: + dbname.WriteString(datname) + default: + dbname.WriteString("postgres") + } } else { dbname.WriteString("postgres") } diff --git a/plugins/inputs/postgresql_extensible/postgresql_extensible_test.go b/plugins/inputs/postgresql_extensible/postgresql_extensible_test.go index 0f9358da6..1ed62a1cd 100644 --- a/plugins/inputs/postgresql_extensible/postgresql_extensible_test.go +++ b/plugins/inputs/postgresql_extensible/postgresql_extensible_test.go @@ -1,6 +1,7 @@ package postgresql_extensible import ( + "errors" "fmt" "testing" @@ -223,3 +224,41 @@ func TestPostgresqlIgnoresUnwantedColumns(t *testing.T) { assert.False(t, acc.HasMeasurement(col)) } } + +func TestAccRow(t *testing.T) { + p := Postgresql{} + var acc testutil.Accumulator + columns := []string{"datname", "cat"} + + testRows := []fakeRow{ + {fields: []interface{}{1, "gato"}}, + {fields: []interface{}{nil, "gato"}}, + {fields: []interface{}{"name", "gato"}}, + } + for i := range testRows { + err := p.accRow("pgTEST", testRows[i], &acc, columns) + if err != nil { + t.Fatalf("Scan failed: %s", err) + } + } +} + +type fakeRow struct { + fields []interface{} +} + +func (f fakeRow) Scan(dest ...interface{}) error { + if len(f.fields) != len(dest) { + return errors.New("Nada matchy buddy") + } + + for i, d := range dest { + switch d.(type) { + case (*interface{}): + *d.(*interface{}) = f.fields[i] + default: + return fmt.Errorf("Bad type %T", d) + } + } + return nil +}