blob: bc0b4445fe56895db15fc6ec9a93eecef0a41ee0 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package oauth2
import (
"context"
"errors"
"github.com/go-chassis/go-chassis/core/handler"
"github.com/go-chassis/go-chassis/core/invocation"
"github.com/go-mesh/openlogging"
"net/http"
"time"
)
// errors
var (
ErrInvalidState = errors.New("invalid state")
ErrInvalidCode = errors.New("invalid code")
ErrInvalidToken = errors.New("invalid authorization")
ErrInvalidAuth = errors.New("invalid authentication")
ErrExpiredToken = errors.New("expired token")
)
// AuthName is the auth style
const AuthName = "oauth2"
// Random is a state value
const Random = "random"
// Handler is is a oauth2 pre process raw data in handler
type Handler struct {
}
// Handle is provider
func (oa *Handler) Handle(chain *handler.Chain, inv *invocation.Invocation, cb invocation.ResponseCallBack) {
if auth != nil && auth.GrantType == "authorization_code" {
if req, ok := inv.Args.(*http.Request); ok {
state := req.FormValue("state")
if state != Random && state != "" {
WriteBackErr(ErrInvalidState, http.StatusUnauthorized, cb)
return
}
code := req.FormValue("code")
if code == "" {
WriteBackErr(ErrInvalidCode, http.StatusUnauthorized, cb)
return
}
accessToken, err := getToken(code, cb)
if err != nil {
openlogging.Error("authorization error: " + err.Error())
WriteBackErr(ErrInvalidToken, http.StatusUnauthorized, cb)
return
}
if auth.Authenticate != nil {
err = auth.Authenticate(accessToken, req)
if err != nil {
openlogging.Error("authentication error: " + err.Error())
WriteBackErr(ErrInvalidAuth, http.StatusUnauthorized, cb)
return
}
}
}
}
chain.Next(inv, func(r *invocation.Response) {
cb(r)
})
}
// getToken deal with the authorization code and return the token
func getToken(code string, cb invocation.ResponseCallBack) (accessToken string, err error) {
if auth.UseConfig != nil {
config := auth.UseConfig
token, err := config.Exchange(context.Background(), code)
if err != nil {
openlogging.Error("get token failed, errors: " + err.Error())
WriteBackErr(ErrInvalidCode, http.StatusUnauthorized, cb)
return "", err
}
// set the expiry token in 30 minutes
token.Expiry = time.Now().Add(30 * 60 * time.Second)
if time.Now().After(token.Expiry) {
return "", ErrExpiredToken
}
accessToken = token.AccessToken
return accessToken, nil
}
return "", nil
}
// Name returns router string
func (oa *Handler) Name() string {
return AuthName
}
// NewOAuth2 returns new auth handler
func NewOAuth2() handler.Handler {
return &Handler{}
}
func init() {
err := handler.RegisterHandler(AuthName, NewOAuth2)
if err != nil {
openlogging.Error("register handler error: " + err.Error())
return
}
}
// WriteBackErr write err and callback
func WriteBackErr(err error, status int, cb invocation.ResponseCallBack) {
r := &invocation.Response{
Err: err,
Status: status,
}
cb(r)
}