diff --git a/cmd/templ/generatecmd/symlink/symlink_test.go b/cmd/templ/generatecmd/symlink/symlink_test.go new file mode 100644 index 000000000..eac5bc8bd --- /dev/null +++ b/cmd/templ/generatecmd/symlink/symlink_test.go @@ -0,0 +1,52 @@ +package symlink + +import ( + "context" + "io" + "log/slog" + "os" + "path" + "testing" + + "github.com/a-h/templ/cmd/templ/generatecmd" + "github.com/a-h/templ/cmd/templ/testproject" +) + +func TestSymlink(t *testing.T) { + log := slog.New(slog.NewJSONHandler(io.Discard, nil)) + t.Run("can generate if root is symlink", func(t *testing.T) { + // templ generate -f templates.templ + dir, err := testproject.Create("github.com/a-h/templ/cmd/templ/testproject") + if err != nil { + t.Fatalf("failed to create test project: %v", err) + } + defer os.RemoveAll(dir) + + symlinkPath := dir + "-symlink" + err = os.Symlink(dir, symlinkPath) + if err != nil { + t.Fatalf("failed to create dir symlink: %v", err) + } + defer os.Remove(symlinkPath) + + // Delete the templates_templ.go file to ensure it is generated. + err = os.Remove(path.Join(symlinkPath, "templates_templ.go")) + if err != nil { + t.Fatalf("failed to remove templates_templ.go: %v", err) + } + + // Run the generate command. + err = generatecmd.Run(context.Background(), log, generatecmd.Arguments{ + Path: symlinkPath, + }) + if err != nil { + t.Fatalf("failed to run generate command: %v", err) + } + + // Check the templates_templ.go file was created. + _, err = os.Stat(path.Join(symlinkPath, "templates_templ.go")) + if err != nil { + t.Fatalf("templates_templ.go was not created: %v", err) + } + }) +} diff --git a/cmd/templ/generatecmd/watcher/watch.go b/cmd/templ/generatecmd/watcher/watch.go index 57d725f2a..f97149c8d 100644 --- a/cmd/templ/generatecmd/watcher/watch.go +++ b/cmd/templ/generatecmd/watcher/watch.go @@ -2,6 +2,7 @@ package watcher import ( "context" + "io/fs" "os" "path" "path/filepath" @@ -36,18 +37,24 @@ func Recursive( // WalkFiles walks the file tree rooted at path, sending a Create event for each // file it encounters. func WalkFiles(ctx context.Context, path string, out chan fsnotify.Event) (err error) { - return filepath.WalkDir(path, func(path string, info os.DirEntry, err error) error { + rootPath := path + fileSystem := os.DirFS(rootPath) + return fs.WalkDir(fileSystem, ".", func(path string, info os.DirEntry, err error) error { if err != nil { return nil } - if info.IsDir() && shouldSkipDir(path) { + absPath, err := filepath.Abs(filepath.Join(rootPath, path)) + if err != nil { + return nil + } + if info.IsDir() && shouldSkipDir(absPath) { return filepath.SkipDir } - if !shouldIncludeFile(path) { + if !shouldIncludeFile(absPath) { return nil } out <- fsnotify.Event{ - Name: path, + Name: absPath, Op: fsnotify.Create, } return nil