diff --git a/build.gradle b/build.gradle index 4041edc..ed919bd 100644 --- a/build.gradle +++ b/build.gradle @@ -87,6 +87,9 @@ dependencies { implementation libs.hadoop.common implementation libs.hadoop.mapreduce.client.common + // HDF5 support for Vector Search + implementation libs.jhdf + // MCP Server dependencies implementation (libs.mcp.sdk) { exclude group: 'org.slf4j', module: 'slf4j-api' diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 10d134a..ef6868b 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -35,6 +35,7 @@ coroutines-test = "1.9.0" mcp-sdk = "0.7.2" ktor = "3.3.0" testcontainers = "1.19.8" +jhdf = "0.7.0" [libraries] jcommander = { module = "com.beust:jcommander", version.ref = "jcommander" } @@ -71,6 +72,8 @@ parquet-hadoop = { module = "org.apache.parquet:parquet-hadoop", version.ref = " hadoop-common = { module = "org.apache.hadoop:hadoop-common", version.ref = "hadoop" } hadoop-mapreduce-client-common = { module = "org.apache.hadoop:hadoop-mapreduce-client-common", version.ref = "hadoop" } +jhdf = { module = "io.jhdf:jhdf", version.ref = "jhdf" } + junit-jupiter-engine = { module = "org.junit.jupiter:junit-jupiter-engine", version.ref = "junit" } junit-jupiter-params = { module = "org.junit.jupiter:junit-jupiter-params", version.ref = "junit-params" } assertj-core = { module = "org.assertj:assertj-core", version.ref = "assertj" } diff --git a/src/main/kotlin/org/apache/cassandra/easystress/OperationCallback.kt b/src/main/kotlin/org/apache/cassandra/easystress/OperationCallback.kt index c61d5c2..94b64ee 100644 --- a/src/main/kotlin/org/apache/cassandra/easystress/OperationCallback.kt +++ b/src/main/kotlin/org/apache/cassandra/easystress/OperationCallback.kt @@ -65,8 +65,7 @@ class OperationCallback( // TODO (visibility): include details about paging? context.collect(op, Either.Left(result!!), startNanos, endNanos) - // do the callback for mutations - // might extend this to select, but I can't see a reason for it now + // do the callback for mutations and selects when (op) { is Operation.Mutation -> { runner.onSuccess(op, result) @@ -74,6 +73,9 @@ class OperationCallback( is Operation.DDL -> { runner.onSuccess(op, result) } + is Operation.SelectStatement -> { + runner.onSuccess(op, result) + } is Operation.Stop -> { throw OperationStopException() } diff --git a/src/main/kotlin/org/apache/cassandra/easystress/workloads/IStressWorkload.kt b/src/main/kotlin/org/apache/cassandra/easystress/workloads/IStressWorkload.kt index 23d5c70..41db08d 100644 --- a/src/main/kotlin/org/apache/cassandra/easystress/workloads/IStressWorkload.kt +++ b/src/main/kotlin/org/apache/cassandra/easystress/workloads/IStressWorkload.kt @@ -57,6 +57,11 @@ interface IStressRunner { op: Operation.DDL, result: AsyncResultSet?, ) { } + + fun onSuccess( + op: Operation.SelectStatement, + result: AsyncResultSet?, + ) { } } /** @@ -131,6 +136,7 @@ sealed class Operation( class SelectStatement( bound: BoundStatement, + val callbackPayload: Any? = null, ) : Operation(bound) class Deletion( diff --git a/src/main/kotlin/org/apache/cassandra/easystress/workloads/VectorSearch.kt b/src/main/kotlin/org/apache/cassandra/easystress/workloads/VectorSearch.kt new file mode 100644 index 0000000..975b4bf --- /dev/null +++ b/src/main/kotlin/org/apache/cassandra/easystress/workloads/VectorSearch.kt @@ -0,0 +1,377 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.cassandra.easystress.workloads + +import com.datastax.oss.driver.api.core.CqlSession +import com.datastax.oss.driver.api.core.cql.AsyncResultSet +import com.datastax.oss.driver.api.core.cql.PreparedStatement +import com.datastax.oss.driver.api.core.data.CqlVector +import io.jhdf.HdfFile +import org.apache.cassandra.easystress.MinimumVersion +import org.apache.cassandra.easystress.PartitionKey +import org.apache.cassandra.easystress.StressContext +import org.apache.cassandra.easystress.WorkloadParameter +import org.apache.logging.log4j.kotlin.logger +import java.io.File +import java.nio.file.Paths +import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.ThreadLocalRandom +import java.util.concurrent.atomic.AtomicInteger +import java.util.concurrent.atomic.AtomicLong + +/** + * Workload for benchmarking Vector Search (Cassandra 5.0+). + * Supports both synthetic random vectors and realistic datasets via HDF5. + * + * ## Memory Considerations + * When using HDF5 datasets, vectors are loaded entirely into memory. + * Example memory requirements: + * - SIFT-1M (1M vectors × 128 dims × 4 bytes) ≈ 512 MB + * - GloVe-1M (1M vectors × 200 dims × 4 bytes) ≈ 800 MB + * - Deep1B subset (10M vectors × 96 dims × 4 bytes) ≈ 3.8 GB + * + * Ensure sufficient heap space with -Xmx JVM flag. + * + * ## Ground Truth / Recall Calculation + * When using standard ANN benchmark datasets (SIFT, GloVe, etc.) that include + * ground truth neighbors, enable recall calculation to measure search quality: + * - Set `calculateRecall=true` + * - Ensure your HDF5 file contains a 'neighbors' dataset (or configure `neighborsDataset`) + * - Recall@K is logged periodically based on `recallLogInterval` + */ +@MinimumVersion("5.0") +class VectorSearch : IStressWorkload { + @WorkloadParameter(description = "Vector dimensions. Default 1536 (OpenAI).") + var dimensions = 1536 + + @WorkloadParameter(description = "Similarity function: COSINE, EUCLIDEAN, or DOT_PRODUCT.") + var similarityFunction = "COSINE" + + @WorkloadParameter(description = "ANN search limit (TOP K).") + var limit = 10 + + @WorkloadParameter(description = "Path to HDF5 dataset file (e.g., glove.hdf5). If empty, uses random vectors.") + var datasetPath: String = "" + + @WorkloadParameter(description = "Name of the HDF5 dataset for training data (inserts). Default 'train'.") + var trainDataset = "train" + + @WorkloadParameter(description = "Name of the HDF5 dataset for query data (selects). Default 'test'.") + var queryDataset = "test" + + @WorkloadParameter(description = "Name of the HDF5 dataset for ground truth neighbors. Default 'neighbors'.") + var neighborsDataset = "neighbors" + + @WorkloadParameter(description = "Enable recall calculation (requires ground truth in HDF5). Default false.") + var calculateRecall = false + + @WorkloadParameter(description = "How often to log recall summary (every N queries). Default 1000.") + var recallLogInterval = 1000 + + lateinit var insert: PreparedStatement + lateinit var select: PreparedStatement + lateinit var delete: PreparedStatement + + private lateinit var trainVectors: Array + private lateinit var queryVectors: Array + private lateinit var groundTruth: Array + private var hdf5Loaded = false + private var hasGroundTruth = false + + // Recall tracking + private val queryCount = AtomicLong(0) + private val totalRecall = AtomicLong(0) + private val minRecall = AtomicInteger(Int.MAX_VALUE) + private val maxRecall = AtomicInteger(Int.MIN_VALUE) + private val skippedQueryCount = AtomicLong(0) + + // Sequential counter for inserting training vectors in order + private val insertCounter = AtomicLong(0) + + // Track which training indices have been inserted (for recall calculation) + private val insertedIndices = ConcurrentHashMap.newKeySet() + + val log = logger() + + override fun prepare(session: CqlSession) { + val validFunctions = listOf("COSINE", "EUCLIDEAN", "DOT_PRODUCT") + require(similarityFunction.uppercase() in validFunctions) { + "similarityFunction must be one of: $validFunctions (got: $similarityFunction)" + } + + // Validate limit ('K' in top 1 <= K <= 1000) + require(limit in 1..1000) { + "limit must be between 1 and 1000 (got: $limit)" + } + + log.info { "Preparing VectorSearch workload. Dimensions: $dimensions, Similarity: $similarityFunction" } + + if (datasetPath.isNotEmpty() && !hdf5Loaded) { + loadHdf5Data() + } + + insert = session.prepare("INSERT INTO vector_test (id, val) VALUES (?, ?)") + select = session.prepare("SELECT id FROM vector_test ORDER BY val ANN OF ? LIMIT ?") + delete = session.prepare("DELETE FROM vector_test WHERE id = ?") + } + + private fun loadHdf5Data() { + val file = File(datasetPath) + if (!file.exists()) { + throw IllegalArgumentException("Dataset file not found: $datasetPath") + } + + log.info { "Loading vectors from HDF5: $datasetPath" } + HdfFile(Paths.get(datasetPath)).use { hdf -> + // Inserts + val trainData = hdf.getDatasetByPath(trainDataset).data + trainVectors = convertToFloatArray(trainData) + log.info { "Loaded ${trainVectors.size} training vectors." } + + // Selects + val queryData = hdf.getDatasetByPath(queryDataset).data + queryVectors = convertToFloatArray(queryData) + log.info { "Loaded ${queryVectors.size} query vectors." } + + // Validate dimensions + if (trainVectors.isNotEmpty() && trainVectors[0].size != dimensions) { + log.warn { + "Dataset dimensions (${trainVectors[0].size}) do not match configured dimensions ($dimensions). Updating." + } + dimensions = trainVectors[0].size + } + + // If recall calculation is enabled, ground truth values are required + if (calculateRecall) { + try { + val neighborsData = hdf.getDatasetByPath(neighborsDataset).data + groundTruth = convertToIntArray(neighborsData) + hasGroundTruth = true + log.info { "Loaded ground truth with ${groundTruth.size} query neighbors." } + } catch (e: Exception) { + log.warn { "Ground truth dataset '$neighborsDataset' not found. Recall calculation disabled." } + calculateRecall = false + } + } + + hdf5Loaded = true + } + } + + override fun schema(): List { + if (datasetPath.isNotEmpty() && !hdf5Loaded) { + loadHdf5Data() + } + + return listOf( + """ + CREATE TABLE IF NOT EXISTS vector_test ( + id text PRIMARY KEY, + val vector + ) + """.trimIndent(), + """ + CREATE INDEX IF NOT EXISTS ann_index ON vector_test(val) + USING 'sai' + WITH OPTIONS = {'similarity_function': '$similarityFunction'} + """.trimIndent(), + ) + } + + override fun getDefaultReadRate(): Double = 0.5 + + override fun getRunner(context: StressContext): IStressRunner { + return object : IStressRunner { + override fun getNextMutation(partitionKey: PartitionKey): Operation { + val vector: CqlVector + val id: String + + if (hdf5Loaded) { + val trainIdx = (insertCounter.getAndIncrement() % trainVectors.size).toInt() + vector = CqlVector.newInstance(trainVectors[trainIdx].toList()) + id = trainIdx.toString() + insertedIndices.add(trainIdx) + } else { + vector = generateRandomVector(dimensions) + id = partitionKey.getText() + } + + val bound = + insert + .bind() + .setString(0, id) + .setVector(1, vector, Float::class.javaObjectType) + + return Operation.Mutation(bound) + } + + override fun getNextSelect(partitionKey: PartitionKey): Operation { + val vector: CqlVector + var queryIdx: Int? = null + + if (hdf5Loaded) { + queryIdx = ThreadLocalRandom.current().nextInt(queryVectors.size) + vector = CqlVector.newInstance(queryVectors[queryIdx].toList()) + } else { + vector = generateRandomVector(dimensions) + } + + val bound = + select + .bind() + .setVector(0, vector, Float::class.javaObjectType) + .setInt(1, limit) + + // Pass query index as callback payload for recall calculation + val payload = + if (calculateRecall && hasGroundTruth && queryIdx != null) { + RecallPayload(queryIdx) + } else { + null + } + + return Operation.SelectStatement(bound, payload) + } + + override fun getNextDelete(partitionKey: PartitionKey): Operation { + // No need to track deletes for recall calculation since we use training indices as IDs + val bound = delete.bind().setString(0, partitionKey.getText()) + return Operation.Deletion(bound) + } + + override fun onSuccess( + op: Operation.SelectStatement, + result: AsyncResultSet?, + ) { + if (result == null) return + val payload = op.callbackPayload as? RecallPayload ?: return + + // Parse returned IDs as training indices + val returnedIndices = mutableSetOf() + for (row in result.currentPage()) { + val id = row.getString("id") ?: continue + id.toIntOrNull()?.let { returnedIndices.add(it) } + } + + // take top K from ground truth for recall calculation + val truthIndices = groundTruth[payload.queryIndex].take(limit).toSet() + + // Only count neighbors that have actually been inserted + val relevantTruth = truthIndices.intersect(insertedIndices) + + // Skip recall calculation if no ground truth neighbors are inserted yet + if (relevantTruth.isEmpty()) { + skippedQueryCount.incrementAndGet() + return + } + + // number of relevant items successfully retrieved + val hits = returnedIndices.intersect(truthIndices).size + val recall = hits.toDouble() / relevantTruth.size + + val recallFixed = (recall * 10000).toInt() + totalRecall.addAndGet(recallFixed.toLong()) + minRecall.updateAndGet { min -> minOf(min, recallFixed) } + maxRecall.updateAndGet { max -> maxOf(max, recallFixed) } + + val count = queryCount.incrementAndGet() + + // Log periodic summary + if (count % recallLogInterval == 0L) { + val skipped = skippedQueryCount.get() + val avgRecall = totalRecall.get().toDouble() / count / 10000 + val minR = minRecall.get().toDouble() / 10000 + val maxR = maxRecall.get().toDouble() / 10000 + log.info { + "Recall@$limit after $count queries: avg=%.3f, min=%.2f, max=%.2f (indexed: ${insertedIndices.size}, skipped: $skipped)" + .format( + avgRecall, + minR, + maxR, + ) + } + } + } + } + } + + private fun generateRandomVector(dim: Int): CqlVector { + val rand = ThreadLocalRandom.current() + val list = ArrayList(dim) + for (i in 0 until dim) { + list.add(rand.nextFloat()) + } + return CqlVector.newInstance(list) + } + + private data class RecallPayload( + val queryIndex: Int, + ) + + companion object { + /** + * Converts HDF5 numeric array data to Array. + * Handles both float[][] and double[][] formats commonly found in vector datasets. + */ + fun convertToFloatArray(data: Any): Array { + return when (data) { + is Array<*> -> { + if (data.isArrayOf()) { + @Suppress("UNCHECKED_CAST") + return data as Array + } + if (data.isArrayOf()) { + @Suppress("UNCHECKED_CAST") + val doubleArray = data as Array + return Array(doubleArray.size) { i -> + FloatArray(doubleArray[i].size) { j -> doubleArray[i][j].toFloat() } + } + } + throw IllegalArgumentException("Unsupported array type in HDF5: ${data::class.java}") + } + + else -> throw IllegalArgumentException("Unsupported data format in HDF5: ${data::class.java}") + } + } + + /** + * Converts HDF5 integer array data to Array. + * Used for ground truth neighbor indices. + */ + fun convertToIntArray(data: Any): Array { + return when (data) { + is Array<*> -> { + if (data.isArrayOf()) { + @Suppress("UNCHECKED_CAST") + return data as Array + } + if (data.isArrayOf()) { + @Suppress("UNCHECKED_CAST") + val longArray = data as Array + return Array(longArray.size) { i -> + IntArray(longArray[i].size) { j -> longArray[i][j].toInt() } + } + } + throw IllegalArgumentException("Unsupported array type for neighbors: ${data::class.java}") + } + + else -> throw IllegalArgumentException("Unsupported data format for neighbors: ${data::class.java}") + } + } + } +} diff --git a/src/test/kotlin/org/apache/cassandra/easystress/workloads/VectorSearchTest.kt b/src/test/kotlin/org/apache/cassandra/easystress/workloads/VectorSearchTest.kt new file mode 100644 index 0000000..933c3f1 --- /dev/null +++ b/src/test/kotlin/org/apache/cassandra/easystress/workloads/VectorSearchTest.kt @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.cassandra.easystress.workloads + +import org.apache.cassandra.easystress.Workload +import org.assertj.core.api.Assertions.assertThat +import org.assertj.core.api.Assertions.assertThatThrownBy +import org.junit.jupiter.api.Test + +class VectorSearchTest { + @Test + fun `schema should reflect configured dimensions and similarity`() { + val workload = + VectorSearch().apply { + dimensions = 768 + similarityFunction = "EUCLIDEAN" + } + + val schema = workload.schema() + assertThat(schema).hasSize(2) + + // Check table schema + assertThat(schema[0]).contains("vector") + + // Check index schema + assertThat(schema[1]).contains("USING 'sai'") + assertThat(schema[1]).contains("'similarity_function': 'EUCLIDEAN'") + } + + @Test + fun `test dynamic workload parameters`() { + val workload = Workload.getWorkloads()["VectorSearch"]!! + + workload.applyDynamicSettings(mapOf("dimensions" to "1024", "limit" to "50")) + + val instance = workload.instance as VectorSearch + assertThat(instance.dimensions).isEqualTo(1024) + assertThat(instance.limit).isEqualTo(50) + } + + @Test + fun `convertToFloatArray should handle float arrays`() { + val floatData: Array = + arrayOf( + floatArrayOf(1.0f, 2.0f, 3.0f), + floatArrayOf(4.0f, 5.0f, 6.0f), + ) + + val result = VectorSearch.convertToFloatArray(floatData) + + assertThat(result).hasSize(2) + assertThat(result[0]).containsExactly(1.0f, 2.0f, 3.0f) + assertThat(result[1]).containsExactly(4.0f, 5.0f, 6.0f) + } + + @Test + fun `convertToFloatArray should handle double arrays`() { + val doubleData: Array = + arrayOf( + doubleArrayOf(1.0, 2.0, 3.0), + doubleArrayOf(4.0, 5.0, 6.0), + ) + + val result = VectorSearch.convertToFloatArray(doubleData) + + assertThat(result).hasSize(2) + assertThat(result[0]).containsExactly(1.0f, 2.0f, 3.0f) + assertThat(result[1]).containsExactly(4.0f, 5.0f, 6.0f) + } + + @Test + fun `convertToFloatArray should throw on unsupported types`() { + val stringData = "not an array" + + assertThatThrownBy { VectorSearch.convertToFloatArray(stringData) } + .isInstanceOf(IllegalArgumentException::class.java) + .hasMessageContaining("Unsupported data format") + } + + @Test + fun `convertToIntArray should handle int arrays`() { + val intData: Array = + arrayOf( + intArrayOf(1, 2, 3), + intArrayOf(4, 5, 6), + ) + + val result = VectorSearch.convertToIntArray(intData) + + assertThat(result).hasSize(2) + assertThat(result[0]).containsExactly(1, 2, 3) + assertThat(result[1]).containsExactly(4, 5, 6) + } + + @Test + fun `convertToIntArray should handle long arrays`() { + val longData: Array = + arrayOf( + longArrayOf(1L, 2L, 3L), + longArrayOf(4L, 5L, 6L), + ) + + val result = VectorSearch.convertToIntArray(longData) + + assertThat(result).hasSize(2) + assertThat(result[0]).containsExactly(1, 2, 3) + assertThat(result[1]).containsExactly(4, 5, 6) + } +}