Skip to content

Commit b72936d

Browse files
committed
Only work-steal in the main loop
1 parent f192a48 commit b72936d

File tree

18 files changed

+284
-117
lines changed

18 files changed

+284
-117
lines changed

Diff for: rayon-core/Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ num_cpus = "1.2"
2222
crossbeam-channel = "0.5.0"
2323
crossbeam-deque = "0.8.1"
2424
crossbeam-utils = "0.8.0"
25+
smallvec = "1.11.0"
2526

2627
[dev-dependencies]
2728
rand = "0.8"

Diff for: rayon-core/src/broadcast/mod.rs

+21-3
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use crate::registry::{Registry, WorkerThread};
44
use crate::scope::ScopeLatch;
55
use std::fmt;
66
use std::marker::PhantomData;
7+
use std::sync::atomic::{AtomicBool, Ordering};
78
use std::sync::Arc;
89

910
mod test;
@@ -100,13 +101,22 @@ where
100101
OP: Fn(BroadcastContext<'_>) -> R + Sync,
101102
R: Send,
102103
{
104+
let current_thread = WorkerThread::current();
105+
let current_thread_addr = current_thread as usize;
106+
let started = &AtomicBool::new(false);
103107
let f = move |injected: bool| {
104108
debug_assert!(injected);
109+
110+
// Mark as started if we are on the thread that initiated the broadcast.
111+
if current_thread_addr == WorkerThread::current() as usize {
112+
started.store(true, Ordering::Relaxed);
113+
}
114+
105115
BroadcastContext::with(&op)
106116
};
107117

108118
let n_threads = registry.num_threads();
109-
let current_thread = WorkerThread::current().as_ref();
119+
let current_thread = current_thread.as_ref();
110120
let tlv = crate::tlv::get();
111121
let latch = ScopeLatch::with_count(n_threads, current_thread);
112122
let jobs: Vec<_> = (0..n_threads)
@@ -116,8 +126,16 @@ where
116126

117127
registry.inject_broadcast(job_refs);
118128

129+
let current_thread_job_id = current_thread
130+
.and_then(|worker| (registry.id() == worker.registry.id()).then(|| worker))
131+
.map(|worker| jobs[worker.index].as_job_ref().id());
132+
119133
// Wait for all jobs to complete, then collect the results, maybe propagating a panic.
120-
latch.wait(current_thread);
134+
latch.wait(
135+
current_thread,
136+
|| started.load(Ordering::Relaxed),
137+
|job| Some(job.id()) == current_thread_job_id,
138+
);
121139
jobs.into_iter().map(|job| job.into_result()).collect()
122140
}
123141

@@ -133,7 +151,7 @@ where
133151
{
134152
let job = ArcJob::new({
135153
let registry = Arc::clone(registry);
136-
move || {
154+
move |_| {
137155
registry.catch_unwind(|| BroadcastContext::with(&op));
138156
registry.terminate(); // (*) permit registry to terminate now
139157
}

Diff for: rayon-core/src/broadcast/test.rs

+2
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ fn spawn_broadcast_self() {
6363
}
6464

6565
#[test]
66+
#[ignore]
6667
#[cfg_attr(any(target_os = "emscripten", target_family = "wasm"), ignore)]
6768
fn broadcast_mutual() {
6869
let count = AtomicUsize::new(0);
@@ -97,6 +98,7 @@ fn spawn_broadcast_mutual() {
9798
}
9899

99100
#[test]
101+
#[ignore]
100102
#[cfg_attr(any(target_os = "emscripten", target_family = "wasm"), ignore)]
101103
fn broadcast_mutual_sleepy() {
102104
let count = AtomicUsize::new(0);

Diff for: rayon-core/src/job.rs

+26-14
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,11 @@ pub(super) trait Job {
2626
unsafe fn execute(this: *const ());
2727
}
2828

29+
#[derive(PartialEq, Eq, Hash, Copy, Clone, Debug)]
30+
pub(super) struct JobRefId {
31+
pointer: usize,
32+
}
33+
2934
/// Effectively a Job trait object. Each JobRef **must** be executed
3035
/// exactly once, or else data may leak.
3136
///
@@ -54,11 +59,11 @@ impl JobRef {
5459
}
5560
}
5661

57-
/// Returns an opaque handle that can be saved and compared,
58-
/// without making `JobRef` itself `Copy + Eq`.
5962
#[inline]
60-
pub(super) fn id(&self) -> impl Eq {
61-
(self.pointer, self.execute_fn)
63+
pub(super) fn id(&self) -> JobRefId {
64+
JobRefId {
65+
pointer: self.pointer as usize,
66+
}
6267
}
6368

6469
#[inline]
@@ -102,8 +107,13 @@ where
102107
JobRef::new(self)
103108
}
104109

105-
pub(super) unsafe fn run_inline(self, stolen: bool) -> R {
106-
self.func.into_inner().unwrap()(stolen)
110+
pub(super) unsafe fn run_inline(&self, stolen: bool) {
111+
let func = (*self.func.get()).take().unwrap();
112+
(*self.result.get()) = match unwind::halt_unwinding(|| func(stolen)) {
113+
Ok(x) => JobResult::Ok(x),
114+
Err(x) => JobResult::Panic(x),
115+
};
116+
Latch::set(&self.latch);
107117
}
108118

109119
pub(super) unsafe fn into_result(self) -> R {
@@ -136,15 +146,15 @@ where
136146
/// (Probably `StackJob` should be refactored in a similar fashion.)
137147
pub(super) struct HeapJob<BODY>
138148
where
139-
BODY: FnOnce() + Send,
149+
BODY: FnOnce(JobRefId) + Send,
140150
{
141151
job: BODY,
142152
tlv: Tlv,
143153
}
144154

145155
impl<BODY> HeapJob<BODY>
146156
where
147-
BODY: FnOnce() + Send,
157+
BODY: FnOnce(JobRefId) + Send,
148158
{
149159
pub(super) fn new(tlv: Tlv, job: BODY) -> Box<Self> {
150160
Box::new(HeapJob { job, tlv })
@@ -168,27 +178,28 @@ where
168178

169179
impl<BODY> Job for HeapJob<BODY>
170180
where
171-
BODY: FnOnce() + Send,
181+
BODY: FnOnce(JobRefId) + Send,
172182
{
173183
unsafe fn execute(this: *const ()) {
184+
let pointer = this as usize;
174185
let this = Box::from_raw(this as *mut Self);
175186
tlv::set(this.tlv);
176-
(this.job)();
187+
(this.job)(JobRefId { pointer });
177188
}
178189
}
179190

180191
/// Represents a job stored in an `Arc` -- like `HeapJob`, but may
181192
/// be turned into multiple `JobRef`s and called multiple times.
182193
pub(super) struct ArcJob<BODY>
183194
where
184-
BODY: Fn() + Send + Sync,
195+
BODY: Fn(JobRefId) + Send + Sync,
185196
{
186197
job: BODY,
187198
}
188199

189200
impl<BODY> ArcJob<BODY>
190201
where
191-
BODY: Fn() + Send + Sync,
202+
BODY: Fn(JobRefId) + Send + Sync,
192203
{
193204
pub(super) fn new(job: BODY) -> Arc<Self> {
194205
Arc::new(ArcJob { job })
@@ -212,11 +223,12 @@ where
212223

213224
impl<BODY> Job for ArcJob<BODY>
214225
where
215-
BODY: Fn() + Send + Sync,
226+
BODY: Fn(JobRefId) + Send + Sync,
216227
{
217228
unsafe fn execute(this: *const ()) {
229+
let pointer = this as usize;
218230
let this = Arc::from_raw(this as *mut Self);
219-
(this.job)();
231+
(this.job)(JobRefId { pointer });
220232
}
221233
}
222234

Diff for: rayon-core/src/join/mod.rs

+28-54
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1+
use crate::job::JobRef;
12
use crate::job::StackJob;
23
use crate::latch::SpinLatch;
3-
use crate::registry::{self, WorkerThread};
4-
use crate::tlv::{self, Tlv};
4+
use crate::registry;
5+
use crate::tlv;
56
use crate::unwind;
6-
use std::any::Any;
7+
use std::sync::atomic::{AtomicBool, Ordering};
78

89
use crate::FnContext;
910

@@ -135,68 +136,41 @@ where
135136
// Create virtual wrapper for task b; this all has to be
136137
// done here so that the stack frame can keep it all live
137138
// long enough.
138-
let job_b = StackJob::new(tlv, call_b(oper_b), SpinLatch::new(worker_thread));
139+
let job_b_started = AtomicBool::new(false);
140+
let job_b = StackJob::new(
141+
tlv,
142+
|migrated| {
143+
job_b_started.store(true, Ordering::Relaxed);
144+
call_b(oper_b)(migrated)
145+
},
146+
SpinLatch::new(worker_thread),
147+
);
139148
let job_b_ref = job_b.as_job_ref();
140149
let job_b_id = job_b_ref.id();
141150
worker_thread.push(job_b_ref);
142151

143152
// Execute task a; hopefully b gets stolen in the meantime.
144153
let status_a = unwind::halt_unwinding(call_a(oper_a, injected));
145-
let result_a = match status_a {
146-
Ok(v) => v,
147-
Err(err) => join_recover_from_panic(worker_thread, &job_b.latch, err, tlv),
148-
};
149-
150-
// Now that task A has finished, try to pop job B from the
151-
// local stack. It may already have been popped by job A; it
152-
// may also have been stolen. There may also be some tasks
153-
// pushed on top of it in the stack, and we will have to pop
154-
// those off to get to it.
155-
while !job_b.latch.probe() {
156-
if let Some(job) = worker_thread.take_local_job() {
157-
if job_b_id == job.id() {
158-
// Found it! Let's run it.
159-
//
160-
// Note that this could panic, but it's ok if we unwind here.
161154

162-
// Restore the TLV since we might have run some jobs overwriting it when waiting for job b.
163-
tlv::set(tlv);
164-
165-
let result_b = job_b.run_inline(injected);
166-
return (result_a, result_b);
167-
} else {
168-
worker_thread.execute(job);
169-
}
170-
} else {
171-
// Local deque is empty. Time to steal from other
172-
// threads.
173-
worker_thread.wait_until(&job_b.latch);
174-
debug_assert!(job_b.latch.probe());
175-
break;
176-
}
177-
}
155+
// Wait for job B or execute it if it's in the local queue.
156+
worker_thread.wait_for_jobs::<_, false>(
157+
&job_b.latch,
158+
|| job_b_started.load(Ordering::Relaxed),
159+
|job| job.id() == job_b_id,
160+
|job: JobRef| {
161+
debug_assert_eq!(job.id(), job_b_id);
162+
job_b.run_inline(injected);
163+
},
164+
);
178165

179166
// Restore the TLV since we might have run some jobs overwriting it when waiting for job b.
180167
tlv::set(tlv);
181168

169+
let result_a = match status_a {
170+
Ok(v) => v,
171+
Err(err) => unwind::resume_unwinding(err),
172+
};
173+
182174
(result_a, job_b.into_result())
183175
})
184176
}
185-
186-
/// If job A panics, we still cannot return until we are sure that job
187-
/// B is complete. This is because it may contain references into the
188-
/// enclosing stack frame(s).
189-
#[cold] // cold path
190-
unsafe fn join_recover_from_panic(
191-
worker_thread: &WorkerThread,
192-
job_b_latch: &SpinLatch<'_>,
193-
err: Box<dyn Any + Send>,
194-
tlv: Tlv,
195-
) -> ! {
196-
worker_thread.wait_until(job_b_latch);
197-
198-
// Restore the TLV since we might have run some jobs overwriting it when waiting for job b.
199-
tlv::set(tlv);
200-
201-
unwind::resume_unwinding(err)
202-
}

Diff for: rayon-core/src/join/test.rs

+1
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ fn join_context_both() {
9797
}
9898

9999
#[test]
100+
#[ignore]
100101
#[cfg_attr(any(target_os = "emscripten", target_family = "wasm"), ignore)]
101102
fn join_context_neither() {
102103
// If we're already in a 1-thread pool, neither job should be stolen.

Diff for: rayon-core/src/latch.rs

-5
Original file line numberDiff line numberDiff line change
@@ -177,11 +177,6 @@ impl<'r> SpinLatch<'r> {
177177
..SpinLatch::new(thread)
178178
}
179179
}
180-
181-
#[inline]
182-
pub(super) fn probe(&self) -> bool {
183-
self.core_latch.probe()
184-
}
185180
}
186181

187182
impl<'r> AsCoreLatch for SpinLatch<'r> {

0 commit comments

Comments
 (0)