blob: 861e8ac5c41e6500cb2bb22693bd747b9ae71917 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package flight
import (
"context"
"encoding/base64"
"strings"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
)
const (
grpcAuthHeader = "auth-token-bin"
basicAuthHeader = "authorization"
basicAuthPrefix = "Basic"
bearerTokenPrefix = "Bearer"
)
// AuthConn wraps the stream from grpc for handshakes to simplify handling
// handshake request and response from the flight.proto forwarding just the
// payloads and errors instead of having to deal with the handshake request
// and response protos directly
type AuthConn interface {
Read() ([]byte, error)
Send([]byte) error
}
type serverAuthConn struct {
stream FlightService_HandshakeServer
}
func (a *serverAuthConn) Read() ([]byte, error) {
in, err := a.stream.Recv()
if err != nil {
return nil, err
}
return in.Payload, nil
}
func (a *serverAuthConn) Send(b []byte) error {
return a.stream.Send(&HandshakeResponse{Payload: b})
}
// ServerAuthHandler defines an interface for the server to perform the handshake.
// The token is expected to be sent as part of the context metadata in subsequent
// requests with a key of "auth-token-bin" which will then call IsValid to validate
type ServerAuthHandler interface {
Authenticate(AuthConn) error
IsValid(token string) (interface{}, error)
}
type authCtxKey struct{}
type authWrappedStream struct {
grpc.ServerStream
ctx context.Context
}
func (a *authWrappedStream) Context() context.Context { return a.ctx }
// AuthFromContext will return back whatever object was returned from `IsValid` for a
// given request context allowing handlers to retrieve identifying information
// for the current request for use.
func AuthFromContext(ctx context.Context) interface{} {
return ctx.Value(authCtxKey{})
}
func createServerAuthUnaryInterceptor(auth ServerAuthHandler) grpc.UnaryServerInterceptor {
if auth == nil {
return func(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
return handler(ctx, req)
}
}
return func(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
var authTok string
md, ok := metadata.FromIncomingContext(ctx)
if ok {
vals := md.Get(grpcAuthHeader)
if len(vals) > 0 {
authTok = vals[0]
}
}
peerIdentity, err := auth.IsValid(authTok)
if err != nil {
return nil, status.Errorf(codes.PermissionDenied, "auth-error: %s", err)
}
return handler(context.WithValue(ctx, authCtxKey{}, peerIdentity), req)
}
}
func createServerAuthStreamInterceptor(auth ServerAuthHandler) grpc.StreamServerInterceptor {
if auth == nil {
return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
return handler(srv, stream)
}
}
return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
if strings.HasSuffix(info.FullMethod, "/Handshake") {
return handler(srv, stream)
}
var authTok string
md, ok := metadata.FromIncomingContext(stream.Context())
if ok {
vals := md.Get(grpcAuthHeader)
if len(vals) > 0 {
authTok = vals[0]
}
}
peerIdentity, err := auth.IsValid(authTok)
if err != nil {
return status.Errorf(codes.Unauthenticated, "auth-error: %s", err)
}
stream = &authWrappedStream{ServerStream: stream, ctx: context.WithValue(stream.Context(), authCtxKey{}, peerIdentity)}
return handler(srv, stream)
}
}
// our implementation of handshake using the authhandler
func (s *server) handshake(stream FlightService_HandshakeServer) error {
if s.authHandler == nil {
return nil
}
return s.authHandler.Authenticate(&serverAuthConn{stream})
}
type BasicAuthValidator interface {
Validate(username, password string) (string, error)
IsValid(bearerToken string) (interface{}, error)
}
func createServerBearerTokenUnaryInterceptor(validator BasicAuthValidator) grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
var auth string
md, ok := metadata.FromIncomingContext(ctx)
if ok {
vals := md.Get(basicAuthHeader)
if len(vals) > 0 && strings.HasPrefix(vals[0], bearerTokenPrefix) {
auth = vals[0][len(bearerTokenPrefix)+1:]
}
}
identity, err := validator.IsValid(auth)
if err != nil {
return nil, err
}
return handler(context.WithValue(ctx, authCtxKey{}, identity), req)
}
}
func createServerBearerTokenStreamInterceptor(validator BasicAuthValidator) grpc.StreamServerInterceptor {
return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
var auth []string
md, ok := metadata.FromIncomingContext(stream.Context())
if ok {
auth = md.Get(basicAuthHeader)
if len(auth) > 0 {
auth = strings.Split(auth[0], " ")
}
}
if auth == nil || len(auth) == 0 {
return status.Error(codes.Unauthenticated, "must authenticate first")
}
if strings.HasSuffix(info.FullMethod, "/Handshake") {
if auth[0] == basicAuthPrefix {
val, err := base64.RawStdEncoding.DecodeString(auth[1])
if err != nil {
return status.Errorf(codes.Unauthenticated, "invalid basic auth encoding: %s", err)
}
creds := strings.SplitN(string(val), ":", 2)
token, err := validator.Validate(creds[0], creds[1])
if err != nil {
return err
}
stream.SetTrailer(metadata.New(map[string]string{basicAuthHeader: strings.Join([]string{bearerTokenPrefix, token}, " ")}))
return handler(srv, stream)
}
return status.Errorf(codes.Unauthenticated, "only Basic Auth implemented")
}
if auth[0] == bearerTokenPrefix {
identity, err := validator.IsValid(auth[1])
if err != nil {
return err
}
return handler(srv, &authWrappedStream{ServerStream: stream, ctx: context.WithValue(stream.Context(), authCtxKey{}, identity)})
}
return status.Errorf(codes.Unauthenticated, "Only bearer token auth implemented")
}
}
func CreateServerBearerTokenAuthInterceptors(validator BasicAuthValidator) (grpc.UnaryServerInterceptor, grpc.StreamServerInterceptor) {
if validator == nil {
panic("validator cannot be nil")
}
return createServerBearerTokenUnaryInterceptor(validator), createServerBearerTokenStreamInterceptor(validator)
}