diff --git a/AGENTS.md b/AGENTS.md index de26f97..9577fdd 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -11,7 +11,7 @@ Treehouse is a Go CLI tool that manages a pool of git worktrees for parallel AI - `internal/config/` — config file loading (`treehouse.toml`) - `internal/pool/` — pool manager (acquire, release, list, destroy) + state file - `internal/git/` — git operations (shells out to `git` binary) -- `internal/process/` — in-use detection via process cwd scanning +- `internal/process/` — in-use detection and lingering process termination for worktrees - `internal/shell/` — subshell spawning - `internal/ui/` — Y/n confirmation prompts diff --git a/README.md b/README.md index d46f92f..4ccdfc8 100644 --- a/README.md +++ b/README.md @@ -35,6 +35,7 @@ $ treehouse # get a worktree and drop into a subshell # Run your AI agent, make changes, do whatever you need. $ exit # exit the subshell when you're done +🌳 Terminated lingering processes: opencode (pid 12345) 🌳 Worktree returned to pool. ``` @@ -121,7 +122,9 @@ Treehouse manages a **pool of git worktrees** per repository, stored under `~/.t exit subshell │ ▼ - Reset worktree & return to pool + Terminate lingering worktree + processes, reset worktree, + & return to pool (ready for next agent) ``` @@ -136,7 +139,7 @@ Treehouse manages a **pool of git worktrees** per repository, stored under `~/.t | `treehouse` | Get a worktree and open a subshell (alias for `get`) | | `treehouse get` | Acquire a worktree from the pool | | `treehouse status` | Show pool status (highlights your current worktree) | -| `treehouse return [path]` | Return a worktree to the pool | +| `treehouse return [path]` | Terminate lingering worktree processes and return it to the pool | | `treehouse destroy [path]` | Remove a worktree from the pool | | `treehouse init` | Create a default `treehouse.toml` config file | | `treehouse update` | Update treehouse to the latest version | diff --git a/cmd/e2e_test.go b/cmd/e2e_test.go index 53f3b46..03cf9fb 100644 --- a/cmd/e2e_test.go +++ b/cmd/e2e_test.go @@ -106,9 +106,14 @@ func setupTestRepo(t *testing.T) (repoDir, homeDir string) { // HOME (or USERPROFILE on Windows) is set to homeDir so pool state is isolated. func runTreehouse(t *testing.T, repoDir, homeDir string, extraEnv []string, args ...string) (stdout, stderr string, exitCode int) { t.Helper() + return runTreehouseFromDir(t, repoDir, repoDir, homeDir, extraEnv, args...) +} + +func runTreehouseFromDir(t *testing.T, repoDir, workDir, homeDir string, extraEnv []string, args ...string) (stdout, stderr string, exitCode int) { + t.Helper() cmd := exec.Command(treehouseBin, args...) - cmd.Dir = repoDir + cmd.Dir = workDir cmd.Env = buildEnv(homeDir, extraEnv...) var outBuf, errBuf bytes.Buffer @@ -312,6 +317,31 @@ func TestGetReusesWorktree(t *testing.T) { } } +func TestReturnFromInsideWorktreeDoesNotTerminateCaller(t *testing.T) { + repoDir, homeDir := setupTestRepo(t) + env := []string{"SHELL=" + exitShellBin} + + _, getErr, code := runTreehouse(t, repoDir, homeDir, env, "get") + if code != 0 { + t.Fatalf("get failed (code %d): %s", code, getErr) + } + wtPath := extractWorktreePath(getErr, homeDir) + if wtPath == "" { + t.Fatal("could not extract worktree path") + } + + _, returnErr, code := runTreehouseFromDir(t, repoDir, wtPath, homeDir, nil, "return", "--force") + if code != 0 { + t.Fatalf("return from inside worktree failed (code %d): %s", code, returnErr) + } + if !strings.Contains(returnErr, "Worktree returned to pool") { + t.Fatalf("expected return confirmation, got: %s", returnErr) + } + if strings.Contains(returnErr, "Terminated lingering processes") && strings.Contains(returnErr, "treehouse") { + t.Fatalf("return should not terminate its own process chain: %s", returnErr) + } +} + func TestDestroySpecific(t *testing.T) { repoDir, homeDir := setupTestRepo(t) env := []string{"SHELL=" + exitShellBin} diff --git a/cmd/get.go b/cmd/get.go index 7ff0b3d..b4bfab1 100644 --- a/cmd/get.go +++ b/cmd/get.go @@ -4,12 +4,15 @@ import ( "fmt" "os" "path/filepath" + "strings" + "time" "github.com/spf13/cobra" "github.com/kunchenguid/treehouse/internal/config" "github.com/kunchenguid/treehouse/internal/git" "github.com/kunchenguid/treehouse/internal/pool" + "github.com/kunchenguid/treehouse/internal/process" "github.com/kunchenguid/treehouse/internal/shell" "github.com/kunchenguid/treehouse/internal/ui" ) @@ -68,6 +71,8 @@ func getRunE(cmd *cobra.Command, args []string) error { } } + killLingeringProcesses(wtPath) + if err := pool.Release(poolDir, wtPath); err != nil { fmt.Fprintf(os.Stderr, "🌳 Warning: failed to clean worktree: %v\n", err) } else { @@ -76,3 +81,22 @@ func getRunE(cmd *cobra.Command, args []string) error { return nil } + +// killLingeringProcesses terminates any process whose cwd is within the given +// worktree. Called before returning a worktree to the pool so detached tools +// (e.g. opencode servers that ignore SIGHUP) don't keep holding the worktree. +func killLingeringProcesses(wtPath string) { + killed, err := process.TerminateWorktreeProcesses(wtPath, 2*time.Second) + if err != nil { + fmt.Fprintf(os.Stderr, "🌳 Warning: failed to scan for lingering processes: %v\n", err) + return + } + if len(killed) == 0 { + return + } + names := make([]string, len(killed)) + for i, p := range killed { + names[i] = p.String() + } + fmt.Fprintf(os.Stderr, "🌳 Terminated lingering processes: %s\n", strings.Join(names, ", ")) +} diff --git a/cmd/return_cmd.go b/cmd/return_cmd.go index 5beb51d..910bef7 100644 --- a/cmd/return_cmd.go +++ b/cmd/return_cmd.go @@ -17,7 +17,7 @@ var returnForce bool var returnCmd = &cobra.Command{ Use: "return [path]", - Short: "Return a worktree to the pool", + Short: "Terminate lingering processes and return a worktree", RunE: func(cmd *cobra.Command, args []string) error { wtPath, err := resolveWorktreePath(args) if err != nil { @@ -55,6 +55,8 @@ var returnCmd = &cobra.Command{ } } + killLingeringProcesses(wtPath) + if err := pool.Release(poolDir, wtPath); err != nil { return fmt.Errorf("failed to return worktree: %w", err) } diff --git a/internal/process/detect.go b/internal/process/detect.go index 15cb660..cabebe3 100644 --- a/internal/process/detect.go +++ b/internal/process/detect.go @@ -35,6 +35,7 @@ func FindProcessesInWorktree(worktreePath string) ([]ProcessInfo, error) { if err != nil { return nil, err } + absWorktree = resolvePath(absWorktree) var result []ProcessInfo @@ -48,6 +49,7 @@ func FindProcessesInWorktree(worktreePath string) ([]ProcessInfo, error) { if err != nil { continue } + absCwd = resolvePath(absCwd) rel, err := filepath.Rel(absWorktree, absCwd) if err != nil { @@ -65,3 +67,14 @@ func FindProcessesInWorktree(worktreePath string) ([]ProcessInfo, error) { return result, nil } + +// resolvePath returns the symlink-resolved path, or the input if resolution +// fails (e.g. path doesn't exist). This lets us match process cwds (which +// gopsutil returns canonicalized, e.g. /private/var/... on macOS) against +// caller-supplied worktree paths that may still contain symlinks. +func resolvePath(p string) string { + if resolved, err := filepath.EvalSymlinks(p); err == nil { + return resolved + } + return p +} diff --git a/internal/process/detect_unix_test.go b/internal/process/detect_unix_test.go new file mode 100644 index 0000000..a12660c --- /dev/null +++ b/internal/process/detect_unix_test.go @@ -0,0 +1,52 @@ +//go:build !windows + +package process + +import ( + "os" + "os/exec" + "path/filepath" + "testing" + "time" +) + +// FindProcessesInWorktree should match a process whose cwd resolves to the +// same real path as the worktree, even when the caller passes a symlinked +// worktree path. This also covers macOS /tmp -> /private/tmp. +func TestFindProcessesInWorktree_ResolvesSymlinks(t *testing.T) { + realDir := t.TempDir() + + linkDir := filepath.Join(t.TempDir(), "link") + if err := os.Symlink(realDir, linkDir); err != nil { + t.Fatalf("symlink: %v", err) + } + + cmd := exec.Command("sleep", "60") + cmd.Dir = realDir + if err := cmd.Start(); err != nil { + t.Skipf("cannot start sleep: %v", err) + } + t.Cleanup(func() { + _ = cmd.Process.Kill() + _ = cmd.Wait() + }) + + time.Sleep(200 * time.Millisecond) + + procs, err := FindProcessesInWorktree(linkDir) + if err != nil { + t.Fatalf("FindProcessesInWorktree: %v", err) + } + + var found bool + for _, p := range procs { + if int(p.PID) == cmd.Process.Pid { + found = true + break + } + } + if !found { + t.Fatalf("expected to find pid %d via symlinked path %q, got %v", + cmd.Process.Pid, linkDir, procs) + } +} diff --git a/internal/process/terminate.go b/internal/process/terminate.go new file mode 100644 index 0000000..7318528 --- /dev/null +++ b/internal/process/terminate.go @@ -0,0 +1,74 @@ +package process + +import ( + "os" + "time" + + gopsutilprocess "github.com/shirou/gopsutil/v4/process" +) + +// TerminateWorktreeProcesses finds every process whose cwd is within the given +// worktree path and terminates them. +// +// On unix it sends SIGTERM, waits up to gracePeriod for processes to exit, +// then SIGKILLs any survivors. On windows it uses TerminateProcess. +// +// Returns the list of processes that were targeted. Errors only if the initial +// scan fails; individual kill failures (e.g. process already gone) are +// swallowed. +func TerminateWorktreeProcesses(worktreePath string, gracePeriod time.Duration) ([]ProcessInfo, error) { + procs, err := FindProcessesInWorktree(worktreePath) + if err != nil { + return nil, err + } + procs = filterProtectedProcesses(procs, int32(os.Getpid()), parentPID) + if len(procs) == 0 { + return nil, nil + } + + pids := make([]int32, len(procs)) + for i, p := range procs { + pids[i] = p.PID + } + + terminate(pids, gracePeriod) + return procs, nil +} + +func filterProtectedProcesses(procs []ProcessInfo, currentPID int32, lookupParent func(int32) (int32, error)) []ProcessInfo { + protected := map[int32]struct{}{ + currentPID: {}, + } + + for pid := currentPID; pid > 0; { + parent, err := lookupParent(pid) + if err != nil { + return nil + } + if parent <= 0 { + break + } + if _, seen := protected[parent]; seen { + break + } + protected[parent] = struct{}{} + pid = parent + } + + filtered := procs[:0] + for _, proc := range procs { + if _, skip := protected[proc.PID]; skip { + continue + } + filtered = append(filtered, proc) + } + return filtered +} + +func parentPID(pid int32) (int32, error) { + proc, err := gopsutilprocess.NewProcess(pid) + if err != nil { + return 0, err + } + return proc.Ppid() +} diff --git a/internal/process/terminate_test.go b/internal/process/terminate_test.go new file mode 100644 index 0000000..fa7155d --- /dev/null +++ b/internal/process/terminate_test.go @@ -0,0 +1,56 @@ +package process + +import ( + "errors" + "testing" +) + +func TestFilterProtectedProcesses_SkipsCurrentProcessAndAncestors(t *testing.T) { + procs := []ProcessInfo{ + {PID: 100, Name: "shell"}, + {PID: 200, Name: "treehouse"}, + {PID: 300, Name: "server"}, + } + + filtered := filterProtectedProcesses(procs, 200, func(pid int32) (int32, error) { + switch pid { + case 200: + return 100, nil + case 100: + return 1, nil + case 1: + return 0, nil + default: + return 0, errors.New("unknown pid") + } + }) + + if len(filtered) != 1 { + t.Fatalf("expected 1 process after filtering, got %d", len(filtered)) + } + if filtered[0].PID != 300 { + t.Fatalf("expected pid 300 to remain, got %d", filtered[0].PID) + } + if filtered[0].Name != "server" { + t.Fatalf("expected server to remain, got %q", filtered[0].Name) + } +} + +func TestFilterProtectedProcesses_SkipsTerminationWhenParentLookupFails(t *testing.T) { + procs := []ProcessInfo{ + {PID: 100, Name: "shell"}, + {PID: 200, Name: "treehouse"}, + {PID: 300, Name: "server"}, + } + + filtered := filterProtectedProcesses(procs, 200, func(pid int32) (int32, error) { + if pid == 200 { + return 0, errors.New("cannot inspect parent") + } + return 0, nil + }) + + if len(filtered) != 0 { + t.Fatalf("expected no processes after filtering, got %+v", filtered) + } +} diff --git a/internal/process/terminate_unix.go b/internal/process/terminate_unix.go new file mode 100644 index 0000000..b86de74 --- /dev/null +++ b/internal/process/terminate_unix.go @@ -0,0 +1,42 @@ +//go:build !windows + +package process + +import ( + "syscall" + "time" +) + +func terminate(pids []int32, gracePeriod time.Duration) { + for _, pid := range pids { + _ = syscall.Kill(int(pid), syscall.SIGTERM) + } + + deadline := time.Now().Add(gracePeriod) + for time.Now().Before(deadline) { + if !anyAlive(pids) { + return + } + time.Sleep(50 * time.Millisecond) + } + + for _, pid := range pids { + if isAlive(pid) { + _ = syscall.Kill(int(pid), syscall.SIGKILL) + } + } +} + +// isAlive uses signal 0 which validates process existence without signaling it. +func isAlive(pid int32) bool { + return syscall.Kill(int(pid), 0) == nil +} + +func anyAlive(pids []int32) bool { + for _, pid := range pids { + if isAlive(pid) { + return true + } + } + return false +} diff --git a/internal/process/terminate_unix_test.go b/internal/process/terminate_unix_test.go new file mode 100644 index 0000000..888f7c6 --- /dev/null +++ b/internal/process/terminate_unix_test.go @@ -0,0 +1,114 @@ +//go:build !windows + +package process + +import ( + "os/exec" + "syscall" + "testing" + "time" +) + +func TestTerminateWorktreeProcesses_NoProcesses(t *testing.T) { + dir := t.TempDir() + procs, err := TerminateWorktreeProcesses(dir, 1*time.Second) + if err != nil { + t.Fatalf("TerminateWorktreeProcesses: %v", err) + } + if len(procs) != 0 { + t.Errorf("expected 0 processes, got %d", len(procs)) + } +} + +func TestTerminateWorktreeProcesses_KillsProcessInWorktree(t *testing.T) { + dir := t.TempDir() + + cmd := exec.Command("sleep", "60") + cmd.Dir = dir + if err := cmd.Start(); err != nil { + t.Skipf("cannot start sleep: %v", err) + } + t.Cleanup(func() { + _ = cmd.Process.Kill() + _ = cmd.Wait() + }) + + // Give the OS a moment to record the process cwd. + time.Sleep(200 * time.Millisecond) + + procs, err := TerminateWorktreeProcesses(dir, 2*time.Second) + if err != nil { + t.Fatalf("TerminateWorktreeProcesses: %v", err) + } + if len(procs) == 0 { + t.Fatal("expected at least one process, got none") + } + + done := make(chan error, 1) + go func() { done <- cmd.Wait() }() + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("process was not terminated within 5s") + } +} + +func TestTerminateWorktreeProcesses_EscalatesToKill(t *testing.T) { + dir := t.TempDir() + + // Ignore SIGTERM; only SIGKILL should end this process. + cmd := exec.Command("sh", "-c", "trap '' TERM; sleep 60") + cmd.Dir = dir + if err := cmd.Start(); err != nil { + t.Skipf("cannot start sh: %v", err) + } + t.Cleanup(func() { + _ = cmd.Process.Kill() + _ = cmd.Wait() + }) + + time.Sleep(200 * time.Millisecond) + + start := time.Now() + procs, err := TerminateWorktreeProcesses(dir, 500*time.Millisecond) + if err != nil { + t.Fatalf("TerminateWorktreeProcesses: %v", err) + } + if len(procs) == 0 { + t.Fatal("expected processes to target, got none") + } + if elapsed := time.Since(start); elapsed < 400*time.Millisecond { + t.Errorf("expected grace period to elapse before SIGKILL, only took %s", elapsed) + } + + done := make(chan error, 1) + go func() { done <- cmd.Wait() }() + select { + case <-done: + case <-time.After(3 * time.Second): + t.Fatal("process was not killed after escalation") + } +} + +func TestTerminateWorktreeProcesses_AlreadyDeadPIDIsNoop(t *testing.T) { + // Kill our own fake PID that won't exist; Terminate should not error. + // We simulate by starting a process, killing it, then calling Terminate + // on the same cwd - FindProcessesInWorktree will return nothing, but we + // still cover the "dead pids" case by exercising the scan-empty path. + dir := t.TempDir() + cmd := exec.Command("sleep", "60") + cmd.Dir = dir + if err := cmd.Start(); err != nil { + t.Skipf("cannot start sleep: %v", err) + } + _ = cmd.Process.Signal(syscall.SIGKILL) + _ = cmd.Wait() + + procs, err := TerminateWorktreeProcesses(dir, 500*time.Millisecond) + if err != nil { + t.Fatalf("TerminateWorktreeProcesses: %v", err) + } + if len(procs) != 0 { + t.Errorf("expected 0 processes after kill, got %d", len(procs)) + } +} diff --git a/internal/process/terminate_windows.go b/internal/process/terminate_windows.go new file mode 100644 index 0000000..155e497 --- /dev/null +++ b/internal/process/terminate_windows.go @@ -0,0 +1,22 @@ +//go:build windows + +package process + +import ( + "time" + + "golang.org/x/sys/windows" +) + +func terminate(pids []int32, _ time.Duration) { + // Windows has no graceful SIGTERM equivalent for arbitrary processes; + // TerminateProcess is the standard way to end a process from outside it. + for _, pid := range pids { + h, err := windows.OpenProcess(windows.PROCESS_TERMINATE, false, uint32(pid)) + if err != nil { + continue + } + _ = windows.TerminateProcess(h, 1) + _ = windows.CloseHandle(h) + } +}