diff --git a/ratelimit/rate.go b/ratelimit/rate.go new file mode 100644 index 0000000..6aebdce --- /dev/null +++ b/ratelimit/rate.go @@ -0,0 +1,36 @@ +package ratelimit + +import ( + "fmt" + "strconv" + "strings" + "time" +) + +// ParseRate parses rate in form of "N/D", e.g "10/1s" or "100/1ms" +func ParseRate(rate string) (count int, interval time.Duration, err error) { + parts := strings.SplitN(rate, "/", 2) + if len(parts) != 2 { + err = fmt.Errorf("invalid rate format in %q: missing slash", rate) + return + } + + count, err = strconv.Atoi(parts[0]) + if err != nil { + err = fmt.Errorf("invalid rate format in %q: %v", rate, err) + return + } + + interval, err = time.ParseDuration(parts[1]) + if err != nil { + err = fmt.Errorf("invalid rate format in %q: %v", rate, err) + return + } + + if interval < 0 { + err = fmt.Errorf("invalid rate format in %q: negative interval", rate) + return + } + + return +} diff --git a/ratelimit/rate_test.go b/ratelimit/rate_test.go new file mode 100644 index 0000000..297619a --- /dev/null +++ b/ratelimit/rate_test.go @@ -0,0 +1,32 @@ +package ratelimit + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestParseRate(t *testing.T) { + count, dt, err := ParseRate("10/1s") + require.NoError(t, err) + require.Equal(t, 10, count) + require.Equal(t, time.Second, dt) + + count, dt, err = ParseRate("1/1ms") + require.NoError(t, err) + require.Equal(t, 1, count) + require.Equal(t, time.Millisecond, dt) +} + +func TestInvalidRate(t *testing.T) { + for _, invalid := range []string{ + "", + "1/2", + "/10m", + "abc", + } { + _, _, err := ParseRate(invalid) + require.Errorf(t, err, "rate %q is invalid", invalid) + } +} diff --git a/ratelimit/ratelimit.go b/ratelimit/ratelimit.go new file mode 100644 index 0000000..9641aa2 --- /dev/null +++ b/ratelimit/ratelimit.go @@ -0,0 +1,22 @@ +// +build !solution + +package ratelimit + +import ( + "context" + "time" +) + +// Limiter is precise rate limiter with context support. +type Limiter struct { +} + +// NewLimiter returns limiter that throttles rate of successful Acquire() calls +// to maxSize events at any given interval. +func NewLimiter(maxCount int, interval time.Duration) *Limiter { + panic("not implemented") +} + +func (l *Limiter) Acquire(ctx context.Context) error { + panic("not implemented") +} diff --git a/ratelimit/ratelimit_test.go b/ratelimit/ratelimit_test.go new file mode 100644 index 0000000..0095944 --- /dev/null +++ b/ratelimit/ratelimit_test.go @@ -0,0 +1,120 @@ +package ratelimit + +import ( + "context" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" +) + +func TestNoRateLimit(t *testing.T) { + limit := NewLimiter(1, 0) + + ctx := context.Background() + + require.NoError(t, limit.Acquire(ctx)) + require.NoError(t, limit.Acquire(ctx)) +} + +func TestBlockedRateLimit(t *testing.T) { + limit := NewLimiter(0, time.Minute) + + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond) + defer cancel() + + err := limit.Acquire(ctx) + require.Equal(t, context.DeadlineExceeded, err) +} + +func TestSimpleLimitCancel(t *testing.T) { + limit := NewLimiter(1, time.Minute) + + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond) + defer cancel() + + require.NoError(t, limit.Acquire(ctx)) + + err := limit.Acquire(ctx) + require.Equal(t, context.DeadlineExceeded, err) +} + +func TestStressBlocking(t *testing.T) { + const ( + N = 100 + G = 100 + ) + + limit := NewLimiter(N, time.Millisecond*10) + + var eg errgroup.Group + for i := 0; i < G; i++ { + eg.Go(func() error { + for j := 0; j < N; j++ { + if err := limit.Acquire(context.Background()); err != nil { + return err + } + } + + return nil + }) + } + + require.NoError(t, eg.Wait()) +} + +func TestStressNoBlocking(t *testing.T) { + const ( + N = 100 + G = 100 + ) + + limit := NewLimiter(N, time.Millisecond*10) + + var eg errgroup.Group + for i := 0; i < G; i++ { + eg.Go(func() error { + for j := 0; j < N; j++ { + if err := limit.Acquire(context.Background()); err != nil { + return err + } + + time.Sleep(time.Millisecond * 11) + } + + return nil + }) + } + + require.NoError(t, eg.Wait()) +} + +func BenchmarkNoBlocking(b *testing.B) { + b.ReportAllocs() + b.SetBytes(1) + + limit := NewLimiter(1, 0) + + ctx := context.Background() + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + if err := limit.Acquire(ctx); err != nil { + b.Errorf("acquire failed: %v", err) + } + } + }) +} + +func BenchmarkReferenceMutex(b *testing.B) { + var mu sync.Mutex + + var j int + for i := 0; i < b.N; i++ { + mu.Lock() + j++ + mu.Unlock() + } +}