diff --git a/Cargo.toml b/Cargo.toml index 93248ac4..4bf669b0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -28,6 +28,7 @@ metrics = [] byteorder = "1.5.0" crossbeam-skiplist = "0.1.3" enum_dispatch = "0.3.13" +equivalent = "1.0.2" interval-heap = "0.0.5" log = "0.4.27" lz4_flex = { version = "0.11.5", optional = true, default-features = false } @@ -44,6 +45,7 @@ xxhash-rust = { version = "0.8.15", features = ["xxh3"] } criterion = { version = "0.5.1", features = ["html_reports"] } fs_extra = "1.3.0" nanoid = "0.4.0" +quickcheck = "1.0.3" rand = "0.9.2" test-log = "0.2.18" diff --git a/benches/memtable.rs b/benches/memtable.rs index e7d201fe..14d7813c 100644 --- a/benches/memtable.rs +++ b/benches/memtable.rs @@ -1,5 +1,5 @@ use criterion::{criterion_group, criterion_main, Criterion}; -use lsm_tree::{InternalValue, Memtable}; +use lsm_tree::{InternalValue, Memtable, SeqNo}; use nanoid::nanoid; fn memtable_get_hit(c: &mut Criterion) { @@ -25,7 +25,10 @@ fn memtable_get_hit(c: &mut Criterion) { b.iter(|| { assert_eq!( [1, 2, 3], - &*memtable.get(b"abc_w5wa35aw35naw", None).unwrap().value, + &*memtable + .get(b"abc_w5wa35aw35naw", SeqNo::MAX) + .unwrap() + .value, ) }); }); @@ -60,7 +63,10 @@ fn memtable_get_snapshot(c: &mut Criterion) { b.iter(|| { assert_eq!( [1, 2, 3], - &*memtable.get(b"abc_w5wa35aw35naw", Some(1)).unwrap().value, + &*memtable + .get(b"abc_w5wa35aw35naw", SeqNo::MAX) + .unwrap() + .value, ); }); }); @@ -79,7 +85,7 @@ fn memtable_get_miss(c: &mut Criterion) { } c.bench_function("memtable get miss", |b| { - b.iter(|| assert!(memtable.get(b"abc_564321", None).is_none())); + b.iter(|| assert!(memtable.get(b"abc_564321", SeqNo::MAX).is_none())); }); } diff --git a/src/key.rs b/src/key.rs index 24e42032..6126e5a9 100644 --- a/src/key.rs +++ b/src/key.rs @@ -7,6 +7,7 @@ use crate::{ SeqNo, UserKey, ValueType, }; use byteorder::{ReadBytesExt, WriteBytesExt}; +use equivalent::{Comparable, Equivalent}; use std::{ cmp::Reverse, io::{Read, Write}, @@ -131,37 +132,27 @@ impl Ord for InternalKey { } } -// TODO: wait for new crossbeam-skiplist -// TODO: https://github.com/crossbeam-rs/crossbeam/pull/1162 -// -// impl Equivalent> for InternalKey { -// fn equivalent(&self, other: &InternalKeyRef<'_>) -> bool { -// self.user_key == other.user_key && self.seqno == other.seqno -// } -// } +impl Equivalent> for InternalKey { + fn equivalent(&self, other: &InternalKeyRef<'_>) -> bool { + self.user_key == other.user_key && self.seqno == other.seqno + } +} -// impl Comparable> for InternalKey { -// fn compare(&self, other: &InternalKeyRef<'_>) -> std::cmp::Ordering { -// (&*self.user_key, Reverse(self.seqno)).cmp(&(other.user_key, Reverse(other.seqno))) -// } -// } +impl Comparable> for InternalKey { + fn compare(&self, other: &InternalKeyRef<'_>) -> std::cmp::Ordering { + (&*self.user_key, Reverse(self.seqno)).cmp(&(other.user_key, Reverse(other.seqno))) + } +} -/* /// Temporary internal key without heap allocation -#[derive(Clone, Debug, Eq)] +/// Temporary internal key without heap allocation +#[derive(Debug, Eq)] pub struct InternalKeyRef<'a> { pub user_key: &'a [u8], pub seqno: SeqNo, pub value_type: ValueType, } -impl<'a> AsRef<[u8]> for InternalKeyRef<'a> { - fn as_ref(&self) -> &[u8] { - self.user_key - } -} - impl<'a> InternalKeyRef<'a> { - // Constructor for InternalKeyRef pub fn new(user_key: &'a [u8], seqno: u64, value_type: ValueType) -> Self { InternalKeyRef { user_key, @@ -187,4 +178,16 @@ impl<'a> Ord for InternalKeyRef<'a> { fn cmp(&self, other: &Self) -> std::cmp::Ordering { (&self.user_key, Reverse(self.seqno)).cmp(&(&other.user_key, Reverse(other.seqno))) } -} */ +} + +impl Equivalent for InternalKeyRef<'_> { + fn equivalent(&self, other: &InternalKey) -> bool { + self.user_key == other.user_key && self.seqno == other.seqno + } +} + +impl Comparable for InternalKeyRef<'_> { + fn compare(&self, other: &InternalKey) -> std::cmp::Ordering { + (self.user_key, Reverse(self.seqno)).cmp(&(&other.user_key, Reverse(other.seqno))) + } +} diff --git a/src/memtable/mod.rs b/src/memtable/mod.rs index c54d53bd..895003a7 100644 --- a/src/memtable/mod.rs +++ b/src/memtable/mod.rs @@ -2,9 +2,12 @@ // This source code is licensed under both the Apache 2.0 and MIT License // (found in the LICENSE-* files in the repository) -use crate::key::InternalKey; +#[allow(unsafe_code)] +mod skiplist; + +use crate::key::{InternalKey, InternalKeyRef}; use crate::value::{InternalValue, SeqNo, UserValue, ValueType}; -use crossbeam_skiplist::SkipMap; +use skiplist::SkipMap; use std::ops::RangeBounds; use std::sync::atomic::AtomicU64; @@ -31,7 +34,7 @@ pub struct Memtable { impl Memtable { /// Clears the memtable. pub fn clear(&mut self) { - self.items.clear(); + self.items = SkipMap::default(); self.highest_seqno = AtomicU64::new(0); self.approximate_size .store(0, std::sync::atomic::Ordering::Release); @@ -81,7 +84,7 @@ impl Memtable { // abcdef -> 6 // abcdef -> 5 // - let lower_bound = InternalKey::new(key, seqno - 1, ValueType::Value); + let lower_bound = InternalKeyRef::new(key, seqno - 1, ValueType::Value); let mut iter = self .items @@ -126,7 +129,11 @@ impl Memtable { .fetch_add(item_size, std::sync::atomic::Ordering::AcqRel); let key = InternalKey::new(item.key.user_key, item.key.seqno, item.key.value_type); - self.items.insert(key, item.value); + // TODO(ajwerner): Decide what we want to do here. The panic is sort of + // extreme, but also seems right given the invariants. + if let Err((key, _value)) = self.items.insert(key, item.value) { + panic!("duplicate insert of {key:?} into memtable") + } self.highest_seqno .fetch_max(item.key.seqno, std::sync::atomic::Ordering::AcqRel); diff --git a/src/memtable/skiplist/arena.rs b/src/memtable/skiplist/arena.rs new file mode 100644 index 00000000..45dcb7bd --- /dev/null +++ b/src/memtable/skiplist/arena.rs @@ -0,0 +1,120 @@ +// Copyright (c) 2024-present, fjall-rs +// This source code is licensed under both the Apache 2.0 and MIT License +// (found in the LICENSE-* files in the repository) + +use std::{ + alloc::Layout, + mem::offset_of, + sync::{ + atomic::{AtomicPtr, AtomicUsize, Ordering}, + Mutex, + }, +}; + +// DEFAULT_BUFFER_SIZE needs to be at least big enough for one fullly-aligned node +// for the crate to work correctly. Anything larger than that will work. +// +// TODO: Justify this size. +const DEFAULT_BUFFER_SIZE: usize = (32 << 10) - size_of::(); + +impl Default for Arenas { + fn default() -> Self { + Self::new() + } +} + +unsafe impl Send for Arenas {} +unsafe impl Sync for Arenas {} + +pub(crate) struct Arenas { + // The current set of Arenas + arenas: Mutex>>, + // Cache of the currently open Arena. It'll be the last item in the buffers + // vec. This atomic is only ever written while holding the buffers Mutex. + open_arena: AtomicPtr>, +} + +impl Arenas { + pub(crate) fn new() -> Self { + Self { + arenas: Mutex::default(), + open_arena: AtomicPtr::default(), + } + } +} + +impl Arenas { + pub(crate) fn alloc(&self, layout: Layout) -> *mut u8 { + loop { + let buffer_tail = self.open_arena.load(Ordering::Acquire); + if !buffer_tail.is_null() { + if let Some(offset) = try_alloc(buffer_tail, layout) { + return offset; + } + } + + let mut buffers = self.arenas.lock().expect("lock is poisoned"); + let buffer = buffers.last().unwrap_or(&std::ptr::null_mut()); + if *buffer != buffer_tail { + // Lost the race with somebody else. + continue; + } + + let new_buffer: Box> = Box::default(); + let new_buffer = Box::into_raw(new_buffer); + self.open_arena.store(new_buffer, Ordering::Release); + buffers.push(new_buffer); + } + } +} + +struct Buffer { + offset: AtomicUsize, + data: [u8; N], +} + +impl Default for Buffer { + fn default() -> Self { + Self { + offset: AtomicUsize::default(), + data: [0; N], + } + } +} + +impl Drop for Arenas { + fn drop(&mut self) { + let mut buffers = self.arenas.lock().expect("lock is poisoned"); + + for buffer in buffers.drain(..) { + drop(unsafe { Box::from_raw(buffer) }); + } + } +} + +fn try_alloc(buf: *mut Buffer, layout: Layout) -> Option<*mut u8> { + let mut cur_offset = unsafe { &(*buf).offset }.load(Ordering::Relaxed); + + loop { + let buf_start = unsafe { buf.byte_add(offset_of!(Buffer, data)) as *mut u8 }; + let free_start = unsafe { buf_start.byte_add(cur_offset) }; + let start_addr = unsafe { free_start.byte_add(free_start.align_offset(layout.align())) }; + let new_offset = ((start_addr as usize) + layout.size()) - (buf_start as usize); + if new_offset > N { + return None; + } + + // Note that we can get away with using relaxed ordering here because we're not + // asserting anything about the contents of the buffer. We're just trying to + // allocate a new node. + match unsafe { &(*buf).offset }.compare_exchange( + cur_offset, + new_offset, + Ordering::Relaxed, + Ordering::Relaxed, + ) { + Ok(_offset) => return Some(start_addr), + Err(offset) => cur_offset = offset, + } + } +} diff --git a/src/memtable/skiplist/mod.rs b/src/memtable/skiplist/mod.rs new file mode 100644 index 00000000..bb9378c9 --- /dev/null +++ b/src/memtable/skiplist/mod.rs @@ -0,0 +1,30 @@ +// Copyright (c) 2024-present, fjall-rs +// This source code is licensed under both the Apache 2.0 and MIT License +// (found in the LICENSE-* files in the repository) + +// This implementation was heavily inspired by: +// * https://github.com/andy-kimball/arenaskl/tree/f7010085 +// * https://github.com/crossbeam-rs/crossbeam/tree/983d56b6/crossbeam-skiplist + +//! This mod is a purpose-built concurrent skiplist intended for use +//! by the memtable. +//! +//! Due to the requirements of memtable, there are a number of notable in the +//! features it lacks: +//! - Updates +//! - Deletes +//! - Overwrites +//! +//! The main reasons for its existence are that it +//! - provides concurrent reads and inserts, and +//! - batches memory allocations +//! +//! Prior to this implementation, `crossbeam_skiplist` was used. + +mod arena; +mod skipmap; + +pub use skipmap::SkipMap; + +#[cfg(test)] +mod test; diff --git a/src/memtable/skiplist/skipmap.rs b/src/memtable/skiplist/skipmap.rs new file mode 100644 index 00000000..a026992c --- /dev/null +++ b/src/memtable/skiplist/skipmap.rs @@ -0,0 +1,856 @@ +// Copyright (c) 2024-present, fjall-rs +// This source code is licensed under both the Apache 2.0 and MIT License +// (found in the LICENSE-* files in the repository) + +#![allow(unsafe_code)] + +use equivalent::Comparable; + +use super::arena::Arenas; +use std::{ + alloc::Layout, + borrow::Borrow, + hash::Hash, + marker::PhantomData, + mem::{offset_of, ManuallyDrop}, + ops::{Bound, RangeBounds}, + sync::{ + atomic::{AtomicPtr, AtomicU32, AtomicUsize, Ordering}, + LazyLock, + }, +}; + +/// A `SkipMap` is a concurrent, ordered map like a `BTreeMap` +/// but it allows for concurrent reads and writes. +/// +/// A tradeoff is that it does not allow for updates or deletions. +pub struct SkipMap { + arena: ArenasAllocator, + + head: BoundaryNode, + tail: BoundaryNode, + + seed: AtomicU32, + height: AtomicUsize, + len: AtomicUsize, +} + +impl Default for SkipMap { + fn default() -> Self { + const DEFAULT_SEED: u32 = 1; // arbitrary + Self::new(DEFAULT_SEED) + } +} + +impl SkipMap { + /// New constructs a new `[SkipMap]`. + #[warn(clippy::unwrap_used)] + pub fn new(seed: u32) -> Self { + let arena = ArenasAllocator::default(); + let head = arena.alloc(MAX_HEIGHT); + let head = NodePtr::new(head).unwrap(); + let tail = arena.alloc(MAX_HEIGHT); + let tail = NodePtr::new(tail).unwrap(); + for i in 0..MAX_HEIGHT { + head.init_next(i, tail); + tail.init_prev(i, head); + } + Self { + arena, + head: BoundaryNode::new(head), + tail: BoundaryNode::new(tail), + seed: AtomicU32::new(seed), + height: AtomicUsize::new(1), + len: AtomicUsize::new(0), + } + } + + /// Iter constructs an iterator over the complete + /// range. + pub fn iter(&self) -> Iter<'_, K, V> { + Iter::new(self) + } +} + +impl SkipMap +where + K: Ord, +{ + /// Inserts a key-value pair into the `SkipMap`. + /// + /// Returns `true` if the entry was inserted. + #[warn(clippy::unwrap_used)] + pub fn insert(&self, k: K, v: V) -> Result<(), (K, V)> { + let Some(splices) = self.seek_splices(&k) else { + return Err((k, v)); + }; + let (node, height) = self.new_node(k, v); + + #[warn(clippy::needless_range_loop)] + for level in 0..height { + #[warn(clippy::indexing_slicing)] + let mut splice = match splices[level].clone() { + Some(splice) => splice, + // This node increased the height. + None => Splice { + prev: self.head.load(), + next: self.tail.load(), + }, + }; + + loop { + let Splice { next, prev } = splice; + // +----------------+ +------------+ +----------------+ + // | prev | | nd | | next | + // | prevNextOffset |---->| | | | + // | |<----| prevOffset | | | + // | | | nextOffset |---->| | + // | | | |<----| nextPrevOffset | + // +----------------+ +------------+ +----------------+ + // + // 1. Initialize prevOffset and nextOffset to point to prev and next. + // 2. CAS prevNextOffset to repoint from next to nd. + // 3. CAS nextPrevOffset to repoint from prev to nd. + node.init_prev(level, prev); + node.init_next(level, next); + + // Check whether next has an updated link to prev. If it does not, + // that can mean one of two things: + // 1. The thread that added the next node hasn't yet had a chance + // to add the prev link (but will shortly). + // 2. Another thread has added a new node between prev and next. + let next_prev = next.load_prev(level).unwrap(); + + if next_prev != prev { + // Determine whether #1 or #2 is true by checking whether prev + // is still pointing to next. As long as the atomic operations + // have at least acquire/release semantics (no need for + // sequential consistency), this works, as it is equivalent to + // the "publication safety" pattern. + let prev_next = prev.load_next(level).unwrap(); + if prev_next == next { + let _ = next.cas_prev(level, next_prev, prev); + } + } + + if prev.cas_next(level, next, node).is_ok() { + // Either we succeed, or somebody else fixed up our link above. + let _ = next.cas_prev(level, prev, node); + break; + } + + splice = match Self::find_splice_for_level(node.key(), level, prev) { + SpliceOrMatch::Splice(splice) => splice, + SpliceOrMatch::Match(_non_null) => { + if level == 0 { + // This means we encountered a race with somebody + // else to insert the same key. In that case, we + // fail on the insert but we need to make sure that + // K and V get returned to the caller so they aren't + // leaked. However, it's worth noting that in this + // scenario, we have wasted this node object. + let NodeData { key, value } = + unsafe { ManuallyDrop::take(&mut (*node.0).data) }; + return Err((key, value)); + } + + // This shouldn't be possible because we go from level 0 + // up the tower. If some other insert of the same key + // succeeded we should have found it and bailed. + panic!("concurrent insert of identical key") + } + } + } + } + + self.len.fetch_add(1, Ordering::Relaxed); + + Ok(()) + } + + /// Returns a ranged iterator over the `SkipMap`. + pub fn range(&self, range: R) -> Range<'_, K, V, Q, R> + where + K: Comparable, + R: RangeBounds, + Q: ?Sized, + { + Range { + map: self, + range, + exhausted: false, + next: None, + next_back: None, + called: 0, + _phantom: PhantomData, + } + } + + /// Returns `true` if the map is empty. + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Returns the current number of entries in the `SkipMap`. + pub fn len(&self) -> usize { + self.len.load(Ordering::Relaxed) + } + + /// Returns the current height of the `SkipMap`. + pub fn height(&self) -> usize { + self.height.load(Ordering::Relaxed) + } + + // Search for the node that comes before the bound in the SkipMap. + #[warn(clippy::unwrap_used)] + fn find_from_node(&self, bounds: Bound<&Q>) -> NodePtr + where + K: Comparable, + Q: Comparable + Ord + ?Sized, + { + match bounds { + std::ops::Bound::Included(v) => match self.seek_for_base_splice(v) { + SpliceOrMatch::Splice(splice) => splice.prev, + SpliceOrMatch::Match(node) => { + // It is safe to unwrap here because matches can't match a boundary + // and there's always a boundary. + node.load_prev(0).unwrap() + } + }, + std::ops::Bound::Excluded(v) => match self.seek_for_base_splice(v) { + SpliceOrMatch::Splice(splice) => splice.prev, + SpliceOrMatch::Match(node) => node, + }, + std::ops::Bound::Unbounded => self.head.load(), + } + } + + // Search for the node that comes after the bound in the SkipMap. + #[warn(clippy::unwrap_used)] + fn find_to_node(&self, bounds: Bound<&Q>) -> NodePtr + where + K: Comparable, + Q: Comparable + Ord + ?Sized, + { + match bounds { + std::ops::Bound::Included(v) => match self.seek_for_base_splice(v) { + SpliceOrMatch::Splice(splice) => splice.next, + SpliceOrMatch::Match(node) => node.load_next(0).unwrap(), + }, + std::ops::Bound::Excluded(v) => match self.seek_for_base_splice(v) { + SpliceOrMatch::Splice(splice) => splice.next, + SpliceOrMatch::Match(node) => node, + }, + std::ops::Bound::Unbounded => self.tail.load(), + } + } + + fn new_node(&self, key: K, value: V) -> (NodePtr, usize) { + let height = self.random_height(); + let node = self.arena.alloc(height); + unsafe { (*node).data = ManuallyDrop::new(NodeData { key, value }) } + (NodePtr(node), height) + } + + fn random_height(&self) -> usize { + // Pseudorandom number generation from "Xorshift RNGs" by George Marsaglia. + // + // This particular set of operations generates 32-bit integers. See: + // https://en.wikipedia.org/wiki/Xorshift#Example_implementation + let mut num = self.seed.load(Ordering::Relaxed); + num ^= num << 13; + num ^= num >> 17; + num ^= num << 5; + self.seed.store(num, Ordering::Relaxed); + + let mut height = 1; + for &p in PROBABILITIES.iter() { + if num > p { + break; + } + height += 1; + } + + // Keep decreasing the height while it's much larger than all towers currently in the + // skip list. + let head = self.head.load(); + let tail = self.tail.load(); + while height >= 4 && head.load_next(height - 2) == Some(tail) { + height -= 1; + } + + // Track the max height to speed up lookups + let mut max_height = self.height.load(Ordering::Relaxed); + while height > max_height { + match self.height.compare_exchange( + max_height, + height, + Ordering::AcqRel, + Ordering::Acquire, + ) { + Ok(_) => break, + Err(h) => max_height = h, + } + } + + height + } + + // Finds the splice between which this key should be placed in the SkipMap, + // or the Node with the matching key if one exists. + #[warn(clippy::unwrap_used)] + fn find_splice_for_level(key: &Q, level: usize, start: NodePtr) -> SpliceOrMatch + where + K: Comparable, + Q: Comparable + Ord + ?Sized, + { + let mut prev = start; + // We can unwrap here because we know that start must be before + // our key no matter what, and the tail node is after. + let mut next = start.load_next(level).unwrap(); + + loop { + // Assume prev.key < key. + let Some(after_next) = next.load_next(level) else { + // We know that next must be tail. + return Splice { prev, next }.into(); + }; + + match key.compare(next.key()) { + std::cmp::Ordering::Less => return Splice { next, prev }.into(), + std::cmp::Ordering::Equal => return SpliceOrMatch::Match(next), + std::cmp::Ordering::Greater => { + prev = next; + next = after_next; + } + } + } + } + + // Returns the set of splices for all the levels where a key should be + // inserted. If the key already exists in the SkipMap, None is returned. + #[warn(clippy::indexing_slicing)] + fn seek_splices(&self, key: &K) -> Option> { + let mut splices = Splices::default(); + let mut level = self.height() - 1; + let mut prev = self.head.load(); + + loop { + match Self::find_splice_for_level(key.borrow(), level, prev) { + SpliceOrMatch::Splice(splice) => { + prev = splice.prev; + splices[level] = Some(splice) + } + SpliceOrMatch::Match(_match) => break None, + } + + if level == 0 { + break Some(splices); + } + + level -= 1; + } + } + + fn seek_for_base_splice(&self, key: &Q) -> SpliceOrMatch + where + K: Comparable, + Q: Comparable + Ord + ?Sized, + { + let mut level = self.height() - 1; + let mut prev = self.head.load(); + + loop { + match Self::find_splice_for_level(key, level, prev) { + n @ SpliceOrMatch::Match(_) => return n, + s @ SpliceOrMatch::Splice(_) if level == 0 => return s, + SpliceOrMatch::Splice(s) => { + prev = s.prev; + level -= 1; + } + } + } + } +} + +// It is important to run the drop action associated with the data +// inserted into the SkipMap in order to not leak memory. +// +// This implementation is somewhat unfortunate in that it's going to +// bounce around the SkipMap in sorted order. +// +// TODO: Perhaps a better design would be to keep nodes densely in +// the arenas so that it was possible to iterate through the initialized +// nodes without needing to traverse the links when dropping for better +// memory locality. A downside there is that we'd need to keep fixed-sized +// nodes. Perhaps a reasonable solution there might be to have only towers +// taller than 1 out-of-line and then we could iterate all the nodes more +// cheaply. +impl Drop for SkipMap { + fn drop(&mut self) { + if std::mem::needs_drop::() || std::mem::needs_drop::() { + self.iter() + .for_each(|entry| unsafe { ManuallyDrop::drop(&mut (*entry.node.0).data) }); + } + } +} + +const MAX_HEIGHT: usize = 20; + +// Precompute the value thresholds for given node heights for all levels other +// than the first level, where all nodes will have links. +static PROBABILITIES: LazyLock<[u32; MAX_HEIGHT - 1]> = LazyLock::new(|| { + let mut probabilities = [0u32; MAX_HEIGHT - 1]; + const P_VALUE: f64 = 1f64 / std::f64::consts::E; + let mut p = 1f64; + + for i in 0..MAX_HEIGHT { + // NOTE: i is >= 1 + #[allow(clippy::indexing_slicing)] + if i > 0 { + probabilities[i - 1] = ((u32::MAX as f64) * p) as u32; + } + p *= P_VALUE; + } + + probabilities +}); + +#[repr(C)] +struct Node { + data: ManuallyDrop>, + // Note that this is a lie! Sometimes this array is shorter than MAX_HEIGHT. + // and will instead point to garbage. That's okay because we'll use other + // bookkeeping invariants to ensure that we never actually access the garbage. + tower: [Links; MAX_HEIGHT], +} + +struct NodeData { + key: K, + value: V, +} + +// The forward and backward pointers in the tower for nodes. +#[repr(C)] +struct Links { + next: NodeCell, + prev: NodeCell, +} + +// BoundaryNodePtr points to either the head or tail node. It is never modified +// after it is created, so it can use Ordering::Relaxed without concern. It's +// only using atomics at all because it makes the object Send and Sync and they +// don't really have cost given there won't ever be contention. +struct BoundaryNode(AtomicPtr>); + +impl BoundaryNode { + fn load(&self) -> NodePtr { + let Self(ptr) = self; + NodePtr(ptr.load(Ordering::Relaxed)) + } + + fn new(node: NodePtr) -> Self { + Self(AtomicPtr::new(node.0)) + } +} + +struct NodePtr(*mut Node); + +impl Clone for NodePtr { + fn clone(&self) -> Self { + *self + } +} + +impl Copy for NodePtr {} + +impl Eq for NodePtr {} + +impl PartialEq for NodePtr { + fn eq(&self, other: &Self) -> bool { + self.0 == other.0 + } +} + +impl Hash for NodePtr { + fn hash(&self, state: &mut H) { + self.0.hash(state); + } +} + +impl std::fmt::Debug for NodePtr { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.0.fmt(f) + } +} + +impl NodePtr { + fn new(ptr: *mut Node) -> Option { + (!ptr.is_null()).then_some(Self(ptr)) + } + + fn init_next(self, level: usize, next: Self) { + self.links(level).next.store(next); + } + + fn init_prev(self, level: usize, prev: Self) { + self.links(level).prev.store(prev); + } + + fn cas_next(self, level: usize, current: Self, new: Self) -> Result<(), Option> { + self.links(level).next.cas(current, new) + } + + fn cas_prev(self, level: usize, current: Self, new: Self) -> Result<(), Option> { + self.links(level).prev.cas(current, new) + } + + fn load_next(self, level: usize) -> Option { + self.links(level).next.load() + } + + fn load_prev(self, level: usize) -> Option { + self.links(level).prev.load() + } + + #[warn(clippy::indexing_slicing)] + fn links(&self, level: usize) -> &'_ Links { + let Self(ptr) = self; + unsafe { &(**ptr).tower[level] } + } + + fn key(&self) -> &K { + let Self(ptr) = self; + &unsafe { &(**ptr) }.data.key + } +} + +#[repr(transparent)] +struct NodeCell(AtomicPtr>); + +impl NodeCell { + fn store(&self, value: NodePtr) { + let Self(ptr) = self; + ptr.store(value.0, Ordering::Release); + } + + fn cas(&self, current: NodePtr, new: NodePtr) -> Result<(), Option>> { + let Self(ptr) = self; + match ptr.compare_exchange(current.0, new.0, Ordering::AcqRel, Ordering::Acquire) { + Ok(_) => Ok(()), + Err(new) => Err(NodePtr::new(new)), + } + } + + fn load(&self) -> Option> { + let Self(ptr) = self; + NodePtr::new(ptr.load(Ordering::Acquire)) + } +} + +enum SpliceOrMatch { + Splice(Splice), + Match(NodePtr), +} + +impl From> for SpliceOrMatch { + fn from(value: Splice) -> Self { + Self::Splice(value) + } +} + +type Splices = [Option>; MAX_HEIGHT]; + +struct Splice { + prev: NodePtr, + next: NodePtr, +} + +impl Clone for Splice { + fn clone(&self) -> Self { + let &Self { prev, next } = self; + Self { prev, next } + } +} + +// Iter is an Iterator over all elements of a SkipMap. +pub struct Iter<'map, K, V> { + // Keeps the map alive. + _map: &'map SkipMap, + exhausted: bool, + before: NodePtr, + after: NodePtr, +} + +impl<'map, K, V> Iter<'map, K, V> { + fn new(map: &'map SkipMap) -> Self { + Self { + _map: map, + exhausted: false, + before: map.head.load(), + after: map.tail.load(), + } + } +} + +impl<'map, K, V> Iterator for Iter<'map, K, V> { + type Item = Entry<'map, K, V>; + + #[warn(clippy::unwrap_used)] + fn next(&mut self) -> Option { + if self.exhausted { + return None; + } + + let next = self.before.load_next(0).unwrap(); + if next == self.after { + self.exhausted = true; + return None; + } + + self.before = next; + Some(Entry::new(next)) + } +} + +impl<'map, K, V> DoubleEndedIterator for Iter<'map, K, V> { + #[warn(clippy::unwrap_used)] + fn next_back(&mut self) -> Option { + if self.exhausted { + return None; + } + + let next = self.after.load_prev(0).unwrap(); + if next == self.before { + self.exhausted = true; + return None; + } + + self.after = next; + Some(Entry::new(next)) + } +} + +/// Range is an Iterator over a `SkipMap` for a range. +#[allow(clippy::struct_field_names)] +pub struct Range<'m, K, V, Q: ?Sized, R> { + map: &'m SkipMap, + range: R, + exhausted: bool, + next: Option>, + next_back: Option>, + called: usize, + _phantom: PhantomData, +} + +pub struct Entry<'m, K, V> { + node: NodePtr, + _phantom: PhantomData<&'m ()>, +} + +impl<'m, K, V> Entry<'m, K, V> { + fn new(node: NodePtr) -> Self { + Self { + node, + _phantom: PhantomData, + } + } + + pub fn key(&self) -> &'m K { + // Transmute because we're lying about the lifetime. + unsafe { core::mem::transmute(&(&*self.node.0).data.key) } + } + + pub fn value(&self) -> &'m V { + // Transmute because we're lying about the lifetime. + unsafe { core::mem::transmute(&(&*self.node.0).data.value) } + } +} + +impl<'m, K, V, Q: ?Sized, R> Range<'m, K, V, Q, R> { + fn exhaust(&mut self) { + self.exhausted = true; + self.next = None; + self.next_back = None; + } +} + +impl<'m, K, V, Q, R> Iterator for Range<'m, K, V, Q, R> +where + K: Ord + Comparable, + R: RangeBounds, + Q: Comparable + Ord + ?Sized, +{ + type Item = Entry<'m, K, V>; + + #[allow(unsafe_code)] + fn next(&mut self) -> Option { + if self.exhausted { + return None; + } + + self.called += 1; + + let next = if let Some(next) = self.next { + next + } else { + let before = self.map.find_from_node(self.range.start_bound()); + match before.load_next(0) { + Some(next) => next, + None => { + self.exhaust(); + return None; + } + } + }; + + // If after_next is None, then we're at the tail and are done. + let Some(after_next) = next.load_next(0) else { + self.exhaust(); + return None; + }; + + // If we're not at the tail, then the key is valid. + if match self.range.end_bound() { + Bound::Included(bound) => next.key().compare(bound).is_gt(), + Bound::Excluded(bound) => next.key().compare(bound).is_ge(), + Bound::Unbounded => false, + } { + self.exhaust(); + return None; + } + + // Make sure we haven't moved past reverse iteration. + if self.next_back.is_none_or(|next_back| next_back != next) { + self.next = Some(after_next); + } else { + self.exhaust(); + }; + + Some(Entry::new(next)) + } +} + +impl<'m, K, V, Q, R> DoubleEndedIterator for Range<'m, K, V, Q, R> +where + K: Ord + Comparable, + R: RangeBounds, + Q: Comparable + Ord + ?Sized, +{ + fn next_back(&mut self) -> Option { + if self.exhausted { + return None; + } + + let next_back = if let Some(next_back) = self.next_back { + next_back + } else { + let after = self.map.find_to_node(self.range.end_bound()); + match after.load_prev(0) { + Some(next_back) => next_back, + None => { + self.exhaust(); + return None; + } + } + }; + + let Some(before_next_back) = next_back.load_prev(0) else { + self.exhaust(); + return None; + }; + + if match self.range.start_bound() { + Bound::Included(bound) => next_back.key().compare(bound).is_lt(), + Bound::Excluded(bound) => next_back.key().compare(bound).is_le(), + Bound::Unbounded => false, + } { + self.exhaust(); + return None; + } + + if self.next.is_none_or(|next| next_back != next) { + self.next_back = Some(before_next_back); + } else { + self.exhaust(); + }; + + Some(Entry::new(next_back)) + } +} + +#[cfg(test)] +impl SkipMap +where + K: Ord, +{ + #[allow(clippy::needless_pass_by_ref_mut)] + pub(crate) fn check_integrity(&mut self) { + use std::collections::HashSet; + + // We want to check that there are no cycles, that the forward and backwards + // directions have the same chains at all levels, and that the values are + // ordered. + let head_nodes = { + let mut cur = Some(self.head.load()); + let mut head_forward_nodes = HashSet::new(); + let mut head_nodes = Vec::new(); + + while let Some(node) = cur { + head_nodes.push(node); + assert!(head_forward_nodes.insert(node), "head"); + cur = node.load_next(0); + } + + head_nodes + }; + + let mut tail_nodes = { + let mut cur = Some(self.tail.load()); + let mut tail_backward_nodes = HashSet::new(); + let mut tail_nodes = Vec::new(); + + while let Some(node) = cur { + tail_nodes.push(node); + assert!(tail_backward_nodes.insert(node), "tail"); + cur = node.load_prev(0); + } + + tail_nodes + }; + + tail_nodes.reverse(); + + assert_eq!(head_nodes, tail_nodes); + } +} + +struct ArenasAllocator { + arenas: Arenas, + _phantom: PhantomData, +} + +impl Default for ArenasAllocator { + fn default() -> Self { + Self { + arenas: Arenas::default(), + _phantom: PhantomData, + } + } +} + +impl ArenasAllocator { + const ALIGNMENT: usize = align_of::>(); + const TOWER_OFFSET: usize = offset_of!(Node, tower); + + fn alloc(&self, height: usize) -> *mut Node { + let layout = unsafe { + Layout::from_size_align_unchecked( + Self::TOWER_OFFSET + (height * size_of::>()), + Self::ALIGNMENT, + ) + }; + + self.arenas.alloc(layout).cast::>() + } +} diff --git a/src/memtable/skiplist/test.rs b/src/memtable/skiplist/test.rs new file mode 100644 index 00000000..f88c26e3 --- /dev/null +++ b/src/memtable/skiplist/test.rs @@ -0,0 +1,280 @@ +// Copyright (c) 2024-present, fjall-rs +// This source code is licensed under both the Apache 2.0 and MIT License +// (found in the LICENSE-* files in the repository) + +use std::{ + collections::BTreeMap, + fmt::{Debug, Write}, + num::NonZero, + ops::RangeBounds, + sync::Barrier, +}; + +use super::*; +use quickcheck::{Arbitrary, Gen}; +use rand::{rng, RngCore}; + +#[test] +fn skip_map_basic() { + let v = SkipMap::::new(rng().next_u32()); + assert_eq!(v.insert(1, 1), Ok(())); + assert_eq!(v.len(), 1); + assert_eq!(v.insert(1, 2), Err((1, 2))); + assert_eq!(v.len(), 1); + assert_eq!(v.insert(2, 2), Ok(())); + assert_eq!(v.len(), 2); + assert_eq!(v.insert(2, 1), Err((2, 1))); + let got: Vec<_> = v.iter().map(|e| (*e.key(), *e.value())).collect(); + assert_eq!(got, vec![(1, 1), (2, 2)]); + let got_rev: Vec<_> = v.iter().rev().map(|e| (*e.key(), *e.value())).collect(); + assert_eq!(got_rev, vec![(2, 2), (1, 1)]); +} + +#[test] +#[allow(clippy::unwrap_used)] +fn skip_map_basic_strings() { + let v = SkipMap::::new(rng().next_u32()); + let mut foo = String::new(); + foo.write_str("foo").unwrap(); + assert_eq!(v.insert(foo, 1), Ok(())); + assert_eq!(v.len(), 1); + assert_eq!(v.insert("foo".into(), 2), Err(("foo".into(), 2))); + assert_eq!(v.len(), 1); + assert_eq!(v.insert("bar".into(), 2), Ok(())); + assert_eq!(v.len(), 2); + assert_eq!(v.insert("bar".into(), 1), Err(("bar".into(), 1))); + let got: Vec<_> = v.iter().map(|e| (e.key().clone(), *e.value())).collect(); + assert_eq!(got, vec![("bar".into(), 2), ("foo".into(), 1)]); +} + +#[derive(Clone, Debug)] +struct TestOperation { + key: K, + value: V, +} + +impl Arbitrary for TestOperation +where + K: Arbitrary, + V: Arbitrary, +{ + fn arbitrary(g: &mut Gen) -> Self { + Self { + key: K::arbitrary(g), + value: V::arbitrary(g), + } + } +} + +#[derive(Debug, Clone)] +struct TestOperations { + seed: u32, + threads: usize, + ops: Vec>, +} + +impl Arbitrary for TestOperations +where + K: Arbitrary, + V: Arbitrary, +{ + fn arbitrary(g: &mut Gen) -> Self { + let max_threads = std::thread::available_parallelism() + .map(NonZero::get) + .unwrap_or(64) + * 16; + Self { + seed: u32::arbitrary(g), + threads: 1usize.max(usize::arbitrary(g) % max_threads), + ops: > as Arbitrary>::arbitrary(g), + } + } +} + +fn prop(operations: TestOperations) -> bool +where + K: Arbitrary + Ord + Eq + Debug + Send + Sync + Clone, + V: Arbitrary + Eq + Debug + Send + Sync + Clone, +{ + #[cfg(not(miri))] + const TRACK_OUTCOMES: bool = true; + #[cfg(miri)] + const TRACK_OUTCOMES: bool = false; + + let mut skipmap = SkipMap::new(operations.seed); + let barrier = Barrier::new(operations.threads); + + let outcomes = std::thread::scope(|scope| { + let (mut ops, mut threads_to_launch) = (operations.ops.as_slice(), operations.threads); + let mut thread_outcomes = Vec::new(); + + while threads_to_launch > 0 { + let items = ops.len() / threads_to_launch; + let (subslice, remaining) = ops.split_at(items); + ops = remaining; + threads_to_launch -= 1; + let skipmap = &skipmap; + let barrier = &barrier; + + let spawned = scope.spawn(move || { + barrier.wait(); + let mut outcomes = Vec::new(); + for op in subslice { + outcomes.push(skipmap.insert(op.key.clone(), op.value.clone()).is_ok()); + } + outcomes + }); + + if TRACK_OUTCOMES { + thread_outcomes.push(spawned); + } + } + + thread_outcomes + .into_iter() + .flat_map(|v| v.join().unwrap()) + .collect::>() + }); + + #[cfg(miri)] + if true { + return true; + } + + let successful_ops = operations + .ops + .into_iter() + .zip(outcomes) + .filter_map(|(op, outcome)| outcome.then_some(op)) + .collect::>(); + + skipmap.check_integrity(); + + verify_ranges(&skipmap, &successful_ops); + + let skipmap_items: Vec<_> = skipmap + .iter() + .map(|e| (e.key().clone(), e.value().clone())) + .collect(); + let skipmap_items_rev: Vec<_> = skipmap + .iter() + .rev() + .map(|e| (e.key().clone(), e.value().clone())) + .collect(); + + let mut skipmap_items_rev_rev = skipmap_items_rev.clone(); + skipmap_items_rev_rev.reverse(); + + assert_eq!(successful_ops.len(), skipmap.len(), "len"); + assert_eq!(skipmap_items.len(), skipmap.len(), "items"); + assert_eq!(skipmap_items.len(), skipmap_items_rev.len(), "rev items"); + assert_eq!( + skipmap_items, skipmap_items_rev_rev, + "Forward iteration should match\n{skipmap_items:#?}\n{skipmap_items_rev_rev:#?}", + ); + + true +} + +#[test] +fn test_quickcheck_strings() { + quickcheck::quickcheck(prop as fn(TestOperations) -> bool); +} + +#[test] +fn test_quickcheck_ints() { + quickcheck::quickcheck(prop as fn(TestOperations) -> bool); +} + +#[allow(clippy::indexing_slicing)] +fn verify_ranges(skipmap: &SkipMap, successful_ops: &Vec>) +where + K: Ord + Eq + Debug + Clone, + V: Eq + Debug + Clone, +{ + let mut successful_keys_sorted = successful_ops + .iter() + .map(|op| op.key.clone()) + .collect::>(); + successful_keys_sorted.sort(); + + let btree = successful_ops + .iter() + .map(|TestOperation { key, value }| (key.clone(), value.clone())) + .collect::>(); + + for _ in 0..10 { + if successful_ops.is_empty() { + break; + } + let (a, b) = ( + rng().next_u32() as usize % successful_ops.len(), + rng().next_u32() as usize % successful_ops.len(), + ); + + let (start, end) = (a.min(b), a.max(b)); + + fn assert_range_eq + Clone + std::fmt::Debug>( + a: &BTreeMap, + b: &SkipMap, + bounds: B, + ) where + K: Ord + Eq + Debug + Clone, + V: Eq + Debug + Clone, + { + { + let ra = a + .range(bounds.clone()) + .map(|(a, b)| (a.clone(), b.clone())) + .collect::>(); + + let rb = b + .range(bounds.clone()) + .map(|entry| (entry.key().clone(), entry.value().clone())) + .collect::>(); + + assert_eq!( + ra, + rb, + "{} {:?} forward: {:#?} != {:#?}", + std::any::type_name::(), + bounds, + ra, + rb + ); + } + + { + let ra = a + .range(bounds.clone()) + .rev() + .map(|(a, b)| (a.clone(), b.clone())) + .collect::>(); + + let rb = b + .range(bounds.clone()) + .rev() + .map(|entry| (entry.key().clone(), entry.value().clone())) + .collect::>(); + + assert_eq!( + ra, + rb, + "{} {:?} backwards: {:#?} != {:#?}", + std::any::type_name::(), + bounds, + ra, + rb + ); + } + } + + let (start, end) = (&successful_keys_sorted[start], &successful_keys_sorted[end]); + assert_range_eq(&btree, skipmap, ..); + assert_range_eq(&btree, skipmap, ..end); + assert_range_eq(&btree, skipmap, ..=end); + assert_range_eq(&btree, skipmap, start..); + assert_range_eq(&btree, skipmap, start..end); + assert_range_eq(&btree, skipmap, start..=end); + } +}