shad-go/consistenthash/hash_test.go

142 lines
2.5 KiB
Go
Raw Normal View History

2023-04-18 10:48:06 +00:00
package consistenthash
import (
"fmt"
"math"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/exp/maps"
)
type node string
func (n node) ID() string { return string(n) }
func TestHash_SingleNode(t *testing.T) {
h := New[node]()
n1 := node("1")
h.AddNode(&n1)
require.Equal(t, &n1, h.GetNode("key0"))
}
func TestHash_TwoNodes(t *testing.T) {
h := New[node]()
n1 := node("1")
h.AddNode(&n1)
n2 := node("2")
h.AddNode(&n2)
n := h.GetNode("key0")
require.True(t, n == &n1 || n == &n2)
for i := 0; i < 32; i++ {
require.Equal(t, n, h.GetNode("key0"))
}
differs := false
for i := 0; i < 32; i++ {
other := h.GetNode(fmt.Sprintf("key%d", i))
if other != n {
differs = true
}
}
require.True(t, differs)
}
func TestHash_EvenDistribution(t *testing.T) {
h := New[node]()
const K = 32
for i := 0; i < K; i++ {
n := node(fmt.Sprint(i))
h.AddNode(&n)
}
counts := map[*node]float64{}
2023-04-18 10:48:06 +00:00
const N = 1 << 16
for i := 0; i < N; i++ {
counts[h.GetNode(fmt.Sprintf("key%d", i))] += 1
}
const P = 1 / float64(K)
const variance = N * (P) * (1 - P)
idealStddev := math.Sqrt(variance)
t.Logf("P = %v, var = %v, stddev = %v", P, variance, idealStddev)
2023-04-18 10:48:06 +00:00
t.Logf("counts = %v", maps.Values(counts))
total := float64(N)
mean := total / K
var dispersion float64
2023-04-18 10:48:06 +00:00
for _, count := range counts {
dispersion += (count - mean) * (count - mean)
2023-04-18 10:48:06 +00:00
}
realStddev := math.Sqrt(dispersion / K)
t.Logf("read stddev = %v", realStddev)
require.Less(t, math.Abs(realStddev-idealStddev)/idealStddev, float64(4))
2023-04-18 10:48:06 +00:00
}
func TestHash_ConsistentDistribution(t *testing.T) {
h := New[node]()
const K = 32
for i := 0; i < K; i++ {
n := node(fmt.Sprint(i))
h.AddNode(&n)
}
nodes := map[string]*node{}
const N = 1 << 16
for i := 0; i < N; i++ {
key := fmt.Sprintf("key%d", i)
nodes[key] = h.GetNode(key)
}
newNode := node("new_node")
h.AddNode(&newNode)
changed := 0
movedToNewNode := 0
for key, oldNode := range nodes {
n := h.GetNode(key)
if n != oldNode {
changed++
}
if n == &newNode {
movedToNewNode++
}
}
t.Logf("changed = %d, movedToNewNode = %d", changed, movedToNewNode)
assert.Less(t, changed, N/K*2)
assert.Equal(t, movedToNewNode, changed)
}
func BenchmarkHashSpeed(b *testing.B) {
for _, K := range []int{32, 1024, 4096} {
h := New[node]()
for i := 0; i < K; i++ {
n := node(fmt.Sprint(i))
h.AddNode(&n)
}
b.Run(fmt.Sprintf("K=%d", K), func(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = h.GetNode(fmt.Sprintf("key%d", i))
}
})
}
}