59
internal/limiter/limiter.go
Normal file
59
internal/limiter/limiter.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package limiter
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// NewRateLimiter returns a rate limiter that will will emit from the C
|
||||
// channel only 'n' times every 'rate' seconds.
|
||||
func NewRateLimiter(n int, rate time.Duration) *rateLimiter {
|
||||
r := &rateLimiter{
|
||||
C: make(chan bool),
|
||||
rate: rate,
|
||||
n: n,
|
||||
shutdown: make(chan bool),
|
||||
}
|
||||
r.wg.Add(1)
|
||||
go r.limiter()
|
||||
return r
|
||||
}
|
||||
|
||||
type rateLimiter struct {
|
||||
C chan bool
|
||||
rate time.Duration
|
||||
n int
|
||||
|
||||
shutdown chan bool
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
func (r *rateLimiter) Stop() {
|
||||
close(r.shutdown)
|
||||
r.wg.Wait()
|
||||
close(r.C)
|
||||
}
|
||||
|
||||
func (r *rateLimiter) limiter() {
|
||||
defer r.wg.Done()
|
||||
ticker := time.NewTicker(r.rate)
|
||||
defer ticker.Stop()
|
||||
counter := 0
|
||||
for {
|
||||
select {
|
||||
case <-r.shutdown:
|
||||
return
|
||||
case <-ticker.C:
|
||||
counter = 0
|
||||
default:
|
||||
if counter < r.n {
|
||||
select {
|
||||
case r.C <- true:
|
||||
counter++
|
||||
case <-r.shutdown:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
54
internal/limiter/limiter_test.go
Normal file
54
internal/limiter/limiter_test.go
Normal file
@@ -0,0 +1,54 @@
|
||||
package limiter
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestRateLimiter(t *testing.T) {
|
||||
r := NewRateLimiter(5, time.Second)
|
||||
ticker := time.NewTicker(time.Millisecond * 75)
|
||||
|
||||
// test that we can only get 5 receives from the rate limiter
|
||||
counter := 0
|
||||
outer:
|
||||
for {
|
||||
select {
|
||||
case <-r.C:
|
||||
counter++
|
||||
case <-ticker.C:
|
||||
break outer
|
||||
}
|
||||
}
|
||||
|
||||
assert.Equal(t, 5, counter)
|
||||
r.Stop()
|
||||
// verify that the Stop function closes the channel.
|
||||
_, ok := <-r.C
|
||||
assert.False(t, ok)
|
||||
}
|
||||
|
||||
func TestRateLimiterMultipleIterations(t *testing.T) {
|
||||
r := NewRateLimiter(5, time.Millisecond*50)
|
||||
ticker := time.NewTicker(time.Millisecond * 250)
|
||||
|
||||
// test that we can get 15 receives from the rate limiter
|
||||
counter := 0
|
||||
outer:
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
break outer
|
||||
case <-r.C:
|
||||
counter++
|
||||
}
|
||||
}
|
||||
|
||||
assert.True(t, counter > 10)
|
||||
r.Stop()
|
||||
// verify that the Stop function closes the channel.
|
||||
_, ok := <-r.C
|
||||
assert.False(t, ok)
|
||||
}
|
||||
Reference in New Issue
Block a user