blob: e83437831d9b58c7ae7585235625089fea70fcbd [file] [log] [blame]
// Copyright Istio Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package server
import (
"bytes"
"context"
"encoding/json"
"errors"
"io"
"net"
"net/http"
"net/http/httputil"
"net/url"
"reflect"
"strings"
"testing"
"time"
)
import (
"istio.io/pkg/log"
)
import (
"github.com/apache/dubbo-go-pixiu/security/pkg/stsservice"
"github.com/apache/dubbo-go-pixiu/security/pkg/stsservice/mock"
)
type stsReqType int
const (
validStsReq stsReqType = 0
emptyGrantType stsReqType = 1
incorrectGrantType stsReqType = 2
emptySubjectToken stsReqType = 3
emptySubjectTokenType stsReqType = 4
incorrectSubjectTokenType stsReqType = 5
incorrectRequestMethod stsReqType = 6
incorrectContentType stsReqType = 7
tokenStatusDump stsReqType = 8
)
type stsRespType int
const (
successStsResp stsRespType = 0
validationFailure stsRespType = 1
tokenGenerationFailure stsRespType = 2
StatusDumpSuccess stsRespType = 3
StatusDumpFailure stsRespType = 4
)
// TestStsService verifies that STS server handles STS request properly.
func TestStsService(t *testing.T) {
tokenManager, hTTPClient, sTSAddr, sTSServer := setUpServerAndClient(t)
emptyStsParam := stsservice.StsResponseParameters{}
mockToken := &stsservice.TokenInfo{
TokenType: "type",
IssueTime: time.Now(),
ExpireTime: time.Now().Add(1 * time.Hour),
}
testCases := map[string]struct {
genTokenError error
dumpTokenError error
stsRequest *http.Request
stsRespParam stsservice.StsResponseParameters
expectedStsResponse *http.Response
expectedResponseType stsRespType
expectedToken *stsservice.TokenInfo
}{
"Send a valid STS request and get STS success response": {
stsRequest: genStsRequest(validStsReq, "http://"+sTSAddr.String()+TokenPath),
stsRespParam: genSuccessStsRespParam(),
expectedStsResponse: genStsResponse(successStsResp, genSuccessStsRespParam(), nil, nil),
expectedResponseType: successStsResp,
},
"Send an invalid STS request (empty grant type) and get STS error response": {
stsRequest: genStsRequest(emptyGrantType, "http://"+sTSAddr.String()+TokenPath),
expectedStsResponse: genStsResponse(validationFailure, genSuccessStsRespParam(), errors.New("request query grant_type is invalid"), nil),
expectedResponseType: validationFailure,
},
"Send an invalid STS request (incorrect grant type) and get STS error response": {
stsRequest: genStsRequest(incorrectGrantType, "http://"+sTSAddr.String()+TokenPath),
expectedStsResponse: genStsResponse(validationFailure, genSuccessStsRespParam(), errors.New("request query grant_type is invalid"), nil),
expectedResponseType: validationFailure,
},
"Send an invalid STS request (empty subject token) and get STS error response": {
stsRequest: genStsRequest(emptySubjectToken, "http://"+sTSAddr.String()+TokenPath),
expectedStsResponse: genStsResponse(validationFailure, genSuccessStsRespParam(), errors.New("subject_token is empty"), nil),
expectedResponseType: validationFailure,
},
"Send an invalid STS request (empty subject token type) and get STS error response": {
stsRequest: genStsRequest(emptySubjectTokenType, "http://"+sTSAddr.String()+TokenPath),
expectedStsResponse: genStsResponse(validationFailure, genSuccessStsRespParam(), errors.New("subject_token_type is invalid"), nil),
expectedResponseType: validationFailure,
},
"Send an invalid STS request (incorrect subject token type) and get STS error response": {
stsRequest: genStsRequest(incorrectSubjectTokenType, "http://"+sTSAddr.String()+TokenPath),
expectedStsResponse: genStsResponse(validationFailure, genSuccessStsRespParam(), errors.New("subject_token_type is invalid"), nil),
expectedResponseType: validationFailure,
},
"Send an invalid STS request (incorrect request method) and get STS error response": {
stsRequest: genStsRequest(incorrectRequestMethod, "http://"+sTSAddr.String()+TokenPath),
expectedStsResponse: genStsResponse(validationFailure, genSuccessStsRespParam(), errors.New("request method is invalid"), nil),
expectedResponseType: validationFailure,
},
"Send an invalid STS request (incorrect content type) and get STS error response": {
stsRequest: genStsRequest(incorrectContentType, "http://"+sTSAddr.String()+TokenPath),
expectedStsResponse: genStsResponse(validationFailure, genSuccessStsRespParam(), errors.New("request content type is invalid"), nil),
expectedResponseType: validationFailure,
},
"Send a valid STS request and get STS error response": {
stsRequest: genStsRequest(validStsReq, "http://"+sTSAddr.String()+TokenPath),
genTokenError: errors.New("failed to generate token"),
expectedStsResponse: genStsResponse(tokenGenerationFailure, emptyStsParam, errors.New("failed to generate token"), nil),
expectedResponseType: tokenGenerationFailure,
},
"Send a dump request and get dump information in response": {
stsRequest: genStsRequest(tokenStatusDump, "http://"+sTSAddr.String()+StsStatusPath),
expectedStsResponse: genStsResponse(StatusDumpSuccess, emptyStsParam, nil, mockToken),
expectedResponseType: StatusDumpSuccess,
expectedToken: mockToken,
},
"Send a dump request and get error response": {
stsRequest: genStsRequest(tokenStatusDump, "http://"+sTSAddr.String()+StsStatusPath),
dumpTokenError: errors.New("failed to dump token"),
expectedStsResponse: genStsResponse(StatusDumpFailure, emptyStsParam, errors.New("failed to dump token"), nil),
expectedResponseType: StatusDumpFailure,
},
}
for k, tc := range testCases {
if tc.genTokenError != nil {
tokenManager.SetGenerateTokenError(tc.genTokenError)
}
if tc.dumpTokenError != nil {
tokenManager.SetDumpTokenError(tc.dumpTokenError)
}
if tc.expectedToken != nil {
tokenManager.SetToken(*tc.expectedToken)
}
tokenManager.SetRespStsParam(tc.stsRespParam)
resp, err := sendStsRequestWithRetry(hTTPClient, tc.stsRequest)
if err != nil {
t.Fatalf("(Test case %s), failure in sending STS request: %v", k, err)
}
verifyResponse(t, k, tc.expectedResponseType, resp, tc.expectedStsResponse)
tokenManager.SetGenerateTokenError(nil)
tokenManager.SetDumpTokenError(nil)
}
sTSServer.Stop()
}
func setUpServerAndClient(t *testing.T) (*mock.FakeTokenManager, *http.Client, *net.TCPAddr, *Server) {
tokenManager := mock.CreateFakeTokenManager()
addr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:3333")
if err != nil {
t.Fatalf("failed to create address %v", err)
}
config := Config{LocalHostAddr: addr.IP.String(), LocalPort: addr.Port}
ipPort := addr.String()
server, _ := NewServer(config, tokenManager)
hTTPClient := &http.Client{
Transport: &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
t.Logf("set up server address to dial %s", addr)
addr = ipPort
return net.Dial(network, addr)
},
},
}
return tokenManager, hTTPClient, addr, server
}
func sendStsRequestWithRetry(client *http.Client, req *http.Request) (resp *http.Response, err error) {
for i := 0; i < 10; i++ {
resp, err = client.Do(req)
if err == nil {
return resp, nil
}
time.Sleep(100 * time.Millisecond)
}
return resp, err
}
func verifyResponse(t *testing.T, tCase string, respType stsRespType, resp, expectedResp *http.Response) {
if resp.StatusCode != expectedResp.StatusCode {
t.Errorf("(Test case %s): response HTTP status code does not match, get %d vs expected %d",
tCase, resp.StatusCode, expectedResp.StatusCode)
}
if respType != StatusDumpFailure {
if resp.Header.Get("Content-Type") != expectedResp.Header.Get("Content-Type") {
t.Errorf("(Test case %s): response HTTP Header Content-Type does not match, get %s vs expected %s",
tCase, resp.Header.Get("Content-Type"), expectedResp.Header.Get("Content-Type"))
}
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
defer expectedResp.Body.Close()
expectedBody, _ := io.ReadAll(expectedResp.Body)
if respType == successStsResp {
verifyResponseBody(t, tCase, body, expectedBody)
} else if respType == validationFailure || respType == tokenGenerationFailure {
verifyErrorResponse(t, tCase, body, expectedBody)
} else if respType == StatusDumpSuccess {
verifyDumpResponse(t, tCase, body, expectedBody)
} else if respType == StatusDumpFailure {
bodyStr := string(body)
eBodyStr := string(expectedBody)
if bodyStr != eBodyStr {
t.Errorf("Dump status failure response does not match, get %s but expect %s", bodyStr, eBodyStr)
}
}
}
func verifyResponseBody(t *testing.T, tCase string, body, expBody []byte) {
respStsParam := &stsservice.StsResponseParameters{}
expRespStsParam := &stsservice.StsResponseParameters{}
if err := json.Unmarshal(body, respStsParam); err != nil {
t.Errorf("failed to unmarshal STS success response: %v", err)
}
if err := json.Unmarshal(expBody, expRespStsParam); err != nil {
t.Errorf("failed to unmarshal expected STS success response: %v", err)
}
if !reflect.DeepEqual(respStsParam, expRespStsParam) {
t.Errorf("(Test case %s): STS response parameter does not match, get %v vs expected %v",
tCase, respStsParam, expRespStsParam)
}
}
func verifyErrorResponse(t *testing.T, tCase string, body, expBody []byte) {
respErr := &stsservice.StsErrorResponse{}
expRespErr := &stsservice.StsErrorResponse{}
if err := json.Unmarshal(body, respErr); err != nil {
t.Errorf("failed to unmarshal error response: %v", err)
}
if err := json.Unmarshal(expBody, expRespErr); err != nil {
t.Errorf("failed to unmarshal expected error response: %v", err)
}
if respErr.Error != expRespErr.Error {
t.Errorf("(Test case %s): STS response error code does not match, get %s vs expected %s",
tCase, respErr.Error, expRespErr.Error)
}
if !strings.HasPrefix(respErr.ErrorDescription, expRespErr.ErrorDescription) {
t.Errorf("(Test case %s): STS response error message does not match, get %s vs expected %s",
tCase, respErr.ErrorDescription, expRespErr.ErrorDescription)
}
}
func verifyDumpResponse(t *testing.T, tCase string, body, expBody []byte) {
tokenDump := &stsservice.TokensDump{}
expTokenDump := &stsservice.TokensDump{}
if err := json.Unmarshal(body, tokenDump); err != nil {
t.Errorf("failed to unmarshal token dump: %v", err)
}
if err := json.Unmarshal(expBody, expTokenDump); err != nil {
t.Errorf("failed to unmarshal expected token dump: %v", err)
}
if !reflect.DeepEqual(tokenDump, expTokenDump) {
t.Errorf("(Test case %s): token dump does not match, get %v vs expected %v",
tCase, tokenDump, expTokenDump)
}
}
func genStsRequest(reqType stsReqType, serverAddr string) (req *http.Request) {
stsQuery := url.Values{}
stsQuery.Set("grant_type", TokenExchangeGrantType)
stsQuery.Set("resource", "https//:backend.example.com")
stsQuery.Set("audience", "audience")
stsQuery.Set("scope", "scope")
stsQuery.Set("requested_token_type", "urn:ietf:params:oauth:token-type:access_token")
stsQuery.Set("subject_token", "subject token")
stsQuery.Set("subject_token_type", SubjectTokenType)
stsQuery.Set("actor_token", "")
stsQuery.Set("actor_token_type", "")
if reqType == emptyGrantType {
stsQuery.Set("grant_type", "")
} else if reqType == incorrectGrantType {
stsQuery.Set("grant_type", "incorrect")
} else if reqType == emptySubjectToken {
stsQuery.Set("subject_token", "")
} else if reqType == emptySubjectTokenType {
stsQuery.Set("subject_token_type", "")
} else if reqType == incorrectSubjectTokenType {
stsQuery.Set("subject_token_type", "incorrect")
}
if reqType == incorrectRequestMethod {
req, _ = http.NewRequest("GET", serverAddr, strings.NewReader(stsQuery.Encode()))
req.Header.Set("Content-Type", URLEncodedForm)
} else if reqType == incorrectContentType {
req, _ = http.NewRequest("POST", serverAddr, strings.NewReader(stsQuery.Encode()))
req.Header.Set("Content-Type", "application/json")
} else if reqType == tokenStatusDump {
req, _ = http.NewRequest("GET", serverAddr, nil)
} else {
req, _ = http.NewRequest("POST", serverAddr, strings.NewReader(stsQuery.Encode()))
req.Header.Set("Content-Type", URLEncodedForm)
}
reqDump, _ := httputil.DumpRequest(req, true)
log.Infof("STS request: %s", string(reqDump))
return req
}
func genStsResponse(respType stsRespType, param stsservice.StsResponseParameters,
serverErr error, tokenInfo *stsservice.TokenInfo) (resp *http.Response) {
resp = &http.Response{
Header: make(http.Header),
}
resp.Header.Add("Content-Type", "application/json")
if respType == successStsResp {
resp.StatusCode = http.StatusOK
resp.Status = http.StatusText(http.StatusOK)
stsJSON, _ := json.MarshalIndent(param, "", " ")
resp.Body = io.NopCloser(bytes.NewBuffer(stsJSON))
} else if respType == tokenGenerationFailure {
resp.StatusCode = http.StatusInternalServerError
resp.Status = http.StatusText(http.StatusInternalServerError)
errResp := stsservice.StsErrorResponse{
Error: invalidTarget,
ErrorDescription: serverErr.Error(),
}
errRespJSON, _ := json.MarshalIndent(errResp, "", " ")
resp.Body = io.NopCloser(bytes.NewBuffer(errRespJSON))
} else if respType == validationFailure {
resp.StatusCode = http.StatusBadRequest
resp.Status = http.StatusText(http.StatusBadRequest)
errResp := stsservice.StsErrorResponse{
Error: invalidRequest,
ErrorDescription: serverErr.Error(),
}
errRespJSON, _ := json.MarshalIndent(errResp, "", " ")
resp.Body = io.NopCloser(bytes.NewBuffer(errRespJSON))
} else if respType == StatusDumpSuccess {
resp.StatusCode = http.StatusOK
resp.Status = http.StatusText(http.StatusOK)
tokenStatus := make([]stsservice.TokenInfo, 0)
tokenStatus = append(tokenStatus, *tokenInfo)
td := stsservice.TokensDump{Tokens: tokenStatus}
statusJSON, _ := json.MarshalIndent(td, "", " ")
resp.Body = io.NopCloser(bytes.NewBuffer(statusJSON))
} else if respType == StatusDumpFailure {
resp.StatusCode = http.StatusInternalServerError
resp.Status = http.StatusText(http.StatusInternalServerError)
resp.Header.Set("Content-Type", "text/plain")
resp.Body = io.NopCloser(bytes.NewBufferString("failure in dumping STS server status: " + serverErr.Error()))
}
respDump, _ := httputil.DumpResponse(resp, true)
log.Infof("Dump response: %s", string(respDump))
return resp
}
func genSuccessStsRespParam() (p stsservice.StsResponseParameters) {
return stsservice.StsResponseParameters{
AccessToken: "accesstoken",
IssuedTokenType: "urn:ietf:params:oauth:token-type:access_token",
TokenType: "Bearer",
ExpiresIn: 60,
Scope: "example.com",
}
}