Skip to content

Commit

Permalink
Implement standard traits for pointers
Browse files Browse the repository at this point in the history
  • Loading branch information
powergee committed Sep 27, 2024
1 parent f7825ba commit f583c4a
Show file tree
Hide file tree
Showing 4 changed files with 240 additions and 43 deletions.
29 changes: 20 additions & 9 deletions src/ebr_impl/pointers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use core::marker::PhantomData;
use core::mem::align_of;
use core::ptr::null_mut;
use core::sync::atomic::AtomicUsize;
use std::fmt::{Debug, Formatter, Pointer};

use atomic::{Atomic, Ordering};

Expand All @@ -12,6 +13,18 @@ pub struct Tagged<T: ?Sized> {
ptr: *mut T,
}

impl<T> Debug for Tagged<T> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
Pointer::fmt(&self.as_raw(), f)
}
}

impl<T> Pointer for Tagged<T> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
Pointer::fmt(&self.as_raw(), f)
}
}

impl<T> Default for Tagged<T> {
fn default() -> Self {
Self { ptr: null_mut() }
Expand All @@ -26,14 +39,6 @@ impl<T> Clone for Tagged<T> {

impl<T> Copy for Tagged<T> {}

impl<T> PartialEq for Tagged<T> {
fn eq(&self, other: &Self) -> bool {
self.with_high_tag(0).ptr == other.with_high_tag(0).ptr
}
}

impl<T> Eq for Tagged<T> {}

impl<T> Hash for Tagged<T> {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.ptr.hash(state)
Expand Down Expand Up @@ -125,6 +130,11 @@ impl<T> Tagged<T> {
Some(self.deref())
}
}

pub fn ptr_eq(self, other: Self) -> bool {
// Ignore the epoch tags, and compare between pointer values.
self.with_high_tag(0).ptr == other.with_high_tag(0).ptr
}
}

/// Returns a bitmask containing the unused least significant bits of an aligned pointer to `T`.
Expand Down Expand Up @@ -196,6 +206,7 @@ impl<T> RawAtomic<T> {
}
}

// A shared pointer type only for the internal EBR implementation.
pub(crate) struct RawShared<'g, T> {
inner: Tagged<T>,
_marker: PhantomData<&'g T>,
Expand Down Expand Up @@ -238,7 +249,7 @@ impl<'g, T> From<Tagged<T>> for RawShared<'g, T> {

impl<'g, T> PartialEq for RawShared<'g, T> {
fn eq(&self, other: &Self) -> bool {
self.inner == other.inner
self.inner.ptr_eq(other.inner)
}
}

Expand Down
171 changes: 146 additions & 25 deletions src/strong.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use std::{
array,
fmt::{Debug, Formatter, Pointer},
hash::{Hash, Hasher},
marker::PhantomData,
mem::{forget, size_of},
sync::atomic::{AtomicUsize, Ordering},
Expand Down Expand Up @@ -196,24 +198,24 @@ impl<T: RcObject> AtomicRc<T> {
failure: Ordering,
guard: &'g Guard,
) -> Result<Rc<T>, CompareExchangeError<Rc<T>, Snapshot<'g, T>>> {
let mut expected_ptr = expected.ptr;
let desired_ptr = desired.ptr.with_timestamp();
let mut expected_raw = expected.ptr;
let desired_raw = desired.ptr.with_timestamp();
loop {
match self
.link
.compare_exchange(expected_ptr, desired_ptr, success, failure)
.compare_exchange(expected_raw, desired_raw, success, failure)
{
Ok(_) => {
// Skip decrementing a strong count of the inserted pointer.
forget(desired);
let rc = Rc::from_raw(expected_ptr);
let rc = Rc::from_raw(expected_raw);
return Ok(rc);
}
Err(current) => {
if current.with_high_tag(0) == expected_ptr.with_high_tag(0) {
expected_ptr = current;
Err(current_raw) => {
if current_raw.ptr_eq(expected_raw) {
expected_raw = current_raw;
} else {
let current = Snapshot::from_raw(current, guard);
let current = Snapshot::from_raw(current_raw, guard);
return Err(CompareExchangeError { desired, current });
}
}
Expand Down Expand Up @@ -248,24 +250,24 @@ impl<T: RcObject> AtomicRc<T> {
failure: Ordering,
guard: &'g Guard,
) -> Result<Rc<T>, CompareExchangeError<Rc<T>, Snapshot<'g, T>>> {
let mut expected_ptr = expected.ptr;
let desired_ptr = desired.ptr.with_timestamp();
let mut expected_raw = expected.ptr;
let desired_raw = desired.ptr.with_timestamp();
loop {
match self
.link
.compare_exchange_weak(expected_ptr, desired_ptr, success, failure)
.compare_exchange_weak(expected_raw, desired_raw, success, failure)
{
Ok(_) => {
// Skip decrementing a strong count of the inserted pointer.
forget(desired);
let rc = Rc::from_raw(expected_ptr);
let rc = Rc::from_raw(expected_raw);
return Ok(rc);
}
Err(current) => {
if current.with_high_tag(0) == expected_ptr.with_high_tag(0) {
expected_ptr = current;
Err(current_raw) => {
if current_raw.ptr_eq(expected_raw) {
expected_raw = current_raw;
} else {
let current = Snapshot::from_raw(current, guard);
let current = Snapshot::from_raw(current_raw, guard);
return Err(CompareExchangeError { desired, current });
}
}
Expand Down Expand Up @@ -312,14 +314,14 @@ impl<T: RcObject> AtomicRc<T> {
.link
.compare_exchange(expected_raw, desired_raw, success, failure)
{
Ok(current) => return Ok(Snapshot::from_raw(current, guard)),
Err(current) => {
if current.with_high_tag(0) == expected_raw.with_high_tag(0) {
expected_raw = current;
Ok(current_raw) => return Ok(Snapshot::from_raw(current_raw, guard)),
Err(current_raw) => {
if current_raw.ptr_eq(expected_raw) {
expected_raw = current_raw;
} else {
return Err(CompareExchangeError {
desired: Snapshot::from_raw(desired_raw, guard),
current: Snapshot::from_raw(current, guard),
current: Snapshot::from_raw(current_raw, guard),
});
}
}
Expand Down Expand Up @@ -381,6 +383,25 @@ impl<T: RcObject> From<Rc<T>> for AtomicRc<T> {
}
}

impl<T: RcObject> From<&Rc<T>> for AtomicRc<T> {
#[inline]
fn from(value: &Rc<T>) -> Self {
Self::from(value.clone())
}
}

impl<T: RcObject> Debug for AtomicRc<T> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
Debug::fmt(&self.link.load(Ordering::Relaxed), f)
}
}

impl<T: RcObject> Pointer for AtomicRc<T> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
Pointer::fmt(&self.link.load(Ordering::Relaxed), f)
}
}

/// A reference-counted pointer to an object of type `T`.
///
/// When `T` implements [`Send`] and [`Sync`], [`Rc<T>`] also implements these traits.
Expand Down Expand Up @@ -588,6 +609,35 @@ impl<T: RcObject> Rc<T> {
Some(unsafe { self.deref_mut() })
}
}

/// Returns true if the two [`Rc`]s point to the same allocation in a vein similar to
/// [`std::ptr::eq`].
#[inline]
pub fn ptr_eq(&self, other: &Self) -> bool {
self.ptr.ptr_eq(other.ptr)
}
}

impl<'g, T: RcObject> From<Snapshot<'g, T>> for Rc<T> {
fn from(value: Snapshot<'g, T>) -> Self {
value.counted()
}
}

impl<T: RcObject + Debug> Debug for Rc<T> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
if let Some(cnt) = self.as_ref() {
f.debug_tuple("RcObject").field(cnt).finish()
} else {
f.write_str("Null")
}
}
}

impl<T: RcObject> Pointer for Rc<T> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
Pointer::fmt(&self.ptr, f)
}
}

impl<T: RcObject> Default for Rc<T> {
Expand All @@ -608,10 +658,34 @@ impl<T: RcObject> Drop for Rc<T> {
}
}

impl<T: RcObject> PartialEq for Rc<T> {
impl<T: RcObject + PartialEq> PartialEq for Rc<T> {
#[inline(always)]
fn eq(&self, other: &Self) -> bool {
self.ptr == other.ptr
match (self.as_ref(), other.as_ref()) {
(None, None) => true,
(None, Some(_)) | (Some(_), None) => false,
(Some(x), Some(y)) => x.eq(y),
}
}
}

impl<T: RcObject + Eq> Eq for Rc<T> {}

impl<T: RcObject + Hash> Hash for Rc<T> {
fn hash<H: Hasher>(&self, state: &mut H) {
self.as_ref().hash(state);
}
}

impl<T: RcObject + PartialOrd> PartialOrd for Rc<T> {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
self.as_ref().partial_cmp(&other.as_ref())
}
}

impl<T: RcObject + Ord> Ord for Rc<T> {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.as_ref().cmp(&other.as_ref())
}
}

Expand Down Expand Up @@ -775,6 +849,13 @@ impl<'g, T: RcObject> Snapshot<'g, T> {
Some(unsafe { self.deref_mut() })
}
}

/// Returns true if the two [`Rc`]s point to the same allocation in a vein similar to
/// [`std::ptr::eq`].
#[inline]
pub fn ptr_eq(self, other: Self) -> bool {
self.ptr.ptr_eq(other.ptr)
}
}

impl<'g, T> Snapshot<'g, T> {
Expand Down Expand Up @@ -803,9 +884,49 @@ impl<'g, T: RcObject> Default for Snapshot<'g, T> {
}
}

impl<'g, T> PartialEq for Snapshot<'g, T> {
impl<'g, T: RcObject + PartialEq> PartialEq for Snapshot<'g, T> {
#[inline(always)]
fn eq(&self, other: &Self) -> bool {
self.ptr.eq(&other.ptr)
match (self.as_ref(), other.as_ref()) {
(None, None) => true,
(None, Some(_)) | (Some(_), None) => false,
(Some(x), Some(y)) => x.eq(y),
}
}
}

impl<'g, T: RcObject + Eq> Eq for Snapshot<'g, T> {}

impl<'g, T: RcObject + Hash> Hash for Snapshot<'g, T> {
fn hash<H: Hasher>(&self, state: &mut H) {
self.as_ref().hash(state);
}
}

impl<'g, T: RcObject + PartialOrd> PartialOrd for Snapshot<'g, T> {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
self.as_ref().partial_cmp(&other.as_ref())
}
}

impl<'g, T: RcObject + Ord> Ord for Snapshot<'g, T> {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.as_ref().cmp(&other.as_ref())
}
}

impl<'g, T: RcObject + Debug> Debug for Snapshot<'g, T> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
if let Some(cnt) = self.as_ref() {
f.debug_tuple("RcObject").field(cnt).finish()
} else {
f.write_str("Null")
}
}
}

impl<'g, T: RcObject> Pointer for Snapshot<'g, T> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
Pointer::fmt(&self.ptr, f)
}
}
Loading

0 comments on commit f583c4a

Please sign in to comment.