diff --git a/gopls/internal/cache/testfuncs/tests.go b/gopls/internal/cache/testfuncs/tests.go index e0e3ce1beca..5cf579f2717 100644 --- a/gopls/internal/cache/testfuncs/tests.go +++ b/gopls/internal/cache/testfuncs/tests.go @@ -7,6 +7,7 @@ package testfuncs import ( "go/ast" "go/constant" + "go/token" "go/types" "strings" "unicode" @@ -133,6 +134,121 @@ func (b *indexBuilder) findSubtests(parent gobTest, typ *ast.FuncType, body *ast // parameter of the enclosing test function. var tests []gobTest for _, stmt := range body.List { + // Handle direct t.Run calls + if expr, ok := stmt.(*ast.ExprStmt); ok { + tests = append(tests, b.findDirectSubtests(parent, param, expr, file, files, info)...) + continue + } + + // Handle table-driven tests: for _, tt := range tests { t.Run(tt.name, ...) } + if rangeStmt, ok := stmt.(*ast.RangeStmt); ok { + tests = append(tests, b.findTableDrivenSubtests(parent, param, rangeStmt, file, files, info)...) + continue + } + } + return tests +} + +// findDirectSubtests finds subtests from direct t.Run("name", ...) calls. +func (b *indexBuilder) findDirectSubtests(parent gobTest, param types.Object, expr *ast.ExprStmt, file *parsego.File, files []*parsego.File, info *types.Info) []gobTest { + var tests []gobTest + + call, ok := expr.X.(*ast.CallExpr) + if !ok || len(call.Args) != 2 { + return nil + } + fun, ok := call.Fun.(*ast.SelectorExpr) + if !ok || fun.Sel.Name != "Run" { + return nil + } + recv, ok := fun.X.(*ast.Ident) + if !ok || info.ObjectOf(recv) != param { + return nil + } + + sig, ok := info.TypeOf(call.Args[1]).(*types.Signature) + if !ok { + return nil + } + if _, ok := testKind(sig); !ok { + return nil // subtest has wrong signature + } + + val := info.Types[call.Args[0]].Value // may be zero + if val == nil || val.Kind() != constant.String { + return nil + } + + var t gobTest + t.Name = b.uniqueName(parent.Name, rewrite(constant.StringVal(val))) + t.Location.URI = file.URI + t.Location.Range, _ = file.NodeRange(call) + tests = append(tests, t) + + fn, funcType, funcBody := findFunc(files, info, nil, call.Args[1]) + if funcType == nil { + return tests + } + + // Function literals don't have an associated object + if fn == nil { + tests = append(tests, b.findSubtests(t, funcType, funcBody, file, files, info)...) + return tests + } + + // Never recurse if the second argument is a top-level test function + if isTest, _ := isTestOrExample(fn); isTest { + return tests + } + + // Don't recurse into functions that have already been visited + if b.visited[fn] { + return tests + } + + b.visited[fn] = true + tests = append(tests, b.findSubtests(t, funcType, funcBody, file, files, info)...) + return tests +} + +// findTableDrivenSubtests finds subtests from table-driven tests. +// It handles patterns like: +// +// tests := []struct{ name string; ... }{{name: "test1"}, ...} +// for _, tt := range tests { +// t.Run(tt.name, func(t *testing.T) { ... }) +// } +func (b *indexBuilder) findTableDrivenSubtests(parent gobTest, param types.Object, rangeStmt *ast.RangeStmt, file *parsego.File, files []*parsego.File, info *types.Info) []gobTest { + var tests []gobTest + + // rangeStmt.Body should contain t.Run calls + if rangeStmt.Body == nil { + return nil + } + + // Get the loop variable (e.g., tt in "for _, tt := range tests") + var loopVar types.Object + if rangeStmt.Value != nil { + if ident, ok := rangeStmt.Value.(*ast.Ident); ok { + loopVar = info.ObjectOf(ident) + } + } + if loopVar == nil { + // Try rangeStmt.Key for "for tt := range tests" pattern + if rangeStmt.Key != nil { + if ident, ok := rangeStmt.Key.(*ast.Ident); ok { + loopVar = info.ObjectOf(ident) + } + } + } + if loopVar == nil { + return nil + } + + var testNameField *ast.Ident + // Find t.Run calls in the range body to confirm this is a table-driven test, if so then set the testNameVar + hasRun := false + for _, stmt := range rangeStmt.Body.List { expr, ok := stmt.(*ast.ExprStmt) if !ok { continue @@ -159,42 +275,202 @@ func (b *indexBuilder) findSubtests(parent gobTest, typ *ast.FuncType, body *ast continue // subtest has wrong signature } - val := info.Types[call.Args[0]].Value // may be zero - if val == nil || val.Kind() != constant.String { + // Check if first argument is a field access like tt.name, if so set + testNameField = b.isLoopVarFieldAccess(call.Args[0], loopVar, info) + if testNameField == nil { continue } + // TODO: handle expressions other than struct field selectors + + hasRun = true + break + } + + if !hasRun { + return nil + } + + // Find the table being ranged over and extract test cases with their locations + tableEntries := b.extractTableTestCases(rangeStmt.X, files, info, file, testNameField) + if len(tableEntries) == 0 { + return nil + } + + // Create a test entry for each table entry with its specific location + for _, entry := range tableEntries { var t gobTest - t.Name = b.uniqueName(parent.Name, rewrite(constant.StringVal(val))) + t.Name = b.uniqueName(parent.Name, rewrite(entry.name)) t.Location.URI = file.URI - t.Location.Range, _ = file.NodeRange(call) + t.Location.Range = entry.location tests = append(tests, t) + } - fn, typ, body := findFunc(files, info, body, call.Args[1]) - if typ == nil { - continue + return tests +} + +// isLoopVarFieldAccess checks if expr is a field access on the loop variable, if so returns the field identifier +// (e.g., tt.name where tt is the loop variable). +func (b *indexBuilder) isLoopVarFieldAccess(expr ast.Expr, loopVar types.Object, info *types.Info) *ast.Ident { + sel, ok := expr.(*ast.SelectorExpr) + if !ok { + return nil + } + ident, ok := sel.X.(*ast.Ident) + if !ok { + return nil + } + if info.ObjectOf(ident) != loopVar { + return nil + } + return sel.Sel +} + +// tableTestCase represents a single test case in a table-driven test +type tableTestCase struct { + name string + location protocol.Range +} + +// extractTableTestCases extracts test cases with their locations from a table-driven test slice. +// It handles patterns like: +// - tests := []struct{name string}{{"test1"}, {"test2"}} +// - []struct{name string}{{"test1"}, {"test2"}} +// - For identifier references, attempts to find the composite literal value +func (b *indexBuilder) extractTableTestCases(expr ast.Expr, files []*parsego.File, info *types.Info, file *parsego.File, testNameField *ast.Ident) []tableTestCase { + // Unwrap parentheses + for { + if paren, ok := expr.(*ast.ParenExpr); ok { + expr = paren.X + } else { + break } + } - // Function literals don't have an associated object - if fn == nil { - tests = append(tests, b.findSubtests(t, typ, body, file, files, info)...) - continue + // Handle both direct composite literals and identifiers + var comp *ast.CompositeLit + switch e := expr.(type) { + case *ast.CompositeLit: + comp = e + case *ast.Ident: + // Look for the assignment of this identifier + obj := info.ObjectOf(e) + if obj == nil { + return nil + } + // Find the composite literal from the identifier's definition + comp = b.findCompositeLiteralForIdent(e, files, info) + if comp == nil { + return nil } + default: + return nil + } + + // comp should be a slice composite literal + if comp.Type == nil { + return nil + } - // Never recurse if the second argument is a top-level test function - if isTest, _ := isTestOrExample(fn); isTest { + var cases []tableTestCase + for _, elt := range comp.Elts { + // Each element should be a struct literal + structLit, ok := elt.(*ast.CompositeLit) + if !ok { continue } - // Don't recurse into functions that have already been visited - if b.visited[fn] { + if len(structLit.Elts) == 0 { continue } - b.visited[fn] = true - tests = append(tests, b.findSubtests(t, typ, body, file, files, info)...) + // Try keyed fields first (e.g., {name: "test1", ...}) + for _, field := range structLit.Elts { + kv, ok := field.(*ast.KeyValueExpr) + if !ok { + //TODO: look for unkeyed fields + continue + } + key, ok := kv.Key.(*ast.Ident) + if !ok || key.Name != testNameField.Name { + continue + } + + // Get the location of this test case (the struct literal) + rng, err := file.NodeRange(structLit) + if err != nil { + continue + } + + // Extract the string value + if val := info.Types[kv.Value].Value; val != nil && val.Kind() == constant.String { + cases = append(cases, tableTestCase{ + name: constant.StringVal(val), + location: rng, + }) + break + } + } } - return tests + + return cases +} + +// findCompositeLiteralForIdent finds the composite literal that initializes the given identifier. +// It searches through the files for variable declarations and assignments. +func (b *indexBuilder) findCompositeLiteralForIdent(ident *ast.Ident, files []*parsego.File, info *types.Info) *ast.CompositeLit { + obj := info.ObjectOf(ident) + if obj == nil { + return nil + } + + // Search through all files to find the declaration + for _, file := range files { + // Walk through declarations to find variable declarations + for _, decl := range file.File.Decls { + // Check function declarations (where local variables are declared) + funcDecl, ok := decl.(*ast.FuncDecl) + if !ok || funcDecl.Body == nil { + continue + } + + // Walk through statements in the function body + for _, stmt := range funcDecl.Body.List { + // Check for short variable declaration: tests := ... + if assign, ok := stmt.(*ast.AssignStmt); ok && assign.Tok == token.DEFINE { + for i, lhs := range assign.Lhs { + if lhsIdent, ok := lhs.(*ast.Ident); ok && info.ObjectOf(lhsIdent) == obj { + // Found the declaration, check if RHS is a composite literal + if i < len(assign.Rhs) { + if comp, ok := assign.Rhs[i].(*ast.CompositeLit); ok { + return comp + } + } + } + } + } + + // Check for var declaration: var tests = ... + if declStmt, ok := stmt.(*ast.DeclStmt); ok { + if genDecl, ok := declStmt.Decl.(*ast.GenDecl); ok && genDecl.Tok == token.VAR { + for _, spec := range genDecl.Specs { + if valueSpec, ok := spec.(*ast.ValueSpec); ok { + for i, name := range valueSpec.Names { + if info.ObjectOf(name) == obj && i < len(valueSpec.Values) { + if comp, ok := valueSpec.Values[i].(*ast.CompositeLit); ok { + return comp + } + } + } + } + } + } + } + } + } + } + + return nil } // findFunc finds the type and body of the given expr, which may be a function diff --git a/gopls/internal/golang/code_lens.go b/gopls/internal/golang/code_lens.go index b04724e0cbc..203f57b9b6a 100644 --- a/gopls/internal/golang/code_lens.go +++ b/gopls/internal/golang/code_lens.go @@ -55,6 +55,13 @@ func runTestCodeLens(ctx context.Context, snapshot *cache.Snapshot, fh file.Hand codeLens = append(codeLens, protocol.CodeLens{Range: rng, Command: cmd}) } + // Add code lenses for subtests (including table-driven subtests) + subtestLenses, err := subtestCodeLenses(ctx, snapshot, pkg, puri) + if err != nil { + return nil, err + } + codeLens = append(codeLens, subtestLenses...) + for _, fn := range benchFuncs { cmd := command.NewRunTestsCommand("run benchmark", command.RunTestsArgs{ URI: puri, @@ -154,6 +161,46 @@ func matchTestFunc(fn *ast.FuncDecl, info *types.Info, nameRe *regexp.Regexp, pa return namedObj.Id() == paramID } +// subtestCodeLenses returns code lenses for subtests, including table-driven subtests. +func subtestCodeLenses(ctx context.Context, snapshot *cache.Snapshot, pkg *cache.Package, uri protocol.DocumentURI) ([]protocol.CodeLens, error) { + // Get test index which includes subtests + indexes, err := snapshot.Tests(ctx, pkg.Metadata().ID) + if err != nil { + return nil, err + } + if len(indexes) == 0 { + return nil, nil + } + + var codeLens []protocol.CodeLens + for _, idx := range indexes { + if idx == nil { + continue + } + for _, result := range idx.All() { + // Only show code lenses for subtests in the current file + if result.Location.URI != uri { + continue + } + + // Skip top-level tests (they already have code lenses) + if !strings.Contains(result.Name, "/") { + continue + } + + // Create a code lens for this subtest + cmd := command.NewRunTestsCommand("run subtest", command.RunTestsArgs{ + URI: uri, + Tests: []string{result.Name}, + }) + rng := protocol.Range{Start: result.Location.Range.Start, End: result.Location.Range.Start} + codeLens = append(codeLens, protocol.CodeLens{Range: rng, Command: cmd}) + } + } + + return codeLens, nil +} + func goGenerateCodeLens(ctx context.Context, snapshot *cache.Snapshot, fh file.Handle) ([]protocol.CodeLens, error) { pgf, err := snapshot.ParseGo(ctx, fh, parsego.Full) if err != nil {