Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .github/workflows/check-codegen.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ jobs:
- name: Build
run: |
go install ./internal/cmd/gtrace
go install ./internal/cmd/gstack
go install go.uber.org/mock/[email protected]

- name: Clean and re-generate *_gtrace.go files
Expand All @@ -40,5 +41,10 @@ jobs:
go generate ./trace
go generate ./...

- name: Re-generate stack.FunctionID calls
run: |
gstack .
rm ./gstack

- name: Check repository diff
run: bash ./.github/scripts/check-work-copy-equals-to-committed.sh "code-generation not equal with committed"
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
* Added query pool metrics
* Fixed logic of query session pool
* Changed initialization of internal driver clients to lazy
* Disabled the logic of background grpc-connection parking
* Disabled the logic of background grpc-connection parking
* Added codegarator for filling FunctionID with value from call stack

## v3.58.2
* Added `trace.Query.OnSessionBegin` event
Expand Down
223 changes: 223 additions & 0 deletions internal/cmd/gstack/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
package main

import (
"bytes"
"flag"
"fmt"
"go/ast"
"go/parser"
"go/token"
"io/fs"
"os"
"path/filepath"

"github.com/ydb-platform/ydb-go-sdk/v3/internal/cmd/gstack/utils"
)

func usage() {
fmt.Fprintf(os.Stderr, "usage: codegenerate [path]\n")
flag.PrintDefaults()
}

func getCallExpressionsFromExpr(expr ast.Expr) (listOfCalls []*ast.CallExpr) {
switch expr := expr.(type) {
case *ast.SelectorExpr:
listOfCalls = getCallExpressionsFromExpr(expr.X)
case *ast.IndexExpr:
listOfCalls = getCallExpressionsFromExpr(expr.X)
case *ast.StarExpr:
listOfCalls = getCallExpressionsFromExpr(expr.X)
case *ast.BinaryExpr:
listOfCalls = getCallExpressionsFromExpr(expr.X)
listOfCalls = append(listOfCalls, getCallExpressionsFromExpr(expr.Y)...)
case *ast.CallExpr:
listOfCalls = append(listOfCalls, expr)
listOfCalls = append(listOfCalls, getCallExpressionsFromExpr(expr.Fun)...)
for _, arg := range expr.Args {
listOfCalls = append(listOfCalls, getCallExpressionsFromExpr(arg)...)
}
case *ast.CompositeLit:
for _, elt := range expr.Elts {
listOfCalls = append(listOfCalls, getCallExpressionsFromExpr(elt)...)
}
case *ast.UnaryExpr:
listOfCalls = append(listOfCalls, getCallExpressionsFromExpr(expr.X)...)
case *ast.KeyValueExpr:
listOfCalls = append(listOfCalls, getCallExpressionsFromExpr(expr.Value)...)
case *ast.FuncLit:
listOfCalls = append(listOfCalls, getListOfCallExpressionsFromBlockStmt(expr.Body)...)
}

return listOfCalls
}

func getExprFromDeclStmt(statement *ast.DeclStmt) (listOfExpressions []ast.Expr) {
decl, ok := statement.Decl.(*ast.GenDecl)
if !ok {
return listOfExpressions
}
for _, spec := range decl.Specs {
if spec, ok := spec.(*ast.ValueSpec); ok {
for _, expr := range spec.Values {
listOfExpressions = append(listOfExpressions, expr)
}
}
}
return listOfExpressions
}

func getCallExpressionsFromStmt(statement ast.Stmt) (listOfCallExpressions []*ast.CallExpr) {
var body *ast.BlockStmt
var listOfExpressions []ast.Expr
switch statement.(type) {
case *ast.IfStmt:
body = statement.(*ast.IfStmt).Body
case *ast.SwitchStmt:
body = statement.(*ast.SwitchStmt).Body
case *ast.TypeSwitchStmt:
body = statement.(*ast.TypeSwitchStmt).Body
case *ast.SelectStmt:
body = statement.(*ast.SelectStmt).Body
case *ast.ForStmt:
body = statement.(*ast.ForStmt).Body
case *ast.RangeStmt:
body = statement.(*ast.RangeStmt).Body
case *ast.DeclStmt:
listOfExpressions = append(listOfExpressions, getExprFromDeclStmt(statement.(*ast.DeclStmt))...)
for _, expr := range listOfExpressions {
listOfCallExpressions = append(listOfCallExpressions, getCallExpressionsFromExpr(expr)...)
}
case *ast.CommClause:
stmts := statement.(*ast.CommClause).Body
for _, stmt := range stmts {
listOfCallExpressions = append(listOfCallExpressions, getCallExpressionsFromStmt(stmt)...)
}
case *ast.ExprStmt:
listOfCallExpressions = append(listOfCallExpressions, getCallExpressionsFromExpr(statement.(*ast.ExprStmt).X)...)
}
if body != nil {
listOfCallExpressions = append(
listOfCallExpressions,
getListOfCallExpressionsFromBlockStmt(body)...,
)
}

return listOfCallExpressions
}

func getListOfCallExpressionsFromBlockStmt(block *ast.BlockStmt) (listOfCallExpressions []*ast.CallExpr) {
for _, statement := range block.List {
switch expr := statement.(type) {
case *ast.ExprStmt:
listOfCallExpressions = append(listOfCallExpressions, getCallExpressionsFromExpr(expr.X)...)
case *ast.ReturnStmt:
for _, result := range expr.Results {
listOfCallExpressions = append(listOfCallExpressions, getCallExpressionsFromExpr(result)...)
}
case *ast.AssignStmt:
for _, rh := range expr.Rhs {
listOfCallExpressions = append(listOfCallExpressions, getCallExpressionsFromExpr(rh)...)
}
default:
listOfCallExpressions = append(listOfCallExpressions, getCallExpressionsFromStmt(statement)...)
}
}

return listOfCallExpressions
}

func format(src []byte, path string, fset *token.FileSet, file ast.File) ([]byte, error) {
var listOfArgs []utils.FunctionIDArg
for _, f := range file.Decls {
var listOfCalls []*ast.CallExpr
fn, ok := f.(*ast.FuncDecl)
if !ok {
continue
}
listOfCalls = getListOfCallExpressionsFromBlockStmt(fn.Body)
for _, call := range listOfCalls {
if function, ok := call.Fun.(*ast.SelectorExpr); ok && function.Sel.Name == "FunctionID" {
pack, ok := function.X.(*ast.Ident)
if !ok {
continue
}
if pack.Name == "stack" && len(call.Args) == 1 {
listOfArgs = append(listOfArgs, utils.FunctionIDArg{
FuncDecl: fn,
ArgPos: call.Args[0].Pos(),
ArgEnd: call.Args[0].End(),
})
}
}
}
}
if len(listOfArgs) != 0 {
fixed, err := utils.FixSource(fset, path, src, listOfArgs)
if err != nil {
return nil, err
}

return fixed, nil
}

return src, nil
}

func main() {
flag.Usage = usage
flag.Parse()
args := flag.Args()

if len(args) != 1 {
flag.Usage()

return
}
_, err := os.Stat(args[0])
if err != nil {
panic(err)
}

fileSystem := os.DirFS(args[0])

err = fs.WalkDir(fileSystem, ".", func(path string, d fs.DirEntry, err error) error {
fset := token.NewFileSet()
if err != nil {
return err
}
if d.IsDir() {
return nil
}
if filepath.Ext(path) == ".go" {
info, err := os.Stat(path)
if err != nil {
return err
}
src, err := utils.ReadFile(path, info)
if err != nil {
return err
}
file, err := parser.ParseFile(fset, path, nil, 0)
if err != nil {
return err
}
formatted, err := format(src, path, fset, *file)
if err != nil {
return err
}
if !bytes.Equal(src, formatted) {
err = utils.WriteFile(path, formatted, info.Mode().Perm())
if err != nil {
return err
}
}

return nil
}

return nil
})
if err != nil {
panic(err)
}
}
134 changes: 134 additions & 0 deletions internal/cmd/gstack/utils/utils.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
package utils

import (
"fmt"
"github.com/ydb-platform/ydb-go-sdk/v3/internal/version"
"go/ast"
"go/parser"
"go/token"
"io"
"io/fs"
"os"
"path/filepath"
"strings"
)

type FunctionIDArg struct {
FuncDecl *ast.FuncDecl
ArgPos token.Pos
ArgEnd token.Pos
}

func ReadFile(filename string, info fs.FileInfo) ([]byte, error) {
f, err := os.Open(filename)
if err != nil {
return nil, err
}
defer func(f *os.File) {
err := f.Close()
if err != nil {
}
}(f)
size := int(info.Size())
src := make([]byte, size)
n, err := io.ReadFull(f, src)
if err != nil {
return nil, err
}
if n < size {
return nil, fmt.Errorf("error: size of %s changed during reading (from %d to %d bytes)", filename, size, n)
} else if n > size {
return nil, fmt.Errorf("error: size of %s changed during reading (from %d to >=%d bytes)", filename, size, len(src))
}

return src, nil
}

func FixSource(fset *token.FileSet, path string, src []byte, listOfArgs []FunctionIDArg) ([]byte, error) {
var fixed []byte
var previousArgEnd int
for _, arg := range listOfArgs {
argPosOffset := fset.Position(arg.ArgPos).Offset
argEndOffset := fset.Position(arg.ArgEnd).Offset
argument, err := makeCall(fset, path, arg)
if err != nil {
return nil, err
}
fixed = append(fixed, src[previousArgEnd:argPosOffset]...)
fixed = append(fixed, fmt.Sprintf("\"%s\"", argument)...)
previousArgEnd = argEndOffset
}
fixed = append(fixed, src[previousArgEnd:]...)

return fixed, nil
}

func WriteFile(filename string, formatted []byte, perm fs.FileMode) error {
fout, err := os.OpenFile(filename, os.O_WRONLY|os.O_TRUNC, perm)
if err != nil {
return err
}

defer fout.Close()

_, err = fout.Write(formatted)
if err != nil {
return err
}

return nil
}

func makeCall(fset *token.FileSet, path string, arg FunctionIDArg) (string, error) {
basePath := filepath.Join("github.com/ydb-platform/", version.Prefix, version.Major, "")
packageName, err := getPackageName(fset, arg)
if err != nil {
return "", err
}
filePath := filepath.Dir(filepath.Dir(path))
funcName, err := getFuncName(arg.FuncDecl)
if err != nil {
return "", err
}
return strings.Join([]string{filepath.Join(basePath, filePath, packageName), funcName}, "."), nil
}

func getFuncName(funcDecl *ast.FuncDecl) (string, error) {
if funcDecl.Recv != nil {
recvType := funcDecl.Recv.List[0].Type
prefix, err := getIdentNameFromExpr(recvType)
if err != nil {
return "", err
}
return strings.Join([]string{prefix, funcDecl.Name.Name}, "."), nil
}
return funcDecl.Name.Name, nil
}

func getIdentNameFromExpr(expr ast.Expr) (string, error) {
switch expr := expr.(type) {
case *ast.Ident:
return expr.Name, nil
case *ast.StarExpr:
prefix, err := getIdentNameFromExpr(expr.X)
if err != nil {
return "", err
}
return "(*" + prefix + ")", nil
case *ast.IndexExpr:
return getIdentNameFromExpr(expr.X)
case *ast.IndexListExpr:
return getIdentNameFromExpr(expr.X)
default:
return "", fmt.Errorf("error during getting ident from expr")
}
}

func getPackageName(fset *token.FileSet, arg FunctionIDArg) (string, error) {
file := fset.File(arg.ArgPos)
parsedFile, err := parser.ParseFile(fset, file.Name(), nil, parser.PackageClauseOnly)
if err != nil {
return "", fmt.Errorf("error during get package name function")
}
return parsedFile.Name.Name, nil
}
Loading