shad-go/firewall/cmd/firewall/main_test.go

284 lines
6.3 KiB
Go
Raw Normal View History

2020-04-02 19:25:32 +00:00
package main
import (
"fmt"
"io"
"io/ioutil"
"net/http"
"net/http/httptest"
"os"
"os/exec"
"path"
"testing"
"time"
"github.com/go-resty/resty/v2"
"github.com/stretchr/testify/require"
"gitlab.com/slon/shad-go/tools/testtool"
)
const importPath = "gitlab.com/slon/shad-go/firewall/cmd/firewall"
var binCache testtool.BinCache
func TestMain(m *testing.M) {
os.Exit(func() int {
var teardown testtool.CloseFunc
binCache, teardown = testtool.NewBinCache()
defer teardown()
return m.Run()
}())
}
func storeConfig(t *testing.T, conf string) (filename string, cleanup func()) {
t.Helper()
filename = path.Join(os.TempDir(), testtool.RandomName()+".yaml")
err := ioutil.WriteFile(filename, []byte(conf), 0777)
require.NoError(t, err)
cleanup = func() { _ = os.Remove(filename) }
return
}
func startServer(t *testing.T, serviceURL string, conf string) (port string, stop func()) {
binary, err := binCache.GetBinary(importPath)
require.NoError(t, err)
confPath, removeConf := storeConfig(t, conf)
defer removeConf()
port, err = testtool.GetFreePort()
require.NoError(t, err, "unable to get free port")
addr := fmt.Sprintf("localhost:%s", port)
cmd := exec.Command(binary, "-service-addr", serviceURL, "-addr", addr, "-conf", confPath)
cmd.Stdout = nil
cmd.Stderr = os.Stderr
require.NoError(t, cmd.Start())
done := make(chan error)
go func() {
done <- cmd.Wait()
}()
stop = func() {
_ = cmd.Process.Kill()
<-done
}
if err = testtool.WaitForPort(t, time.Second*5, port); err != nil {
stop()
}
require.NoError(t, err)
return
}
func TestFirewall(t *testing.T) {
echoService := func(w http.ResponseWriter, r *http.Request) {
_, _ = io.Copy(w, r.Body)
defer func() { _ = r.Body.Close() }()
w.Header().Set("Content-Length", fmt.Sprintf("%d", r.ContentLength))
}
c := resty.New()
type result struct {
code int
body string
}
for _, tc := range []struct {
name string
conf string
service http.HandlerFunc
makeRequest func() *resty.Request
expected result
}{
{
name: "empty",
conf: ``,
service: echoService,
makeRequest: func() *resty.Request {
return c.R().SetBody("hello")
},
expected: result{code: http.StatusOK, body: "hello"},
},
{
name: "simple",
conf: `
rules:
- endpoint: "/"
forbidden_user_agents:
- 'python-requests.*'
forbidden_headers:
- 'Content-Type: text/html'
required_headers:
- "Content-Type"
max_request_length_bytes: 50
max_response_length_bytes: 50
forbidden_response_codes: [201]
forbidden_request_re:
- '.*(\.\./){3,}.*'
forbidden_response_re:
- '.*admin.*'
`,
service: echoService,
makeRequest: func() *resty.Request {
return c.R().
SetBody(`{"user_id": 123, "path": "../../user"}`).
SetHeaders(map[string]string{
"User-Agent": "Mozilla/5.0",
"Content-Type": "application/json",
})
},
expected: result{code: http.StatusOK, body: `{"user_id": 123, "path": "../../user"}`},
},
{
name: "bad-user-agent",
conf: `
rules:
- endpoint: "/"
forbidden_user_agents:
- 'python-requests.*'
`,
service: echoService,
makeRequest: func() *resty.Request {
return c.R().SetBody("hello").SetHeader("User-Agent", "python-requests/2.22.0")
},
expected: result{code: http.StatusForbidden, body: "Forbidden"},
},
{
name: "forbidden-header",
conf: `
rules:
- endpoint: "/"
forbidden_headers:
- 'Content-Type: text/html'
`,
service: echoService,
makeRequest: func() *resty.Request {
return c.R().SetBody("hello").SetHeader("Content-Type", "text/html")
},
expected: result{code: http.StatusForbidden, body: "Forbidden"},
},
{
name: "missing-required-header",
conf: `
rules:
- endpoint: "/"
required_headers:
- "Content-Type"
- "Content-Length"
- "NoOneUsesThisHeader"
2020-04-02 19:25:32 +00:00
`,
service: echoService,
makeRequest: func() *resty.Request {
return c.R().SetBody("hello").SetHeader("Content-Length", "5")
},
expected: result{code: http.StatusForbidden, body: "Forbidden"},
},
{
name: "max-request-length-exceeded",
conf: `
rules:
- endpoint: "/"
max_request_length_bytes: 4
`,
service: echoService,
makeRequest: func() *resty.Request {
return c.R().SetBody("hello")
},
expected: result{code: http.StatusForbidden, body: "Forbidden"},
},
{
name: "max-response-length-exceeded",
conf: `
rules:
- endpoint: "/"
max_response_length_bytes: 4
`,
service: echoService,
makeRequest: func() *resty.Request {
return c.R().SetBody("hello")
},
expected: result{code: http.StatusForbidden, body: "Forbidden"},
},
{
name: "bad-status-code",
conf: `
rules:
- endpoint: "/"
forbidden_response_codes: [200]
2020-04-02 19:25:32 +00:00
`,
service: echoService,
makeRequest: func() *resty.Request {
return c.R().SetBody("hello")
},
expected: result{code: http.StatusForbidden, body: "Forbidden"},
},
{
name: "bad-status-code",
conf: `
rules:
- endpoint: "/"
forbidden_response_codes: [200]
2020-04-02 19:25:32 +00:00
`,
service: echoService,
makeRequest: func() *resty.Request {
return c.R().SetBody("hello")
},
expected: result{code: http.StatusForbidden, body: "Forbidden"},
},
{
name: "bad-response",
conf: `
rules:
- endpoint: "/"
forbidden_request_re:
- '.*(\.\./){3,}.*'
`,
service: func(w http.ResponseWriter, r *http.Request) {},
makeRequest: func() *resty.Request {
return c.R().SetBody(`{"path": "../../../../etc.passwd"}`)
},
expected: result{code: http.StatusForbidden, body: "Forbidden"},
},
{
name: "bad-response",
conf: `
rules:
- endpoint: "/"
forbidden_response_re:
- '.*admin.*'
`,
service: echoService,
makeRequest: func() *resty.Request {
return c.R().SetBody(`{"user": "admin", "password": "1234"}`)
},
expected: result{code: http.StatusForbidden, body: "Forbidden"},
},
} {
t.Run(tc.name, func(t *testing.T) {
service := httptest.NewServer(tc.service)
defer service.Close()
port, cleanup := startServer(t, service.URL, tc.conf)
defer cleanup()
u := fmt.Sprintf("http://localhost:%s", port)
resp, err := tc.makeRequest().Post(u)
require.NoError(t, err)
require.Equal(t, tc.expected.code, resp.StatusCode())
require.Equal(t, tc.expected.body, resp.String())
})
}
}