blob: 8ed96a287b8793ab4fa0f61f174feaf1b1c1c64d [file] [log] [blame]
// Copyright 2016 Michal Witkowski. All Rights Reserved.
// See LICENSE for licensing terms.
package grpc_prometheus
import (
"net"
"testing"
"time"
"io"
pb_testproto "github.com/grpc-ecosystem/go-grpc-prometheus/examples/testproto"
"github.com/prometheus/client_golang/prometheus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
"golang.org/x/net/context"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
var (
// client metrics must satisfy the Collector interface
_ prometheus.Collector = NewClientMetrics()
)
func TestClientInterceptorSuite(t *testing.T) {
suite.Run(t, &ClientInterceptorTestSuite{})
}
type ClientInterceptorTestSuite struct {
suite.Suite
serverListener net.Listener
server *grpc.Server
clientConn *grpc.ClientConn
testClient pb_testproto.TestServiceClient
ctx context.Context
}
func (s *ClientInterceptorTestSuite) SetupSuite() {
var err error
EnableClientHandlingTimeHistogram()
s.serverListener, err = net.Listen("tcp", "127.0.0.1:0")
require.NoError(s.T(), err, "must be able to allocate a port for serverListener")
// This is the point where we hook up the interceptor
s.server = grpc.NewServer()
pb_testproto.RegisterTestServiceServer(s.server, &testService{t: s.T()})
go func() {
s.server.Serve(s.serverListener)
}()
s.clientConn, err = grpc.Dial(
s.serverListener.Addr().String(),
grpc.WithInsecure(),
grpc.WithBlock(),
grpc.WithUnaryInterceptor(UnaryClientInterceptor),
grpc.WithStreamInterceptor(StreamClientInterceptor),
grpc.WithTimeout(2*time.Second))
require.NoError(s.T(), err, "must not error on client Dial")
s.testClient = pb_testproto.NewTestServiceClient(s.clientConn)
}
func (s *ClientInterceptorTestSuite) SetupTest() {
// Make all RPC calls last at most 2 sec, meaning all async issues or deadlock will not kill tests.
s.ctx, _ = context.WithTimeout(context.TODO(), 2*time.Second)
}
func (s *ClientInterceptorTestSuite) TearDownSuite() {
if s.serverListener != nil {
s.server.Stop()
s.T().Logf("stopped grpc.Server at: %v", s.serverListener.Addr().String())
s.serverListener.Close()
}
if s.clientConn != nil {
s.clientConn.Close()
}
}
func (s *ClientInterceptorTestSuite) TestUnaryIncrementsStarted() {
var before int
var after int
before = sumCountersForMetricAndLabels(s.T(), "grpc_client_started_total", "PingEmpty", "unary")
s.testClient.PingEmpty(s.ctx, &pb_testproto.Empty{})
after = sumCountersForMetricAndLabels(s.T(), "grpc_client_started_total", "PingEmpty", "unary")
assert.EqualValues(s.T(), before+1, after, "grpc_client_started_total should be incremented for PingEmpty")
before = sumCountersForMetricAndLabels(s.T(), "grpc_client_started_total", "PingError", "unary")
s.testClient.PingError(s.ctx, &pb_testproto.PingRequest{ErrorCodeReturned: uint32(codes.Unavailable)})
after = sumCountersForMetricAndLabels(s.T(), "grpc_client_started_total", "PingError", "unary")
assert.EqualValues(s.T(), before+1, after, "grpc_client_started_total should be incremented for PingError")
}
func (s *ClientInterceptorTestSuite) TestUnaryIncrementsHandled() {
var before int
var after int
before = sumCountersForMetricAndLabels(s.T(), "grpc_client_handled_total", "PingEmpty", "unary", "OK")
s.testClient.PingEmpty(s.ctx, &pb_testproto.Empty{}) // should return with code=OK
after = sumCountersForMetricAndLabels(s.T(), "grpc_client_handled_total", "PingEmpty", "unary", "OK")
assert.EqualValues(s.T(), before+1, after, "grpc_client_handled_count should be incremented for PingEmpty")
before = sumCountersForMetricAndLabels(s.T(), "grpc_client_handled_total", "PingError", "unary", "FailedPrecondition")
s.testClient.PingError(s.ctx, &pb_testproto.PingRequest{ErrorCodeReturned: uint32(codes.FailedPrecondition)}) // should return with code=FailedPrecondition
after = sumCountersForMetricAndLabels(s.T(), "grpc_client_handled_total", "PingError", "unary", "FailedPrecondition")
assert.EqualValues(s.T(), before+1, after, "grpc_client_handled_total should be incremented for PingError")
}
func (s *ClientInterceptorTestSuite) TestUnaryIncrementsHistograms() {
var before int
var after int
before = sumCountersForMetricAndLabels(s.T(), "grpc_client_handling_seconds_count", "PingEmpty", "unary")
s.testClient.PingEmpty(s.ctx, &pb_testproto.Empty{}) // should return with code=OK
after = sumCountersForMetricAndLabels(s.T(), "grpc_client_handling_seconds_count", "PingEmpty", "unary")
assert.EqualValues(s.T(), before+1, after, "grpc_client_handled_count should be incremented for PingEmpty")
before = sumCountersForMetricAndLabels(s.T(), "grpc_client_handling_seconds_count", "PingError", "unary")
s.testClient.PingError(s.ctx, &pb_testproto.PingRequest{ErrorCodeReturned: uint32(codes.FailedPrecondition)}) // should return with code=FailedPrecondition
after = sumCountersForMetricAndLabels(s.T(), "grpc_client_handling_seconds_count", "PingError", "unary")
assert.EqualValues(s.T(), before+1, after, "grpc_client_handling_seconds_count should be incremented for PingError")
}
func (s *ClientInterceptorTestSuite) TestStreamingIncrementsStarted() {
var before int
var after int
before = sumCountersForMetricAndLabels(s.T(), "grpc_client_started_total", "PingList", "server_stream")
s.testClient.PingList(s.ctx, &pb_testproto.PingRequest{})
after = sumCountersForMetricAndLabels(s.T(), "grpc_client_started_total", "PingList", "server_stream")
assert.EqualValues(s.T(), before+1, after, "grpc_client_started_total should be incremented for PingList")
}
func (s *ClientInterceptorTestSuite) TestStreamingIncrementsHistograms() {
var before int
var after int
before = sumCountersForMetricAndLabels(s.T(), "grpc_client_handling_seconds_count", "PingList", "server_stream")
ss, _ := s.testClient.PingList(s.ctx, &pb_testproto.PingRequest{}) // should return with code=OK
// Do a read, just for kicks.
for {
_, err := ss.Recv()
if err == io.EOF {
break
}
require.NoError(s.T(), err, "reading pingList shouldn't fail")
}
after = sumCountersForMetricAndLabels(s.T(), "grpc_client_handling_seconds_count", "PingList", "server_stream")
assert.EqualValues(s.T(), before+1, after, "grpc_client_handling_seconds_count should be incremented for PingList OK")
before = sumCountersForMetricAndLabels(s.T(), "grpc_client_handling_seconds_count", "PingList", "server_stream")
ss, err := s.testClient.PingList(s.ctx, &pb_testproto.PingRequest{ErrorCodeReturned: uint32(codes.FailedPrecondition)}) // should return with code=FailedPrecondition
require.NoError(s.T(), err, "PingList must not fail immediately")
// Do a read, just to progate errors.
_, err = ss.Recv()
st, _ := status.FromError(err)
require.Equal(s.T(), codes.FailedPrecondition, st.Code(), "Recv must return FailedPrecondition, otherwise the test is wrong")
after = sumCountersForMetricAndLabels(s.T(), "grpc_client_handling_seconds_count", "PingList", "server_stream")
assert.EqualValues(s.T(), before+1, after, "grpc_client_handling_seconds_count should be incremented for PingList FailedPrecondition")
}
func (s *ClientInterceptorTestSuite) TestStreamingIncrementsHandled() {
var before int
var after int
before = sumCountersForMetricAndLabels(s.T(), "grpc_client_handled_total", "PingList", "server_stream", "OK")
ss, _ := s.testClient.PingList(s.ctx, &pb_testproto.PingRequest{}) // should return with code=OK
// Do a read, just for kicks.
for {
_, err := ss.Recv()
if err == io.EOF {
break
}
require.NoError(s.T(), err, "reading pingList shouldn't fail")
}
after = sumCountersForMetricAndLabels(s.T(), "grpc_client_handled_total", "PingList", "server_stream", "OK")
assert.EqualValues(s.T(), before+1, after, "grpc_client_handled_total should be incremented for PingList OK")
before = sumCountersForMetricAndLabels(s.T(), "grpc_client_handled_total", "PingList", "server_stream", "FailedPrecondition")
ss, err := s.testClient.PingList(s.ctx, &pb_testproto.PingRequest{ErrorCodeReturned: uint32(codes.FailedPrecondition)}) // should return with code=FailedPrecondition
require.NoError(s.T(), err, "PingList must not fail immediately")
// Do a read, just to progate errors.
_, err = ss.Recv()
st, _ := status.FromError(err)
require.Equal(s.T(), codes.FailedPrecondition, st.Code(), "Recv must return FailedPrecondition, otherwise the test is wrong")
after = sumCountersForMetricAndLabels(s.T(), "grpc_client_handled_total", "PingList", "server_stream", "FailedPrecondition")
assert.EqualValues(s.T(), before+1, after, "grpc_client_handled_total should be incremented for PingList FailedPrecondition")
}
func (s *ClientInterceptorTestSuite) TestStreamingIncrementsMessageCounts() {
beforeRecv := sumCountersForMetricAndLabels(s.T(), "grpc_client_msg_received_total", "PingList", "server_stream")
beforeSent := sumCountersForMetricAndLabels(s.T(), "grpc_client_msg_sent_total", "PingList", "server_stream")
ss, _ := s.testClient.PingList(s.ctx, &pb_testproto.PingRequest{}) // should return with code=OK
// Do a read, just for kicks.
count := 0
for {
_, err := ss.Recv()
if err == io.EOF {
break
}
require.NoError(s.T(), err, "reading pingList shouldn't fail")
count++
}
require.EqualValues(s.T(), countListResponses, count, "Number of received msg on the wire must match")
afterSent := sumCountersForMetricAndLabels(s.T(), "grpc_client_msg_sent_total", "PingList", "server_stream")
afterRecv := sumCountersForMetricAndLabels(s.T(), "grpc_client_msg_received_total", "PingList", "server_stream")
assert.EqualValues(s.T(), beforeSent+1, afterSent, "grpc_client_msg_sent_total should be incremented 20 times for PingList")
assert.EqualValues(s.T(), beforeRecv+countListResponses, afterRecv, "grpc_client_msg_sent_total should be incremented ones for PingList ")
}