diff --git a/services/enterprise/enterprise.go b/services/enterprise/enterprise.go index a4905ab56..d44073286 100644 --- a/services/enterprise/enterprise.go +++ b/services/enterprise/enterprise.go @@ -3,8 +3,10 @@ package enterprise import ( "fmt" "log" + "net" "net/http" "os" + "time" "github.com/influxdata/enterprise-client/v2" "github.com/influxdata/enterprise-client/v2/admin" @@ -20,14 +22,17 @@ type Service struct { logger *log.Logger hostname string adminPort string + + shutdown chan struct{} } -func NewEnterprise(c Config, hostname string) *Service { +func NewEnterprise(c Config, hostname string, shutdown chan struct{}) *Service { return &Service{ hosts: c.Hosts, hostname: hostname, logger: log.New(os.Stdout, "[enterprise]", log.Ldate|log.Ltime), adminPort: fmt.Sprintf(":%d", c.AdminPort), + shutdown: shutdown, } } @@ -37,22 +42,51 @@ func (s *Service) Open() { s.logger.Printf("Unable to contact one or more Enterprise hosts. err: %s", err.Error()) return } - go s.registerProduct(cl) - go s.startAdminInterface() + go func() { + token, secret, err := s.registerProduct(cl) + if err == nil { + s.startAdminInterface(token, secret) + } + }() } -func (s *Service) registerProduct(cl *client.Client) { +func (s *Service) registerProduct(cl *client.Client) (token string, secret string, err error) { p := client.Product{ ProductID: "telegraf", Host: s.hostname, } - _, err := cl.Register(&p) + _, err = cl.Register(&p) if err != nil { s.logger.Println("Unable to register Telegraf with Enterprise") + return } + + for _, host := range cl.Hosts { + if host.Primary { + token = host.Token + secret = host.SecretKey + } + } + return } -func (s *Service) startAdminInterface() { - go http.ListenAndServe(s.adminPort, admin.App("foo", []byte("bar"))) +func (s *Service) startAdminInterface(token, secret string) { + srv := &http.Server{ + ReadTimeout: 5 * time.Second, + WriteTimeout: 5 * time.Second, + Handler: admin.App(token, []byte(secret)), + } + l, err := net.Listen("tcp", s.adminPort) + if err != nil { + s.logger.Printf("Unable to bind to admin interface port: err: %s", err.Error()) + return + } + go srv.Serve(l) + select { + case <-s.shutdown: + s.logger.Printf("Shutting down enterprise admin interface") + l.Close() + } + return } diff --git a/services/enterprise/enterprise_test.go b/services/enterprise/enterprise_test.go index 5cc1a9a62..92c2cb45e 100644 --- a/services/enterprise/enterprise_test.go +++ b/services/enterprise/enterprise_test.go @@ -44,7 +44,10 @@ func Test_RegistersWithEnterprise(t *testing.T) { &client.Host{URL: srv.URL}, }, } - e := enterprise.NewEnterprise(c, expected) + + shutdown := make(chan struct{}) + defer close(shutdown) + e := enterprise.NewEnterprise(c, expected, shutdown) e.Open() timeout := time.After(1 * time.Millisecond) @@ -75,7 +78,9 @@ func Test_StartsAdminInterface(t *testing.T) { AdminPort: 2300, } - e := enterprise.NewEnterprise(c, hostname) + shutdown := make(chan struct{}) + defer close(shutdown) + e := enterprise.NewEnterprise(c, hostname, shutdown) e.Open() timeout := time.After(1 * time.Millisecond) @@ -92,3 +97,44 @@ func Test_StartsAdminInterface(t *testing.T) { } } } + +func Test_ClosesAdminInterface(t *testing.T) { + hostname := "localhost" + adminPort := 2300 + + success, srv := mockEnterprise(func(c *client.Product, err error) {}) + defer srv.Close() + + c := enterprise.Config{ + Hosts: []*client.Host{ + &client.Host{URL: srv.URL}, + }, + AdminPort: 2300, + } + + shutdown := make(chan struct{}) + e := enterprise.NewEnterprise(c, hostname, shutdown) + e.Open() + + timeout := time.After(1 * time.Millisecond) + for { + select { + case <-success: + // Ensure that the admin interface is running + _, err := http.Get(fmt.Sprintf("http://%s:%d", hostname, adminPort)) + if err != nil { + t.Errorf("Unable to connect to admin interface: err: %s", err) + } + close(shutdown) + + // ...and that it's not running after we shut it down + _, err = http.Get(fmt.Sprintf("http://%s:%d", hostname, adminPort)) + if err == nil { + t.Errorf("Admin interface continued running after shutdown") + } + return + case <-timeout: + t.Fatal("Expected to receive call to Enterprise API, but received none") + } + } +}