Skip to content

Commit

Permalink
v3: Added context to func to improve graceful shutdown
Browse files Browse the repository at this point in the history
  • Loading branch information
hbl-ngocnd1 committed Oct 31, 2024
1 parent e843a09 commit 92b9c61
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 87 deletions.
20 changes: 12 additions & 8 deletions chain.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package cron

import (
"context"
"fmt"
"runtime"
"sync"
Expand All @@ -24,9 +25,12 @@ func NewChain(c ...JobWrapper) Chain {
// Then decorates the given job with all JobWrappers in the chain.
//
// This:
// NewChain(m1, m2, m3).Then(job)
//
// NewChain(m1, m2, m3).Then(job)
//
// is equivalent to:
// m1(m2(m3(job)))
//
// m1(m2(m3(job)))
func (c Chain) Then(j Job) Job {
for i := range c.wrappers {
j = c.wrappers[len(c.wrappers)-i-1](j)
Expand All @@ -37,7 +41,7 @@ func (c Chain) Then(j Job) Job {
// Recover panics in wrapped jobs and log them with the provided logger.
func Recover(logger Logger) JobWrapper {
return func(j Job) Job {
return FuncJob(func() {
return FuncJob(func(ctx context.Context) {
defer func() {
if r := recover(); r != nil {
const size = 64 << 10
Expand All @@ -50,7 +54,7 @@ func Recover(logger Logger) JobWrapper {
logger.Error(err, "panic", "stack", "...\n"+string(buf))
}
}()
j.Run()
j.Run(ctx)
})
}
}
Expand All @@ -61,14 +65,14 @@ func Recover(logger Logger) JobWrapper {
func DelayIfStillRunning(logger Logger) JobWrapper {
return func(j Job) Job {
var mu sync.Mutex
return FuncJob(func() {
return FuncJob(func(ctx context.Context) {
start := time.Now()
mu.Lock()
defer mu.Unlock()
if dur := time.Since(start); dur > time.Minute {
logger.Info("delay", "duration", dur)
}
j.Run()
j.Run(ctx)
})
}
}
Expand All @@ -79,10 +83,10 @@ func SkipIfStillRunning(logger Logger) JobWrapper {
var ch = make(chan struct{}, 1)
ch <- struct{}{}
return func(j Job) Job {
return FuncJob(func() {
return FuncJob(func(ctx context.Context) {
select {
case v := <-ch:
j.Run()
j.Run(ctx)
ch <- v
default:
logger.Info("skip")
Expand Down
43 changes: 22 additions & 21 deletions chain_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package cron

import (
"context"
"io/ioutil"
"log"
"reflect"
Expand All @@ -11,7 +12,7 @@ import (

func appendingJob(slice *[]int, value int) Job {
var m sync.Mutex
return FuncJob(func() {
return FuncJob(func(ctx context.Context) {
m.Lock()
*slice = append(*slice, value)
m.Unlock()
Expand All @@ -20,9 +21,9 @@ func appendingJob(slice *[]int, value int) Job {

func appendingWrapper(slice *[]int, value int) JobWrapper {
return func(j Job) Job {
return FuncJob(func() {
appendingJob(slice, value).Run()
j.Run()
return FuncJob(func(ctx context.Context) {
appendingJob(slice, value).Run(ctx)
j.Run(ctx)
})
}
}
Expand All @@ -35,14 +36,14 @@ func TestChain(t *testing.T) {
append3 = appendingWrapper(&nums, 3)
append4 = appendingJob(&nums, 4)
)
NewChain(append1, append2, append3).Then(append4).Run()
NewChain(append1, append2, append3).Then(append4).Run(context.Background())
if !reflect.DeepEqual(nums, []int{1, 2, 3, 4}) {
t.Error("unexpected order of calls:", nums)
}
}

func TestChainRecover(t *testing.T) {
panickingJob := FuncJob(func() {
panickingJob := FuncJob(func(ctx context.Context) {
panic("panickingJob panics")
})

Expand All @@ -53,19 +54,19 @@ func TestChainRecover(t *testing.T) {
}
}()
NewChain().Then(panickingJob).
Run()
Run(context.Background())
})

t.Run("Recovering JobWrapper recovers", func(t *testing.T) {
NewChain(Recover(PrintfLogger(log.New(ioutil.Discard, "", 0)))).
Then(panickingJob).
Run()
Run(context.Background())
})

t.Run("composed with the *IfStillRunning wrappers", func(t *testing.T) {
NewChain(Recover(PrintfLogger(log.New(ioutil.Discard, "", 0)))).
Then(panickingJob).
Run()
Run(context.Background())
})
}

Expand All @@ -76,7 +77,7 @@ type countJob struct {
delay time.Duration
}

func (j *countJob) Run() {
func (j *countJob) Run(context.Context) {
j.m.Lock()
j.started++
j.m.Unlock()
Expand All @@ -103,7 +104,7 @@ func TestChainDelayIfStillRunning(t *testing.T) {
t.Run("runs immediately", func(t *testing.T) {
var j countJob
wrappedJob := NewChain(DelayIfStillRunning(DiscardLogger)).Then(&j)
go wrappedJob.Run()
go wrappedJob.Run(context.Background())
time.Sleep(2 * time.Millisecond) // Give the job 2ms to complete.
if c := j.Done(); c != 1 {
t.Errorf("expected job run once, immediately, got %d", c)
Expand All @@ -114,9 +115,9 @@ func TestChainDelayIfStillRunning(t *testing.T) {
var j countJob
wrappedJob := NewChain(DelayIfStillRunning(DiscardLogger)).Then(&j)
go func() {
go wrappedJob.Run()
go wrappedJob.Run(context.Background())
time.Sleep(time.Millisecond)
go wrappedJob.Run()
go wrappedJob.Run(context.Background())
}()
time.Sleep(3 * time.Millisecond) // Give both jobs 3ms to complete.
if c := j.Done(); c != 2 {
Expand All @@ -129,9 +130,9 @@ func TestChainDelayIfStillRunning(t *testing.T) {
j.delay = 10 * time.Millisecond
wrappedJob := NewChain(DelayIfStillRunning(DiscardLogger)).Then(&j)
go func() {
go wrappedJob.Run()
go wrappedJob.Run(context.Background())
time.Sleep(time.Millisecond)
go wrappedJob.Run()
go wrappedJob.Run(context.Background())
}()

// After 5ms, the first job is still in progress, and the second job was
Expand All @@ -157,7 +158,7 @@ func TestChainSkipIfStillRunning(t *testing.T) {
t.Run("runs immediately", func(t *testing.T) {
var j countJob
wrappedJob := NewChain(SkipIfStillRunning(DiscardLogger)).Then(&j)
go wrappedJob.Run()
go wrappedJob.Run(context.Background())
time.Sleep(2 * time.Millisecond) // Give the job 2ms to complete.
if c := j.Done(); c != 1 {
t.Errorf("expected job run once, immediately, got %d", c)
Expand All @@ -168,9 +169,9 @@ func TestChainSkipIfStillRunning(t *testing.T) {
var j countJob
wrappedJob := NewChain(SkipIfStillRunning(DiscardLogger)).Then(&j)
go func() {
go wrappedJob.Run()
go wrappedJob.Run(context.Background())
time.Sleep(time.Millisecond)
go wrappedJob.Run()
go wrappedJob.Run(context.Background())
}()
time.Sleep(3 * time.Millisecond) // Give both jobs 3ms to complete.
if c := j.Done(); c != 2 {
Expand All @@ -183,9 +184,9 @@ func TestChainSkipIfStillRunning(t *testing.T) {
j.delay = 10 * time.Millisecond
wrappedJob := NewChain(SkipIfStillRunning(DiscardLogger)).Then(&j)
go func() {
go wrappedJob.Run()
go wrappedJob.Run(context.Background())
time.Sleep(time.Millisecond)
go wrappedJob.Run()
go wrappedJob.Run(context.Background())
}()

// After 5ms, the first job is still in progress, and the second job was
Expand All @@ -209,7 +210,7 @@ func TestChainSkipIfStillRunning(t *testing.T) {
j.delay = 10 * time.Millisecond
wrappedJob := NewChain(SkipIfStillRunning(DiscardLogger)).Then(&j)
for i := 0; i < 11; i++ {
go wrappedJob.Run()
go wrappedJob.Run(context.Background())
}
time.Sleep(200 * time.Millisecond)
done := j.Done()
Expand Down
34 changes: 20 additions & 14 deletions cron.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,13 @@ type Cron struct {
parser Parser
nextID EntryID
jobWaiter sync.WaitGroup
ctx context.Context
cancel context.CancelFunc
}

// Job is an interface for submitted cron jobs.
type Job interface {
Run()
Run(ctx context.Context)
}

// Schedule describes a job's duty cycle.
Expand Down Expand Up @@ -92,20 +94,21 @@ func (s byTime) Less(i, j int) bool {
//
// Available Settings
//
// Time Zone
// Description: The time zone in which schedules are interpreted
// Default: time.Local
// Time Zone
// Description: The time zone in which schedules are interpreted
// Default: time.Local
//
// Parser
// Description: Parser converts cron spec strings into cron.Schedules.
// Default: Accepts this spec: https://en.wikipedia.org/wiki/Cron
// Parser
// Description: Parser converts cron spec strings into cron.Schedules.
// Default: Accepts this spec: https://en.wikipedia.org/wiki/Cron
//
// Chain
// Description: Wrap submitted jobs to customize behavior.
// Default: A chain that recovers panics and logs them to stderr.
// Chain
// Description: Wrap submitted jobs to customize behavior.
// Default: A chain that recovers panics and logs them to stderr.
//
// See "cron.With*" to modify the default behavior.
func New(opts ...Option) *Cron {
ctx, cancel := context.WithCancel(context.Background())
c := &Cron{
entries: nil,
chain: NewChain(),
Expand All @@ -118,6 +121,8 @@ func New(opts ...Option) *Cron {
logger: DefaultLogger,
location: time.Local,
parser: standardParser,
ctx: ctx,
cancel: cancel,
}
for _, opt := range opts {
opt(c)
Expand All @@ -126,14 +131,14 @@ func New(opts ...Option) *Cron {
}

// FuncJob is a wrapper that turns a func() into a cron.Job
type FuncJob func()
type FuncJob func(ctx context.Context)

func (f FuncJob) Run() { f() }
func (f FuncJob) Run(ctx context.Context) { f(ctx) }

// AddFunc adds a func to the Cron to be run on the given schedule.
// The spec is parsed using the time zone of this Cron instance as the default.
// An opaque ID is returned that can be used to later remove it.
func (c *Cron) AddFunc(spec string, cmd func()) (EntryID, error) {
func (c *Cron) AddFunc(spec string, cmd func(ctx context.Context)) (EntryID, error) {
return c.AddJob(spec, FuncJob(cmd))
}

Expand Down Expand Up @@ -304,7 +309,7 @@ func (c *Cron) startJob(j Job) {
c.jobWaiter.Add(1)
go func() {
defer c.jobWaiter.Done()
j.Run()
j.Run(c.ctx)
}()
}

Expand All @@ -319,6 +324,7 @@ func (c *Cron) Stop() context.Context {
c.runningMu.Lock()
defer c.runningMu.Unlock()
if c.running {
c.cancel()
c.stop <- struct{}{}
c.running = false
}
Expand Down
Loading

0 comments on commit 92b9c61

Please sign in to comment.