diff --git a/internal/internal.go b/internal/internal.go index af36460e3..12e4b3af2 100644 --- a/internal/internal.go +++ b/internal/internal.go @@ -16,6 +16,7 @@ import ( "runtime" "strconv" "strings" + "sync" "syscall" "time" "unicode" @@ -50,6 +51,11 @@ type Number struct { Value float64 } +type ReadWaitCloser struct { + pipeReader *io.PipeReader + wg sync.WaitGroup +} + // SetVersion sets the telegraf agent version func SetVersion(v string) error { if version != "" { @@ -281,14 +287,25 @@ func ExitStatus(err error) (int, bool) { return 0, false } +func (r *ReadWaitCloser) Close() error { + err := r.pipeReader.Close() + r.wg.Wait() // wait for the gzip goroutine finish + return err +} + // CompressWithGzip takes an io.Reader as input and pipes // it through a gzip.Writer returning an io.Reader containing // the gzipped data. // An error is returned if passing data to the gzip.Writer fails -func CompressWithGzip(data io.Reader) (io.Reader, error) { +func CompressWithGzip(data io.Reader) (io.ReadCloser, error) { pipeReader, pipeWriter := io.Pipe() gzipWriter := gzip.NewWriter(pipeWriter) + rc := &ReadWaitCloser{ + pipeReader: pipeReader, + } + + rc.wg.Add(1) var err error go func() { _, err = io.Copy(gzipWriter, data) @@ -296,6 +313,7 @@ func CompressWithGzip(data io.Reader) (io.Reader, error) { // subsequent reads from the read half of the pipe will // return no bytes and the error err, or EOF if err is nil. pipeWriter.CloseWithError(err) + rc.wg.Done() }() return pipeReader, err diff --git a/internal/internal_test.go b/internal/internal_test.go index f4627ee74..bb186f5fc 100644 --- a/internal/internal_test.go +++ b/internal/internal_test.go @@ -3,6 +3,8 @@ package internal import ( "bytes" "compress/gzip" + "crypto/rand" + "io" "io/ioutil" "log" "os/exec" @@ -232,6 +234,38 @@ func TestCompressWithGzip(t *testing.T) { assert.Equal(t, testData, string(output)) } +type mockReader struct { + readN uint64 // record the number of calls to Read +} + +func (r *mockReader) Read(p []byte) (n int, err error) { + r.readN++ + return rand.Read(p) +} + +func TestCompressWithGzipEarlyClose(t *testing.T) { + mr := &mockReader{} + + rc, err := CompressWithGzip(mr) + assert.NoError(t, err) + + n, err := io.CopyN(ioutil.Discard, rc, 10000) + assert.NoError(t, err) + assert.Equal(t, int64(10000), n) + + r1 := mr.readN + err = rc.Close() + assert.NoError(t, err) + + n, err = io.CopyN(ioutil.Discard, rc, 10000) + assert.Error(t, io.EOF, err) + assert.Equal(t, int64(0), n) + + r2 := mr.readN + // no more read to the source after closing + assert.Equal(t, r1, r2) +} + func TestVersionAlreadySet(t *testing.T) { err := SetVersion("foo") assert.Nil(t, err) diff --git a/plugins/inputs/http/http.go b/plugins/inputs/http/http.go index dc155f254..13c9cd170 100644 --- a/plugins/inputs/http/http.go +++ b/plugins/inputs/http/http.go @@ -153,6 +153,7 @@ func (h *HTTP) gatherURL( if err != nil { return err } + defer body.Close() request, err := http.NewRequest(h.Method, url, body) if err != nil { @@ -216,16 +217,16 @@ func (h *HTTP) gatherURL( return nil } -func makeRequestBodyReader(contentEncoding, body string) (io.Reader, error) { - var err error +func makeRequestBodyReader(contentEncoding, body string) (io.ReadCloser, error) { var reader io.Reader = strings.NewReader(body) if contentEncoding == "gzip" { - reader, err = internal.CompressWithGzip(reader) + rc, err := internal.CompressWithGzip(reader) if err != nil { return nil, err } + return rc, nil } - return reader, nil + return ioutil.NopCloser(reader), nil } func init() { diff --git a/plugins/outputs/http/http.go b/plugins/outputs/http/http.go index 1967b6ef9..746cba346 100644 --- a/plugins/outputs/http/http.go +++ b/plugins/outputs/http/http.go @@ -176,10 +176,12 @@ func (h *HTTP) write(reqBody []byte) error { var err error if h.ContentEncoding == "gzip" { - reqBodyBuffer, err = internal.CompressWithGzip(reqBodyBuffer) + rc, err := internal.CompressWithGzip(reqBodyBuffer) if err != nil { return err } + defer rc.Close() + reqBodyBuffer = rc } req, err := http.NewRequest(h.Method, h.URL, reqBodyBuffer) diff --git a/plugins/outputs/influxdb/http.go b/plugins/outputs/influxdb/http.go index b30a8206d..d449c9456 100644 --- a/plugins/outputs/influxdb/http.go +++ b/plugins/outputs/influxdb/http.go @@ -6,6 +6,7 @@ import ( "encoding/json" "fmt" "io" + "io/ioutil" "net" "net/http" "net/url" @@ -288,7 +289,12 @@ func (c *httpClient) writeBatch(ctx context.Context, db string, metrics []telegr return err } - reader := influx.NewReader(metrics, c.config.Serializer) + reader, err := c.requestBodyReader(metrics) + if err != nil { + return err + } + defer reader.Close() + req, err := c.makeWriteRequest(url, reader) if err != nil { return err @@ -386,12 +392,6 @@ func (c *httpClient) makeQueryRequest(query string) (*http.Request, error) { func (c *httpClient) makeWriteRequest(url string, body io.Reader) (*http.Request, error) { var err error - if c.config.ContentEncoding == "gzip" { - body, err = internal.CompressWithGzip(body) - if err != nil { - return nil, err - } - } req, err := http.NewRequest("POST", url, body) if err != nil { @@ -408,6 +408,23 @@ func (c *httpClient) makeWriteRequest(url string, body io.Reader) (*http.Request return req, nil } +// requestBodyReader warp io.Reader from influx.NewReader to io.ReadCloser, which is usefully to fast close the write +// side of the connection in case of error +func (c *httpClient) requestBodyReader(metrics []telegraf.Metric) (io.ReadCloser, error) { + reader := influx.NewReader(metrics, c.config.Serializer) + + if c.config.ContentEncoding == "gzip" { + rc, err := internal.CompressWithGzip(reader) + if err != nil { + return nil, err + } + + return rc, nil + } + + return ioutil.NopCloser(reader), nil +} + func (c *httpClient) addHeaders(req *http.Request) { if c.config.Username != "" || c.config.Password != "" { req.SetBasicAuth(c.config.Username, c.config.Password) diff --git a/plugins/outputs/influxdb/influxdb.go b/plugins/outputs/influxdb/influxdb.go index 01a09208a..50161e832 100644 --- a/plugins/outputs/influxdb/influxdb.go +++ b/plugins/outputs/influxdb/influxdb.go @@ -57,8 +57,7 @@ type InfluxDB struct { CreateHTTPClientF func(config *HTTPConfig) (Client, error) CreateUDPClientF func(config *UDPConfig) (Client, error) - serializer *influx.Serializer - Log telegraf.Logger + Log telegraf.Logger } var sampleConfig = ` @@ -145,11 +144,6 @@ func (i *InfluxDB) Connect() error { urls = append(urls, defaultURL) } - i.serializer = influx.NewSerializer() - if i.InfluxUintSupport { - i.serializer.SetFieldTypeSupport(influx.UintSupport) - } - for _, u := range urls { parts, err := url.Parse(u) if err != nil { @@ -237,7 +231,7 @@ func (i *InfluxDB) udpClient(url *url.URL) (Client, error) { config := &UDPConfig{ URL: url, MaxPayloadSize: int(i.UDPPayload.Size), - Serializer: i.serializer, + Serializer: i.newSerializer(), Log: i.Log, } @@ -271,7 +265,7 @@ func (i *InfluxDB) httpClient(ctx context.Context, url *url.URL, proxy *url.URL) SkipDatabaseCreation: i.SkipDatabaseCreation, RetentionPolicy: i.RetentionPolicy, Consistency: i.WriteConsistency, - Serializer: i.serializer, + Serializer: i.newSerializer(), Log: i.Log, } @@ -291,6 +285,15 @@ func (i *InfluxDB) httpClient(ctx context.Context, url *url.URL, proxy *url.URL) return c, nil } +func (i *InfluxDB) newSerializer() *influx.Serializer { + serializer := influx.NewSerializer() + if i.InfluxUintSupport { + serializer.SetFieldTypeSupport(influx.UintSupport) + } + + return serializer +} + func init() { outputs.Add("influxdb", func() telegraf.Output { return &InfluxDB{ diff --git a/plugins/outputs/influxdb_v2/http.go b/plugins/outputs/influxdb_v2/http.go index b8706c9a5..b94df889b 100644 --- a/plugins/outputs/influxdb_v2/http.go +++ b/plugins/outputs/influxdb_v2/http.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "io" + "io/ioutil" "log" "net" "net/http" @@ -214,7 +215,12 @@ func (c *httpClient) writeBatch(ctx context.Context, bucket string, metrics []te return err } - reader := influx.NewReader(metrics, c.serializer) + reader, err := c.requestBodyReader(metrics) + if err != nil { + return err + } + defer reader.Close() + req, err := c.makeWriteRequest(url, reader) if err != nil { return err @@ -282,12 +288,6 @@ func (c *httpClient) writeBatch(ctx context.Context, bucket string, metrics []te func (c *httpClient) makeWriteRequest(url string, body io.Reader) (*http.Request, error) { var err error - if c.ContentEncoding == "gzip" { - body, err = internal.CompressWithGzip(body) - if err != nil { - return nil, err - } - } req, err := http.NewRequest("POST", url, body) if err != nil { @@ -304,6 +304,23 @@ func (c *httpClient) makeWriteRequest(url string, body io.Reader) (*http.Request return req, nil } +// requestBodyReader warp io.Reader from influx.NewReader to io.ReadCloser, which is usefully to fast close the write +// side of the connection in case of error +func (c *httpClient) requestBodyReader(metrics []telegraf.Metric) (io.ReadCloser, error) { + reader := influx.NewReader(metrics, c.serializer) + + if c.ContentEncoding == "gzip" { + rc, err := internal.CompressWithGzip(reader) + if err != nil { + return nil, err + } + + return rc, nil + } + + return ioutil.NopCloser(reader), nil +} + func (c *httpClient) addHeaders(req *http.Request) { for header, value := range c.Headers { req.Header.Set(header, value) diff --git a/plugins/outputs/influxdb_v2/influxdb.go b/plugins/outputs/influxdb_v2/influxdb.go index 972773f79..4e2314691 100644 --- a/plugins/outputs/influxdb_v2/influxdb.go +++ b/plugins/outputs/influxdb_v2/influxdb.go @@ -96,8 +96,7 @@ type InfluxDB struct { UintSupport bool `toml:"influx_uint_support"` tls.ClientConfig - clients []Client - serializer *influx.Serializer + clients []Client } func (i *InfluxDB) Connect() error { @@ -107,11 +106,6 @@ func (i *InfluxDB) Connect() error { i.URLs = append(i.URLs, defaultURL) } - i.serializer = influx.NewSerializer() - if i.UintSupport { - i.serializer.SetFieldTypeSupport(influx.UintSupport) - } - for _, u := range i.URLs { parts, err := url.Parse(u) if err != nil { @@ -196,7 +190,7 @@ func (i *InfluxDB) getHTTPClient(ctx context.Context, url *url.URL, proxy *url.U UserAgent: i.UserAgent, ContentEncoding: i.ContentEncoding, TLSConfig: tlsConfig, - Serializer: i.serializer, + Serializer: i.newSerializer(), } c, err := NewHTTPClient(config) @@ -207,6 +201,15 @@ func (i *InfluxDB) getHTTPClient(ctx context.Context, url *url.URL, proxy *url.U return c, nil } +func (i *InfluxDB) newSerializer() *influx.Serializer { + serializer := influx.NewSerializer() + if i.UintSupport { + serializer.SetFieldTypeSupport(influx.UintSupport) + } + + return serializer +} + func init() { outputs.Add("influxdb_v2", func() telegraf.Output { return &InfluxDB{