Simplify testing with TLS (#4095)

This commit is contained in:
Daniel Nelson
2018-05-04 16:33:23 -07:00
committed by GitHub
parent 6e10a4ea88
commit 55b4fcb40d
92 changed files with 1246 additions and 1360 deletions

View File

@@ -4,11 +4,7 @@ import (
"bufio"
"bytes"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io/ioutil"
"log"
"math/big"
"os"
@@ -112,94 +108,6 @@ func RandomString(n int) string {
return string(bytes)
}
// GetTLSConfig gets a tls.Config object from the given certs, key, and CA files
// for use with a client.
// The full path to each file must be provided.
// Returns a nil pointer if all files are blank and InsecureSkipVerify=false.
func GetTLSConfig(
SSLCert, SSLKey, SSLCA string,
InsecureSkipVerify bool,
) (*tls.Config, error) {
if SSLCert == "" && SSLKey == "" && SSLCA == "" && !InsecureSkipVerify {
return nil, nil
}
t := &tls.Config{
InsecureSkipVerify: InsecureSkipVerify,
}
if SSLCA != "" {
caCert, err := ioutil.ReadFile(SSLCA)
if err != nil {
return nil, errors.New(fmt.Sprintf("Could not load TLS CA: %s",
err))
}
caCertPool := x509.NewCertPool()
caCertPool.AppendCertsFromPEM(caCert)
t.RootCAs = caCertPool
}
if SSLCert != "" && SSLKey != "" {
cert, err := tls.LoadX509KeyPair(SSLCert, SSLKey)
if err != nil {
return nil, errors.New(fmt.Sprintf(
"Could not load TLS client key/certificate from %s:%s: %s",
SSLKey, SSLCert, err))
}
t.Certificates = []tls.Certificate{cert}
t.BuildNameToCertificate()
}
// will be nil by default if nothing is provided
return t, nil
}
// GetServerTLSConfig gets a tls.Config object from the given certs, key, and one or more CA files
// for use with a server.
// The full path to each file must be provided.
// Returns a nil pointer if all files are blank.
func GetServerTLSConfig(
TLSCert, TLSKey string,
TLSAllowedCACerts []string,
) (*tls.Config, error) {
if TLSCert == "" && TLSKey == "" && len(TLSAllowedCACerts) == 0 {
return nil, nil
}
t := &tls.Config{}
if len(TLSAllowedCACerts) != 0 {
caCertPool := x509.NewCertPool()
for _, cert := range TLSAllowedCACerts {
c, err := ioutil.ReadFile(cert)
if err != nil {
return nil, errors.New(fmt.Sprintf("Could not load TLS CA: %s",
err))
}
caCertPool.AppendCertsFromPEM(c)
}
t.ClientCAs = caCertPool
t.ClientAuth = tls.RequireAndVerifyClientCert
}
if TLSCert != "" && TLSKey != "" {
cert, err := tls.LoadX509KeyPair(TLSCert, TLSKey)
if err != nil {
return nil, errors.New(fmt.Sprintf(
"Could not load TLS client key/certificate from %s:%s: %s",
TLSKey, TLSCert, err))
}
t.Certificates = []tls.Certificate{cert}
}
t.BuildNameToCertificate()
return t, nil
}
// SnakeCase converts the given string to snake case following the Golang format:
// acronyms are converted to lower-case and preceded by an underscore.
func SnakeCase(in string) string {

130
internal/tls/config.go Normal file
View File

@@ -0,0 +1,130 @@
package tls
import (
"crypto/tls"
"crypto/x509"
"fmt"
"io/ioutil"
)
// ClientConfig represents the standard client TLS config.
type ClientConfig struct {
TLSCA string `toml:"tls_ca"`
TLSCert string `toml:"tls_cert"`
TLSKey string `toml:"tls_key"`
InsecureSkipVerify bool `toml:"insecure_skip_verify"`
// Deprecated in 1.7; use TLS variables above
SSLCA string `toml:"ssl_ca"`
SSLCert string `toml:"ssl_cert"`
SSLKey string `toml:"ssl_ca"`
}
// ServerConfig represents the standard server TLS config.
type ServerConfig struct {
TLSCert string `toml:"tls_cert"`
TLSKey string `toml:"tls_key"`
TLSAllowedCACerts []string `toml:"tls_allowed_cacerts"`
}
// TLSConfig returns a tls.Config, may be nil without error if TLS is not
// configured.
func (c *ClientConfig) TLSConfig() (*tls.Config, error) {
// Support deprecated variable names
if c.TLSCA == "" && c.SSLCA != "" {
c.TLSCA = c.SSLCA
}
if c.TLSCert == "" && c.SSLCert != "" {
c.TLSCert = c.SSLCert
}
if c.TLSKey == "" && c.SSLKey != "" {
c.TLSKey = c.SSLKey
}
// TODO: return default tls.Config; plugins should not call if they don't
// want TLS, this will require using another option to determine. In the
// case of an HTTP plugin, you could use `https`. Other plugins may need
// the dedicated option `TLSEnable`.
if c.TLSCA == "" && c.TLSKey == "" && c.TLSCert == "" && !c.InsecureSkipVerify {
return nil, nil
}
tlsConfig := &tls.Config{
InsecureSkipVerify: c.InsecureSkipVerify,
Renegotiation: tls.RenegotiateNever,
}
if c.TLSCA != "" {
pool, err := makeCertPool([]string{c.TLSCA})
if err != nil {
return nil, err
}
tlsConfig.RootCAs = pool
}
if c.TLSCert != "" && c.TLSKey != "" {
err := loadCertificate(tlsConfig, c.TLSCert, c.TLSKey)
if err != nil {
return nil, err
}
}
return tlsConfig, nil
}
// TLSConfig returns a tls.Config, may be nil without error if TLS is not
// configured.
func (c *ServerConfig) TLSConfig() (*tls.Config, error) {
if c.TLSCert == "" && c.TLSKey == "" && len(c.TLSAllowedCACerts) == 0 {
return nil, nil
}
tlsConfig := &tls.Config{}
if len(c.TLSAllowedCACerts) != 0 {
pool, err := makeCertPool(c.TLSAllowedCACerts)
if err != nil {
return nil, err
}
tlsConfig.ClientCAs = pool
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
}
if c.TLSCert != "" && c.TLSKey != "" {
err := loadCertificate(tlsConfig, c.TLSCert, c.TLSKey)
if err != nil {
return nil, err
}
}
return tlsConfig, nil
}
func makeCertPool(certFiles []string) (*x509.CertPool, error) {
pool := x509.NewCertPool()
for _, certFile := range certFiles {
pem, err := ioutil.ReadFile(certFile)
if err != nil {
return nil, fmt.Errorf(
"could not read certificate %q: %v", certFile, err)
}
ok := pool.AppendCertsFromPEM(pem)
if !ok {
return nil, fmt.Errorf(
"could not parse any PEM certificates %q: %v", certFile, err)
}
}
return pool, nil
}
func loadCertificate(config *tls.Config, certFile, keyFile string) error {
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
return fmt.Errorf(
"could not load keypair %s:%s: %v", certFile, keyFile, err)
}
config.Certificates = []tls.Certificate{cert}
config.BuildNameToCertificate()
return nil
}

226
internal/tls/config_test.go Normal file
View File

@@ -0,0 +1,226 @@
package tls_test
import (
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/influxdata/telegraf/internal/tls"
"github.com/influxdata/telegraf/testutil"
"github.com/stretchr/testify/require"
)
var pki = testutil.NewPKI("../../testutil/pki")
func TestClientConfig(t *testing.T) {
tests := []struct {
name string
client tls.ClientConfig
expNil bool
expErr bool
}{
{
name: "unset",
client: tls.ClientConfig{},
expNil: true,
},
{
name: "success",
client: tls.ClientConfig{
TLSCA: pki.CACertPath(),
TLSCert: pki.ClientCertPath(),
TLSKey: pki.ClientKeyPath(),
},
},
{
name: "invalid ca",
client: tls.ClientConfig{
TLSCA: pki.ClientKeyPath(),
TLSCert: pki.ClientCertPath(),
TLSKey: pki.ClientKeyPath(),
},
expNil: true,
expErr: true,
},
{
name: "missing ca is okay",
client: tls.ClientConfig{
TLSCert: pki.ClientCertPath(),
TLSKey: pki.ClientKeyPath(),
},
},
{
name: "invalid cert",
client: tls.ClientConfig{
TLSCA: pki.CACertPath(),
TLSCert: pki.ClientKeyPath(),
TLSKey: pki.ClientKeyPath(),
},
expNil: true,
expErr: true,
},
{
name: "missing cert skips client keypair",
client: tls.ClientConfig{
TLSCA: pki.CACertPath(),
TLSKey: pki.ClientKeyPath(),
},
expNil: false,
expErr: false,
},
{
name: "missing key skips client keypair",
client: tls.ClientConfig{
TLSCA: pki.CACertPath(),
TLSCert: pki.ClientCertPath(),
},
expNil: false,
expErr: false,
},
{
name: "support deprecated ssl field names",
client: tls.ClientConfig{
SSLCA: pki.CACertPath(),
SSLCert: pki.ClientCertPath(),
SSLKey: pki.ClientKeyPath(),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tlsConfig, err := tt.client.TLSConfig()
if !tt.expNil {
require.NotNil(t, tlsConfig)
} else {
require.Nil(t, tlsConfig)
}
if !tt.expErr {
require.NoError(t, err)
} else {
require.Error(t, err)
}
})
}
}
func TestServerConfig(t *testing.T) {
tests := []struct {
name string
server tls.ServerConfig
expNil bool
expErr bool
}{
{
name: "unset",
server: tls.ServerConfig{},
expNil: true,
},
{
name: "success",
server: tls.ServerConfig{
TLSCert: pki.ServerCertPath(),
TLSKey: pki.ServerKeyPath(),
TLSAllowedCACerts: []string{pki.CACertPath()},
},
},
{
name: "invalid ca",
server: tls.ServerConfig{
TLSCert: pki.ServerCertPath(),
TLSKey: pki.ServerKeyPath(),
TLSAllowedCACerts: []string{pki.ServerKeyPath()},
},
expNil: true,
expErr: true,
},
{
name: "missing allowed ca is okay",
server: tls.ServerConfig{
TLSCert: pki.ServerCertPath(),
TLSKey: pki.ServerKeyPath(),
},
expNil: true,
expErr: true,
},
{
name: "invalid cert",
server: tls.ServerConfig{
TLSCert: pki.ServerKeyPath(),
TLSKey: pki.ServerKeyPath(),
TLSAllowedCACerts: []string{pki.CACertPath()},
},
expNil: true,
expErr: true,
},
{
name: "missing cert",
server: tls.ServerConfig{
TLSKey: pki.ServerKeyPath(),
TLSAllowedCACerts: []string{pki.CACertPath()},
},
expNil: true,
expErr: true,
},
{
name: "missing key",
server: tls.ServerConfig{
TLSCert: pki.ServerCertPath(),
TLSAllowedCACerts: []string{pki.CACertPath()},
},
expNil: true,
expErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tlsConfig, err := tt.server.TLSConfig()
if !tt.expNil {
require.NotNil(t, tlsConfig)
}
if !tt.expErr {
require.NoError(t, err)
}
})
}
}
func TestConnect(t *testing.T) {
clientConfig := tls.ClientConfig{
TLSCA: pki.CACertPath(),
TLSCert: pki.ClientCertPath(),
TLSKey: pki.ClientKeyPath(),
}
serverConfig := tls.ServerConfig{
TLSCert: pki.ServerCertPath(),
TLSKey: pki.ServerKeyPath(),
TLSAllowedCACerts: []string{pki.CACertPath()},
}
serverTLSConfig, err := serverConfig.TLSConfig()
require.NoError(t, err)
ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
ts.TLS = serverTLSConfig
ts.StartTLS()
defer ts.Close()
clientTLSConfig, err := clientConfig.TLSConfig()
require.NoError(t, err)
client := http.Client{
Transport: &http.Transport{
TLSClientConfig: clientTLSConfig,
},
Timeout: 10 * time.Second,
}
resp, err := client.Get(ts.URL)
require.NoError(t, err)
require.Equal(t, 200, resp.StatusCode)
}