diff --git a/cmd/templ/lspcmd/proxy/import_test.go b/cmd/templ/lspcmd/proxy/import_test.go new file mode 100644 index 000000000..c31f17573 --- /dev/null +++ b/cmd/templ/lspcmd/proxy/import_test.go @@ -0,0 +1,293 @@ +package proxy + +import ( + "fmt" + "strings" + "testing" + + "github.com/google/go-cmp/cmp" +) + +func TestFindLastImport(t *testing.T) { + tests := []struct { + name string + templContents string + packageName string + expected string + }{ + { + name: "if there are no imports, add a single line import", + templContents: `package main + +templ example() { +} +`, + packageName: "strings", + expected: `package main + +import "strings" + +templ example() { +} +`, + }, + { + name: "if there is an existing single-line imports, add one at the end", + templContents: `package main + +import "strings" + +templ example() { +} +`, + packageName: "fmt", + expected: `package main + +import "strings" +import "fmt" + +templ example() { +} +`, + }, + { + name: "if there are multiple existing single-line imports, add one at the end", + templContents: `package main + +import "strings" +import "fmt" + +templ example() { +} +`, + packageName: "time", + expected: `package main + +import "strings" +import "fmt" +import "time" + +templ example() { +} +`, + }, + { + name: "if there are existing multi-line imports, add one at the end", + templContents: `package main + +import ( + "strings" +) + +templ example() { +} +`, + packageName: "fmt", + expected: `package main + +import ( + "strings" + "fmt" +) + +templ example() { +} +`, + }, + { + name: "ignore imports that happen after templates", + templContents: `package main + +import "strings" + +templ example() { +} + +import "other" +`, + packageName: "fmt", + expected: `package main + +import "strings" +import "fmt" + +templ example() { +} + +import "other" +`, + }, + { + name: "ignore imports that happen after funcs in the file", + templContents: `package main + +import "strings" + +func example() { +} + +import "other" +`, + packageName: "fmt", + expected: `package main + +import "strings" +import "fmt" + +func example() { +} + +import "other" +`, + }, + { + name: "ignore imports that happen after css expressions in the file", + templContents: `package main + +import "strings" + +css example() { +} + +import "other" +`, + packageName: "fmt", + expected: `package main + +import "strings" +import "fmt" + +css example() { +} + +import "other" +`, + }, + { + name: "ignore imports that happen after script expressions in the file", + templContents: `package main + +import "strings" + +script example() { +} + +import "other" +`, + packageName: "fmt", + expected: `package main + +import "strings" +import "fmt" + +script example() { +} + +import "other" +`, + }, + { + name: "ignore imports that happen after var expressions in the file", + templContents: `package main + +import "strings" + +var s string + +import "other" +`, + packageName: "fmt", + expected: `package main + +import "strings" +import "fmt" + +var s string + +import "other" +`, + }, + { + name: "ignore imports that happen after const expressions in the file", + templContents: `package main + +import "strings" + +const s = "test" + +import "other" +`, + packageName: "fmt", + expected: `package main + +import "strings" +import "fmt" + +const s = "test" + +import "other" +`, + }, + { + name: "ignore imports that happen after type expressions in the file", + templContents: `package main + +import "strings" + +type Value int + +import "other" +`, + packageName: "fmt", + expected: `package main + +import "strings" +import "fmt" + +type Value int + +import "other" +`, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + lines := strings.Split(test.templContents, "\n") + imp := addImport(lines, fmt.Sprintf("%q", test.packageName)) + textWithoutNewline := strings.TrimSuffix(imp.Text, "\n") + actualLines := append(lines[:imp.LineIndex], append([]string{textWithoutNewline}, lines[imp.LineIndex:]...)...) + actual := strings.Join(actualLines, "\n") + if diff := cmp.Diff(test.expected, actual); diff != "" { + t.Error(diff) + } + }) + } +} + +func TestGetPackageFromItemDetail(t *testing.T) { + tests := []struct { + input string + expected string + }{ + { + input: `"fmt"`, + expected: `"fmt"`, + }, + { + input: `func(state fmt.State, verb rune) string (from "fmt")`, + expected: `"fmt"`, + }, + { + input: `non matching`, + expected: `non matching`, + }, + } + for _, test := range tests { + t.Run(test.input, func(t *testing.T) { + actual := getPackageFromItemDetail(test.input) + if test.expected != actual { + t.Errorf("expected %q, got %q", test.expected, actual) + } + }) + } +} diff --git a/cmd/templ/lspcmd/proxy/server.go b/cmd/templ/lspcmd/proxy/server.go index 2dba884d5..68b4c6766 100644 --- a/cmd/templ/lspcmd/proxy/server.go +++ b/cmd/templ/lspcmd/proxy/server.go @@ -3,6 +3,7 @@ package proxy import ( "context" "fmt" + "regexp" "strings" "github.com/a-h/parse" @@ -338,11 +339,78 @@ func (p *Server) Completion(ctx context.Context, params *lsp.CompletionParams) ( if item.TextEdit != nil { item.TextEdit.Range = p.convertGoRangeToTemplRange(templURI, item.TextEdit.Range) } + if len(item.AdditionalTextEdits) > 0 { + doc, ok := p.TemplSource.Get(string(templURI)) + if !ok { + continue + } + pkg := getPackageFromItemDetail(item.Detail) + imp := addImport(doc.Lines, pkg) + item.AdditionalTextEdits = []lsp.TextEdit{ + { + Range: lsp.Range{ + Start: lsp.Position{Line: uint32(imp.LineIndex), Character: 0}, + End: lsp.Position{Line: uint32(imp.LineIndex), Character: 0}, + }, + NewText: imp.Text, + }, + } + } result.Items[i] = item } return } +var completionWithImport = regexp.MustCompile(`^.*\(from\s(".+")\)$`) + +func getPackageFromItemDetail(pkg string) string { + if m := completionWithImport.FindStringSubmatch(pkg); len(m) == 2 { + return m[1] + } + return pkg +} + +type importInsert struct { + LineIndex int + Text string +} + +var nonImportKeywordRegexp = regexp.MustCompile(`^(?:templ|func|css|script|var|const|type)\s`) + +func addImport(lines []string, pkg string) (result importInsert) { + var isInMultiLineImport bool + lastSingleLineImportIndex := -1 + for lineIndex, line := range lines { + if strings.HasPrefix(line, "import (") { + isInMultiLineImport = true + continue + } + if strings.HasPrefix(line, "import \"") { + lastSingleLineImportIndex = lineIndex + continue + } + if isInMultiLineImport && strings.HasPrefix(line, ")") { + return importInsert{ + LineIndex: lineIndex, + Text: fmt.Sprintf("\t%s\n", pkg), + } + } + // Only add import statements before templates, functions, css, and script templates. + if nonImportKeywordRegexp.MatchString(line) { + break + } + } + var suffix string + if lastSingleLineImportIndex == -1 { + lastSingleLineImportIndex = 1 + suffix = "\n" + } + return importInsert{ + LineIndex: lastSingleLineImportIndex + 1, + Text: fmt.Sprintf("import %s\n%s", pkg, suffix), + } +} + func (p *Server) CompletionResolve(ctx context.Context, params *lsp.CompletionItem) (result *lsp.CompletionItem, err error) { p.Log.Info("client -> server: CompletionResolve") defer p.Log.Info("client -> server: CompletionResolve end")