blob: 84e21c3c13c32ab9cdecf12fe81d3479128df831 [file] [log] [blame]
// Licensed to Apache Software Foundation (ASF) under one or more contributor
// license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright
// ownership. Apache Software Foundation (ASF) licenses this file to you under
// the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package rewrite
import (
"fmt"
"reflect"
"strings"
"github.com/dave/dst"
"github.com/dave/dst/dstutil"
"github.com/apache/skywalking-go/tools/go-agent/tools"
)
var (
GlobalOperatorRealSetMethodName = VarPrefix + "OperatorSetOperator"
GlobalOperatorRealGetMethodName = VarPrefix + "OperatorGetOperator"
GlobalOperatorTypeName = TypePrefix + "OperatorOperator"
)
func (c *Context) Func(funcDecl *dst.FuncDecl, cursor *dstutil.Cursor) {
// only the static method needs rewrite
if funcDecl.Recv == nil {
// if the method name is generated, then ignore to enhance(for adapter)
if !strings.HasPrefix(funcDecl.Name.Name, GenerateMethodPrefix) {
funcDecl.Name = dst.NewIdent(fmt.Sprintf("%s%s%s", StaticMethodPrefix, c.currentPackageTitle, funcDecl.Name.Name))
}
} else if len(funcDecl.Recv.List) == 1 {
// if contains the receiver, then enhance the receiver type
field := funcDecl.Recv.List[0]
var typeName string
if len(field.Names) > 0 {
typeName = field.Names[0].Name
c.rewriteMapping.addVarMapping(typeName, typeName)
}
if k, v := c.enhanceTypeNameWhenRewrite(field.Type, field, -1); k != "" {
c.rewriteMapping.addTypeMapping(k, v)
}
}
// enhance method parameter and return value
c.enhanceFuncParameter(funcDecl.Type.Params)
c.enhanceFuncParameter(funcDecl.Type.Results)
// enhance the method body
for _, stmt := range funcDecl.Body.List {
c.enhanceFuncStmt(stmt)
}
}
// nolint
func (c *Context) enhanceFuncStmt(stmt dst.Stmt) {
// for the variables created in the sub statement, ex: if, func(), the temporary variable count should be recorded
subCallTypes := []reflect.Type{
reflect.TypeOf(&dst.IfStmt{}),
reflect.TypeOf(&dst.BlockStmt{}),
}
dstutil.Apply(stmt, func(cursor *dstutil.Cursor) bool {
for _, t := range subCallTypes {
if reflect.TypeOf(cursor.Node()) == t {
c.rewriteMapping.pushBlockStack()
}
}
switch n := cursor.Node().(type) {
case *dst.BlockStmt:
for _, tmp := range n.List {
c.enhanceFuncStmt(tmp)
}
case *dst.AssignStmt:
for _, l := range n.Lhs {
if k, v := c.enhanceVarNameWhenRewrite(l); k != "" {
c.rewriteMapping.addVarMapping(k, v)
}
}
for i, r := range n.Rhs {
if k, v := c.enhanceTypeNameWhenRewrite(r, nil, i); k != "" {
c.rewriteMapping.addTypeMapping(k, v)
}
}
case *dst.BinaryExpr:
c.rewriteVarIfExistingMapping(n.X, n)
c.rewriteVarIfExistingMapping(n.Y, n)
case *dst.CallExpr:
c.enhanceTypeNameWhenRewrite(n.Fun, n, -1)
for inx, arg := range n.Args {
c.enhanceTypeNameWhenRewrite(arg, n, inx)
}
case *dst.ReturnStmt:
for inx, arg := range n.Results {
c.enhanceTypeNameWhenRewrite(arg, n, inx)
}
case *dst.FuncType:
c.enhanceFuncParameter(n.Params)
c.enhanceFuncParameter(n.Results)
case *dst.ExprStmt:
c.enhanceTypeNameWhenRewrite(n.X, n, -1)
case *dst.TypeAssertExpr:
c.enhanceTypeNameWhenRewrite(n.X, n, -1)
c.enhanceTypeNameWhenRewrite(n.Type, n, -1)
case *dst.IfStmt:
c.enhanceFuncStmt(n.Init)
c.enhanceTypeNameWhenRewrite(n.Cond, n, -1)
if n.Body != nil {
for _, stmt := range n.Body.List {
c.enhanceFuncStmt(stmt)
}
}
if n.Else != nil {
c.enhanceFuncStmt(n.Else)
}
case *dst.RangeStmt:
c.enhanceTypeNameWhenRewrite(n.X, n, -1)
if k, v := c.enhanceVarNameWhenRewrite(n.Key); k != "" {
c.rewriteMapping.addVarMapping(k, v)
}
if k, v := c.enhanceVarNameWhenRewrite(n.Value); k != "" {
c.rewriteMapping.addVarMapping(k, v)
}
if n.Body != nil {
for _, stmt := range n.Body.List {
c.enhanceFuncStmt(stmt)
}
}
case *dst.ValueSpec:
c.Var(n, false)
default:
return true
}
return false
}, func(cursor *dstutil.Cursor) bool {
// all templates variables should be removed
for _, t := range subCallTypes {
if reflect.TypeOf(cursor.Node()) == t {
c.rewriteMapping.popBlockStack()
break
}
}
return true
})
}
func (c *Context) rewriteVarIfExistingMapping(exp, parent dst.Expr) bool {
switch n := exp.(type) {
case *dst.Ident:
if v := c.rewriteMapping.findVarMappingName(n.Name); v != "" {
n.Name = v
return true
}
case *dst.SelectorExpr:
if pkg, ok := n.X.(*dst.Ident); ok {
if imp := c.packageImport[pkg.Name]; imp != nil {
tools.RemovePackageRef(parent, n)
return true
}
}
return c.rewriteVarIfExistingMapping(n.X, n)
case *dst.CompositeLit:
c.enhanceTypeNameWhenRewrite(n.Type, n, -1)
for _, elt := range n.Elts {
// for struct data, ex: "&xxx{k: v}"
if kv, ok := elt.(*dst.KeyValueExpr); ok {
c.rewriteVarIfExistingMapping(kv.Value, elt)
}
}
case *dst.UnaryExpr:
c.enhanceTypeNameWhenRewrite(n.X, n, -1)
case *dst.IndexExpr:
c.rewriteVarIfExistingMapping(n.Index, n)
c.rewriteVarIfExistingMapping(n.X, n)
case *dst.CallExpr:
c.enhanceTypeNameWhenRewrite(n.Fun, n, -1)
for _, arg := range n.Args {
c.rewriteVarIfExistingMapping(arg, n)
}
case *dst.StarExpr:
c.enhanceTypeNameWhenRewrite(n.X, n, -1)
}
return false
}
func (c *Context) enhanceFuncParameter(fields *dst.FieldList) {
if fields == nil {
return
}
for _, field := range fields.List {
if len(field.Names) > 0 {
for inx := range field.Names {
name := field.Names[inx].Name
// keep the var names for debugging
c.rewriteMapping.addVarMapping(name, name)
}
}
if k, v := c.enhanceTypeNameWhenRewrite(field.Type, field, -1); k != "" {
c.rewriteMapping.addTypeMapping(k, v)
}
}
}