Simplify testing with TLS (#4095)
This commit is contained in:
130
internal/tls/config.go
Normal file
130
internal/tls/config.go
Normal file
@@ -0,0 +1,130 @@
|
||||
package tls
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
)
|
||||
|
||||
// ClientConfig represents the standard client TLS config.
|
||||
type ClientConfig struct {
|
||||
TLSCA string `toml:"tls_ca"`
|
||||
TLSCert string `toml:"tls_cert"`
|
||||
TLSKey string `toml:"tls_key"`
|
||||
InsecureSkipVerify bool `toml:"insecure_skip_verify"`
|
||||
|
||||
// Deprecated in 1.7; use TLS variables above
|
||||
SSLCA string `toml:"ssl_ca"`
|
||||
SSLCert string `toml:"ssl_cert"`
|
||||
SSLKey string `toml:"ssl_ca"`
|
||||
}
|
||||
|
||||
// ServerConfig represents the standard server TLS config.
|
||||
type ServerConfig struct {
|
||||
TLSCert string `toml:"tls_cert"`
|
||||
TLSKey string `toml:"tls_key"`
|
||||
TLSAllowedCACerts []string `toml:"tls_allowed_cacerts"`
|
||||
}
|
||||
|
||||
// TLSConfig returns a tls.Config, may be nil without error if TLS is not
|
||||
// configured.
|
||||
func (c *ClientConfig) TLSConfig() (*tls.Config, error) {
|
||||
// Support deprecated variable names
|
||||
if c.TLSCA == "" && c.SSLCA != "" {
|
||||
c.TLSCA = c.SSLCA
|
||||
}
|
||||
if c.TLSCert == "" && c.SSLCert != "" {
|
||||
c.TLSCert = c.SSLCert
|
||||
}
|
||||
if c.TLSKey == "" && c.SSLKey != "" {
|
||||
c.TLSKey = c.SSLKey
|
||||
}
|
||||
|
||||
// TODO: return default tls.Config; plugins should not call if they don't
|
||||
// want TLS, this will require using another option to determine. In the
|
||||
// case of an HTTP plugin, you could use `https`. Other plugins may need
|
||||
// the dedicated option `TLSEnable`.
|
||||
if c.TLSCA == "" && c.TLSKey == "" && c.TLSCert == "" && !c.InsecureSkipVerify {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
tlsConfig := &tls.Config{
|
||||
InsecureSkipVerify: c.InsecureSkipVerify,
|
||||
Renegotiation: tls.RenegotiateNever,
|
||||
}
|
||||
|
||||
if c.TLSCA != "" {
|
||||
pool, err := makeCertPool([]string{c.TLSCA})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tlsConfig.RootCAs = pool
|
||||
}
|
||||
|
||||
if c.TLSCert != "" && c.TLSKey != "" {
|
||||
err := loadCertificate(tlsConfig, c.TLSCert, c.TLSKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return tlsConfig, nil
|
||||
}
|
||||
|
||||
// TLSConfig returns a tls.Config, may be nil without error if TLS is not
|
||||
// configured.
|
||||
func (c *ServerConfig) TLSConfig() (*tls.Config, error) {
|
||||
if c.TLSCert == "" && c.TLSKey == "" && len(c.TLSAllowedCACerts) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
tlsConfig := &tls.Config{}
|
||||
|
||||
if len(c.TLSAllowedCACerts) != 0 {
|
||||
pool, err := makeCertPool(c.TLSAllowedCACerts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tlsConfig.ClientCAs = pool
|
||||
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
|
||||
}
|
||||
|
||||
if c.TLSCert != "" && c.TLSKey != "" {
|
||||
err := loadCertificate(tlsConfig, c.TLSCert, c.TLSKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return tlsConfig, nil
|
||||
}
|
||||
|
||||
func makeCertPool(certFiles []string) (*x509.CertPool, error) {
|
||||
pool := x509.NewCertPool()
|
||||
for _, certFile := range certFiles {
|
||||
pem, err := ioutil.ReadFile(certFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(
|
||||
"could not read certificate %q: %v", certFile, err)
|
||||
}
|
||||
ok := pool.AppendCertsFromPEM(pem)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf(
|
||||
"could not parse any PEM certificates %q: %v", certFile, err)
|
||||
}
|
||||
}
|
||||
return pool, nil
|
||||
}
|
||||
|
||||
func loadCertificate(config *tls.Config, certFile, keyFile string) error {
|
||||
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf(
|
||||
"could not load keypair %s:%s: %v", certFile, keyFile, err)
|
||||
}
|
||||
|
||||
config.Certificates = []tls.Certificate{cert}
|
||||
config.BuildNameToCertificate()
|
||||
return nil
|
||||
}
|
||||
226
internal/tls/config_test.go
Normal file
226
internal/tls/config_test.go
Normal file
@@ -0,0 +1,226 @@
|
||||
package tls_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/influxdata/telegraf/internal/tls"
|
||||
"github.com/influxdata/telegraf/testutil"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var pki = testutil.NewPKI("../../testutil/pki")
|
||||
|
||||
func TestClientConfig(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
client tls.ClientConfig
|
||||
expNil bool
|
||||
expErr bool
|
||||
}{
|
||||
{
|
||||
name: "unset",
|
||||
client: tls.ClientConfig{},
|
||||
expNil: true,
|
||||
},
|
||||
{
|
||||
name: "success",
|
||||
client: tls.ClientConfig{
|
||||
TLSCA: pki.CACertPath(),
|
||||
TLSCert: pki.ClientCertPath(),
|
||||
TLSKey: pki.ClientKeyPath(),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid ca",
|
||||
client: tls.ClientConfig{
|
||||
TLSCA: pki.ClientKeyPath(),
|
||||
TLSCert: pki.ClientCertPath(),
|
||||
TLSKey: pki.ClientKeyPath(),
|
||||
},
|
||||
expNil: true,
|
||||
expErr: true,
|
||||
},
|
||||
{
|
||||
name: "missing ca is okay",
|
||||
client: tls.ClientConfig{
|
||||
TLSCert: pki.ClientCertPath(),
|
||||
TLSKey: pki.ClientKeyPath(),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid cert",
|
||||
client: tls.ClientConfig{
|
||||
TLSCA: pki.CACertPath(),
|
||||
TLSCert: pki.ClientKeyPath(),
|
||||
TLSKey: pki.ClientKeyPath(),
|
||||
},
|
||||
expNil: true,
|
||||
expErr: true,
|
||||
},
|
||||
{
|
||||
name: "missing cert skips client keypair",
|
||||
client: tls.ClientConfig{
|
||||
TLSCA: pki.CACertPath(),
|
||||
TLSKey: pki.ClientKeyPath(),
|
||||
},
|
||||
expNil: false,
|
||||
expErr: false,
|
||||
},
|
||||
{
|
||||
name: "missing key skips client keypair",
|
||||
client: tls.ClientConfig{
|
||||
TLSCA: pki.CACertPath(),
|
||||
TLSCert: pki.ClientCertPath(),
|
||||
},
|
||||
expNil: false,
|
||||
expErr: false,
|
||||
},
|
||||
{
|
||||
name: "support deprecated ssl field names",
|
||||
client: tls.ClientConfig{
|
||||
SSLCA: pki.CACertPath(),
|
||||
SSLCert: pki.ClientCertPath(),
|
||||
SSLKey: pki.ClientKeyPath(),
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tlsConfig, err := tt.client.TLSConfig()
|
||||
if !tt.expNil {
|
||||
require.NotNil(t, tlsConfig)
|
||||
} else {
|
||||
require.Nil(t, tlsConfig)
|
||||
}
|
||||
|
||||
if !tt.expErr {
|
||||
require.NoError(t, err)
|
||||
} else {
|
||||
require.Error(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestServerConfig(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
server tls.ServerConfig
|
||||
expNil bool
|
||||
expErr bool
|
||||
}{
|
||||
{
|
||||
name: "unset",
|
||||
server: tls.ServerConfig{},
|
||||
expNil: true,
|
||||
},
|
||||
{
|
||||
name: "success",
|
||||
server: tls.ServerConfig{
|
||||
TLSCert: pki.ServerCertPath(),
|
||||
TLSKey: pki.ServerKeyPath(),
|
||||
TLSAllowedCACerts: []string{pki.CACertPath()},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid ca",
|
||||
server: tls.ServerConfig{
|
||||
TLSCert: pki.ServerCertPath(),
|
||||
TLSKey: pki.ServerKeyPath(),
|
||||
TLSAllowedCACerts: []string{pki.ServerKeyPath()},
|
||||
},
|
||||
expNil: true,
|
||||
expErr: true,
|
||||
},
|
||||
{
|
||||
name: "missing allowed ca is okay",
|
||||
server: tls.ServerConfig{
|
||||
TLSCert: pki.ServerCertPath(),
|
||||
TLSKey: pki.ServerKeyPath(),
|
||||
},
|
||||
expNil: true,
|
||||
expErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid cert",
|
||||
server: tls.ServerConfig{
|
||||
TLSCert: pki.ServerKeyPath(),
|
||||
TLSKey: pki.ServerKeyPath(),
|
||||
TLSAllowedCACerts: []string{pki.CACertPath()},
|
||||
},
|
||||
expNil: true,
|
||||
expErr: true,
|
||||
},
|
||||
{
|
||||
name: "missing cert",
|
||||
server: tls.ServerConfig{
|
||||
TLSKey: pki.ServerKeyPath(),
|
||||
TLSAllowedCACerts: []string{pki.CACertPath()},
|
||||
},
|
||||
expNil: true,
|
||||
expErr: true,
|
||||
},
|
||||
{
|
||||
name: "missing key",
|
||||
server: tls.ServerConfig{
|
||||
TLSCert: pki.ServerCertPath(),
|
||||
TLSAllowedCACerts: []string{pki.CACertPath()},
|
||||
},
|
||||
expNil: true,
|
||||
expErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tlsConfig, err := tt.server.TLSConfig()
|
||||
if !tt.expNil {
|
||||
require.NotNil(t, tlsConfig)
|
||||
}
|
||||
if !tt.expErr {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnect(t *testing.T) {
|
||||
clientConfig := tls.ClientConfig{
|
||||
TLSCA: pki.CACertPath(),
|
||||
TLSCert: pki.ClientCertPath(),
|
||||
TLSKey: pki.ClientKeyPath(),
|
||||
}
|
||||
|
||||
serverConfig := tls.ServerConfig{
|
||||
TLSCert: pki.ServerCertPath(),
|
||||
TLSKey: pki.ServerKeyPath(),
|
||||
TLSAllowedCACerts: []string{pki.CACertPath()},
|
||||
}
|
||||
|
||||
serverTLSConfig, err := serverConfig.TLSConfig()
|
||||
require.NoError(t, err)
|
||||
|
||||
ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
ts.TLS = serverTLSConfig
|
||||
|
||||
ts.StartTLS()
|
||||
defer ts.Close()
|
||||
|
||||
clientTLSConfig, err := clientConfig.TLSConfig()
|
||||
require.NoError(t, err)
|
||||
|
||||
client := http.Client{
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: clientTLSConfig,
|
||||
},
|
||||
Timeout: 10 * time.Second,
|
||||
}
|
||||
|
||||
resp, err := client.Get(ts.URL)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 200, resp.StatusCode)
|
||||
}
|
||||
Reference in New Issue
Block a user