blob: 9b6d8b7bd884d7e4c87384ce41a52f883b56a191 [file] [log] [blame]
// Copyright (c) 2017 VMware, Inc. All Rights Reserved.
//
// This product is licensed to you under the Apache License, Version 2.0 (the "License").
// You may not use this product except in compliance with the License.
//
// This product may include a number of subcomponents with separate copyright notices and
// license terms. Your use of these subcomponents is subject to the terms and conditions
// of the subcomponent's license, as noted in the LICENSE file.
// +build windows
package lightwave
import (
"encoding/base64"
"fmt"
"github.com/vmware/photon-controller-go-sdk/SSPI"
"math/rand"
"net"
"net/url"
"strings"
"time"
)
const gssTicketGrantFormatString = "grant_type=urn:vmware:grant_type:gss_ticket&gss_ticket=%s&context_id=%s&scope=%s"
// GetTokensFromWindowsLogInContext gets tokens based on Windows logged in context
// Here is how it works:
// 1. Get the SPN (Service Principal Name) in the format host/FQDN of lightwave. This is needed for SSPI/Kerberos protocol
// 2. Call Windows API AcquireCredentialsHandle() using SSPI library. This will give the current users credential handle
// 3. Using this handle call Windows API AcquireCredentialsHandle(). This will give you byte[]
// 4. Encode this byte[] and send it to OIDC server over HTTP (using POST)
// 5. OIDC server can send either of the following
// - Access tokens. In this case return access tokens to client
// - Error in the format: invalid_grant: gss_continue_needed:'context id':'token from server'
// 6. In case you get error, parse it and get the token from server
// 7. Feed this token to step 3 and repeat steps till you get the access tokens from server
func (client *OIDCClient) GetTokensFromWindowsLogInContext() (tokens *OIDCTokenResponse, err error) {
spn, err := client.buildSPN()
if err != nil {
return nil, err
}
auth, _ := SSPI.GetAuth("", "", spn, "")
userContext, err := auth.InitialBytes()
if err != nil {
return nil, err
}
// In case of multiple req/res between client and server (as explained in above comment),
// server needs to maintain the mapping of context id -> token
// So we need to generate random string as a context id
// If we use same context id for all the requests, results can be erroneous
contextId := client.generateRandomString()
body := fmt.Sprintf(gssTicketGrantFormatString, url.QueryEscape(base64.StdEncoding.EncodeToString(userContext)), contextId, client.Options.TokenScope)
tokens, err = client.getToken(body)
for {
if err == nil {
break
}
// In case of error the response will be in format: invalid_grant: gss_continue_needed:'context id':'token from server'
gssToken := client.validateAndExtractGSSResponse(err, contextId)
if gssToken == "" {
return nil, err
}
data, err := base64.StdEncoding.DecodeString(gssToken)
if err != nil {
return nil, err
}
userContext, err := auth.NextBytes(data)
body := fmt.Sprintf(gssTicketGrantFormatString, url.QueryEscape(base64.StdEncoding.EncodeToString(userContext)), contextId, client.Options.TokenScope)
tokens, err = client.getToken(body)
}
return tokens, err
}
// Gets the SPN (Service Principal Name) in the format host/FQDN of lightwave
func (client *OIDCClient) buildSPN() (spn string, err error) {
u, err := url.Parse(client.Endpoint)
if err != nil {
return "", err
}
host, _, err := net.SplitHostPort(u.Host)
if err != nil {
return "", err
}
addr, err := net.LookupAddr(host)
if err != nil {
return "", err
}
var s = strings.TrimSuffix(addr[0], ".")
return "host/" + s, nil
}
// validateAndExtractGSSResponse parse the error from server and returns token from server
// In case of error from the server, response will be in format: invalid_grant: gss_continue_needed:'context id':'token from server'
// So, we check for the above format in error and then return the token from server
// If error is not in above format, we return empty string
func (client *OIDCClient) validateAndExtractGSSResponse(err error, contextId string) string {
parts := strings.Split(err.Error(), ":")
if !(len(parts) == 4 && strings.TrimSpace(parts[1]) == "gss_continue_needed" && parts[2] == contextId) {
return ""
} else {
return parts[3]
}
}
func (client *OIDCClient) generateRandomString() string {
const length = 10
const asciiA = 65
const asciiZ = 90
rand.Seed(time.Now().UTC().UnixNano())
bytes := make([]byte, length)
for i := 0; i < length; i++ {
bytes[i] = byte(randInt(asciiA, asciiZ))
}
return string(bytes)
}
func randInt(min int, max int) int {
return min + rand.Intn(max-min)
}