| // Copyright Istio Authors |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| package tokenmanager |
| |
| import ( |
| "context" |
| "encoding/json" |
| "fmt" |
| "io" |
| "net" |
| "net/http" |
| "net/http/httputil" |
| "net/url" |
| "strings" |
| "testing" |
| "time" |
| ) |
| |
| import ( |
| "github.com/apache/dubbo-go-pixiu/security/pkg/stsservice" |
| stsServer "github.com/apache/dubbo-go-pixiu/security/pkg/stsservice/server" |
| "github.com/apache/dubbo-go-pixiu/security/pkg/stsservice/tokenmanager/google" |
| "github.com/apache/dubbo-go-pixiu/security/pkg/stsservice/tokenmanager/google/mock" |
| ) |
| |
| // Number of test client to create for testing. |
| const numClient = 10 |
| |
| var stsServerAddress string |
| |
| // TestStsService sets up a STS server and token manager which has enabled Google |
| // token exchange plugin, a mock authorization server for token service, and |
| // verifies STS flows. |
| func TestStsFlow(t *testing.T) { |
| stsServer, mockBackend, clients := setUpTestComponents(t, testSetUp{}) |
| defer tearDownTest(t, stsServer, mockBackend) |
| |
| federatedTokenReceivedTime := time.Time{} |
| accessTokenReceivedTime := time.Time{} |
| for i := 0; i < numClient; i++ { |
| resp, err := sendHTTPRequestWithRetry(clients[i], genStsReq(t)) |
| if err != nil { |
| t.Fatalf("client %d: failure in sending STS request: %v", i, err) |
| } |
| verifyStsResponse(t, resp) |
| |
| resp, err = sendHTTPRequestWithRetry(clients[i], genDumpReq(t)) |
| if err != nil { |
| t.Fatalf("client %d: failure in sending STS request: %v", i, err) |
| } |
| federatedTokenReceivedTime, accessTokenReceivedTime = verifyDumpResponse(t, resp, federatedTokenReceivedTime, accessTokenReceivedTime) |
| } |
| } |
| |
| // TestStsCache enables caching at token exchange plugin, which will return cached token if that token |
| // is not going to expire soon. |
| func TestStsCache(t *testing.T) { |
| stsServer, mockBackend, clients := setUpTestComponents(t, testSetUp{enableCache: true, enableDynamicToken: true}) |
| defer tearDownTest(t, stsServer, mockBackend) |
| |
| accessToken := "" |
| for i := 0; i < numClient; i++ { |
| resp, err := sendHTTPRequestWithRetry(clients[i], genStsReq(t)) |
| if err != nil { |
| t.Fatalf("client %d: failure in sending STS request: %v", i, err) |
| } |
| if resp.StatusCode != http.StatusOK { |
| t.Errorf("response HTTP status code does not match, get %d vs expected %d", |
| resp.StatusCode, http.StatusOK) |
| } |
| body, _ := io.ReadAll(resp.Body) |
| respStsParam := &stsservice.StsResponseParameters{} |
| json.Unmarshal(body, respStsParam) |
| if i == 0 { |
| accessToken = respStsParam.AccessToken |
| } |
| if i > 0 { |
| if accessToken != respStsParam.AccessToken { |
| t.Errorf("cached token is not in use") |
| } |
| } |
| resp.Body.Close() |
| } |
| if mockBackend.NumGetFederatedTokenCalls() != 1 { |
| t.Errorf("Number of get federated token API calls does not match, expected 1 but got %d", mockBackend.NumGetFederatedTokenCalls()) |
| } |
| if mockBackend.NumGetAccessTokenCalls() != 1 { |
| t.Errorf("Number of get access token API calls does not match, expected 1 but got %d", mockBackend.NumGetAccessTokenCalls()) |
| } |
| } |
| |
| func genDumpReq(t *testing.T) (req *http.Request) { |
| dumpURL := "http://" + stsServerAddress + stsServer.StsStatusPath |
| req, _ = http.NewRequest("GET", dumpURL, nil) |
| |
| reqDump, _ := httputil.DumpRequest(req, true) |
| t.Logf("status dump request:\n%s", string(reqDump)) |
| return req |
| } |
| |
| // verifyDumpResponse parses token info from dump response, and verifies that |
| // issue time of federated token and access token have updated by comparing them |
| // with oldFTime and oldATime, and returns new issue time of federated token and access token. |
| func verifyDumpResponse(t *testing.T, resp *http.Response, oldFTime, oldATime time.Time) (newFTime, newATime time.Time) { |
| if resp.StatusCode != http.StatusOK { |
| t.Errorf("response HTTP status code does not match, get %d vs expected %d", |
| resp.StatusCode, http.StatusOK) |
| } |
| |
| defer resp.Body.Close() |
| body, _ := io.ReadAll(resp.Body) |
| tokenDump := &stsservice.TokensDump{} |
| if err := json.Unmarshal(body, tokenDump); err != nil { |
| t.Errorf("failed to unmarshal status dump response: %v", err) |
| } |
| for _, info := range tokenDump.Tokens { |
| if info.TokenType == "access token" { |
| newFTime = info.IssueTime |
| if newFTime == oldFTime { |
| t.Errorf("federated token issue time does not change: %s", newFTime) |
| } |
| } else { |
| newATime = info.IssueTime |
| if newATime == oldATime { |
| t.Errorf("access token issue time does not change: %s", newATime) |
| } |
| } |
| } |
| return newFTime, newATime |
| } |
| |
| // verifyStsResponse verifies that received STS response has valid parameter values. |
| func verifyStsResponse(t *testing.T, resp *http.Response) { |
| if resp.StatusCode != http.StatusOK { |
| t.Errorf("response HTTP status code does not match, get %d vs expected %d", |
| resp.StatusCode, http.StatusOK) |
| } |
| ctVal := resp.Header.Get("Content-Type") |
| if ctVal != "application/json" { |
| t.Errorf("response header Content-Type does not match, get %s vs expected application/json", |
| ctVal) |
| } |
| defer resp.Body.Close() |
| body, _ := io.ReadAll(resp.Body) |
| respStsParam := &stsservice.StsResponseParameters{} |
| if err := json.Unmarshal(body, respStsParam); err != nil { |
| t.Errorf("failed to unmarshal STS success response: %v", err) |
| } |
| if respStsParam.AccessToken == "" { |
| t.Errorf("failed to get access token from STS response parameters: %+v", respStsParam) |
| } |
| if respStsParam.IssuedTokenType != "urn:ietf:params:oauth:token-type:access_token" { |
| t.Errorf("unexpected issued token type from STS response parameters: %+v", respStsParam) |
| } |
| if respStsParam.TokenType != "Bearer" { |
| t.Errorf("unexpected token type from STS response parameters: %+v", respStsParam) |
| } |
| } |
| |
| func genStsReq(t *testing.T) (req *http.Request) { |
| stsQuery := url.Values{} |
| stsQuery.Set("grant_type", stsServer.TokenExchangeGrantType) |
| stsQuery.Set("resource", "https//:backend.example.com") |
| stsQuery.Set("audience", "audience") |
| stsQuery.Set("scope", "https://www.googleapis.com/auth/cloud-platform") |
| stsQuery.Set("requested_token_type", "urn:ietf:params:oauth:token-type:access_token") |
| stsQuery.Set("subject_token", mock.FakeSubjectToken) |
| stsQuery.Set("subject_token_type", stsServer.SubjectTokenType) |
| stsQuery.Set("actor_token", "") |
| stsQuery.Set("actor_token_type", "") |
| stsURL := "http://" + stsServerAddress + stsServer.TokenPath |
| req, _ = http.NewRequest("POST", stsURL, strings.NewReader(stsQuery.Encode())) |
| req.Header.Set("Content-Type", stsServer.URLEncodedForm) |
| reqDump, _ := httputil.DumpRequest(req, true) |
| t.Logf("STS request:\n%s", string(reqDump)) |
| return req |
| } |
| |
| func sendHTTPRequestWithRetry(client *http.Client, req *http.Request) (resp *http.Response, err error) { |
| for i := 0; i < 10; i++ { |
| resp, err = client.Do(req) |
| if err == nil { |
| return resp, nil |
| } |
| time.Sleep(100 * time.Millisecond) |
| } |
| return resp, err |
| } |
| |
| type testSetUp struct { |
| enableCache bool |
| enableDynamicToken bool |
| } |
| |
| // setUpTest sets up components for the STS flow, including a STS server, a |
| // token manager, and an authorization server. |
| func setUpTestComponents(t *testing.T, setup testSetUp) (*stsServer.Server, *mock.AuthorizationServer, []*http.Client) { |
| // Create mock authorization server |
| mockServer, err := mock.StartNewServer(t, mock.Config{Port: 0}) |
| mockServer.EnableDynamicAccessToken(setup.enableDynamicToken) |
| if err != nil { |
| t.Fatalf("failed to start a mock server: %v", err) |
| } |
| // Create token exchange Google plugin |
| tokenExchangePlugin, _ := google.CreateTokenManagerPlugin(nil, mock.FakeTrustDomain, mock.FakeProjectNum, |
| mock.FakeGKEClusterURL, setup.enableCache) |
| federatedTokenTestingEndpoint := mockServer.URL + "/v1/token" |
| accessTokenTestingEndpoint := mockServer.URL + "/v1/projects/-/serviceAccounts/service-%s@gcp-sa-meshdataplane.iam.gserviceaccount.com:generateAccessToken" |
| tokenExchangePlugin.SetEndpoints(federatedTokenTestingEndpoint, accessTokenTestingEndpoint) |
| // Create token manager |
| tokenManager := &TokenManager{} |
| tokenManager.SetPlugin(tokenExchangePlugin) |
| // Create STS server |
| server, _ := stsServer.NewServer(stsServer.Config{LocalHostAddr: "127.0.0.1", LocalPort: 0}, tokenManager) |
| // Create test client |
| stsServerAddress = fmt.Sprintf("127.0.0.1:%d", server.Port) |
| clients := []*http.Client{} |
| for i := 0; i < numClient; i++ { |
| hTTPClient := &http.Client{ |
| Transport: &http.Transport{ |
| DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { |
| t.Logf("set up server address to dial %s", addr) |
| addr = stsServerAddress |
| return net.Dial(network, addr) |
| }, |
| }, |
| } |
| clients = append(clients, hTTPClient) |
| } |
| return server, mockServer, clients |
| } |
| |
| func tearDownTest(t *testing.T, stsServer *stsServer.Server, backend *mock.AuthorizationServer) { |
| if err := backend.Stop(); err != nil { |
| t.Logf("failed to stop mock server: %v", err) |
| } |
| stsServer.Stop() |
| } |