Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
312 changes: 294 additions & 18 deletions gopls/internal/cache/testfuncs/tests.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ package testfuncs
import (
"go/ast"
"go/constant"
"go/token"
"go/types"
"strings"
"unicode"
Expand Down Expand Up @@ -133,6 +134,121 @@ func (b *indexBuilder) findSubtests(parent gobTest, typ *ast.FuncType, body *ast
// parameter of the enclosing test function.
var tests []gobTest
for _, stmt := range body.List {
// Handle direct t.Run calls
if expr, ok := stmt.(*ast.ExprStmt); ok {
tests = append(tests, b.findDirectSubtests(parent, param, expr, file, files, info)...)
continue
}

// Handle table-driven tests: for _, tt := range tests { t.Run(tt.name, ...) }
if rangeStmt, ok := stmt.(*ast.RangeStmt); ok {
tests = append(tests, b.findTableDrivenSubtests(parent, param, rangeStmt, file, files, info)...)
continue
}
}
return tests
}

// findDirectSubtests finds subtests from direct t.Run("name", ...) calls.
func (b *indexBuilder) findDirectSubtests(parent gobTest, param types.Object, expr *ast.ExprStmt, file *parsego.File, files []*parsego.File, info *types.Info) []gobTest {
var tests []gobTest

call, ok := expr.X.(*ast.CallExpr)
if !ok || len(call.Args) != 2 {
return nil
}
fun, ok := call.Fun.(*ast.SelectorExpr)
if !ok || fun.Sel.Name != "Run" {
return nil
}
recv, ok := fun.X.(*ast.Ident)
if !ok || info.ObjectOf(recv) != param {
return nil
}

sig, ok := info.TypeOf(call.Args[1]).(*types.Signature)
if !ok {
return nil
}
if _, ok := testKind(sig); !ok {
return nil // subtest has wrong signature
}

val := info.Types[call.Args[0]].Value // may be zero
if val == nil || val.Kind() != constant.String {
return nil
}

var t gobTest
t.Name = b.uniqueName(parent.Name, rewrite(constant.StringVal(val)))
t.Location.URI = file.URI
t.Location.Range, _ = file.NodeRange(call)
tests = append(tests, t)

fn, funcType, funcBody := findFunc(files, info, nil, call.Args[1])
if funcType == nil {
return tests
}

// Function literals don't have an associated object
if fn == nil {
tests = append(tests, b.findSubtests(t, funcType, funcBody, file, files, info)...)
return tests
}

// Never recurse if the second argument is a top-level test function
if isTest, _ := isTestOrExample(fn); isTest {
return tests
}

// Don't recurse into functions that have already been visited
if b.visited[fn] {
return tests
}

b.visited[fn] = true
tests = append(tests, b.findSubtests(t, funcType, funcBody, file, files, info)...)
return tests
}

// findTableDrivenSubtests finds subtests from table-driven tests.
// It handles patterns like:
//
// tests := []struct{ name string; ... }{{name: "test1"}, ...}
// for _, tt := range tests {
// t.Run(tt.name, func(t *testing.T) { ... })
// }
func (b *indexBuilder) findTableDrivenSubtests(parent gobTest, param types.Object, rangeStmt *ast.RangeStmt, file *parsego.File, files []*parsego.File, info *types.Info) []gobTest {
var tests []gobTest

// rangeStmt.Body should contain t.Run calls
if rangeStmt.Body == nil {
return nil
}

// Get the loop variable (e.g., tt in "for _, tt := range tests")
var loopVar types.Object
if rangeStmt.Value != nil {
if ident, ok := rangeStmt.Value.(*ast.Ident); ok {
loopVar = info.ObjectOf(ident)
}
}
if loopVar == nil {
// Try rangeStmt.Key for "for tt := range tests" pattern
if rangeStmt.Key != nil {
if ident, ok := rangeStmt.Key.(*ast.Ident); ok {
loopVar = info.ObjectOf(ident)
}
}
}
if loopVar == nil {
return nil
}

var testNameField *ast.Ident
// Find t.Run calls in the range body to confirm this is a table-driven test, if so then set the testNameVar
hasRun := false
for _, stmt := range rangeStmt.Body.List {
expr, ok := stmt.(*ast.ExprStmt)
if !ok {
continue
Expand All @@ -159,42 +275,202 @@ func (b *indexBuilder) findSubtests(parent gobTest, typ *ast.FuncType, body *ast
continue // subtest has wrong signature
}

val := info.Types[call.Args[0]].Value // may be zero
if val == nil || val.Kind() != constant.String {
// Check if first argument is a field access like tt.name, if so set
testNameField = b.isLoopVarFieldAccess(call.Args[0], loopVar, info)
if testNameField == nil {
continue
}

// TODO: handle expressions other than struct field selectors

hasRun = true
break
}

if !hasRun {
return nil
}

// Find the table being ranged over and extract test cases with their locations
tableEntries := b.extractTableTestCases(rangeStmt.X, files, info, file, testNameField)
if len(tableEntries) == 0 {
return nil
}

// Create a test entry for each table entry with its specific location
for _, entry := range tableEntries {
var t gobTest
t.Name = b.uniqueName(parent.Name, rewrite(constant.StringVal(val)))
t.Name = b.uniqueName(parent.Name, rewrite(entry.name))
t.Location.URI = file.URI
t.Location.Range, _ = file.NodeRange(call)
t.Location.Range = entry.location
tests = append(tests, t)
}

fn, typ, body := findFunc(files, info, body, call.Args[1])
if typ == nil {
continue
return tests
}

// isLoopVarFieldAccess checks if expr is a field access on the loop variable, if so returns the field identifier
// (e.g., tt.name where tt is the loop variable).
func (b *indexBuilder) isLoopVarFieldAccess(expr ast.Expr, loopVar types.Object, info *types.Info) *ast.Ident {
sel, ok := expr.(*ast.SelectorExpr)
if !ok {
return nil
}
ident, ok := sel.X.(*ast.Ident)
if !ok {
return nil
}
if info.ObjectOf(ident) != loopVar {
return nil
}
return sel.Sel
}

// tableTestCase represents a single test case in a table-driven test
type tableTestCase struct {
name string
location protocol.Range
}

// extractTableTestCases extracts test cases with their locations from a table-driven test slice.
// It handles patterns like:
// - tests := []struct{name string}{{"test1"}, {"test2"}}
// - []struct{name string}{{"test1"}, {"test2"}}
// - For identifier references, attempts to find the composite literal value
func (b *indexBuilder) extractTableTestCases(expr ast.Expr, files []*parsego.File, info *types.Info, file *parsego.File, testNameField *ast.Ident) []tableTestCase {
// Unwrap parentheses
for {
if paren, ok := expr.(*ast.ParenExpr); ok {
expr = paren.X
} else {
break
}
}

// Function literals don't have an associated object
if fn == nil {
tests = append(tests, b.findSubtests(t, typ, body, file, files, info)...)
continue
// Handle both direct composite literals and identifiers
var comp *ast.CompositeLit
switch e := expr.(type) {
case *ast.CompositeLit:
comp = e
case *ast.Ident:
// Look for the assignment of this identifier
obj := info.ObjectOf(e)
if obj == nil {
return nil
}
// Find the composite literal from the identifier's definition
comp = b.findCompositeLiteralForIdent(e, files, info)
if comp == nil {
return nil
}
default:
return nil
}

// comp should be a slice composite literal
if comp.Type == nil {
return nil
}

// Never recurse if the second argument is a top-level test function
if isTest, _ := isTestOrExample(fn); isTest {
var cases []tableTestCase
for _, elt := range comp.Elts {
// Each element should be a struct literal
structLit, ok := elt.(*ast.CompositeLit)
if !ok {
continue
}

// Don't recurse into functions that have already been visited
if b.visited[fn] {
if len(structLit.Elts) == 0 {
continue
}

b.visited[fn] = true
tests = append(tests, b.findSubtests(t, typ, body, file, files, info)...)
// Try keyed fields first (e.g., {name: "test1", ...})
for _, field := range structLit.Elts {
kv, ok := field.(*ast.KeyValueExpr)
if !ok {
//TODO: look for unkeyed fields
continue
}
key, ok := kv.Key.(*ast.Ident)
if !ok || key.Name != testNameField.Name {
continue
}

// Get the location of this test case (the struct literal)
rng, err := file.NodeRange(structLit)
if err != nil {
continue
}

// Extract the string value
if val := info.Types[kv.Value].Value; val != nil && val.Kind() == constant.String {
cases = append(cases, tableTestCase{
name: constant.StringVal(val),
location: rng,
})
break
}
}
}
return tests

return cases
}

// findCompositeLiteralForIdent finds the composite literal that initializes the given identifier.
// It searches through the files for variable declarations and assignments.
func (b *indexBuilder) findCompositeLiteralForIdent(ident *ast.Ident, files []*parsego.File, info *types.Info) *ast.CompositeLit {
obj := info.ObjectOf(ident)
if obj == nil {
return nil
}

// Search through all files to find the declaration
for _, file := range files {
// Walk through declarations to find variable declarations
for _, decl := range file.File.Decls {
// Check function declarations (where local variables are declared)
funcDecl, ok := decl.(*ast.FuncDecl)
if !ok || funcDecl.Body == nil {
continue
}

// Walk through statements in the function body
for _, stmt := range funcDecl.Body.List {
// Check for short variable declaration: tests := ...
if assign, ok := stmt.(*ast.AssignStmt); ok && assign.Tok == token.DEFINE {
for i, lhs := range assign.Lhs {
if lhsIdent, ok := lhs.(*ast.Ident); ok && info.ObjectOf(lhsIdent) == obj {
// Found the declaration, check if RHS is a composite literal
if i < len(assign.Rhs) {
if comp, ok := assign.Rhs[i].(*ast.CompositeLit); ok {
return comp
}
}
}
}
}

// Check for var declaration: var tests = ...
if declStmt, ok := stmt.(*ast.DeclStmt); ok {
if genDecl, ok := declStmt.Decl.(*ast.GenDecl); ok && genDecl.Tok == token.VAR {
for _, spec := range genDecl.Specs {
if valueSpec, ok := spec.(*ast.ValueSpec); ok {
for i, name := range valueSpec.Names {
if info.ObjectOf(name) == obj && i < len(valueSpec.Values) {
if comp, ok := valueSpec.Values[i].(*ast.CompositeLit); ok {
return comp
}
}
}
}
}
}
}
}
}
}

return nil
}

// findFunc finds the type and body of the given expr, which may be a function
Expand Down
Loading