diff --git a/chain.go b/chain.go index 118e5bbe..276843e0 100644 --- a/chain.go +++ b/chain.go @@ -1,6 +1,7 @@ package cron import ( + "context" "fmt" "runtime" "sync" @@ -24,9 +25,12 @@ func NewChain(c ...JobWrapper) Chain { // Then decorates the given job with all JobWrappers in the chain. // // This: -// NewChain(m1, m2, m3).Then(job) +// +// NewChain(m1, m2, m3).Then(job) +// // is equivalent to: -// m1(m2(m3(job))) +// +// m1(m2(m3(job))) func (c Chain) Then(j Job) Job { for i := range c.wrappers { j = c.wrappers[len(c.wrappers)-i-1](j) @@ -37,7 +41,7 @@ func (c Chain) Then(j Job) Job { // Recover panics in wrapped jobs and log them with the provided logger. func Recover(logger Logger) JobWrapper { return func(j Job) Job { - return FuncJob(func() { + return FuncJob(func(ctx context.Context) { defer func() { if r := recover(); r != nil { const size = 64 << 10 @@ -50,7 +54,7 @@ func Recover(logger Logger) JobWrapper { logger.Error(err, "panic", "stack", "...\n"+string(buf)) } }() - j.Run() + j.Run(ctx) }) } } @@ -61,14 +65,14 @@ func Recover(logger Logger) JobWrapper { func DelayIfStillRunning(logger Logger) JobWrapper { return func(j Job) Job { var mu sync.Mutex - return FuncJob(func() { + return FuncJob(func(ctx context.Context) { start := time.Now() mu.Lock() defer mu.Unlock() if dur := time.Since(start); dur > time.Minute { logger.Info("delay", "duration", dur) } - j.Run() + j.Run(ctx) }) } } @@ -79,10 +83,10 @@ func SkipIfStillRunning(logger Logger) JobWrapper { var ch = make(chan struct{}, 1) ch <- struct{}{} return func(j Job) Job { - return FuncJob(func() { + return FuncJob(func(ctx context.Context) { select { case v := <-ch: - j.Run() + j.Run(ctx) ch <- v default: logger.Info("skip") diff --git a/chain_test.go b/chain_test.go index 2561bd7f..c05f67de 100644 --- a/chain_test.go +++ b/chain_test.go @@ -1,6 +1,7 @@ package cron import ( + "context" "io/ioutil" "log" "reflect" @@ -11,7 +12,7 @@ import ( func appendingJob(slice *[]int, value int) Job { var m sync.Mutex - return FuncJob(func() { + return FuncJob(func(ctx context.Context) { m.Lock() *slice = append(*slice, value) m.Unlock() @@ -20,9 +21,9 @@ func appendingJob(slice *[]int, value int) Job { func appendingWrapper(slice *[]int, value int) JobWrapper { return func(j Job) Job { - return FuncJob(func() { - appendingJob(slice, value).Run() - j.Run() + return FuncJob(func(ctx context.Context) { + appendingJob(slice, value).Run(ctx) + j.Run(ctx) }) } } @@ -35,14 +36,14 @@ func TestChain(t *testing.T) { append3 = appendingWrapper(&nums, 3) append4 = appendingJob(&nums, 4) ) - NewChain(append1, append2, append3).Then(append4).Run() + NewChain(append1, append2, append3).Then(append4).Run(context.Background()) if !reflect.DeepEqual(nums, []int{1, 2, 3, 4}) { t.Error("unexpected order of calls:", nums) } } func TestChainRecover(t *testing.T) { - panickingJob := FuncJob(func() { + panickingJob := FuncJob(func(ctx context.Context) { panic("panickingJob panics") }) @@ -53,19 +54,19 @@ func TestChainRecover(t *testing.T) { } }() NewChain().Then(panickingJob). - Run() + Run(context.Background()) }) t.Run("Recovering JobWrapper recovers", func(t *testing.T) { NewChain(Recover(PrintfLogger(log.New(ioutil.Discard, "", 0)))). Then(panickingJob). - Run() + Run(context.Background()) }) t.Run("composed with the *IfStillRunning wrappers", func(t *testing.T) { NewChain(Recover(PrintfLogger(log.New(ioutil.Discard, "", 0)))). Then(panickingJob). - Run() + Run(context.Background()) }) } @@ -76,7 +77,7 @@ type countJob struct { delay time.Duration } -func (j *countJob) Run() { +func (j *countJob) Run(context.Context) { j.m.Lock() j.started++ j.m.Unlock() @@ -103,7 +104,7 @@ func TestChainDelayIfStillRunning(t *testing.T) { t.Run("runs immediately", func(t *testing.T) { var j countJob wrappedJob := NewChain(DelayIfStillRunning(DiscardLogger)).Then(&j) - go wrappedJob.Run() + go wrappedJob.Run(context.Background()) time.Sleep(2 * time.Millisecond) // Give the job 2ms to complete. if c := j.Done(); c != 1 { t.Errorf("expected job run once, immediately, got %d", c) @@ -114,9 +115,9 @@ func TestChainDelayIfStillRunning(t *testing.T) { var j countJob wrappedJob := NewChain(DelayIfStillRunning(DiscardLogger)).Then(&j) go func() { - go wrappedJob.Run() + go wrappedJob.Run(context.Background()) time.Sleep(time.Millisecond) - go wrappedJob.Run() + go wrappedJob.Run(context.Background()) }() time.Sleep(3 * time.Millisecond) // Give both jobs 3ms to complete. if c := j.Done(); c != 2 { @@ -129,9 +130,9 @@ func TestChainDelayIfStillRunning(t *testing.T) { j.delay = 10 * time.Millisecond wrappedJob := NewChain(DelayIfStillRunning(DiscardLogger)).Then(&j) go func() { - go wrappedJob.Run() + go wrappedJob.Run(context.Background()) time.Sleep(time.Millisecond) - go wrappedJob.Run() + go wrappedJob.Run(context.Background()) }() // After 5ms, the first job is still in progress, and the second job was @@ -157,7 +158,7 @@ func TestChainSkipIfStillRunning(t *testing.T) { t.Run("runs immediately", func(t *testing.T) { var j countJob wrappedJob := NewChain(SkipIfStillRunning(DiscardLogger)).Then(&j) - go wrappedJob.Run() + go wrappedJob.Run(context.Background()) time.Sleep(2 * time.Millisecond) // Give the job 2ms to complete. if c := j.Done(); c != 1 { t.Errorf("expected job run once, immediately, got %d", c) @@ -168,9 +169,9 @@ func TestChainSkipIfStillRunning(t *testing.T) { var j countJob wrappedJob := NewChain(SkipIfStillRunning(DiscardLogger)).Then(&j) go func() { - go wrappedJob.Run() + go wrappedJob.Run(context.Background()) time.Sleep(time.Millisecond) - go wrappedJob.Run() + go wrappedJob.Run(context.Background()) }() time.Sleep(3 * time.Millisecond) // Give both jobs 3ms to complete. if c := j.Done(); c != 2 { @@ -183,9 +184,9 @@ func TestChainSkipIfStillRunning(t *testing.T) { j.delay = 10 * time.Millisecond wrappedJob := NewChain(SkipIfStillRunning(DiscardLogger)).Then(&j) go func() { - go wrappedJob.Run() + go wrappedJob.Run(context.Background()) time.Sleep(time.Millisecond) - go wrappedJob.Run() + go wrappedJob.Run(context.Background()) }() // After 5ms, the first job is still in progress, and the second job was @@ -209,7 +210,7 @@ func TestChainSkipIfStillRunning(t *testing.T) { j.delay = 10 * time.Millisecond wrappedJob := NewChain(SkipIfStillRunning(DiscardLogger)).Then(&j) for i := 0; i < 11; i++ { - go wrappedJob.Run() + go wrappedJob.Run(context.Background()) } time.Sleep(200 * time.Millisecond) done := j.Done() diff --git a/cron.go b/cron.go index f6e451db..8e224551 100644 --- a/cron.go +++ b/cron.go @@ -24,11 +24,13 @@ type Cron struct { parser Parser nextID EntryID jobWaiter sync.WaitGroup + ctx context.Context + cancel context.CancelFunc } // Job is an interface for submitted cron jobs. type Job interface { - Run() + Run(ctx context.Context) } // Schedule describes a job's duty cycle. @@ -92,20 +94,21 @@ func (s byTime) Less(i, j int) bool { // // Available Settings // -// Time Zone -// Description: The time zone in which schedules are interpreted -// Default: time.Local +// Time Zone +// Description: The time zone in which schedules are interpreted +// Default: time.Local // -// Parser -// Description: Parser converts cron spec strings into cron.Schedules. -// Default: Accepts this spec: https://en.wikipedia.org/wiki/Cron +// Parser +// Description: Parser converts cron spec strings into cron.Schedules. +// Default: Accepts this spec: https://en.wikipedia.org/wiki/Cron // -// Chain -// Description: Wrap submitted jobs to customize behavior. -// Default: A chain that recovers panics and logs them to stderr. +// Chain +// Description: Wrap submitted jobs to customize behavior. +// Default: A chain that recovers panics and logs them to stderr. // // See "cron.With*" to modify the default behavior. func New(opts ...Option) *Cron { + ctx, cancel := context.WithCancel(context.Background()) c := &Cron{ entries: nil, chain: NewChain(), @@ -118,6 +121,8 @@ func New(opts ...Option) *Cron { logger: DefaultLogger, location: time.Local, parser: standardParser, + ctx: ctx, + cancel: cancel, } for _, opt := range opts { opt(c) @@ -126,14 +131,14 @@ func New(opts ...Option) *Cron { } // FuncJob is a wrapper that turns a func() into a cron.Job -type FuncJob func() +type FuncJob func(ctx context.Context) -func (f FuncJob) Run() { f() } +func (f FuncJob) Run(ctx context.Context) { f(ctx) } // AddFunc adds a func to the Cron to be run on the given schedule. // The spec is parsed using the time zone of this Cron instance as the default. // An opaque ID is returned that can be used to later remove it. -func (c *Cron) AddFunc(spec string, cmd func()) (EntryID, error) { +func (c *Cron) AddFunc(spec string, cmd func(ctx context.Context)) (EntryID, error) { return c.AddJob(spec, FuncJob(cmd)) } @@ -304,7 +309,7 @@ func (c *Cron) startJob(j Job) { c.jobWaiter.Add(1) go func() { defer c.jobWaiter.Done() - j.Run() + j.Run(c.ctx) }() } @@ -319,6 +324,7 @@ func (c *Cron) Stop() context.Context { c.runningMu.Lock() defer c.runningMu.Unlock() if c.running { + c.cancel() c.stop <- struct{}{} c.running = false } diff --git a/cron_test.go b/cron_test.go index 36f06bf7..829ca018 100644 --- a/cron_test.go +++ b/cron_test.go @@ -2,6 +2,7 @@ package cron import ( "bytes" + "context" "fmt" "log" "strings" @@ -44,7 +45,7 @@ func TestFuncPanicRecovery(t *testing.T) { WithChain(Recover(newBufLogger(&buf)))) cron.Start() defer cron.Stop() - cron.AddFunc("* * * * * ?", func() { + cron.AddFunc("* * * * * ?", func(ctx context.Context) { panic("YOLO") }) @@ -59,7 +60,7 @@ func TestFuncPanicRecovery(t *testing.T) { type DummyJob struct{} -func (d DummyJob) Run() { +func (d DummyJob) Run(ctx context.Context) { panic("YOLO") } @@ -102,7 +103,7 @@ func TestStopCausesJobsToNotRun(t *testing.T) { cron := newWithSeconds() cron.Start() cron.Stop() - cron.AddFunc("* * * * * ?", func() { wg.Done() }) + cron.AddFunc("* * * * * ?", func(ctx context.Context) { wg.Done() }) select { case <-time.After(OneSecond): @@ -118,7 +119,7 @@ func TestAddBeforeRunning(t *testing.T) { wg.Add(1) cron := newWithSeconds() - cron.AddFunc("* * * * * ?", func() { wg.Done() }) + cron.AddFunc("* * * * * ?", func(ctx context.Context) { wg.Done() }) cron.Start() defer cron.Stop() @@ -138,7 +139,7 @@ func TestAddWhileRunning(t *testing.T) { cron := newWithSeconds() cron.Start() defer cron.Stop() - cron.AddFunc("* * * * * ?", func() { wg.Done() }) + cron.AddFunc("* * * * * ?", func(ctx context.Context) { wg.Done() }) select { case <-time.After(OneSecond): @@ -154,7 +155,7 @@ func TestAddWhileRunningWithDelay(t *testing.T) { defer cron.Stop() time.Sleep(5 * time.Second) var calls int64 - cron.AddFunc("* * * * * *", func() { atomic.AddInt64(&calls, 1) }) + cron.AddFunc("* * * * * *", func(ctx context.Context) { atomic.AddInt64(&calls, 1) }) <-time.After(OneSecond) if atomic.LoadInt64(&calls) != 1 { @@ -168,7 +169,7 @@ func TestRemoveBeforeRunning(t *testing.T) { wg.Add(1) cron := newWithSeconds() - id, _ := cron.AddFunc("* * * * * ?", func() { wg.Done() }) + id, _ := cron.AddFunc("* * * * * ?", func(ctx context.Context) { wg.Done() }) cron.Remove(id) cron.Start() defer cron.Stop() @@ -189,7 +190,7 @@ func TestRemoveWhileRunning(t *testing.T) { cron := newWithSeconds() cron.Start() defer cron.Stop() - id, _ := cron.AddFunc("* * * * * ?", func() { wg.Done() }) + id, _ := cron.AddFunc("* * * * * ?", func(ctx context.Context) { wg.Done() }) cron.Remove(id) select { @@ -205,7 +206,7 @@ func TestSnapshotEntries(t *testing.T) { wg.Add(1) cron := New() - cron.AddFunc("@every 2s", func() { wg.Done() }) + cron.AddFunc("@every 2s", func(ctx context.Context) { wg.Done() }) cron.Start() defer cron.Stop() @@ -232,12 +233,12 @@ func TestMultipleEntries(t *testing.T) { wg.Add(2) cron := newWithSeconds() - cron.AddFunc("0 0 0 1 1 ?", func() {}) - cron.AddFunc("* * * * * ?", func() { wg.Done() }) - id1, _ := cron.AddFunc("* * * * * ?", func() { t.Fatal() }) - id2, _ := cron.AddFunc("* * * * * ?", func() { t.Fatal() }) - cron.AddFunc("0 0 0 31 12 ?", func() {}) - cron.AddFunc("* * * * * ?", func() { wg.Done() }) + cron.AddFunc("0 0 0 1 1 ?", func(ctx context.Context) {}) + cron.AddFunc("* * * * * ?", func(ctx context.Context) { wg.Done() }) + id1, _ := cron.AddFunc("* * * * * ?", func(ctx context.Context) { t.Fatal() }) + id2, _ := cron.AddFunc("* * * * * ?", func(ctx context.Context) { t.Fatal() }) + cron.AddFunc("0 0 0 31 12 ?", func(ctx context.Context) {}) + cron.AddFunc("* * * * * ?", func(ctx context.Context) { wg.Done() }) cron.Remove(id1) cron.Start() @@ -257,9 +258,9 @@ func TestRunningJobTwice(t *testing.T) { wg.Add(2) cron := newWithSeconds() - cron.AddFunc("0 0 0 1 1 ?", func() {}) - cron.AddFunc("0 0 0 31 12 ?", func() {}) - cron.AddFunc("* * * * * ?", func() { wg.Done() }) + cron.AddFunc("0 0 0 1 1 ?", func(ctx context.Context) {}) + cron.AddFunc("0 0 0 31 12 ?", func(ctx context.Context) {}) + cron.AddFunc("* * * * * ?", func(ctx context.Context) { wg.Done() }) cron.Start() defer cron.Stop() @@ -276,12 +277,12 @@ func TestRunningMultipleSchedules(t *testing.T) { wg.Add(2) cron := newWithSeconds() - cron.AddFunc("0 0 0 1 1 ?", func() {}) - cron.AddFunc("0 0 0 31 12 ?", func() {}) - cron.AddFunc("* * * * * ?", func() { wg.Done() }) - cron.Schedule(Every(time.Minute), FuncJob(func() {})) - cron.Schedule(Every(time.Second), FuncJob(func() { wg.Done() })) - cron.Schedule(Every(time.Hour), FuncJob(func() {})) + cron.AddFunc("0 0 0 1 1 ?", func(ctx context.Context) {}) + cron.AddFunc("0 0 0 31 12 ?", func(ctx context.Context) {}) + cron.AddFunc("* * * * * ?", func(ctx context.Context) { wg.Done() }) + cron.Schedule(Every(time.Minute), FuncJob(func(ctx context.Context) {})) + cron.Schedule(Every(time.Second), FuncJob(func(ctx context.Context) { wg.Done() })) + cron.Schedule(Every(time.Hour), FuncJob(func(ctx context.Context) {})) cron.Start() defer cron.Stop() @@ -310,7 +311,7 @@ func TestLocalTimezone(t *testing.T) { now.Second()+1, now.Second()+2, now.Minute(), now.Hour(), now.Day(), now.Month()) cron := newWithSeconds() - cron.AddFunc(spec, func() { wg.Done() }) + cron.AddFunc(spec, func(ctx context.Context) { wg.Done() }) cron.Start() defer cron.Stop() @@ -344,7 +345,7 @@ func TestNonLocalTimezone(t *testing.T) { now.Second()+1, now.Second()+2, now.Minute(), now.Hour(), now.Day(), now.Month()) cron := New(WithLocation(loc), WithParser(secondParser)) - cron.AddFunc(spec, func() { wg.Done() }) + cron.AddFunc(spec, func(ctx context.Context) { wg.Done() }) cron.Start() defer cron.Stop() @@ -367,7 +368,7 @@ type testJob struct { name string } -func (t testJob) Run() { +func (t testJob) Run(ctx context.Context) { t.wg.Done() } @@ -386,7 +387,7 @@ func TestBlockingRun(t *testing.T) { wg.Add(1) cron := newWithSeconds() - cron.AddFunc("* * * * * ?", func() { wg.Done() }) + cron.AddFunc("* * * * * ?", func(ctx context.Context) { wg.Done() }) var unblockChan = make(chan struct{}) @@ -410,7 +411,7 @@ func TestStartNoop(t *testing.T) { var tickChan = make(chan struct{}, 2) cron := newWithSeconds() - cron.AddFunc("* * * * * ?", func() { + cron.AddFunc("* * * * * ?", func(ctx context.Context) { tickChan <- struct{}{} }) @@ -501,8 +502,8 @@ func TestScheduleAfterRemoval(t *testing.T) { var mu sync.Mutex cron := newWithSeconds() - hourJob := cron.Schedule(Every(time.Hour), FuncJob(func() {})) - cron.Schedule(Every(time.Second), FuncJob(func() { + hourJob := cron.Schedule(Every(time.Hour), FuncJob(func(ctx context.Context) {})) + cron.Schedule(Every(time.Second), FuncJob(func(ctx context.Context) { mu.Lock() defer mu.Unlock() switch calls { @@ -545,8 +546,8 @@ func (*ZeroSchedule) Next(time.Time) time.Time { func TestJobWithZeroTimeDoesNotRun(t *testing.T) { cron := newWithSeconds() var calls int64 - cron.AddFunc("* * * * * *", func() { atomic.AddInt64(&calls, 1) }) - cron.Schedule(new(ZeroSchedule), FuncJob(func() { t.Error("expected zero task will not run") })) + cron.AddFunc("* * * * * *", func(ctx context.Context) { atomic.AddInt64(&calls, 1) }) + cron.Schedule(new(ZeroSchedule), FuncJob(func(ctx context.Context) { t.Error("expected zero task will not run") })) cron.Start() defer cron.Stop() <-time.After(OneSecond) @@ -582,11 +583,11 @@ func TestStopAndWait(t *testing.T) { t.Run("a couple fast jobs added, still returns immediately", func(t *testing.T) { cron := newWithSeconds() - cron.AddFunc("* * * * * *", func() {}) + cron.AddFunc("* * * * * *", func(ctx context.Context) {}) cron.Start() - cron.AddFunc("* * * * * *", func() {}) - cron.AddFunc("* * * * * *", func() {}) - cron.AddFunc("* * * * * *", func() {}) + cron.AddFunc("* * * * * *", func(ctx context.Context) {}) + cron.AddFunc("* * * * * *", func(ctx context.Context) {}) + cron.AddFunc("* * * * * *", func(ctx context.Context) {}) time.Sleep(time.Second) ctx := cron.Stop() select { @@ -598,10 +599,10 @@ func TestStopAndWait(t *testing.T) { t.Run("a couple fast jobs and a slow job added, waits for slow job", func(t *testing.T) { cron := newWithSeconds() - cron.AddFunc("* * * * * *", func() {}) + cron.AddFunc("* * * * * *", func(ctx context.Context) {}) cron.Start() - cron.AddFunc("* * * * * *", func() { time.Sleep(2 * time.Second) }) - cron.AddFunc("* * * * * *", func() {}) + cron.AddFunc("* * * * * *", func(ctx context.Context) { time.Sleep(2 * time.Second) }) + cron.AddFunc("* * * * * *", func(ctx context.Context) {}) time.Sleep(time.Second) ctx := cron.Stop() @@ -625,10 +626,10 @@ func TestStopAndWait(t *testing.T) { t.Run("repeated calls to stop, waiting for completion and after", func(t *testing.T) { cron := newWithSeconds() - cron.AddFunc("* * * * * *", func() {}) - cron.AddFunc("* * * * * *", func() { time.Sleep(2 * time.Second) }) + cron.AddFunc("* * * * * *", func(ctx context.Context) {}) + cron.AddFunc("* * * * * *", func(ctx context.Context) { time.Sleep(2 * time.Second) }) cron.Start() - cron.AddFunc("* * * * * *", func() {}) + cron.AddFunc("* * * * * *", func(ctx context.Context) {}) time.Sleep(time.Second) ctx := cron.Stop() ctx2 := cron.Stop() diff --git a/option_test.go b/option_test.go index 8aef1682..35251b36 100644 --- a/option_test.go +++ b/option_test.go @@ -1,6 +1,7 @@ package cron import ( + "context" "log" "strings" "testing" @@ -30,7 +31,7 @@ func TestWithVerboseLogger(t *testing.T) { t.Error("expected provided logger") } - c.AddFunc("@every 1s", func() {}) + c.AddFunc("@every 1s", func(ctx context.Context) {}) c.Start() time.Sleep(OneSecond) c.Stop()