package tokenmanager
import (
stsServer ""
// Number of test client to create for testing.
const numClient = 10
var stsServerAddress string
// TestStsService sets up a STS server and token manager which has enabled Google
// token exchange plugin, a mock authorization server for token service, and
// verifies STS flows.
func TestStsFlow(t *testing.T) {
stsServer, mockBackend, clients := setUpTestComponents(t, testSetUp{})
defer tearDownTest(t, stsServer, mockBackend)
federatedTokenReceivedTime := time.Time{}
accessTokenReceivedTime := time.Time{}
for i := 0; i < numClient; i++ {
resp, err := sendHTTPRequestWithRetry(clients[i], genStsReq(t))
if err != nil {
t.Fatalf("client %d: failure in sending STS request: %v", i, err)
verifyStsResponse(t, resp)
resp, err = sendHTTPRequestWithRetry(clients[i], genDumpReq(t))
if err != nil {
t.Fatalf("client %d: failure in sending STS request: %v", i, err)
federatedTokenReceivedTime, accessTokenReceivedTime = verifyDumpResponse(t, resp, federatedTokenReceivedTime, accessTokenReceivedTime)
// TestStsCache enables caching at token exchange plugin, which will return cached token if that token
// is not going to expire soon.
func TestStsCache(t *testing.T) {
stsServer, mockBackend, clients := setUpTestComponents(t, testSetUp{enableCache: true, enableDynamicToken: true})
defer tearDownTest(t, stsServer, mockBackend)
accessToken := ""
for i := 0; i < numClient; i++ {
resp, err := sendHTTPRequestWithRetry(clients[i], genStsReq(t))
if err != nil {
t.Fatalf("client %d: failure in sending STS request: %v", i, err)
if resp.StatusCode != http.StatusOK {
t.Errorf("response HTTP status code does not match, get %d vs expected %d",
resp.StatusCode, http.StatusOK)
body, _ := io.ReadAll(resp.Body)
respStsParam := &stsservice.StsResponseParameters{}
json.Unmarshal(body, respStsParam)
if i == 0 {
accessToken = respStsParam.AccessToken
if i > 0 {
if accessToken != respStsParam.AccessToken {
t.Errorf("cached token is not in use")
if mockBackend.NumGetFederatedTokenCalls() != 1 {
t.Errorf("Number of get federated token API calls does not match, expected 1 but got %d", mockBackend.NumGetFederatedTokenCalls())
if mockBackend.NumGetAccessTokenCalls() != 1 {
t.Errorf("Number of get access token API calls does not match, expected 1 but got %d", mockBackend.NumGetAccessTokenCalls())
func genDumpReq(t *testing.T) (req *http.Request) {
dumpURL := "http://" + stsServerAddress + stsServer.StsStatusPath
req, _ = http.NewRequest("GET", dumpURL, nil)
reqDump, _ := httputil.DumpRequest(req, true)
t.Logf("status dump request:\n%s", string(reqDump))
return req
// verifyDumpResponse parses token info from dump response, and verifies that
// issue time of federated token and access token have updated by comparing them
// with oldFTime and oldATime, and returns new issue time of federated token and access token.
func verifyDumpResponse(t *testing.T, resp *http.Response, oldFTime, oldATime time.Time) (newFTime, newATime time.Time) {
if resp.StatusCode != http.StatusOK {
t.Errorf("response HTTP status code does not match, get %d vs expected %d",
resp.StatusCode, http.StatusOK)
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
tokenDump := &stsservice.TokensDump{}
if err := json.Unmarshal(body, tokenDump); err != nil {
t.Errorf("failed to unmarshal status dump response: %v", err)
for _, info := range tokenDump.Tokens {
if info.TokenType == "access token" {
newFTime = info.IssueTime
if newFTime == oldFTime {
t.Errorf("federated token issue time does not change: %s", newFTime)
} else {
newATime = info.IssueTime
if newATime == oldATime {
t.Errorf("access token issue time does not change: %s", newATime)
return newFTime, newATime
// verifyStsResponse verifies that received STS response has valid parameter values.
func verifyStsResponse(t *testing.T, resp *http.Response) {
if resp.StatusCode != http.StatusOK {
t.Errorf("response HTTP status code does not match, get %d vs expected %d",
resp.StatusCode, http.StatusOK)
ctVal := resp.Header.Get("Content-Type")
if ctVal != "application/json" {
t.Errorf("response header Content-Type does not match, get %s vs expected application/json",
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
respStsParam := &stsservice.StsResponseParameters{}
if err := json.Unmarshal(body, respStsParam); err != nil {
t.Errorf("failed to unmarshal STS success response: %v", err)
if respStsParam.AccessToken == "" {
t.Errorf("failed to get access token from STS response parameters: %+v", respStsParam)
if respStsParam.IssuedTokenType != "urn:ietf:params:oauth:token-type:access_token" {
t.Errorf("unexpected issued token type from STS response parameters: %+v", respStsParam)
if respStsParam.TokenType != "Bearer" {
t.Errorf("unexpected token type from STS response parameters: %+v", respStsParam)
func genStsReq(t *testing.T) (req *http.Request) {
stsQuery := url.Values{}
stsQuery.Set("grant_type", stsServer.TokenExchangeGrantType)
stsQuery.Set("resource", "https//")
stsQuery.Set("audience", "audience")
stsQuery.Set("scope", "")
stsQuery.Set("requested_token_type", "urn:ietf:params:oauth:token-type:access_token")
stsQuery.Set("subject_token", mock.FakeSubjectToken)
stsQuery.Set("subject_token_type", stsServer.SubjectTokenType)
stsQuery.Set("actor_token", "")
stsQuery.Set("actor_token_type", "")
stsURL := "http://" + stsServerAddress + stsServer.TokenPath
req, _ = http.NewRequest("POST", stsURL, strings.NewReader(stsQuery.Encode()))
req.Header.Set("Content-Type", stsServer.URLEncodedForm)
reqDump, _ := httputil.DumpRequest(req, true)
t.Logf("STS request:\n%s", string(reqDump))
return req
func sendHTTPRequestWithRetry(client *http.Client, req *http.Request) (resp *http.Response, err error) {
for i := 0; i < 10; i++ {
resp, err = client.Do(req)
if err == nil {
return resp, nil
time.Sleep(100 * time.Millisecond)
return resp, err
type testSetUp struct {
enableCache bool
enableDynamicToken bool
// setUpTest sets up components for the STS flow, including a STS server, a
// token manager, and an authorization server.
func setUpTestComponents(t *testing.T, setup testSetUp) (*stsServer.Server, *mock.AuthorizationServer, []*http.Client) {
// Create mock authorization server
mockServer, err := mock.StartNewServer(t, mock.Config{Port: 0})
if err != nil {
t.Fatalf("failed to start a mock server: %v", err)
// Create token exchange Google plugin
tokenExchangePlugin, _ := google.CreateTokenManagerPlugin(nil, mock.FakeTrustDomain, mock.FakeProjectNum,
mock.FakeGKEClusterURL, setup.enableCache)
federatedTokenTestingEndpoint := mockServer.URL + "/v1/token"
accessTokenTestingEndpoint := mockServer.URL + "/v1/projects/-/serviceAccounts/"
tokenExchangePlugin.SetEndpoints(federatedTokenTestingEndpoint, accessTokenTestingEndpoint)
// Create token manager
tokenManager := &TokenManager{}
// Create STS server
server, _ := stsServer.NewServer(stsServer.Config{LocalHostAddr: "", LocalPort: 0}, tokenManager)
// Create test client
stsServerAddress = fmt.Sprintf("", server.Port)
clients := []*http.Client{}
for i := 0; i < numClient; i++ {
hTTPClient := &http.Client{
Transport: &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
t.Logf("set up server address to dial %s", addr)
addr = stsServerAddress
return net.Dial(network, addr)
clients = append(clients, hTTPClient)
return server, mockServer, clients
func tearDownTest(t *testing.T, stsServer *stsServer.Server, backend *mock.AuthorizationServer) {
if err := backend.Stop(); err != nil {
t.Logf("failed to stop mock server: %v", err)