Skip to content

Commit b79133d

Browse files
committed
prevent races between Start and [Force]Stop
1 parent f77786f commit b79133d

File tree

2 files changed

+120
-30
lines changed

2 files changed

+120
-30
lines changed

go/client.go

Lines changed: 78 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ import (
4141
"strconv"
4242
"strings"
4343
"sync"
44+
"sync/atomic"
4445
"time"
4546

4647
"github.com/github/copilot-sdk/go/internal/embeddedcli"
@@ -86,8 +87,10 @@ type Client struct {
8687
lifecycleHandlers []SessionLifecycleHandler
8788
typedLifecycleHandlers map[SessionLifecycleEventType][]SessionLifecycleHandler
8889
lifecycleHandlersMux sync.Mutex
89-
processDone chan struct{} // closed when CLI process exits
90-
processError error // set before processDone is closed
90+
startStopMux sync.RWMutex // protects process and state during start/[force]stop
91+
processDone chan struct{}
92+
processErrorPtr *error
93+
osProcess atomic.Pointer[os.Process]
9194

9295
// RPC provides typed server-scoped RPC methods.
9396
// This field is nil until the client is connected via Start().
@@ -251,6 +254,9 @@ func parseCliUrl(url string) (string, int) {
251254
// }
252255
// // Now ready to create sessions
253256
func (c *Client) Start(ctx context.Context) error {
257+
c.startStopMux.Lock()
258+
defer c.startStopMux.Unlock()
259+
254260
if c.state == StateConnected {
255261
return nil
256262
}
@@ -260,21 +266,24 @@ func (c *Client) Start(ctx context.Context) error {
260266
// Only start CLI server process if not connecting to external server
261267
if !c.isExternalServer {
262268
if err := c.startCLIServer(ctx); err != nil {
269+
c.process = nil
263270
c.state = StateError
264271
return err
265272
}
266273
}
267274

268275
// Connect to the server
269276
if err := c.connectToServer(ctx); err != nil {
277+
killErr := c.killProcess()
270278
c.state = StateError
271-
return err
279+
return errors.Join(err, killErr)
272280
}
273281

274282
// Verify protocol version compatibility
275283
if err := c.verifyProtocolVersion(ctx); err != nil {
284+
killErr := c.killProcess()
276285
c.state = StateError
277-
return err
286+
return errors.Join(err, killErr)
278287
}
279288

280289
c.state = StateConnected
@@ -316,13 +325,16 @@ func (c *Client) Stop() error {
316325
c.sessions = make(map[string]*Session)
317326
c.sessionsMux.Unlock()
318327

328+
c.startStopMux.Lock()
329+
defer c.startStopMux.Unlock()
330+
319331
// Kill CLI process FIRST (this closes stdout and unblocks readLoop) - only if we spawned it
320332
if c.process != nil && !c.isExternalServer {
321-
if err := c.process.Process.Kill(); err != nil {
322-
errs = append(errs, fmt.Errorf("failed to kill CLI process: %w", err))
333+
if err := c.killProcess(); err != nil {
334+
errs = append(errs, err)
323335
}
324-
c.process = nil
325336
}
337+
c.process = nil
326338

327339
// Close external TCP connection if exists
328340
if c.isExternalServer && c.conn != nil {
@@ -375,16 +387,27 @@ func (c *Client) Stop() error {
375387
// client.ForceStop()
376388
// }
377389
func (c *Client) ForceStop() {
390+
// Kill the process without waiting for startStopMux, which Start may hold.
391+
// This unblocks any I/O Start is doing (connect, version check).
392+
if p := c.osProcess.Swap(nil); p != nil {
393+
p.Kill()
394+
}
395+
378396
// Clear sessions immediately without trying to destroy them
379397
c.sessionsMux.Lock()
380398
c.sessions = make(map[string]*Session)
381399
c.sessionsMux.Unlock()
382400

401+
c.startStopMux.Lock()
402+
defer c.startStopMux.Unlock()
403+
383404
// Kill CLI process (only if we spawned it)
405+
// This is a fallback in case the process wasn't killed above (e.g. if Start hadn't set
406+
// osProcess yet), or if the process was restarted and osProcess now points to a new process.
384407
if c.process != nil && !c.isExternalServer {
385-
c.process.Process.Kill() // Ignore errors
386-
c.process = nil
408+
_ = c.killProcess() // Ignore errors since we're force stopping
387409
}
410+
c.process = nil
388411

389412
// Close external TCP connection if exists
390413
if c.isExternalServer && c.conn != nil {
@@ -886,6 +909,8 @@ func (c *Client) handleLifecycleEvent(event SessionLifecycleEvent) {
886909
// })
887910
// }
888911
func (c *Client) State() ConnectionState {
912+
c.startStopMux.RLock()
913+
defer c.startStopMux.RUnlock()
889914
return c.state
890915
}
891916

@@ -1096,27 +1121,11 @@ func (c *Client) startCLIServer(ctx context.Context) error {
10961121
return fmt.Errorf("failed to start CLI server: %w", err)
10971122
}
10981123

1099-
// Monitor process exit to signal pending requests
1100-
c.processDone = make(chan struct{})
1101-
// Capturing a stable reference to the process for the goroutine prevents
1102-
// a race: c.process can be assigned nil in [Force]Stop() while the
1103-
// goroutine is starting. It's okay for this goroutine to Wait on the
1104-
// process in that case because [Force]Stop() kills the process, causing
1105-
// Wait to return immediately.
1106-
proc := c.process
1107-
go func() {
1108-
waitErr := proc.Wait()
1109-
if waitErr != nil {
1110-
c.processError = fmt.Errorf("CLI process exited: %v", waitErr)
1111-
} else {
1112-
c.processError = fmt.Errorf("CLI process exited unexpectedly")
1113-
}
1114-
close(c.processDone)
1115-
}()
1124+
c.monitorProcess()
11161125

11171126
// Create JSON-RPC client immediately
11181127
c.client = jsonrpc2.NewClient(stdin, stdout)
1119-
c.client.SetProcessDone(c.processDone, &c.processError)
1128+
c.client.SetProcessDone(c.processDone, c.processErrorPtr)
11201129
c.RPC = rpc.NewServerRpc(c.client)
11211130
c.setupNotificationHandler()
11221131
c.client.Start()
@@ -1133,22 +1142,25 @@ func (c *Client) startCLIServer(ctx context.Context) error {
11331142
return fmt.Errorf("failed to start CLI server: %w", err)
11341143
}
11351144

1136-
// Wait for port announcement
1145+
c.monitorProcess()
1146+
11371147
scanner := bufio.NewScanner(stdout)
11381148
timeout := time.After(10 * time.Second)
11391149
portRegex := regexp.MustCompile(`listening on port (\d+)`)
11401150

11411151
for {
11421152
select {
11431153
case <-timeout:
1144-
return fmt.Errorf("timeout waiting for CLI server to start")
1154+
killErr := c.killProcess()
1155+
return errors.Join(fmt.Errorf("timeout waiting for CLI server to start"), killErr)
11451156
default:
11461157
if scanner.Scan() {
11471158
line := scanner.Text()
11481159
if matches := portRegex.FindStringSubmatch(line); len(matches) > 1 {
11491160
port, err := strconv.Atoi(matches[1])
11501161
if err != nil {
1151-
return fmt.Errorf("failed to parse port: %w", err)
1162+
killErr := c.killProcess()
1163+
return errors.Join(fmt.Errorf("failed to parse port: %w", err), killErr)
11521164
}
11531165
c.actualPort = port
11541166
return nil
@@ -1159,6 +1171,39 @@ func (c *Client) startCLIServer(ctx context.Context) error {
11591171
}
11601172
}
11611173

1174+
func (c *Client) killProcess() error {
1175+
if p := c.osProcess.Swap(nil); p != nil {
1176+
if err := p.Kill(); err != nil {
1177+
return fmt.Errorf("failed to kill CLI process: %w", err)
1178+
}
1179+
}
1180+
c.process = nil
1181+
return nil
1182+
}
1183+
1184+
// monitorProcess signals when the CLI process exits and captures any exit error.
1185+
// processError is intentionally a local: each process lifecycle gets its own
1186+
// error value, so goroutines from previous processes can't overwrite the
1187+
// current one. Closing the channel synchronizes with readers, guaranteeing
1188+
// they see the final processError value.
1189+
func (c *Client) monitorProcess() {
1190+
done := make(chan struct{})
1191+
c.processDone = done
1192+
proc := c.process
1193+
c.osProcess.Store(proc.Process)
1194+
var processError error
1195+
c.processErrorPtr = &processError
1196+
go func() {
1197+
waitErr := proc.Wait()
1198+
if waitErr != nil {
1199+
processError = fmt.Errorf("CLI process exited: %w", waitErr)
1200+
} else {
1201+
processError = errors.New("CLI process exited unexpectedly")
1202+
}
1203+
close(done)
1204+
}()
1205+
}
1206+
11621207
// connectToServer establishes a connection to the server.
11631208
func (c *Client) connectToServer(ctx context.Context) error {
11641209
if c.useStdio {
@@ -1190,6 +1235,9 @@ func (c *Client) connectViaTcp(ctx context.Context) error {
11901235

11911236
// Create JSON-RPC client with the connection
11921237
c.client = jsonrpc2.NewClient(conn, conn)
1238+
if c.processDone != nil {
1239+
c.client.SetProcessDone(c.processDone, c.processErrorPtr)
1240+
}
11931241
c.RPC = rpc.NewServerRpc(c.client)
11941242
c.setupNotificationHandler()
11951243
c.client.Start()

go/client_test.go

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"path/filepath"
77
"reflect"
88
"regexp"
9+
"sync"
910
"testing"
1011
)
1112

@@ -486,3 +487,44 @@ func TestClient_ResumeSession_RequiresPermissionHandler(t *testing.T) {
486487
}
487488
})
488489
}
490+
491+
func TestClient_StartStopRace(t *testing.T) {
492+
cliPath := findCLIPathForTest()
493+
if cliPath == "" {
494+
t.Skip("CLI not found")
495+
}
496+
client := NewClient(&ClientOptions{CLIPath: cliPath})
497+
defer client.ForceStop()
498+
errChan := make(chan error)
499+
wg := sync.WaitGroup{}
500+
for range 10 {
501+
wg.Add(3)
502+
go func() {
503+
defer wg.Done()
504+
if err := client.Start(t.Context()); err != nil {
505+
select {
506+
case errChan <- err:
507+
default:
508+
}
509+
}
510+
}()
511+
go func() {
512+
defer wg.Done()
513+
if err := client.Stop(); err != nil {
514+
select {
515+
case errChan <- err:
516+
default:
517+
}
518+
}
519+
}()
520+
go func() {
521+
defer wg.Done()
522+
client.ForceStop()
523+
}()
524+
}
525+
wg.Wait()
526+
close(errChan)
527+
if err := <-errChan; err != nil {
528+
t.Fatal(err)
529+
}
530+
}

0 commit comments

Comments
 (0)