blob: 0909895111f11acb00b74bbf660a80ca642a29db [file] [log] [blame]
// Licensed to Apache Software Foundation (ASF) under one or more contributor
// license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright
// ownership. Apache Software Foundation (ASF) licenses this file to you under
// the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package tools
import (
"bytes"
"fmt"
"go/ast"
"go/printer"
"go/token"
"io"
"os"
"path/filepath"
"regexp"
"strings"
"github.com/dave/dst"
"github.com/dave/dst/decorator"
"github.com/dave/dst/dstutil"
)
var packageImportExp = regexp.MustCompile(`^(\S+\s+)?"(.+)"$`)
func ChangePackageImportPath(file dst.Node, pkgChanges map[string]string) {
dstutil.Apply(file, func(cursor *dstutil.Cursor) bool {
if n, ok := cursor.Node().(*dst.ImportSpec); ok {
for originalPkg, targetPkg := range pkgChanges {
sprintf := fmt.Sprintf("%q", originalPkg)
if n.Path.Value == sprintf {
n.Path.Value = fmt.Sprintf("%q", targetPkg)
}
}
}
return true
}, func(cursor *dstutil.Cursor) bool {
return true
})
}
func DeletePackageImports(file dst.Node, imports ...string) {
containsDeletedImport := false
deletedPackages := make(map[string]string)
dstutil.Apply(file, func(cursor *dstutil.Cursor) bool {
switch n := cursor.Node().(type) {
case *dst.ImportSpec:
for _, pkg := range imports {
if n.Path.Value == fmt.Sprintf("%q", pkg) {
containsDeletedImport = true
cursor.Delete()
if n.Name != nil {
deletedPackages[n.Name.Name] = pkg
} else {
deletedPackages[filepath.Base(pkg)] = pkg
}
}
}
return false
case *dst.SelectorExpr:
pkgRefName, ok := n.X.(*dst.Ident)
if !ok {
return true
}
if _, ok := deletedPackages[pkgRefName.Name]; ok {
RemovePackageRef(cursor.Parent(), n, -1)
}
case *dst.CaseClause:
for i, d := range n.List {
if sel, ok := d.(*dst.SelectorExpr); ok {
pkgRefName, ok := sel.X.(*dst.Ident)
if !ok {
return true
}
if _, ok := deletedPackages[pkgRefName.Name]; ok {
RemovePackageRef(n, sel, i)
}
}
}
}
return true
}, func(cursor *dstutil.Cursor) bool {
return true
})
if containsDeletedImport {
RemoveImportDefineIfNoPackage(file)
}
}
func RemovePackageRef(parent dst.Node, current *dst.SelectorExpr, inx int) {
switch p := parent.(type) {
case *dst.Field:
p.Type = dst.NewIdent(current.Sel.Name)
case *dst.Ellipsis:
p.Elt = dst.NewIdent(current.Sel.Name)
case *dst.StarExpr:
p.X = dst.NewIdent(current.Sel.Name)
case *dst.TypeAssertExpr:
p.Type = dst.NewIdent(current.Sel.Name)
case *dst.CompositeLit:
p.Type = dst.NewIdent(current.Sel.Name)
case *dst.ArrayType:
p.Elt = dst.NewIdent(current.Sel.Name)
case *dst.CallExpr:
p.Fun = dst.NewIdent(current.Sel.Name)
case *dst.KeyValueExpr:
p.Value = dst.NewIdent(current.Sel.Name)
case *dst.AssignStmt:
p.Rhs = []dst.Expr{dst.NewIdent(current.Sel.Name)}
case *dst.CaseClause:
p.List[inx] = dst.NewIdent(current.Sel.Name)
}
}
func RemoveImportDefineIfNoPackage(file dst.Node) {
// remove the import decl if empty
dstutil.Apply(file, func(cursor *dstutil.Cursor) bool {
if decl, ok := cursor.Node().(*dst.GenDecl); ok && decl.Tok == token.IMPORT && len(decl.Specs) == 0 {
cursor.Delete()
return false
}
return true
}, func(cursor *dstutil.Cursor) bool {
return true
})
}
type DebugInfo struct {
FilePath string
Line int
CheckOldLine bool
}
func BuildDSTDebugInfo(srcPath string, file *dst.File) (*DebugInfo, error) {
result := &DebugInfo{FilePath: srcPath}
if file != nil {
fset, f, err := decorator.RestoreFile(file)
if err != nil {
return nil, err
}
originalFile, err := os.ReadFile(srcPath)
if err != nil {
return nil, err
}
line, err := findFirstNoImportLocation(fset, f, bytes.NewBuffer(originalFile))
if err != nil {
return nil, err
}
result.Line = line
result.CheckOldLine = true
} else {
result.Line = 1
result.CheckOldLine = false
}
return result, nil
}
func WriteDSTFile(path string, file *dst.File, debug *DebugInfo) error {
output, err := os.Create(path)
if err != nil {
return err
}
defer output.Close()
content, err := GenerateDSTFileContent(file, debug)
if err != nil {
return err
}
if _, err = output.WriteString(content); err != nil {
return err
}
return nil
}
func GenerateDSTFileContent(file *dst.File, debug *DebugInfo) (string, error) {
var buf bytes.Buffer
writer := io.Writer(&buf)
fset, af, err := decorator.RestoreFile(file)
if err != nil {
return "", err
}
if debug != nil {
if err1 := writeDSTFileWithDebug(fset, af, debug, writer); err1 != nil {
return "", err1
}
return buf.String(), nil
}
if err := printer.Fprint(writer, fset, af); err != nil {
return "", err
}
return buf.String(), nil
}
func BuildFuncIdentity(pkgPath string, node *dst.FuncDecl) string {
var receiver string
if node.Recv != nil {
expr, ok := node.Recv.List[0].Type.(*dst.StarExpr)
if !ok {
return ""
}
ident, ok := expr.X.(*dst.Ident)
if !ok {
return ""
}
receiver = ident.Name
}
return fmt.Sprintf("%s_%s%s",
regexp.MustCompile(`[/.\-@]`).ReplaceAllString(pkgPath, "_"), receiver, node.Name)
}
type ImportAnalyzer struct {
imports map[string]map[string]*dst.ImportSpec
usedImports map[string]*dst.ImportSpec
}
func CreateImportAnalyzer() *ImportAnalyzer {
return &ImportAnalyzer{
imports: make(map[string]map[string]*dst.ImportSpec),
usedImports: make(map[string]*dst.ImportSpec)}
}
func (i *ImportAnalyzer) AnalyzeFileImports(filePath string, f dst.Node) {
imports := make(map[string]*dst.ImportSpec)
i.imports[filePath] = imports
dstutil.Apply(f, func(cursor *dstutil.Cursor) bool {
importSpec, ok := cursor.Node().(*dst.ImportSpec)
if !ok {
return true
}
var pkgName = filepath.Base(importSpec.Path.Value)
if importSpec.Name != nil {
pkgName = importSpec.Name.Name
}
imports[strings.Trim(pkgName, "\"")] = importSpec
return false
}, func(cursor *dstutil.Cursor) bool {
return true
})
}
func (i *ImportAnalyzer) AnalyzeNeedsImports(filePath string, fields *dst.FieldList) {
if fields == nil || len(fields.List) == 0 {
return
}
for _, f := range fields.List {
i.analyzeFieldImport(filePath, f.Type)
}
}
func (i *ImportAnalyzer) analyzeFieldImport(filePath string, exp dst.Expr) {
switch n := exp.(type) {
case *dst.Ident:
return
case *dst.SelectorExpr:
pkgRefName, ok := n.X.(*dst.Ident)
if !ok {
return
}
imports := i.imports[filePath]
if imports == nil {
return
}
spec := imports[pkgRefName.Name]
if spec == nil {
return
}
i.usedImports[pkgRefName.Name] = spec
case *dst.Ellipsis:
i.analyzeFieldImport(filePath, n.Elt)
case *dst.ArrayType:
i.analyzeFieldImport(filePath, n.Elt)
case *dst.StarExpr:
i.analyzeFieldImport(filePath, n.X)
}
}
func (i *ImportAnalyzer) AppendUsedImports(decl *dst.GenDecl) {
if decl.Tok != token.IMPORT {
return
}
for _, spec := range i.usedImports {
found := false
for _, existingSpec := range decl.Specs {
if existingSpec.(*dst.ImportSpec).Path.Value == spec.Path.Value {
found = true
break
}
}
if !found {
decl.Specs = append(decl.Specs, dst.Clone(spec).(*dst.ImportSpec))
}
}
}
func writeDSTFileWithDebug(fset *token.FileSet, file *ast.File, debug *DebugInfo, output io.Writer) error {
var changeInfo *dstFilePathChangeInfo
if !debug.CheckOldLine {
changeInfo = &dstFilePathChangeInfo{
oldDebugPath: debug.FilePath,
oldDebugLine: 1,
newDebugLine: 1,
}
if _, err := fmt.Fprintf(output, "//line %s:%d\n", debug.FilePath, debug.Line); err != nil {
return err
}
if err := printer.Fprint(output, fset, file); err != nil {
return err
}
return nil
}
var buffer bytes.Buffer
if err := printer.Fprint(&buffer, fset, file); err != nil {
return err
}
newPosition, err := findFirstNoImportLocation(fset, file, bytes.NewBuffer(buffer.Bytes()))
if err != nil {
return err
}
changeInfo = &dstFilePathChangeInfo{
oldDebugPath: debug.FilePath,
oldDebugLine: debug.Line,
newDebugLine: newPosition,
}
lineCount := 1
alreadyChange := false
for {
line, err := buffer.ReadBytes('\n')
if err != nil {
if err == io.EOF && !alreadyChange {
return fmt.Errorf("rewrite file line number failure: %v", err)
}
break
}
if lineCount == changeInfo.newDebugLine {
line = []byte(fmt.Sprintf("//line %s:%d\n%s", debug.FilePath, changeInfo.oldDebugLine, line))
alreadyChange = true
}
if _, e := output.Write(line); e != nil {
return err
}
lineCount++
}
return nil
}
type dstFilePathChangeInfo struct {
oldDebugPath string
oldDebugLine int
newDebugLine int
}
func findFirstNoImportLocation(fset *token.FileSet, file *ast.File, fileContent *bytes.Buffer) (int, error) {
var pos token.Pos
for _, decl := range file.Decls {
if genDecl, ok := decl.(*ast.GenDecl); ok {
if genDecl.Tok == token.IMPORT {
pos = genDecl.End()
continue
}
}
break
}
importEndLine := fset.Position(pos).Line
if pos == 0 {
if len(file.Decls) == 0 {
return 1, nil
}
importEndLine = fset.Position(file.Decls[0].Pos()).Line
}
lineNumber := 0
for {
line, err := fileContent.ReadBytes('\n')
if err != nil {
return 0, err
}
lineNumber++
if lineNumber < importEndLine {
continue
}
trimed := strings.TrimSpace(string(line))
if trimed == "" || trimed == ")" ||
(strings.HasPrefix(trimed, "import ")) ||
(packageImportExp.MatchString(trimed)) {
continue
}
return lineNumber, nil
}
}