blob: 7ff60fda403ff67177524c6c336e1325bbf34821 [file] [log] [blame]
package middleware
import (
"context"
"crypto/rand"
"encoding/json"
"fmt"
"net"
"net/http"
"net/http/httptest"
"testing"
"time"
dcontext "github.com/docker/distribution/context"
"reflect" // used as a replacement for testify
)
// Rather than pull in all of testify
func assertEqual(t *testing.T, x, y interface{}) {
if !reflect.DeepEqual(x, y) {
t.Errorf("%s: Not equal! Expected='%v', Actual='%v'\n", t.Name(), x, y)
t.FailNow()
}
}
type mockIPRangeHandler struct {
data awsIPResponse
}
func (m mockIPRangeHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
bytes, err := json.Marshal(m.data)
if err != nil {
w.WriteHeader(500)
return
}
w.Write(bytes)
}
func newTestHandler(data awsIPResponse) *httptest.Server {
return httptest.NewServer(mockIPRangeHandler{
data: data,
})
}
func serverIPRanges(server *httptest.Server) string {
return fmt.Sprintf("%s/", server.URL)
}
func setupTest(data awsIPResponse) *httptest.Server {
// This is a basic schema which only claims the exact ip
// is in aws.
server := newTestHandler(data)
return server
}
func TestS3TryUpdate(t *testing.T) {
t.Parallel()
server := setupTest(awsIPResponse{
Prefixes: []prefixEntry{
{IPV4Prefix: "123.231.123.231/32"},
},
})
defer server.Close()
ips := newAWSIPs(serverIPRanges(server), time.Hour, nil)
assertEqual(t, 1, len(ips.ipv4))
assertEqual(t, 0, len(ips.ipv6))
}
func TestMatchIPV6(t *testing.T) {
t.Parallel()
server := setupTest(awsIPResponse{
V6Prefixes: []prefixEntry{
{IPV6Prefix: "ff00::/16"},
},
})
defer server.Close()
ips := newAWSIPs(serverIPRanges(server), time.Hour, nil)
ips.tryUpdate()
assertEqual(t, true, ips.contains(net.ParseIP("ff00::")))
assertEqual(t, 1, len(ips.ipv6))
assertEqual(t, 0, len(ips.ipv4))
}
func TestMatchIPV4(t *testing.T) {
t.Parallel()
server := setupTest(awsIPResponse{
Prefixes: []prefixEntry{
{IPV4Prefix: "192.168.0.0/24"},
},
})
defer server.Close()
ips := newAWSIPs(serverIPRanges(server), time.Hour, nil)
ips.tryUpdate()
assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.0")))
assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.1")))
assertEqual(t, false, ips.contains(net.ParseIP("192.169.0.0")))
}
func TestMatchIPV4_2(t *testing.T) {
t.Parallel()
server := setupTest(awsIPResponse{
Prefixes: []prefixEntry{
{
IPV4Prefix: "192.168.0.0/24",
Region: "us-east-1",
},
},
})
defer server.Close()
ips := newAWSIPs(serverIPRanges(server), time.Hour, nil)
ips.tryUpdate()
assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.0")))
assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.1")))
assertEqual(t, false, ips.contains(net.ParseIP("192.169.0.0")))
}
func TestMatchIPV4WithRegionMatched(t *testing.T) {
t.Parallel()
server := setupTest(awsIPResponse{
Prefixes: []prefixEntry{
{
IPV4Prefix: "192.168.0.0/24",
Region: "us-east-1",
},
},
})
defer server.Close()
ips := newAWSIPs(serverIPRanges(server), time.Hour, []string{"us-east-1"})
ips.tryUpdate()
assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.0")))
assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.1")))
assertEqual(t, false, ips.contains(net.ParseIP("192.169.0.0")))
}
func TestMatchIPV4WithRegionMatch_2(t *testing.T) {
t.Parallel()
server := setupTest(awsIPResponse{
Prefixes: []prefixEntry{
{
IPV4Prefix: "192.168.0.0/24",
Region: "us-east-1",
},
},
})
defer server.Close()
ips := newAWSIPs(serverIPRanges(server), time.Hour, []string{"us-west-2", "us-east-1"})
ips.tryUpdate()
assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.0")))
assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.1")))
assertEqual(t, false, ips.contains(net.ParseIP("192.169.0.0")))
}
func TestMatchIPV4WithRegionNotMatched(t *testing.T) {
t.Parallel()
server := setupTest(awsIPResponse{
Prefixes: []prefixEntry{
{
IPV4Prefix: "192.168.0.0/24",
Region: "us-east-1",
},
},
})
defer server.Close()
ips := newAWSIPs(serverIPRanges(server), time.Hour, []string{"us-west-2"})
ips.tryUpdate()
assertEqual(t, false, ips.contains(net.ParseIP("192.168.0.0")))
assertEqual(t, false, ips.contains(net.ParseIP("192.168.0.1")))
assertEqual(t, false, ips.contains(net.ParseIP("192.169.0.0")))
}
func TestInvalidData(t *testing.T) {
t.Parallel()
// Invalid entries from aws should be ignored.
server := setupTest(awsIPResponse{
Prefixes: []prefixEntry{
{IPV4Prefix: "9000"},
{IPV4Prefix: "192.168.0.0/24"},
},
})
defer server.Close()
ips := newAWSIPs(serverIPRanges(server), time.Hour, nil)
ips.tryUpdate()
assertEqual(t, 1, len(ips.ipv4))
}
func TestInvalidNetworkType(t *testing.T) {
t.Parallel()
server := setupTest(awsIPResponse{
Prefixes: []prefixEntry{
{IPV4Prefix: "192.168.0.0/24"},
},
V6Prefixes: []prefixEntry{
{IPV6Prefix: "ff00::/8"},
{IPV6Prefix: "fe00::/8"},
},
})
defer server.Close()
ips := newAWSIPs(serverIPRanges(server), time.Hour, nil)
assertEqual(t, 0, len(ips.getCandidateNetworks(make([]byte, 17)))) // 17 bytes does not correspond to any net type
assertEqual(t, 1, len(ips.getCandidateNetworks(make([]byte, 4)))) // netv4 networks
assertEqual(t, 2, len(ips.getCandidateNetworks(make([]byte, 16)))) // netv6 networks
}
func TestParsing(t *testing.T) {
var data = `{
"prefixes": [{
"ip_prefix": "192.168.0.0",
"region": "someregion",
"service": "s3"}],
"ipv6_prefixes": [{
"ipv6_prefix": "2001:4860:4860::8888",
"region": "anotherregion",
"service": "ec2"}]
}`
rawMockHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write([]byte(data)) })
t.Parallel()
server := httptest.NewServer(rawMockHandler)
defer server.Close()
schema, err := fetchAWSIPs(server.URL)
assertEqual(t, nil, err)
assertEqual(t, 1, len(schema.Prefixes))
assertEqual(t, prefixEntry{
IPV4Prefix: "192.168.0.0",
Region: "someregion",
Service: "s3",
}, schema.Prefixes[0])
assertEqual(t, 1, len(schema.V6Prefixes))
assertEqual(t, prefixEntry{
IPV6Prefix: "2001:4860:4860::8888",
Region: "anotherregion",
Service: "ec2",
}, schema.V6Prefixes[0])
}
func TestUpdateCalledRegularly(t *testing.T) {
t.Parallel()
updateCount := 0
server := httptest.NewServer(http.HandlerFunc(
func(rw http.ResponseWriter, req *http.Request) {
updateCount++
rw.Write([]byte("ok"))
}))
defer server.Close()
newAWSIPs(fmt.Sprintf("%s/", server.URL), time.Second, nil)
time.Sleep(time.Second*4 + time.Millisecond*500)
if updateCount < 4 {
t.Errorf("Update should have been called at least 4 times, actual=%d", updateCount)
}
}
func TestEligibleForS3(t *testing.T) {
awsIPs := &awsIPs{
ipv4: []net.IPNet{{
IP: net.ParseIP("192.168.1.1"),
Mask: net.IPv4Mask(255, 255, 255, 0),
}},
initialized: true,
}
empty := context.TODO()
makeContext := func(ip string) context.Context {
req := &http.Request{
RemoteAddr: ip,
}
return dcontext.WithRequest(empty, req)
}
cases := []struct {
Context context.Context
Expected bool
}{
{Context: empty, Expected: false},
{Context: makeContext("192.168.1.2"), Expected: true},
{Context: makeContext("192.168.0.2"), Expected: false},
}
for _, testCase := range cases {
name := fmt.Sprintf("Client IP = %v",
testCase.Context.Value("http.request.ip"))
t.Run(name, func(t *testing.T) {
assertEqual(t, testCase.Expected, eligibleForS3(testCase.Context, awsIPs))
})
}
}
func TestEligibleForS3WithAWSIPNotInitialized(t *testing.T) {
awsIPs := &awsIPs{
ipv4: []net.IPNet{{
IP: net.ParseIP("192.168.1.1"),
Mask: net.IPv4Mask(255, 255, 255, 0),
}},
initialized: false,
}
empty := context.TODO()
makeContext := func(ip string) context.Context {
req := &http.Request{
RemoteAddr: ip,
}
return dcontext.WithRequest(empty, req)
}
cases := []struct {
Context context.Context
Expected bool
}{
{Context: empty, Expected: false},
{Context: makeContext("192.168.1.2"), Expected: false},
{Context: makeContext("192.168.0.2"), Expected: false},
}
for _, testCase := range cases {
name := fmt.Sprintf("Client IP = %v",
testCase.Context.Value("http.request.ip"))
t.Run(name, func(t *testing.T) {
assertEqual(t, testCase.Expected, eligibleForS3(testCase.Context, awsIPs))
})
}
}
// populate ips with a number of different ipv4 and ipv6 networks, for the purposes
// of benchmarking contains() performance.
func populateRandomNetworks(b *testing.B, ips *awsIPs, ipv4Count, ipv6Count int) {
generateNetworks := func(dest *[]net.IPNet, bytes int, count int) {
for i := 0; i < count; i++ {
ip := make([]byte, bytes)
_, err := rand.Read(ip)
if err != nil {
b.Fatalf("failed to generate network for test : %s", err.Error())
}
mask := make([]byte, bytes)
for i := 0; i < bytes; i++ {
mask[i] = 0xff
}
*dest = append(*dest, net.IPNet{
IP: ip,
Mask: mask,
})
}
}
generateNetworks(&ips.ipv4, 4, ipv4Count)
generateNetworks(&ips.ipv6, 16, ipv6Count)
}
func BenchmarkContainsRandom(b *testing.B) {
// Generate a random network configuration, of size comparable to
// aws official networks list
// curl -s https://ip-ranges.amazonaws.com/ip-ranges.json | jq '.prefixes | length'
// 941
numNetworksPerType := 1000 // keep in sync with the above
// intentionally skip constructor when creating awsIPs, to avoid updater routine.
// This benchmark is only concerned with contains() performance.
awsIPs := awsIPs{}
populateRandomNetworks(b, &awsIPs, numNetworksPerType, numNetworksPerType)
ipv4 := make([][]byte, b.N)
ipv6 := make([][]byte, b.N)
for i := 0; i < b.N; i++ {
ipv4[i] = make([]byte, 4)
ipv6[i] = make([]byte, 16)
rand.Read(ipv4[i])
rand.Read(ipv6[i])
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
awsIPs.contains(ipv4[i])
awsIPs.contains(ipv6[i])
}
}
func BenchmarkContainsProd(b *testing.B) {
awsIPs := newAWSIPs(defaultIPRangesURL, defaultUpdateFrequency, nil)
ipv4 := make([][]byte, b.N)
ipv6 := make([][]byte, b.N)
for i := 0; i < b.N; i++ {
ipv4[i] = make([]byte, 4)
ipv6[i] = make([]byte, 16)
rand.Read(ipv4[i])
rand.Read(ipv6[i])
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
awsIPs.contains(ipv4[i])
awsIPs.contains(ipv6[i])
}
}