Skip to content

Commit 63352b0

Browse files
committed
solver: recursively add merge source jobs to target and ancestors
Signed-off-by: Erik Sipsma <[email protected]>
1 parent 447857f commit 63352b0

File tree

2 files changed

+111
-18
lines changed

2 files changed

+111
-18
lines changed

solver/jobs.go

+46-5
Original file line numberDiff line numberDiff line change
@@ -176,11 +176,7 @@ func (s *state) setEdge(index Index, targetEdge *edge, targetState *state) {
176176
targetEdge.takeOwnership(e)
177177

178178
if targetState != nil {
179-
targetState.mu.Lock()
180-
for j := range s.jobs {
181-
targetState.jobs[j] = struct{}{}
182-
}
183-
targetState.mu.Unlock()
179+
targetState.addJobs(s, map[*state]struct{}{})
184180

185181
if _, ok := targetState.allPw[s.mpw]; !ok {
186182
targetState.mpw.Add(s.mpw)
@@ -189,6 +185,51 @@ func (s *state) setEdge(index Index, targetEdge *edge, targetState *state) {
189185
}
190186
}
191187

188+
// addJobs recursively adds jobs to state and all its ancestors. currently
189+
// only used during edge merges to add jobs from the source of the merge to the
190+
// target and its ancestors.
191+
// requires that Solver.mu is read-locked and srcState.mu is locked
192+
func (s *state) addJobs(srcState *state, memo map[*state]struct{}) {
193+
if _, ok := memo[s]; ok {
194+
return
195+
}
196+
memo[s] = struct{}{}
197+
198+
s.mu.Lock()
199+
defer s.mu.Unlock()
200+
201+
for j := range srcState.jobs {
202+
s.jobs[j] = struct{}{}
203+
}
204+
205+
for _, inputEdge := range s.vtx.Inputs() {
206+
inputState, ok := s.solver.actives[inputEdge.Vertex.Digest()]
207+
if !ok {
208+
bklog.G(context.TODO()).
209+
WithField("vertex_digest", inputEdge.Vertex.Digest()).
210+
Error("input vertex not found during addJobs")
211+
continue
212+
}
213+
inputState.addJobs(srcState, memo)
214+
215+
// tricky case: if the inputState's edge was *already* merged we should
216+
// also add jobs to the merged edge's state
217+
mergedInputEdge := inputState.getEdge(inputEdge.Index)
218+
if mergedInputEdge == nil || mergedInputEdge.edge.Vertex.Digest() == inputEdge.Vertex.Digest() {
219+
// not merged
220+
continue
221+
}
222+
mergedInputState, ok := s.solver.actives[mergedInputEdge.edge.Vertex.Digest()]
223+
if !ok {
224+
bklog.G(context.TODO()).
225+
WithField("vertex_digest", mergedInputEdge.edge.Vertex.Digest()).
226+
Error("merged input vertex not found during addJobs")
227+
continue
228+
}
229+
mergedInputState.addJobs(srcState, memo)
230+
}
231+
}
232+
192233
func (s *state) combinedCacheManager() CacheManager {
193234
s.mu.Lock()
194235
cms := make([]CacheManager, 0, len(s.cache)+1)

solver/scheduler_test.go

+65-13
Original file line numberDiff line numberDiff line change
@@ -3457,18 +3457,19 @@ func TestStaleEdgeMerge(t *testing.T) {
34573457
})
34583458
defer s.Close()
34593459

3460+
depV0 := vtxConst(1, vtxOpt{name: "depV0"})
3461+
depV1 := vtxConst(1, vtxOpt{name: "depV1"})
3462+
depV2 := vtxConst(1, vtxOpt{name: "depV2"})
3463+
34603464
// These should all end up edge merged
34613465
v0 := vtxAdd(2, vtxOpt{name: "v0", inputs: []Edge{
3462-
{Vertex: vtxConst(3, vtxOpt{})},
3463-
{Vertex: vtxConst(4, vtxOpt{})},
3466+
{Vertex: depV0},
34643467
}})
34653468
v1 := vtxAdd(2, vtxOpt{name: "v1", inputs: []Edge{
3466-
{Vertex: vtxConst(3, vtxOpt{})},
3467-
{Vertex: vtxConst(4, vtxOpt{})},
3469+
{Vertex: depV1},
34683470
}})
34693471
v2 := vtxAdd(2, vtxOpt{name: "v2", inputs: []Edge{
3470-
{Vertex: vtxConst(3, vtxOpt{})},
3471-
{Vertex: vtxConst(4, vtxOpt{})},
3472+
{Vertex: depV2},
34723473
}})
34733474

34743475
j0, err := s.NewJob("job0")
@@ -3478,6 +3479,11 @@ func TestStaleEdgeMerge(t *testing.T) {
34783479
require.NoError(t, err)
34793480
require.NotNil(t, res)
34803481

3482+
require.Contains(t, s.actives, v0.Digest())
3483+
require.Contains(t, s.actives[v0.Digest()].jobs, j0)
3484+
require.Contains(t, s.actives, depV0.Digest())
3485+
require.Contains(t, s.actives[depV0.Digest()].jobs, j0)
3486+
34813487
// this edge should be merged with the one from j0
34823488
j1, err := s.NewJob("job1")
34833489
require.NoError(t, err)
@@ -3486,14 +3492,37 @@ func TestStaleEdgeMerge(t *testing.T) {
34863492
require.NoError(t, err)
34873493
require.NotNil(t, res)
34883494

3495+
require.Contains(t, s.actives, v0.Digest())
3496+
require.Contains(t, s.actives[v0.Digest()].jobs, j0)
3497+
require.Contains(t, s.actives[v0.Digest()].jobs, j1)
3498+
require.Contains(t, s.actives, depV0.Digest())
3499+
require.Contains(t, s.actives[depV0.Digest()].jobs, j0)
3500+
require.Contains(t, s.actives[depV0.Digest()].jobs, j1)
3501+
3502+
require.Contains(t, s.actives, v1.Digest())
3503+
require.NotContains(t, s.actives[v1.Digest()].jobs, j0)
3504+
require.Contains(t, s.actives[v1.Digest()].jobs, j1)
3505+
require.Contains(t, s.actives, depV1.Digest())
3506+
require.NotContains(t, s.actives[depV1.Digest()].jobs, j0)
3507+
require.Contains(t, s.actives[depV1.Digest()].jobs, j1)
3508+
34893509
// discard j0, verify that v0 is still active and it's state contains j1 since j1's
34903510
// edge was merged to v0's state
34913511
require.NoError(t, j0.Discard())
3512+
34923513
require.Contains(t, s.actives, v0.Digest())
3493-
require.Contains(t, s.actives, v1.Digest())
34943514
require.NotContains(t, s.actives[v0.Digest()].jobs, j0)
34953515
require.Contains(t, s.actives[v0.Digest()].jobs, j1)
3516+
require.Contains(t, s.actives, depV0.Digest())
3517+
require.NotContains(t, s.actives[depV0.Digest()].jobs, j0)
3518+
require.Contains(t, s.actives[depV0.Digest()].jobs, j1)
3519+
3520+
require.Contains(t, s.actives, v1.Digest())
3521+
require.NotContains(t, s.actives[v1.Digest()].jobs, j0)
34963522
require.Contains(t, s.actives[v1.Digest()].jobs, j1)
3523+
require.Contains(t, s.actives, depV1.Digest())
3524+
require.NotContains(t, s.actives[depV1.Digest()].jobs, j0)
3525+
require.Contains(t, s.actives[depV1.Digest()].jobs, j1)
34973526

34983527
// verify another job can still merge
34993528
j2, err := s.NewJob("job2")
@@ -3504,29 +3533,52 @@ func TestStaleEdgeMerge(t *testing.T) {
35043533
require.NotNil(t, res)
35053534

35063535
require.Contains(t, s.actives, v0.Digest())
3507-
require.Contains(t, s.actives, v1.Digest())
3508-
require.Contains(t, s.actives, v2.Digest())
3509-
require.NotContains(t, s.actives[v0.Digest()].jobs, j0)
35103536
require.Contains(t, s.actives[v0.Digest()].jobs, j1)
35113537
require.Contains(t, s.actives[v0.Digest()].jobs, j2)
3538+
require.Contains(t, s.actives, depV0.Digest())
3539+
require.Contains(t, s.actives[depV0.Digest()].jobs, j1)
3540+
require.Contains(t, s.actives[depV0.Digest()].jobs, j2)
3541+
3542+
require.Contains(t, s.actives, v1.Digest())
35123543
require.Contains(t, s.actives[v1.Digest()].jobs, j1)
3544+
require.NotContains(t, s.actives[v1.Digest()].jobs, j2)
3545+
require.Contains(t, s.actives, depV1.Digest())
3546+
require.Contains(t, s.actives[depV1.Digest()].jobs, j1)
3547+
require.NotContains(t, s.actives[depV1.Digest()].jobs, j2)
3548+
3549+
require.Contains(t, s.actives, v2.Digest())
3550+
require.NotContains(t, s.actives[v2.Digest()].jobs, j1)
35133551
require.Contains(t, s.actives[v2.Digest()].jobs, j2)
3552+
require.Contains(t, s.actives, depV2.Digest())
3553+
require.NotContains(t, s.actives[depV2.Digest()].jobs, j1)
3554+
require.Contains(t, s.actives[depV2.Digest()].jobs, j2)
35143555

35153556
// discard j1, verify only referenced edges still exist
35163557
require.NoError(t, j1.Discard())
3558+
35173559
require.Contains(t, s.actives, v0.Digest())
3518-
require.NotContains(t, s.actives, v1.Digest())
3519-
require.Contains(t, s.actives, v2.Digest())
3520-
require.NotContains(t, s.actives[v0.Digest()].jobs, j0)
35213560
require.NotContains(t, s.actives[v0.Digest()].jobs, j1)
35223561
require.Contains(t, s.actives[v0.Digest()].jobs, j2)
3562+
require.Contains(t, s.actives, depV0.Digest())
3563+
require.NotContains(t, s.actives[depV0.Digest()].jobs, j1)
3564+
require.Contains(t, s.actives[depV0.Digest()].jobs, j2)
3565+
3566+
require.NotContains(t, s.actives, v1.Digest())
3567+
require.NotContains(t, s.actives, depV1.Digest())
3568+
3569+
require.Contains(t, s.actives, v2.Digest())
35233570
require.Contains(t, s.actives[v2.Digest()].jobs, j2)
3571+
require.Contains(t, s.actives, depV2.Digest())
3572+
require.Contains(t, s.actives[depV2.Digest()].jobs, j2)
35243573

35253574
// discard the last job and verify everything was removed now
35263575
require.NoError(t, j2.Discard())
35273576
require.NotContains(t, s.actives, v0.Digest())
35283577
require.NotContains(t, s.actives, v1.Digest())
35293578
require.NotContains(t, s.actives, v2.Digest())
3579+
require.NotContains(t, s.actives, depV0.Digest())
3580+
require.NotContains(t, s.actives, depV1.Digest())
3581+
require.NotContains(t, s.actives, depV2.Digest())
35303582
}
35313583

35323584
func generateSubGraph(nodes int) (Edge, int) {

0 commit comments

Comments
 (0)