From f583c4acb75b3093796abbce9137429bccf20dd1 Mon Sep 17 00:00:00 2001 From: Jeonghyeon Kim Date: Sat, 28 Sep 2024 01:40:54 +0900 Subject: [PATCH] Implement standard traits for pointers --- src/ebr_impl/pointers.rs | 29 ++++--- src/strong.rs | 171 +++++++++++++++++++++++++++++++++------ src/weak.rs | 81 +++++++++++++++++-- tests/harris_list.rs | 2 +- 4 files changed, 240 insertions(+), 43 deletions(-) diff --git a/src/ebr_impl/pointers.rs b/src/ebr_impl/pointers.rs index 8ffc879..2aa504e 100644 --- a/src/ebr_impl/pointers.rs +++ b/src/ebr_impl/pointers.rs @@ -3,6 +3,7 @@ use core::marker::PhantomData; use core::mem::align_of; use core::ptr::null_mut; use core::sync::atomic::AtomicUsize; +use std::fmt::{Debug, Formatter, Pointer}; use atomic::{Atomic, Ordering}; @@ -12,6 +13,18 @@ pub struct Tagged { ptr: *mut T, } +impl Debug for Tagged { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + Pointer::fmt(&self.as_raw(), f) + } +} + +impl Pointer for Tagged { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + Pointer::fmt(&self.as_raw(), f) + } +} + impl Default for Tagged { fn default() -> Self { Self { ptr: null_mut() } @@ -26,14 +39,6 @@ impl Clone for Tagged { impl Copy for Tagged {} -impl PartialEq for Tagged { - fn eq(&self, other: &Self) -> bool { - self.with_high_tag(0).ptr == other.with_high_tag(0).ptr - } -} - -impl Eq for Tagged {} - impl Hash for Tagged { fn hash(&self, state: &mut H) { self.ptr.hash(state) @@ -125,6 +130,11 @@ impl Tagged { Some(self.deref()) } } + + pub fn ptr_eq(self, other: Self) -> bool { + // Ignore the epoch tags, and compare between pointer values. + self.with_high_tag(0).ptr == other.with_high_tag(0).ptr + } } /// Returns a bitmask containing the unused least significant bits of an aligned pointer to `T`. @@ -196,6 +206,7 @@ impl RawAtomic { } } +// A shared pointer type only for the internal EBR implementation. pub(crate) struct RawShared<'g, T> { inner: Tagged, _marker: PhantomData<&'g T>, @@ -238,7 +249,7 @@ impl<'g, T> From> for RawShared<'g, T> { impl<'g, T> PartialEq for RawShared<'g, T> { fn eq(&self, other: &Self) -> bool { - self.inner == other.inner + self.inner.ptr_eq(other.inner) } } diff --git a/src/strong.rs b/src/strong.rs index 50245be..f2985e5 100644 --- a/src/strong.rs +++ b/src/strong.rs @@ -1,5 +1,7 @@ use std::{ array, + fmt::{Debug, Formatter, Pointer}, + hash::{Hash, Hasher}, marker::PhantomData, mem::{forget, size_of}, sync::atomic::{AtomicUsize, Ordering}, @@ -196,24 +198,24 @@ impl AtomicRc { failure: Ordering, guard: &'g Guard, ) -> Result, CompareExchangeError, Snapshot<'g, T>>> { - let mut expected_ptr = expected.ptr; - let desired_ptr = desired.ptr.with_timestamp(); + let mut expected_raw = expected.ptr; + let desired_raw = desired.ptr.with_timestamp(); loop { match self .link - .compare_exchange(expected_ptr, desired_ptr, success, failure) + .compare_exchange(expected_raw, desired_raw, success, failure) { Ok(_) => { // Skip decrementing a strong count of the inserted pointer. forget(desired); - let rc = Rc::from_raw(expected_ptr); + let rc = Rc::from_raw(expected_raw); return Ok(rc); } - Err(current) => { - if current.with_high_tag(0) == expected_ptr.with_high_tag(0) { - expected_ptr = current; + Err(current_raw) => { + if current_raw.ptr_eq(expected_raw) { + expected_raw = current_raw; } else { - let current = Snapshot::from_raw(current, guard); + let current = Snapshot::from_raw(current_raw, guard); return Err(CompareExchangeError { desired, current }); } } @@ -248,24 +250,24 @@ impl AtomicRc { failure: Ordering, guard: &'g Guard, ) -> Result, CompareExchangeError, Snapshot<'g, T>>> { - let mut expected_ptr = expected.ptr; - let desired_ptr = desired.ptr.with_timestamp(); + let mut expected_raw = expected.ptr; + let desired_raw = desired.ptr.with_timestamp(); loop { match self .link - .compare_exchange_weak(expected_ptr, desired_ptr, success, failure) + .compare_exchange_weak(expected_raw, desired_raw, success, failure) { Ok(_) => { // Skip decrementing a strong count of the inserted pointer. forget(desired); - let rc = Rc::from_raw(expected_ptr); + let rc = Rc::from_raw(expected_raw); return Ok(rc); } - Err(current) => { - if current.with_high_tag(0) == expected_ptr.with_high_tag(0) { - expected_ptr = current; + Err(current_raw) => { + if current_raw.ptr_eq(expected_raw) { + expected_raw = current_raw; } else { - let current = Snapshot::from_raw(current, guard); + let current = Snapshot::from_raw(current_raw, guard); return Err(CompareExchangeError { desired, current }); } } @@ -312,14 +314,14 @@ impl AtomicRc { .link .compare_exchange(expected_raw, desired_raw, success, failure) { - Ok(current) => return Ok(Snapshot::from_raw(current, guard)), - Err(current) => { - if current.with_high_tag(0) == expected_raw.with_high_tag(0) { - expected_raw = current; + Ok(current_raw) => return Ok(Snapshot::from_raw(current_raw, guard)), + Err(current_raw) => { + if current_raw.ptr_eq(expected_raw) { + expected_raw = current_raw; } else { return Err(CompareExchangeError { desired: Snapshot::from_raw(desired_raw, guard), - current: Snapshot::from_raw(current, guard), + current: Snapshot::from_raw(current_raw, guard), }); } } @@ -381,6 +383,25 @@ impl From> for AtomicRc { } } +impl From<&Rc> for AtomicRc { + #[inline] + fn from(value: &Rc) -> Self { + Self::from(value.clone()) + } +} + +impl Debug for AtomicRc { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + Debug::fmt(&self.link.load(Ordering::Relaxed), f) + } +} + +impl Pointer for AtomicRc { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + Pointer::fmt(&self.link.load(Ordering::Relaxed), f) + } +} + /// A reference-counted pointer to an object of type `T`. /// /// When `T` implements [`Send`] and [`Sync`], [`Rc`] also implements these traits. @@ -588,6 +609,35 @@ impl Rc { Some(unsafe { self.deref_mut() }) } } + + /// Returns true if the two [`Rc`]s point to the same allocation in a vein similar to + /// [`std::ptr::eq`]. + #[inline] + pub fn ptr_eq(&self, other: &Self) -> bool { + self.ptr.ptr_eq(other.ptr) + } +} + +impl<'g, T: RcObject> From> for Rc { + fn from(value: Snapshot<'g, T>) -> Self { + value.counted() + } +} + +impl Debug for Rc { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + if let Some(cnt) = self.as_ref() { + f.debug_tuple("RcObject").field(cnt).finish() + } else { + f.write_str("Null") + } + } +} + +impl Pointer for Rc { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + Pointer::fmt(&self.ptr, f) + } } impl Default for Rc { @@ -608,10 +658,34 @@ impl Drop for Rc { } } -impl PartialEq for Rc { +impl PartialEq for Rc { #[inline(always)] fn eq(&self, other: &Self) -> bool { - self.ptr == other.ptr + match (self.as_ref(), other.as_ref()) { + (None, None) => true, + (None, Some(_)) | (Some(_), None) => false, + (Some(x), Some(y)) => x.eq(y), + } + } +} + +impl Eq for Rc {} + +impl Hash for Rc { + fn hash(&self, state: &mut H) { + self.as_ref().hash(state); + } +} + +impl PartialOrd for Rc { + fn partial_cmp(&self, other: &Self) -> Option { + self.as_ref().partial_cmp(&other.as_ref()) + } +} + +impl Ord for Rc { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.as_ref().cmp(&other.as_ref()) } } @@ -775,6 +849,13 @@ impl<'g, T: RcObject> Snapshot<'g, T> { Some(unsafe { self.deref_mut() }) } } + + /// Returns true if the two [`Rc`]s point to the same allocation in a vein similar to + /// [`std::ptr::eq`]. + #[inline] + pub fn ptr_eq(self, other: Self) -> bool { + self.ptr.ptr_eq(other.ptr) + } } impl<'g, T> Snapshot<'g, T> { @@ -803,9 +884,49 @@ impl<'g, T: RcObject> Default for Snapshot<'g, T> { } } -impl<'g, T> PartialEq for Snapshot<'g, T> { +impl<'g, T: RcObject + PartialEq> PartialEq for Snapshot<'g, T> { #[inline(always)] fn eq(&self, other: &Self) -> bool { - self.ptr.eq(&other.ptr) + match (self.as_ref(), other.as_ref()) { + (None, None) => true, + (None, Some(_)) | (Some(_), None) => false, + (Some(x), Some(y)) => x.eq(y), + } + } +} + +impl<'g, T: RcObject + Eq> Eq for Snapshot<'g, T> {} + +impl<'g, T: RcObject + Hash> Hash for Snapshot<'g, T> { + fn hash(&self, state: &mut H) { + self.as_ref().hash(state); + } +} + +impl<'g, T: RcObject + PartialOrd> PartialOrd for Snapshot<'g, T> { + fn partial_cmp(&self, other: &Self) -> Option { + self.as_ref().partial_cmp(&other.as_ref()) + } +} + +impl<'g, T: RcObject + Ord> Ord for Snapshot<'g, T> { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.as_ref().cmp(&other.as_ref()) + } +} + +impl<'g, T: RcObject + Debug> Debug for Snapshot<'g, T> { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + if let Some(cnt) = self.as_ref() { + f.debug_tuple("RcObject").field(cnt).finish() + } else { + f.write_str("Null") + } + } +} + +impl<'g, T: RcObject> Pointer for Snapshot<'g, T> { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + Pointer::fmt(&self.ptr, f) } } diff --git a/src/weak.rs b/src/weak.rs index 837e9d0..0b7083e 100644 --- a/src/weak.rs +++ b/src/weak.rs @@ -1,4 +1,5 @@ use std::{ + fmt::{Debug, Formatter, Pointer}, marker::PhantomData, mem::{forget, size_of}, sync::atomic::{AtomicUsize, Ordering}, @@ -232,6 +233,32 @@ impl From> for AtomicWeak { } } +impl From<&Weak> for AtomicWeak { + #[inline] + fn from(value: &Weak) -> Self { + Self::from(value.clone()) + } +} + +impl From<&Rc> for AtomicWeak { + #[inline] + fn from(value: &Rc) -> Self { + Self::from(value.downgrade()) + } +} + +impl Debug for AtomicWeak { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + Debug::fmt(&self.link.load(Ordering::Relaxed), f) + } +} + +impl Pointer for AtomicWeak { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + Pointer::fmt(&self.link.load(Ordering::Relaxed), f) + } +} + impl Drop for AtomicWeak { #[inline(always)] fn drop(&mut self) { @@ -336,6 +363,11 @@ impl Weak { ptr.increment_weak(1); } } + + #[inline] + pub fn ptr_eq(&self, other: &Self) -> bool { + self.ptr.ptr_eq(other.ptr) + } } impl Weak { @@ -364,10 +396,27 @@ impl Drop for Weak { } } -impl PartialEq for Weak { - #[inline(always)] - fn eq(&self, other: &Self) -> bool { - self.ptr == other.ptr +impl<'g, T> From> for Weak { + fn from(value: WeakSnapshot<'g, T>) -> Self { + value.counted() + } +} + +impl<'g, T: RcObject> From> for Weak { + fn from(value: Snapshot<'g, T>) -> Self { + value.downgrade().counted() + } +} + +impl Debug for Weak { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + Debug::fmt(&self.ptr, f) + } +} + +impl Pointer for Weak { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + Pointer::fmt(&self.ptr, f) } } @@ -451,6 +500,11 @@ impl<'g, T> WeakSnapshot<'g, T> { _marker: PhantomData, } } + + #[inline] + pub fn ptr_eq(self, other: Self) -> bool { + self.ptr.ptr_eq(other.ptr) + } } impl<'g, T> Default for WeakSnapshot<'g, T> { @@ -460,9 +514,20 @@ impl<'g, T> Default for WeakSnapshot<'g, T> { } } -impl<'g, T> PartialEq for WeakSnapshot<'g, T> { - #[inline(always)] - fn eq(&self, other: &Self) -> bool { - self.ptr.eq(&other.ptr) +impl<'g, T: RcObject> From> for WeakSnapshot<'g, T> { + fn from(value: Snapshot<'g, T>) -> Self { + value.downgrade() + } +} + +impl<'g, T> Debug for WeakSnapshot<'g, T> { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + Debug::fmt(&self.ptr, f) + } +} + +impl<'g, T> Pointer for WeakSnapshot<'g, T> { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + Pointer::fmt(&self.ptr, f) } } diff --git a/tests/harris_list.rs b/tests/harris_list.rs index 3318efd..73e0ba3 100644 --- a/tests/harris_list.rs +++ b/tests/harris_list.rs @@ -105,7 +105,7 @@ impl<'g, K: Ord, V> Cursor<'g, K, V> { }; // If prev and curr WERE adjacent, no need to clean up - if prev_next == self.curr { + if prev_next.ptr_eq(self.curr) { return Ok(found); }