From 1f368433cf63a56a3f444a2f3b1a241df87ea804 Mon Sep 17 00:00:00 2001 From: macronova Date: Mon, 27 Oct 2025 15:03:49 -0700 Subject: [PATCH 1/2] [BUG] Set minimum element size --- src/hnsw.rs | 258 ++++++++++++++++++++++++++++++++++++++++------------ 1 file changed, 200 insertions(+), 58 deletions(-) diff --git a/src/hnsw.rs b/src/hnsw.rs index 64efd1e6..8ef1e4be 100644 --- a/src/hnsw.rs +++ b/src/hnsw.rs @@ -5,6 +5,8 @@ use std::{ }; use thiserror::Error; +const MIN_ELEMENT_SIZE: usize = 100; + // https://doc.rust-lang.org/nomicon/ffi.html#representing-opaque-structs #[repr(C)] struct HnswIndexPtrFFI { @@ -83,19 +85,43 @@ extern "C" { fn resize_index(index: *const HnswIndexPtrFFI, new_size: usize); fn get_last_error(index: *const HnswIndexPtrFFI) -> *const c_char; - // Functions for Rust-managed buffer allocation fn get_required_header_size(index: *const HnswIndexPtrFFI) -> usize; fn get_required_data_level0_size(index: *const HnswIndexPtrFFI) -> usize; fn get_required_length_size(index: *const HnswIndexPtrFFI) -> usize; fn get_required_link_list_size(index: *const HnswIndexPtrFFI) -> usize; - fn create_mutable_hnsw_data(header_buffer: *mut c_uchar, header_size: usize, data_level0_buffer: *mut c_uchar, data_level0_size: usize, length_buffer: *mut c_uchar, length_size: usize, link_list_buffer: *mut c_uchar, link_list_size: usize) -> *const DataFFI; - fn create_hnsw_data_view(header_buffer: *const c_uchar, header_size: usize, data_level0_buffer: *const c_uchar, data_level0_size: usize, length_buffer: *const c_uchar, length_size: usize, link_list_buffer: *const c_uchar, link_list_size: usize) -> *const DataViewFFI; - fn hnsw_data_to_view(hnsw_data :*const DataFFI) -> *const DataViewFFI; - - fn serialize_index_to_hnsw_data(index: *const HnswIndexPtrFFI, hnsw_data: *const DataFFI) -> bool; - fn load_index_from_hnsw_data(index: *const HnswIndexPtrFFI, buffers: *const DataViewFFI, max_elements: usize); + fn create_mutable_hnsw_data( + header_buffer: *mut c_uchar, + header_size: usize, + data_level0_buffer: *mut c_uchar, + data_level0_size: usize, + length_buffer: *mut c_uchar, + length_size: usize, + link_list_buffer: *mut c_uchar, + link_list_size: usize, + ) -> *const DataFFI; + fn create_hnsw_data_view( + header_buffer: *const c_uchar, + header_size: usize, + data_level0_buffer: *const c_uchar, + data_level0_size: usize, + length_buffer: *const c_uchar, + length_size: usize, + link_list_buffer: *const c_uchar, + link_list_size: usize, + ) -> *const DataViewFFI; + fn hnsw_data_to_view(hnsw_data: *const DataFFI) -> *const DataViewFFI; + + fn serialize_index_to_hnsw_data( + index: *const HnswIndexPtrFFI, + hnsw_data: *const DataFFI, + ) -> bool; + fn load_index_from_hnsw_data( + index: *const HnswIndexPtrFFI, + buffers: *const DataViewFFI, + max_elements: usize, + ); } #[derive(Error, Debug)] @@ -106,7 +132,7 @@ pub enum HnswError { #[error(transparent)] ErrorStringRead(#[from] Utf8Error), #[error("HnswDataError: `{0}`")] - ErrorHnswData(String) + ErrorHnswData(String), } #[derive(Clone, Copy, Debug, PartialEq)] @@ -191,7 +217,8 @@ impl HnswIndex { unsafe { init_index( ffi_ptr, - config.max_elements, + // NOTE(sicheng): Avoid init with zero element + config.max_elements.max(MIN_ELEMENT_SIZE), config.m, config.ef_construction, config.random_seed, @@ -228,7 +255,8 @@ impl HnswIndex { let path = CString::new(path).map_err(|e| HnswInitError::InvalidPath(e.to_string()))?; unsafe { - load_index(ffi_ptr, path.as_ptr(), true, true, 0); + // NOTE(sicheng): Avoid load with zero element + load_index(ffi_ptr, path.as_ptr(), true, true, MIN_ELEMENT_SIZE); } read_and_return_hnsw_error(ffi_ptr)?; @@ -264,7 +292,8 @@ impl HnswIndex { } pub fn resize(&mut self, new_size: usize) -> Result<(), HnswError> { - unsafe { resize_index(self.ffi_ptr, new_size) } + // NOTE(sicheng): Avoid resize to zero element + unsafe { resize_index(self.ffi_ptr, new_size.max(MIN_ELEMENT_SIZE)) } read_and_return_hnsw_error(self.ffi_ptr) } @@ -373,7 +402,10 @@ impl HnswIndex { } /// Load index from memory buffers - pub fn load_from_hnsw_data(config: HnswIndexMemoryLoadConfig, load_data: &HnswData) -> Result { + pub fn load_from_hnsw_data( + config: HnswIndexMemoryLoadConfig, + load_data: &HnswData, + ) -> Result { let distance_function_string: String = config.distance_function.into(); let space_name = CString::new(distance_function_string) .map_err(|e| HnswInitError::InvalidDistanceFunction(e.to_string()))?; @@ -382,7 +414,8 @@ impl HnswIndex { read_and_return_hnsw_error(ffi_ptr)?; unsafe { - load_index_from_hnsw_data(ffi_ptr, load_data.ffi_ptr, 0); + // NOTE(sicheng): Avoid load with zero element + load_index_from_hnsw_data(ffi_ptr, load_data.ffi_ptr, MIN_ELEMENT_SIZE); } read_and_return_hnsw_error(ffi_ptr)?; @@ -412,33 +445,42 @@ unsafe impl Send for HnswData {} impl HnswData { /// Create new HnswData with Rust-managed buffers from index serialization - fn new_from_index( - index_ptr: *const HnswIndexPtrFFI, - ) -> Result { + fn new_from_index(index_ptr: *const HnswIndexPtrFFI) -> Result { // Allocate mutable buffers in Rust let header_size = unsafe { get_required_header_size(index_ptr) }; let data_level0_size = unsafe { get_required_data_level0_size(index_ptr) }; let length_size = unsafe { get_required_length_size(index_ptr) }; let link_list_size = unsafe { get_required_link_list_size(index_ptr) }; - + let mut header_buffer = vec![0u8; header_size]; let mut data_level0_buffer = vec![0u8; data_level0_size]; let mut length_buffer = vec![0u8; length_size]; let mut link_list_buffer = vec![0u8; link_list_size]; - let ffi_ptr = unsafe { create_mutable_hnsw_data(header_buffer.as_mut_ptr(), header_size, data_level0_buffer.as_mut_ptr(), data_level0_size, length_buffer.as_mut_ptr(), length_size, link_list_buffer.as_mut_ptr(), link_list_size) }; - + let ffi_ptr = unsafe { + create_mutable_hnsw_data( + header_buffer.as_mut_ptr(), + header_size, + data_level0_buffer.as_mut_ptr(), + data_level0_size, + length_buffer.as_mut_ptr(), + length_size, + link_list_buffer.as_mut_ptr(), + link_list_size, + ) + }; + read_and_return_hnsw_error(index_ptr)?; - + // Call C++ to serialize directly into our Rust-allocated buffers - let success = unsafe { - serialize_index_to_hnsw_data(index_ptr, ffi_ptr) - }; - + let success = unsafe { serialize_index_to_hnsw_data(index_ptr, ffi_ptr) }; + if !success { - return Err(HnswError::FFIException("Failed to serialize to external buffers".to_string())); + return Err(HnswError::FFIException( + "Failed to serialize to external buffers".to_string(), + )); } - + Ok(HnswData { ffi_ptr: unsafe { hnsw_data_to_view(ffi_ptr) }, _marker: std::marker::PhantomData, @@ -457,11 +499,24 @@ impl HnswData { link_list_buffer: std::sync::Arc>, ) -> Result { // Create HnswData FFI structure and set the buffers - let ffi_ptr = unsafe { create_hnsw_data_view(header_buffer.as_ptr(), header_buffer.len(), data_level0_buffer.as_ptr(), data_level0_buffer.len(), length_buffer.as_ptr(), length_buffer.len(), link_list_buffer.as_ptr(), link_list_buffer.len()) }; + let ffi_ptr = unsafe { + create_hnsw_data_view( + header_buffer.as_ptr(), + header_buffer.len(), + data_level0_buffer.as_ptr(), + data_level0_buffer.len(), + length_buffer.as_ptr(), + length_buffer.len(), + link_list_buffer.as_ptr(), + link_list_buffer.len(), + ) + }; if ffi_ptr.is_null() { - return Err(HnswError::ErrorHnswData("Failed to create HnswData structure".to_string())); + return Err(HnswError::ErrorHnswData( + "Failed to create HnswData structure".to_string(), + )); } - + Ok(HnswData { ffi_ptr, _marker: std::marker::PhantomData, @@ -525,7 +580,7 @@ impl HnswDataBuilder { self.data_level0_buffer = Some(buffer); self } - + pub fn length_buffer(mut self, buffer: std::sync::Arc>) -> Self { self.length_buffer = Some(buffer); self @@ -539,27 +594,41 @@ impl HnswDataBuilder { /// Validate that all required buffers are set and non-empty fn validate(&self) -> Result<(), HnswError> { if self.header_buffer.is_none() { - return Err(HnswError::ErrorHnswData("Header buffer is required".to_string())); + return Err(HnswError::ErrorHnswData( + "Header buffer is required".to_string(), + )); } if self.data_level0_buffer.is_none() { - return Err(HnswError::ErrorHnswData("Data level0 buffer is required".to_string())); + return Err(HnswError::ErrorHnswData( + "Data level0 buffer is required".to_string(), + )); } if self.length_buffer.is_none() { - return Err(HnswError::ErrorHnswData("Length buffer is required".to_string())); + return Err(HnswError::ErrorHnswData( + "Length buffer is required".to_string(), + )); } if self.link_list_buffer.is_none() { - return Err(HnswError::ErrorHnswData("Link list buffer is required".to_string())); + return Err(HnswError::ErrorHnswData( + "Link list buffer is required".to_string(), + )); } // Check that buffers are not empty if self.header_buffer.as_ref().unwrap().is_empty() { - return Err(HnswError::ErrorHnswData("Header buffer cannot be empty".to_string())); + return Err(HnswError::ErrorHnswData( + "Header buffer cannot be empty".to_string(), + )); } if self.data_level0_buffer.as_ref().unwrap().is_empty() { - return Err(HnswError::ErrorHnswData("Data level0 buffer cannot be empty".to_string())); + return Err(HnswError::ErrorHnswData( + "Data level0 buffer cannot be empty".to_string(), + )); } if self.length_buffer.as_ref().unwrap().is_empty() { - return Err(HnswError::ErrorHnswData("Length buffer cannot be empty".to_string())); + return Err(HnswError::ErrorHnswData( + "Length buffer cannot be empty".to_string(), + )); } // It's ok for link list buffer to be empty in an empty index. @@ -945,11 +1014,14 @@ pub mod test { .link_list_buffer(src_buffers[3].clone()) .build(); - let index = HnswIndex::load_from_hnsw_data(HnswIndexMemoryLoadConfig { - distance_function, - dimensionality: d as i32, - ef_search: 100, - }, &hnsw_data.expect("Failed to create HnswData")); + let index = HnswIndex::load_from_hnsw_data( + HnswIndexMemoryLoadConfig { + distance_function, + dimensionality: d as i32, + ef_search: 100, + }, + &hnsw_data.expect("Failed to create HnswData"), + ); let index = match index { Err(e) => panic!("Error loading index: {}", e), @@ -978,7 +1050,7 @@ pub mod test { let distance_function = HnswDistanceFunction::Euclidean; let tmp_dir = tempdir().unwrap(); let persist_path = tmp_dir.path(); - + // Create and populate original index let original_index = HnswIndex::init(HnswIndexInitConfig { distance_function, @@ -989,7 +1061,8 @@ pub mod test { ef_search: 100, random_seed: 42, persist_path: Some(persist_path.to_path_buf()), - }).expect("Failed to create original index"); + }) + .expect("Failed to create original index"); let data: Vec = generate_random_data(n, d); let ids: Vec = (0..n).collect(); @@ -997,7 +1070,9 @@ pub mod test { // Add data to original index for i in 0..n { let data_slice = &data[i * d..(i + 1) * d]; - original_index.add(ids[i], data_slice).expect("Should not error"); + original_index + .add(ids[i], data_slice) + .expect("Should not error"); } // Verify original index has correct data @@ -1005,14 +1080,27 @@ pub mod test { index_data_same(&original_index, &ids, &data, d); // Serialize to memory buffers - let hnsw_data = original_index.serialize_index_to_hnsw_data() + let hnsw_data = original_index + .serialize_index_to_hnsw_data() .expect("Failed to serialize to memory buffers"); // Verify buffers are not empty - assert!(!hnsw_data.header_buffer().is_empty(), "Header buffer should not be empty"); - assert!(!hnsw_data.data_level0_buffer().is_empty(), "Data level0 buffer should not be empty"); - assert!(!hnsw_data.length_buffer().is_empty(), "Length buffer should not be empty"); - assert!(!hnsw_data.link_list_buffer().is_empty(), "Link list buffer should not be empty"); + assert!( + !hnsw_data.header_buffer().is_empty(), + "Header buffer should not be empty" + ); + assert!( + !hnsw_data.data_level0_buffer().is_empty(), + "Data level0 buffer should not be empty" + ); + assert!( + !hnsw_data.length_buffer().is_empty(), + "Length buffer should not be empty" + ); + assert!( + !hnsw_data.link_list_buffer().is_empty(), + "Link list buffer should not be empty" + ); // Create new index from memory buffers let loaded_index = HnswIndex::load_from_hnsw_data( @@ -1022,7 +1110,8 @@ pub mod test { ef_search: 100, }, &hnsw_data, - ).expect("Failed to load from memory buffers"); + ) + .expect("Failed to load from memory buffers"); // Verify loaded index has same data assert_eq!(loaded_index.len(), n); @@ -1031,17 +1120,25 @@ pub mod test { // Test querying both indices to ensure they behave the same let query_vector = &data[0..d]; // Use first vector as query let k = 5; - - let (original_ids, original_distances) = original_index.query(query_vector, k, &[], &[]) + + let (original_ids, original_distances) = original_index + .query(query_vector, k, &[], &[]) .expect("Query should not error"); - let (loaded_ids, loaded_distances) = loaded_index.query(query_vector, k, &[], &[]) + let (loaded_ids, loaded_distances) = loaded_index + .query(query_vector, k, &[], &[]) .expect("Query should not error"); // Results should be identical - assert_eq!(original_ids, loaded_ids, "Query results should be identical"); + assert_eq!( + original_ids, loaded_ids, + "Query results should be identical" + ); assert_eq!(original_distances.len(), loaded_distances.len()); for (orig_dist, loaded_dist) in original_distances.iter().zip(loaded_distances.iter()) { - assert!((orig_dist - loaded_dist).abs() < EPS, "Distances should be nearly identical"); + assert!( + (orig_dist - loaded_dist).abs() < EPS, + "Distances should be nearly identical" + ); } } @@ -1132,7 +1229,7 @@ pub mod test { #[test] fn it_can_catch_error() { - let n = 10; + let n = 1000; let d: usize = 960; let distance_function = HnswDistanceFunction::Euclidean; let tmp_dir = tempdir().unwrap(); @@ -1237,7 +1334,50 @@ pub mod test { #[test] fn it_can_resize_correctly() { - let n: usize = 10; + let n: usize = 100; + let d: usize = 960; + let distance_function = HnswDistanceFunction::Euclidean; + let tmp_dir = tempdir().unwrap(); + let persist_path = tmp_dir.path(); + let index = HnswIndex::init(HnswIndexInitConfig { + distance_function, + dimensionality: d as i32, + max_elements: n, + m: 16, + ef_construction: 100, + ef_search: 100, + random_seed: 0, + persist_path: Some(persist_path.to_path_buf()), + }); + + let mut index = match index { + Err(e) => panic!("Error initializing index: {}", e), + Ok(index) => index, + }; + + let data: Vec = generate_random_data(n, d); + let ids: Vec = (0..n).collect(); + + (0..n).for_each(|i| { + let data = &data[i * d..(i + 1) * d]; + index.add(ids[i], data).expect("Should not error"); + }); + + index.delete(0).unwrap(); + let data = &data[d..2 * d]; + + let index_len = index.len_with_deleted(); + let index_capacity = index.capacity(); + if index_len + 1 > index_capacity { + index.resize(index_capacity * 2).unwrap(); + } + // this will fail if the index is not resized correctly + index.add(100, data).unwrap(); + } + + #[test] + fn it_can_resize_with_floor() { + let mut n: usize = 0; let d: usize = 960; let distance_function = HnswDistanceFunction::Euclidean; let tmp_dir = tempdir().unwrap(); @@ -1258,6 +1398,8 @@ pub mod test { Ok(index) => index, }; + n = 10; + let data: Vec = generate_random_data(n, d); let ids: Vec = (0..n).collect(); From 437692d33d834c1d6f784018173092e768297038 Mon Sep 17 00:00:00 2001 From: macronova Date: Mon, 27 Oct 2025 16:01:26 -0700 Subject: [PATCH 2/2] Bump version --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index e11df9d6..97944020 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "hnswlib" edition = "2021" -version = "0.8.1" +version = "0.8.2" [lib] path = "src/lib.rs"