Don't deadlock on mutual stealing
This commit is contained in:
parent
5b745553b0
commit
fb209c7ff9
1 changed files with 63 additions and 22 deletions
85
src/lib.rs
85
src/lib.rs
|
@ -185,6 +185,22 @@ impl ThreadState {
|
|||
}
|
||||
}
|
||||
|
||||
fn steal_from(&mut self) -> Option<Vec<Arc<Task>>> {
|
||||
let split_point = (self.woken.len() - (self.woken.len() % 2)) / 2;
|
||||
|
||||
if split_point > 0 {
|
||||
let v = self.woken.drain(split_point..).collect::<Vec<_>>();
|
||||
|
||||
for task in &v {
|
||||
self.pending.remove(&task.task_id);
|
||||
}
|
||||
|
||||
return Some(v);
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
fn pop_available_head(&mut self) -> Option<Arc<Task>> {
|
||||
self.available.pop_front()
|
||||
}
|
||||
|
@ -335,10 +351,12 @@ impl Runner {
|
|||
return false;
|
||||
}
|
||||
|
||||
let mut guard = self.state.lock().unwrap();
|
||||
let current_tid = guard.handle.id();
|
||||
let (last_steal, current_tid) = {
|
||||
let mut guard = self.state.lock().unwrap();
|
||||
(guard.last_steal.take(), guard.handle.id())
|
||||
};
|
||||
|
||||
match guard.last_steal.take() {
|
||||
let (tid, stolen) = match last_steal {
|
||||
Some(id) => {
|
||||
if read_guard.contains_key(&id) {
|
||||
let opt = read_guard
|
||||
|
@ -350,13 +368,17 @@ impl Runner {
|
|||
|
||||
if let Some((tid, thread_state)) = opt {
|
||||
let mut state = thread_state.lock().unwrap();
|
||||
let split_point = (state.woken.len() - (state.woken.len() % 2)) / 2;
|
||||
if split_point > 0 {
|
||||
guard.woken.extend(state.woken.drain(split_point..));
|
||||
guard.last_steal = Some(*tid);
|
||||
return true;
|
||||
|
||||
if let Some(v) = state.steal_from() {
|
||||
(*tid, v)
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
None => {
|
||||
|
@ -364,17 +386,32 @@ impl Runner {
|
|||
|
||||
if let Some((tid, thread_state)) = opt {
|
||||
let mut state = thread_state.lock().unwrap();
|
||||
let split_point = (state.woken.len() - (state.woken.len() % 2)) / 2;
|
||||
if split_point > 0 {
|
||||
guard.woken.extend(state.woken.drain(split_point..));
|
||||
guard.last_steal = Some(*tid);
|
||||
return true;
|
||||
if let Some(v) = state.steal_from() {
|
||||
(*tid, v)
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
false
|
||||
println!(
|
||||
"oh lawd {:?} stolen {} tasks from {:?}",
|
||||
std::thread::current().id(),
|
||||
stolen.len(),
|
||||
tid
|
||||
);
|
||||
|
||||
let mut guard = self.state.lock().unwrap();
|
||||
for task in &stolen {
|
||||
guard.pending.insert(task.task_id, Arc::clone(task));
|
||||
}
|
||||
guard.woken.extend(stolen);
|
||||
guard.last_steal = Some(tid);
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
pub fn any_woken(&self) -> bool {
|
||||
|
@ -529,14 +566,18 @@ impl Future for Runner {
|
|||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let this = self.as_mut();
|
||||
|
||||
while this.tick(cx.waker().clone()) {
|
||||
// processing spawned tasks
|
||||
}
|
||||
loop {
|
||||
while this.tick(cx.waker().clone()) {
|
||||
// processing spawned tasks
|
||||
}
|
||||
|
||||
if this.stopping() {
|
||||
Poll::Ready(())
|
||||
} else {
|
||||
Poll::Pending
|
||||
if this.stopping() {
|
||||
return Poll::Ready(());
|
||||
}
|
||||
|
||||
if !this.steal() {
|
||||
return Poll::Pending;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue