Skip to content

Commit afbc653

Browse files
committed
initial commit
0 parents  commit afbc653

3 files changed

Lines changed: 228 additions & 0 deletions

File tree

LICENSE

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
Copyright (c) db7
2+
3+
Permission is hereby granted, free of charge, to any person obtaining a copy
4+
of this software and associated documentation files (the "Software"), to deal
5+
in the Software without restriction, including without limitation the rights
6+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7+
copies of the Software, and to permit persons to whom the Software is
8+
furnished to do so, subject to the following conditions:
9+
10+
The above copyright notice and this permission notice shall be included in all
11+
copies or substantial portions of the Software.
12+
13+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
19+
SOFTWARE.
20+

barrier.go

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
// Package barrier provides a data structure to synchronize a group of
2+
// goroutines, blocking them until all of the arrive on the barrier. Once the
3+
// last goroutine arrive an optional callback is executed in isolation.
4+
package barrier
5+
6+
import (
7+
"fmt"
8+
"sync"
9+
)
10+
11+
// ErrBarrierAborted is returned by Await() if Abort() was called.
12+
var ErrBarrierAborted = fmt.Errorf("Barrier aborted")
13+
14+
// ErrBarrierMisused is returned by Await() if more than n concurrent Await()
15+
// calls are detected.
16+
var ErrBarrierMisused = fmt.Errorf("Barrier misused: more than n concurrent Await() calls")
17+
18+
// Callback is called by the last goroutine entering the barrier.
19+
type Callback func() error
20+
21+
// Barrier synchronizes a group of goroutines and optinally executes a callback
22+
// in isolation.
23+
type Barrier struct {
24+
sync.Mutex
25+
n int64
26+
count int64
27+
done chan bool
28+
abort chan bool
29+
}
30+
31+
// New returns a new Barrier which expects n goroutines to synchronize.
32+
func New(n int) *Barrier {
33+
return &Barrier{
34+
n: int64(n),
35+
count: int64(n),
36+
done: make(chan bool),
37+
abort: make(chan bool),
38+
}
39+
}
40+
41+
// Abort marks the barrier as aborted and signal all waiting goroutines.
42+
// The barrier cannot be reset once aborted.
43+
func (b *Barrier) Abort() {
44+
close(b.abort)
45+
}
46+
47+
// Await synchronizes n goroutines and executes in isolation the callback of
48+
// the last goroutine calling Await. Await returns any error the callback
49+
// returns to one goroutine; if Abort() is called, ErrBarrierAborted is
50+
// returned. The number of goroutines call Await should always match the value
51+
// n passed in the barrier's initialization.
52+
func (b *Barrier) Await(cb Callback) error {
53+
if b.aborted() {
54+
return ErrBarrierAborted
55+
}
56+
// keep copy of current state
57+
b.Lock()
58+
b.count--
59+
count := b.count
60+
done := b.done
61+
b.Unlock()
62+
63+
// more than n goroutines called Await
64+
if count < 0 {
65+
b.Abort()
66+
return ErrBarrierMisused
67+
}
68+
69+
// wait for others and callback execution
70+
if count > 0 {
71+
return b.wait(done)
72+
}
73+
74+
// if count == 0 execute callback if last goroutine
75+
var err error
76+
if cb != nil {
77+
err = cb()
78+
}
79+
b.reset()
80+
return err
81+
}
82+
83+
// aborted checks whether Barrier is aborted
84+
func (b *Barrier) aborted() bool {
85+
select {
86+
case <-b.abort:
87+
return true
88+
default:
89+
return false
90+
}
91+
}
92+
93+
// wait waits for execution of callback or abort()
94+
func (b *Barrier) wait(done chan bool) error {
95+
select {
96+
case <-done:
97+
return nil
98+
case <-b.abort:
99+
return ErrBarrierAborted
100+
}
101+
}
102+
103+
// reset resets the barrier for another round
104+
func (b *Barrier) reset() {
105+
b.Lock()
106+
close(b.done)
107+
b.done = make(chan bool)
108+
b.count = b.n
109+
b.Unlock()
110+
}

barrier_test.go

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
package barrier_test
2+
3+
import (
4+
"fmt"
5+
"sync"
6+
"sync/atomic"
7+
"testing"
8+
9+
"github.com/db7/barrier"
10+
"github.com/facebookgo/ensure"
11+
)
12+
13+
func TestBarrier_manyrounds(t *testing.T) {
14+
var count int64
15+
rounds := 100
16+
n := 10 // number of goroutines
17+
b := barrier.New(n)
18+
19+
for i := 0; i < rounds; i++ {
20+
var wg sync.WaitGroup
21+
wg.Add(n)
22+
for j := 0; j < n; j++ {
23+
go func() {
24+
b.Await(func() error {
25+
atomic.AddInt64(&count, 1)
26+
return nil
27+
})
28+
wg.Done()
29+
}()
30+
}
31+
wg.Wait()
32+
}
33+
ensure.True(t, atomic.LoadInt64(&count) == int64(rounds))
34+
}
35+
36+
func TestBarrier_abortBeforeLast(t *testing.T) {
37+
n := 10 // number of goroutines
38+
b := barrier.New(n)
39+
40+
// one round before abort
41+
var wg sync.WaitGroup
42+
wg.Add(n)
43+
for j := 0; j < n; j++ {
44+
go func() {
45+
defer wg.Done()
46+
err := b.Await(nil)
47+
ensure.Nil(t, err)
48+
}()
49+
}
50+
wg.Wait()
51+
52+
// one round when abort is called
53+
wg.Add(n)
54+
for j := 0; j < n-1; j++ {
55+
go func() {
56+
defer wg.Done()
57+
err := b.Await(nil)
58+
ensure.True(t, err == barrier.ErrBarrierAborted)
59+
}()
60+
}
61+
// last goroutine
62+
b.Abort()
63+
go func() {
64+
defer wg.Done()
65+
err := b.Await(nil)
66+
ensure.True(t, err == barrier.ErrBarrierAborted)
67+
}()
68+
wg.Wait()
69+
70+
// one last round where all goroutines should fail (aborted)
71+
wg.Add(n)
72+
for j := 0; j < n; j++ {
73+
go func() {
74+
defer wg.Done()
75+
err := b.Await(nil)
76+
ensure.True(t, err == barrier.ErrBarrierAborted)
77+
}()
78+
}
79+
wg.Wait()
80+
}
81+
82+
func ExampleBarrier_simple() {
83+
n := 40 // number of goroutines
84+
b := barrier.New(n)
85+
86+
var wg sync.WaitGroup
87+
wg.Add(n)
88+
for i := 0; i < n; i++ {
89+
go func(k int) {
90+
b.Await(func() error {
91+
fmt.Println(k, "is the last goroutine")
92+
return nil
93+
})
94+
wg.Done()
95+
}(i)
96+
}
97+
wg.Wait()
98+
}

0 commit comments

Comments
 (0)