Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(lsp): file incorrectly updated when importing modules, fixes #135 #136

Merged
merged 7 commits into from
Sep 5, 2023
293 changes: 293 additions & 0 deletions cmd/templ/lspcmd/proxy/import_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,293 @@
package proxy

import (
"fmt"
"strings"
"testing"

"github.com/google/go-cmp/cmp"
)

func TestFindLastImport(t *testing.T) {
tests := []struct {
name string
templContents string
packageName string
expected string
}{
{
name: "if there are no imports, add a single line import",
templContents: `package main

templ example() {
}
`,
packageName: "strings",
expected: `package main

import "strings"

templ example() {
}
`,
},
{
name: "if there is an existing single-line imports, add one at the end",
templContents: `package main

import "strings"

templ example() {
}
`,
packageName: "fmt",
expected: `package main

import "strings"
import "fmt"

templ example() {
}
`,
},
{
name: "if there are multiple existing single-line imports, add one at the end",
templContents: `package main

import "strings"
import "fmt"

templ example() {
}
`,
packageName: "time",
expected: `package main

import "strings"
import "fmt"
import "time"

templ example() {
}
`,
},
{
name: "if there are existing multi-line imports, add one at the end",
templContents: `package main

import (
"strings"
)

templ example() {
}
`,
packageName: "fmt",
expected: `package main

import (
"strings"
"fmt"
)

templ example() {
}
`,
},
{
name: "ignore imports that happen after templates",
templContents: `package main

import "strings"

templ example() {
}

import "other"
`,
packageName: "fmt",
expected: `package main

import "strings"
import "fmt"

templ example() {
}

import "other"
`,
},
{
name: "ignore imports that happen after funcs in the file",
templContents: `package main

import "strings"

func example() {
}

import "other"
`,
packageName: "fmt",
expected: `package main

import "strings"
import "fmt"

func example() {
}

import "other"
`,
},
{
name: "ignore imports that happen after css expressions in the file",
templContents: `package main

import "strings"

css example() {
}

import "other"
`,
packageName: "fmt",
expected: `package main

import "strings"
import "fmt"

css example() {
}

import "other"
`,
},
{
name: "ignore imports that happen after script expressions in the file",
templContents: `package main

import "strings"

script example() {
}

import "other"
`,
packageName: "fmt",
expected: `package main

import "strings"
import "fmt"

script example() {
}

import "other"
`,
},
{
name: "ignore imports that happen after var expressions in the file",
templContents: `package main

import "strings"

var s string

import "other"
`,
packageName: "fmt",
expected: `package main

import "strings"
import "fmt"

var s string

import "other"
`,
},
{
name: "ignore imports that happen after const expressions in the file",
templContents: `package main

import "strings"

const s = "test"

import "other"
`,
packageName: "fmt",
expected: `package main

import "strings"
import "fmt"

const s = "test"

import "other"
`,
},
{
name: "ignore imports that happen after type expressions in the file",
templContents: `package main

import "strings"

type Value int

import "other"
`,
packageName: "fmt",
expected: `package main

import "strings"
import "fmt"

type Value int

import "other"
`,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
lines := strings.Split(test.templContents, "\n")
imp := addImport(lines, fmt.Sprintf("%q", test.packageName))
textWithoutNewline := strings.TrimSuffix(imp.Text, "\n")
actualLines := append(lines[:imp.LineIndex], append([]string{textWithoutNewline}, lines[imp.LineIndex:]...)...)
actual := strings.Join(actualLines, "\n")
if diff := cmp.Diff(test.expected, actual); diff != "" {
t.Error(diff)
}
})
}
}

func TestGetPackageFromItemDetail(t *testing.T) {
tests := []struct {
input string
expected string
}{
{
input: `"fmt"`,
expected: `"fmt"`,
},
{
input: `func(state fmt.State, verb rune) string (from "fmt")`,
expected: `"fmt"`,
},
{
input: `non matching`,
expected: `non matching`,
},
}
for _, test := range tests {
t.Run(test.input, func(t *testing.T) {
actual := getPackageFromItemDetail(test.input)
if test.expected != actual {
t.Errorf("expected %q, got %q", test.expected, actual)
}
})
}
}
68 changes: 68 additions & 0 deletions cmd/templ/lspcmd/proxy/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package proxy
import (
"context"
"fmt"
"regexp"
"strings"

"github.com/a-h/parse"
Expand Down Expand Up @@ -338,11 +339,78 @@ func (p *Server) Completion(ctx context.Context, params *lsp.CompletionParams) (
if item.TextEdit != nil {
item.TextEdit.Range = p.convertGoRangeToTemplRange(templURI, item.TextEdit.Range)
}
if len(item.AdditionalTextEdits) > 0 {
doc, ok := p.TemplSource.Get(string(templURI))
if !ok {
continue
}
pkg := getPackageFromItemDetail(item.Detail)
imp := addImport(doc.Lines, pkg)
item.AdditionalTextEdits = []lsp.TextEdit{
{
Range: lsp.Range{
Start: lsp.Position{Line: uint32(imp.LineIndex), Character: 0},
End: lsp.Position{Line: uint32(imp.LineIndex), Character: 0},
},
NewText: imp.Text,
},
}
}
result.Items[i] = item
}
return
}

var completionWithImport = regexp.MustCompile(`^.*\(from\s(".+")\)$`)

func getPackageFromItemDetail(pkg string) string {
if m := completionWithImport.FindStringSubmatch(pkg); len(m) == 2 {
return m[1]
}
return pkg
}

type importInsert struct {
LineIndex int
Text string
}

var nonImportKeywordRegexp = regexp.MustCompile(`^(?:templ|func|css|script|var|const|type)\s`)

func addImport(lines []string, pkg string) (result importInsert) {
var isInMultiLineImport bool
lastSingleLineImportIndex := -1
for lineIndex, line := range lines {
if strings.HasPrefix(line, "import (") {
isInMultiLineImport = true
continue
}
if strings.HasPrefix(line, "import \"") {
lastSingleLineImportIndex = lineIndex
continue
}
if isInMultiLineImport && strings.HasPrefix(line, ")") {
return importInsert{
LineIndex: lineIndex,
Text: fmt.Sprintf("\t%s\n", pkg),
}
}
// Only add import statements before templates, functions, css, and script templates.
if nonImportKeywordRegexp.MatchString(line) {
break
}
}
var suffix string
if lastSingleLineImportIndex == -1 {
lastSingleLineImportIndex = 1
suffix = "\n"
}
return importInsert{
LineIndex: lastSingleLineImportIndex + 1,
Text: fmt.Sprintf("import %s\n%s", pkg, suffix),
}
}

func (p *Server) CompletionResolve(ctx context.Context, params *lsp.CompletionItem) (result *lsp.CompletionItem, err error) {
p.Log.Info("client -> server: CompletionResolve")
defer p.Log.Info("client -> server: CompletionResolve end")
Expand Down