blob: b08f87878d4ce1a711a8049ed68ae4d54d40b697 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package oauth2
import (
"bytes"
"context"
"encoding/json"
"io/ioutil"
"net/http"
"strings"
"time"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
type MockTransport struct {
Responses []*http.Response
ReturnError error
}
var _ HTTPAuthTransport = &MockTransport{}
func (t *MockTransport) Do(req *http.Request) (*http.Response, error) {
if len(t.Responses) > 0 {
r := t.Responses[0]
t.Responses = t.Responses[1:]
return r, nil
}
return nil, t.ReturnError
}
var _ = Describe("CodetokenExchanger", func() {
Describe("newExchangeCodeRequest", func() {
It("creates the request", func() {
tokenRetriever := TokenRetriever{}
exchangeRequest := AuthorizationCodeExchangeRequest{
TokenEndpoint: "https://issuer/oauth/token",
ClientID: "clientID",
CodeVerifier: "Verifier",
Code: "code",
RedirectURI: "https://redirect",
}
result, err := tokenRetriever.newExchangeCodeRequest(exchangeRequest)
result.ParseForm()
Expect(err).To(BeNil())
Expect(result.FormValue("grant_type")).To(Equal("authorization_code"))
Expect(result.FormValue("client_id")).To(Equal("clientID"))
Expect(result.FormValue("code_verifier")).To(Equal("Verifier"))
Expect(result.FormValue("code")).To(Equal("code"))
Expect(result.FormValue("redirect_uri")).To(Equal("https://redirect"))
Expect(result.URL.String()).To(Equal("https://issuer/oauth/token"))
Expect(result.Header.Get("Content-Type")).To(Equal("application/x-www-form-urlencoded"))
Expect(result.Header.Get("Content-Length")).To(Equal("117"))
})
It("returns an error when NewRequest returns an error", func() {
tokenRetriever := TokenRetriever{}
result, err := tokenRetriever.newExchangeCodeRequest(AuthorizationCodeExchangeRequest{
TokenEndpoint: "://issuer/oauth/token",
})
Expect(result).To(BeNil())
Expect(err.Error()).To(Equal("parse ://issuer/oauth/token: missing protocol scheme"))
})
})
Describe("handleAuthTokensResponse", func() {
It("handles the response", func() {
tokenRetriever := TokenRetriever{}
response := buildResponse(200, AuthorizationTokenResponse{
ExpiresIn: 1,
AccessToken: "myAccessToken",
RefreshToken: "myRefreshToken",
})
result, err := tokenRetriever.handleAuthTokensResponse(response)
Expect(err).To(BeNil())
Expect(result).To(Equal(&TokenResult{
ExpiresIn: 1,
AccessToken: "myAccessToken",
RefreshToken: "myRefreshToken",
}))
})
It("returns error when status code is not successful", func() {
tokenRetriever := TokenRetriever{}
response := buildResponse(500, nil)
result, err := tokenRetriever.handleAuthTokensResponse(response)
Expect(result).To(BeNil())
Expect(err.Error()).To(Equal("a non-success status code was received: 500"))
})
It("returns typed error when response body contains error information", func() {
errorBody := TokenErrorResponse{Error: "test", ErrorDescription: "test description"}
tokenRetriever := TokenRetriever{}
response := buildResponse(400, errorBody)
result, err := tokenRetriever.handleAuthTokensResponse(response)
Expect(result).To(BeNil())
Expect(err).To(Equal(&TokenError{ErrorCode: "test", ErrorDescription: "test description"}))
Expect(err.Error()).To(Equal("test description (test)"))
})
It("returns error when deserialization fails", func() {
tokenRetriever := TokenRetriever{}
response := buildResponse(200, "")
result, err := tokenRetriever.handleAuthTokensResponse(response)
Expect(result).To(BeNil())
Expect(err.Error()).To(Equal(
"json: cannot unmarshal string into Go value of type oauth2.AuthorizationTokenResponse"))
})
})
Describe("newRefreshTokenRequest", func() {
It("creates the request", func() {
tokenRetriever := TokenRetriever{}
exchangeRequest := RefreshTokenExchangeRequest{
TokenEndpoint: "https://issuer/oauth/token",
ClientID: "clientID",
RefreshToken: "refreshToken",
}
result, err := tokenRetriever.newRefreshTokenRequest(exchangeRequest)
result.ParseForm()
Expect(err).To(BeNil())
Expect(result.FormValue("grant_type")).To(Equal("refresh_token"))
Expect(result.FormValue("client_id")).To(Equal("clientID"))
Expect(result.FormValue("refresh_token")).To(Equal("refreshToken"))
Expect(result.URL.String()).To(Equal("https://issuer/oauth/token"))
Expect(result.Header.Get("Content-Type")).To(Equal("application/x-www-form-urlencoded"))
Expect(result.Header.Get("Content-Length")).To(Equal("70"))
})
It("returns an error when NewRequest returns an error", func() {
tokenRetriever := TokenRetriever{}
result, err := tokenRetriever.newRefreshTokenRequest(RefreshTokenExchangeRequest{
TokenEndpoint: "://issuer/oauth/token",
})
Expect(result).To(BeNil())
Expect(err.Error()).To(Equal("parse ://issuer/oauth/token: missing protocol scheme"))
})
})
Describe("newClientCredentialsRequest", func() {
It("creates the request", func() {
tokenRetriever := TokenRetriever{}
exchangeRequest := ClientCredentialsExchangeRequest{
TokenEndpoint: "https://issuer/oauth/token",
ClientID: "clientID",
ClientSecret: "clientSecret",
Audience: "audience",
}
result, err := tokenRetriever.newClientCredentialsRequest(exchangeRequest)
result.ParseForm()
Expect(err).To(BeNil())
Expect(result.FormValue("grant_type")).To(Equal("client_credentials"))
Expect(result.FormValue("client_id")).To(Equal("clientID"))
Expect(result.FormValue("client_secret")).To(Equal("clientSecret"))
Expect(result.FormValue("audience")).To(Equal("audience"))
Expect(result.URL.String()).To(Equal("https://issuer/oauth/token"))
Expect(result.Header.Get("Content-Type")).To(Equal("application/x-www-form-urlencoded"))
Expect(result.Header.Get("Content-Length")).To(Equal("93"))
})
It("returns an error when NewRequest returns an error", func() {
tokenRetriever := TokenRetriever{}
result, err := tokenRetriever.newClientCredentialsRequest(ClientCredentialsExchangeRequest{
TokenEndpoint: "://issuer/oauth/token",
})
Expect(result).To(BeNil())
Expect(err.Error()).To(Equal("parse ://issuer/oauth/token: missing protocol scheme"))
})
})
Describe("newDeviceCodeExchangeRequest", func() {
It("creates the request", func() {
tokenRetriever := TokenRetriever{}
exchangeRequest := DeviceCodeExchangeRequest{
TokenEndpoint: "https://issuer/oauth/token",
ClientID: "clientID",
DeviceCode: "deviceCode",
PollInterval: time.Duration(5) * time.Second,
}
result, err := tokenRetriever.newDeviceCodeExchangeRequest(exchangeRequest)
result.ParseForm()
Expect(err).To(BeNil())
Expect(result.FormValue("grant_type")).To(Equal("urn:ietf:params:oauth:grant-type:device_code"))
Expect(result.FormValue("client_id")).To(Equal("clientID"))
Expect(result.FormValue("device_code")).To(Equal("deviceCode"))
Expect(result.URL.String()).To(Equal("https://issuer/oauth/token"))
Expect(result.Header.Get("Content-Type")).To(Equal("application/x-www-form-urlencoded"))
Expect(result.Header.Get("Content-Length")).To(Equal("107"))
})
It("returns an error when NewRequest returns an error", func() {
tokenRetriever := TokenRetriever{}
result, err := tokenRetriever.newClientCredentialsRequest(ClientCredentialsExchangeRequest{
TokenEndpoint: "://issuer/oauth/token",
})
Expect(result).To(BeNil())
Expect(err.Error()).To(Equal("parse ://issuer/oauth/token: missing protocol scheme"))
})
})
Describe("ExchangeDeviceCode", func() {
var mockTransport *MockTransport
var tokenRetriever *TokenRetriever
var exchangeRequest DeviceCodeExchangeRequest
var tokenResult TokenResult
BeforeEach(func() {
mockTransport = &MockTransport{}
tokenRetriever = &TokenRetriever{
transport: mockTransport,
}
exchangeRequest = DeviceCodeExchangeRequest{
TokenEndpoint: "https://issuer/oauth/token",
ClientID: "clientID",
DeviceCode: "deviceCode",
PollInterval: time.Duration(1) * time.Second,
}
tokenResult = TokenResult{
ExpiresIn: 1,
AccessToken: "myAccessToken",
RefreshToken: "myRefreshToken",
}
})
It("returns a token", func() {
})
It("supports cancellation", func() {
mockTransport.Responses = []*http.Response{
buildResponse(400, &TokenErrorResponse{"authorization_pending", ""}),
}
ctx, cancel := context.WithCancel(context.Background())
cancel()
_, err := tokenRetriever.ExchangeDeviceCode(ctx, exchangeRequest)
Expect(err).ToNot(BeNil())
Expect(err.Error()).To(Equal("cancelled"))
})
It("implements authorization_pending and slow_down", func() {
startTime := time.Now()
mockTransport.Responses = []*http.Response{
buildResponse(400, &TokenErrorResponse{"authorization_pending", ""}),
buildResponse(400, &TokenErrorResponse{"authorization_pending", ""}),
buildResponse(400, &TokenErrorResponse{"slow_down", ""}),
buildResponse(200, &tokenResult),
}
token, err := tokenRetriever.ExchangeDeviceCode(context.Background(), exchangeRequest)
Expect(err).To(BeNil())
Expect(token).To(Equal(&tokenResult))
endTime := time.Now()
Expect(endTime.Sub(startTime)).To(BeNumerically(">", exchangeRequest.PollInterval*3))
})
It("implements expired_token", func() {
mockTransport.Responses = []*http.Response{
buildResponse(400, &TokenErrorResponse{"expired_token", ""}),
}
_, err := tokenRetriever.ExchangeDeviceCode(context.Background(), exchangeRequest)
Expect(err).ToNot(BeNil())
Expect(err.Error()).To(Equal("the device code has expired"))
})
It("implements access_denied", func() {
mockTransport.Responses = []*http.Response{
buildResponse(400, &TokenErrorResponse{"access_denied", ""}),
}
_, err := tokenRetriever.ExchangeDeviceCode(context.Background(), exchangeRequest)
Expect(err).ToNot(BeNil())
Expect(err.Error()).To(Equal("the device was not authorized"))
})
})
})
func buildResponse(statusCode int, body interface{}) *http.Response {
b, err := json.Marshal(body)
if err != nil {
panic(err)
}
resp := &http.Response{
StatusCode: statusCode,
Header: map[string][]string{},
Body: ioutil.NopCloser(bytes.NewReader(b)),
}
if strings.HasPrefix(string(b), "{") {
resp.Header.Add("Content-Type", "application/json")
}
return resp
}