blob: 84647dba11d6bbce0daf2bbae45be76bb0d2d76a [file] [log] [blame]
// Copyright (c) 2016 VMware, Inc. All Rights Reserved.
//
// This product is licensed to you under the Apache License, Version 2.0 (the "License").
// You may not use this product except in compliance with the License.
//
// This product may include a number of subcomponents with separate copyright notices and
// license terms. Your use of these subcomponents is subject to the terms and conditions
// of the subcomponent's license, as noted in the LICENSE file.
package lightwave
import (
"crypto/tls"
"crypto/x509"
"encoding/json"
"encoding/pem"
"fmt"
"io/ioutil"
"log"
"net/http"
"net/url"
"strings"
)
const tokenScope string = "openid offline_access"
type OIDCClient struct {
httpClient *http.Client
logger *log.Logger
Endpoint string
Options *OIDCClientOptions
}
type OIDCClientOptions struct {
// Whether or not to ignore any TLS errors when talking to photon,
// false by default.
IgnoreCertificate bool
// List of root CA's to use for server validation
// nil by default.
RootCAs *x509.CertPool
// The scope values to use when requesting tokens
TokenScope string
}
func NewOIDCClient(endpoint string, options *OIDCClientOptions, logger *log.Logger) (c *OIDCClient) {
if logger == nil {
logger = log.New(ioutil.Discard, "", log.LstdFlags)
}
options = buildOptions(options)
tr := &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: options.IgnoreCertificate,
RootCAs: options.RootCAs},
}
c = &OIDCClient{
httpClient: &http.Client{Transport: tr},
logger: logger,
Endpoint: strings.TrimRight(endpoint, "/"),
Options: options,
}
return
}
func buildOptions(options *OIDCClientOptions) (result *OIDCClientOptions) {
result = &OIDCClientOptions{
TokenScope: tokenScope,
}
if options == nil {
return
}
result.IgnoreCertificate = options.IgnoreCertificate
if options.RootCAs != nil {
result.RootCAs = options.RootCAs
}
if options.TokenScope != "" {
result.TokenScope = options.TokenScope
}
return
}
func (client *OIDCClient) buildUrl(path string) (url string) {
return fmt.Sprintf("%s%s", client.Endpoint, path)
}
// Cert download helper
const certDownloadPath string = "/afd/vecs/ssl"
type lightWaveCert struct {
Value string `json:"encoded"`
}
func (client *OIDCClient) GetRootCerts() (certList []*x509.Certificate, err error) {
// turn TLS verification off for
originalTr := client.httpClient.Transport
defer client.setTransport(originalTr)
tr := &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true,
},
}
client.setTransport(tr)
// get the certs
resp, err := client.httpClient.Get(client.buildUrl(certDownloadPath))
if err != nil {
return
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
err = fmt.Errorf("Unexpected error retrieving auth server certs: %v %s", resp.StatusCode, resp.Status)
return
}
// parse the certs
certsData := &[]lightWaveCert{}
err = json.NewDecoder(resp.Body).Decode(certsData)
if err != nil {
return
}
certList = make([]*x509.Certificate, len(*certsData))
for idx, cert := range *certsData {
block, _ := pem.Decode([]byte(cert.Value))
if block == nil {
err = fmt.Errorf("Unexpected response format: %v", certsData)
return nil, err
}
decodedCert, err := x509.ParseCertificate(block.Bytes)
if err != nil {
return nil, err
}
certList[idx] = decodedCert
}
return
}
func (client *OIDCClient) setTransport(tr http.RoundTripper) {
client.httpClient.Transport = tr
}
// Toke request helpers
const tokenPath string = "/openidconnect/token"
const passwordGrantFormatString = "grant_type=password&username=%s&password=%s&scope=%s"
const refreshTokenGrantFormatString = "grant_type=refresh_token&refresh_token=%s"
type OIDCTokenResponse struct {
AccessToken string `json:"access_token"`
ExpiresIn int `json:"expires_in"`
RefreshToken string `json:"refresh_token,omitempty"`
IdToken string `json:"id_token"`
TokenType string `json:"token_type"`
}
func (client *OIDCClient) GetTokenByPasswordGrant(username string, password string) (tokens *OIDCTokenResponse, err error) {
username = url.QueryEscape(username)
password = url.QueryEscape(password)
body := fmt.Sprintf(passwordGrantFormatString, username, password, client.Options.TokenScope)
return client.getToken(body)
}
func (client *OIDCClient) GetTokenByRefreshTokenGrant(refreshToken string) (tokens *OIDCTokenResponse, err error) {
body := fmt.Sprintf(refreshTokenGrantFormatString, refreshToken)
return client.getToken(body)
}
func (client *OIDCClient) getToken(body string) (tokens *OIDCTokenResponse, err error) {
request, err := http.NewRequest("POST", client.buildUrl(tokenPath), strings.NewReader(body))
if err != nil {
return nil, err
}
request.Header.Add("Content-Type", "application/x-www-form-urlencoded")
resp, err := client.httpClient.Do(request)
if err != nil {
return nil, err
}
defer resp.Body.Close()
err = client.checkResponse(resp)
if err != nil {
return nil, err
}
tokens = &OIDCTokenResponse{}
err = json.NewDecoder(resp.Body).Decode(tokens)
if err != nil {
return nil, err
}
return
}
type OIDCError struct {
Code string `json:"error"`
Message string `json:"error_description"`
}
func (e OIDCError) Error() string {
return fmt.Sprintf("%v: %v", e.Code, e.Message)
}
func (client *OIDCClient) checkResponse(response *http.Response) (err error) {
if response.StatusCode/100 == 2 {
return
}
respBody, readErr := ioutil.ReadAll(response.Body)
if err != nil {
return fmt.Errorf(
"Status: %v, Body: %v [%v]", response.Status, string(respBody[:]), readErr)
}
var oidcErr OIDCError
err = json.Unmarshal(respBody, &oidcErr)
if err != nil {
return fmt.Errorf(
"Status: %v, Body: %v [%v]", response.Status, string(respBody[:]), readErr)
}
return oidcErr
}