Fix DC/OS URL creation race (#3932)

This commit is contained in:
Daniel Nelson 2018-03-23 19:14:07 -07:00 committed by GitHub
parent 519e0274a0
commit 338def3524
2 changed files with 49 additions and 31 deletions

View File

@ -31,6 +31,7 @@ type Client interface {
} }
type APIError struct { type APIError struct {
URL string
StatusCode int StatusCode int
Title string Title string
Description string Description string
@ -105,9 +106,9 @@ type claims struct {
func (e APIError) Error() string { func (e APIError) Error() string {
if e.Description != "" { if e.Description != "" {
return fmt.Sprintf("%s: %s", e.Title, e.Description) return fmt.Sprintf("[%s] %s: %s", e.URL, e.Title, e.Description)
} }
return e.Title return fmt.Sprintf("[%s] %s", e.URL, e.Title)
} }
func NewClusterClient( func NewClusterClient(
@ -156,7 +157,8 @@ func (c *ClusterClient) Login(ctx context.Context, sa *ServiceAccount) (*AuthTok
return nil, err return nil, err
} }
req, err := http.NewRequest("POST", c.url("/acs/api/v1/auth/login"), bytes.NewBuffer(octets)) loc := c.url("/acs/api/v1/auth/login")
req, err := http.NewRequest("POST", loc, bytes.NewBuffer(octets))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -189,6 +191,7 @@ func (c *ClusterClient) Login(ctx context.Context, sa *ServiceAccount) (*AuthTok
err = dec.Decode(loginError) err = dec.Decode(loginError)
if err != nil { if err != nil {
err := &APIError{ err := &APIError{
URL: loc,
StatusCode: resp.StatusCode, StatusCode: resp.StatusCode,
Title: resp.Status, Title: resp.Status,
} }
@ -196,6 +199,7 @@ func (c *ClusterClient) Login(ctx context.Context, sa *ServiceAccount) (*AuthTok
} }
err = &APIError{ err = &APIError{
URL: loc,
StatusCode: resp.StatusCode, StatusCode: resp.StatusCode,
Title: loginError.Title, Title: loginError.Title,
Description: loginError.Description, Description: loginError.Description,
@ -301,6 +305,7 @@ func (c *ClusterClient) doGet(ctx context.Context, url string, v interface{}) er
if resp.StatusCode < 200 || resp.StatusCode >= 300 { if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return &APIError{ return &APIError{
URL: url,
StatusCode: resp.StatusCode, StatusCode: resp.StatusCode,
Title: resp.Status, Title: resp.Status,
} }
@ -315,7 +320,7 @@ func (c *ClusterClient) doGet(ctx context.Context, url string, v interface{}) er
} }
func (c *ClusterClient) url(path string) string { func (c *ClusterClient) url(path string) string {
url := c.clusterURL url := *c.clusterURL
url.Path = path url.Path = path
return url.String() return url.String()
} }

View File

@ -31,6 +31,9 @@ P0a+YZUeHNRqT2pPN9lMTAZGGi3CtcF2XScbLNEBeXge
) )
func TestLogin(t *testing.T) { func TestLogin(t *testing.T) {
ts := httptest.NewServer(http.NotFoundHandler())
defer ts.Close()
var tests = []struct { var tests = []struct {
name string name string
responseCode int responseCode int
@ -40,16 +43,21 @@ func TestLogin(t *testing.T) {
}{ }{
{ {
name: "Login successful", name: "Login successful",
responseCode: 200, responseCode: http.StatusOK,
responseBody: `{"token": "XXX.YYY.ZZZ"}`, responseBody: `{"token": "XXX.YYY.ZZZ"}`,
expectedError: nil, expectedError: nil,
expectedToken: "XXX.YYY.ZZZ", expectedToken: "XXX.YYY.ZZZ",
}, },
{ {
name: "Unauthorized Error", name: "Unauthorized Error",
responseCode: http.StatusUnauthorized, responseCode: http.StatusUnauthorized,
responseBody: `{"title": "x", "description": "y"}`, responseBody: `{"title": "x", "description": "y"}`,
expectedError: &APIError{http.StatusUnauthorized, "x", "y"}, expectedError: &APIError{
URL: ts.URL + "/acs/api/v1/auth/login",
StatusCode: http.StatusUnauthorized,
Title: "x",
Description: "y",
},
expectedToken: "", expectedToken: "",
}, },
} }
@ -59,11 +67,11 @@ func TestLogin(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ts.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(tt.responseCode) w.WriteHeader(tt.responseCode)
fmt.Fprintln(w, tt.responseBody) fmt.Fprintln(w, tt.responseBody)
}) })
ts := httptest.NewServer(handler)
u, err := url.Parse(ts.URL) u, err := url.Parse(ts.URL)
require.NoError(t, err) require.NoError(t, err)
@ -82,13 +90,14 @@ func TestLogin(t *testing.T) {
} else { } else {
require.Nil(t, auth) require.Nil(t, auth)
} }
ts.Close()
}) })
} }
} }
func TestGetSummary(t *testing.T) { func TestGetSummary(t *testing.T) {
ts := httptest.NewServer(http.NotFoundHandler())
defer ts.Close()
var tests = []struct { var tests = []struct {
name string name string
responseCode int responseCode int
@ -98,7 +107,7 @@ func TestGetSummary(t *testing.T) {
}{ }{
{ {
name: "No nodes", name: "No nodes",
responseCode: 200, responseCode: http.StatusOK,
responseBody: `{"cluster": "a", "slaves": []}`, responseBody: `{"cluster": "a", "slaves": []}`,
expectedValue: &Summary{Cluster: "a", Slaves: []Slave{}}, expectedValue: &Summary{Cluster: "a", Slaves: []Slave{}},
expectedError: nil, expectedError: nil,
@ -108,11 +117,15 @@ func TestGetSummary(t *testing.T) {
responseCode: http.StatusUnauthorized, responseCode: http.StatusUnauthorized,
responseBody: `<html></html>`, responseBody: `<html></html>`,
expectedValue: nil, expectedValue: nil,
expectedError: &APIError{StatusCode: http.StatusUnauthorized, Title: "401 Unauthorized"}, expectedError: &APIError{
URL: ts.URL + "/mesos/master/state-summary",
StatusCode: http.StatusUnauthorized,
Title: "401 Unauthorized",
},
}, },
{ {
name: "Has nodes", name: "Has nodes",
responseCode: 200, responseCode: http.StatusOK,
responseBody: `{"cluster": "a", "slaves": [{"id": "a"}, {"id": "b"}]}`, responseBody: `{"cluster": "a", "slaves": [{"id": "a"}, {"id": "b"}]}`,
expectedValue: &Summary{ expectedValue: &Summary{
Cluster: "a", Cluster: "a",
@ -127,12 +140,12 @@ func TestGetSummary(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ts.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// check the path // check the path
w.WriteHeader(tt.responseCode) w.WriteHeader(tt.responseCode)
fmt.Fprintln(w, tt.responseBody) fmt.Fprintln(w, tt.responseBody)
}) })
ts := httptest.NewServer(handler)
u, err := url.Parse(ts.URL) u, err := url.Parse(ts.URL)
require.NoError(t, err) require.NoError(t, err)
@ -142,14 +155,15 @@ func TestGetSummary(t *testing.T) {
require.Equal(t, tt.expectedError, err) require.Equal(t, tt.expectedError, err)
require.Equal(t, tt.expectedValue, summary) require.Equal(t, tt.expectedValue, summary)
ts.Close()
}) })
} }
} }
func TestGetNodeMetrics(t *testing.T) { func TestGetNodeMetrics(t *testing.T) {
ts := httptest.NewServer(http.NotFoundHandler())
defer ts.Close()
var tests = []struct { var tests = []struct {
name string name string
responseCode int responseCode int
@ -159,7 +173,7 @@ func TestGetNodeMetrics(t *testing.T) {
}{ }{
{ {
name: "Empty Body", name: "Empty Body",
responseCode: 200, responseCode: http.StatusOK,
responseBody: `{}`, responseBody: `{}`,
expectedValue: &Metrics{}, expectedValue: &Metrics{},
expectedError: nil, expectedError: nil,
@ -168,12 +182,12 @@ func TestGetNodeMetrics(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ts.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// check the path // check the path
w.WriteHeader(tt.responseCode) w.WriteHeader(tt.responseCode)
fmt.Fprintln(w, tt.responseBody) fmt.Fprintln(w, tt.responseBody)
}) })
ts := httptest.NewServer(handler)
u, err := url.Parse(ts.URL) u, err := url.Parse(ts.URL)
require.NoError(t, err) require.NoError(t, err)
@ -183,14 +197,15 @@ func TestGetNodeMetrics(t *testing.T) {
require.Equal(t, tt.expectedError, err) require.Equal(t, tt.expectedError, err)
require.Equal(t, tt.expectedValue, m) require.Equal(t, tt.expectedValue, m)
ts.Close()
}) })
} }
} }
func TestGetContainerMetrics(t *testing.T) { func TestGetContainerMetrics(t *testing.T) {
ts := httptest.NewServer(http.NotFoundHandler())
defer ts.Close()
var tests = []struct { var tests = []struct {
name string name string
responseCode int responseCode int
@ -199,8 +214,8 @@ func TestGetContainerMetrics(t *testing.T) {
expectedError error expectedError error
}{ }{
{ {
name: "204 No Contents", name: "204 No Content",
responseCode: 204, responseCode: http.StatusNoContent,
responseBody: ``, responseBody: ``,
expectedValue: &Metrics{}, expectedValue: &Metrics{},
expectedError: nil, expectedError: nil,
@ -209,12 +224,12 @@ func TestGetContainerMetrics(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ts.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// check the path // check the path
w.WriteHeader(tt.responseCode) w.WriteHeader(tt.responseCode)
fmt.Fprintln(w, tt.responseBody) fmt.Fprintln(w, tt.responseBody)
}) })
ts := httptest.NewServer(handler)
u, err := url.Parse(ts.URL) u, err := url.Parse(ts.URL)
require.NoError(t, err) require.NoError(t, err)
@ -224,8 +239,6 @@ func TestGetContainerMetrics(t *testing.T) {
require.Equal(t, tt.expectedError, err) require.Equal(t, tt.expectedError, err)
require.Equal(t, tt.expectedValue, m) require.Equal(t, tt.expectedValue, m)
ts.Close()
}) })
} }