diff --git a/rwmutex/rwmutex_test.go b/rwmutex/rwmutex_test.go index 6917345..df9761a 100644 --- a/rwmutex/rwmutex_test.go +++ b/rwmutex/rwmutex_test.go @@ -2,9 +2,12 @@ package rwmutex import ( "fmt" + "math/rand" "runtime" + "sync" "sync/atomic" "testing" + "time" ) func parallelReader(m *RWMutex, clocked, cunlock, cdone chan bool) { @@ -114,3 +117,163 @@ func TestRWMutex(t *testing.T) { HammerRWMutex(10, 10, n) HammerRWMutex(10, 5, n) } + +type CriticalSection struct { + mu sync.Mutex + readersCount, writersCount int +} + +func (cs *CriticalSection) AddToVariable(value *int, count int) { + cs.mu.Lock() + *value += count + cs.mu.Unlock() +} + +func (cs *CriticalSection) Reader(t *testing.T, duration time.Duration) { + cs.AddToVariable(&cs.readersCount, 1) + cs.Check(t) + time.Sleep(duration) // do some work + cs.AddToVariable(&cs.readersCount, -1) +} + +func (cs *CriticalSection) Writer(t *testing.T, duration time.Duration) { + cs.AddToVariable(&cs.writersCount, 1) + cs.Check(t) + time.Sleep(duration) // do some work + cs.AddToVariable(&cs.writersCount, -1) +} + +func (cs *CriticalSection) Check(t *testing.T) { + cs.mu.Lock() + defer cs.mu.Unlock() + if cs.writersCount > 1 { + t.Errorf("To much writers: %d", cs.writersCount) + } + if cs.writersCount == 1 && cs.readersCount > 0 { + t.Errorf("We have %d readers and %d writers", cs.readersCount, cs.writersCount) + } +} + +func TestAFewReaders(t *testing.T) { + var wg sync.WaitGroup + readersCount := 100 + rwm := New() + cs := new(CriticalSection) + ch := make(chan struct{}) + wg.Add(readersCount) + for i := 0; i < readersCount; i++ { + go func() { + rwm.RLock() + cs.Reader(t, 20*time.Millisecond) + rwm.RUnlock() + defer wg.Done() + }() + } + go func() { + wg.Wait() + ch <- struct{}{} + }() + select { + case <-ch: //ok + case <-time.After(25 * time.Millisecond): + t.Error("too slow, your readers are blocked") + } +} + +func TestAFewWriters(t *testing.T) { + var wg sync.WaitGroup + writersCount := 10 + rwm := New() + cs := new(CriticalSection) + wg.Add(writersCount) + for i := 0; i < writersCount; i++ { + go func() { + rwm.Lock() + cs.Writer(t, 10*time.Millisecond) + rwm.Unlock() + defer wg.Done() + }() + } + wg.Wait() +} + +func TestWriterAfterReaders(t *testing.T) { + var wg sync.WaitGroup + rwm := New() + cs := new(CriticalSection) + readersCount := 10 + wg.Add(readersCount + 1) + for i := 0; i < readersCount; i++ { + go func() { + rwm.RLock() + cs.Reader(t, 100*time.Millisecond) + rwm.RUnlock() + defer wg.Done() + }() + } + + time.Sleep(10 * time.Millisecond) + + go func() { + rwm.Lock() + cs.Writer(t, 10*time.Millisecond) + rwm.Unlock() + defer wg.Done() + }() + wg.Wait() +} + +func TestReadersAfterWriters(t *testing.T) { + var wg sync.WaitGroup + rwm := New() + cs := new(CriticalSection) + RWCount := 10 + wg.Add(2 * RWCount) + for i := 0; i < RWCount; i++ { + go func() { + rwm.Lock() + cs.Writer(t, 100*time.Millisecond) + rwm.Unlock() + defer wg.Done() + }() + } + + time.Sleep(20 * time.Millisecond) + + for i := 0; i < RWCount; i++ { + go func() { + rwm.RLock() + cs.Reader(t, 10*time.Millisecond) + rwm.RUnlock() + defer wg.Done() + }() + } + wg.Wait() +} + +func TestRWStress(t *testing.T) { + var wg sync.WaitGroup + rwm := New() + cs := new(CriticalSection) + RWCount := 20 + for j := 0; j < 100; j++ { + wg.Add(2 * RWCount) + for i := 0; i < RWCount; i++ { + go func() { + time.Sleep(time.Duration(rand.Intn(5)) * time.Millisecond) // some delay + rwm.Lock() + cs.Writer(t, time.Millisecond) + rwm.Unlock() + defer wg.Done() + }() + go func() { + time.Sleep(time.Duration(rand.Intn(5)) * time.Millisecond) // some delay + rwm.RLock() + cs.Reader(t, time.Millisecond) + rwm.RUnlock() + defer wg.Done() + }() + } + wg.Wait() + } +}