blob: ff8f12131c4940f7a56d471ebf79248ce9a92fbe [file] [log] [blame]
// Copyright Istio Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package google
import (
"encoding/json"
"errors"
"strings"
"testing"
)
import (
"github.com/apache/dubbo-go-pixiu/pkg/security"
"github.com/apache/dubbo-go-pixiu/security/pkg/stsservice"
"github.com/apache/dubbo-go-pixiu/security/pkg/stsservice/tokenmanager/google/mock"
)
// TestAccessToken verifies that token manager could successfully call server and get access token.
func TestTokenExchangePlugin(t *testing.T) {
tmPlugin, ms, originalFederatedTokenEndpoint, originalAccessTokenEndpoint := setUpTest(t, testSetUp{})
lastStatusDumpMap := make(map[string]stsservice.TokenInfo)
defer func() {
if err := ms.Stop(); err != nil {
t.Logf("failed to stop mock server: %v", err)
}
federatedTokenEndpoint = originalFederatedTokenEndpoint
accessTokenEndpoint = originalAccessTokenEndpoint
}()
testCases := map[string]struct {
genFederatedTokenError error
genAccessTokenError error
expectedError string
expectedStatusDumpUpdate []string
}{
"token manager returns valid STS success response": {
expectedStatusDumpUpdate: []string{federatedToken, accessToken},
},
"token manager failed to return federated token": {
genFederatedTokenError: errors.New("fake error in generating federated access token"),
expectedError: "failed to exchange federated token",
expectedStatusDumpUpdate: []string{},
},
"token manager failed to return access token": {
genAccessTokenError: errors.New("fake error in generating access token"),
expectedError: "failed to exchange access token",
expectedStatusDumpUpdate: []string{federatedToken},
},
}
for k, tc := range testCases {
if tc.genAccessTokenError != nil {
ms.SetGenAcsTokenError(tc.genAccessTokenError)
}
if tc.genFederatedTokenError != nil {
ms.SetGenFedTokenError(tc.genFederatedTokenError)
}
stsRespJSON, err := tmPlugin.ExchangeToken(defaultSTSRequest())
verifyToken(t, k, stsRespJSON, err, tc.expectedError)
stsDumpJSON, _ := tmPlugin.DumpPluginStatus()
lastStatusDumpMap = verifyDumpStatus(t, k, stsDumpJSON, lastStatusDumpMap, tc.expectedStatusDumpUpdate)
ms.SetGenAcsTokenError(nil)
ms.SetGenFedTokenError(nil)
}
}
func verifyDumpStatus(t *testing.T, tCase string, dumpJSON []byte, lastStatus map[string]stsservice.TokenInfo,
expected []string) map[string]stsservice.TokenInfo {
newStatus := &stsservice.TokensDump{}
if err := json.Unmarshal(dumpJSON, newStatus); err != nil {
t.Errorf("(Test case %s), failed to unmarshal status dump: %v", tCase, err)
}
newStatusMap := extractTokenDumpToMap(newStatus)
t.Logf("Dump newStatusMap:\n%+v", newStatusMap)
t.Logf("Dump lastStatus:\n%+v", lastStatus)
for _, exp := range expected {
if newVal, ok := newStatusMap[exp]; !ok {
t.Errorf("(Test case %s), failed to find expected token %s in status dump", tCase, exp)
} else if oldVal, ok := lastStatus[exp]; ok {
if newVal.ExpireTime == oldVal.ExpireTime || newVal.IssueTime == oldVal.IssueTime {
t.Errorf("(Test case %s), expected status update for %s (%v) in status dump", tCase, exp, newVal)
}
}
}
return newStatusMap
}
func extractTokenDumpToMap(newStatus *stsservice.TokensDump) map[string]stsservice.TokenInfo {
newStatusMap := make(map[string]stsservice.TokenInfo)
for _, info := range newStatus.Tokens {
newStatusMap[info.TokenType] = info
}
return newStatusMap
}
// verifyToken verifies the received STS response parameters and error match expectation.
func verifyToken(t *testing.T, tCase string, stsRespJSON []byte, actualErr error, expErr string) {
if len(expErr) != 0 && actualErr != nil {
if !strings.Contains(actualErr.Error(), expErr) {
t.Errorf("(Test case %s), error does not match, want: %v vs get: %v",
tCase, expErr, actualErr)
}
return
} else if len(expErr) == 0 && actualErr == nil {
stsResp := &stsservice.StsResponseParameters{}
if err := json.Unmarshal(stsRespJSON, stsResp); err != nil {
t.Errorf("(Test case %s), failed to unmarshal STS response: %v", tCase, err)
}
if stsResp.AccessToken != mock.FakeAccessToken {
t.Errorf("(Test case %s), access token got: %q, expected: %q",
tCase, stsResp.AccessToken, mock.FakeAccessToken)
}
} else {
t.Errorf("(Test case %s), error does not match: want %s vs get: %v",
tCase, expErr, actualErr)
}
}
func defaultSTSRequest() security.StsRequestParameters {
return security.StsRequestParameters{
GrantType: "urn:ietf:params:oauth:grant-type:token-exchange",
Audience: mock.FakeTrustDomain,
Scope: scope,
SubjectToken: mock.FakeSubjectToken,
SubjectTokenType: "urn:ietf:params:oauth:token-type:jwt",
}
}
type testSetUp struct {
enableCache bool
enableDynamicToken bool
}
// setUpTest sets up token manager, authorization server.
func setUpTest(t *testing.T, setup testSetUp) (*Plugin, *mock.AuthorizationServer, string, string) {
tm, _ := CreateTokenManagerPlugin(nil, mock.FakeTrustDomain, mock.FakeProjectNum, mock.FakeGKEClusterURL, setup.enableCache)
ms, err := mock.StartNewServer(t, mock.Config{Port: 0})
ms.EnableDynamicAccessToken(setup.enableDynamicToken)
if err != nil {
t.Fatalf("failed to start a mock server: %v", err)
}
originalFederatedTokenEndpoint := federatedTokenEndpoint
federatedTokenEndpoint = ms.URL + "/v1/token"
originalAccessTokenEndpoint := accessTokenEndpoint
accessTokenEndpoint = ms.URL + "/v1/projects/-/serviceAccounts/service-%s@gcp-sa-meshdataplane.iam.gserviceaccount.com:generateAccessToken"
return tm, ms, originalFederatedTokenEndpoint, originalAccessTokenEndpoint
}
// TestAccessToken verifies that token manager could return a cached token to client.
func TestTokenExchangePluginWithCache(t *testing.T) {
tmPlugin, ms, originalFederatedTokenEndpoint, originalAccessTokenEndpoint := setUpTest(t, testSetUp{enableCache: true, enableDynamicToken: true})
defer func() {
if err := ms.Stop(); err != nil {
t.Logf("failed to stop mock server: %v", err)
}
federatedTokenEndpoint = originalFederatedTokenEndpoint
accessTokenEndpoint = originalAccessTokenEndpoint
}()
// Make the first token exchange call to plugin. Plugin should call backend.
stsRespJSON, _ := tmPlugin.ExchangeToken(defaultSTSRequest())
stsResp := &stsservice.StsResponseParameters{}
if err := json.Unmarshal(stsRespJSON, stsResp); err != nil {
t.Errorf("failed to unmarshal STS response: %v", err)
}
firstToken := stsResp.AccessToken
numFTCalls := ms.NumGetFederatedTokenCalls()
numATCalls := ms.NumGetAccessTokenCalls()
if numFTCalls != 1 {
t.Errorf("number of get federated token API calls does not match, expected 1 but got %d", numFTCalls)
}
if numATCalls != 1 {
t.Errorf("number of get access token API calls does not match, expected 1 but got %d", numATCalls)
}
// Make the second token exchange call to plugin. Plugin should return cached token.
stsRespJSON, _ = tmPlugin.ExchangeToken(defaultSTSRequest())
stsResp = &stsservice.StsResponseParameters{}
if err := json.Unmarshal(stsRespJSON, stsResp); err != nil {
t.Errorf("failed to unmarshal STS response: %v", err)
}
secondToken := stsResp.AccessToken
numFTCalls = ms.NumGetFederatedTokenCalls()
numATCalls = ms.NumGetAccessTokenCalls()
if numFTCalls != 1 {
t.Errorf("number of get federated token API calls does not match, expected 1 got %d", numFTCalls)
}
if numATCalls != 1 {
t.Errorf("number of get access token API calls does not match, expected 1 got %d", numATCalls)
}
if firstToken != secondToken {
t.Errorf("cached token is not used")
}
// Delete cached token
tmPlugin.ClearCache()
// Set token life time to 4 min, which is shorter than token grace period.
ms.SetTokenLifeTime(4 * 60)
// Make the third token exchange call to plugin. Cache is deleted, plugin should call backend.
stsRespJSON, _ = tmPlugin.ExchangeToken(defaultSTSRequest())
stsResp = &stsservice.StsResponseParameters{}
if err := json.Unmarshal(stsRespJSON, stsResp); err != nil {
t.Errorf("failed to unmarshal STS response: %v", err)
}
thirdToken := stsResp.AccessToken
numFTCalls = ms.NumGetFederatedTokenCalls()
numATCalls = ms.NumGetAccessTokenCalls()
if numFTCalls != 2 {
t.Errorf("number of get federated token API calls does not match, expected 2 got %d", numFTCalls)
}
if numATCalls != 2 {
t.Errorf("number of get access token API calls does not match, expected 2 got %d", numATCalls)
}
if secondToken == thirdToken {
t.Errorf("should not return cached token")
}
// Make the fourth token exchange call to plugin. Cached token is going to expire, plugin should call backend.
stsRespJSON, _ = tmPlugin.ExchangeToken(defaultSTSRequest())
stsResp = &stsservice.StsResponseParameters{}
if err := json.Unmarshal(stsRespJSON, stsResp); err != nil {
t.Errorf("failed to unmarshal STS response: %v", err)
}
fourthToken := stsResp.AccessToken
numFTCalls = ms.NumGetFederatedTokenCalls()
numATCalls = ms.NumGetAccessTokenCalls()
if numFTCalls != 3 {
t.Errorf("number of get federated token API calls does not match, expected 3 got %d", numFTCalls)
}
if numATCalls != 3 {
t.Errorf("number of get access token API calls does not match, expected 3 got %d", numATCalls)
}
if thirdToken == fourthToken {
t.Errorf("should not return cached token")
}
}
// TestAccessTokenRequestToJson verifies the result of AccessTokenRequest-to-Json conversion.
func TestAccessTokenRequestToJson(t *testing.T) {
tests := []struct {
name string
delegates []string
scope []string
lifetime Duration
want string
}{
{
name: "OneHourInSecondLifetime",
delegates: []string{},
scope: []string{"https://www.googleapis.com/auth/cloud-platform"},
lifetime: Duration{Seconds: 3600},
want: `{"name":"OneHourInSecondLifetime","delegates":[],"scope":["https://www.googleapis.com/auth/cloud-platform"],"lifetime":{"seconds":3600}}`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
query := accessTokenRequest{
Name: tt.name,
Delegates: tt.delegates,
Scope: tt.scope,
LifeTime: tt.lifetime,
}
jsonQuery, err := json.Marshal(query)
if err != nil {
t.Errorf("%s: query: %v, err in json.Marshal: %v", tt.name, query, err)
}
got := string(jsonQuery)
if !(got == tt.want) {
t.Errorf("%s: got: %v, want: %v", tt.name, got, tt.want)
}
})
}
}