@@ -41,6 +41,7 @@ import (
4141 "strconv"
4242 "strings"
4343 "sync"
44+ "sync/atomic"
4445 "time"
4546
4647 "github.com/github/copilot-sdk/go/internal/embeddedcli"
@@ -86,8 +87,10 @@ type Client struct {
8687 lifecycleHandlers []SessionLifecycleHandler
8788 typedLifecycleHandlers map [SessionLifecycleEventType ][]SessionLifecycleHandler
8889 lifecycleHandlersMux sync.Mutex
89- processDone chan struct {} // closed when CLI process exits
90- processError error // set before processDone is closed
90+ startStopMux sync.RWMutex // protects process and state during start/[force]stop
91+ processDone chan struct {}
92+ processErrorPtr * error
93+ osProcess atomic.Pointer [os.Process ]
9194
9295 // RPC provides typed server-scoped RPC methods.
9396 // This field is nil until the client is connected via Start().
@@ -251,6 +254,9 @@ func parseCliUrl(url string) (string, int) {
251254// }
252255// // Now ready to create sessions
253256func (c * Client ) Start (ctx context.Context ) error {
257+ c .startStopMux .Lock ()
258+ defer c .startStopMux .Unlock ()
259+
254260 if c .state == StateConnected {
255261 return nil
256262 }
@@ -260,21 +266,24 @@ func (c *Client) Start(ctx context.Context) error {
260266 // Only start CLI server process if not connecting to external server
261267 if ! c .isExternalServer {
262268 if err := c .startCLIServer (ctx ); err != nil {
269+ c .process = nil
263270 c .state = StateError
264271 return err
265272 }
266273 }
267274
268275 // Connect to the server
269276 if err := c .connectToServer (ctx ); err != nil {
277+ killErr := c .killProcess ()
270278 c .state = StateError
271- return err
279+ return errors . Join ( err , killErr )
272280 }
273281
274282 // Verify protocol version compatibility
275283 if err := c .verifyProtocolVersion (ctx ); err != nil {
284+ killErr := c .killProcess ()
276285 c .state = StateError
277- return err
286+ return errors . Join ( err , killErr )
278287 }
279288
280289 c .state = StateConnected
@@ -316,13 +325,16 @@ func (c *Client) Stop() error {
316325 c .sessions = make (map [string ]* Session )
317326 c .sessionsMux .Unlock ()
318327
328+ c .startStopMux .Lock ()
329+ defer c .startStopMux .Unlock ()
330+
319331 // Kill CLI process FIRST (this closes stdout and unblocks readLoop) - only if we spawned it
320332 if c .process != nil && ! c .isExternalServer {
321- if err := c .process . Process . Kill (); err != nil {
322- errs = append (errs , fmt . Errorf ( "failed to kill CLI process: %w" , err ) )
333+ if err := c .killProcess (); err != nil {
334+ errs = append (errs , err )
323335 }
324- c .process = nil
325336 }
337+ c .process = nil
326338
327339 // Close external TCP connection if exists
328340 if c .isExternalServer && c .conn != nil {
@@ -375,16 +387,27 @@ func (c *Client) Stop() error {
375387// client.ForceStop()
376388// }
377389func (c * Client ) ForceStop () {
390+ // Kill the process without waiting for startStopMux, which Start may hold.
391+ // This unblocks any I/O Start is doing (connect, version check).
392+ if p := c .osProcess .Swap (nil ); p != nil {
393+ p .Kill ()
394+ }
395+
378396 // Clear sessions immediately without trying to destroy them
379397 c .sessionsMux .Lock ()
380398 c .sessions = make (map [string ]* Session )
381399 c .sessionsMux .Unlock ()
382400
401+ c .startStopMux .Lock ()
402+ defer c .startStopMux .Unlock ()
403+
383404 // Kill CLI process (only if we spawned it)
405+ // This is a fallback in case the process wasn't killed above (e.g. if Start hadn't set
406+ // osProcess yet), or if the process was restarted and osProcess now points to a new process.
384407 if c .process != nil && ! c .isExternalServer {
385- c .process .Process .Kill () // Ignore errors
386- c .process = nil
408+ _ = c .killProcess () // Ignore errors since we're force stopping
387409 }
410+ c .process = nil
388411
389412 // Close external TCP connection if exists
390413 if c .isExternalServer && c .conn != nil {
@@ -886,6 +909,8 @@ func (c *Client) handleLifecycleEvent(event SessionLifecycleEvent) {
886909// })
887910// }
888911func (c * Client ) State () ConnectionState {
912+ c .startStopMux .RLock ()
913+ defer c .startStopMux .RUnlock ()
889914 return c .state
890915}
891916
@@ -1096,27 +1121,11 @@ func (c *Client) startCLIServer(ctx context.Context) error {
10961121 return fmt .Errorf ("failed to start CLI server: %w" , err )
10971122 }
10981123
1099- // Monitor process exit to signal pending requests
1100- c .processDone = make (chan struct {})
1101- // Capturing a stable reference to the process for the goroutine prevents
1102- // a race: c.process can be assigned nil in [Force]Stop() while the
1103- // goroutine is starting. It's okay for this goroutine to Wait on the
1104- // process in that case because [Force]Stop() kills the process, causing
1105- // Wait to return immediately.
1106- proc := c .process
1107- go func () {
1108- waitErr := proc .Wait ()
1109- if waitErr != nil {
1110- c .processError = fmt .Errorf ("CLI process exited: %v" , waitErr )
1111- } else {
1112- c .processError = fmt .Errorf ("CLI process exited unexpectedly" )
1113- }
1114- close (c .processDone )
1115- }()
1124+ c .monitorProcess ()
11161125
11171126 // Create JSON-RPC client immediately
11181127 c .client = jsonrpc2 .NewClient (stdin , stdout )
1119- c .client .SetProcessDone (c .processDone , & c . processError )
1128+ c .client .SetProcessDone (c .processDone , c . processErrorPtr )
11201129 c .RPC = rpc .NewServerRpc (c .client )
11211130 c .setupNotificationHandler ()
11221131 c .client .Start ()
@@ -1133,22 +1142,25 @@ func (c *Client) startCLIServer(ctx context.Context) error {
11331142 return fmt .Errorf ("failed to start CLI server: %w" , err )
11341143 }
11351144
1136- // Wait for port announcement
1145+ c .monitorProcess ()
1146+
11371147 scanner := bufio .NewScanner (stdout )
11381148 timeout := time .After (10 * time .Second )
11391149 portRegex := regexp .MustCompile (`listening on port (\d+)` )
11401150
11411151 for {
11421152 select {
11431153 case <- timeout :
1144- return fmt .Errorf ("timeout waiting for CLI server to start" )
1154+ killErr := c .killProcess ()
1155+ return errors .Join (fmt .Errorf ("timeout waiting for CLI server to start" ), killErr )
11451156 default :
11461157 if scanner .Scan () {
11471158 line := scanner .Text ()
11481159 if matches := portRegex .FindStringSubmatch (line ); len (matches ) > 1 {
11491160 port , err := strconv .Atoi (matches [1 ])
11501161 if err != nil {
1151- return fmt .Errorf ("failed to parse port: %w" , err )
1162+ killErr := c .killProcess ()
1163+ return errors .Join (fmt .Errorf ("failed to parse port: %w" , err ), killErr )
11521164 }
11531165 c .actualPort = port
11541166 return nil
@@ -1159,6 +1171,39 @@ func (c *Client) startCLIServer(ctx context.Context) error {
11591171 }
11601172}
11611173
1174+ func (c * Client ) killProcess () error {
1175+ if p := c .osProcess .Swap (nil ); p != nil {
1176+ if err := p .Kill (); err != nil {
1177+ return fmt .Errorf ("failed to kill CLI process: %w" , err )
1178+ }
1179+ }
1180+ c .process = nil
1181+ return nil
1182+ }
1183+
1184+ // monitorProcess signals when the CLI process exits and captures any exit error.
1185+ // processError is intentionally a local: each process lifecycle gets its own
1186+ // error value, so goroutines from previous processes can't overwrite the
1187+ // current one. Closing the channel synchronizes with readers, guaranteeing
1188+ // they see the final processError value.
1189+ func (c * Client ) monitorProcess () {
1190+ done := make (chan struct {})
1191+ c .processDone = done
1192+ proc := c .process
1193+ c .osProcess .Store (proc .Process )
1194+ var processError error
1195+ c .processErrorPtr = & processError
1196+ go func () {
1197+ waitErr := proc .Wait ()
1198+ if waitErr != nil {
1199+ processError = fmt .Errorf ("CLI process exited: %w" , waitErr )
1200+ } else {
1201+ processError = errors .New ("CLI process exited unexpectedly" )
1202+ }
1203+ close (done )
1204+ }()
1205+ }
1206+
11621207// connectToServer establishes a connection to the server.
11631208func (c * Client ) connectToServer (ctx context.Context ) error {
11641209 if c .useStdio {
@@ -1190,6 +1235,9 @@ func (c *Client) connectViaTcp(ctx context.Context) error {
11901235
11911236 // Create JSON-RPC client with the connection
11921237 c .client = jsonrpc2 .NewClient (conn , conn )
1238+ if c .processDone != nil {
1239+ c .client .SetProcessDone (c .processDone , c .processErrorPtr )
1240+ }
11931241 c .RPC = rpc .NewServerRpc (c .client )
11941242 c .setupNotificationHandler ()
11951243 c .client .Start ()
0 commit comments