| // Copyright (c) 2017 VMware, Inc. All Rights Reserved. |
| // |
| // This product is licensed to you under the Apache License, Version 2.0 (the "License"). |
| // You may not use this product except in compliance with the License. |
| // |
| // This product may include a number of subcomponents with separate copyright notices and |
| // license terms. Your use of these subcomponents is subject to the terms and conditions |
| // of the subcomponent's license, as noted in the LICENSE file. |
| |
| // +build windows |
| |
| package lightwave |
| |
| import ( |
| "encoding/base64" |
| "fmt" |
| "github.com/vmware/photon-controller-go-sdk/SSPI" |
| "math/rand" |
| "net" |
| "net/url" |
| "strings" |
| "time" |
| ) |
| |
| const gssTicketGrantFormatString = "grant_type=urn:vmware:grant_type:gss_ticket&gss_ticket=%s&context_id=%s&scope=%s" |
| |
| // GetTokensFromWindowsLogInContext gets tokens based on Windows logged in context |
| // Here is how it works: |
| // 1. Get the SPN (Service Principal Name) in the format host/FQDN of lightwave. This is needed for SSPI/Kerberos protocol |
| // 2. Call Windows API AcquireCredentialsHandle() using SSPI library. This will give the current users credential handle |
| // 3. Using this handle call Windows API AcquireCredentialsHandle(). This will give you byte[] |
| // 4. Encode this byte[] and send it to OIDC server over HTTP (using POST) |
| // 5. OIDC server can send either of the following |
| // - Access tokens. In this case return access tokens to client |
| // - Error in the format: invalid_grant: gss_continue_needed:'context id':'token from server' |
| // 6. In case you get error, parse it and get the token from server |
| // 7. Feed this token to step 3 and repeat steps till you get the access tokens from server |
| func (client *OIDCClient) GetTokensFromWindowsLogInContext() (tokens *OIDCTokenResponse, err error) { |
| spn, err := client.buildSPN() |
| if err != nil { |
| return nil, err |
| } |
| |
| auth, _ := SSPI.GetAuth("", "", spn, "") |
| |
| userContext, err := auth.InitialBytes() |
| if err != nil { |
| return nil, err |
| } |
| |
| // In case of multiple req/res between client and server (as explained in above comment), |
| // server needs to maintain the mapping of context id -> token |
| // So we need to generate random string as a context id |
| // If we use same context id for all the requests, results can be erroneous |
| contextId := client.generateRandomString() |
| body := fmt.Sprintf(gssTicketGrantFormatString, url.QueryEscape(base64.StdEncoding.EncodeToString(userContext)), contextId, client.Options.TokenScope) |
| tokens, err = client.getToken(body) |
| |
| for { |
| if err == nil { |
| break |
| } |
| |
| // In case of error the response will be in format: invalid_grant: gss_continue_needed:'context id':'token from server' |
| gssToken := client.validateAndExtractGSSResponse(err, contextId) |
| if gssToken == "" { |
| return nil, err |
| } |
| |
| data, err := base64.StdEncoding.DecodeString(gssToken) |
| if err != nil { |
| return nil, err |
| } |
| |
| userContext, err := auth.NextBytes(data) |
| body := fmt.Sprintf(gssTicketGrantFormatString, url.QueryEscape(base64.StdEncoding.EncodeToString(userContext)), contextId, client.Options.TokenScope) |
| tokens, err = client.getToken(body) |
| } |
| |
| return tokens, err |
| } |
| |
| // Gets the SPN (Service Principal Name) in the format host/FQDN of lightwave |
| func (client *OIDCClient) buildSPN() (spn string, err error) { |
| u, err := url.Parse(client.Endpoint) |
| if err != nil { |
| return "", err |
| } |
| |
| host, _, err := net.SplitHostPort(u.Host) |
| if err != nil { |
| return "", err |
| } |
| |
| addr, err := net.LookupAddr(host) |
| if err != nil { |
| return "", err |
| } |
| |
| var s = strings.TrimSuffix(addr[0], ".") |
| return "host/" + s, nil |
| } |
| |
| // validateAndExtractGSSResponse parse the error from server and returns token from server |
| // In case of error from the server, response will be in format: invalid_grant: gss_continue_needed:'context id':'token from server' |
| // So, we check for the above format in error and then return the token from server |
| // If error is not in above format, we return empty string |
| func (client *OIDCClient) validateAndExtractGSSResponse(err error, contextId string) string { |
| parts := strings.Split(err.Error(), ":") |
| if !(len(parts) == 4 && strings.TrimSpace(parts[1]) == "gss_continue_needed" && parts[2] == contextId) { |
| return "" |
| } else { |
| return parts[3] |
| } |
| } |
| |
| func (client *OIDCClient) generateRandomString() string { |
| const length = 10 |
| const asciiA = 65 |
| const asciiZ = 90 |
| rand.Seed(time.Now().UTC().UnixNano()) |
| bytes := make([]byte, length) |
| for i := 0; i < length; i++ { |
| bytes[i] = byte(randInt(asciiA, asciiZ)) |
| } |
| return string(bytes) |
| } |
| |
| func randInt(min int, max int) int { |
| return min + rand.Intn(max-min) |
| } |