blob: 56bc264a942f29c14de0c6a33edc3f3a945ff591 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package grpcproxy
import (
"context"
"errors"
"fmt"
"io"
stdHttp "net/http"
"strings"
"sync"
"time"
)
import (
"github.com/golang/protobuf/jsonpb" //nolint
"github.com/golang/protobuf/proto" //nolint
"github.com/jhump/protoreflect/desc"
"github.com/jhump/protoreflect/dynamic"
"github.com/jhump/protoreflect/dynamic/grpcdynamic"
perrors "github.com/pkg/errors"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
)
import (
"github.com/apache/dubbo-go-pixiu/pixiu/pkg/common/constant"
"github.com/apache/dubbo-go-pixiu/pixiu/pkg/common/extension/filter"
ct "github.com/apache/dubbo-go-pixiu/pixiu/pkg/context"
"github.com/apache/dubbo-go-pixiu/pixiu/pkg/context/http"
"github.com/apache/dubbo-go-pixiu/pixiu/pkg/logger"
"github.com/apache/dubbo-go-pixiu/pixiu/pkg/server"
)
const (
// Kind is the kind of Fallback.
Kind = constant.HTTPGrpcProxyFilter
loggerHeader = "[grpc-proxy]"
// DescriptorSourceKey current ds
DescriptorSourceKey = "DescriptorSource"
// GrpcClientConnKey the grpc-client-conn by the coroutine local
GrpcClientConnKey = "GrpcClientConn"
)
func init() {
filter.RegisterHttpFilter(&Plugin{})
}
const (
NONE = "none"
AUTO = "auto"
LOCAL = "local"
REMOTE = "remote"
)
type (
DescriptorSourceStrategy string
// Plugin is grpc filter plugin.
Plugin struct {
}
// FilterFactory is grpc filter instance
FilterFactory struct {
cfg *Config
// grpc descriptor source factory
descriptor *Descriptor
// hold grpc.ClientConns, key format: cluster name + "." + endpoint
pools map[string]*sync.Pool
extReg *dynamic.ExtensionRegistry
registered map[string]bool
}
Filter struct {
cfg *Config
// grpc descriptor source factory
descriptor *Descriptor
// hold grpc.ClientConns, key format: cluster name + "." + endpoint
pools map[string]*sync.Pool
extReg *dynamic.ExtensionRegistry
registered map[string]bool
}
// Config describe the config of AccessFilter
Config struct {
DescriptorSourceStrategy DescriptorSourceStrategy `yaml:"descriptor_source_strategy" json:"descriptor_source_strategy" default:"auto"`
Path string `yaml:"path" json:"path"`
Rules []*Rule `yaml:"rules" json:"rules"` //nolint
Timeout time.Duration `yaml:"timeout" json:"timeout"` //nolint
}
Rule struct {
Selector string `yaml:"selector" json:"selector"`
Match Match `yaml:"match" json:"match"`
}
Match struct {
Method string `yaml:"method" json:"method"` //nolint
}
)
func (c DescriptorSourceStrategy) String() string {
return string(c)
}
func (c DescriptorSourceStrategy) Val() DescriptorSourceStrategy {
switch c {
case NONE:
return NONE
case AUTO:
return AUTO
case LOCAL:
return LOCAL
case REMOTE:
return REMOTE
}
return ""
}
func (p *Plugin) Kind() string {
return Kind
}
func (p *Plugin) CreateFilterFactory() (filter.HttpFilterFactory, error) {
return &FilterFactory{cfg: &Config{DescriptorSourceStrategy: AUTO}, descriptor: &Descriptor{}}, nil
}
func (factory *FilterFactory) PrepareFilterChain(ctx *http.HttpContext, chain filter.FilterChain) error {
f := &Filter{cfg: factory.cfg, descriptor: factory.descriptor, pools: factory.pools, extReg: factory.extReg, registered: factory.registered}
chain.AppendDecodeFilters(f)
return nil
}
// getServiceAndMethod first return value is package.service, second one is method name
func getServiceAndMethod(path string) (string, string) {
pos := strings.LastIndex(path, "/")
if pos < 0 {
return "", ""
}
mth := path[pos+1:]
prefix := strings.TrimSuffix(path, "/"+mth)
pos = strings.LastIndex(prefix, "/")
if pos < 0 {
return "", ""
}
svc := prefix[pos+1:]
return svc, mth
}
// Decode use the default http to grpc transcoding strategy https://cloud.google.com/endpoints/docs/grpc/transcoding
func (f *Filter) Decode(c *http.HttpContext) filter.FilterStatus {
svc, mth := getServiceAndMethod(c.GetUrl())
var clientConn *grpc.ClientConn
var err error
re := c.GetRouteEntry()
logger.Debugf("%s client choose endpoint from cluster :%v", loggerHeader, re.Cluster)
e := server.GetClusterManager().PickEndpoint(re.Cluster, c)
if e == nil {
logger.Errorf("%s err {cluster not exists}", loggerHeader)
c.SendLocalReply(stdHttp.StatusServiceUnavailable, []byte("cluster not exists"))
return filter.Stop
}
// timeout for Dial and Invoke
ctx, cancel := context.WithTimeout(c.Ctx, c.Timeout)
defer cancel()
ep := e.Address.GetAddress()
p, ok := f.pools[strings.Join([]string{re.Cluster, ep}, ".")]
if !ok {
p = &sync.Pool{}
}
clientConn, ok = p.Get().(*grpc.ClientConn)
if !ok || clientConn == nil {
// TODO(Kenway): Support Credential and TLS
clientConn, err = grpc.DialContext(ctx, ep, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil || clientConn == nil {
logger.Errorf("%s err {failed to connect to grpc service provider}", loggerHeader)
c.SendLocalReply(stdHttp.StatusServiceUnavailable, []byte((fmt.Sprintf("%s", err))))
return filter.Stop
}
}
// get DescriptorSource, contain file and reflection
source, err := f.descriptor.getDescriptorSource(context.WithValue(ctx, ct.ContextKey(GrpcClientConnKey), clientConn), f.cfg)
if err != nil {
logger.Errorf("%s err %s : %s ", loggerHeader, "get desc source fail", err)
c.SendLocalReply(stdHttp.StatusInternalServerError, []byte("service not config proto file or the server not support reflection API"))
return filter.Stop
}
//put DescriptorSource concurrent, del if no need
ctx = context.WithValue(ctx, ct.ContextKey(DescriptorSourceKey), source)
dscp, err := source.FindSymbol(svc)
if err != nil {
logger.Errorf("%s err {%s}", loggerHeader, "request path invalid")
c.SendLocalReply(stdHttp.StatusBadRequest, []byte("method not allow"))
return filter.Stop
}
svcDesc, ok := dscp.(*desc.ServiceDescriptor)
if !ok {
logger.Errorf("%s err {service not expose, %s}", loggerHeader, svc)
c.SendLocalReply(stdHttp.StatusBadRequest, []byte(fmt.Sprintf("service not expose, %s", svc)))
return filter.Stop
}
mthDesc := svcDesc.FindMethodByName(mth)
err = f.registerExtension(source, mthDesc)
if err != nil {
logger.Errorf("%s err {%s}", loggerHeader, "register extension failed")
c.SendLocalReply(stdHttp.StatusInternalServerError, []byte(fmt.Sprintf("%s", err)))
return filter.Stop
}
msgFac := dynamic.NewMessageFactoryWithExtensionRegistry(f.extReg)
grpcReq := msgFac.NewMessage(mthDesc.GetInputType())
err = jsonToProtoMsg(c.Request.Body, grpcReq)
if err != nil && !errors.Is(err, io.EOF) {
logger.Errorf("%s err {failed to convert json to proto msg, %s}", loggerHeader, err.Error())
c.SendLocalReply(stdHttp.StatusInternalServerError, []byte(fmt.Sprintf("%s", err)))
return filter.Stop
}
stub := grpcdynamic.NewStubWithMessageFactory(clientConn, msgFac)
// metadata in grpc has the same feature in http
md := mapHeaderToMetadata(c.AllHeaders())
ctx = metadata.NewOutgoingContext(ctx, md)
md = metadata.MD{}
t := metadata.MD{}
resp, err := Invoke(ctx, stub, mthDesc, grpcReq, grpc.Header(&md), grpc.Trailer(&t))
// judge err is server side error or not
if st, ok := status.FromError(err); !ok || isServerError(st) {
if isServerTimeout(st) {
logger.Errorf("%s err {failed to invoke grpc service provider because timeout, err:%s}", loggerHeader, err.Error())
c.SendLocalReply(stdHttp.StatusGatewayTimeout, []byte(fmt.Sprintf("%s", err)))
return filter.Stop
}
logger.Errorf("%s err {failed to invoke grpc service provider, %s}", loggerHeader, err.Error())
c.SendLocalReply(stdHttp.StatusServiceUnavailable, []byte(fmt.Sprintf("%s", err)))
return filter.Stop
}
res, err := protoMsgToJson(resp)
if err != nil {
logger.Errorf("%s err {failed to convert proto msg to json, %s}", loggerHeader, err.Error())
c.SendLocalReply(stdHttp.StatusInternalServerError, []byte(fmt.Sprintf("%s", err)))
return filter.Stop
}
h := mapMetadataToHeader(md)
th := mapMetadataToHeader(t)
// let response filter handle resp
c.SourceResp = &stdHttp.Response{
StatusCode: stdHttp.StatusOK,
Header: h,
Body: io.NopCloser(strings.NewReader(res)),
Trailer: th,
Request: c.Request,
}
p.Put(clientConn)
return filter.Continue
}
func (f *Filter) registerExtension(source DescriptorSource, mthDesc *desc.MethodDescriptor) error {
err := RegisterExtension(source, f.extReg, mthDesc.GetInputType(), f.registered)
if err != nil {
return perrors.New("register extension failed")
}
err = RegisterExtension(source, f.extReg, mthDesc.GetOutputType(), f.registered)
if err != nil {
return perrors.New("register extension failed")
}
return nil
}
func RegisterExtension(source DescriptorSource, extReg *dynamic.ExtensionRegistry, msgDesc *desc.MessageDescriptor, registered map[string]bool) error {
msgType := msgDesc.GetFullyQualifiedName()
if _, ok := registered[msgType]; ok {
return nil
}
if len(msgDesc.GetExtensionRanges()) > 0 {
fds, err := source.AllExtensionsForType(msgType)
if err != nil {
return fmt.Errorf("failed to find msg type {%s} in file source", msgType)
}
err = extReg.AddExtension(fds...)
if err != nil {
return fmt.Errorf("failed to register extensions of msgType {%s}, err is {%s}", msgType, err.Error())
}
}
for _, fd := range msgDesc.GetFields() {
if fd.GetMessageType() != nil {
err := RegisterExtension(source, extReg, fd.GetMessageType(), registered)
if err != nil {
return err
}
}
}
return nil
}
func mapHeaderToMetadata(header stdHttp.Header) metadata.MD {
md := metadata.MD{}
for key, val := range header {
md.Append(key, val...)
}
return md
}
func mapMetadataToHeader(md metadata.MD) stdHttp.Header {
h := stdHttp.Header{}
for key, val := range md {
for _, v := range val {
h.Add(key, v)
}
}
return h
}
func jsonToProtoMsg(reader io.Reader, msg proto.Message) error {
body, err := io.ReadAll(reader)
if err != nil {
return err
}
return jsonpb.UnmarshalString(string(body), msg)
}
func protoMsgToJson(msg proto.Message) (string, error) {
m := jsonpb.Marshaler{}
return m.MarshalToString(msg)
}
func isServerError(st *status.Status) bool {
return st.Code() == codes.DeadlineExceeded || st.Code() == codes.ResourceExhausted || st.Code() == codes.Internal ||
st.Code() == codes.Unavailable
}
func isServerTimeout(st *status.Status) bool {
return st.Code() == codes.DeadlineExceeded || st.Code() == codes.Canceled
}
func (factory *FilterFactory) Config() interface{} {
return factory.cfg
}
func (factory *FilterFactory) Apply() error {
err := configCheck(factory.cfg)
if err != nil {
return err
}
factory.descriptor.initDescriptorSource(factory.cfg)
return nil
}
func configCheck(cfg *Config) error {
if len(cfg.DescriptorSourceStrategy.Val()) == 0 {
return perrors.Errorf("grpc descriptor source config `descriptor_source_strategy` is `%s`, maybe set it `%s`", cfg.DescriptorSourceStrategy.String(), AUTO)
}
return nil
}