Skip to content

Commit

Permalink
Refactor the intrusive linked list (Matthias247#30)
Browse files Browse the repository at this point in the history
This is an attempt to reduce the amount of unsafe code and
raw pointers in intrusive list usage.

The main unsafe function is now adding a node to the intrusive linked
list, which requires the caller to assure they will remove the item
properly later on.

Many other functions have been converted to safe functions. In order for
the iterators to be safe, they had been converted to internal iterators.
External iterators did not provide proper guarantees on when they will
drain the list.

The change also moved from raw pointers to using
`Option<NonNull<ListNode<T>>>` for intrusive list
  • Loading branch information
Matthias247 committed Feb 20, 2020
1 parent 34c02d5 commit 9945e76
Show file tree
Hide file tree
Showing 9 changed files with 534 additions and 465 deletions.
115 changes: 55 additions & 60 deletions src/channel/mpmc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,53 +20,45 @@ use super::{
SendPollState, SendWaitQueueEntry, TryReceiveError, TrySendError,
};

fn wake_recv_waiters(waiters: LinkedList<RecvWaitQueueEntry>) {
unsafe {
// Reverse the waiter list, so that the oldest waker (which is
// at the end of the list), gets woken first and has the best
// chance to grab the channel value.

for waiter in waiters.into_reverse_iter() {
if let Some(handle) = (*waiter).task.take() {
handle.wake();
}
// The only kind of waiter that could have been stored here are
// registered waiters (with a value), since others are removed
// whenever their value had been copied into the channel.
(*waiter).state = RecvPollState::Unregistered;
fn wake_recv_waiters(waiters: &mut LinkedList<RecvWaitQueueEntry>) {
// Remove all waiters from the waiting list in reverse order and wake them.
// We reverse the waiter list, so that the oldest waker (which is
// at the end of the list), gets woken first and has the best
// chance to grab the channel value.
waiters.reverse_drain(|waiter| {
if let Some(handle) = waiter.task.take() {
handle.wake();
}
}
// The only kind of waiter that could have been stored here are
// registered waiters (with a value), since others are removed
// whenever their value had been copied into the channel.
waiter.state = RecvPollState::Unregistered;
});
}

fn wake_send_waiters<T>(waiters: LinkedList<SendWaitQueueEntry<T>>) {
unsafe {
// Reverse the waiter list, so that the oldest waker (which is
// at the end of the list), gets woken first and has the best
// chance to grab the channel value.

for waiter in waiters.into_reverse_iter() {
if let Some(handle) = (*waiter).task.take() {
handle.wake();
}
(*waiter).state = SendPollState::Unregistered;
fn wake_send_waiters<T>(waiters: &mut LinkedList<SendWaitQueueEntry<T>>) {
// Remove all waiters from the waiting list in reverse order and wake them.
// We reverse the waiter list, so that the oldest waker (which is
// at the end of the list), gets woken first and has the best
// chance to send.
waiters.reverse_drain(|waiter| {
if let Some(handle) = waiter.task.take() {
handle.wake();
}
}
waiter.state = SendPollState::Unregistered;
});
}

/// Wakes up the last waiter and removes it from the wait queue
#[must_use]
fn return_oldest_receive_waiter(
waiters: &mut LinkedList<RecvWaitQueueEntry>,
) -> Option<Waker> {
// Safety: The list is is guaranteed to be in consistent state
let last_waiter = unsafe { waiters.remove_last() };

if !last_waiter.is_null() {
unsafe {
(*last_waiter).state = RecvPollState::Notified;
let last_waiter = waiters.remove_last();

(*last_waiter).task.take()
}
if let Some(last_waiter) = last_waiter {
last_waiter.state = RecvPollState::Notified;
last_waiter.task.take()
} else {
None
}
Expand Down Expand Up @@ -108,10 +100,8 @@ where

// Wakeup all send and receive waiters, since they are now guaranteed
// to make progress.
let recv_waiters = self.receive_waiters.take();
wake_recv_waiters(recv_waiters);
let send_waiters = self.send_waiters.take();
wake_send_waiters(send_waiters);
wake_recv_waiters(&mut self.receive_waiters);
wake_send_waiters(&mut self.send_waiters);

CloseStatus::NewlyClosed
}
Expand Down Expand Up @@ -200,11 +190,10 @@ where
/// If there is a send waiter, copy it's value into the channel buffer and complete it.
/// The method may only be called if there is space in the receive buffer.
#[must_use]
unsafe fn try_copy_value_from_oldest_waiter(&mut self) -> Option<Waker> {
fn try_copy_value_from_oldest_waiter(&mut self) -> Option<Waker> {
let last_waiter = self.send_waiters.remove_last();

if !last_waiter.is_null() {
let last_waiter = &mut (*last_waiter);
if let Some(last_waiter) = last_waiter {
let value = last_waiter
.value
.take()
Expand All @@ -222,24 +211,26 @@ where
/// Tries to extract a value from the sending waiter which has been waiting
/// longest on the send operation to complete.
fn try_take_value_from_sender(&mut self) -> Option<(T, Option<Waker>)> {
if self.send_waiters.is_empty() {
return None;
// Safety: The method is only called inside the lock on a consistent
// list.
match self.send_waiters.remove_last() {
Some(last_sender) => {
// This path should be only used for 0 capacity queues.
// Since the list is not empty, a value is available.
// Extract it from the sender in order to return it
debug_assert_eq!(0, self.buffer.capacity());

// Safety: The sender can't be invalid, since we only add valid
// senders to the queue
let val =
last_sender.value.take().expect("Value must be available");
last_sender.state = SendPollState::SendComplete;

// Return the waiter
Some((val, last_sender.task.take()))
}
None => None,
}
// This path should be only used for 0 capacity queues.
// Since the list is not empty, a value is available.
// Extract it from the sender in order to return it
debug_assert_eq!(0, self.buffer.capacity());
let last_sender = unsafe { self.send_waiters.remove_last() };
debug_assert!(!last_sender.is_null());

// Safety: The sender can't be null, since we only add valid
// senders to the queue
let last_sender = unsafe { &mut (*last_sender) };
let val = last_sender.value.take().expect("Value must be available");
last_sender.state = SendPollState::SendComplete;

// Return the waiter
Some((val, last_sender.task.take()))
}

/// Tries to receive a value from the channel without waiting.
Expand All @@ -249,7 +240,7 @@ where

// Since this means a space in the buffer had been freed,
// try to copy a value from a potential waiter into the channel.
let waker = unsafe { self.try_copy_value_from_oldest_waiter() };
let waker = self.try_copy_value_from_oldest_waiter();

Ok((val, waker))
} else if let Some((val, waker)) = self.try_take_value_from_sender() {
Expand Down Expand Up @@ -307,6 +298,8 @@ where
// This has happened in the SendPollState::Registered case.
match wait_node.state {
SendPollState::Registered => {
// Safety: Due to the state, we know that the node must be part
// of the waiter list
if !unsafe { self.send_waiters.remove(wait_node) } {
// Panic if the address isn't found. This can only happen if the contract was
// violated, e.g. the WaitQueueEntry got moved after the initial poll.
Expand All @@ -330,6 +323,8 @@ where
// the wait queue of the channel. This has happened in the RecvPollState::Registered case.
match wait_node.state {
RecvPollState::Registered => {
// Safety: Due to the state, we know that the node must be part
// of the waiter list
if !unsafe { self.receive_waiters.remove(wait_node) } {
// Panic if the address isn't found. This can only happen if the contract was
// violated, e.g. the WaitQueueEntry got moved after the initial poll.
Expand Down
33 changes: 12 additions & 21 deletions src/channel/oneshot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,17 @@ use core::marker::PhantomData;
use futures_core::task::{Context, Poll};
use lock_api::{Mutex, RawMutex};

unsafe fn wake_waiters(waiters: LinkedList<RecvWaitQueueEntry>) {
// Reverse the waiter list, so that the oldest waker (which is
fn wake_waiters(waiters: &mut LinkedList<RecvWaitQueueEntry>) {
// Remove all waiters from the waiting list in reverse order and wake them.
// We reverse the waiter list, so that the oldest waker (which is
// at the end of the list), gets woken first and has the best
// chance to grab the channel value.

for waiter in waiters.into_reverse_iter() {
if let Some(handle) = (*waiter).task.take() {
waiters.reverse_drain(|waiter| {
if let Some(handle) = waiter.task.take() {
handle.wake();
}
(*waiter).state = RecvPollState::Unregistered;
}
waiter.state = RecvPollState::Unregistered;
});
}

/// Internal state of the oneshot channel
Expand Down Expand Up @@ -56,13 +56,7 @@ impl<T> ChannelState<T> {
self.is_fulfilled = true;

// Wakeup all waiters
let waiters = self.waiters.take();
// Safety: The linked list is guaranteed to be only manipulated inside
// the mutex in scope of the ChannelState and is thereby guaranteed to
// be consistent.
unsafe {
wake_waiters(waiters);
}
wake_waiters(&mut self.waiters);

Ok(())
}
Expand All @@ -74,13 +68,8 @@ impl<T> ChannelState<T> {
self.is_fulfilled = true;

// Wakeup all waiters
let waiters = self.waiters.take();
// Safety: The linked list is guaranteed to be only manipulated inside
// the mutex in scope of the ChannelState and is thereby guaranteed to
// be consistent.
unsafe {
wake_waiters(waiters);
}
wake_waiters(&mut self.waiters);

CloseStatus::NewlyClosed
}

Expand Down Expand Up @@ -135,6 +124,8 @@ impl<T> ChannelState<T> {
// ChannelReceiveFuture only needs to get removed if it had been added to
// the wait queue of the channel. This has happened in the RecvPollState::Waiting case.
if let RecvPollState::Registered = wait_node.state {
// Safety: Due to the state, we know that the node must be part
// of the waiter list
if !unsafe { self.waiters.remove(wait_node) } {
// Panic if the address isn't found. This can only happen if the contract was
// violated, e.g. the RecvWaitQueueEntry got moved after the initial poll.
Expand Down
32 changes: 11 additions & 21 deletions src/channel/oneshot_broadcast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,17 @@ use core::marker::PhantomData;
use futures_core::task::{Context, Poll};
use lock_api::{Mutex, RawMutex};

unsafe fn wake_waiters(waiters: LinkedList<RecvWaitQueueEntry>) {
// Reverse the waiter list, so that the oldest waker (which is
fn wake_waiters(waiters: &mut LinkedList<RecvWaitQueueEntry>) {
// Remove all waiters from the waiting list in reverse order and wake them.
// We reverse the waiter list, so that the oldest waker (which is
// at the end of the list), gets woken first and has the best
// chance to grab the channel value.

for waiter in waiters.into_reverse_iter() {
if let Some(handle) = (*waiter).task.take() {
waiters.reverse_drain(|waiter| {
if let Some(handle) = waiter.task.take() {
handle.wake();
}
(*waiter).state = RecvPollState::Unregistered;
}
waiter.state = RecvPollState::Unregistered;
});
}

/// Internal state of the oneshot channel
Expand Down Expand Up @@ -60,13 +60,7 @@ where
self.is_fulfilled = true;

// Wakeup all waiters
let waiters = self.waiters.take();
// Safety: The linked list is guaranteed to be only manipulated inside
// the mutex in scope of the ChannelState and is thereby guaranteed to
// be consistent.
unsafe {
wake_waiters(waiters);
}
wake_waiters(&mut self.waiters);

Ok(())
}
Expand All @@ -78,13 +72,7 @@ where
self.is_fulfilled = true;

// Wakeup all waiters
let waiters = self.waiters.take();
// Safety: The linked list is guaranteed to be only manipulated inside
// the mutex in scope of the ChannelState and is thereby guaranteed to
// be consistent.
unsafe {
wake_waiters(waiters);
}
wake_waiters(&mut self.waiters);

CloseStatus::NewlyClosed
}
Expand Down Expand Up @@ -141,6 +129,8 @@ where
// ChannelReceiveFuture only needs to get removed if it had been added to
// the wait queue of the channel. This has happened in the RecvPollState::Waiting case.
if let RecvPollState::Registered = wait_node.state {
// Safety: Due to the state, we know that the node must be part
// of the waiter list
if !unsafe { self.waiters.remove(wait_node) } {
// Panic if the address isn't found. This can only happen if the contract was
// violated, e.g. the RecvWaitQueueEntry got moved after the initial poll.
Expand Down
32 changes: 11 additions & 21 deletions src/channel/state_broadcast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,17 +166,17 @@ impl<'a, MutexType, T: Clone> Drop for StateReceiveFuture<'a, MutexType, T> {
}
}

unsafe fn wake_waiters(waiters: LinkedList<RecvWaitQueueEntry>) {
// Reverse the waiter list, so that the oldest waker (which is
fn wake_waiters(waiters: &mut LinkedList<RecvWaitQueueEntry>) {
// Remove all waiters from the waiting list in reverse order and wake them.
// We reverse the waiter list, so that the oldest waker (which is
// at the end of the list), gets woken first and has the best
// chance to grab the channel value.

for waiter in waiters.into_reverse_iter() {
if let Some(handle) = (*waiter).task.take() {
waiters.reverse_drain(|waiter| {
if let Some(handle) = waiter.task.take() {
handle.wake();
}
(*waiter).state = RecvPollState::Unregistered;
}
waiter.state = RecvPollState::Unregistered;
});
}

/// Internal state of the state broadcast channel
Expand Down Expand Up @@ -215,13 +215,7 @@ where
self.state_id.0 += 1;

// Wakeup all waiters
let waiters = self.waiters.take();
// Safety: The linked list is guaranteed to be only manipulated inside
// the mutex in scope of the ChannelState and is thereby guaranteed to
// be consistent.
unsafe {
wake_waiters(waiters);
}
wake_waiters(&mut self.waiters);

Ok(())
}
Expand All @@ -233,13 +227,7 @@ where
self.is_closed = true;

// Wakeup all waiters
let waiters = self.waiters.take();
// Safety: The linked list is guaranteed to be only manipulated inside
// the mutex in scope of the ChannelState and is thereby guaranteed to
// be consistent.
unsafe {
wake_waiters(waiters);
}
wake_waiters(&mut self.waiters);

CloseStatus::NewlyClosed
}
Expand Down Expand Up @@ -310,6 +298,8 @@ where
// StateReceiveFuture only needs to get removed if it had been added to
// the wait queue of the channel. This has happened in the RecvPollState::Waiting case.
if let RecvPollState::Registered = wait_node.state {
// Safety: Due to the state, we know that the node must be part
// of the waiter list
if !unsafe { self.waiters.remove(wait_node) } {
// Panic if the address isn't found. This can only happen if the contract was
// violated, e.g. the RecvWaitQueueEntry got moved after the initial poll.
Expand Down
Loading

0 comments on commit 9945e76

Please sign in to comment.