shad-go/wscat/main_test.go
2020-04-04 05:50:43 +03:00

140 lines
2.6 KiB
Go

package main
import (
"bytes"
"context"
"io"
"net/http"
"net/http/httptest"
"os"
"os/exec"
"strings"
"syscall"
"testing"
"time"
"github.com/gorilla/websocket"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gitlab.com/slon/shad-go/tools/testtool"
)
const importPath = "gitlab.com/slon/shad-go/wscat"
var binCache testtool.BinCache
func TestMain(m *testing.M) {
os.Exit(func() int {
var teardown testtool.CloseFunc
binCache, teardown = testtool.NewBinCache()
defer teardown()
return m.Run()
}())
}
type Conn struct {
in io.WriteCloser
out *bytes.Buffer
}
func startCommand(t *testing.T, addr string) (conn *Conn, stop func()) {
t.Helper()
binary, err := binCache.GetBinary(importPath)
require.NoError(t, err)
cmd := exec.Command(binary, "-addr", addr)
cmd.Stderr = os.Stderr
out := &bytes.Buffer{}
cmd.Stdout = out
stdin, err := cmd.StdinPipe()
require.NoError(t, err)
require.NoError(t, cmd.Start())
conn = &Conn{
in: stdin,
out: out,
}
done := make(chan struct{})
go func() {
assert.NoError(t, cmd.Wait())
close(done)
}()
stop = func() {
defer func() {
_ = cmd.Process.Kill()
<-done
}()
// try killing softly
_ = cmd.Process.Signal(syscall.SIGTERM)
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*100)
defer cancel()
select {
case <-done:
case <-ctx.Done():
t.Fatalf("client shutdown timed out")
}
}
return
}
func TestWScat(t *testing.T) {
upgrader := websocket.Upgrader{}
var received, sent [][]byte
h := func(w http.ResponseWriter, r *http.Request) {
c, err := upgrader.Upgrade(w, r, nil)
require.NoError(t, err)
defer func() { _ = c.Close() }()
for {
_, message, err := c.ReadMessage()
if err != nil {
t.Logf("error reading message: %s", err)
break
}
received = append(received, message)
resp := []byte(testtool.RandomName())
err = c.WriteMessage(websocket.TextMessage, resp)
require.NoError(t, err)
sent = append(sent, resp)
}
}
s := httptest.NewServer(http.HandlerFunc(h))
defer s.Close()
wsURL := strings.Replace(s.URL, "http", "ws", 1)
t.Logf("starting ws server %s", wsURL)
conn, stop := startCommand(t, wsURL)
defer stop()
var in [][]byte
for i := 0; i < 100; i++ {
msg := []byte(testtool.RandomName())
in = append(in, msg)
_, err := conn.in.Write(append(msg, '\n'))
require.NoError(t, err)
}
// give client time to make a request
time.Sleep(time.Millisecond * 100)
stop()
require.Equal(t, bytes.Join(in, nil), bytes.Join(received, nil))
require.Equal(t, bytes.Join(sent, nil), conn.out.Bytes())
}