blob: 5bb7305f6a992f3a1ea46e62b59e9ffa848d8c88 [file] [log] [blame]
/*
*
* Copyright 2020 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package old_triple
import (
"fmt"
"strconv"
"strings"
)
import (
"google.golang.org/protobuf/compiler/protogen"
"google.golang.org/protobuf/types/descriptorpb"
)
var RequireUnimplemented *bool
const (
Version = "1.0.8"
contextPackage = protogen.GoImportPath("context")
grpcPackage = protogen.GoImportPath("github.com/dubbogo/grpc-go")
codesPackage = protogen.GoImportPath("github.com/dubbogo/grpc-go/codes")
statusPackage = protogen.GoImportPath("github.com/dubbogo/grpc-go/status")
metadataPackage = protogen.GoImportPath("github.com/dubbogo/grpc-go/metadata")
dubbo3Package = protogen.GoImportPath("dubbo.apache.org/dubbo-go/v3/protocol/dubbo3")
constantPackage = protogen.GoImportPath("github.com/dubbogo/triple/pkg/common/constant")
dubboConstantPackage = protogen.GoImportPath("dubbo.apache.org/dubbo-go/v3/common/constant")
commonPackage = protogen.GoImportPath("github.com/dubbogo/triple/pkg/common")
triplePackage = protogen.GoImportPath("github.com/dubbogo/triple/pkg/triple")
protocolPackage = protogen.GoImportPath("dubbo.apache.org/dubbo-go/v3/protocol")
invocationPackage = protogen.GoImportPath("dubbo.apache.org/dubbo-go/v3/protocol/invocation")
fmtPackage = protogen.GoImportPath("fmt")
)
// GenerateFile generates a _grpc.pb.go file containing gRPC service definitions.
func GenerateFile(gen *protogen.Plugin, file *protogen.File) *protogen.GeneratedFile {
if len(file.Services) == 0 {
return nil
}
filename := file.GeneratedFilenamePrefix + "_triple.pb.go"
g := gen.NewGeneratedFile(filename, file.GoImportPath)
g.P("// Code generated by protoc-gen-go-triple. DO NOT EDIT.")
g.P("// versions:")
g.P("// - protoc-gen-go-triple v", Version)
g.P("// - protoc ", protocVersion(gen))
if file.Proto.GetOptions().GetDeprecated() {
g.P("// ", file.Desc.Path(), " is a deprecated file.")
} else {
g.P("// source: ", file.Desc.Path())
}
g.P()
g.P("package ", file.GoPackageName)
g.P()
generateFileContent(gen, file, g)
return g
}
func protocVersion(gen *protogen.Plugin) string {
v := gen.Request.GetCompilerVersion()
if v == nil {
return "(unknown)"
}
var suffix string
if s := v.GetSuffix(); s != "" {
suffix = "-" + s
}
return fmt.Sprintf("v%d.%d.%d%s", v.GetMajor(), v.GetMinor(), v.GetPatch(), suffix)
}
// generateFileContent generates the gRPC service definitions, excluding the package statement.
func generateFileContent(gen *protogen.Plugin, file *protogen.File, g *protogen.GeneratedFile) {
if len(file.Services) == 0 {
return
}
g.P("// This is a compile-time assertion to ensure that this generated file")
g.P("// is compatible with the grpc package it is being compiled against.")
g.P("const _ = ", grpcPackage.Ident("SupportPackageIsVersion7")) // When changing, update version number above.
g.P()
for _, service := range file.Services {
generateTripleService(gen, file, g, service)
}
}
func generateTripleService(gen *protogen.Plugin, file *protogen.File, g *protogen.GeneratedFile, service *protogen.Service) {
clientName := service.GoName + "Client"
g.P("// ", clientName, " is the client API for ", service.GoName, " service.")
g.P("//")
g.P("// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.")
// Client interface.
if service.Desc.Options().(*descriptorpb.ServiceOptions).GetDeprecated() {
g.P("//")
g.P(deprecationComment)
}
g.Annotate(clientName, service.Location)
g.P("type ", clientName, " interface {")
for _, method := range service.Methods {
g.Annotate(clientName+"."+method.GoName, method.Location)
if method.Desc.Options().(*descriptorpb.MethodOptions).GetDeprecated() {
g.P(deprecationComment)
}
g.P(method.Comments.Leading,
generateTripleClientSignature(g, method))
}
g.P("}")
g.P()
// Client structure.
g.P("type ", unexport(clientName), " struct {")
// triple logic
g.P("cc *", triplePackage.Ident("TripleConn"))
g.P("}")
g.P()
dubboSrvName := clientName + "Impl"
g.P("type ", dubboSrvName, " struct {")
for _, method := range service.Methods {
g.Annotate(clientName+"."+method.GoName, method.Location)
if method.Desc.Options().(*descriptorpb.MethodOptions).GetDeprecated() {
g.P(deprecationComment)
}
g.P(generateClientImplSignature(g, method))
}
g.P("}")
g.P()
g.P("func (c *", dubboSrvName, ") ", " GetDubboStub(cc *", triplePackage.Ident("TripleConn"), ") ", clientName, "{")
g.P(`return New`, clientName, `(cc)`)
g.P("}")
g.P()
g.P("func (c *", dubboSrvName, ") ", " XXX_InterfaceName() string{")
g.P(`return `, strconv.Quote(string(service.Desc.FullName())), ``)
g.P("}")
g.P()
// NewClient factory.
if service.Desc.Options().(*descriptorpb.ServiceOptions).GetDeprecated() {
g.P(deprecationComment)
}
g.P("func New", clientName, " (cc *", triplePackage.Ident("TripleConn"), ") ", clientName, " {")
g.P("return &", unexport(clientName), "{cc}")
g.P("}")
g.P()
var methodIndex, streamIndex int
// Client method implementations.
for _, method := range service.Methods {
if !method.Desc.IsStreamingServer() && !method.Desc.IsStreamingClient() {
// Unary RPC method
generateTripleClientMethod(gen, file, g, method, methodIndex)
methodIndex++
} else {
// Streaming RPC method
generateTripleClientMethod(gen, file, g, method, streamIndex)
streamIndex++
}
}
mustOrShould := "must"
if !*RequireUnimplemented {
mustOrShould = "should"
}
// Server interface.
serverType := service.GoName + "Server"
g.P("// ", serverType, " is the server API for ", service.GoName, " service.")
g.P("// All implementations ", mustOrShould, " embed Unimplemented", serverType)
g.P("// for forward compatibility")
if service.Desc.Options().(*descriptorpb.ServiceOptions).GetDeprecated() {
g.P("//")
g.P(deprecationComment)
}
g.Annotate(serverType, service.Location)
g.P("type ", serverType, " interface {")
for _, method := range service.Methods {
g.Annotate(serverType+"."+method.GoName, method.Location)
if method.Desc.Options().(*descriptorpb.MethodOptions).GetDeprecated() {
g.P(deprecationComment)
}
g.P(method.Comments.Leading,
generateTripleServerSignature(g, method))
}
if *RequireUnimplemented {
g.P("mustEmbedUnimplemented", serverType, "()")
}
g.P("}")
g.P()
// Server Unimplemented struct for forward compatibility.
g.P("// Unimplemented", serverType, " ", mustOrShould, " be embedded to have forward compatible implementations.")
g.P("type Unimplemented", serverType, " struct {")
g.P("proxyImpl ", protocolPackage.Ident("Invoker"))
g.P("}")
g.P()
for _, method := range service.Methods {
nilArg := ""
if !method.Desc.IsStreamingClient() && !method.Desc.IsStreamingServer() {
nilArg = "nil,"
}
g.P("func (Unimplemented", serverType, ") ", generateTripleServerSignature(g, method), "{")
g.P("return ", nilArg, statusPackage.Ident("Errorf"), "(", codesPackage.Ident("Unimplemented"), `, "method `, method.GoName, ` not implemented")`)
g.P("}")
}
// triple logic
// add set method
//func (g *GreeterProviderBase) SetProxyImpl(impl protocol.Invoker) {
// g.proxyImpl = impl
//}
g.P("func (s *Unimplemented", serverType, ") ", "XXX_SetProxyImpl(impl ", protocolPackage.Ident("Invoker"), ") {")
g.P(`s.proxyImpl = impl`)
g.P("}")
g.P()
// return get method
g.P("func (s *Unimplemented", serverType, ") ", "XXX_GetProxyImpl() ", protocolPackage.Ident("Invoker"), " {")
g.P(`return s.proxyImpl`)
g.P("}")
g.P()
serviceDescVar := service.GoName + "_ServiceDesc"
// return service desc
g.P("func (s *Unimplemented", serverType, ") XXX_ServiceDesc() *grpc_go.ServiceDesc {")
g.P(`return &`, serviceDescVar)
g.P(`}`)
g.P("func (s *Unimplemented", serverType, ") XXX_InterfaceName() string{")
g.P(`return `, strconv.Quote(string(service.Desc.FullName())), ``)
g.P("}")
g.P()
if *RequireUnimplemented {
g.P("func (Unimplemented", serverType, ") mustEmbedUnimplemented", serverType, "() {}")
}
g.P()
// Unsafe Server interface to opt-out of forward compatibility.
g.P("// Unsafe", serverType, " may be embedded to opt out of forward compatibility for this service.")
g.P("// Use of this interface is not recommended, as added methods to ", serverType, " will")
g.P("// result in compilation errors.")
g.P("type Unsafe", serverType, " interface {")
g.P("mustEmbedUnimplemented", serverType, "()")
g.P("}")
// Server registration.
if service.Desc.Options().(*descriptorpb.ServiceOptions).GetDeprecated() {
g.P(deprecationComment)
}
g.P("func Register", service.GoName, "Server(s ", grpcPackage.Ident("ServiceRegistrar"), ", srv ", serverType, ") {")
g.P("s.RegisterService(&", serviceDescVar, `, srv)`)
g.P("}")
g.P()
// Server handler implementations.
handlerNames := make([]string, 0, len(service.Methods))
for _, method := range service.Methods {
hname := generateTripleServerMethod(gen, file, g, method)
handlerNames = append(handlerNames, hname)
}
// Service descriptor.
g.P("// ", serviceDescVar, " is the ", grpcPackage.Ident("ServiceDesc"), " for ", service.GoName, " service.")
g.P("// It's only intended for direct use with ", grpcPackage.Ident("RegisterService"), ",")
g.P("// and not to be introspected or modified (even as a copy)")
g.P("var ", serviceDescVar, " = ", grpcPackage.Ident("ServiceDesc"), " {")
g.P("ServiceName: ", strconv.Quote(string(service.Desc.FullName())), ",")
g.P("HandlerType: (*", serverType, ")(nil),")
g.P("Methods: []", grpcPackage.Ident("MethodDesc"), "{")
for i, method := range service.Methods {
if method.Desc.IsStreamingClient() || method.Desc.IsStreamingServer() {
continue
}
g.P("{")
g.P("MethodName: ", strconv.Quote(string(method.Desc.Name())), ",")
g.P("Handler: ", handlerNames[i], ",")
g.P("},")
}
g.P("},")
g.P("Streams: []", grpcPackage.Ident("StreamDesc"), "{")
for i, method := range service.Methods {
if !method.Desc.IsStreamingClient() && !method.Desc.IsStreamingServer() {
continue
}
g.P("{")
g.P("StreamName: ", strconv.Quote(string(method.Desc.Name())), ",")
g.P("Handler: ", handlerNames[i], ",")
if method.Desc.IsStreamingServer() {
g.P("ServerStreams: true,")
}
if method.Desc.IsStreamingClient() {
g.P("ClientStreams: true,")
}
g.P("},")
}
g.P("},")
g.P("Metadata: \"", file.Desc.Path(), "\",")
g.P("}")
g.P()
}
// generateClientImplSignature returns the client-side signature for a method.
func generateClientImplSignature(g *protogen.GeneratedFile, method *protogen.Method) string {
s := method.GoName + " func(ctx " + g.QualifiedGoIdent(contextPackage.Ident("Context"))
if !method.Desc.IsStreamingClient() {
s += ", in *" + g.QualifiedGoIdent(method.Input.GoIdent)
}
s += ") ("
if !method.Desc.IsStreamingClient() && !method.Desc.IsStreamingServer() {
s += "*" + g.QualifiedGoIdent(method.Output.GoIdent)
s += ", error" + ")"
return s
}
s += method.Parent.GoName + "_" + method.GoName + "Client"
s += ", error)"
return s
}
// generateTripleClientSignature returns the client-side signature for a method.
func generateTripleClientSignature(g *protogen.GeneratedFile, method *protogen.Method) string {
s := method.GoName + "(ctx " + g.QualifiedGoIdent(contextPackage.Ident("Context"))
if !method.Desc.IsStreamingClient() {
s += ", in *" + g.QualifiedGoIdent(method.Input.GoIdent)
}
s += ", opts ..." + g.QualifiedGoIdent(grpcPackage.Ident("CallOption")) + ") ("
if !method.Desc.IsStreamingClient() && !method.Desc.IsStreamingServer() {
s += "*" + g.QualifiedGoIdent(method.Output.GoIdent)
s += ", " + g.QualifiedGoIdent(commonPackage.Ident("ErrorWithAttachment")) + ")"
return s
}
s += method.Parent.GoName + "_" + method.GoName + "Client"
s += ", error)"
return s
}
func generateTripleClientMethod(gen *protogen.Plugin, file *protogen.File, g *protogen.GeneratedFile, method *protogen.Method, index int) {
service := method.Parent
if method.Desc.Options().(*descriptorpb.MethodOptions).GetDeprecated() {
g.P(deprecationComment)
}
g.P("func (c *", unexport(service.GoName), "Client) ", generateTripleClientSignature(g, method), "{")
if !method.Desc.IsStreamingServer() && !method.Desc.IsStreamingClient() {
g.P("out := new(", method.Output.GoIdent, ")")
// triple logic
g.P(fmt.Sprintf("interfaceKey := ctx.Value(" + g.QualifiedGoIdent(constantPackage.Ident("InterfaceKey")) + ").(string)"))
g.P(fmt.Sprintf("return out, c.cc.Invoke(ctx, \"/\" + interfaceKey + \"/%s\", in, out)", method.GoName))
g.P("}")
g.P()
return
}
streamType := unexport(service.GoName) + method.GoName + "Client"
// triple logic
g.P(fmt.Sprintf("interfaceKey := ctx.Value(" + g.QualifiedGoIdent(constantPackage.Ident("InterfaceKey")) + ").(string)"))
g.P(fmt.Sprintf("stream, err := c.cc.NewStream(ctx, \"/\" + interfaceKey + \"/%s\", opts...)", method.GoName))
g.P("if err != nil { return nil, err }")
g.P("x := &", streamType, "{stream}")
if !method.Desc.IsStreamingClient() {
g.P("if err := x.ClientStream.SendMsg(in); err != nil { return nil, err }")
g.P("if err := x.ClientStream.CloseSend(); err != nil { return nil, err }")
}
g.P("return x, nil")
g.P("}")
g.P()
genSend := method.Desc.IsStreamingClient()
genRecv := method.Desc.IsStreamingServer()
genCloseAndRecv := !method.Desc.IsStreamingServer()
// Stream auxiliary types and methods.
g.P("type ", service.GoName, "_", method.GoName, "Client interface {")
if genSend {
g.P("Send(*", method.Input.GoIdent, ") error")
}
if genRecv {
g.P("Recv() (*", method.Output.GoIdent, ", error)")
}
if genCloseAndRecv {
g.P("CloseAndRecv() (*", method.Output.GoIdent, ", error)")
}
g.P(grpcPackage.Ident("ClientStream"))
g.P("}")
g.P()
g.P("type ", streamType, " struct {")
g.P(grpcPackage.Ident("ClientStream"))
g.P("}")
g.P()
if genSend {
g.P("func (x *", streamType, ") Send(m *", method.Input.GoIdent, ") error {")
g.P("return x.ClientStream.SendMsg(m)")
g.P("}")
g.P()
}
if genRecv {
g.P("func (x *", streamType, ") Recv() (*", method.Output.GoIdent, ", error) {")
g.P("m := new(", method.Output.GoIdent, ")")
g.P("if err := x.ClientStream.RecvMsg(m); err != nil { return nil, err }")
g.P("return m, nil")
g.P("}")
g.P()
}
if genCloseAndRecv {
g.P("func (x *", streamType, ") CloseAndRecv() (*", method.Output.GoIdent, ", error) {")
g.P("if err := x.ClientStream.CloseSend(); err != nil { return nil, err }")
g.P("m := new(", method.Output.GoIdent, ")")
g.P("if err := x.ClientStream.RecvMsg(m); err != nil { return nil, err }")
g.P("return m, nil")
g.P("}")
g.P()
}
}
func generateTripleServerSignature(g *protogen.GeneratedFile, method *protogen.Method) string {
var reqArgs []string
ret := "error"
if !method.Desc.IsStreamingClient() && !method.Desc.IsStreamingServer() {
reqArgs = append(reqArgs, g.QualifiedGoIdent(contextPackage.Ident("Context")))
ret = "(*" + g.QualifiedGoIdent(method.Output.GoIdent) + ", error)"
}
if !method.Desc.IsStreamingClient() {
reqArgs = append(reqArgs, "*"+g.QualifiedGoIdent(method.Input.GoIdent))
}
if method.Desc.IsStreamingClient() || method.Desc.IsStreamingServer() {
reqArgs = append(reqArgs, method.Parent.GoName+"_"+method.GoName+"Server")
}
return method.GoName + "(" + strings.Join(reqArgs, ", ") + ") " + ret
}
func generateTripleServerMethod(gen *protogen.Plugin, file *protogen.File, g *protogen.GeneratedFile, method *protogen.Method) string {
service := method.Parent
hname := fmt.Sprintf("_%s_%s_Handler", service.GoName, method.GoName)
if !method.Desc.IsStreamingClient() && !method.Desc.IsStreamingServer() {
g.P("func ", hname, "(srv interface{}, ctx ", contextPackage.Ident("Context"), ", dec func(interface{}) error, interceptor ", grpcPackage.Ident("UnaryServerInterceptor"), ") (interface{}, error) {")
g.P("in := new(", method.Input.GoIdent, ")")
g.P("if err := dec(in); err != nil { return nil, err }")
// triple logic
g.P("base := srv.(", dubbo3Package.Ident("Dubbo3GrpcService"), ")")
g.P("args := []interface{}{}")
g.P("args = append(args, in)")
g.P("md, _ := ", metadataPackage.Ident("FromIncomingContext"), "(ctx)")
g.P("invAttachment := make(map[string]interface{}, len(md))")
g.P(`for k, v := range md{
invAttachment[k] = v
}`)
g.P(`invo := `, invocationPackage.Ident("NewRPCInvocation"), `("`, method.GoName, `", args, invAttachment)`)
g.P("if interceptor == nil {")
g.P("result := base.XXX_GetProxyImpl().Invoke(ctx, invo)")
g.P("return result, result.Error()")
g.P("}")
g.P("info := &", grpcPackage.Ident("UnaryServerInfo"), "{")
g.P("Server: srv,")
g.P(`FullMethod: ctx.Value("XXX_TRIPLE_GO_INTERFACE_NAME").(string),`)
g.P("}")
g.P("handler := func(ctx ", contextPackage.Ident("Context"), ", req interface{}) (interface{}, error) {")
g.P("result := base.XXX_GetProxyImpl().Invoke(ctx, invo)")
g.P("return result, result.Error()")
g.P("}")
g.P("return interceptor(ctx, in, info, handler)")
g.P("}")
g.P()
return hname
}
streamType := unexport(service.GoName) + method.GoName + "Server"
g.P("func ", hname, "(srv interface{}, stream ", grpcPackage.Ident("ServerStream"), ") error {")
// triple logic
g.P("_, ok := srv.(", dubbo3Package.Ident("Dubbo3GrpcService"), ")")
g.P(`ctx := stream.Context()
md, _ := `, metadataPackage.Ident("FromIncomingContext"), `(ctx)
invAttachment := make(map[string]interface{}, len(md))
for k, v := range md {
invAttachment[k] = v
}
stream.(`, grpcPackage.Ident("CtxSetterStream"), `).SetContext(context.WithValue(ctx, `, dubboConstantPackage.Ident("AttachmentKey"), `, invAttachment))`)
g.P(`invo := `, invocationPackage.Ident("NewRPCInvocation"), `("`, method.GoName, `", nil, nil)`)
g.P("if !ok {")
g.P(fmtPackage.Ident("Println(invo)"))
g.P(`return nil`)
g.P("}")
if !method.Desc.IsStreamingClient() {
g.P("m := new(", method.Input.GoIdent, ")")
g.P("if err := stream.RecvMsg(m); err != nil { return err }")
g.P("return srv.(", service.GoName, "Server).", method.GoName, "(m, &", streamType, "{stream})")
} else {
g.P("return srv.(", service.GoName, "Server).", method.GoName, "(&", streamType, "{stream})")
}
g.P("}")
g.P()
genSend := method.Desc.IsStreamingServer()
genSendAndClose := !method.Desc.IsStreamingServer()
genRecv := method.Desc.IsStreamingClient()
// Stream auxiliary types and methods.
g.P("type ", service.GoName, "_", method.GoName, "Server interface {")
if genSend {
g.P("Send(*", method.Output.GoIdent, ") error")
}
if genSendAndClose {
g.P("SendAndClose(*", method.Output.GoIdent, ") error")
}
if genRecv {
g.P("Recv() (*", method.Input.GoIdent, ", error)")
}
g.P(grpcPackage.Ident("ServerStream"))
g.P("}")
g.P()
g.P("type ", streamType, " struct {")
g.P(grpcPackage.Ident("ServerStream"))
g.P("}")
g.P()
if genSend {
g.P("func (x *", streamType, ") Send(m *", method.Output.GoIdent, ") error {")
g.P("return x.ServerStream.SendMsg(m)")
g.P("}")
g.P()
}
if genSendAndClose {
g.P("func (x *", streamType, ") SendAndClose(m *", method.Output.GoIdent, ") error {")
g.P("return x.ServerStream.SendMsg(m)")
g.P("}")
g.P()
}
if genRecv {
g.P("func (x *", streamType, ") Recv() (*", method.Input.GoIdent, ", error) {")
g.P("m := new(", method.Input.GoIdent, ")")
g.P("if err := x.ServerStream.RecvMsg(m); err != nil { return nil, err }")
g.P("return m, nil")
g.P("}")
g.P()
}
return hname
}
const deprecationComment = "// Deprecated: Do not use."
func unexport(s string) string { return strings.ToLower(s[:1]) + s[1:] }