@@ -104,23 +104,26 @@ func (r *runner) Run(ctx context.Context) error {
104
104
105
105
errc := make (chan error , len (r .funcs ))
106
106
// this cancel func cancels all subroutines
107
- ctx , cancel := context .WithCancelCause (ctx )
107
+ subctx , cancel := context .WithCancelCause (ctx )
108
108
109
109
var waitCount int32
110
110
111
111
for i , f := range r .funcs {
112
112
atomic .AddInt32 (& waitCount , 1 )
113
113
go func (fn func (context.Context ) error , idx int ) {
114
- err := fn (ctx )
114
+ err := fn (subctx )
115
115
if err != nil && ! errors .Is (err , context .Canceled ) {
116
116
if r .funcNames [idx ] != "" {
117
117
slog .Info (fmt .Sprintf ("subroutine %s error: %+v" , r .funcNames [idx ], err ))
118
118
} else {
119
119
slog .Info (fmt .Sprintf ("subroutine error: %+v" , err ))
120
120
}
121
121
}
122
- errc <- err
122
+ // Order matters for the following two statements.
123
+ // We must decrement before writing to the channel so that waitCount is
124
+ // accurate when we read the remaining waitCount below after reading errC.
123
125
atomic .AddInt32 (& waitCount , - 1 )
126
+ errc <- err
124
127
}(f , i )
125
128
}
126
129
@@ -140,19 +143,19 @@ loop:
140
143
select {
141
144
case sig := <- sigc :
142
145
slog .Error ("stopping on signal" , "signal" , sig )
143
- case <- ctx .Done ():
144
- err = ctx .Err ()
146
+ case <- subctx .Done ():
147
+ err = subctx .Err ()
145
148
if ! errors .Is (err , context .Canceled ) {
146
149
slog .Error ("error on context done" , "err" , err )
147
150
}
148
151
case err = <- errc :
149
152
if err != nil {
150
153
slog .Warn ("await: stopping on error returned" , "err" , err )
151
154
} else {
152
- if ! r .proceedOnNil {
153
- slog .Info ("await: stopping because a subroutine finished" )
154
- } else {
155
+ if r .proceedOnNil && atomic .LoadInt32 (& waitCount ) > 0 {
155
156
goto loop
157
+ } else {
158
+ slog .Debug ("await: stopping on subroutine(s) complete" )
156
159
}
157
160
}
158
161
}
0 commit comments