diff --git a/tokio/src/util/linked_list.rs b/tokio/src/util/linked_list.rs index e74cd61b22a..dd00e1405e4 100644 --- a/tokio/src/util/linked_list.rs +++ b/tokio/src/util/linked_list.rs @@ -6,10 +6,11 @@ //! structure's APIs are `unsafe` as they require the caller to ensure the //! specified node is actually contained by the list. +use core::cell::UnsafeCell; use core::fmt; -use core::marker::PhantomData; +use core::marker::{PhantomData, PhantomPinned}; use core::mem::ManuallyDrop; -use core::ptr::NonNull; +use core::ptr::{self, NonNull}; /// An intrusive linked list. /// @@ -60,11 +61,40 @@ pub(crate) unsafe trait Link { /// Previous / next pointers pub(crate) struct Pointers { + inner: UnsafeCell>, +} +/// We do not want the compiler to put the `noalias` attribute on mutable +/// references to this type, so the type has been made `!Unpin` with a +/// `PhantomPinned` field. +/// +/// Additionally, we never access the `prev` or `next` fields directly, as any +/// such access would implicitly involve the creation of a reference to the +/// field, which we want to avoid since the fields are not `!Unpin`, and would +/// hence be given the `noalias` attribute if we were to do such an access. +/// As an alternative to accessing the fields directly, the `Pointers` type +/// provides getters and setters for the two fields, and those are implemented +/// using raw pointer casts and offsets, which is valid since the struct is +/// #[repr(C)]. +/// +/// See this link for more information: +/// https://github.com/rust-lang/rust/pull/82834 +#[repr(C)] +struct PointersInner { /// The previous node in the list. null if there is no previous node. + /// + /// This field is accessed through pointer manipulation, so it is not dead code. + #[allow(dead_code)] prev: Option>, /// The next node in the list. null if there is no previous node. + /// + /// This field is accessed through pointer manipulation, so it is not dead code. + #[allow(dead_code)] next: Option>, + + /// This type is !Unpin due to the heuristic from: + /// https://github.com/rust-lang/rust/pull/82834 + _pin: PhantomPinned, } unsafe impl Send for Pointers {} @@ -91,11 +121,11 @@ impl LinkedList { let ptr = L::as_raw(&*val); assert_ne!(self.head, Some(ptr)); unsafe { - L::pointers(ptr).as_mut().next = self.head; - L::pointers(ptr).as_mut().prev = None; + L::pointers(ptr).as_mut().set_next(self.head); + L::pointers(ptr).as_mut().set_prev(None); if let Some(head) = self.head { - L::pointers(head).as_mut().prev = Some(ptr); + L::pointers(head).as_mut().set_prev(Some(ptr)); } self.head = Some(ptr); @@ -111,16 +141,16 @@ impl LinkedList { pub(crate) fn pop_back(&mut self) -> Option { unsafe { let last = self.tail?; - self.tail = L::pointers(last).as_ref().prev; + self.tail = L::pointers(last).as_ref().get_prev(); - if let Some(prev) = L::pointers(last).as_ref().prev { - L::pointers(prev).as_mut().next = None; + if let Some(prev) = L::pointers(last).as_ref().get_prev() { + L::pointers(prev).as_mut().set_next(None); } else { self.head = None } - L::pointers(last).as_mut().prev = None; - L::pointers(last).as_mut().next = None; + L::pointers(last).as_mut().set_prev(None); + L::pointers(last).as_mut().set_next(None); Some(L::from_raw(last)) } @@ -143,31 +173,35 @@ impl LinkedList { /// The caller **must** ensure that `node` is currently contained by /// `self` or not contained by any other list. pub(crate) unsafe fn remove(&mut self, node: NonNull) -> Option { - if let Some(prev) = L::pointers(node).as_ref().prev { - debug_assert_eq!(L::pointers(prev).as_ref().next, Some(node)); - L::pointers(prev).as_mut().next = L::pointers(node).as_ref().next; + if let Some(prev) = L::pointers(node).as_ref().get_prev() { + debug_assert_eq!(L::pointers(prev).as_ref().get_next(), Some(node)); + L::pointers(prev) + .as_mut() + .set_next(L::pointers(node).as_ref().get_next()); } else { if self.head != Some(node) { return None; } - self.head = L::pointers(node).as_ref().next; + self.head = L::pointers(node).as_ref().get_next(); } - if let Some(next) = L::pointers(node).as_ref().next { - debug_assert_eq!(L::pointers(next).as_ref().prev, Some(node)); - L::pointers(next).as_mut().prev = L::pointers(node).as_ref().prev; + if let Some(next) = L::pointers(node).as_ref().get_next() { + debug_assert_eq!(L::pointers(next).as_ref().get_prev(), Some(node)); + L::pointers(next) + .as_mut() + .set_prev(L::pointers(node).as_ref().get_prev()); } else { // This might be the last item in the list if self.tail != Some(node) { return None; } - self.tail = L::pointers(node).as_ref().prev; + self.tail = L::pointers(node).as_ref().get_prev(); } - L::pointers(node).as_mut().next = None; - L::pointers(node).as_mut().prev = None; + L::pointers(node).as_mut().set_next(None); + L::pointers(node).as_mut().set_prev(None); Some(L::from_raw(node)) } @@ -224,7 +258,7 @@ cfg_rt_multi_thread! { fn next(&mut self) -> Option<&'a T::Target> { let curr = self.curr?; // safety: the pointer references data contained by the list - self.curr = unsafe { T::pointers(curr).as_ref() }.next; + self.curr = unsafe { T::pointers(curr).as_ref() }.get_next(); // safety: the value is still owned by the linked list. Some(unsafe { &*curr.as_ptr() }) @@ -265,7 +299,7 @@ cfg_io_readiness! { fn next(&mut self) -> Option { while let Some(curr) = self.curr { // safety: the pointer references data contained by the list - self.curr = unsafe { T::pointers(curr).as_ref() }.next; + self.curr = unsafe { T::pointers(curr).as_ref() }.get_next(); // safety: the value is still owned by the linked list. if (self.filter)(unsafe { &mut *curr.as_ptr() }) { @@ -284,17 +318,58 @@ impl Pointers { /// Create a new set of empty pointers pub(crate) fn new() -> Pointers { Pointers { - prev: None, - next: None, + inner: UnsafeCell::new(PointersInner { + prev: None, + next: None, + _pin: PhantomPinned, + }), + } + } + + fn get_prev(&self) -> Option> { + // SAFETY: prev is the first field in PointersInner, which is #[repr(C)]. + unsafe { + let inner = self.inner.get(); + let prev = inner as *const Option>; + ptr::read(prev) + } + } + fn get_next(&self) -> Option> { + // SAFETY: next is the second field in PointersInner, which is #[repr(C)]. + unsafe { + let inner = self.inner.get(); + let prev = inner as *const Option>; + let next = prev.add(1); + ptr::read(next) + } + } + + fn set_prev(&mut self, value: Option>) { + // SAFETY: prev is the first field in PointersInner, which is #[repr(C)]. + unsafe { + let inner = self.inner.get(); + let prev = inner as *mut Option>; + ptr::write(prev, value); + } + } + fn set_next(&mut self, value: Option>) { + // SAFETY: next is the second field in PointersInner, which is #[repr(C)]. + unsafe { + let inner = self.inner.get(); + let prev = inner as *mut Option>; + let next = prev.add(1); + ptr::write(next, value); } } } impl fmt::Debug for Pointers { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let prev = self.get_prev(); + let next = self.get_next(); f.debug_struct("Pointers") - .field("prev", &self.prev) - .field("next", &self.next) + .field("prev", &prev) + .field("next", &next) .finish() } } @@ -321,7 +396,7 @@ mod tests { } unsafe fn from_raw(ptr: NonNull) -> Pin<&'a Entry> { - Pin::new(&*ptr.as_ptr()) + Pin::new_unchecked(&*ptr.as_ptr()) } unsafe fn pointers(mut target: NonNull) -> NonNull> { @@ -361,8 +436,8 @@ mod tests { macro_rules! assert_clean { ($e:ident) => {{ - assert!($e.pointers.next.is_none()); - assert!($e.pointers.prev.is_none()); + assert!($e.pointers.get_next().is_none()); + assert!($e.pointers.get_prev().is_none()); }}; } @@ -460,8 +535,8 @@ mod tests { assert_clean!(a); assert_ptr_eq!(b, list.head); - assert_ptr_eq!(c, b.pointers.next); - assert_ptr_eq!(b, c.pointers.prev); + assert_ptr_eq!(c, b.pointers.get_next()); + assert_ptr_eq!(b, c.pointers.get_prev()); let items = collect_list(&mut list); assert_eq!([31, 7].to_vec(), items); @@ -476,8 +551,8 @@ mod tests { assert!(list.remove(ptr(&b)).is_some()); assert_clean!(b); - assert_ptr_eq!(c, a.pointers.next); - assert_ptr_eq!(a, c.pointers.prev); + assert_ptr_eq!(c, a.pointers.get_next()); + assert_ptr_eq!(a, c.pointers.get_prev()); let items = collect_list(&mut list); assert_eq!([31, 5].to_vec(), items); @@ -493,7 +568,7 @@ mod tests { assert!(list.remove(ptr(&c)).is_some()); assert_clean!(c); - assert!(b.pointers.next.is_none()); + assert!(b.pointers.get_next().is_none()); assert_ptr_eq!(b, list.tail); let items = collect_list(&mut list); @@ -516,8 +591,8 @@ mod tests { assert_ptr_eq!(b, list.head); assert_ptr_eq!(b, list.tail); - assert!(b.pointers.next.is_none()); - assert!(b.pointers.prev.is_none()); + assert!(b.pointers.get_next().is_none()); + assert!(b.pointers.get_prev().is_none()); let items = collect_list(&mut list); assert_eq!([7].to_vec(), items); @@ -536,8 +611,8 @@ mod tests { assert_ptr_eq!(a, list.head); assert_ptr_eq!(a, list.tail); - assert!(a.pointers.next.is_none()); - assert!(a.pointers.prev.is_none()); + assert!(a.pointers.get_next().is_none()); + assert!(a.pointers.get_prev().is_none()); let items = collect_list(&mut list); assert_eq!([5].to_vec(), items);