Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ Treehouse is a Go CLI tool that manages a pool of git worktrees for parallel AI
- `internal/config/` — config file loading (`treehouse.toml`)
- `internal/pool/` — pool manager (acquire, release, list, destroy) + state file
- `internal/git/` — git operations (shells out to `git` binary)
- `internal/process/` — in-use detection via process cwd scanning
- `internal/process/` — in-use detection and lingering process termination for worktrees
- `internal/shell/` — subshell spawning
- `internal/ui/` — Y/n confirmation prompts

Expand Down
7 changes: 5 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ $ treehouse # get a worktree and drop into a subshell
# Run your AI agent, make changes, do whatever you need.

$ exit # exit the subshell when you're done
🌳 Terminated lingering processes: opencode (pid 12345)
🌳 Worktree returned to pool.
```

Expand Down Expand Up @@ -121,7 +122,9 @@ Treehouse manages a **pool of git worktrees** per repository, stored under `~/.t
exit subshell
Reset worktree & return to pool
Terminate lingering worktree
processes, reset worktree,
& return to pool
(ready for next agent)
```

Expand All @@ -136,7 +139,7 @@ Treehouse manages a **pool of git worktrees** per repository, stored under `~/.t
| `treehouse` | Get a worktree and open a subshell (alias for `get`) |
| `treehouse get` | Acquire a worktree from the pool |
| `treehouse status` | Show pool status (highlights your current worktree) |
| `treehouse return [path]` | Return a worktree to the pool |
| `treehouse return [path]` | Terminate lingering worktree processes and return it to the pool |
| `treehouse destroy [path]` | Remove a worktree from the pool |
| `treehouse init` | Create a default `treehouse.toml` config file |
| `treehouse update` | Update treehouse to the latest version |
Expand Down
32 changes: 31 additions & 1 deletion cmd/e2e_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,14 @@ func setupTestRepo(t *testing.T) (repoDir, homeDir string) {
// HOME (or USERPROFILE on Windows) is set to homeDir so pool state is isolated.
func runTreehouse(t *testing.T, repoDir, homeDir string, extraEnv []string, args ...string) (stdout, stderr string, exitCode int) {
t.Helper()
return runTreehouseFromDir(t, repoDir, repoDir, homeDir, extraEnv, args...)
}

func runTreehouseFromDir(t *testing.T, repoDir, workDir, homeDir string, extraEnv []string, args ...string) (stdout, stderr string, exitCode int) {
t.Helper()

cmd := exec.Command(treehouseBin, args...)
cmd.Dir = repoDir
cmd.Dir = workDir
cmd.Env = buildEnv(homeDir, extraEnv...)

var outBuf, errBuf bytes.Buffer
Expand Down Expand Up @@ -312,6 +317,31 @@ func TestGetReusesWorktree(t *testing.T) {
}
}

func TestReturnFromInsideWorktreeDoesNotTerminateCaller(t *testing.T) {
repoDir, homeDir := setupTestRepo(t)
env := []string{"SHELL=" + exitShellBin}

_, getErr, code := runTreehouse(t, repoDir, homeDir, env, "get")
if code != 0 {
t.Fatalf("get failed (code %d): %s", code, getErr)
}
wtPath := extractWorktreePath(getErr, homeDir)
if wtPath == "" {
t.Fatal("could not extract worktree path")
}

_, returnErr, code := runTreehouseFromDir(t, repoDir, wtPath, homeDir, nil, "return", "--force")
if code != 0 {
t.Fatalf("return from inside worktree failed (code %d): %s", code, returnErr)
}
if !strings.Contains(returnErr, "Worktree returned to pool") {
t.Fatalf("expected return confirmation, got: %s", returnErr)
}
if strings.Contains(returnErr, "Terminated lingering processes") && strings.Contains(returnErr, "treehouse") {
t.Fatalf("return should not terminate its own process chain: %s", returnErr)
}
}

func TestDestroySpecific(t *testing.T) {
repoDir, homeDir := setupTestRepo(t)
env := []string{"SHELL=" + exitShellBin}
Expand Down
24 changes: 24 additions & 0 deletions cmd/get.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,15 @@ import (
"fmt"
"os"
"path/filepath"
"strings"
"time"

"github.com/spf13/cobra"

"github.com/kunchenguid/treehouse/internal/config"
"github.com/kunchenguid/treehouse/internal/git"
"github.com/kunchenguid/treehouse/internal/pool"
"github.com/kunchenguid/treehouse/internal/process"
"github.com/kunchenguid/treehouse/internal/shell"
"github.com/kunchenguid/treehouse/internal/ui"
)
Expand Down Expand Up @@ -68,6 +71,8 @@ func getRunE(cmd *cobra.Command, args []string) error {
}
}

killLingeringProcesses(wtPath)

if err := pool.Release(poolDir, wtPath); err != nil {
fmt.Fprintf(os.Stderr, "🌳 Warning: failed to clean worktree: %v\n", err)
} else {
Expand All @@ -76,3 +81,22 @@ func getRunE(cmd *cobra.Command, args []string) error {

return nil
}

// killLingeringProcesses terminates any process whose cwd is within the given
// worktree. Called before returning a worktree to the pool so detached tools
// (e.g. opencode servers that ignore SIGHUP) don't keep holding the worktree.
func killLingeringProcesses(wtPath string) {
killed, err := process.TerminateWorktreeProcesses(wtPath, 2*time.Second)
if err != nil {
fmt.Fprintf(os.Stderr, "🌳 Warning: failed to scan for lingering processes: %v\n", err)
return
}
if len(killed) == 0 {
return
}
names := make([]string, len(killed))
for i, p := range killed {
names[i] = p.String()
}
fmt.Fprintf(os.Stderr, "🌳 Terminated lingering processes: %s\n", strings.Join(names, ", "))
}
4 changes: 3 additions & 1 deletion cmd/return_cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ var returnForce bool

var returnCmd = &cobra.Command{
Use: "return [path]",
Short: "Return a worktree to the pool",
Short: "Terminate lingering processes and return a worktree",
RunE: func(cmd *cobra.Command, args []string) error {
wtPath, err := resolveWorktreePath(args)
if err != nil {
Expand Down Expand Up @@ -55,6 +55,8 @@ var returnCmd = &cobra.Command{
}
}

killLingeringProcesses(wtPath)

if err := pool.Release(poolDir, wtPath); err != nil {
return fmt.Errorf("failed to return worktree: %w", err)
}
Expand Down
13 changes: 13 additions & 0 deletions internal/process/detect.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ func FindProcessesInWorktree(worktreePath string) ([]ProcessInfo, error) {
if err != nil {
return nil, err
}
absWorktree = resolvePath(absWorktree)

var result []ProcessInfo

Expand All @@ -48,6 +49,7 @@ func FindProcessesInWorktree(worktreePath string) ([]ProcessInfo, error) {
if err != nil {
continue
}
absCwd = resolvePath(absCwd)

rel, err := filepath.Rel(absWorktree, absCwd)
if err != nil {
Expand All @@ -65,3 +67,14 @@ func FindProcessesInWorktree(worktreePath string) ([]ProcessInfo, error) {

return result, nil
}

// resolvePath returns the symlink-resolved path, or the input if resolution
// fails (e.g. path doesn't exist). This lets us match process cwds (which
// gopsutil returns canonicalized, e.g. /private/var/... on macOS) against
// caller-supplied worktree paths that may still contain symlinks.
func resolvePath(p string) string {
if resolved, err := filepath.EvalSymlinks(p); err == nil {
return resolved
}
return p
}
52 changes: 52 additions & 0 deletions internal/process/detect_unix_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
//go:build !windows

package process

import (
"os"
"os/exec"
"path/filepath"
"testing"
"time"
)

// FindProcessesInWorktree should match a process whose cwd resolves to the
// same real path as the worktree, even when the caller passes a symlinked
// worktree path. This also covers macOS /tmp -> /private/tmp.
func TestFindProcessesInWorktree_ResolvesSymlinks(t *testing.T) {
realDir := t.TempDir()

linkDir := filepath.Join(t.TempDir(), "link")
if err := os.Symlink(realDir, linkDir); err != nil {
t.Fatalf("symlink: %v", err)
}

cmd := exec.Command("sleep", "60")
cmd.Dir = realDir
if err := cmd.Start(); err != nil {
t.Skipf("cannot start sleep: %v", err)
}
t.Cleanup(func() {
_ = cmd.Process.Kill()
_ = cmd.Wait()
})

time.Sleep(200 * time.Millisecond)

procs, err := FindProcessesInWorktree(linkDir)
if err != nil {
t.Fatalf("FindProcessesInWorktree: %v", err)
}

var found bool
for _, p := range procs {
if int(p.PID) == cmd.Process.Pid {
found = true
break
}
}
if !found {
t.Fatalf("expected to find pid %d via symlinked path %q, got %v",
cmd.Process.Pid, linkDir, procs)
}
}
74 changes: 74 additions & 0 deletions internal/process/terminate.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package process

import (
"os"
"time"

gopsutilprocess "github.com/shirou/gopsutil/v4/process"
)

// TerminateWorktreeProcesses finds every process whose cwd is within the given
// worktree path and terminates them.
//
// On unix it sends SIGTERM, waits up to gracePeriod for processes to exit,
// then SIGKILLs any survivors. On windows it uses TerminateProcess.
//
// Returns the list of processes that were targeted. Errors only if the initial
// scan fails; individual kill failures (e.g. process already gone) are
// swallowed.
func TerminateWorktreeProcesses(worktreePath string, gracePeriod time.Duration) ([]ProcessInfo, error) {
procs, err := FindProcessesInWorktree(worktreePath)
if err != nil {
return nil, err
}
procs = filterProtectedProcesses(procs, int32(os.Getpid()), parentPID)
if len(procs) == 0 {
return nil, nil
}

pids := make([]int32, len(procs))
for i, p := range procs {
pids[i] = p.PID
}

terminate(pids, gracePeriod)
return procs, nil
}

func filterProtectedProcesses(procs []ProcessInfo, currentPID int32, lookupParent func(int32) (int32, error)) []ProcessInfo {
protected := map[int32]struct{}{
currentPID: {},
}

for pid := currentPID; pid > 0; {
parent, err := lookupParent(pid)
if err != nil {
return nil
}
if parent <= 0 {
break
}
if _, seen := protected[parent]; seen {
break
}
protected[parent] = struct{}{}
pid = parent
}

filtered := procs[:0]
for _, proc := range procs {
if _, skip := protected[proc.PID]; skip {
continue
}
filtered = append(filtered, proc)
}
return filtered
}

func parentPID(pid int32) (int32, error) {
proc, err := gopsutilprocess.NewProcess(pid)
if err != nil {
return 0, err
}
return proc.Ppid()
}
56 changes: 56 additions & 0 deletions internal/process/terminate_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package process

import (
"errors"
"testing"
)

func TestFilterProtectedProcesses_SkipsCurrentProcessAndAncestors(t *testing.T) {
procs := []ProcessInfo{
{PID: 100, Name: "shell"},
{PID: 200, Name: "treehouse"},
{PID: 300, Name: "server"},
}

filtered := filterProtectedProcesses(procs, 200, func(pid int32) (int32, error) {
switch pid {
case 200:
return 100, nil
case 100:
return 1, nil
case 1:
return 0, nil
default:
return 0, errors.New("unknown pid")
}
})

if len(filtered) != 1 {
t.Fatalf("expected 1 process after filtering, got %d", len(filtered))
}
if filtered[0].PID != 300 {
t.Fatalf("expected pid 300 to remain, got %d", filtered[0].PID)
}
if filtered[0].Name != "server" {
t.Fatalf("expected server to remain, got %q", filtered[0].Name)
}
}

func TestFilterProtectedProcesses_SkipsTerminationWhenParentLookupFails(t *testing.T) {
procs := []ProcessInfo{
{PID: 100, Name: "shell"},
{PID: 200, Name: "treehouse"},
{PID: 300, Name: "server"},
}

filtered := filterProtectedProcesses(procs, 200, func(pid int32) (int32, error) {
if pid == 200 {
return 0, errors.New("cannot inspect parent")
}
return 0, nil
})

if len(filtered) != 0 {
t.Fatalf("expected no processes after filtering, got %+v", filtered)
}
}
Loading
Loading