diff --git a/fdb-extensions/fdb-extensions.gradle b/fdb-extensions/fdb-extensions.gradle index 6c77f13db9..200324d730 100644 --- a/fdb-extensions/fdb-extensions.gradle +++ b/fdb-extensions/fdb-extensions.gradle @@ -27,6 +27,7 @@ dependencies { } api(libs.fdbJava) implementation(libs.guava) + implementation(libs.half4j) implementation(libs.slf4j.api) compileOnly(libs.jsr305) @@ -41,6 +42,38 @@ dependencies { testFixturesAnnotationProcessor(libs.autoService) } +def siftSmallFile = layout.buildDirectory.file('downloads/siftsmall.tar.gz') +def extractDir = layout.buildDirectory.dir("extracted") + +// Task that downloads the CSV exactly once unless it changed +tasks.register('downloadSiftSmall', de.undercouch.gradle.tasks.download.Download) { + src 'https://huggingface.co/datasets/vecdata/siftsmall/resolve/3106e1b83049c44713b1ce06942d0ab474bbdfb6/siftsmall.tar.gz' + dest siftSmallFile.get().asFile + onlyIfModified true + tempAndMove true + retries 3 +} + +tasks.register('extractSiftSmall', Copy) { + dependsOn 'downloadSiftSmall' + from(tarTree(resources.gzip(siftSmallFile))) + into extractDir + + doLast { + println "Extracted files into: ${extractDir.get().asFile}" + fileTree(extractDir).visit { details -> + if (!details.isDirectory()) { + println " - ${details.file}" + } + } + } +} + +test { + dependsOn tasks.named('extractSiftSmall') + inputs.dir extractDir +} + publishing { publications { library(MavenPublication) { diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/MoreAsyncUtil.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/MoreAsyncUtil.java index 563dec11a6..e696512fdd 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/async/MoreAsyncUtil.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/MoreAsyncUtil.java @@ -23,12 +23,14 @@ import com.apple.foundationdb.annotation.API; import com.apple.foundationdb.util.LoggableException; import com.google.common.base.Suppliers; +import com.google.common.collect.Lists; import com.google.common.util.concurrent.ThreadFactoryBuilder; import javax.annotation.Nonnull; import javax.annotation.Nullable; import java.util.ArrayDeque; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.Iterator; import java.util.List; @@ -42,9 +44,13 @@ import java.util.concurrent.ScheduledThreadPoolExecutor; import java.util.concurrent.ThreadFactory; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.BiConsumer; import java.util.function.BiFunction; import java.util.function.Function; +import java.util.function.IntPredicate; +import java.util.function.IntUnaryOperator; import java.util.function.Predicate; import java.util.function.Supplier; @@ -1051,6 +1057,93 @@ public static CompletableFuture swallowException(@Nonnull CompletableFutur return result; } + /** + * Method that provides the functionality of a for loop, however, in an asynchronous way. The result of this method + * is a {@link CompletableFuture} that represents the result of the last iteration of the loop body. + * @param startI an integer analogous to the starting value of a loop variable in a for loop + * @param startU an object of some type {@code U} that represents some initial state that is passed to the loop's + * initial state + * @param conditionPredicate a predicate on the loop variable that must be true before the next iteration is + * entered; analogous to the condition in a for loop + * @param stepFunction a unary operator used for modifying the loop variable after each iteration + * @param body a bi-function to be called for each iteration; this function is initially invoked using + * {@code startI} and {@code startU}; the result of the body is then passed into the next iterator's body + * together with a new value for the loop variable. In this way callers can access state inside an iteration + * that was computed in a previous iteration. + * @param executor the executor + * @param the type of the result of the body {@link BiFunction} + * @return a {@link CompletableFuture} containing the result of the last iteration's body invocation. + */ + @Nonnull + public static CompletableFuture forLoop(final int startI, @Nullable final U startU, + @Nonnull final IntPredicate conditionPredicate, + @Nonnull final IntUnaryOperator stepFunction, + @Nonnull final BiFunction> body, + @Nonnull final Executor executor) { + final AtomicInteger loopVariableAtomic = new AtomicInteger(startI); + final AtomicReference lastResultAtomic = new AtomicReference<>(startU); + return whileTrue(() -> { + final int loopVariable = loopVariableAtomic.get(); + if (!conditionPredicate.test(loopVariable)) { + return AsyncUtil.READY_FALSE; + } + return body.apply(loopVariable, lastResultAtomic.get()) + .thenApply(result -> { + loopVariableAtomic.set(stepFunction.applyAsInt(loopVariable)); + lastResultAtomic.set(result); + return true; + }); + }, executor).thenApply(ignored -> lastResultAtomic.get()); + } + + /** + * Method to iterate over some items, for each of which a body is executed asynchronously. The result of each such + * executed is then collected in a list and returned as a {@link CompletableFuture} over that list. + * @param items the items to iterate over + * @param body a function to be called for each item + * @param parallelism the maximum degree of parallelism this method should use + * @param executor the executor + * @param the type of item + * @param the type of the result + * @return a {@link CompletableFuture} containing a list of results collected from the individual body invocations + */ + @Nonnull + @SuppressWarnings("unchecked") + public static CompletableFuture> forEach(@Nonnull final Iterable items, + @Nonnull final Function> body, + final int parallelism, + @Nonnull final Executor executor) { + // this deque is only modified by once upon creation + final ArrayDeque toBeProcessed = new ArrayDeque<>(); + for (final T item : items) { + toBeProcessed.addLast(item); + } + + final List> working = Lists.newArrayList(); + final AtomicInteger indexAtomic = new AtomicInteger(0); + final Object[] resultArray = new Object[toBeProcessed.size()]; + + return whileTrue(() -> { + working.removeIf(CompletableFuture::isDone); + + while (working.size() <= parallelism) { + final T currentItem = toBeProcessed.pollFirst(); + if (currentItem == null) { + break; + } + + final int index = indexAtomic.getAndIncrement(); + working.add(body.apply(currentItem) + .thenAccept(result -> resultArray[index] = result)); + } + + if (working.isEmpty()) { + return AsyncUtil.READY_FALSE; + } + return whenAny(working).thenApply(ignored -> true); + }, executor).thenApply(ignored -> Arrays.asList((U[])resultArray)); + } + /** * A {@code Boolean} function that is always true. * @param the type of the (ignored) argument to the function diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/AbstractNode.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/AbstractNode.java new file mode 100644 index 0000000000..252185f38b --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/AbstractNode.java @@ -0,0 +1,98 @@ +/* + * AbstractNode.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2023 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.apple.foundationdb.tuple.Tuple; +import com.google.common.collect.ImmutableList; + +import javax.annotation.Nonnull; +import java.util.List; + +/** + * An abstract base class implementing the {@link Node} interface. + *

+ * This class provides the fundamental structure for a node within the HNSW graph, + * managing a unique {@link Tuple} primary key and an immutable list of its neighbors. + * Subclasses are expected to provide concrete implementations, potentially adding + * more state or behavior. + * + * @param the type of the node reference used for neighbors, which must extend {@link NodeReference} + */ +abstract class AbstractNode implements Node { + @Nonnull + private final Tuple primaryKey; + + @Nonnull + private final List neighbors; + + /** + * Constructs a new {@code AbstractNode} with a specified primary key and a list of neighbors. + *

+ * This constructor creates a defensive, immutable copy of the provided {@code neighbors} list. + * This ensures that the internal state of the node cannot be modified by external + * changes to the original list after construction. + * + * @param primaryKey the unique identifier for this node; must not be {@code null} + * @param neighbors the list of nodes connected to this node; must not be {@code null} + */ + protected AbstractNode(@Nonnull final Tuple primaryKey, + @Nonnull final List neighbors) { + this.primaryKey = primaryKey; + this.neighbors = ImmutableList.copyOf(neighbors); + } + + /** + * Gets the primary key that uniquely identifies this object. + * @return the primary key {@link Tuple}, which will never be {@code null}. + */ + @Nonnull + @Override + public Tuple getPrimaryKey() { + return primaryKey; + } + + /** + * Gets the list of neighbors connected to this node. + *

+ * This method returns a direct reference to the internal list which is + * immutable. + * @return a non-null, possibly empty, list of neighbors. + */ + @Nonnull + @Override + public List getNeighbors() { + return neighbors; + } + + /** + * Gets the neighbor at the specified index. + *

+ * This method provides access to a specific neighbor by its zero-based position + * in the internal list of neighbors. + * @param index the zero-based index of the neighbor to retrieve. + * @return the neighbor at the specified index, guaranteed to be non-null. + */ + @Nonnull + @Override + public N getNeighbor(final int index) { + return neighbors.get(index); + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/AbstractStorageAdapter.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/AbstractStorageAdapter.java new file mode 100644 index 0000000000..2b0e17da69 --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/AbstractStorageAdapter.java @@ -0,0 +1,276 @@ +/* + * AbstractStorageAdapter.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2023 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.apple.foundationdb.ReadTransaction; +import com.apple.foundationdb.Transaction; +import com.apple.foundationdb.subspace.Subspace; +import com.apple.foundationdb.tuple.Tuple; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import java.util.concurrent.CompletableFuture; + +/** + * An abstract base class for {@link StorageAdapter} implementations. + *

+ * This class provides the common infrastructure for managing HNSW graph data within {@link Subspace}. + * It handles the configuration, node creation, and listener management, while delegating the actual + * storage-specific read and write operations to concrete subclasses through the {@code fetchNodeInternal} + * and {@code writeNodeInternal} abstract methods. + * + * @param the type of {@link NodeReference} used to reference nodes in the graph + */ +abstract class AbstractStorageAdapter implements StorageAdapter { + @Nonnull + private static final Logger logger = LoggerFactory.getLogger(AbstractStorageAdapter.class); + + @Nonnull + private final HNSW.Config config; + @Nonnull + private final NodeFactory nodeFactory; + @Nonnull + private final Subspace subspace; + @Nonnull + private final OnWriteListener onWriteListener; + @Nonnull + private final OnReadListener onReadListener; + + private final Subspace dataSubspace; + + /** + * Constructs a new {@code AbstractStorageAdapter}. + *

+ * This constructor initializes the adapter with the necessary configuration, + * factories, and listeners for managing an HNSW graph. It also sets up a + * dedicated data subspace within the provided main subspace for storing node data. + * + * @param config the HNSW graph configuration + * @param nodeFactory the factory to create new nodes of type {@code } + * @param subspace the primary subspace for storing all graph-related data + * @param onWriteListener the listener to be called on write operations + * @param onReadListener the listener to be called on read operations + */ + protected AbstractStorageAdapter(@Nonnull final HNSW.Config config, @Nonnull final NodeFactory nodeFactory, + @Nonnull final Subspace subspace, + @Nonnull final OnWriteListener onWriteListener, + @Nonnull final OnReadListener onReadListener) { + this.config = config; + this.nodeFactory = nodeFactory; + this.subspace = subspace; + this.onWriteListener = onWriteListener; + this.onReadListener = onReadListener; + this.dataSubspace = subspace.subspace(Tuple.from(SUBSPACE_PREFIX_DATA)); + } + + /** + * Returns the configuration used to build and search this HNSW graph. + * + * @return the current {@link HNSW.Config} object, never {@code null}. + */ + @Override + @Nonnull + public HNSW.Config getConfig() { + return config; + } + + /** + * Gets the factory responsible for creating new nodes. + *

+ * This factory is used to instantiate nodes of the generic type {@code N} + * for the current context. The {@code @Nonnull} annotation guarantees that + * this method will never return {@code null}. + * + * @return the non-null {@link NodeFactory} instance. + */ + @Nonnull + @Override + public NodeFactory getNodeFactory() { + return nodeFactory; + } + + /** + * Gets the kind of this node, which uniquely identifies the type of node. + *

+ * This method is an override and provides a way to determine the concrete + * type of node without using {@code instanceof} checks. + * + * @return the non-null {@link NodeKind} representing the type of this node. + */ + @Nonnull + @Override + public NodeKind getNodeKind() { + return getNodeFactory().getNodeKind(); + } + + /** + * Gets the subspace in which this key or value is stored. + *

+ * This subspace provides a logical separation for keys within the underlying key-value store. + * + * @return the non-null {@link Subspace} for this context + */ + @Override + @Nonnull + public Subspace getSubspace() { + return subspace; + } + + /** + * Gets the subspace for the data associated with this component. + *

+ * The data subspace defines the portion of the directory space where the data + * for this component is stored. + * + * @return the non-null {@link Subspace} for the data + */ + @Override + @Nonnull + public Subspace getDataSubspace() { + return dataSubspace; + } + + /** + * Returns the listener that is notified upon write events. + *

+ * This method is an override and guarantees a non-null return value, + * as indicated by the {@code @Nonnull} annotation. + * + * @return the configured {@link OnWriteListener} instance; will never be {@code null}. + */ + @Override + @Nonnull + public OnWriteListener getOnWriteListener() { + return onWriteListener; + } + + /** + * Gets the listener that is notified upon completion of a read operation. + *

+ * This method is an override and provides the currently configured listener instance. + * The returned listener is guaranteed to be non-null as indicated by the + * {@code @Nonnull} annotation. + * + * @return the non-null {@link OnReadListener} instance. + */ + @Override + @Nonnull + public OnReadListener getOnReadListener() { + return onReadListener; + } + + /** + * Asynchronously fetches a node from a specific layer of the HNSW. + *

+ * The node is identified by its {@code layer} and {@code primaryKey}. The entire fetch operation is + * performed within the given {@link ReadTransaction}. After the underlying + * fetch operation completes, the retrieved node is validated by the + * {@link #checkNode(Node)} method before the returned future is completed. + * + * @param readTransaction the non-null transaction to use for the read operation + * @param layer the layer of the tree from which to fetch the node + * @param primaryKey the non-null primary key that identifies the node to fetch + * + * @return a {@link CompletableFuture} that will complete with the fetched {@link Node} + * once it has been read from storage and validated + */ + @Nonnull + @Override + public CompletableFuture> fetchNode(@Nonnull final ReadTransaction readTransaction, + int layer, @Nonnull Tuple primaryKey) { + return fetchNodeInternal(readTransaction, layer, primaryKey).thenApply(this::checkNode); + } + + /** + * Asynchronously fetches a specific node from the data store for a given layer and primary key. + *

+ * This is an internal, abstract method that concrete subclasses must implement to define + * the storage-specific logic for retrieving a node. The operation is performed within the + * context of the provided {@link ReadTransaction}. + * + * @param readTransaction the transaction to use for the read operation; must not be {@code null} + * @param layer the layer index from which to fetch the node + * @param primaryKey the primary key that uniquely identifies the node to be fetched; must not be {@code null} + * + * @return a {@link CompletableFuture} that will be completed with the fetched {@link Node}. + * The future will complete with {@code null} if no node is found for the given key and layer. + */ + @Nonnull + protected abstract CompletableFuture> fetchNodeInternal(@Nonnull ReadTransaction readTransaction, + int layer, @Nonnull Tuple primaryKey); + + /** + * Method to perform basic invariant check(s) on a newly-fetched node. + * + * @param node the node to check + * was passed in + * + * @return the node that was passed in + */ + @Nullable + private Node checkNode(@Nullable final Node node) { + return node; + } + + /** + * Writes a given node and its neighbor modifications to the underlying storage. + *

+ * This operation is executed within the context of the provided {@link Transaction}. + * It handles persisting the node's data at a specific {@code layer} and applies + * the changes to its neighbors as defined in the {@link NeighborsChangeSet}. + * This method delegates the core writing logic to an internal method and provides + * debug logging upon completion. + * + * @param transaction the non-null {@link Transaction} context for this write operation + * @param node the non-null {@link Node} to be written to storage + * @param layer the layer index where the node is being written + * @param changeSet the non-null {@link NeighborsChangeSet} detailing the modifications + * to the node's neighbors + */ + @Override + public void writeNode(@Nonnull Transaction transaction, @Nonnull Node node, int layer, + @Nonnull NeighborsChangeSet changeSet) { + writeNodeInternal(transaction, node, layer, changeSet); + if (logger.isDebugEnabled()) { + logger.debug("written node with key={} at layer={}", node.getPrimaryKey(), layer); + } + } + + /** + * Writes a single node to the data store as part of a larger transaction. + *

+ * This is an abstract method that concrete implementations must provide. + * It is responsible for the low-level persistence of the given {@code node} at a + * specific {@code layer}. The implementation should also handle the modifications + * to the node's neighbors, as detailed in the {@code changeSet}. + * + * @param transaction the non-null transaction context for the write operation + * @param node the non-null {@link Node} to write + * @param layer the layer or level of the node in the structure + * @param changeSet the non-null {@link NeighborsChangeSet} detailing additions or + * removals of neighbor links + */ + protected abstract void writeNodeInternal(@Nonnull Transaction transaction, @Nonnull Node node, int layer, + @Nonnull NeighborsChangeSet changeSet); + +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/BaseNeighborsChangeSet.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/BaseNeighborsChangeSet.java new file mode 100644 index 0000000000..5d27783b9e --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/BaseNeighborsChangeSet.java @@ -0,0 +1,95 @@ +/* + * BaseNeighborsChangeSet.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2023 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.apple.foundationdb.Transaction; +import com.apple.foundationdb.tuple.Tuple; +import com.google.common.collect.ImmutableList; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import java.util.List; +import java.util.function.Predicate; + +/** + * A base implementation of the {@link NeighborsChangeSet} interface. + *

+ * This class represents a complete, non-delta state of a node's neighbors. It holds a fixed, immutable + * list of neighbors provided at construction time. As such, it does not support parent change sets or writing deltas. + * + * @param the type of the node reference, which must extend {@link NodeReference} + */ +class BaseNeighborsChangeSet implements NeighborsChangeSet { + @Nonnull + private final List neighbors; + + /** + * Creates a new change set with the specified neighbors. + *

+ * This constructor creates an immutable copy of the provided list. + * + * @param neighbors the list of neighbors for this change set; must not be null. + */ + public BaseNeighborsChangeSet(@Nonnull final List neighbors) { + this.neighbors = ImmutableList.copyOf(neighbors); + } + + /** + * Gets the parent change set. + *

+ * This implementation always returns {@code null}, as this type of change set + * does not have a parent. + * + * @return always {@code null}. + */ + @Nullable + @Override + public BaseNeighborsChangeSet getParent() { + return null; + } + + /** + * Retrieves the list of neighbors associated with this object. + *

+ * This implementation fulfills the {@code merge} contract by simply returning the + * existing list of neighbors without performing any additional merging logic. + * @return a non-null list of neighbors. The generic type {@code N} represents + * the type of the neighboring elements. + */ + @Nonnull + @Override + public List merge() { + return neighbors; + } + + /** + * {@inheritDoc} + * + *

This implementation is a no-op and does not write any delta information, + * as indicated by the empty method body. + */ + @Override + public void writeDelta(@Nonnull final InliningStorageAdapter storageAdapter, @Nonnull final Transaction transaction, + final int layer, @Nonnull final Node node, + @Nonnull final Predicate primaryKeyPredicate) { + // nothing to be written + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/CompactNode.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/CompactNode.java new file mode 100644 index 0000000000..b594e70a2f --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/CompactNode.java @@ -0,0 +1,163 @@ +/* + * CompactNode.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2023 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.apple.foundationdb.annotation.SpotBugsSuppressWarnings; +import com.apple.foundationdb.tuple.Tuple; +import com.christianheina.langx.half4j.Half; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import java.util.List; +import java.util.Objects; + +/** + * Represents a compact node within a graph structure, extending {@link AbstractNode}. + *

+ * This node type is considered "compact" because it directly stores its associated + * data vector of type {@link Vector}. It is used to represent a vector in a + * vector space and maintains references to its neighbors via {@link NodeReference} objects. + * + * @see AbstractNode + * @see NodeReference + */ +public class CompactNode extends AbstractNode { + @Nonnull + private static final NodeFactory FACTORY = new NodeFactory<>() { + @SuppressWarnings("unchecked") + @Nonnull + @Override + @SpotBugsSuppressWarnings("NP_PARAMETER_MUST_BE_NONNULL_BUT_MARKED_AS_NULLABLE") + public Node create(@Nonnull final Tuple primaryKey, @Nullable final Vector vector, + @Nonnull final List neighbors) { + return new CompactNode(primaryKey, Objects.requireNonNull(vector), (List)neighbors); + } + + @Nonnull + @Override + public NodeKind getNodeKind() { + return NodeKind.COMPACT; + } + }; + + @Nonnull + private final Vector vector; + + /** + * Constructs a new {@code CompactNode} instance. + *

+ * This constructor initializes the node with its primary key, a data vector, + * and a list of its neighbors. It delegates the initialization of the + * {@code primaryKey} and {@code neighbors} to the superclass constructor. + * + * @param primaryKey the primary key that uniquely identifies this node; must not be {@code null}. + * @param vector the data vector of type {@code Vector} associated with this node; must not be {@code null}. + * @param neighbors a list of {@link NodeReference} objects representing the neighbors of this node; must not be + * {@code null}. + */ + public CompactNode(@Nonnull final Tuple primaryKey, @Nonnull final Vector vector, + @Nonnull final List neighbors) { + super(primaryKey, neighbors); + this.vector = vector; + } + + /** + * Returns a {@link NodeReference} that uniquely identifies this node. + *

+ * This implementation creates the reference using the node's primary key, obtained via {@code getPrimaryKey()}. It + * ignores the provided {@code vector} parameter, which exists to fulfill the contract of the overridden method. + * + * @param vector the vector context, which is ignored in this implementation. + * Per the {@code @Nullable} annotation, this can be {@code null}. + * + * @return a non-null {@link NodeReference} to this node. + */ + @Nonnull + @Override + public NodeReference getSelfReference(@Nullable final Vector vector) { + return new NodeReference(getPrimaryKey()); + } + + /** + * Gets the kind of this node. + * This implementation always returns {@link NodeKind#COMPACT}. + * @return the node kind, which is guaranteed to be {@link NodeKind#COMPACT}. + */ + @Nonnull + @Override + public NodeKind getKind() { + return NodeKind.COMPACT; + } + + /** + * Gets the vector of {@code Half} objects. + * @return the non-null vector of {@link Half} objects. + */ + @Nonnull + public Vector getVector() { + return vector; + } + + /** + * Returns this node as a {@code CompactNode}. As this class is already a {@code CompactNode}, this method provides + * {@code this}. + * @return this object cast as a {@code CompactNode}, which is guaranteed to be non-null. + */ + @Nonnull + @Override + public CompactNode asCompactNode() { + return this; + } + + /** + * Returns this node as an {@link InliningNode}. + *

+ * This override is for node types that are not inlining nodes. As such, it + * will always fail. + * @return this node as a non-null {@link InliningNode} + * @throws IllegalStateException always, as this is not an inlining node + */ + @Nonnull + @Override + public InliningNode asInliningNode() { + throw new IllegalStateException("this is not an inlining node"); + } + + /** + * Gets the shared factory instance for creating {@link NodeReference} objects. + *

+ * This static factory method is the preferred way to obtain a {@code NodeFactory} + * for {@link NodeReference} instances, as it returns a shared, pre-configured object. + * + * @return a shared, non-null instance of {@code NodeFactory} + */ + @Nonnull + public static NodeFactory factory() { + return FACTORY; + } + + @Override + public String toString() { + return "C[primaryKey=" + getPrimaryKey() + + ";vector=" + vector + + ";neighbors=" + getNeighbors() + "]"; + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/CompactStorageAdapter.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/CompactStorageAdapter.java new file mode 100644 index 0000000000..826ba57f9b --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/CompactStorageAdapter.java @@ -0,0 +1,303 @@ +/* + * CompactStorageAdapter.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2023 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.apple.foundationdb.KeyValue; +import com.apple.foundationdb.Range; +import com.apple.foundationdb.ReadTransaction; +import com.apple.foundationdb.StreamingMode; +import com.apple.foundationdb.Transaction; +import com.apple.foundationdb.async.AsyncIterable; +import com.apple.foundationdb.async.AsyncUtil; +import com.apple.foundationdb.subspace.Subspace; +import com.apple.foundationdb.tuple.ByteArrayUtil; +import com.apple.foundationdb.tuple.Tuple; +import com.google.common.base.Verify; +import com.google.common.collect.Lists; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import java.util.List; +import java.util.concurrent.CompletableFuture; + +/** + * The {@code CompactStorageAdapter} class is a concrete implementation of {@link StorageAdapter} for managing HNSW + * graph data in a compact format. + *

+ * It handles the serialization and deserialization of graph nodes to and from a persistent data store. This + * implementation is optimized for space efficiency by storing nodes with their accompanying vector data and by storing + * just neighbor primary keys. It extends {@link AbstractStorageAdapter} to inherit common storage logic. + */ +class CompactStorageAdapter extends AbstractStorageAdapter implements StorageAdapter { + @Nonnull + private static final Logger logger = LoggerFactory.getLogger(CompactStorageAdapter.class); + + /** + * Constructs a new {@code CompactStorageAdapter}. + *

+ * This constructor initializes the adapter by delegating to the superclass, + * setting up the necessary components for managing an HNSW graph. + * + * @param config the HNSW graph configuration, must not be null. See {@link HNSW.Config}. + * @param nodeFactory the factory used to create new nodes of type {@link NodeReference}, must not be null. + * @param subspace the {@link Subspace} where the graph data is stored, must not be null. + * @param onWriteListener the listener to be notified of write events, must not be null. + * @param onReadListener the listener to be notified of read events, must not be null. + */ + public CompactStorageAdapter(@Nonnull final HNSW.Config config, @Nonnull final NodeFactory nodeFactory, + @Nonnull final Subspace subspace, + @Nonnull final OnWriteListener onWriteListener, + @Nonnull final OnReadListener onReadListener) { + super(config, nodeFactory, subspace, onWriteListener, onReadListener); + } + + /** + * Returns this storage adapter instance, as it is already a compact storage adapter. + * @return the current instance, which serves as its own compact representation. + * This will never be {@code null}. + */ + @Nonnull + @Override + public StorageAdapter asCompactStorageAdapter() { + return this; + } + + /** + * Returns this adapter as a {@code StorageAdapter} that supports inlining. + *

+ * This operation is not supported by a compact storage adapter. Calling this method on this implementation will + * always result in an {@code IllegalStateException}. + * + * @return an instance of {@code StorageAdapter} that supports inlining + * + * @throws IllegalStateException unconditionally, as this operation is not supported + * on a compact storage adapter. + */ + @Nonnull + @Override + public StorageAdapter asInliningStorageAdapter() { + throw new IllegalStateException("cannot call this method on a compact storage adapter"); + } + + /** + * Asynchronously fetches a node from the database for a given layer and primary key. + *

+ * This internal method constructs a raw byte key from the {@code layer} and {@code primaryKey} + * within the store's data subspace. It then uses the provided {@link ReadTransaction} to + * retrieve the raw value. If a value is found, it is deserialized into a {@link Node} object + * using the {@code nodeFromRaw} method. + * + * @param readTransaction the transaction to use for the read operation + * @param layer the layer of the node to fetch + * @param primaryKey the primary key of the node to fetch + * + * @return a future that will complete with the fetched {@link Node} + * + * @throws IllegalStateException if the node cannot be found in the database for the given key + */ + @Nonnull + @Override + protected CompletableFuture> fetchNodeInternal(@Nonnull final ReadTransaction readTransaction, + final int layer, + @Nonnull final Tuple primaryKey) { + final byte[] keyBytes = getDataSubspace().pack(Tuple.from(layer, primaryKey)); + + return readTransaction.get(keyBytes) + .thenApply(valueBytes -> { + if (valueBytes == null) { + throw new IllegalStateException("cannot fetch node"); + } + return nodeFromRaw(layer, primaryKey, keyBytes, valueBytes); + }); + } + + /** + * Deserializes a raw key-value byte array pair into a {@code Node}. + *

+ * This method first converts the {@code valueBytes} into a {@link Tuple} and then, + * along with the {@code primaryKey}, constructs the final {@code Node} object. + * It also notifies any registered {@link OnReadListener} about the raw key-value + * read and the resulting node creation. + * + * @param layer the layer of the HNSW where this node resides + * @param primaryKey the primary key for the node + * @param keyBytes the raw byte representation of the node's key + * @param valueBytes the raw byte representation of the node's value, which will be deserialized + * + * @return a non-null, deserialized {@link Node} object + */ + @Nonnull + private Node nodeFromRaw(final int layer, final @Nonnull Tuple primaryKey, + @Nonnull final byte[] keyBytes, @Nonnull final byte[] valueBytes) { + final Tuple nodeTuple = Tuple.fromBytes(valueBytes); + final Node node = nodeFromKeyValuesTuples(primaryKey, nodeTuple); + final OnReadListener onReadListener = getOnReadListener(); + onReadListener.onNodeRead(layer, node); + onReadListener.onKeyValueRead(layer, keyBytes, valueBytes); + return node; + } + + /** + * Constructs a compact {@link Node} from its representation as stored key and value tuples. + *

+ * This method deserializes a node by extracting its components from the provided tuples. It verifies that the + * node is of type {@link NodeKind#COMPACT} before delegating the final construction to + * {@link #compactNodeFromTuples(Tuple, Tuple, Tuple)}. The {@code valueTuple} is expected to have a specific + * structure: the serialized node kind at index 0, a nested tuple for the vector at index 1, and a nested + * tuple for the neighbors at index 2. + * + * @param primaryKey the tuple representing the primary key of the node + * @param valueTuple the tuple containing the serialized node data, including kind, vector, and neighbors + * + * @return the reconstructed compact {@link Node} + * + * @throws com.google.common.base.VerifyException if the node kind encoded in {@code valueTuple} is not + * {@link NodeKind#COMPACT} + */ + @Nonnull + private Node nodeFromKeyValuesTuples(@Nonnull final Tuple primaryKey, + @Nonnull final Tuple valueTuple) { + final NodeKind nodeKind = NodeKind.fromSerializedNodeKind((byte)valueTuple.getLong(0)); + Verify.verify(nodeKind == NodeKind.COMPACT); + + final Tuple vectorTuple; + final Tuple neighborsTuple; + + vectorTuple = valueTuple.getNestedTuple(1); + neighborsTuple = valueTuple.getNestedTuple(2); + return compactNodeFromTuples(primaryKey, vectorTuple, neighborsTuple); + } + + /** + * Creates a compact in-memory representation of a graph node from its constituent storage tuples. + *

+ * This method deserializes the raw data stored in {@code Tuple} objects into their + * corresponding in-memory types. It extracts the vector, constructs a list of + * {@link NodeReference} objects for the neighbors, and then uses a factory to + * assemble the final {@code Node} object. + *

+ * + * @param primaryKey the tuple representing the node's primary key + * @param vectorTuple the tuple containing the node's vector data + * @param neighborsTuple the tuple containing a list of nested tuples, where each nested tuple represents a neighbor + * + * @return a new {@code Node} instance containing the deserialized data from the input tuples + */ + @Nonnull + private Node compactNodeFromTuples(@Nonnull final Tuple primaryKey, + @Nonnull final Tuple vectorTuple, + @Nonnull final Tuple neighborsTuple) { + final Vector vector = StorageAdapter.vectorFromTuple(vectorTuple); + final List nodeReferences = Lists.newArrayListWithExpectedSize(neighborsTuple.size()); + + for (int i = 0; i < neighborsTuple.size(); i ++) { + final Tuple neighborTuple = neighborsTuple.getNestedTuple(i); + nodeReferences.add(new NodeReference(neighborTuple)); + } + + return getNodeFactory().create(primaryKey, vector, nodeReferences); + } + + /** + * Writes the internal representation of a compact node to the data store within a given transaction. + * This method handles the serialization of the node's vector and its final set of neighbors based on the + * provided {@code neighborsChangeSet}. + * + *

The node is stored as a {@link Tuple} with the structure {@code (NodeKind, Vector, NeighborPrimaryKeys)}. + * The key for the storage is derived from the node's layer and its primary key. After writing, it notifies any + * registered write listeners via {@code onNodeWritten} and {@code onKeyValueWritten}. + * + * @param transaction the {@link Transaction} to use for the write operation. + * @param node the {@link Node} to be serialized and written; it is processed as a {@link CompactNode}. + * @param layer the graph layer index for the node, used to construct the storage key. + * @param neighborsChangeSet a {@link NeighborsChangeSet} containing the additions and removals, which are + * merged to determine the final set of neighbors to be written. + */ + @Override + public void writeNodeInternal(@Nonnull final Transaction transaction, @Nonnull final Node node, + final int layer, @Nonnull final NeighborsChangeSet neighborsChangeSet) { + final byte[] key = getDataSubspace().pack(Tuple.from(layer, node.getPrimaryKey())); + + final List nodeItems = Lists.newArrayListWithExpectedSize(3); + nodeItems.add(NodeKind.COMPACT.getSerialized()); + final CompactNode compactNode = node.asCompactNode(); + nodeItems.add(StorageAdapter.tupleFromVector(compactNode.getVector())); + + final Iterable neighbors = neighborsChangeSet.merge(); + + final List neighborItems = Lists.newArrayList(); + for (final NodeReference neighborReference : neighbors) { + neighborItems.add(neighborReference.getPrimaryKey()); + } + nodeItems.add(Tuple.fromList(neighborItems)); + + final Tuple nodeTuple = Tuple.fromList(nodeItems); + + final byte[] value = nodeTuple.pack(); + transaction.set(key, value); + getOnWriteListener().onNodeWritten(layer, node); + getOnWriteListener().onKeyValueWritten(layer, key, value); + + if (logger.isDebugEnabled()) { + logger.debug("written neighbors of primaryKey={}, oldSize={}, newSize={}", node.getPrimaryKey(), + node.getNeighbors().size(), neighborItems.size()); + } + } + + /** + * Scans a given layer for nodes, returning an iterable over the results. + *

+ * This method reads a limited number of nodes from a specific layer in the underlying data store. + * The scan can be started from a specific point using the {@code lastPrimaryKey} parameter, which is + * useful for paginating through the nodes in a large layer. + * + * @param readTransaction the transaction to use for reading data; must not be {@code null} + * @param layer the layer to scan for nodes + * @param lastPrimaryKey the primary key of the last node from a previous scan. If {@code null}, + * the scan starts from the beginning of the layer. + * @param maxNumRead the maximum number of nodes to read in this scan + * + * @return an {@link Iterable} of {@link Node} objects found in the specified layer, + * limited by {@code maxNumRead} + */ + @Nonnull + @Override + public Iterable> scanLayer(@Nonnull final ReadTransaction readTransaction, int layer, + @Nullable final Tuple lastPrimaryKey, int maxNumRead) { + final byte[] layerPrefix = getDataSubspace().pack(Tuple.from(layer)); + final Range range = + lastPrimaryKey == null + ? Range.startsWith(layerPrefix) + : new Range(ByteArrayUtil.strinc(getDataSubspace().pack(Tuple.from(layer, lastPrimaryKey))), + ByteArrayUtil.strinc(layerPrefix)); + final AsyncIterable itemsIterable = + readTransaction.getRange(range, maxNumRead, false, StreamingMode.ITERATOR); + + return AsyncUtil.mapIterable(itemsIterable, keyValue -> { + final byte[] key = keyValue.getKey(); + final byte[] value = keyValue.getValue(); + final Tuple primaryKey = getDataSubspace().unpack(key).getNestedTuple(1); + return nodeFromRaw(layer, primaryKey, key, value); + }); + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/DeleteNeighborsChangeSet.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/DeleteNeighborsChangeSet.java new file mode 100644 index 0000000000..a4852b66a1 --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/DeleteNeighborsChangeSet.java @@ -0,0 +1,137 @@ +/* + * DeleteNeighborsChangeSet.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2023 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.apple.foundationdb.Transaction; +import com.apple.foundationdb.tuple.Tuple; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Iterables; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.Nonnull; +import java.util.Collection; +import java.util.Set; +import java.util.function.Predicate; + +/** + * A {@link NeighborsChangeSet} that represents the deletion of a set of neighbors from a parent change set. + *

+ * This class acts as a filter, wrapping a parent {@link NeighborsChangeSet} and providing a view of the neighbors + * that excludes those whose primary keys have been marked for deletion. + * + * @param the type of the node reference, which must extend {@link NodeReference} + */ +class DeleteNeighborsChangeSet implements NeighborsChangeSet { + @Nonnull + private static final Logger logger = LoggerFactory.getLogger(DeleteNeighborsChangeSet.class); + + @Nonnull + private final NeighborsChangeSet parent; + + @Nonnull + private final Set deletedNeighborsPrimaryKeys; + + /** + * Constructs a new {@code DeleteNeighborsChangeSet}. + *

+ * This object represents a set of changes where specific neighbors are marked for deletion. + * It holds a reference to a parent {@link NeighborsChangeSet} and creates an immutable copy + * of the primary keys for the neighbors to be deleted. + * + * @param parent the parent {@link NeighborsChangeSet} to which this deletion change belongs. Must not be null. + * @param deletedNeighborsPrimaryKeys a {@link Collection} of primary keys, represented as {@link Tuple}s, + * identifying the neighbors to be deleted. Must not be null. + */ + public DeleteNeighborsChangeSet(@Nonnull final NeighborsChangeSet parent, + @Nonnull final Collection deletedNeighborsPrimaryKeys) { + this.parent = parent; + this.deletedNeighborsPrimaryKeys = ImmutableSet.copyOf(deletedNeighborsPrimaryKeys); + } + + /** + * Gets the parent change set from which this change set was derived. + *

+ * In a sequence of modifications, each {@code NeighborsChangeSet} is derived from a previous state, which is + * considered its parent. This method allows traversing the history of changes backward. + * + * @return the parent {@link NeighborsChangeSet} + */ + @Nonnull + @Override + public NeighborsChangeSet getParent() { + return parent; + } + + /** + * Merges the neighbors from the parent context, filtering out any neighbors that have been marked as deleted. + *

+ * This implementation retrieves the collection of neighbors from its parent by calling + * {@code getParent().merge()}. + * It then filters this collection, removing any neighbor whose primary key is present in the + * {@code deletedNeighborsPrimaryKeys} set. + * This ensures the resulting {@link Iterable} represents a consistent view of neighbors, respecting deletions made + * in the current context. + * + * @return an {@link Iterable} of the merged neighbors, excluding those marked as deleted. This method never returns + * {@code null}. + */ + @Nonnull + @Override + public Iterable merge() { + return Iterables.filter(getParent().merge(), + current -> !deletedNeighborsPrimaryKeys.contains(current.getPrimaryKey())); + } + + /** + * Writes the delta of changes for a given node to the storage layer. + *

+ * This implementation first delegates to the parent's {@code writeDelta} method to handle its changes, but modifies + * the predicate to exclude any neighbors that are marked for deletion in this delta. + *

+ * It then iterates through the set of locally deleted neighbor primary keys. For each key that matches the supplied + * {@code tuplePredicate}, it instructs the {@link InliningStorageAdapter} to delete the corresponding neighbor + * relationship for the given {@code node}. + * + * @param storageAdapter the storage adapter to which the changes are written + * @param transaction the transaction context for the write operations + * @param layer the layer index where the write operations should occur + * @param node the node for which the delta is being written + * @param tuplePredicate a predicate to filter which neighbor tuples should be processed; + * only deletions matching this predicate will be written + */ + @Override + public void writeDelta(@Nonnull final InliningStorageAdapter storageAdapter, @Nonnull final Transaction transaction, + final int layer, @Nonnull final Node node, @Nonnull final Predicate tuplePredicate) { + getParent().writeDelta(storageAdapter, transaction, layer, node, + tuplePredicate.and(tuple -> !deletedNeighborsPrimaryKeys.contains(tuple))); + + for (final Tuple deletedNeighborPrimaryKey : deletedNeighborsPrimaryKeys) { + if (tuplePredicate.test(deletedNeighborPrimaryKey)) { + storageAdapter.deleteNeighbor(transaction, layer, node.asInliningNode(), deletedNeighborPrimaryKey); + if (logger.isDebugEnabled()) { + logger.debug("deleted neighbor of primaryKey={} targeting primaryKey={}", node.getPrimaryKey(), + deletedNeighborPrimaryKey); + } + } + } + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/EntryNodeReference.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/EntryNodeReference.java new file mode 100644 index 0000000000..f8b9587bdd --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/EntryNodeReference.java @@ -0,0 +1,94 @@ +/* + * EntryNodeReference.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.apple.foundationdb.tuple.Tuple; + +import javax.annotation.Nonnull; +import java.util.Objects; + +/** + * Represents an entry reference to a node within a hierarchical graph structure. + *

+ * This class extends {@link NodeReferenceWithVector} by adding a {@code layer} + * attribute. It is used to encapsulate all the necessary information for an + * entry point into a specific layer of the graph, including its unique identifier + * (primary key), its vector representation, and its hierarchical level. + */ +class EntryNodeReference extends NodeReferenceWithVector { + private final int layer; + + /** + * Constructs a new reference to an entry node. + *

+ * This constructor initializes the node with its primary key, its associated vector, + * and the specific layer it belongs to within a hierarchical graph structure. It calls the + * superclass constructor to set the {@code primaryKey} and {@code vector}. + * + * @param primaryKey the primary key identifying the node. Must not be {@code null}. + * @param vector the vector data associated with the node. Must not be {@code null}. + * @param layer the layer number where this entry node is located. + */ + public EntryNodeReference(@Nonnull final Tuple primaryKey, @Nonnull final Vector vector, final int layer) { + super(primaryKey, vector); + this.layer = layer; + } + + /** + * Gets the layer value for this object. + * @return the integer representing the layer + */ + public int getLayer() { + return layer; + } + + /** + * Compares this {@code EntryNodeReference} to the specified object for equality. + *

+ * The result is {@code true} if and only if the argument is an instance of {@code EntryNodeReference}, the + * superclass's {@link #equals(Object)} method returns {@code true}, and the {@code layer} fields of both objects + * are equal. + * @param o the object to compare this {@code EntryNodeReference} against. + * @return {@code true} if the given object is equal to this one; {@code false} otherwise. + */ + @Override + public boolean equals(final Object o) { + if (!(o instanceof EntryNodeReference)) { + return false; + } + if (!super.equals(o)) { + return false; + } + return layer == ((EntryNodeReference)o).layer; + } + + /** + * Generates a hash code for this object. + *

+ * The hash code is computed by combining the hash code of the superclass with the hash code of the {@code layer} + * field. This implementation is consistent with the contract of {@link Object#hashCode()}. + * @return a hash code value for this object. + */ + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), layer); + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/HNSW.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/HNSW.java new file mode 100644 index 0000000000..a1875d4988 --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/HNSW.java @@ -0,0 +1,1786 @@ +/* + * HNSW.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2023 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.apple.foundationdb.Database; +import com.apple.foundationdb.ReadTransaction; +import com.apple.foundationdb.Transaction; +import com.apple.foundationdb.annotation.API; +import com.apple.foundationdb.async.AsyncUtil; +import com.apple.foundationdb.async.MoreAsyncUtil; +import com.apple.foundationdb.subspace.Subspace; +import com.apple.foundationdb.tuple.Tuple; +import com.google.common.base.Verify; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Iterables; +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; +import com.google.common.collect.Sets; +import com.google.common.collect.Streams; +import com.google.common.collect.TreeMultimap; +import com.google.errorprone.annotations.CanIgnoreReturnValue; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.Nonnull; +import java.util.Collection; +import java.util.Comparator; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Queue; +import java.util.Random; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Executor; +import java.util.concurrent.PriorityBlockingQueue; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.BiFunction; +import java.util.function.Consumer; +import java.util.function.Function; +import java.util.stream.Collectors; + +import static com.apple.foundationdb.async.MoreAsyncUtil.forEach; +import static com.apple.foundationdb.async.MoreAsyncUtil.forLoop; + +/** + * An implementation of the Hierarchical Navigable Small World (HNSW) algorithm for + * efficient approximate nearest neighbor (ANN) search. + *

+ * HNSW constructs a multi-layer graph, where each layer is a subset of the one below it. + * The top layers serve as fast entry points to navigate the graph, while the bottom layer + * contains all the data points. This structure allows for logarithmic-time complexity + * for search operations, making it suitable for large-scale, high-dimensional datasets. + *

+ * This class provides methods for building the graph ({@link #insert(Transaction, Tuple, Vector)}) + * and performing k-NN searches ({@link #kNearestNeighborsSearch(ReadTransaction, int, int, Vector)}). + * It is designed to be used with a transactional storage backend, managed via a {@link Subspace}. + * + * @see Efficient and robust approximate nearest neighbor search using Hierarchical Navigable Small World graphs + */ +@API(API.Status.EXPERIMENTAL) +@SuppressWarnings("checkstyle:AbbreviationAsWordInName") +public class HNSW { + @Nonnull + private static final Logger logger = LoggerFactory.getLogger(HNSW.class); + + public static final int MAX_CONCURRENT_NODE_READS = 16; + public static final int MAX_CONCURRENT_NEIGHBOR_FETCHES = 3; + public static final int MAX_CONCURRENT_SEARCHES = 10; + @Nonnull public static final Random DEFAULT_RANDOM = new Random(0L); + @Nonnull public static final Metric DEFAULT_METRIC = new Metric.EuclideanMetric(); + public static final boolean DEFAULT_USE_INLINING = false; + public static final int DEFAULT_M = 16; + public static final int DEFAULT_M_MAX = DEFAULT_M; + public static final int DEFAULT_M_MAX_0 = 2 * DEFAULT_M; + public static final int DEFAULT_EF_CONSTRUCTION = 200; + public static final boolean DEFAULT_EXTEND_CANDIDATES = false; + public static final boolean DEFAULT_KEEP_PRUNED_CONNECTIONS = false; + + @Nonnull + public static final Config DEFAULT_CONFIG = new Config(); + + @Nonnull + private final Subspace subspace; + @Nonnull + private final Executor executor; + @Nonnull + private final Config config; + @Nonnull + private final OnWriteListener onWriteListener; + @Nonnull + private final OnReadListener onReadListener; + + /** + * Configuration settings for a {@link HNSW}. + */ + @SuppressWarnings("checkstyle:MemberName") + public static class Config { + @Nonnull + private final Random random; + @Nonnull + private final Metric metric; + private final boolean useInlining; + private final int m; + private final int mMax; + private final int mMax0; + private final int efConstruction; + private final boolean extendCandidates; + private final boolean keepPrunedConnections; + + protected Config() { + this.random = DEFAULT_RANDOM; + this.metric = DEFAULT_METRIC; + this.useInlining = DEFAULT_USE_INLINING; + this.m = DEFAULT_M; + this.mMax = DEFAULT_M_MAX; + this.mMax0 = DEFAULT_M_MAX_0; + this.efConstruction = DEFAULT_EF_CONSTRUCTION; + this.extendCandidates = DEFAULT_EXTEND_CANDIDATES; + this.keepPrunedConnections = DEFAULT_KEEP_PRUNED_CONNECTIONS; + } + + protected Config(@Nonnull final Random random, @Nonnull final Metric metric, final boolean useInlining, + final int m, final int mMax, final int mMax0, final int efConstruction, + final boolean extendCandidates, final boolean keepPrunedConnections) { + this.random = random; + this.metric = metric; + this.useInlining = useInlining; + this.m = m; + this.mMax = mMax; + this.mMax0 = mMax0; + this.efConstruction = efConstruction; + this.extendCandidates = extendCandidates; + this.keepPrunedConnections = keepPrunedConnections; + } + + @Nonnull + public Random getRandom() { + return random; + } + + @Nonnull + public Metric getMetric() { + return metric; + } + + public boolean isUseInlining() { + return useInlining; + } + + public int getM() { + return m; + } + + public int getMMax() { + return mMax; + } + + public int getMMax0() { + return mMax0; + } + + public int getEfConstruction() { + return efConstruction; + } + + public boolean isExtendCandidates() { + return extendCandidates; + } + + public boolean isKeepPrunedConnections() { + return keepPrunedConnections; + } + + @Nonnull + public ConfigBuilder toBuilder() { + return new ConfigBuilder(getRandom(), getMetric(), isUseInlining(), getM(), getMMax(), getMMax0(), + getEfConstruction(), isExtendCandidates(), isKeepPrunedConnections()); + } + + @Override + @Nonnull + public String toString() { + return "Config[metric=" + getMetric() + "isUseInlining" + isUseInlining() + "M=" + getM() + + " , MMax=" + getMMax() + " , MMax0=" + getMMax0() + ", efConstruction=" + getEfConstruction() + + ", isExtendCandidates=" + isExtendCandidates() + + ", isKeepPrunedConnections=" + isKeepPrunedConnections() + "]"; + } + } + + /** + * Builder for {@link Config}. + * + * @see #newConfigBuilder + */ + @CanIgnoreReturnValue + @SuppressWarnings("checkstyle:MemberName") + public static class ConfigBuilder { + @Nonnull + private Random random = DEFAULT_RANDOM; + @Nonnull + private Metric metric = DEFAULT_METRIC; + private boolean useInlining = DEFAULT_USE_INLINING; + private int m = DEFAULT_M; + private int mMax = DEFAULT_M_MAX; + private int mMax0 = DEFAULT_M_MAX_0; + private int efConstruction = DEFAULT_EF_CONSTRUCTION; + private boolean extendCandidates = DEFAULT_EXTEND_CANDIDATES; + private boolean keepPrunedConnections = DEFAULT_KEEP_PRUNED_CONNECTIONS; + + public ConfigBuilder() { + } + + public ConfigBuilder(@Nonnull final Random random, @Nonnull final Metric metric, final boolean useInlining, + final int m, final int mMax, final int mMax0, final int efConstruction, + final boolean extendCandidates, final boolean keepPrunedConnections) { + this.random = random; + this.metric = metric; + this.useInlining = useInlining; + this.m = m; + this.mMax = mMax; + this.mMax0 = mMax0; + this.efConstruction = efConstruction; + this.extendCandidates = extendCandidates; + this.keepPrunedConnections = keepPrunedConnections; + } + + @Nonnull + public Random getRandom() { + return random; + } + + @Nonnull + public ConfigBuilder setRandom(@Nonnull final Random random) { + this.random = random; + return this; + } + + @Nonnull + public Metric getMetric() { + return metric; + } + + @Nonnull + public ConfigBuilder setMetric(@Nonnull final Metric metric) { + this.metric = metric; + return this; + } + + public boolean isUseInlining() { + return useInlining; + } + + public ConfigBuilder setUseInlining(final boolean useInlining) { + this.useInlining = useInlining; + return this; + } + + public int getM() { + return m; + } + + @Nonnull + public ConfigBuilder setM(final int m) { + this.m = m; + return this; + } + + public int getMMax() { + return mMax; + } + + @Nonnull + public ConfigBuilder setMMax(final int mMax) { + this.mMax = mMax; + return this; + } + + public int getMMax0() { + return mMax0; + } + + @Nonnull + public ConfigBuilder setMMax0(final int mMax0) { + this.mMax0 = mMax0; + return this; + } + + public int getEfConstruction() { + return efConstruction; + } + + public ConfigBuilder setEfConstruction(final int efConstruction) { + this.efConstruction = efConstruction; + return this; + } + + public boolean isExtendCandidates() { + return extendCandidates; + } + + public ConfigBuilder setExtendCandidates(final boolean extendCandidates) { + this.extendCandidates = extendCandidates; + return this; + } + + public boolean isKeepPrunedConnections() { + return keepPrunedConnections; + } + + public ConfigBuilder setKeepPrunedConnections(final boolean keepPrunedConnections) { + this.keepPrunedConnections = keepPrunedConnections; + return this; + } + + public Config build() { + return new Config(getRandom(), getMetric(), isUseInlining(), getM(), getMMax(), getMMax0(), + getEfConstruction(), isExtendCandidates(), isKeepPrunedConnections()); + } + } + + /** + * Start building a {@link Config}. + * @return a new {@code Config} that can be altered and then built for use with a {@link HNSW} + * @see ConfigBuilder#build + */ + public static ConfigBuilder newConfigBuilder() { + return new ConfigBuilder(); + } + + /** + * Creates a new {@code HNSW} instance using the default configuration, write listener, and read listener. + *

+ * This constructor delegates to the main constructor, providing default values for configuration + * and listeners, simplifying the instantiation process for common use cases. + * + * @param subspace the non-null {@link Subspace} to build the HNSW graph for. + * @param executor the non-null {@link Executor} for concurrent operations, such as building the graph. + */ + public HNSW(@Nonnull final Subspace subspace, @Nonnull final Executor executor) { + this(subspace, executor, DEFAULT_CONFIG, OnWriteListener.NOOP, OnReadListener.NOOP); + } + + /** + * Constructs a new HNSW graph instance. + *

+ * This constructor initializes the HNSW graph with the necessary components for storage, + * execution, configuration, and event handling. All parameters are mandatory and must not be null. + * + * @param subspace the {@link Subspace} where the graph data is stored. + * @param executor the {@link Executor} service to use for concurrent operations. + * @param config the {@link Config} object containing HNSW algorithm parameters. + * @param onWriteListener a listener to be notified of write events on the graph. + * @param onReadListener a listener to be notified of read events on the graph. + * + * @throws NullPointerException if any of the parameters are {@code null}. + */ + public HNSW(@Nonnull final Subspace subspace, + @Nonnull final Executor executor, @Nonnull final Config config, + @Nonnull final OnWriteListener onWriteListener, + @Nonnull final OnReadListener onReadListener) { + this.subspace = subspace; + this.executor = executor; + this.config = config; + this.onWriteListener = onWriteListener; + this.onReadListener = onReadListener; + } + + + /** + * Gets the subspace associated with this object. + * + * @return the non-null subspace + */ + @Nonnull + public Subspace getSubspace() { + return subspace; + } + + /** + * Get the executer used by this r-tree. + * @return executor used when running asynchronous tasks + */ + @Nonnull + public Executor getExecutor() { + return executor; + } + + /** + * Get this r-tree's configuration. + * @return r-tree configuration + */ + @Nonnull + public Config getConfig() { + return config; + } + + /** + * Get the on-write listener. + * @return the on-write listener + */ + @Nonnull + public OnWriteListener getOnWriteListener() { + return onWriteListener; + } + + /** + * Get the on-read listener. + * @return the on-read listener + */ + @Nonnull + public OnReadListener getOnReadListener() { + return onReadListener; + } + + // + // Read Path + // + + /** + * Performs a k-nearest neighbors (k-NN) search for a given query vector. + *

+ * This method implements the search algorithm for an HNSW graph. The search begins at an entry point in the + * highest layer and greedily traverses down through the layers. In each layer, it finds the node closest to the + * {@code queryVector}. This node then serves as the entry point for the search in the layer below. + *

+ * Once the search reaches the base layer (layer 0), it performs a more exhaustive search starting from the + * determined entry point. It explores the graph, maintaining a dynamic list of the best candidates found so far. + * The size of this candidate list is controlled by the {@code efSearch} parameter. Finally, the method selects + * the top {@code k} nodes from the search results, sorted by their distance to the query vector. + * + * @param readTransaction the transaction to use for reading from the database + * @param k the number of nearest neighbors to return + * @param efSearch the size of the dynamic candidate list for the search. A larger value increases accuracy + * at the cost of performance. + * @param queryVector the vector to find the nearest neighbors for + * + * @return a {@link CompletableFuture} that will complete with a list of the {@code k} nearest neighbors, + * sorted by distance in ascending order. The future completes with {@code null} if the index is empty. + */ + @SuppressWarnings("checkstyle:MethodName") // method name introduced by paper + @Nonnull + public CompletableFuture>> kNearestNeighborsSearch(@Nonnull final ReadTransaction readTransaction, + final int k, + final int efSearch, + @Nonnull final Vector queryVector) { + return StorageAdapter.fetchEntryNodeReference(readTransaction, getSubspace(), getOnReadListener()) + .thenCompose(entryPointAndLayer -> { + if (entryPointAndLayer == null) { + return CompletableFuture.completedFuture(null); // not a single node in the index + } + + final Metric metric = getConfig().getMetric(); + + final NodeReferenceWithDistance entryState = + new NodeReferenceWithDistance(entryPointAndLayer.getPrimaryKey(), + entryPointAndLayer.getVector(), + Vector.comparativeDistance(metric, entryPointAndLayer.getVector(), queryVector)); + + final var entryLayer = entryPointAndLayer.getLayer(); + if (entryLayer == 0) { + // entry data points to a node in layer 0 directly + return CompletableFuture.completedFuture(entryState); + } + + return forLoop(entryLayer, entryState, + layer -> layer > 0, + layer -> layer - 1, + (layer, previousNodeReference) -> { + final var storageAdapter = getStorageAdapterForLayer(layer); + return greedySearchLayer(storageAdapter, readTransaction, previousNodeReference, + layer, queryVector); + }, executor); + }).thenCompose(nodeReference -> { + if (nodeReference == null) { + return CompletableFuture.completedFuture(null); + } + + final var storageAdapter = getStorageAdapterForLayer(0); + + return searchLayer(storageAdapter, readTransaction, + ImmutableList.of(nodeReference), 0, efSearch, + Maps.newConcurrentMap(), queryVector) + .thenApply(searchResult -> { + // reverse the original queue + final TreeMultimap> sortedTopK = + TreeMultimap.create(Comparator.naturalOrder(), + Comparator.comparing(nodeReferenceAndNode -> nodeReferenceAndNode.getNode().getPrimaryKey())); + + for (final NodeReferenceAndNode nodeReferenceAndNode : searchResult) { + if (sortedTopK.size() < k || sortedTopK.keySet().last() > + nodeReferenceAndNode.getNodeReferenceWithDistance().getDistance()) { + sortedTopK.put(nodeReferenceAndNode.getNodeReferenceWithDistance().getDistance(), + nodeReferenceAndNode); + } + + if (sortedTopK.size() > k) { + final Double lastKey = sortedTopK.keySet().last(); + final NodeReferenceAndNode lastNode = sortedTopK.get(lastKey).last(); + sortedTopK.remove(lastKey, lastNode); + } + } + + return ImmutableList.copyOf(sortedTopK.values()); + }); + }); + } + + /** + * Performs a greedy search on a single layer of the HNSW graph. + *

+ * This method finds the node on the specified layer that is closest to the given query vector, + * starting the search from a designated entry point. The search is "greedy" because it aims to find + * only the single best neighbor. + *

+ * The implementation strategy depends on the {@link NodeKind} of the provided {@link StorageAdapter}. + * If the node kind is {@code INLINING}, it delegates to the specialized {@link #greedySearchInliningLayer} method. + * Otherwise, it uses the more general {@link #searchLayer} method with a search size (ef) of 1. + * The operation is asynchronous. + * + * @param the type of the node reference, extending {@link NodeReference} + * @param storageAdapter the {@link StorageAdapter} for accessing the graph data + * @param readTransaction the {@link ReadTransaction} to use for the search + * @param entryNeighbor the starting point for the search on this layer, which includes the node and its distance to + * the query vector + * @param layer the zero-based index of the layer to search within + * @param queryVector the query vector for which to find the nearest neighbor + * + * @return a {@link CompletableFuture} that, upon completion, will contain the closest node found on the layer, + * represented as a {@link NodeReferenceWithDistance} + */ + @Nonnull + private CompletableFuture greedySearchLayer(@Nonnull StorageAdapter storageAdapter, + @Nonnull final ReadTransaction readTransaction, + @Nonnull final NodeReferenceWithDistance entryNeighbor, + final int layer, + @Nonnull final Vector queryVector) { + if (storageAdapter.getNodeKind() == NodeKind.INLINING) { + return greedySearchInliningLayer(storageAdapter.asInliningStorageAdapter(), readTransaction, entryNeighbor, layer, queryVector); + } else { + return searchLayer(storageAdapter, readTransaction, ImmutableList.of(entryNeighbor), layer, 1, Maps.newConcurrentMap(), queryVector) + .thenApply(searchResult -> Iterables.getOnlyElement(searchResult).getNodeReferenceWithDistance()); + } + } + + /** + * Performs a greedy search for the nearest neighbor to a query vector within a single, non-zero layer of the HNSW + * graph. + *

+ * This search is performed on layers that use {@code InliningNode}s, where neighbor vectors are stored directly + * within the node. + * The search starts from a given {@code entryNeighbor} and iteratively moves to the closest neighbor in the current + * node's + * neighbor list, until no closer neighbor can be found. + *

+ * The entire process is asynchronous, returning a {@link CompletableFuture} that will complete with the best node + * found in this layer. + * + * @param storageAdapter the storage adapter to fetch nodes from the graph + * @param readTransaction the transaction context for database reads + * @param entryNeighbor the entry point for the search in this layer, typically the result from a search in a higher + * layer + * @param layer the layer number to perform the search in. Must be greater than 0. + * @param queryVector the vector for which to find the nearest neighbor + * + * @return a {@link CompletableFuture} that, upon completion, will hold the {@link NodeReferenceWithDistance} of the nearest + * neighbor found in this layer's greedy search + * + * @throws IllegalStateException if a node that is expected to exist cannot be fetched from the + * {@code storageAdapter} during the search + */ + @Nonnull + private CompletableFuture greedySearchInliningLayer(@Nonnull final StorageAdapter storageAdapter, + @Nonnull final ReadTransaction readTransaction, + @Nonnull final NodeReferenceWithDistance entryNeighbor, + final int layer, + @Nonnull final Vector queryVector) { + Verify.verify(layer > 0); + final Metric metric = getConfig().getMetric(); + final AtomicReference currentNodeReferenceAtomic = + new AtomicReference<>(entryNeighbor); + + return AsyncUtil.whileTrue(() -> onReadListener.onAsyncRead( + storageAdapter.fetchNode(readTransaction, layer, currentNodeReferenceAtomic.get().getPrimaryKey())) + .thenApply(node -> { + if (node == null) { + throw new IllegalStateException("unable to fetch node"); + } + final InliningNode inliningNode = node.asInliningNode(); + final List neighbors = inliningNode.getNeighbors(); + + final NodeReferenceWithDistance currentNodeReference = currentNodeReferenceAtomic.get(); + double minDistance = currentNodeReference.getDistance(); + + NodeReferenceWithVector nearestNeighbor = null; + for (final NodeReferenceWithVector neighbor : neighbors) { + final double distance = + Vector.comparativeDistance(metric, neighbor.getVector(), queryVector); + if (distance < minDistance) { + minDistance = distance; + nearestNeighbor = neighbor; + } + } + + if (nearestNeighbor == null) { + return false; + } + + currentNodeReferenceAtomic.set( + new NodeReferenceWithDistance(nearestNeighbor.getPrimaryKey(), nearestNeighbor.getVector(), + minDistance)); + return true; + }), executor).thenApply(ignored -> currentNodeReferenceAtomic.get()); + } + + /** + * Searches a single layer of the graph to find the nearest neighbors to a query vector. + *

+ * This method implements the greedy search algorithm used in HNSW (Hierarchical Navigable Small World) + * graphs for a specific layer. It begins with a set of entry points and iteratively explores the graph, + * always moving towards nodes that are closer to the {@code queryVector}. + *

+ * It maintains a priority queue of candidates to visit and a result set of the nearest neighbors found so far. + * The size of the dynamic candidate list is controlled by the {@code efSearch} parameter, which balances + * search quality and performance. The entire process is asynchronous, leveraging + * {@link java.util.concurrent.CompletableFuture} + * to handle I/O operations (fetching nodes) without blocking. + * + * @param The type of the node reference, extending {@link NodeReference}. + * @param storageAdapter The storage adapter for accessing node data from the underlying storage. + * @param readTransaction The transaction context for all database read operations. + * @param entryNeighbors A collection of starting nodes for the search in this layer, with their distances + * to the query vector already calculated. + * @param layer The zero-based index of the layer to search. + * @param efSearch The size of the dynamic candidate list. A larger value increases recall at the + * cost of performance. + * @param nodeCache A cache of nodes that have already been fetched from storage to avoid redundant I/O. + * @param queryVector The vector for which to find the nearest neighbors. + * + * @return A {@link java.util.concurrent.CompletableFuture} that, upon completion, will contain a list of the + * best candidate nodes found in this layer, paired with their full node data. + */ + @Nonnull + private CompletableFuture>> searchLayer(@Nonnull StorageAdapter storageAdapter, + @Nonnull final ReadTransaction readTransaction, + @Nonnull final Collection entryNeighbors, + final int layer, + final int efSearch, + @Nonnull final Map> nodeCache, + @Nonnull final Vector queryVector) { + final Set visited = Sets.newConcurrentHashSet(NodeReference.primaryKeys(entryNeighbors)); + final Queue candidates = + new PriorityBlockingQueue<>(config.getM(), + Comparator.comparing(NodeReferenceWithDistance::getDistance)); + candidates.addAll(entryNeighbors); + final Queue nearestNeighbors = + new PriorityBlockingQueue<>(config.getM(), + Comparator.comparing(NodeReferenceWithDistance::getDistance).reversed()); + nearestNeighbors.addAll(entryNeighbors); + final Metric metric = getConfig().getMetric(); + + return AsyncUtil.whileTrue(() -> { + if (candidates.isEmpty()) { + return AsyncUtil.READY_FALSE; + } + + final NodeReferenceWithDistance candidate = candidates.poll(); + final NodeReferenceWithDistance furthestNeighbor = Objects.requireNonNull(nearestNeighbors.peek()); + + if (candidate.getDistance() > furthestNeighbor.getDistance()) { + return AsyncUtil.READY_FALSE; + } + + return fetchNodeIfNotCached(storageAdapter, readTransaction, layer, candidate, nodeCache) + .thenApply(candidateNode -> + Iterables.filter(candidateNode.getNeighbors(), + neighbor -> !visited.contains(neighbor.getPrimaryKey()))) + .thenCompose(neighborReferences -> fetchNeighborhood(storageAdapter, readTransaction, + layer, neighborReferences, nodeCache)) + .thenApply(neighborReferences -> { + for (final NodeReferenceWithVector current : neighborReferences) { + visited.add(current.getPrimaryKey()); + final double furthestDistance = + Objects.requireNonNull(nearestNeighbors.peek()).getDistance(); + + final double currentDistance = + Vector.comparativeDistance(metric, current.getVector(), queryVector); + if (currentDistance < furthestDistance || nearestNeighbors.size() < efSearch) { + final NodeReferenceWithDistance currentWithDistance = + new NodeReferenceWithDistance(current.getPrimaryKey(), current.getVector(), + currentDistance); + candidates.add(currentWithDistance); + nearestNeighbors.add(currentWithDistance); + if (nearestNeighbors.size() > efSearch) { + nearestNeighbors.poll(); + } + } + } + return true; + }); + }).thenCompose(ignored -> + fetchSomeNodesIfNotCached(storageAdapter, readTransaction, layer, nearestNeighbors, nodeCache)) + .thenApply(searchResult -> { + if (logger.isDebugEnabled()) { + logger.debug("searched layer={} for efSearch={} with result=={}", layer, efSearch, + searchResult.stream() + .map(nodeReferenceAndNode -> + "(primaryKey=" + + nodeReferenceAndNode.getNodeReferenceWithDistance().getPrimaryKey() + + ",distance=" + + nodeReferenceAndNode.getNodeReferenceWithDistance().getDistance() + ")") + .collect(Collectors.joining(","))); + } + return searchResult; + }); + } + + /** + * Asynchronously fetches a node if it is not already present in the cache. + *

+ * This method first attempts to retrieve the node from the provided {@code nodeCache} using the + * primary key of the {@code nodeReference}. If the node is not found in the cache, it is + * fetched from the underlying storage using the {@code storageAdapter}. Once fetched, the node + * is added to the {@code nodeCache} before the future is completed. + *

+ * This is a convenience method that delegates to + * {@link #fetchNodeIfNecessaryAndApply(StorageAdapter, ReadTransaction, int, NodeReference, + * java.util.function.Function, java.util.function.BiFunction)}. + * + * @param the type of the node reference, which must extend {@link NodeReference} + * @param storageAdapter the storage adapter used to fetch the node from persistent storage + * @param readTransaction the transaction to use for reading from storage + * @param layer the layer index where the node is located + * @param nodeReference the reference to the node to fetch + * @param nodeCache the cache to check for the node and to which the node will be added if fetched + * + * @return a {@link CompletableFuture} that will be completed with the fetched or cached {@link Node} + */ + @Nonnull + private CompletableFuture> fetchNodeIfNotCached(@Nonnull final StorageAdapter storageAdapter, + @Nonnull final ReadTransaction readTransaction, + final int layer, + @Nonnull final NodeReference nodeReference, + @Nonnull final Map> nodeCache) { + return fetchNodeIfNecessaryAndApply(storageAdapter, readTransaction, layer, nodeReference, + nR -> nodeCache.get(nR.getPrimaryKey()), + (nR, node) -> { + nodeCache.put(nR.getPrimaryKey(), node); + return node; + }); + } + + /** + * Conditionally fetches a node from storage and applies a function to it. + *

+ * This method first attempts to generate a result by applying the {@code fetchBypassFunction}. + * If this function returns a non-null value, that value is returned immediately in a + * completed {@link CompletableFuture}, and no storage access occurs. This provides an + * optimization path, for example, if the required data is already available in a cache. + *

+ * If the bypass function returns {@code null}, the method proceeds to asynchronously fetch the + * node from the given {@code StorageAdapter}. Once the node is retrieved, the + * {@code biMapFunction} is applied to the original {@code nodeReference} and the fetched + * {@code Node} to produce the final result. + * + * @param The type of the input node reference. + * @param The type of the node reference used by the storage adapter. + * @param The type of the result. + * @param storageAdapter The storage adapter used to fetch the node if necessary. + * @param readTransaction The read transaction context for the storage operation. + * @param layer The layer index from which to fetch the node. + * @param nodeReference The reference to the node that may need to be fetched. + * @param fetchBypassFunction A function that provides a potential shortcut. If it returns a + * non-null value, the node fetch is bypassed. + * @param biMapFunction A function to be applied after a successful node fetch, combining the + * original reference and the fetched node to produce the final result. + * + * @return A {@link CompletableFuture} that will complete with the result from either the + * {@code fetchBypassFunction} or the {@code biMapFunction}. + */ + @Nonnull + private CompletableFuture fetchNodeIfNecessaryAndApply(@Nonnull final StorageAdapter storageAdapter, + @Nonnull final ReadTransaction readTransaction, + final int layer, + @Nonnull final R nodeReference, + @Nonnull final Function fetchBypassFunction, + @Nonnull final BiFunction, U> biMapFunction) { + final U bypass = fetchBypassFunction.apply(nodeReference); + if (bypass != null) { + return CompletableFuture.completedFuture(bypass); + } + + return onReadListener.onAsyncRead( + storageAdapter.fetchNode(readTransaction, layer, nodeReference.getPrimaryKey())) + .thenApply(node -> biMapFunction.apply(nodeReference, node)); + } + + /** + * Asynchronously fetches neighborhood nodes and returns them as {@link NodeReferenceWithVector} instances, + * which include the node's vector. + *

+ * This method efficiently retrieves node data by first checking an in-memory {@code nodeCache}. If a node is not + * in the cache, it is fetched from the {@link StorageAdapter}. Fetched nodes are then added to the cache to + * optimize subsequent lookups. It also handles cases where the input {@code neighborReferences} may already + * contain {@link NodeReferenceWithVector} instances, avoiding redundant work. + * + * @param the type of the node reference, extending {@link NodeReference} + * @param storageAdapter the storage adapter to fetch nodes from if they are not in the cache + * @param readTransaction the transaction context for database read operations + * @param layer the graph layer from which to fetch the nodes + * @param neighborReferences an iterable of references to the neighbor nodes to be fetched + * @param nodeCache a map serving as an in-memory cache for nodes. This map will be populated with any + * nodes fetched from storage. + * + * @return a {@link CompletableFuture} that, upon completion, will contain a list of + * {@link NodeReferenceWithVector} objects for the specified neighbors + */ + @Nonnull + private CompletableFuture> fetchNeighborhood(@Nonnull final StorageAdapter storageAdapter, + @Nonnull final ReadTransaction readTransaction, + final int layer, + @Nonnull final Iterable neighborReferences, + @Nonnull final Map> nodeCache) { + return fetchSomeNodesAndApply(storageAdapter, readTransaction, layer, neighborReferences, + neighborReference -> { + if (neighborReference instanceof NodeReferenceWithVector) { + return (NodeReferenceWithVector)neighborReference; + } + final Node neighborNode = nodeCache.get(neighborReference.getPrimaryKey()); + if (neighborNode == null) { + return null; + } + return new NodeReferenceWithVector(neighborReference.getPrimaryKey(), neighborNode.asCompactNode().getVector()); + }, + (neighborReference, neighborNode) -> { + nodeCache.put(neighborReference.getPrimaryKey(), neighborNode); + return new NodeReferenceWithVector(neighborReference.getPrimaryKey(), neighborNode.asCompactNode().getVector()); + }); + } + + /** + * Fetches a collection of nodes, attempting to retrieve them from a cache first before + * accessing the underlying storage. + *

+ * This method iterates through the provided {@code nodeReferences}. For each reference, it + * first checks the {@code nodeCache}. If the corresponding {@link Node} is found, it is + * used directly. If not, the node is fetched from the {@link StorageAdapter}. Any nodes + * fetched from storage are then added to the {@code nodeCache} to optimize subsequent lookups. + * The entire operation is performed asynchronously. + * + * @param The type of the node reference, which must extend {@link NodeReference}. + * @param storageAdapter The storage adapter used to fetch nodes from storage if they are not in the cache. + * @param readTransaction The transaction context for the read operation. + * @param layer The layer from which to fetch the nodes. + * @param nodeReferences An {@link Iterable} of {@link NodeReferenceWithDistance} objects identifying the nodes to + * be fetched. + * @param nodeCache A map used as a cache. It is checked for existing nodes and updated with any newly fetched + * nodes. + * + * @return A {@link CompletableFuture} which will complete with a {@link List} of + * {@link NodeReferenceAndNode} objects, pairing each requested reference with its corresponding node. + */ + @Nonnull + private CompletableFuture>> fetchSomeNodesIfNotCached(@Nonnull final StorageAdapter storageAdapter, + @Nonnull final ReadTransaction readTransaction, + final int layer, + @Nonnull final Iterable nodeReferences, + @Nonnull final Map> nodeCache) { + return fetchSomeNodesAndApply(storageAdapter, readTransaction, layer, nodeReferences, + nodeReference -> { + final Node node = nodeCache.get(nodeReference.getPrimaryKey()); + if (node == null) { + return null; + } + return new NodeReferenceAndNode<>(nodeReference, node); + }, + (nodeReferenceWithDistance, node) -> { + nodeCache.put(nodeReferenceWithDistance.getPrimaryKey(), node); + return new NodeReferenceAndNode<>(nodeReferenceWithDistance, node); + }); + } + + /** + * Asynchronously fetches a collection of nodes from storage and applies a function to each. + *

+ * For each {@link NodeReference} in the provided iterable, this method concurrently fetches the corresponding + * {@code Node} using the given {@link StorageAdapter}. The logic delegates to + * {@code fetchNodeIfNecessaryAndApply}, which determines whether a full node fetch is required. + * If a node is fetched from storage, the {@code biMapFunction} is applied. If the fetch is bypassed + * (e.g., because the reference itself contains sufficient information), the {@code fetchBypassFunction} is used + * instead. + * + * @param The type of the node references to be processed, extending {@link NodeReference}. + * @param The type of the key references within the nodes, extending {@link NodeReference}. + * @param The type of the result after applying one of the mapping functions. + * @param storageAdapter The {@link StorageAdapter} used to fetch nodes from the underlying storage. + * @param readTransaction The {@link ReadTransaction} context for the read operations. + * @param layer The layer index from which the nodes are being fetched. + * @param nodeReferences An {@link Iterable} of {@link NodeReference}s for the nodes to be fetched and processed. + * @param fetchBypassFunction The function to apply to a node reference when the actual node fetch is bypassed, + * mapping the reference directly to a result of type {@code U}. + * @param biMapFunction The function to apply when a node is successfully fetched, mapping the original + * reference and the fetched {@link Node} to a result of type {@code U}. + * + * @return A {@link CompletableFuture} that, upon completion, will hold a {@link java.util.List} of results + * of type {@code U}, corresponding to each processed node reference. + */ + @Nonnull + private CompletableFuture> fetchSomeNodesAndApply(@Nonnull final StorageAdapter storageAdapter, + @Nonnull final ReadTransaction readTransaction, + final int layer, + @Nonnull final Iterable nodeReferences, + @Nonnull final Function fetchBypassFunction, + @Nonnull final BiFunction, U> biMapFunction) { + return forEach(nodeReferences, + currentNeighborReference -> fetchNodeIfNecessaryAndApply(storageAdapter, readTransaction, layer, + currentNeighborReference, fetchBypassFunction, biMapFunction), MAX_CONCURRENT_NODE_READS, + getExecutor()); + } + + /** + * Asynchronously inserts a node reference and its corresponding vector into the index. + *

+ * This is a convenience method that extracts the primary key and vector from the + * provided {@link NodeReferenceWithVector} and delegates to the + * {@link #insert(Transaction, Tuple, Vector)} method. + * + * @param transaction the transaction context for the operation. Must not be {@code null}. + * @param nodeReferenceWithVector a container object holding the primary key of the node + * and its vector representation. Must not be {@code null}. + * + * @return a {@link CompletableFuture} that will complete when the insertion operation is finished. + */ + @Nonnull + public CompletableFuture insert(@Nonnull final Transaction transaction, @Nonnull final NodeReferenceWithVector nodeReferenceWithVector) { + return insert(transaction, nodeReferenceWithVector.getPrimaryKey(), nodeReferenceWithVector.getVector()); + } + + /** + * Inserts a new vector with its associated primary key into the HNSW graph. + *

+ * The method first determines a random layer for the new node, called the {@code insertionLayer}. + * It then traverses the graph from the entry point downwards, greedily searching for the nearest + * neighbors to the {@code newVector} at each layer. This search identifies the optimal + * connection points for the new node. + *

+ * Once the nearest neighbors are found, the new node is linked into the graph structure at all + * layers up to its {@code insertionLayer}. Special handling is included for inserting the + * first-ever node into the graph or when a new node's layer is higher than any existing node, + * which updates the graph's entry point. All operations are performed asynchronously. + * + * @param transaction the {@link Transaction} context for all database operations + * @param newPrimaryKey the unique {@link Tuple} primary key for the new node being inserted + * @param newVector the {@link Vector} data to be inserted into the graph + * + * @return a {@link CompletableFuture} that completes when the insertion operation is finished + */ + @Nonnull + public CompletableFuture insert(@Nonnull final Transaction transaction, @Nonnull final Tuple newPrimaryKey, + @Nonnull final Vector newVector) { + final Metric metric = getConfig().getMetric(); + + final int insertionLayer = insertionLayer(getConfig().getRandom()); + if (logger.isDebugEnabled()) { + logger.debug("new node with key={} selected to be inserted into layer={}", newPrimaryKey, insertionLayer); + } + + return StorageAdapter.fetchEntryNodeReference(transaction, getSubspace(), getOnReadListener()) + .thenApply(entryNodeReference -> { + if (entryNodeReference == null) { + // this is the first node + writeLonelyNodes(transaction, newPrimaryKey, newVector, insertionLayer, -1); + StorageAdapter.writeEntryNodeReference(transaction, getSubspace(), + new EntryNodeReference(newPrimaryKey, newVector, insertionLayer), getOnWriteListener()); + if (logger.isDebugEnabled()) { + logger.debug("written entry node reference with key={} on layer={}", newPrimaryKey, insertionLayer); + } + } else { + final int lMax = entryNodeReference.getLayer(); + if (insertionLayer > lMax) { + writeLonelyNodes(transaction, newPrimaryKey, newVector, insertionLayer, lMax); + StorageAdapter.writeEntryNodeReference(transaction, getSubspace(), + new EntryNodeReference(newPrimaryKey, newVector, insertionLayer), getOnWriteListener()); + if (logger.isDebugEnabled()) { + logger.debug("written entry node reference with key={} on layer={}", newPrimaryKey, insertionLayer); + } + } + } + return entryNodeReference; + }).thenCompose(entryNodeReference -> { + if (entryNodeReference == null) { + return AsyncUtil.DONE; + } + + final int lMax = entryNodeReference.getLayer(); + if (logger.isDebugEnabled()) { + logger.debug("entry node with key {} at layer {}", entryNodeReference.getPrimaryKey(), + lMax); + } + + final NodeReferenceWithDistance initialNodeReference = + new NodeReferenceWithDistance(entryNodeReference.getPrimaryKey(), + entryNodeReference.getVector(), + Vector.comparativeDistance(metric, entryNodeReference.getVector(), newVector)); + return forLoop(lMax, initialNodeReference, + layer -> layer > insertionLayer, + layer -> layer - 1, + (layer, previousNodeReference) -> { + final StorageAdapter storageAdapter = getStorageAdapterForLayer(layer); + return greedySearchLayer(storageAdapter, transaction, + previousNodeReference, layer, newVector); + }, executor) + .thenCompose(nodeReference -> + insertIntoLayers(transaction, newPrimaryKey, newVector, nodeReference, + lMax, insertionLayer)); + }).thenCompose(ignored -> AsyncUtil.DONE); + } + + /** + * Inserts a batch of nodes into the HNSW graph asynchronously. + * + *

This method orchestrates the batch insertion of nodes into the HNSW graph structure. + * For each node in the input {@code batch}, it first assigns a random layer based on the configured + * probability distribution. The batch is then sorted in descending order of these assigned layers to + * ensure higher-layer nodes are processed first, which can optimize subsequent insertions by providing + * better entry points.

+ * + *

The insertion logic proceeds in two main asynchronous stages: + *

    + *
  1. Search Phase: For each node to be inserted, the method concurrently performs a greedy search + * from the graph's main entry point down to the node's target layer. This identifies the nearest neighbors + * at each level, which will serve as entry points for the insertion phase.
  2. + *
  3. Insertion Phase: The method then iterates through the nodes and inserts each one into the graph + * from its target layer downwards, connecting it to its nearest neighbors. If a node's assigned layer is + * higher than the current maximum layer of the graph, it becomes the new main entry point.
  4. + *
+ * All underlying storage operations are performed within the context of the provided {@link Transaction}.

+ * + * @param transaction the transaction to use for all storage operations; must not be {@code null} + * @param batch a {@code List} of {@link NodeReferenceWithVector} objects to insert; must not be {@code null} + * + * @return a {@link CompletableFuture} that completes with {@code null} when the entire batch has been inserted + */ + @Nonnull + public CompletableFuture insertBatch(@Nonnull final Transaction transaction, + @Nonnull List batch) { + final Metric metric = getConfig().getMetric(); + + // determine the layer each item should be inserted at + final Random random = getConfig().getRandom(); + final List batchWithLayers = Lists.newArrayListWithCapacity(batch.size()); + for (final NodeReferenceWithVector current : batch) { + batchWithLayers.add(new NodeReferenceWithLayer(current.getPrimaryKey(), current.getVector(), + insertionLayer(random))); + } + // sort the layers in reverse order + batchWithLayers.sort(Comparator.comparing(NodeReferenceWithLayer::getLayer).reversed()); + + return StorageAdapter.fetchEntryNodeReference(transaction, getSubspace(), getOnReadListener()) + .thenCompose(entryNodeReference -> { + final int lMax = entryNodeReference == null ? -1 : entryNodeReference.getLayer(); + + return forEach(batchWithLayers, + item -> { + if (lMax == -1) { + return CompletableFuture.completedFuture(null); + } + + final Vector itemVector = item.getVector(); + final int itemL = item.getLayer(); + + final NodeReferenceWithDistance initialNodeReference = + new NodeReferenceWithDistance(entryNodeReference.getPrimaryKey(), + entryNodeReference.getVector(), + Vector.comparativeDistance(metric, entryNodeReference.getVector(), itemVector)); + + return forLoop(lMax, initialNodeReference, + layer -> layer > itemL, + layer -> layer - 1, + (layer, previousNodeReference) -> { + final StorageAdapter storageAdapter = getStorageAdapterForLayer(layer); + return greedySearchLayer(storageAdapter, transaction, + previousNodeReference, layer, itemVector); + }, executor); + }, MAX_CONCURRENT_SEARCHES, getExecutor()) + .thenCompose(searchEntryReferences -> + forLoop(0, entryNodeReference, + index -> index < batchWithLayers.size(), + index -> index + 1, + (index, currentEntryNodeReference) -> { + final NodeReferenceWithLayer item = batchWithLayers.get(index); + final Tuple itemPrimaryKey = item.getPrimaryKey(); + final Vector itemVector = item.getVector(); + final int itemL = item.getLayer(); + + final EntryNodeReference newEntryNodeReference; + final int currentLMax; + + if (entryNodeReference == null) { + // this is the first node + writeLonelyNodes(transaction, itemPrimaryKey, itemVector, itemL, -1); + newEntryNodeReference = + new EntryNodeReference(itemPrimaryKey, itemVector, itemL); + StorageAdapter.writeEntryNodeReference(transaction, getSubspace(), + newEntryNodeReference, getOnWriteListener()); + if (logger.isDebugEnabled()) { + logger.debug("written entry node reference with key={} on layer={}", itemPrimaryKey, itemL); + } + + return CompletableFuture.completedFuture(newEntryNodeReference); + } else { + currentLMax = currentEntryNodeReference.getLayer(); + if (itemL > currentLMax) { + writeLonelyNodes(transaction, itemPrimaryKey, itemVector, itemL, lMax); + newEntryNodeReference = + new EntryNodeReference(itemPrimaryKey, itemVector, itemL); + StorageAdapter.writeEntryNodeReference(transaction, getSubspace(), + newEntryNodeReference, getOnWriteListener()); + if (logger.isDebugEnabled()) { + logger.debug("written entry node reference with key={} on layer={}", itemPrimaryKey, itemL); + } + } else { + newEntryNodeReference = entryNodeReference; + } + } + + if (logger.isDebugEnabled()) { + logger.debug("entry node with key {} at layer {}", + currentEntryNodeReference.getPrimaryKey(), currentLMax); + } + + final var currentSearchEntry = + searchEntryReferences.get(index); + + return insertIntoLayers(transaction, itemPrimaryKey, itemVector, currentSearchEntry, + lMax, itemL).thenApply(ignored -> newEntryNodeReference); + }, getExecutor())); + }).thenCompose(ignored -> AsyncUtil.DONE); + } + + /** + * Inserts a new vector into the HNSW graph across multiple layers, starting from a given entry point. + *

+ * This method implements the second phase of the HNSW insertion algorithm. It begins at a starting layer, which is + * the minimum of the graph's maximum layer ({@code lMax}) and the new node's randomly assigned + * {@code insertionLayer}. It then iterates downwards to layer 0. In each layer, it invokes + * {@link #insertIntoLayer(StorageAdapter, Transaction, List, int, Tuple, Vector)} to perform the search and + * connect the new node. The set of nearest neighbors found at layer {@code L} serves as the entry points for the + * search at layer {@code L-1}. + *

+ * + * @param transaction the transaction to use for database operations + * @param newPrimaryKey the primary key of the new node being inserted + * @param newVector the vector data of the new node + * @param nodeReference the initial entry point for the search, typically the nearest neighbor found in the highest + * layer + * @param lMax the maximum layer number in the HNSW graph + * @param insertionLayer the randomly determined layer for the new node. The node will be inserted into all layers + * from this layer down to 0. + * + * @return a {@link CompletableFuture} that completes when the new node has been successfully inserted into all + * its designated layers + */ + @Nonnull + private CompletableFuture insertIntoLayers(@Nonnull final Transaction transaction, + @Nonnull final Tuple newPrimaryKey, + @Nonnull final Vector newVector, + @Nonnull final NodeReferenceWithDistance nodeReference, + final int lMax, + final int insertionLayer) { + if (logger.isDebugEnabled()) { + logger.debug("nearest entry point at lMax={} is at key={}", lMax, nodeReference.getPrimaryKey()); + } + return MoreAsyncUtil.>forLoop(Math.min(lMax, insertionLayer), ImmutableList.of(nodeReference), + layer -> layer >= 0, + layer -> layer - 1, + (layer, previousNodeReferences) -> { + final StorageAdapter storageAdapter = getStorageAdapterForLayer(layer); + return insertIntoLayer(storageAdapter, transaction, + previousNodeReferences, layer, newPrimaryKey, newVector); + }, executor).thenCompose(ignored -> AsyncUtil.DONE); + } + + /** + * Inserts a new node into a specified layer of the HNSW graph. + *

+ * This method orchestrates the complete insertion process for a single layer. It begins by performing a search + * within the given layer, starting from the provided {@code nearestNeighbors} as entry points, to find a set of + * candidate neighbors for the new node. From this candidate set, it selects the best connections based on the + * graph's parameters (M). + *

+ *

+ * After selecting the neighbors, it creates the new node and links it to them. It then reciprocally updates + * the selected neighbors to link back to the new node. If adding this new link causes a neighbor to exceed its + * maximum allowed connections, its connections are pruned. All changes, including the new node and the updated + * neighbors, are persisted to storage within the given transaction. + *

+ *

+ * The operation is asynchronous and returns a {@link CompletableFuture}. The future completes with the list of + * nodes found during the initial search phase, which are then used as the entry points for insertion into the + * next lower layer. + *

+ * + * @param the type of the node reference, extending {@link NodeReference} + * @param storageAdapter the storage adapter for reading from and writing to the graph + * @param transaction the transaction context for the database operations + * @param nearestNeighbors the list of nearest neighbors from the layer above, used as entry points for the search + * in this layer + * @param layer the layer number to insert the new node into + * @param newPrimaryKey the primary key of the new node to be inserted + * @param newVector the vector associated with the new node + * + * @return a {@code CompletableFuture} that completes with a list of the nearest neighbors found during the + * initial search phase. This list serves as the entry point for insertion into the next lower layer + * (i.e., {@code layer - 1}). + */ + @Nonnull + private CompletableFuture> insertIntoLayer(@Nonnull final StorageAdapter storageAdapter, + @Nonnull final Transaction transaction, + @Nonnull final List nearestNeighbors, + int layer, + @Nonnull final Tuple newPrimaryKey, + @Nonnull final Vector newVector) { + if (logger.isDebugEnabled()) { + logger.debug("begin insert key={} at layer={}", newPrimaryKey, layer); + } + final Map> nodeCache = Maps.newConcurrentMap(); + + return searchLayer(storageAdapter, transaction, + nearestNeighbors, layer, config.getEfConstruction(), nodeCache, newVector) + .thenCompose(searchResult -> { + final List references = NodeReferenceAndNode.getReferences(searchResult); + + return selectNeighbors(storageAdapter, transaction, searchResult, layer, getConfig().getM(), + getConfig().isExtendCandidates(), nodeCache, newVector) + .thenCompose(selectedNeighbors -> { + final NodeFactory nodeFactory = storageAdapter.getNodeFactory(); + + final Node newNode = + nodeFactory.create(newPrimaryKey, newVector, + NodeReferenceAndNode.getReferences(selectedNeighbors)); + + final NeighborsChangeSet newNodeChangeSet = + new InsertNeighborsChangeSet<>(new BaseNeighborsChangeSet<>(ImmutableList.of()), + newNode.getNeighbors()); + + storageAdapter.writeNode(transaction, newNode, layer, newNodeChangeSet); + + // create change sets for each selected neighbor and insert new node into them + final Map> neighborChangeSetMap = + Maps.newLinkedHashMap(); + for (final NodeReferenceAndNode selectedNeighbor : selectedNeighbors) { + final NeighborsChangeSet baseSet = + new BaseNeighborsChangeSet<>(selectedNeighbor.getNode().getNeighbors()); + final NeighborsChangeSet insertSet = + new InsertNeighborsChangeSet<>(baseSet, ImmutableList.of(newNode.getSelfReference(newVector))); + neighborChangeSetMap.put(selectedNeighbor.getNode().getPrimaryKey(), + insertSet); + } + + final int currentMMax = layer == 0 ? getConfig().getMMax0() : getConfig().getMMax(); + return forEach(selectedNeighbors, + selectedNeighbor -> { + final Node selectedNeighborNode = selectedNeighbor.getNode(); + final NeighborsChangeSet changeSet = + Objects.requireNonNull(neighborChangeSetMap.get(selectedNeighborNode.getPrimaryKey())); + return pruneNeighborsIfNecessary(storageAdapter, transaction, + selectedNeighbor, layer, currentMMax, changeSet, nodeCache) + .thenApply(nodeReferencesAndNodes -> { + if (nodeReferencesAndNodes == null) { + return changeSet; + } + return resolveChangeSetFromNewNeighbors(changeSet, nodeReferencesAndNodes); + }); + }, MAX_CONCURRENT_NEIGHBOR_FETCHES, getExecutor()) + .thenApply(changeSets -> { + for (int i = 0; i < selectedNeighbors.size(); i++) { + final NodeReferenceAndNode selectedNeighbor = selectedNeighbors.get(i); + final NeighborsChangeSet changeSet = changeSets.get(i); + storageAdapter.writeNode(transaction, selectedNeighbor.getNode(), + layer, changeSet); + } + return ImmutableList.copyOf(references); + }); + }); + }).thenApply(nodeReferencesWithDistances -> { + if (logger.isDebugEnabled()) { + logger.debug("end insert key={} at layer={}", newPrimaryKey, layer); + } + return nodeReferencesWithDistances; + }); + } + + /** + * Calculates the delta between a current set of neighbors and a new set, producing a + * {@link NeighborsChangeSet} that represents the required insertions and deletions. + *

+ * This method compares the neighbors present in the initial {@code beforeChangeSet} with + * the provided {@code afterNeighbors}. It identifies which neighbors from the "before" state + * are missing in the "after" state (to be deleted) and which new neighbors are present in the + * "after" state but not in the "before" state (to be inserted). It then constructs a new + * {@code NeighborsChangeSet} by wrapping the original one with {@link DeleteNeighborsChangeSet} + * and {@link InsertNeighborsChangeSet} as needed. + * + * @param the type of the node reference, which must extend {@link NodeReference} + * @param beforeChangeSet the change set representing the state of neighbors before the update. + * This is used as the base for calculating changes. Must not be null. + * @param afterNeighbors an iterable collection of the desired neighbors after the update. + * Must not be null. + * + * @return a new {@code NeighborsChangeSet} that includes the necessary deletion and insertion + * operations to transform the neighbors from the "before" state to the "after" state. + */ + private NeighborsChangeSet resolveChangeSetFromNewNeighbors(@Nonnull final NeighborsChangeSet beforeChangeSet, + @Nonnull final Iterable> afterNeighbors) { + final Map beforeNeighborsMap = Maps.newLinkedHashMap(); + for (final N n : beforeChangeSet.merge()) { + beforeNeighborsMap.put(n.getPrimaryKey(), n); + } + + final Map afterNeighborsMap = Maps.newLinkedHashMap(); + for (final NodeReferenceAndNode nodeReferenceAndNode : afterNeighbors) { + final NodeReferenceWithDistance nodeReferenceWithDistance = nodeReferenceAndNode.getNodeReferenceWithDistance(); + + afterNeighborsMap.put(nodeReferenceWithDistance.getPrimaryKey(), + nodeReferenceAndNode.getNode().getSelfReference(nodeReferenceWithDistance.getVector())); + } + + final ImmutableList.Builder toBeDeletedBuilder = ImmutableList.builder(); + for (final Map.Entry beforeNeighborEntry : beforeNeighborsMap.entrySet()) { + if (!afterNeighborsMap.containsKey(beforeNeighborEntry.getKey())) { + toBeDeletedBuilder.add(beforeNeighborEntry.getValue().getPrimaryKey()); + } + } + final List toBeDeleted = toBeDeletedBuilder.build(); + + final ImmutableList.Builder toBeInsertedBuilder = ImmutableList.builder(); + for (final Map.Entry afterNeighborEntry : afterNeighborsMap.entrySet()) { + if (!beforeNeighborsMap.containsKey(afterNeighborEntry.getKey())) { + toBeInsertedBuilder.add(afterNeighborEntry.getValue()); + } + } + final List toBeInserted = toBeInsertedBuilder.build(); + + NeighborsChangeSet changeSet = beforeChangeSet; + + if (!toBeDeleted.isEmpty()) { + changeSet = new DeleteNeighborsChangeSet<>(changeSet, toBeDeleted); + } + if (!toBeInserted.isEmpty()) { + changeSet = new InsertNeighborsChangeSet<>(changeSet, toBeInserted); + } + return changeSet; + } + + /** + * Prunes the neighborhood of a given node if its number of connections exceeds the maximum allowed ({@code mMax}). + *

+ * This is a maintenance operation for the HNSW graph. When new nodes are added, an existing node's neighborhood + * might temporarily grow beyond its limit. This method identifies such cases and trims the neighborhood back down + * to the {@code mMax} best connections, based on the configured distance metric. If the neighborhood size is + * already within the limit, this method does nothing. + * + * @param the type of the node reference, extending {@link NodeReference} + * @param storageAdapter the storage adapter to fetch nodes from the database + * @param transaction the transaction context for database operations + * @param selectedNeighbor the node whose neighborhood is being considered for pruning + * @param layer the graph layer on which the operation is performed + * @param mMax the maximum number of neighbors a node is allowed to have on this layer + * @param neighborChangeSet a set of pending changes to the neighborhood that must be included in the pruning + * calculation + * @param nodeCache a cache of nodes to avoid redundant database fetches + * + * @return a {@link CompletableFuture} which completes with a list of the newly selected neighbors for the pruned node. + * If no pruning was necessary, it completes with {@code null}. + */ + @Nonnull + private CompletableFuture>> pruneNeighborsIfNecessary(@Nonnull final StorageAdapter storageAdapter, + @Nonnull final Transaction transaction, + @Nonnull final NodeReferenceAndNode selectedNeighbor, + int layer, + int mMax, + @Nonnull final NeighborsChangeSet neighborChangeSet, + @Nonnull final Map> nodeCache) { + final Metric metric = getConfig().getMetric(); + final Node selectedNeighborNode = selectedNeighbor.getNode(); + if (selectedNeighborNode.getNeighbors().size() < mMax) { + return CompletableFuture.completedFuture(null); + } else { + if (logger.isDebugEnabled()) { + logger.debug("pruning neighborhood of key={} which has numNeighbors={} out of mMax={}", + selectedNeighborNode.getPrimaryKey(), selectedNeighborNode.getNeighbors().size(), mMax); + } + return fetchNeighborhood(storageAdapter, transaction, layer, neighborChangeSet.merge(), nodeCache) + .thenCompose(nodeReferenceWithVectors -> { + final ImmutableList.Builder nodeReferencesWithDistancesBuilder = + ImmutableList.builder(); + for (final NodeReferenceWithVector nodeReferenceWithVector : nodeReferenceWithVectors) { + final var vector = nodeReferenceWithVector.getVector(); + final double distance = + Vector.comparativeDistance(metric, vector, + selectedNeighbor.getNodeReferenceWithDistance().getVector()); + nodeReferencesWithDistancesBuilder.add( + new NodeReferenceWithDistance(nodeReferenceWithVector.getPrimaryKey(), + vector, distance)); + } + return fetchSomeNodesIfNotCached(storageAdapter, transaction, layer, + nodeReferencesWithDistancesBuilder.build(), nodeCache); + }) + .thenCompose(nodeReferencesAndNodes -> + selectNeighbors(storageAdapter, transaction, + nodeReferencesAndNodes, layer, + mMax, false, nodeCache, + selectedNeighbor.getNodeReferenceWithDistance().getVector())); + } + } + + /** + * Selects the {@code m} best neighbors for a new node from a set of candidates using the HNSW selection heuristic. + *

+ * This method implements the core logic for neighbor selection within a layer of the HNSW graph. It starts with an + * initial set of candidates ({@code nearestNeighbors}), which can be optionally extended by fetching their own + * neighbors. + * It then iteratively refines this set using a greedy best-first search. + *

+ * The selection heuristic ensures diversity among neighbors. A candidate is added to the result set only if it is + * closer to the query {@code vector} than to any node already in the result set. This prevents selecting neighbors + * that are clustered together. If the {@code keepPrunedConnections} configuration is enabled, candidates that are + * pruned by this heuristic are kept and may be added at the end if the result set is not yet full. + *

+ * The process is asynchronous and returns a {@link CompletableFuture} that will eventually contain the list of + * selected neighbors with their full node data. + * + * @param the type of the node reference, extending {@link NodeReference} + * @param storageAdapter the storage adapter to fetch nodes and their neighbors + * @param readTransaction the transaction for performing database reads + * @param nearestNeighbors the initial pool of candidate neighbors, typically from a search in a higher layer + * @param layer the layer in the HNSW graph where the selection is being performed + * @param m the maximum number of neighbors to select + * @param isExtendCandidates a flag indicating whether to extend the initial candidate pool by fetching the + * neighbors of the {@code nearestNeighbors} + * @param nodeCache a cache of nodes to avoid redundant storage lookups + * @param vector the query vector for which neighbors are being selected + * + * @return a {@link CompletableFuture} which will complete with a list of the selected neighbors, + * each represented as a {@link NodeReferenceAndNode} + */ + private CompletableFuture>> selectNeighbors(@Nonnull final StorageAdapter storageAdapter, + @Nonnull final ReadTransaction readTransaction, + @Nonnull final Iterable> nearestNeighbors, + final int layer, + final int m, + final boolean isExtendCandidates, + @Nonnull final Map> nodeCache, + @Nonnull final Vector vector) { + return extendCandidatesIfNecessary(storageAdapter, readTransaction, nearestNeighbors, layer, isExtendCandidates, nodeCache, vector) + .thenApply(extendedCandidates -> { + final List selected = Lists.newArrayListWithExpectedSize(m); + final Queue candidates = + new PriorityBlockingQueue<>(config.getM(), + Comparator.comparing(NodeReferenceWithDistance::getDistance)); + candidates.addAll(extendedCandidates); + final Queue discardedCandidates = + getConfig().isKeepPrunedConnections() + ? new PriorityBlockingQueue<>(config.getM(), + Comparator.comparing(NodeReferenceWithDistance::getDistance)) + : null; + + final Metric metric = getConfig().getMetric(); + + while (!candidates.isEmpty() && selected.size() < m) { + final NodeReferenceWithDistance nearestCandidate = candidates.poll(); + boolean shouldSelect = true; + for (final NodeReferenceWithDistance alreadySelected : selected) { + if (Vector.comparativeDistance(metric, nearestCandidate.getVector(), + alreadySelected.getVector()) < nearestCandidate.getDistance()) { + shouldSelect = false; + break; + } + } + if (shouldSelect) { + selected.add(nearestCandidate); + } else if (discardedCandidates != null) { + discardedCandidates.add(nearestCandidate); + } + } + + if (discardedCandidates != null) { // isKeepPrunedConnections is set to true + while (!discardedCandidates.isEmpty() && selected.size() < m) { + selected.add(discardedCandidates.poll()); + } + } + + return ImmutableList.copyOf(selected); + }).thenCompose(selectedNeighbors -> + fetchSomeNodesIfNotCached(storageAdapter, readTransaction, layer, selectedNeighbors, nodeCache)) + .thenApply(selectedNeighbors -> { + if (logger.isDebugEnabled()) { + logger.debug("selected neighbors={}", + selectedNeighbors.stream() + .map(selectedNeighbor -> + "(primaryKey=" + selectedNeighbor.getNodeReferenceWithDistance().getPrimaryKey() + + ",distance=" + selectedNeighbor.getNodeReferenceWithDistance().getDistance() + ")") + .collect(Collectors.joining(","))); + } + return selectedNeighbors; + }); + } + + /** + * Conditionally extends a set of candidate nodes by fetching and evaluating their neighbors. + *

+ * If {@code isExtendCandidates} is {@code true}, this method gathers the neighbors of the provided + * {@code candidates}, fetches their full node data, and calculates their distance to the given + * {@code vector}. The resulting list will contain both the original candidates and their newly + * evaluated neighbors. + *

+ * If {@code isExtendCandidates} is {@code false}, the method simply returns a list containing + * only the original candidates. This operation is asynchronous and returns a {@link CompletableFuture}. + * + * @param the type of the {@link NodeReference} + * @param storageAdapter the {@link StorageAdapter} used to access node data from storage + * @param readTransaction the active {@link ReadTransaction} for database access + * @param candidates an {@link Iterable} of initial candidate nodes, which have already been evaluated + * @param layer the graph layer from which to fetch nodes + * @param isExtendCandidates a boolean flag; if {@code true}, the candidate set is extended with neighbors + * @param nodeCache a cache mapping primary keys to {@link Node} objects to avoid redundant fetches + * @param vector the query vector used to calculate distances for any new neighbor nodes + * + * @return a {@link CompletableFuture} which will complete with a list of {@link NodeReferenceWithDistance}, + * containing the original candidates and potentially their neighbors + */ + private CompletableFuture> + extendCandidatesIfNecessary(@Nonnull final StorageAdapter storageAdapter, + @Nonnull final ReadTransaction readTransaction, + @Nonnull final Iterable> candidates, + int layer, + boolean isExtendCandidates, + @Nonnull final Map> nodeCache, + @Nonnull final Vector vector) { + if (isExtendCandidates) { + final Metric metric = getConfig().getMetric(); + + final Set candidatesSeen = Sets.newConcurrentHashSet(); + for (final NodeReferenceAndNode candidate : candidates) { + candidatesSeen.add(candidate.getNode().getPrimaryKey()); + } + + final ImmutableList.Builder neighborsOfCandidatesBuilder = ImmutableList.builder(); + for (final NodeReferenceAndNode candidate : candidates) { + for (final N neighbor : candidate.getNode().getNeighbors()) { + final Tuple neighborPrimaryKey = neighbor.getPrimaryKey(); + if (!candidatesSeen.contains(neighborPrimaryKey)) { + candidatesSeen.add(neighborPrimaryKey); + neighborsOfCandidatesBuilder.add(neighbor); + } + } + } + + final Iterable neighborsOfCandidates = neighborsOfCandidatesBuilder.build(); + + return fetchNeighborhood(storageAdapter, readTransaction, layer, neighborsOfCandidates, nodeCache) + .thenApply(withVectors -> { + final ImmutableList.Builder extendedCandidatesBuilder = + ImmutableList.builder(); + for (final NodeReferenceAndNode candidate : candidates) { + extendedCandidatesBuilder.add(candidate.getNodeReferenceWithDistance()); + } + + for (final NodeReferenceWithVector withVector : withVectors) { + final double distance = Vector.comparativeDistance(metric, withVector.getVector(), vector); + extendedCandidatesBuilder.add(new NodeReferenceWithDistance(withVector.getPrimaryKey(), + withVector.getVector(), distance)); + } + return extendedCandidatesBuilder.build(); + }); + } else { + final ImmutableList.Builder resultBuilder = ImmutableList.builder(); + for (final NodeReferenceAndNode candidate : candidates) { + resultBuilder.add(candidate.getNodeReferenceWithDistance()); + } + + return CompletableFuture.completedFuture(resultBuilder.build()); + } + } + + /** + * Writes lonely nodes for a given key across a specified range of layers. + *

+ * A "lonely node" is a node in the layered structure that does not have a right + * sibling. This method iterates downwards from the {@code highestLayerInclusive} + * to the {@code lowestLayerExclusive}. For each layer in this range, it + * retrieves the appropriate {@link StorageAdapter} and calls + * {@link #writeLonelyNodeOnLayer} to persist the node's information. + * + * @param transaction the transaction to use for writing to the database + * @param primaryKey the primary key of the record for which lonely nodes are being written + * @param vector the search path vector that was followed to find this key + * @param highestLayerInclusive the highest layer (inclusive) to begin writing lonely nodes on + * @param lowestLayerExclusive the lowest layer (exclusive) at which to stop writing lonely nodes + */ + private void writeLonelyNodes(@Nonnull final Transaction transaction, + @Nonnull final Tuple primaryKey, + @Nonnull final Vector vector, + final int highestLayerInclusive, + final int lowestLayerExclusive) { + for (int layer = highestLayerInclusive; layer > lowestLayerExclusive; layer --) { + final StorageAdapter storageAdapter = getStorageAdapterForLayer(layer); + writeLonelyNodeOnLayer(storageAdapter, transaction, layer, primaryKey, vector); + } + } + + /** + * Writes a new, isolated ('lonely') node to a specified layer within the graph. + *

+ * This method uses the provided {@link StorageAdapter} to create a new node with the + * given primary key and vector but with an empty set of neighbors. The write + * operation is performed as part of the given {@link Transaction}. This is typically + * used to insert the very first node into an empty graph layer. + * + * @param the type of the node reference, extending {@link NodeReference} + * @param storageAdapter the {@link StorageAdapter} used to access the data store and create nodes; must not be null + * @param transaction the {@link Transaction} context for the write operation; must not be null + * @param layer the layer index where the new node will be written + * @param primaryKey the primary key for the new node; must not be null + * @param vector the vector data for the new node; must not be null + */ + private void writeLonelyNodeOnLayer(@Nonnull final StorageAdapter storageAdapter, + @Nonnull final Transaction transaction, + final int layer, + @Nonnull final Tuple primaryKey, + @Nonnull final Vector vector) { + storageAdapter.writeNode(transaction, + storageAdapter.getNodeFactory() + .create(primaryKey, vector, ImmutableList.of()), layer, + new BaseNeighborsChangeSet<>(ImmutableList.of())); + if (logger.isDebugEnabled()) { + logger.debug("written lonely node at key={} on layer={}", primaryKey, layer); + } + } + + /** + * Scans all nodes within a given layer of the database. + *

+ * The scan is performed transactionally in batches to avoid loading the entire layer + * into memory at once. Each discovered node is passed to the provided {@link Consumer} + * for processing. The operation continues fetching batches until all nodes in the + * specified layer have been processed. + * + * @param db the non-null {@link Database} instance to run the scan against. + * @param layer the specific layer index to scan. + * @param batchSize the number of nodes to retrieve and process in each batch. + * @param nodeConsumer the non-null {@link Consumer} that will accept each {@link Node} + * found in the layer. + */ + public void scanLayer(@Nonnull final Database db, + final int layer, + final int batchSize, + @Nonnull final Consumer> nodeConsumer) { + final StorageAdapter storageAdapter = getStorageAdapterForLayer(layer); + final AtomicReference lastPrimaryKeyAtomic = new AtomicReference<>(); + Tuple newPrimaryKey; + do { + final Tuple lastPrimaryKey = lastPrimaryKeyAtomic.get(); + lastPrimaryKeyAtomic.set(null); + newPrimaryKey = db.run(tr -> { + Streams.stream(storageAdapter.scanLayer(tr, layer, lastPrimaryKey, batchSize)) + .forEach(node -> { + nodeConsumer.accept(node); + lastPrimaryKeyAtomic.set(node.getPrimaryKey()); + }); + return lastPrimaryKeyAtomic.get(); + }, executor); + } while (newPrimaryKey != null); + } + + /** + * Gets the appropriate storage adapter for a given layer. + *

+ * This method selects a {@link StorageAdapter} implementation based on the layer number. + * The logic is intended to use an {@code InliningStorageAdapter} for layers greater + * than 0 and a {@code CompactStorageAdapter} for layer 0. However, the switch to + * the inlining adapter is currently disabled with a hardcoded {@code false}, + * so this method will always return a {@code CompactStorageAdapter}. + * + * @param layer the layer number for which to get the storage adapter; currently unused + * + * @return a non-null {@link StorageAdapter} instance, which will always be a + * {@link CompactStorageAdapter} in the current implementation + */ + @Nonnull + private StorageAdapter getStorageAdapterForLayer(final int layer) { + return config.isUseInlining() && layer > 0 + ? new InliningStorageAdapter(getConfig(), InliningNode.factory(), getSubspace(), getOnWriteListener(), + getOnReadListener()) + : new CompactStorageAdapter(getConfig(), CompactNode.factory(), getSubspace(), getOnWriteListener(), + getOnReadListener()); + } + + /** + * Calculates a random layer for a new element to be inserted. + *

+ * The layer is selected according to a logarithmic distribution, which ensures that + * the probability of choosing a higher layer decreases exponentially. This is + * achieved by applying the inverse transform sampling method. The specific formula + * is {@code floor(-ln(u) * lambda)}, where {@code u} is a uniform random + * number and {@code lambda} is a normalization factor derived from a system + * configuration parameter {@code M}. + * + * @param random the {@link Random} object used for generating a random number. + * It must not be null. + * + * @return a non-negative integer representing the randomly selected layer. + */ + private int insertionLayer(@Nonnull final Random random) { + double lambda = 1.0 / Math.log(getConfig().getM()); + double u = 1.0 - random.nextDouble(); // Avoid log(0) + return (int) Math.floor(-Math.log(u) * lambda); + } + + /** + * Logs a message at the INFO level, using a consumer for lazy evaluation. + *

+ * This approach avoids the cost of constructing the log message if the INFO + * level is disabled. The provided {@link java.util.function.Consumer} will be + * executed only when {@code logger.isInfoEnabled()} returns {@code true}. + * + * @param loggerConsumer the {@link java.util.function.Consumer} that will be + * accepted if logging is enabled. It receives the + * {@code Logger} instance and must not be null. + */ + @SuppressWarnings("PMD.UnusedPrivateMethod") + private void info(@Nonnull final Consumer loggerConsumer) { + if (logger.isInfoEnabled()) { + loggerConsumer.accept(logger); + } + } + + private static class NodeReferenceWithLayer extends NodeReferenceWithVector { + private final int layer; + + public NodeReferenceWithLayer(@Nonnull final Tuple primaryKey, @Nonnull final Vector vector, + final int layer) { + super(primaryKey, vector); + this.layer = layer; + } + + public int getLayer() { + return layer; + } + + @Override + public boolean equals(final Object o) { + if (!(o instanceof NodeReferenceWithLayer)) { + return false; + } + if (!super.equals(o)) { + return false; + } + return layer == ((NodeReferenceWithLayer)o).layer; + } + + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), layer); + } + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/HNSWHelpers.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/HNSWHelpers.java new file mode 100644 index 0000000000..4921f1280d --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/HNSWHelpers.java @@ -0,0 +1,78 @@ +/* + * HNSWHelpers.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2023 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.christianheina.langx.half4j.Half; + +import javax.annotation.Nonnull; + +/** + * Some helper methods for {@link Node}s. + */ +@SuppressWarnings("checkstyle:AbbreviationAsWordInName") +public class HNSWHelpers { + private static final char[] hexArray = "0123456789ABCDEF".toCharArray(); + + /** + * This is a utility class and is not intended to be instantiated. + */ + private HNSWHelpers() { + // nothing + } + + /** + * Helper method to format bytes as hex strings for logging and debugging. + * @param bytes an array of bytes + * @return a {@link String} containing the hexadecimal representation of the byte array passed in + */ + @Nonnull + public static String bytesToHex(byte[] bytes) { + char[] hexChars = new char[bytes.length * 2]; + for (int j = 0; j < bytes.length; j++) { + int v = bytes[j] & 0xFF; + hexChars[j * 2] = hexArray[v >>> 4]; + hexChars[j * 2 + 1] = hexArray[v & 0x0F]; + } + return "0x" + new String(hexChars).replaceFirst("^0+(?!$)", ""); + } + + /** + * Returns a {@code Half} instance representing the specified {@code double} value, rounded to the nearest + * representable half-precision float value. + * @param d the {@code double} value to be converted. + * @return a non-null {@link Half} instance representing {@code d}. + */ + @Nonnull + public static Half halfValueOf(final double d) { + return Half.shortBitsToHalf(Half.halfToShortBits(Half.valueOf(d))); + } + + /** + * Returns a {@code Half} instance representing the specified {@code float} value, rounded to the nearest + * representable half-precision float value. + * @param f the {@code float} value to be converted. + * @return a non-null {@link Half} instance representing {@code f}. + */ + @Nonnull + public static Half halfValueOf(final float f) { + return Half.shortBitsToHalf(Half.halfToShortBits(Half.valueOf(f))); + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/InliningNode.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/InliningNode.java new file mode 100644 index 0000000000..c8161b825c --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/InliningNode.java @@ -0,0 +1,146 @@ +/* + * InliningNode.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2023 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.apple.foundationdb.annotation.SpotBugsSuppressWarnings; +import com.apple.foundationdb.tuple.Tuple; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import java.util.List; +import java.util.Objects; + +/** + * Represents a specific type of node within a graph structure that is used to represent nodes in an HNSW structure. + *

+ * This node extends {@link AbstractNode}, does not store its own vector and instead specifically manages neighbors + * of type {@link NodeReferenceWithVector} (which do store a vector each). + * It provides a concrete implementation for an "inlining" node, distinguishing it from other node types such as + * {@link CompactNode}. + */ +public class InliningNode extends AbstractNode { + @Nonnull + private static final NodeFactory FACTORY = new NodeFactory<>() { + @SuppressWarnings("unchecked") + @Nonnull + @Override + public Node create(@Nonnull final Tuple primaryKey, + @Nullable final Vector vector, + @Nonnull final List neighbors) { + return new InliningNode(primaryKey, (List)neighbors); + } + + @Nonnull + @Override + public NodeKind getNodeKind() { + return NodeKind.INLINING; + } + }; + + /** + * Constructs a new {@code InliningNode} with a specified primary key and a list of its neighbors. + *

+ * This constructor initializes the node by calling the constructor of its superclass, + * passing the primary key and neighbor list. + * + * @param primaryKey the non-null primary key of the node, represented by a {@link Tuple}. + * @param neighbors the non-null list of neighbors for this node, where each neighbor + * is a {@link NodeReferenceWithVector}. + */ + public InliningNode(@Nonnull final Tuple primaryKey, + @Nonnull final List neighbors) { + super(primaryKey, neighbors); + } + + /** + * Gets a reference to this node. + * + * @param vector the vector to be associated with the node reference. Despite the + * {@code @Nullable} annotation, this parameter must not be null. + * + * @return a new {@link NodeReferenceWithVector} instance containing the node's + * primary key and the provided vector; will never be null. + * + * @throws NullPointerException if the provided {@code vector} is null. + */ + @Nonnull + @Override + @SpotBugsSuppressWarnings("NP_PARAMETER_MUST_BE_NONNULL_BUT_MARKED_AS_NULLABLE") + public NodeReferenceWithVector getSelfReference(@Nullable final Vector vector) { + return new NodeReferenceWithVector(getPrimaryKey(), Objects.requireNonNull(vector)); + } + + /** + * Gets the kind of this node. + * @return the non-null {@link NodeKind} of this node, which is always + * {@code NodeKind.INLINING}. + */ + @Nonnull + @Override + public NodeKind getKind() { + return NodeKind.INLINING; + } + + /** + * Casts this node to a {@link CompactNode}. + *

+ * This implementation always throws an exception because this specific node type + * cannot be represented as a compact node. + * @return this node as a {@link CompactNode}, never {@code null} + * @throws IllegalStateException always, as this node is not a compact node + */ + @Nonnull + @Override + public CompactNode asCompactNode() { + throw new IllegalStateException("this is not a compact node"); + } + + /** + * Returns this object as an {@link InliningNode}. + *

+ * As this class is already an instance of {@code InliningNode}, this method simply returns {@code this}. + * @return this object, which is guaranteed to be an {@code InliningNode} and never {@code null}. + */ + @Nonnull + @Override + public InliningNode asInliningNode() { + return this; + } + + /** + * Returns the singleton factory instance used to create {@link NodeReferenceWithVector} objects. + *

+ * This method provides a standard way to obtain the factory, ensuring that a single, shared instance is used + * throughout the application. + * + * @return the singleton {@link NodeFactory} instance, never {@code null}. + */ + @Nonnull + public static NodeFactory factory() { + return FACTORY; + } + + @Override + public String toString() { + return "I[primaryKey=" + getPrimaryKey() + + ";neighbors=" + getNeighbors() + "]"; + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/InliningStorageAdapter.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/InliningStorageAdapter.java new file mode 100644 index 0000000000..58d8795777 --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/InliningStorageAdapter.java @@ -0,0 +1,345 @@ +/* + * InliningStorageAdapter.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2023 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.apple.foundationdb.KeyValue; +import com.apple.foundationdb.Range; +import com.apple.foundationdb.ReadTransaction; +import com.apple.foundationdb.StreamingMode; +import com.apple.foundationdb.Transaction; +import com.apple.foundationdb.async.AsyncIterable; +import com.apple.foundationdb.async.AsyncUtil; +import com.apple.foundationdb.subspace.Subspace; +import com.apple.foundationdb.tuple.ByteArrayUtil; +import com.apple.foundationdb.tuple.Tuple; +import com.google.common.collect.ImmutableList; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import java.util.List; +import java.util.concurrent.CompletableFuture; + +/** + * An implementation of {@link StorageAdapter} for an HNSW graph that stores node vectors "in-line" with the node's + * neighbor information. + *

+ * In this storage model, each key-value pair in the database represents a single neighbor relationship. The key + * contains the primary keys of both the source node and the neighbor node, while the value contains the neighbor's + * vector. This contrasts with a "compact" storage model where a compact node represents a vector and all of its + * neighbors. This adapter is responsible for serializing and deserializing these structures to and from the underlying + * key-value store. + * + * @see StorageAdapter + * @see HNSW + */ +class InliningStorageAdapter extends AbstractStorageAdapter implements StorageAdapter { + /** + * Constructs a new {@code InliningStorageAdapter} with the given configuration and components. + *

+ * This constructor initializes the storage adapter by passing all necessary components + * to its superclass. + * + * @param config the HNSW configuration to use for the graph + * @param nodeFactory the factory to create new {@link NodeReferenceWithVector} instances + * @param subspace the subspace where the HNSW graph data is stored + * @param onWriteListener the listener to be notified on write operations + * @param onReadListener the listener to be notified on read operations + */ + public InliningStorageAdapter(@Nonnull final HNSW.Config config, + @Nonnull final NodeFactory nodeFactory, + @Nonnull final Subspace subspace, + @Nonnull final OnWriteListener onWriteListener, + @Nonnull final OnReadListener onReadListener) { + super(config, nodeFactory, subspace, onWriteListener, onReadListener); + } + + /** + * Throws {@link IllegalStateException} because an inlining storage adapter cannot be converted to a compact one. + *

+ * This operation is fundamentally not supported for this type of adapter. An inlining adapter stores data directly + * within a parent structure, which is incompatible with the standalone nature of a compact storage format. + * @return This method never returns a value as it always throws an exception. + * @throws IllegalStateException always, as this operation is not supported. + */ + @Nonnull + @Override + public StorageAdapter asCompactStorageAdapter() { + throw new IllegalStateException("cannot call this method on an inlining storage adapter"); + } + + /** + * Returns this object instance as a {@code StorageAdapter} that supports inlining. + *

+ * This implementation returns the current instance ({@code this}) because the class itself is designed to handle + * inlining directly, thus no separate adapter object is needed. + * @return a non-null reference to this object as an {@link StorageAdapter} for inlining. + */ + @Nonnull + @Override + public StorageAdapter asInliningStorageAdapter() { + return this; + } + + /** + * Asynchronously fetches a single node from a given layer by its primary key. + *

+ * This internal method constructs a prefix key based on the {@code layer} and {@code primaryKey}. + * It then performs an asynchronous range scan to retrieve all key-value pairs associated with that prefix. + * Finally, it reconstructs the complete {@link Node} object from the collected raw data using + * the {@code nodeFromRaw} method. + * + * @param readTransaction the transaction to use for reading from the database + * @param layer the layer of the node to fetch + * @param primaryKey the primary key of the node to fetch + * + * @return a {@link CompletableFuture} that will complete with the fetched {@link Node} containing + * {@link NodeReferenceWithVector}s + */ + @Nonnull + @Override + protected CompletableFuture> fetchNodeInternal(@Nonnull final ReadTransaction readTransaction, + final int layer, + @Nonnull final Tuple primaryKey) { + final byte[] rangeKey = getNodeKey(layer, primaryKey); + + return AsyncUtil.collect(readTransaction.getRange(Range.startsWith(rangeKey), + ReadTransaction.ROW_LIMIT_UNLIMITED, false, StreamingMode.WANT_ALL), readTransaction.getExecutor()) + .thenApply(keyValues -> nodeFromRaw(layer, primaryKey, keyValues)); + } + + /** + * Constructs a {@code Node} from its raw key-value representation from storage. + *

+ * This method is responsible for deserializing a node and its neighbors. It processes a list of {@code KeyValue} + * pairs, where each pair represents a neighbor of the node being constructed. Each neighbor is converted from its + * raw form into a {@link NodeReferenceWithVector} by calling the {@link #neighborFromRaw(int, byte[], byte[])} + * method. + *

+ * Once the node is created with its primary key and list of neighbors, it notifies the configured + * {@link OnReadListener} of the read operation. + * + * @param layer the layer in the graph where this node exists + * @param primaryKey the primary key that uniquely identifies the node + * @param keyValues a list of {@code KeyValue} pairs representing the raw data of the node's neighbors + * + * @return a non-null, fully constructed {@code Node} object with its neighbors + */ + @Nonnull + private Node nodeFromRaw(final int layer, + @Nonnull final Tuple primaryKey, + @Nonnull final List keyValues) { + final OnReadListener onReadListener = getOnReadListener(); + + final ImmutableList.Builder nodeReferencesWithVectorBuilder = ImmutableList.builder(); + for (final KeyValue keyValue : keyValues) { + nodeReferencesWithVectorBuilder.add(neighborFromRaw(layer, keyValue.getKey(), keyValue.getValue())); + } + + final Node node = + getNodeFactory().create(primaryKey, null, nodeReferencesWithVectorBuilder.build()); + onReadListener.onNodeRead(layer, node); + return node; + } + + /** + * Constructs a {@code NodeReferenceWithVector} from raw key and value byte arrays retrieved from storage. + *

+ * This helper method deserializes a neighbor's data. It unpacks the provided {@code key} to extract the neighbor's + * primary key and unpacks the {@code value} to extract the neighbor's vector. It also notifies the configured + * {@link OnReadListener} of the read operation. + * + * @param layer the layer of the graph where the neighbor node is located. + * @param key the raw byte array key from the database, which contains the neighbor's primary key. + * @param value the raw byte array value from the database, which represents the neighbor's vector. + * @return a new {@link NodeReferenceWithVector} instance representing the deserialized neighbor. + * @throws IllegalArgumentException if the key or value byte arrays are malformed and cannot be unpacked. + */ + @Nonnull + private NodeReferenceWithVector neighborFromRaw(final int layer, final @Nonnull byte[] key, final byte[] value) { + final OnReadListener onReadListener = getOnReadListener(); + + onReadListener.onKeyValueRead(layer, key, value); + final Tuple neighborKeyTuple = getDataSubspace().unpack(key); + final Tuple neighborValueTuple = Tuple.fromBytes(value); + + final Tuple neighborPrimaryKey = neighborKeyTuple.getNestedTuple(2); // neighbor primary key + final Vector neighborVector = StorageAdapter.vectorFromTuple(neighborValueTuple); // the entire value is the vector + return new NodeReferenceWithVector(neighborPrimaryKey, neighborVector); + } + + /** + * Writes a given node and its neighbor changes to the specified layer within a transaction. + *

+ * This implementation first converts the provided {@link Node} to an {@link InliningNode}. It then delegates the + * writing of neighbor modifications to the {@link NeighborsChangeSet#writeDelta} method. After the changes are + * written, it notifies the registered {@code OnWriteListener} that the node has been processed + * via {@code getOnWriteListener().onNodeWritten()}. + * + * @param transaction the transaction context for the write operation; must not be null + * @param node the node to be written, which is expected to be an + * {@code InliningNode}; must not be null + * @param layer the layer index where the node and its neighbor changes should be written + * @param neighborsChangeSet the set of changes to the node's neighbors to be + * persisted; must not be null + */ + @Override + public void writeNodeInternal(@Nonnull final Transaction transaction, @Nonnull final Node node, + final int layer, @Nonnull final NeighborsChangeSet neighborsChangeSet) { + final InliningNode inliningNode = node.asInliningNode(); + + neighborsChangeSet.writeDelta(this, transaction, layer, inliningNode, t -> true); + getOnWriteListener().onNodeWritten(layer, node); + } + + /** + * Constructs the raw database key for a node based on its layer and primary key. + *

+ * This key is created by packing a tuple containing the specified {@code layer} and the node's {@code primaryKey} + * within the data subspace. The resulting byte array is suitable for use in direct database lookups and preserves + * the sort order of the components. + * + * @param layer the layer index where the node resides + * @param primaryKey the primary key that uniquely identifies the node within its layer, + * encapsulated in a {@link Tuple} + * + * @return a byte array representing the packed key for the specified node + */ + @Nonnull + private byte[] getNodeKey(final int layer, @Nonnull final Tuple primaryKey) { + return getDataSubspace().pack(Tuple.from(layer, primaryKey)); + } + + /** + * Writes a neighbor for a given node to the underlying storage within a specific transaction. + *

+ * This method serializes the neighbor's vector and constructs a unique key based on the layer, the source + * {@code node}, and the neighbor's primary key. It then persists this key-value pair using the provided + * {@link Transaction}. After a successful write, it notifies any registered listeners. + * + * @param transaction the {@link Transaction} to use for the write operation + * @param layer the layer index where the node and its neighbor reside + * @param node the source {@link Node} for which the neighbor is being written + * @param neighbor the {@link NodeReferenceWithVector} representing the neighbor to persist + */ + public void writeNeighbor(@Nonnull final Transaction transaction, final int layer, + @Nonnull final Node node, @Nonnull final NodeReferenceWithVector neighbor) { + final byte[] neighborKey = getNeighborKey(layer, node, neighbor.getPrimaryKey()); + final byte[] value = StorageAdapter.tupleFromVector(neighbor.getVector()).pack(); + transaction.set(neighborKey, + value); + getOnWriteListener().onNeighborWritten(layer, node, neighbor); + getOnWriteListener().onKeyValueWritten(layer, neighborKey, value); + } + + /** + * Deletes a neighbor edge from a given node within a specific layer. + *

+ * This operation removes the key-value pair representing the neighbor relationship from the database within the + * given {@link Transaction}. It also notifies the {@code onWriteListener} about the deletion. + * + * @param transaction the transaction in which to perform the deletion + * @param layer the layer of the graph where the node resides + * @param node the node from which the neighbor edge is removed + * @param neighborPrimaryKey the primary key of the neighbor node to be deleted + */ + public void deleteNeighbor(@Nonnull final Transaction transaction, final int layer, + @Nonnull final Node node, @Nonnull final Tuple neighborPrimaryKey) { + transaction.clear(getNeighborKey(layer, node, neighborPrimaryKey)); + getOnWriteListener().onNeighborDeleted(layer, node, neighborPrimaryKey); + } + + /** + * Constructs the key for a specific neighbor of a node within a given layer. + *

+ * This key is used to uniquely identify and store the neighbor relationship in the underlying data store. It is + * formed by packing a {@link Tuple} containing the {@code layer}, the primary key of the source {@code node}, and + * the {@code neighborPrimaryKey}. + * + * @param layer the layer of the graph where the node and its neighbor reside + * @param node the non-null source node for which the neighbor key is being generated + * @param neighborPrimaryKey the non-null primary key of the neighbor node + * @return a non-null byte array representing the packed key for the neighbor relationship + */ + @Nonnull + private byte[] getNeighborKey(final int layer, + @Nonnull final Node node, + @Nonnull final Tuple neighborPrimaryKey) { + return getDataSubspace().pack(Tuple.from(layer, node.getPrimaryKey(), neighborPrimaryKey)); + } + + /** + * Scans a specific layer of the graph, reconstructing nodes and their neighbors from the underlying key-value + * store. + *

+ * This method reads raw {@link com.apple.foundationdb.KeyValue} records from the database within a given layer. + * It groups adjacent records that belong to the same parent node and uses a {@link NodeFactory} to construct + * {@link Node} objects. The method supports pagination through the {@code lastPrimaryKey} parameter, allowing for + * incremental scanning of large layers. + * + * @param readTransaction the transaction to use for reading data + * @param layer the layer of the graph to scan + * @param lastPrimaryKey the primary key of the last node read in a previous scan, used for pagination. + * If {@code null}, the scan starts from the beginning of the layer. + * @param maxNumRead the maximum number of raw key-value records to read from the database + * + * @return an {@code Iterable} of {@link Node} objects reconstructed from the scanned layer. Each node contains + * its neighbors within that layer. + */ + @Nonnull + @Override + public Iterable> scanLayer(@Nonnull final ReadTransaction readTransaction, int layer, + @Nullable final Tuple lastPrimaryKey, int maxNumRead) { + final byte[] layerPrefix = getDataSubspace().pack(Tuple.from(layer)); + final Range range = + lastPrimaryKey == null + ? Range.startsWith(layerPrefix) + : new Range(ByteArrayUtil.strinc(getDataSubspace().pack(Tuple.from(layer, lastPrimaryKey))), + ByteArrayUtil.strinc(layerPrefix)); + final AsyncIterable itemsIterable = + readTransaction.getRange(range, + maxNumRead, false, StreamingMode.ITERATOR); + int numRead = 0; + Tuple nodePrimaryKey = null; + ImmutableList.Builder> nodeBuilder = ImmutableList.builder(); + ImmutableList.Builder neighborsBuilder = ImmutableList.builder(); + for (final KeyValue item: itemsIterable) { + final NodeReferenceWithVector neighbor = + neighborFromRaw(layer, item.getKey(), item.getValue()); + final Tuple primaryKeyFromNodeReference = neighbor.getPrimaryKey(); + if (nodePrimaryKey == null) { + nodePrimaryKey = primaryKeyFromNodeReference; + } else { + if (!nodePrimaryKey.equals(primaryKeyFromNodeReference)) { + nodeBuilder.add(getNodeFactory().create(nodePrimaryKey, null, neighborsBuilder.build())); + } + } + neighborsBuilder.add(neighbor); + numRead ++; + } + + // there may be a rest + if (numRead > 0 && numRead < maxNumRead) { + nodeBuilder.add(getNodeFactory().create(nodePrimaryKey, null, neighborsBuilder.build())); + } + + return nodeBuilder.build(); + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/InsertNeighborsChangeSet.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/InsertNeighborsChangeSet.java new file mode 100644 index 0000000000..f9894ccebd --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/InsertNeighborsChangeSet.java @@ -0,0 +1,132 @@ +/* + * InsertNeighborsChangeSet.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2023 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.apple.foundationdb.Transaction; +import com.apple.foundationdb.tuple.Tuple; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Iterables; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.Nonnull; +import java.util.List; +import java.util.Map; +import java.util.function.Predicate; + +/** + * Represents an immutable change set for the neighbors of a node in the HNSW graph, specifically + * capturing the insertion of new neighbors. + *

+ * This class layers new neighbors on top of a parent {@link NeighborsChangeSet}, allowing for a + * layered representation of modifications. The changes are not applied to the database until + * {@link #writeDelta} is called. + * + * @param the type of the node reference, which must extend {@link NodeReference} + */ +class InsertNeighborsChangeSet implements NeighborsChangeSet { + @Nonnull + private static final Logger logger = LoggerFactory.getLogger(InsertNeighborsChangeSet.class); + + @Nonnull + private final NeighborsChangeSet parent; + + @Nonnull + private final Map insertedNeighborsMap; + + /** + * Creates a new {@code InsertNeighborsChangeSet}. + *

+ * This constructor initializes the change set with its parent and a list of neighbors + * to be inserted. It internally builds an immutable map of the inserted neighbors, + * keyed by their primary key for efficient lookups. + * + * @param parent the parent {@link NeighborsChangeSet} on which this insertion is based. + * @param insertedNeighbors the list of neighbors to be inserted. + */ + public InsertNeighborsChangeSet(@Nonnull final NeighborsChangeSet parent, + @Nonnull final List insertedNeighbors) { + this.parent = parent; + final ImmutableMap.Builder insertedNeighborsMapBuilder = ImmutableMap.builder(); + for (final N insertedNeighbor : insertedNeighbors) { + insertedNeighborsMapBuilder.put(insertedNeighbor.getPrimaryKey(), insertedNeighbor); + } + + this.insertedNeighborsMap = insertedNeighborsMapBuilder.build(); + } + + /** + * Gets the parent {@code NeighborsChangeSet} from which this change set was derived. + * @return the parent {@link NeighborsChangeSet}, which is never {@code null}. + */ + @Nonnull + @Override + public NeighborsChangeSet getParent() { + return parent; + } + + /** + * Merges the neighbors from this level of the hierarchy with all neighbors from parent levels. + *

+ * This is achieved by creating a combined view that includes the results of the parent's {@code #merge()} call and + * the neighbors that have been inserted at the current level. The resulting {@code Iterable} provides a complete + * set of neighbors from this node and all its ancestors. + * @return a non-null {@code Iterable} containing all neighbors from this node and its ancestors. + */ + @Nonnull + @Override + public Iterable merge() { + return Iterables.concat(getParent().merge(), insertedNeighborsMap.values()); + } + + /** + * Writes the delta of this layer to the specified storage adapter. + *

+ * This implementation first delegates to the parent to write its delta, but excludes any neighbors that have been + * newly inserted in the current context (i.e., those in {@code insertedNeighborsMap}). It then iterates through its + * own newly inserted neighbors. For each neighbor that satisfies the given {@code tuplePredicate}, it writes the + * neighbor relationship to storage via the {@link InliningStorageAdapter}. + * + * @param storageAdapter the storage adapter to write to; must not be null + * @param transaction the transaction context for the write operation; must not be null + * @param layer the layer index to write the data to + * @param node the source node for which the neighbor delta is being written; must not be null + * @param tuplePredicate a predicate to filter which neighbor tuples should be written; must not be null + */ + @Override + public void writeDelta(@Nonnull final InliningStorageAdapter storageAdapter, @Nonnull final Transaction transaction, + final int layer, @Nonnull final Node node, @Nonnull final Predicate tuplePredicate) { + getParent().writeDelta(storageAdapter, transaction, layer, node, + tuplePredicate.and(tuple -> !insertedNeighborsMap.containsKey(tuple))); + + for (final Map.Entry entry : insertedNeighborsMap.entrySet()) { + final Tuple primaryKey = entry.getKey(); + if (tuplePredicate.test(primaryKey)) { + storageAdapter.writeNeighbor(transaction, layer, node.asInliningNode(), + entry.getValue().asNodeReferenceWithVector()); + if (logger.isDebugEnabled()) { + logger.debug("inserted neighbor of primaryKey={} targeting primaryKey={}", node.getPrimaryKey(), + primaryKey); + } + } + } + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Metric.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Metric.java new file mode 100644 index 0000000000..a49457677f --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Metric.java @@ -0,0 +1,242 @@ +/* + * Metric.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import javax.annotation.Nonnull; + +/** + * Defines a metric for measuring the distance or similarity between n-dimensional vectors. + *

+ * This interface provides a contract for various distance calculation algorithms, such as Euclidean, Manhattan, + * and Cosine distance. Implementations of this interface can be used in algorithms that require a metric for + * comparing data vectors, like clustering or nearest neighbor searches. + */ +public interface Metric { + /** + * Calculates a distance between two n-dimensional vectors. + *

+ * The two vectors are represented as arrays of {@link Double} and must be of the + * same length (i.e., have the same number of dimensions). + * + * @param vector1 the first vector. Must not be null. + * @param vector2 the second vector. Must not be null and must have the same + * length as {@code vector1}. + * + * @return the calculated distance as a {@code double}. + * + * @throws IllegalArgumentException if the vectors have different lengths. + * @throws NullPointerException if either {@code vector1} or {@code vector2} is null. + */ + double distance(@Nonnull double[] vector1, @Nonnull double[] vector2); + + /** + * Calculates a comparative distance between two vectors. The comparative distance is used in contexts such as + * ranking where the caller needs to "compare" two distances. In contrast to a true metric, the distances computed + * by this method do not need to follow proper metric invariants: The distance can be negative; the distance + * does not need to follow triangle inequality. + *

+ * This method is an alias for {@link #distance(double[], double[])} under normal circumstances. It is not for e.g. + * {@link DotProductMetric} where the distance is the negative dot product. + * + * @param vector1 the first vector, represented as an array of {@code double}. + * @param vector2 the second vector, represented as an array of {@code double}. + * + * @return the distance between the two vectors. + */ + default double comparativeDistance(@Nonnull double[] vector1, @Nonnull double[] vector2) { + return distance(vector1, vector2); + } + + /** + * A helper method to validate that vectors can be compared. + * @param vector1 The first vector. + * @param vector2 The second vector. + */ + private static void validate(double[] vector1, double[] vector2) { + if (vector1 == null || vector2 == null) { + throw new IllegalArgumentException("Vectors cannot be null"); + } + if (vector1.length != vector2.length) { + throw new IllegalArgumentException( + "Vectors must have the same dimensionality. Got " + vector1.length + " and " + vector2.length + ); + } + if (vector1.length == 0) { + throw new IllegalArgumentException("Vectors cannot be empty."); + } + } + + /** + * Represents the Manhattan distance metric. + *

+ * This metric calculates a distance overlaying the multidimensional space with a grid-like structure only allowing + * orthogonal lines. In 2D this resembles the street structure in Manhattan where one would have to go {@code x} + * blocks north/south and {@code y} blocks east/west leading to a total distance of {@code x + y}. + */ + class ManhattanMetric implements Metric { + @Override + public double distance(@Nonnull final double[] vector1, @Nonnull final double[] vector2) { + Metric.validate(vector1, vector2); + + double sumOfAbsDiffs = 0.0; + for (int i = 0; i < vector1.length; i++) { + sumOfAbsDiffs += Math.abs(vector1[i] - vector2[i]); + } + return sumOfAbsDiffs; + } + + @Override + @Nonnull + public String toString() { + return this.getClass().getSimpleName(); + } + } + + /** + * Represents the Euclidean distance metric. + *

+ * This metric calculates the "ordinary" straight-line distance between two points + * in Euclidean space. The distance is the square root of the sum of the + * squared differences between the corresponding coordinates of the two points. + */ + class EuclideanMetric implements Metric { + @Override + public double distance(@Nonnull final double[] vector1, @Nonnull final double[] vector2) { + Metric.validate(vector1, vector2); + + return Math.sqrt(EuclideanSquareMetric.distanceInternal(vector1, vector2)); + } + + @Override + @Nonnull + public String toString() { + return this.getClass().getSimpleName(); + } + } + + /** + * Represents the squared Euclidean distance metric. + *

+ * This metric calculates the sum of the squared differences between the coordinates of two vectors, defined as + * {@code sum((p_i - q_i)^2)}. It is computationally less expensive than the standard Euclidean distance because it + * avoids the final square root operation. + *

+ * This is often preferred in algorithms where comparing distances is more important than the actual distance value, + * such as in clustering algorithms, as it preserves the relative ordering of distances. + * + * @see Squared Euclidean + * distance + */ + class EuclideanSquareMetric implements Metric { + @Override + public double distance(@Nonnull final double[] vector1, @Nonnull final double[] vector2) { + Metric.validate(vector1, vector2); + return distanceInternal(vector1, vector2); + } + + private static double distanceInternal(@Nonnull final double[] vector1, @Nonnull final double[] vector2) { + double sumOfSquares = 0.0d; + for (int i = 0; i < vector1.length; i++) { + double diff = vector1[i] - vector2[i]; + sumOfSquares += diff * diff; + } + return sumOfSquares; + } + + @Override + @Nonnull + public String toString() { + return this.getClass().getSimpleName(); + } + } + + /** + * Represents the Cosine distance metric. + *

+ * This metric calculates a "distance" between two vectors {@code v1} and {@code v2} that ranges between + * {@code 0.0d} and {@code 2.0d} that corresponds to {@code 1 - cos(v1, v2)}, meaning that if {@code v1 == v2}, + * the distance is {@code 0} while if {@code v1} is orthogonal to {@code v2} it is {@code 1}. + * @see Metric.CosineMetric + */ + class CosineMetric implements Metric { + @Override + public double distance(@Nonnull final double[] vector1, @Nonnull final double[] vector2) { + Metric.validate(vector1, vector2); + + double dotProduct = 0.0; + double normA = 0.0; + double normB = 0.0; + + for (int i = 0; i < vector1.length; i++) { + dotProduct += vector1[i] * vector2[i]; + normA += vector1[i] * vector1[i]; + normB += vector2[i] * vector2[i]; + } + + // Handle the case of zero-vectors to avoid division by zero + if (normA == 0.0 || normB == 0.0) { + return Double.POSITIVE_INFINITY; + } + + return 1.0d - dotProduct / (Math.sqrt(normA) * Math.sqrt(normB)); + } + + @Override + @Nonnull + public String toString() { + return this.getClass().getSimpleName(); + } + } + + /** + * Dot product similarity. + *

+ * This metric calculates the inverted dot product of two vectors. It is not a true metric as the dot product can + * be positive at which point the distance is negative. In order to make callers aware of this fact, this distance + * only allows {@link Metric#comparativeDistance(double[], double[])} to be called. + * + * @see Dot Product + * @see DotProductMetric + */ + class DotProductMetric implements Metric { + @Override + public double distance(@Nonnull final double[] vector1, @Nonnull final double[] vector2) { + throw new UnsupportedOperationException("dot product metric is not a true metric and can only be used for ranking"); + } + + @Override + public double comparativeDistance(@Nonnull final double[] vector1, @Nonnull final double[] vector2) { + Metric.validate(vector1, vector2); + + double product = 0.0d; + for (int i = 0; i < vector1.length; i++) { + product += vector1[i] * vector2[i]; + } + return -product; + } + + @Override + @Nonnull + public String toString() { + return this.getClass().getSimpleName(); + } + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Metrics.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Metrics.java new file mode 100644 index 0000000000..0af9cf7af2 --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Metrics.java @@ -0,0 +1,111 @@ +/* + * Metrics.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import javax.annotation.Nonnull; + +/** + * Represents various distance calculation strategies (metrics) for vectors. + *

+ * Each enum constant holds a specific metric implementation, providing a type-safe way to calculate the distance + * between two points in a multidimensional space. + * + * @see Metric + */ +public enum Metrics { + /** + * Represents the Manhattan distance metric, implemented by {@link Metric.ManhattanMetric}. + *

+ * This metric calculates a distance overlaying the multidimensional space with a grid-like structure only allowing + * orthogonal lines. In 2D this resembles the street structure in Manhattan where one would have to go {@code x} + * blocks north/south and {@code y} blocks east/west leading to a total distance of {@code x + y}. + * @see Metric.ManhattanMetric + */ + MANHATTAN_METRIC(new Metric.ManhattanMetric()), + + /** + * Represents the Euclidean distance metric, implemented by {@link Metric.EuclideanMetric}. + *

+ * This metric calculates the "ordinary" straight-line distance between two points + * in Euclidean space. The distance is the square root of the sum of the + * squared differences between the corresponding coordinates of the two points. + * @see Metric.EuclideanMetric + */ + EUCLIDEAN_METRIC(new Metric.EuclideanMetric()), + + /** + * Represents the squared Euclidean distance metric, implemented by {@link Metric.EuclideanSquareMetric}. + *

+ * This metric calculates the sum of the squared differences between the coordinates of two vectors, defined as + * {@code sum((p_i - q_i)^2)}. It is computationally less expensive than the standard Euclidean distance because it + * avoids the final square root operation. + *

+ * This is often preferred in algorithms where comparing distances is more important than the actual distance value, + * such as in clustering algorithms, as it preserves the relative ordering of distances. + * + * @see Squared Euclidean + * distance + * @see Metric.EuclideanSquareMetric + */ + EUCLIDEAN_SQUARE_METRIC(new Metric.EuclideanSquareMetric()), + + /** + * Represents the Cosine distance metric, implemented by {@link Metric.CosineMetric}. + *

+ * This metric calculates a "distance" between two vectors {@code v1} and {@code v2} that ranges between + * {@code 0.0d} and {@code 2.0d} that corresponds to {@code 1 - cos(v1, v2)}, meaning that if {@code v1 == v2}, + * the distance is {@code 0} while if {@code v1} is orthogonal to {@code v2} it is {@code 1}. + * @see Metric.CosineMetric + */ + COSINE_METRIC(new Metric.CosineMetric()), + + /** + * Dot product similarity, implemented by {@link Metric.DotProductMetric} + *

+ * This metric calculates the inverted dot product of two vectors. It is not a true metric as the dot product can + * be positive at which point the distance is negative. In order to make callers aware of this fact, this distance + * only allows {@link Metric#comparativeDistance(double[], double[])} to be called. + * + * @see Dot Product + * @see Metric.DotProductMetric + */ + DOT_PRODUCT_METRIC(new Metric.DotProductMetric()); + + @Nonnull + private final Metric metric; + + /** + * Constructs a new Metrics instance with the specified metric. + * @param metric the metric to be associated with this Metrics instance; must not be null. + */ + Metrics(@Nonnull final Metric metric) { + this.metric = metric; + } + + /** + * Gets the {@code Metric} associated with this instance. + * @return the non-null {@link Metric} for this instance + */ + @Nonnull + public Metric getMetric() { + return metric; + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NeighborsChangeSet.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NeighborsChangeSet.java new file mode 100644 index 0000000000..2eb02e74e3 --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NeighborsChangeSet.java @@ -0,0 +1,80 @@ +/* + * NeighborsChangeSet.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2023 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.apple.foundationdb.Transaction; +import com.apple.foundationdb.tuple.Tuple; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import java.util.function.Predicate; + +/** + * Represents a set of changes to the neighbors of a node within an HNSW graph. + *

+ * Implementations of this interface manage modifications, such as additions or removals of neighbors. often in a + * layered fashion. This allows for composing changes before they are committed to storage. The {@link #getParent()} + * method returns the next element in this layered structure while {@link #merge()} consolidates changes into + * a final neighbor list. + * + * @param the type of the node reference, which must extend {@link NodeReference} + */ +interface NeighborsChangeSet { + /** + * Gets the parent change set from which this change set was derived. + *

+ * Change sets can be layered, forming a chain of modifications. + * This method allows for traversing up this tree to the preceding set of changes. + * + * @return the parent {@code NeighborsChangeSet}, or {@code null} if this change set + * is the root of the change tree and has no parent. + */ + @Nullable + NeighborsChangeSet getParent(); + + /** + * Merges multiple internal sequences into a single, consolidated iterable sequence. + *

+ * This method combines distinct internal changesets into one continuous stream of neighbors. The specific order + * of the merged elements depends on the implementation. + * + * @return a non-null {@code Iterable} containing the merged sequence of elements. + */ + @Nonnull + Iterable merge(); + + /** + * Writes the neighbor delta for a given {@link Node} to the specified storage layer. + *

+ * This method processes the provided {@code node} and writes only the records that match the given + * {@code primaryKeyPredicate} to the storage system via the {@link InliningStorageAdapter}. The entire operation + * is performed within the context of the supplied {@link Transaction}. + * + * @param storageAdapter the storage adapter to which the delta will be written; must not be null + * @param transaction the transaction context for the write operation; must not be null + * @param layer the specific storage layer to write the delta to + * @param node the source node containing the data to be written; must not be null + * @param primaryKeyPredicate a predicate to filter records by their primary key. Only records + * for which the predicate returns {@code true} will be written. Must not be null. + */ + void writeDelta(@Nonnull InliningStorageAdapter storageAdapter, @Nonnull Transaction transaction, int layer, + @Nonnull Node node, @Nonnull Predicate primaryKeyPredicate); +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Node.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Node.java new file mode 100644 index 0000000000..88d10480ce --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Node.java @@ -0,0 +1,110 @@ +/* + * Node.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2023 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.apple.foundationdb.tuple.Tuple; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import java.util.List; + +/** + * Represents a node within an HNSW (Hierarchical Navigable Small World) structure. + *

+ * A node corresponds to a data point (vector) in the structure and maintains a list of its neighbors. + * This interface defines the common contract for different node representations, such as {@link CompactNode} + * and {@link InliningNode}. + *

+ * + * @param the type of reference used to point to other nodes, which must extend {@link NodeReference} + */ +public interface Node { + /** + * Gets the primary key for this object. + *

+ * The primary key is represented as a {@link Tuple} and uniquely identifies + * the object within its storage context. This method is guaranteed to not + * return a null value. + * + * @return the primary key as a {@code Tuple}, which is never {@code null} + */ + @Nonnull + Tuple getPrimaryKey(); + + /** + * Returns a self-reference to this object, enabling fluent method chaining. This allows to create node references + * that contain an vector and are independent of the storage implementation. + * @param vector the vector of {@code Half} objects to process. This parameter + * is optional and can be {@code null}. + * + * @return a non-null reference to this object ({@code this}) for further + * method calls. + */ + @Nonnull + N getSelfReference(@Nullable Vector vector); + + /** + * Gets the list of neighboring nodes. + *

+ * This method is guaranteed to not return {@code null}. If there are no neighbors, an empty list is returned. + * + * @return a non-null list of neighboring nodes. + */ + @Nonnull + List getNeighbors(); + + /** + * Gets the neighbor at the specified index. + *

+ * This method provides access to the neighbors of a particular node or element, identified by a zero-based index. + * @param index the zero-based index of the neighbor to retrieve. + * @return the neighbor at the specified index; this method will never return {@code null}. + */ + @Nonnull + N getNeighbor(int index); + + /** + * Return the kind of the node, i.e. {@link NodeKind#COMPACT} or {@link NodeKind#INLINING}. + * @return the kind of this node as a {@link NodeKind} + */ + @Nonnull + NodeKind getKind(); + + /** + * Converts this node into its {@link CompactNode} representation. + *

+ * A {@code CompactNode} is a space-efficient implementation {@code Node}. This method provides the + * conversion logic to transform the current object into that compact form. + * + * @return a non-null {@link CompactNode} representing the current node. + */ + @Nonnull + CompactNode asCompactNode(); + + /** + * Converts this node into its {@link InliningNode} representation. + * @return this object cast to an {@link InliningNode}; never {@code null}. + * @throws ClassCastException if this object is not actually an instance of + * {@link InliningNode}. + */ + @Nonnull + InliningNode asInliningNode(); +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeFactory.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeFactory.java new file mode 100644 index 0000000000..814a8d9030 --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeFactory.java @@ -0,0 +1,64 @@ +/* + * NodeFactory.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.apple.foundationdb.tuple.Tuple; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import java.util.List; + +/** + * A factory interface for creating {@link Node} instances within a Hierarchical Navigable Small World (HNSW) graph. + *

+ * Implementations of this interface define how nodes are constructed, allowing for different node types + * or storage strategies within the HNSW structure. + * + * @param the type of {@link NodeReference} used to refer to nodes in the graph + */ +public interface NodeFactory { + /** + * Creates a new node with the specified properties. + *

+ * This method is responsible for instantiating a {@code Node} object, initializing it + * with a primary key, an optional feature vector, and a list of its initial neighbors. + * + * @param primaryKey the {@link Tuple} representing the unique primary key for the new node. Must not be + * {@code null}. + * @param vector the optional feature {@link Vector} associated with the node, which can be used for similarity + * calculations. May be {@code null} if the node does not encode a vector (see {@link CompactNode} versus + * {@link InliningNode}. + * @param neighbors the list of initial {@link NodeReference}s for the new node, + * establishing its initial connections in the graph. Must not be {@code null}. + * + * @return a new, non-null {@link Node} instance configured with the provided parameters. + */ + @Nonnull + Node create(@Nonnull Tuple primaryKey, @Nullable Vector vector, + @Nonnull List neighbors); + + /** + * Gets the kind of this node. + * @return the kind of this node, never {@code null}. + */ + @Nonnull + NodeKind getNodeKind(); +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeKind.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeKind.java new file mode 100644 index 0000000000..de7aeb6572 --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeKind.java @@ -0,0 +1,88 @@ +/* + * NodeKind.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2023 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.google.common.base.Verify; + +import javax.annotation.Nonnull; + +/** + * Represents the different kinds of nodes, each associated with a unique byte value for serialization and + * deserialization. + */ +public enum NodeKind { + /** + * Compact node. Serialization and deserialization is implemented in {@link CompactNode}. + *

+ * Compact nodes store their own vector and their neighbors-list only contain the primary key for each neighbor. + */ + COMPACT((byte)0x00), + + /** + * Inlining node. Serialization and deserialization is implemented in {@link InliningNode}. + *

+ * Inlining nodes do not store their own vector and their neighbors-list contain the both the primary key and the + * neighbor vector for each neighbor. Each neighbor is stored in its own key/value pair. + */ + INLINING((byte)0x01); + + private final byte serialized; + + /** + * Constructs a new {@code NodeKind} instance with its serialized representation. + * @param serialized the byte value used for serialization + */ + NodeKind(final byte serialized) { + this.serialized = serialized; + } + + /** + * Gets the serialized byte value. + * @return the serialized byte value + */ + public byte getSerialized() { + return serialized; + } + + /** + * Deserializes a byte into the corresponding {@link NodeKind}. + * @param serializedNodeKind the byte representation of the node kind. + * @return the corresponding {@link NodeKind}, never {@code null}. + * @throws IllegalArgumentException if the {@code serializedNodeKind} does not + * correspond to a known node kind. + */ + @Nonnull + static NodeKind fromSerializedNodeKind(byte serializedNodeKind) { + final NodeKind nodeKind; + switch (serializedNodeKind) { + case 0x00: + nodeKind = NodeKind.COMPACT; + break; + case 0x01: + nodeKind = NodeKind.INLINING; + break; + default: + throw new IllegalArgumentException("unknown node kind"); + } + Verify.verify(nodeKind.getSerialized() == serializedNodeKind); + return nodeKind; + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeReference.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeReference.java new file mode 100644 index 0000000000..a302607a2c --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeReference.java @@ -0,0 +1,117 @@ +/* + * NodeReference.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.apple.foundationdb.tuple.Tuple; +import com.google.common.collect.Streams; + +import javax.annotation.Nonnull; +import java.util.Objects; + +/** + * Represents a reference to a node, uniquely identified by its primary key. It provides fundamental operations such as + * equality comparison, hashing, and string representation based on this key. It also serves as a base class for more + * specialized node references. + */ +public class NodeReference { + @Nonnull + private final Tuple primaryKey; + + /** + * Constructs a new {@code NodeReference} with the specified primary key. + * @param primaryKey the primary key of the node to reference; must not be {@code null}. + */ + public NodeReference(@Nonnull final Tuple primaryKey) { + this.primaryKey = primaryKey; + } + + /** + * Gets the primary key for this object. + * @return the primary key as a {@code Tuple} object, which is guaranteed to be non-null. + */ + @Nonnull + public Tuple getPrimaryKey() { + return primaryKey; + } + + /** + * Casts this object to a {@link NodeReferenceWithVector}. + *

+ * This method is intended to be used on subclasses that actually represent a node reference with a vector. For this + * base class or specific implementation, it is not a valid operation. + * @return this instance cast as a {@code NodeReferenceWithVector} + * @throws IllegalStateException always, to indicate that this object cannot be + * represented as a {@link NodeReferenceWithVector}. + */ + @Nonnull + public NodeReferenceWithVector asNodeReferenceWithVector() { + throw new IllegalStateException("method should not be called"); + } + + /** + * Compares this {@code NodeReference} to the specified object for equality. + *

+ * The result is {@code true} if and only if the argument is not {@code null} and is a {@code NodeReference} object + * that has the same {@code primaryKey} as this object. + * + * @param o the object to compare with this {@code NodeReference} for equality. + * @return {@code true} if the given object is equal to this one; + * {@code false} otherwise. + */ + @Override + public boolean equals(final Object o) { + if (!(o instanceof NodeReference)) { + return false; + } + final NodeReference that = (NodeReference)o; + return Objects.equals(primaryKey, that.primaryKey); + } + + /** + * Generates a hash code for this object based on the primary key. + * @return a hash code value for this object. + */ + @Override + public int hashCode() { + return Objects.hashCode(primaryKey); + } + + /** + * Returns a string representation of the object. + * @return a string representation of this object. + */ + @Override + public String toString() { + return "NR[primaryKey=" + primaryKey + "]"; + } + + /** + * Helper to extract the primary keys from a given collection of node references. + * @param neighbors an iterable of {@link NodeReference} objects from which to extract primary keys. + * @return a lazily-evaluated {@code Iterable} of {@link Tuple}s, representing the primary keys of the input nodes. + */ + @Nonnull + public static Iterable primaryKeys(@Nonnull Iterable neighbors) { + return () -> Streams.stream(neighbors) + .map(NodeReference::getPrimaryKey) + .iterator(); + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeReferenceAndNode.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeReferenceAndNode.java new file mode 100644 index 0000000000..1a2053133d --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeReferenceAndNode.java @@ -0,0 +1,86 @@ +/* + * NodeReferenceAndNode.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.google.common.collect.ImmutableList; + +import javax.annotation.Nonnull; +import java.util.List; + +/** + * A container class that pairs a {@link NodeReferenceWithDistance} with its corresponding {@link Node} object. + *

+ * This is often used during graph traversal or searching, where a reference to a node (along with its distance from a + * query point) is first identified, and then the complete node data is fetched. This class holds these two related + * pieces of information together. + * @param the type of {@link NodeReference} used within the {@link Node} + */ +public class NodeReferenceAndNode { + @Nonnull + private final NodeReferenceWithDistance nodeReferenceWithDistance; + @Nonnull + private final Node node; + + /** + * Constructs a new instance that pairs a node reference (with distance) with its + * corresponding {@link Node} object. + * @param nodeReferenceWithDistance the reference to a node, which also includes distance information. Must not be + * {@code null}. + * @param node the actual {@code Node} object that the reference points to. Must not be {@code null}. + */ + public NodeReferenceAndNode(@Nonnull final NodeReferenceWithDistance nodeReferenceWithDistance, @Nonnull final Node node) { + this.nodeReferenceWithDistance = nodeReferenceWithDistance; + this.node = node; + } + + /** + * Gets the node reference and its associated distance. + * @return the non-null {@link NodeReferenceWithDistance} object. + */ + @Nonnull + public NodeReferenceWithDistance getNodeReferenceWithDistance() { + return nodeReferenceWithDistance; + } + + /** + * Gets the underlying node represented by this object. + * @return the associated {@link Node} instance, never {@code null}. + */ + @Nonnull + public Node getNode() { + return node; + } + + /** + * Helper to extract the references from a given collection of objects of this container class. + * @param referencesAndNodes an iterable of {@link NodeReferenceAndNode} objects from which to extract the + * references. + * @return a {@link List} of {@link NodeReferenceAndNode}s + */ + @Nonnull + public static List getReferences(@Nonnull List> referencesAndNodes) { + final ImmutableList.Builder referencesBuilder = ImmutableList.builder(); + for (final NodeReferenceAndNode referenceWithNode : referencesAndNodes) { + referencesBuilder.add(referenceWithNode.getNodeReferenceWithDistance()); + } + return referencesBuilder.build(); + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeReferenceWithDistance.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeReferenceWithDistance.java new file mode 100644 index 0000000000..7b46f65f69 --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeReferenceWithDistance.java @@ -0,0 +1,92 @@ +/* + * NodeReferenceWithDistance.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.apple.foundationdb.tuple.Tuple; + +import javax.annotation.Nonnull; +import java.util.Objects; + +/** + * Represents a reference to a node that includes its vector and its distance from a query vector. + *

+ * This class extends {@link NodeReferenceWithVector} by additionally associating a distance value, typically the result + * of a distance calculation in a nearest neighbor search. Objects of this class are immutable. + */ +public class NodeReferenceWithDistance extends NodeReferenceWithVector { + private final double distance; + + /** + * Constructs a new instance of {@code NodeReferenceWithDistance}. + *

+ * This constructor initializes the reference with the node's primary key, its vector, and the calculated distance + * from some origin vector (e.g., a query vector). It calls the superclass constructor to set the {@code primaryKey} + * and {@code vector}. + * @param primaryKey the primary key of the referenced node, represented as a {@link Tuple}. Must not be null. + * @param vector the vector associated with the referenced node. Must not be null. + * @param distance the calculated distance of this node reference to some query vector or similar. + */ + public NodeReferenceWithDistance(@Nonnull final Tuple primaryKey, @Nonnull final Vector vector, + final double distance) { + super(primaryKey, vector); + this.distance = distance; + } + + /** + * Gets the distance. + * @return the current distance value + */ + public double getDistance() { + return distance; + } + + /** + * Compares this object against the specified object for equality. + *

+ * The result is {@code true} if and only if the argument is not {@code null}, + * is a {@code NodeReferenceWithDistance} object, has the same properties as + * determined by the superclass's {@link #equals(Object)} method, and has + * the same {@code distance} value. + * @param o the object to compare with this instance for equality. + * @return {@code true} if the specified object is equal to this {@code NodeReferenceWithDistance}; + * {@code false} otherwise. + */ + @Override + public boolean equals(final Object o) { + if (!(o instanceof NodeReferenceWithDistance)) { + return false; + } + if (!super.equals(o)) { + return false; + } + final NodeReferenceWithDistance that = (NodeReferenceWithDistance)o; + return Double.compare(distance, that.distance) == 0; + } + + /** + * Generates a hash code for this object. + * @return a hash code value for this object. + */ + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), distance); + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeReferenceWithVector.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeReferenceWithVector.java new file mode 100644 index 0000000000..7b29bedb09 --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeReferenceWithVector.java @@ -0,0 +1,123 @@ +/* + * NodeReferenceWithVector.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.apple.foundationdb.tuple.Tuple; +import com.google.common.base.Objects; + +import javax.annotation.Nonnull; + +/** + * Represents a reference to a node that includes an associated vector. + *

+ * This class extends {@link NodeReference} by adding a {@link Vector} field. It encapsulates both the primary key + * of a node and its corresponding vector data, which is particularly useful in vector-based search and + * indexing scenarios. Primarily, node references are used to refer to {@link Node}s in a storage-independent way, i.e. + * a node reference always contains the vector of a node while the node itself (depending on the storage adapter) + * may not. + */ +public class NodeReferenceWithVector extends NodeReference { + @Nonnull + private final Vector vector; + + /** + * Constructs a new {@code NodeReferenceWithVector} with a specified primary key and vector. + *

+ * The primary key is used to initialize the parent class via a call to {@code super()}, + * while the vector is stored as a field in this instance. Both parameters are expected + * to be non-null. + * + * @param primaryKey the primary key of the node, must not be null + * @param vector the vector associated with the node, must not be null + */ + public NodeReferenceWithVector(@Nonnull final Tuple primaryKey, @Nonnull final Vector vector) { + super(primaryKey); + this.vector = vector; + } + + /** + * Gets the vector of {@code Half} objects. + *

+ * This method provides access to the internal vector. The returned vector is guaranteed + * not to be null, as indicated by the {@code @Nonnull} annotation. + * + * @return the vector of {@code Half} objects; will never be {@code null}. + */ + @Nonnull + public Vector getVector() { + return vector; + } + + /** + * Gets the vector as a {@code Vector} of {@code Double}s. + * @return a non-null {@code Vector} containing the elements of this vector. + */ + @Nonnull + public Vector.DoubleVector getDoubleVector() { + return vector.toDoubleVector(); + } + + /** + * Returns this instance cast as a {@code NodeReferenceWithVector}. + * @return this instance as a {@code NodeReferenceWithVector}, which is never {@code null}. + */ + @Nonnull + @Override + public NodeReferenceWithVector asNodeReferenceWithVector() { + return this; + } + + /** + * Compares this {@code NodeReferenceWithVector} to the specified object for equality. + * @param o the object to compare with this {@code NodeReferenceWithVector}. + * @return {@code true} if the objects are equal; {@code false} otherwise. + */ + @Override + public boolean equals(final Object o) { + if (!(o instanceof NodeReferenceWithVector)) { + return false; + } + if (!super.equals(o)) { + return false; + } + return Objects.equal(vector, ((NodeReferenceWithVector)o).vector); + } + + /** + * Computes the hash code for this object. + * @return a hash code value for this object. + */ + @Override + public int hashCode() { + return Objects.hashCode(super.hashCode(), vector); + } + + /** + * Returns a string representation of this object. + * @return a concise string representation of this object. + */ + @Override + public String toString() { + return "NRV[primaryKey=" + getPrimaryKey() + + ";vector=" + vector.toString(3) + + "]"; + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/OnReadListener.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/OnReadListener.java new file mode 100644 index 0000000000..f8a009d32b --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/OnReadListener.java @@ -0,0 +1,78 @@ +/* + * OnReadListener.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2023 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import javax.annotation.Nonnull; +import java.util.concurrent.CompletableFuture; + +/** + * Interface for call backs whenever we read node data from the database. + */ +public interface OnReadListener { + OnReadListener NOOP = new OnReadListener() { + }; + + /** + * A callback method that can be overridden to intercept the result of an asynchronous node read. + *

+ * This method provides a hook for subclasses to inspect or modify the {@code CompletableFuture} after an + * asynchronous read operation is initiated. The default implementation is a no-op that simply returns the original + * future. This method is intended to be used to measure elapsed time between the creation of a + * {@link CompletableFuture} and its completion. + * @param the type of the {@code NodeReference} + * @param future the {@code CompletableFuture} representing the pending asynchronous read operation. + * @return a {@code CompletableFuture} that will complete with the read {@code Node}. + * By default, this is the same future that was passed as an argument. + */ + @SuppressWarnings("unused") + default CompletableFuture> onAsyncRead(@Nonnull CompletableFuture> future) { + return future; + } + + /** + * Callback method invoked when a node is read during a traversal process. + *

+ * This default implementation does nothing. Implementors can override this method to add custom logic that should + * be executed for each node encountered. This serves as an optional hook for processing nodes as they are read. + * @param layer the layer or depth of the node in the structure, starting from 0. + * @param node the {@link Node} that was just read (guaranteed to be non-null). + */ + @SuppressWarnings("unused") + default void onNodeRead(int layer, @Nonnull Node node) { + // nothing + } + + /** + * Callback invoked when a key-value pair is read from a specific layer. + *

+ * This method is typically called during a scan or iteration over data for each key/value pair. + * The default implementation is a no-op and does nothing. + * @param layer the layer from which the key-value pair was read. + * @param key the key that was read, guaranteed to be non-null. + * @param value the value associated with the key, guaranteed to be non-null. + */ + @SuppressWarnings("unused") + default void onKeyValueRead(int layer, + @Nonnull byte[] key, + @Nonnull byte[] value) { + // nothing + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/OnWriteListener.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/OnWriteListener.java new file mode 100644 index 0000000000..d645bf8421 --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/OnWriteListener.java @@ -0,0 +1,86 @@ +/* + * OnWriteListener.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2023 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.apple.foundationdb.tuple.Tuple; + +import javax.annotation.Nonnull; + +/** + * Interface for call backs whenever we write data to the database. + */ +public interface OnWriteListener { + OnWriteListener NOOP = new OnWriteListener() { + }; + + /** + * Callback method invoked after a node has been successfully written to a specific layer. + *

+ * This is a default method with an empty implementation, allowing implementing classes to override it only if they + * need to react to this event. + * @param layer the index of the layer where the node was written. + * @param node the {@link Node} that was written; guaranteed to be non-null. + */ + @SuppressWarnings("unused") + default void onNodeWritten(final int layer, @Nonnull final Node node) { + // nothing + } + + /** + * Callback method invoked when a neighbor is written for a specific node. + *

+ * This method serves as a notification that a neighbor relationship has been established or updated. It is + * typically called after a write operation successfully adds a {@code neighbor} to the specified {@code node} + * within a given {@code layer}. + *

+ * As a {@code default} method, the base implementation does nothing. Implementers can override this to perform + * custom actions, such as updating caches or triggering subsequent events in response to the change. + * @param layer the index of the layer where the neighbor write operation occurred + * @param node the {@link Node} for which the neighbor was written; must not be null + * @param neighbor the {@link NodeReference} of the neighbor that was written; must not be null + */ + @SuppressWarnings("unused") + default void onNeighborWritten(final int layer, @Nonnull final Node node, + @Nonnull final NodeReference neighbor) { + // nothing + } + + /** + * Callback method invoked when a neighbor of a specific node is deleted. + *

+ * This is a default method and its base implementation is a no-op. Implementors of the interface can override this + * method to react to the deletion of a neighbor node, for example, to clean up related resources or update internal + * state. + * @param layer the layer index where the deletion occurred + * @param node the {@link Node} whose neighbor was deleted + * @param neighborPrimaryKey the primary key (as a {@link Tuple}) of the neighbor that was deleted + */ + @SuppressWarnings("unused") + default void onNeighborDeleted(final int layer, @Nonnull final Node node, + @Nonnull final Tuple neighborPrimaryKey) { + // nothing + } + + @SuppressWarnings("unused") + default void onKeyValueWritten(final int layer, @Nonnull final byte[] key, @Nonnull final byte[] value) { + // nothing + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/StorageAdapter.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/StorageAdapter.java new file mode 100644 index 0000000000..dedad69f21 --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/StorageAdapter.java @@ -0,0 +1,455 @@ +/* + * StorageAdapter.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2023 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.apple.foundationdb.ReadTransaction; +import com.apple.foundationdb.Transaction; +import com.apple.foundationdb.subspace.Subspace; +import com.apple.foundationdb.tuple.Tuple; +import com.christianheina.langx.half4j.Half; +import com.google.common.base.Verify; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import java.util.concurrent.CompletableFuture; + +/** + * Defines the contract for storing and retrieving HNSW graph data to/from a persistent store. + *

+ * This interface provides an abstraction layer over the underlying database, handling the serialization and + * deserialization of HNSW graph components such as nodes, vectors, and their relationships. Implementations of this + * interface are responsible for managing the physical layout of data within a given {@link Subspace}. + * The generic type {@code N} represents the specific type of {@link NodeReference} that this storage adapter manages. + * + * @param the type of {@link NodeReference} this storage adapter manages + */ +interface StorageAdapter { + /** + * Subspace for entry nodes; these are kept separately from the data. + */ + byte SUBSPACE_PREFIX_ENTRY_NODE = 0x01; + /** + * Subspace for data. + */ + byte SUBSPACE_PREFIX_DATA = 0x02; + + /** + * Returns the configuration of the HNSW graph. + *

+ * This configuration object contains all the parameters used to build and search the graph, + * such as the number of neighbors to connect (M), the size of the dynamic list for + * construction (efConstruction), and the beam width for searching (ef). + * @return the {@code HNSW.Config} for this graph, never {@code null}. + */ + @Nonnull + HNSW.Config getConfig(); + + /** + * Gets the factory used to create new nodes. + *

+ * This factory is responsible for instantiating new nodes of type {@code N}. + * @return the non-null factory for creating nodes. + */ + @Nonnull + NodeFactory getNodeFactory(); + + /** + * Gets the kind of node this storage adapter manages (and instantiates if needed). + * @return the kind of this node, never {@code null} + */ + @Nonnull + NodeKind getNodeKind(); + + /** + * Returns a view of this object as a {@code StorageAdapter} that is optimized + * for compact data representation. + * @return a non-null {@code StorageAdapter} for {@code NodeReference} objects, + * optimized for compact storage. + */ + @Nonnull + StorageAdapter asCompactStorageAdapter(); + + /** + * Returns a view of this storage as a {@code StorageAdapter} that handles inlined vectors. + *

+ * The returned adapter is specifically designed to work with {@link NodeReferenceWithVector}, assuming that the + * vector data is stored directly within the node reference itself. + * @return a non-null {@link StorageAdapter} + */ + @Nonnull + StorageAdapter asInliningStorageAdapter(); + + /** + * Get the subspace used to store this HNSW structure. + * @return the subspace + */ + @Nonnull + Subspace getSubspace(); + + /** + * Gets the subspace that contains the data for this object. + *

+ * This subspace represents the portion of the keyspace dedicated to storing the actual data, as opposed to metadata + * or other system-level information. + * @return the subspace containing the data, which is guaranteed to be non-null + */ + @Nonnull + Subspace getDataSubspace(); + + /** + * Get the on-write listener. + * @return the on-write listener. + */ + @Nonnull + OnWriteListener getOnWriteListener(); + + /** + * Get the on-read listener. + * @return the on-read listener. + */ + @Nonnull + OnReadListener getOnReadListener(); + + /** + * Asynchronously fetches a node from a specific layer, identified by its primary key. + *

+ * The fetch operation is performed within the scope of the provided {@link ReadTransaction}, ensuring a consistent + * view of the data. The returned {@link CompletableFuture} will be completed with the node once it has been + * retrieved from the underlying data store. + * @param readTransaction the {@link ReadTransaction} context for this read operation + * @param layer the layer from which to fetch the node + * @param primaryKey the {@link Tuple} representing the primary key of the node to retrieve + * @return a non-null {@link CompletableFuture} which will complete with the fetched {@code Node}. + */ + @Nonnull + CompletableFuture> fetchNode(@Nonnull ReadTransaction readTransaction, + int layer, + @Nonnull Tuple primaryKey); + + /** + * Writes a node and its neighbor changes to the data store within a given transaction. + *

+ * This method is responsible for persisting the state of a {@link Node} and applying any modifications to its + * neighboring nodes as defined in the {@code NeighborsChangeSet}. The entire operation is performed atomically as + * part of the provided {@link Transaction}. + * @param transaction the non-null transaction context for this write operation. + * @param node the non-null node to be written to the data store. + * @param layer the layer index where the node resides. + * @param changeSet the non-null set of changes describing additions or removals of + * neighbors for the given {@link Node}. + */ + void writeNode(@Nonnull Transaction transaction, @Nonnull Node node, int layer, + @Nonnull NeighborsChangeSet changeSet); + + /** + * Scans a specified layer of the directory, returning an iterable sequence of nodes. + *

+ * This method allows for paginated scanning of a layer. The scan can be started from the beginning of the layer by + * passing {@code null} for the {@code lastPrimaryKey}, or it can be resumed from a previous point by providing the + * key of the last item from the prior scan. The number of nodes returned is limited by {@code maxNumRead}. + * + * @param readTransaction the transaction to use for the read operation + * @param layer the index of the layer to scan + * @param lastPrimaryKey the primary key of the last node from a previous scan, + * or {@code null} to start from the beginning of the layer + * @param maxNumRead the maximum number of nodes to return in this scan + * @return an {@link Iterable} that provides the nodes found in the specified layer range + */ + Iterable> scanLayer(@Nonnull ReadTransaction readTransaction, int layer, @Nullable Tuple lastPrimaryKey, + int maxNumRead); + + /** + * Fetches the entry node reference for the HNSW index. + *

+ * This method performs an asynchronous read to retrieve the stored entry point of the index. The entry point + * information, which includes its primary key, vector, and the layer value, is packed into a single key-value + * pair within a dedicated subspace. If no entry node is found, it indicates that the index is empty. + * + * @param readTransaction the transaction to use for the read operation + * @param subspace the subspace where the HNSW index data is stored + * @param onReadListener a listener to be notified of the key-value read operation + * @return a {@link CompletableFuture} that will complete with the {@link EntryNodeReference} + * for the index's entry point, or with {@code null} if the index is empty + */ + @Nonnull + static CompletableFuture fetchEntryNodeReference(@Nonnull final ReadTransaction readTransaction, + @Nonnull final Subspace subspace, + @Nonnull final OnReadListener onReadListener) { + final Subspace entryNodeSubspace = subspace.subspace(Tuple.from(SUBSPACE_PREFIX_ENTRY_NODE)); + final byte[] key = entryNodeSubspace.pack(); + + return readTransaction.get(key) + .thenApply(valueBytes -> { + if (valueBytes == null) { + return null; // not a single node in the index + } + onReadListener.onKeyValueRead(-1, key, valueBytes); + + final Tuple entryTuple = Tuple.fromBytes(valueBytes); + final int layer = (int)entryTuple.getLong(0); + final Tuple primaryKey = entryTuple.getNestedTuple(1); + final Tuple vectorTuple = entryTuple.getNestedTuple(2); + return new EntryNodeReference(primaryKey, StorageAdapter.vectorFromTuple(vectorTuple), layer); + }); + } + + /** + * Writes an {@code EntryNodeReference} to the database within a given transaction and subspace. + *

+ * This method serializes the provided {@link EntryNodeReference} into a key-value pair. The key is determined by + * a dedicated subspace for entry nodes, and the value is a tuple containing the layer, primary key, and vector from + * the reference. After writing the data, it notifies the provided {@link OnWriteListener}. + * @param transaction the database transaction to use for the write operation + * @param subspace the subspace where the entry node reference will be stored + * @param entryNodeReference the {@link EntryNodeReference} object to write + * @param onWriteListener the listener to be notified after the key-value pair is written + */ + static void writeEntryNodeReference(@Nonnull final Transaction transaction, + @Nonnull final Subspace subspace, + @Nonnull final EntryNodeReference entryNodeReference, + @Nonnull final OnWriteListener onWriteListener) { + final Subspace entryNodeSubspace = subspace.subspace(Tuple.from(SUBSPACE_PREFIX_ENTRY_NODE)); + final byte[] key = entryNodeSubspace.pack(); + final byte[] value = Tuple.from(entryNodeReference.getLayer(), + entryNodeReference.getPrimaryKey(), + StorageAdapter.tupleFromVector(entryNodeReference.getVector())).pack(); + transaction.set(key, + value); + onWriteListener.onKeyValueWritten(entryNodeReference.getLayer(), key, value); + } + + /** + * Creates a {@code HalfVector} from a given {@code Tuple}. + *

+ * This method assumes the vector data is stored as a byte array at the first. position (index 0) of the tuple. It + * extracts this byte array and then delegates to the {@link #vectorFromBytes(byte[])} method for the actual + * conversion. + * @param vectorTuple the tuple containing the vector data as a byte array at index 0. Must not be {@code null}. + * @return a new {@code HalfVector} instance created from the tuple's data. + * This method never returns {@code null}. + */ + @Nonnull + static Vector vectorFromTuple(final Tuple vectorTuple) { + return vectorFromBytes(vectorTuple.getBytes(0)); + } + + /** + * Creates a {@link Vector} from a byte array. + *

+ * This method interprets the input byte array by interpreting the first byte of the array as the precision shift. + * The byte array must have the proper size, i.e. the invariant {@code (bytesLength - 1) % precision == 0} must + * hold. + * @param vectorBytes the non-null byte array to convert. + * @return a new {@link Vector} instance created from the byte array. + * @throws com.google.common.base.VerifyException if the length of {@code vectorBytes} does not meet the invariant + * {@code (bytesLength - 1) % precision == 0} + */ + @Nonnull + static Vector vectorFromBytes(final byte[] vectorBytes) { + final int bytesLength = vectorBytes.length; + final int precisionShift = (int)vectorBytes[0]; + final int precision = 1 << precisionShift; + Verify.verify((bytesLength - 1) % precision == 0); + final int numDimensions = bytesLength >>> precisionShift; + switch (precisionShift) { + case 1: + return halfVectorFromBytes(vectorBytes, 1, numDimensions); + case 3: + return doubleVectorFromBytes(vectorBytes, 1, numDimensions); + default: + throw new RuntimeException("unable to serialize vector"); + } + } + + /** + * Creates a {@link Vector.HalfVector} from a byte array. + *

+ * This method interprets the input byte array as a sequence of 16-bit half-precision floating-point numbers. Each + * consecutive pair of bytes is converted into a {@code Half} value, which then becomes a component of the resulting + * vector. + * @param vectorBytes the non-null byte array to convert. The length of this array must be even, as each pair of + * bytes represents a single {@link Half} component. + * @return a new {@link Vector.HalfVector} instance created from the byte array. + */ + @Nonnull + static Vector.HalfVector halfVectorFromBytes(@Nonnull final byte[] vectorBytes, final int offset, final int numDimensions) { + final Half[] vectorHalfs = new Half[numDimensions]; + for (int i = 0; i < numDimensions; i ++) { + vectorHalfs[i] = Half.shortBitsToHalf(shortFromBytes(vectorBytes, offset + (i << 1))); + } + return new Vector.HalfVector(vectorHalfs); + } + + /** + * Creates a {@link Vector.DoubleVector} from a byte array. + *

+ * This method interprets the input byte array as a sequence of 64-bit double-precision floating-point numbers. Each + * run of eight bytes is converted into a {@code double} value, which then becomes a component of the resulting + * vector. + * @param vectorBytes the non-null byte array to convert. + * @return a new {@link Vector.DoubleVector} instance created from the byte array. + */ + @Nonnull + static Vector.DoubleVector doubleVectorFromBytes(@Nonnull final byte[] vectorBytes, int offset, final int numDimensions) { + final double[] vectorComponents = new double[numDimensions]; + for (int i = 0; i < numDimensions; i ++) { + vectorComponents[i] = Double.longBitsToDouble(longFromBytes(vectorBytes, offset + (i << 3))); + } + return new Vector.DoubleVector(vectorComponents); + } + + /** + * Converts a {@link Vector} into a {@link Tuple}. + *

+ * This method first serializes the given vector into a byte array using the {@link Vector#getRawData()} getter + * method. It then creates a {@link Tuple} from the resulting byte array. + * @param vector the vector of {@code Half} precision floating-point numbers to convert. Cannot be null. + * @return a new, non-null {@code Tuple} instance representing the contents of the vector. + */ + @Nonnull + @SuppressWarnings("PrimitiveArrayArgumentToVarargsMethod") + static Tuple tupleFromVector(final Vector vector) { + return Tuple.from(vector.getRawData()); + } + + /** + * Converts a {@link Vector} of {@link Half} precision floating-point numbers into a byte array. + *

+ * This method iterates through the input vector, converting each {@link Half} element into its 16-bit short + * representation. It then serializes this short into two bytes, placing them sequentially into the resulting byte + * array. The final array's length will be {@code 2 * vector.size()}. + * @param halfVector the vector of {@link Half} precision numbers to convert. Must not be null. + * @return a new byte array representing the serialized vector data. This array is never null. + */ + @Nonnull + static byte[] bytesFromVector(@Nonnull final Vector.HalfVector halfVector) { + final byte[] vectorBytes = new byte[1 + 2 * halfVector.size()]; + vectorBytes[0] = (byte)halfVector.precisionShift(); + for (int i = 0; i < halfVector.size(); i ++) { + final byte[] componentBytes = bytesFromShort(Half.halfToShortBits(Half.valueOf(halfVector.getComponent(i)))); + final int offset = 1 + (i << 1); + vectorBytes[offset] = componentBytes[0]; + vectorBytes[offset + 1] = componentBytes[1]; + } + return vectorBytes; + } + + /** + * Converts a {@link Vector} of {@code double} precision floating-point numbers into a byte array. + *

+ * This method iterates through the input vector, converting each {@code double} element into its 16-bit short + * representation. It then serializes this short into eight bytes, placing them sequentially into the resulting byte + * array. The final array's length will be {@code 8 * vector.size()}. + * @param doubleVector the vector of {@code double} precision numbers to convert. Must not be null. + * @return a new byte array representing the serialized vector data. This array is never null. + */ + @Nonnull + static byte[] bytesFromVector(final Vector.DoubleVector doubleVector) { + final byte[] vectorBytes = new byte[1 + 8 * doubleVector.size()]; + vectorBytes[0] = (byte)doubleVector.precisionShift(); + for (int i = 0; i < doubleVector.size(); i ++) { + final byte[] componentBytes = bytesFromLong(Double.doubleToLongBits(doubleVector.getComponent(i))); + final int offset = 1 + (i << 3); + vectorBytes[offset] = componentBytes[0]; + vectorBytes[offset + 1] = componentBytes[1]; + vectorBytes[offset + 2] = componentBytes[2]; + vectorBytes[offset + 3] = componentBytes[3]; + vectorBytes[offset + 4] = componentBytes[4]; + vectorBytes[offset + 5] = componentBytes[5]; + vectorBytes[offset + 6] = componentBytes[6]; + vectorBytes[offset + 7] = componentBytes[7]; + } + return vectorBytes; + } + + /** + * Constructs a short from two bytes in a byte array in big-endian order. + *

+ * This method reads two consecutive bytes from the {@code bytes} array, starting at the given {@code offset}. The + * byte at {@code offset} is treated as the most significant byte (MSB), and the byte at {@code offset + 1} is the + * least significant byte (LSB). + * @param bytes the source byte array from which to read the short. + * @param offset the starting index in the byte array. + * @return the short value constructed from the two bytes. + */ + static short shortFromBytes(final byte[] bytes, final int offset) { + int high = bytes[offset] & 0xFF; // Convert to unsigned int + int low = bytes[offset + 1] & 0xFF; + + return (short) ((high << 8) | low); + } + + /** + * Converts a {@code short} value into a 2-element byte array. + *

+ * The conversion is performed in big-endian byte order, where the most significant byte (MSB) is placed at index 0 + * and the least significant byte (LSB) is at index 1. + * @param value the {@code short} value to be converted. + * @return a new 2-element byte array representing the short value in big-endian order. + */ + static byte[] bytesFromShort(final short value) { + byte[] result = new byte[2]; + result[0] = (byte) ((value >> 8) & 0xFF); // high byte first + result[1] = (byte) (value & 0xFF); // low byte second + return result; + } + + /** + * Constructs a long from eight bytes in a byte array in big-endian order. + *

+ * This method reads two consecutive bytes from the {@code bytes} array, starting at the given {@code offset}. The + * byte array is treated to be in big-endian order. + * @param bytes the source byte array from which to read the short. + * @param offset the starting index in the byte array. + * @return the long value constructed from the two bytes. + */ + private static long longFromBytes(final byte[] bytes, final int offset) { + return ((bytes[offset ] & 0xFFL) << 56) | + ((bytes[offset + 1] & 0xFFL) << 48) | + ((bytes[offset + 2] & 0xFFL) << 40) | + ((bytes[offset + 3] & 0xFFL) << 32) | + ((bytes[offset + 4] & 0xFFL) << 24) | + ((bytes[offset + 5] & 0xFFL) << 16) | + ((bytes[offset + 6] & 0xFFL) << 8) | + ((bytes[offset + 7] & 0xFFL)); + } + + /** + * Converts a {@code short} value into a 2-element byte array. + *

+ * The conversion is performed in big-endian byte order. + * @param value the {@code long} value to be converted. + * @return a new 8-element byte array representing the short value in big-endian order. + */ + @Nonnull + private static byte[] bytesFromLong(final long value) { + byte[] result = new byte[8]; + result[0] = (byte)(value >>> 56); + result[1] = (byte)(value >>> 48); + result[2] = (byte)(value >>> 40); + result[3] = (byte)(value >>> 32); + result[4] = (byte)(value >>> 24); + result[5] = (byte)(value >>> 16); + result[6] = (byte)(value >>> 8); + result[7] = (byte)value; + return result; + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Vector.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Vector.java new file mode 100644 index 0000000000..a2ad52b2fe --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Vector.java @@ -0,0 +1,507 @@ +/* + * Vector.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2023 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.christianheina.langx.half4j.Half; +import com.google.common.base.Suppliers; +import com.google.common.base.Verify; +import com.google.common.collect.AbstractIterator; +import com.google.common.collect.ImmutableList; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import java.io.EOFException; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.channels.FileChannel; +import java.util.Arrays; +import java.util.List; +import java.util.Objects; +import java.util.function.Supplier; +import java.util.stream.Collectors; + +/** + * An abstract base class representing a mathematical vector. + *

+ * This class provides a generic framework for vectors of different numerical types, + * where {@code R} is a subtype of {@link Number}. It includes common operations and functionalities like size, + * component access, equality checks, and conversions. Concrete implementations must provide specific logic for + * data type conversions and raw data representation. + + */ +public abstract class Vector { + @Nonnull + final double[] data; + + @Nonnull + protected Supplier hashCodeSupplier; + + @Nonnull + private final Supplier toRawDataSupplier; + + /** + * Constructs a new Vector with the given data. + *

+ * This constructor uses the provided array directly as the backing store for the vector. It does not create a + * defensive copy. Therefore, any subsequent modifications to the input array will be reflected in this vector's + * state. The contract of this constructor is that callers do not modify {@code data} after calling the constructor. + * We do not want to copy the array here for performance reasons. + * @param data the components of this vector + * @throws NullPointerException if the provided {@code data} array is null. + */ + public Vector(@Nonnull final double[] data) { + this.data = data; + this.hashCodeSupplier = Suppliers.memoize(this::computeHashCode); + this.toRawDataSupplier = Suppliers.memoize(this::computeRawData); + } + + /** + * Returns the number of elements in the vector. + * @return the number of elements + */ + public int size() { + return data.length; + } + + /** + * Gets the component of this object at the specified dimension. + *

+ * The dimension is a zero-based index. For a 3D vector, for example, dimension 0 might correspond to the + * x-component, 1 to the y-component, and 2 to the z-component. This method provides direct access to the + * underlying data element. + * @param dimension the zero-based index of the component to retrieve. + * @return the component at the specified dimension, which is guaranteed to be non-null. + * @throws IndexOutOfBoundsException if the {@code dimension} is negative or + * greater than or equal to the number of dimensions of this object. + */ + double getComponent(int dimension) { + return data[dimension]; + } + + /** + * Returns the underlying data array. + *

+ * The returned array is guaranteed to be non-null. Note that this method + * returns a direct reference to the internal array, not a copy. + * @return the data array of type {@code R[]}, never {@code null}. + */ + @Nonnull + public double[] getData() { + return data; + } + + /** + * Gets the raw byte data representation of this object. + *

+ * This method provides a direct, unprocessed view of the object's underlying data. The format of the byte array is + * implementation-specific and should be documented by the concrete class that implements this method. + * @return a non-null byte array containing the raw data. + */ + @Nonnull + public byte[] getRawData() { + return toRawDataSupplier.get(); + } + + /** + * Computes the raw byte data representation of this object. + *

+ * This method provides a direct, unprocessed view of the object's underlying data. The format of the byte array is + * implementation-specific and should be documented by the concrete class that implements this method. + * @return a non-null byte array containing the raw data. + */ + @Nonnull + protected abstract byte[] computeRawData(); + + /** + * Converts this object into a {@code Vector} of {@link Half} precision floating-point numbers. + *

+ * As this is an abstract method, implementing classes are responsible for defining the specific conversion logic + * from their internal representation to a {@code Vector} of {@link Half} objects. If this object already is a + * {@code HalfVector} this method should return {@code this}. + * @return a non-null {@link Vector} containing the {@link Half} precision floating-point representation of this + * object. + */ + @Nonnull + public abstract HalfVector toHalfVector(); + + /** + * Converts this vector into a {@link DoubleVector}. + *

+ * This method provides a way to obtain a double-precision floating-point representation of the vector. If the + * vector is already an instance of {@code DoubleVector}, this method may return the instance itself. Otherwise, + * it will create a new {@code DoubleVector} containing the same elements, which may involve a conversion of the + * underlying data type. + * @return a non-null {@link DoubleVector} representation of this vector. + */ + @Nonnull + public abstract DoubleVector toDoubleVector(); + + /** + * Returns the number of bytes used for the serialization of this vector per component. + * @return the component size, i.e. the number of bytes used for the serialization of this vector per component. + */ + public int precision() { + return (1 << precisionShift()); + } + + /** + * Returns the number of bits we need to shift {@code 1} to express {@link #precision()} used for the serialization + * of this vector per component. + * @return returns the number of bits we need to shift {@code 1} to express {@link #precision()} + */ + public abstract int precisionShift(); + + /** + * Compares this vector to the specified object for equality. + *

+ * The result is {@code true} if and only if the argument is not {@code null} and is a {@code Vector} object that + * has the same data elements as this object. This method performs a deep equality check on the underlying data + * elements using {@link Objects#deepEquals(Object, Object)}. + * @param o the object to compare with this {@code Vector} for equality. + * @return {@code true} if the given object is a {@code Vector} equivalent to this vector, {@code false} otherwise. + */ + @Override + public boolean equals(final Object o) { + if (!(o instanceof Vector)) { + return false; + } + final Vector vector = (Vector)o; + return Objects.deepEquals(data, vector.data); + } + + /** + * Returns a hash code value for this object. The hash code is computed once and memoized. + * @return a hash code value for this object. + */ + @Override + public int hashCode() { + return hashCodeSupplier.get(); + } + + /** + * Computes a hash code based on the internal {@code data} array. + * @return the computed hash code for this object. + */ + private int computeHashCode() { + return Arrays.hashCode(data); + } + + /** + * Returns a string representation of the object. + *

+ * This method provides a default string representation by calling + * {@link #toString(int)} with a predefined indentation level of 3. + * + * @return a string representation of this object with a default indentation. + */ + @Override + public String toString() { + return toString(3); + } + + /** + * Generates a string representation of the data array, with an option to limit the number of dimensions shown. + *

+ * If the specified {@code limitDimensions} is less than the actual number of dimensions in the data array, + * the resulting string will be a truncated view, ending with {@code ", ..."} to indicate that more elements exist. + * Otherwise, the method returns a complete string representation of the entire array. + * @param limitDimensions The maximum number of array elements to include in the string. A non-positive + * value will cause an {@link com.google.common.base.VerifyException}. + * @return A string representation of the data array, potentially truncated. + * @throws com.google.common.base.VerifyException if {@code limitDimensions} is not positive + */ + public String toString(final int limitDimensions) { + Verify.verify(limitDimensions > 0); + if (limitDimensions < data.length) { + return "[" + Arrays.stream(Arrays.copyOfRange(data, 0, limitDimensions)) + .mapToObj(String::valueOf) + .collect(Collectors.joining(",")) + ", ...]"; + } else { + return "[" + Arrays.stream(data) + .mapToObj(String::valueOf) + .collect(Collectors.joining(",")) + "]"; + } + } + + /** + * A vector class encoding a vector over half components. Conversion to {@link DoubleVector} is supported and + * memoized. + */ + public static class HalfVector extends Vector { + @Nonnull + private final Supplier toDoubleVectorSupplier; + + public HalfVector(@Nonnull final Half[] halfData) { + this(computeDoubleData(halfData)); + } + + public HalfVector(@Nonnull final double[] data) { + super(data); + this.toDoubleVectorSupplier = Suppliers.memoize(this::computeDoubleVector); + } + + @Nonnull + @Override + public HalfVector toHalfVector() { + return this; + } + + @Nonnull + @Override + public DoubleVector toDoubleVector() { + return toDoubleVectorSupplier.get(); + } + + @Nonnull + public DoubleVector computeDoubleVector() { + return new DoubleVector(data); + } + + @Override + public int precisionShift() { + return 1; + } + + @Nonnull + @Override + protected byte[] computeRawData() { + return StorageAdapter.bytesFromVector(this); + } + + @Nonnull + private static double[] computeDoubleData(@Nonnull Half[] halfData) { + double[] result = new double[halfData.length]; + for (int i = 0; i < halfData.length; i ++) { + result[i] = halfData[i].doubleValue(); + } + return result; + } + } + + /** + * A vector class encoding a vector over double components. Conversion to {@link HalfVector} is supported and + * memoized. + */ + public static class DoubleVector extends Vector { + @Nonnull + private final Supplier toHalfVectorSupplier; + + public DoubleVector(@Nonnull final Double[] doubleData) { + this(computeDoubleData(doubleData)); + } + + public DoubleVector(@Nonnull final double[] data) { + super(data); + this.toHalfVectorSupplier = Suppliers.memoize(this::computeHalfVector); + } + + @Nonnull + @Override + public HalfVector toHalfVector() { + return toHalfVectorSupplier.get(); + } + + @Nonnull + @Override + public DoubleVector toDoubleVector() { + return this; + } + + @Nonnull + public HalfVector computeHalfVector() { + return new HalfVector(data); + } + + @Override + public int precisionShift() { + return 3; + } + + @Nonnull + @Override + protected byte[] computeRawData() { + return StorageAdapter.bytesFromVector(this); + } + + @Nonnull + private static double[] computeDoubleData(@Nonnull Double[] doubleData) { + double[] result = new double[doubleData.length]; + for (int i = 0; i < doubleData.length; i ++) { + result[i] = doubleData[i]; + } + return result; + } + } + + /** + * Calculates the distance between two vectors using a specified metric. + *

+ * This static utility method provides a convenient way to compute the distance by handling the conversion of + * generic {@code Vector} objects to primitive {@code double} arrays. The actual distance computation is then + * delegated to the provided {@link Metric} instance. + * @param metric the {@link Metric} to use for the distance calculation. + * @param vector1 the first vector. + * @param vector2 the second vector. + * @return the calculated distance between the two vectors as a {@code double}. + */ + public static double distance(@Nonnull Metric metric, + @Nonnull final Vector vector1, + @Nonnull final Vector vector2) { + return metric.distance(vector1.getData(), vector2.getData()); + } + + /** + * Calculates the comparative distance between two vectors using a specified metric. + *

+ * This utility method converts the input vectors, which can contain any {@link Number} type, into primitive double + * arrays. It then delegates the actual distance computation to the {@code comparativeDistance} method of the + * provided {@link Metric} object. + * @param metric the {@link Metric} to use for the distance calculation. Must not be null. + * @param vector1 the first vector for the comparison. Must not be null. + * @param vector2 the second vector for the comparison. Must not be null. + * @return the calculated comparative distance as a {@code double}. + * @throws NullPointerException if {@code metric}, {@code vector1}, or {@code vector2} is null. + */ + static double comparativeDistance(@Nonnull Metric metric, + @Nonnull final Vector vector1, + @Nonnull final Vector vector2) { + return metric.comparativeDistance(vector1.getData(), vector2.getData()); + } + + /** + * Abstract iterator implementation to read the IVecs/FVecs data format that is used by publicly available + * embedding datasets. + * @param the component type of the vectors which must extends {@link Number} + * @param the type of object this iterator creates and uses to represent a stored vector in memory + */ + public abstract static class StoredVecsIterator extends AbstractIterator { + @Nonnull + private final FileChannel fileChannel; + + protected StoredVecsIterator(@Nonnull final FileChannel fileChannel) { + this.fileChannel = fileChannel; + } + + @Nonnull + protected abstract N[] newComponentArray(int size); + + @Nonnull + protected abstract N toComponent(@Nonnull ByteBuffer byteBuffer); + + @Nonnull + protected abstract T toTarget(@Nonnull N[] components); + + + @Nullable + @Override + protected T computeNext() { + try { + final ByteBuffer headerBuf = ByteBuffer.allocate(4).order(ByteOrder.LITTLE_ENDIAN); + // allocate a buffer for reading floats later; you may reuse + headerBuf.clear(); + final int bytesRead = fileChannel.read(headerBuf); + if (bytesRead < 4) { + if (bytesRead == -1) { + return endOfData(); + } + throw new IOException("corrupt fvecs file"); + } + headerBuf.flip(); + final int dims = headerBuf.getInt(); + if (dims <= 0) { + throw new IOException("Invalid dimension " + dims + " at position " + (fileChannel.position() - 4)); + } + final ByteBuffer vecBuf = ByteBuffer.allocate(dims * 4).order(ByteOrder.LITTLE_ENDIAN); + while (vecBuf.hasRemaining()) { + int read = fileChannel.read(vecBuf); + if (read < 0) { + throw new EOFException("unexpected EOF when reading vector data"); + } + } + vecBuf.flip(); + final N[] rawVecData = newComponentArray(dims); + for (int i = 0; i < dims; i++) { + rawVecData[i] = toComponent(vecBuf); + } + + return toTarget(rawVecData); + } catch (final IOException ioE) { + throw new RuntimeException(ioE); + } + } + } + + /** + * Iterator to read floating point vectors from a {@link FileChannel} providing an iterator of + * {@link DoubleVector}s. + */ + public static class StoredFVecsIterator extends StoredVecsIterator { + public StoredFVecsIterator(@Nonnull final FileChannel fileChannel) { + super(fileChannel); + } + + @Nonnull + @Override + protected Double[] newComponentArray(final int size) { + return new Double[size]; + } + + @Nonnull + @Override + protected Double toComponent(@Nonnull final ByteBuffer byteBuffer) { + return (double)byteBuffer.getFloat(); + } + + @Nonnull + @Override + protected DoubleVector toTarget(@Nonnull final Double[] components) { + return new DoubleVector(components); + } + } + + /** + * Iterator to read vectors from a {@link FileChannel} into a list of integers. + */ + public static class StoredIVecsIterator extends StoredVecsIterator> { + public StoredIVecsIterator(@Nonnull final FileChannel fileChannel) { + super(fileChannel); + } + + @Nonnull + @Override + protected Integer[] newComponentArray(final int size) { + return new Integer[size]; + } + + @Nonnull + @Override + protected Integer toComponent(@Nonnull final ByteBuffer byteBuffer) { + return byteBuffer.getInt(); + } + + @Nonnull + @Override + protected List toTarget(@Nonnull final Integer[] components) { + return ImmutableList.copyOf(components); + } + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/package-info.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/package-info.java new file mode 100644 index 0000000000..791fd0728a --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/package-info.java @@ -0,0 +1,24 @@ +/* + * package-info.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2023 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Classes and interfaces related to the HNSW implementation as used for vector indexes. + */ +package com.apple.foundationdb.async.hnsw; diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/rtree/StorageAdapter.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/rtree/StorageAdapter.java index f60c17da63..2623cff1dc 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/async/rtree/StorageAdapter.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/rtree/StorageAdapter.java @@ -36,7 +36,6 @@ * Storage adapter used for serialization and deserialization of nodes. */ interface StorageAdapter { - /** * Get the {@link RTree.Config} associated with this storage adapter. * @return the configuration used by this storage adapter diff --git a/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/HNSWHelpersTest.java b/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/HNSWHelpersTest.java new file mode 100644 index 0000000000..f138fd8417 --- /dev/null +++ b/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/HNSWHelpersTest.java @@ -0,0 +1,75 @@ +/* + * HNSWHelpersTest.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.christianheina.langx.half4j.Half; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +@SuppressWarnings("checkstyle:AbbreviationAsWordInName") +public class HNSWHelpersTest { + @Test + public void bytesToHex_MultipleBytesWithLeadingZeros_ReturnsTrimmedHexTest() { + final byte[] bytes = new byte[] {0, 1, 16, (byte)255}; // Represents 000110FF + final String result = HNSWHelpers.bytesToHex(bytes); + assertEquals("0x110FF", result); + } + + @Test + public void bytesToHex_NegativeByteValues_ReturnsCorrectUnsignedHexTest() { + final byte[] bytes = new byte[] {-1, -2}; // 0xFFFE + final String result = HNSWHelpers.bytesToHex(bytes); + assertEquals("0xFFFE", result); + } + + @Test + public void halfValueOf_NegativeFloat_ReturnsCorrectHalfValue_Test() { + final float inputValue = -56.75f; + final Half expected = Half.valueOf(inputValue); + final Half result = HNSWHelpers.halfValueOf(inputValue); + assertEquals(expected, result); + } + + @Test + public void halfValueOf_PositiveFloat_ReturnsCorrectHalfValue_Test() { + final float inputValue = 123.4375f; + Half expected = Half.valueOf(inputValue); + Half result = HNSWHelpers.halfValueOf(inputValue); + assertEquals(expected, result); + } + + @Test + public void halfValueOf_NegativeDouble_ReturnsCorrectHalfValue_Test() { + final double inputValue = -56.75d; + final Half expected = Half.valueOf(inputValue); + final Half result = HNSWHelpers.halfValueOf(inputValue); + assertEquals(expected, result); + } + + @Test + public void halfValueOf_PositiveDouble_ReturnsCorrectHalfValue_Test() { + final double inputValue = 123.4375d; + Half expected = Half.valueOf(inputValue); + Half result = HNSWHelpers.halfValueOf(inputValue); + assertEquals(expected, result); + } +} diff --git a/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/HNSWTest.java b/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/HNSWTest.java new file mode 100644 index 0000000000..6f9515d8e9 --- /dev/null +++ b/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/HNSWTest.java @@ -0,0 +1,541 @@ +/* + * HNSWTest.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2023 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.apple.foundationdb.Database; +import com.apple.foundationdb.Transaction; +import com.apple.foundationdb.async.hnsw.Vector.HalfVector; +import com.apple.foundationdb.async.rtree.RTree; +import com.apple.foundationdb.test.TestDatabaseExtension; +import com.apple.foundationdb.test.TestExecutors; +import com.apple.foundationdb.test.TestSubspaceExtension; +import com.apple.foundationdb.tuple.Tuple; +import com.apple.test.Tags; +import com.google.common.base.Verify; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Maps; +import com.google.common.collect.ObjectArrays; +import com.google.common.collect.Sets; +import org.assertj.core.util.Lists; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Tag; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.api.extension.RegisterExtension; +import org.junit.jupiter.api.parallel.Execution; +import org.junit.jupiter.api.parallel.ExecutionMode; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.Nonnull; +import java.io.IOException; +import java.nio.channels.FileChannel; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.nio.file.StandardOpenOption; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.Iterator; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Random; +import java.util.Set; +import java.util.TreeSet; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Function; +import java.util.stream.Collectors; +import java.util.stream.LongStream; +import java.util.stream.Stream; + +/** + * Tests testing insert/update/deletes of data into/in/from {@link RTree}s. + */ +@Execution(ExecutionMode.CONCURRENT) +@SuppressWarnings("checkstyle:AbbreviationAsWordInName") +@Tag(Tags.RequiresFDB) +@Tag(Tags.Slow) +public class HNSWTest { + private static final Logger logger = LoggerFactory.getLogger(HNSWTest.class); + private static final int NUM_TEST_RUNS = 5; + private static final int NUM_SAMPLES = 10_000; + + @RegisterExtension + static final TestDatabaseExtension dbExtension = new TestDatabaseExtension(); + @RegisterExtension + TestSubspaceExtension rtSubspace = new TestSubspaceExtension(dbExtension); + @RegisterExtension + TestSubspaceExtension rtSecondarySubspace = new TestSubspaceExtension(dbExtension); + + private Database db; + + @BeforeEach + public void setUpDb() { + db = dbExtension.getDatabase(); + } + + private static Stream randomSeeds() { + return LongStream.generate(() -> new Random().nextLong()) + .limit(5) + .boxed(); + } + + @ParameterizedTest(name = "seed={0}") + @MethodSource("randomSeeds") + public void testCompactSerialization(final long seed) { + final Random random = new Random(seed); + final CompactStorageAdapter storageAdapter = + new CompactStorageAdapter(HNSW.DEFAULT_CONFIG, CompactNode.factory(), rtSubspace.getSubspace(), + OnWriteListener.NOOP, OnReadListener.NOOP); + final Node originalNode = + db.run(tr -> { + final NodeFactory nodeFactory = storageAdapter.getNodeFactory(); + + final Node randomCompactNode = + createRandomCompactNode(random, nodeFactory, 768, 16); + + writeNode(tr, storageAdapter, randomCompactNode, 0); + return randomCompactNode; + }); + + db.run(tr -> storageAdapter.fetchNode(tr, 0, originalNode.getPrimaryKey()) + .thenAccept(node -> { + Assertions.assertAll( + () -> Assertions.assertInstanceOf(CompactNode.class, node), + () -> Assertions.assertEquals(NodeKind.COMPACT, node.getKind()), + () -> Assertions.assertEquals(node.getPrimaryKey(), originalNode.getPrimaryKey()), + () -> Assertions.assertEquals(node.asCompactNode().getVector(), + originalNode.asCompactNode().getVector()), + () -> { + final ArrayList neighbors = + Lists.newArrayList(node.getNeighbors()); + neighbors.sort(Comparator.comparing(NodeReference::getPrimaryKey)); + final ArrayList originalNeighbors = + Lists.newArrayList(originalNode.getNeighbors()); + originalNeighbors.sort(Comparator.comparing(NodeReference::getPrimaryKey)); + Assertions.assertEquals(neighbors, originalNeighbors); + } + ); + }).join()); + } + + @ParameterizedTest(name = "seed={0}") + @MethodSource("randomSeeds") + public void testInliningSerialization(final long seed) { + final Random random = new Random(seed); + final InliningStorageAdapter storageAdapter = + new InliningStorageAdapter(HNSW.DEFAULT_CONFIG, InliningNode.factory(), rtSubspace.getSubspace(), + OnWriteListener.NOOP, OnReadListener.NOOP); + final Node originalNode = + db.run(tr -> { + final NodeFactory nodeFactory = storageAdapter.getNodeFactory(); + + final Node randomInliningNode = + createRandomInliningNode(random, nodeFactory, 768, 16); + + writeNode(tr, storageAdapter, randomInliningNode, 0); + return randomInliningNode; + }); + + db.run(tr -> storageAdapter.fetchNode(tr, 0, originalNode.getPrimaryKey()) + .thenAccept(node -> Assertions.assertAll( + () -> Assertions.assertInstanceOf(InliningNode.class, node), + () -> Assertions.assertEquals(NodeKind.INLINING, node.getKind()), + () -> Assertions.assertEquals(node.getPrimaryKey(), originalNode.getPrimaryKey()), + () -> { + final ArrayList neighbors = + Lists.newArrayList(node.getNeighbors()); + neighbors.sort(Comparator.comparing(NodeReference::getPrimaryKey)); // should not be necessary the way it is stored + final ArrayList originalNeighbors = + Lists.newArrayList(originalNode.getNeighbors()); + originalNeighbors.sort(Comparator.comparing(NodeReference::getPrimaryKey)); + Assertions.assertEquals(neighbors, originalNeighbors); + } + )).join()); + } + + static Stream randomSeedsWithOptions() { + Sets.cartesianProduct(ImmutableSet.of(true, false), + ImmutableSet.of(true, false), + ImmutableSet.of(true, false)); + return Sets.cartesianProduct(ImmutableSet.of(true, false), + ImmutableSet.of(true, false), + ImmutableSet.of(true, false)) + .stream() + .flatMap(arguments -> + LongStream.generate(() -> new Random().nextLong()) + .limit(2) + .mapToObj(seed -> Arguments.of(ObjectArrays.concat(seed, arguments.toArray())))); + } + + @ParameterizedTest(name = "seed={0} useInlining={1} extendCandidates={2} keepPrunedConnections={3}") + @MethodSource("randomSeedsWithOptions") + public void testBasicInsert(final long seed, final boolean useInlining, final boolean extendCandidates, + final boolean keepPrunedConnections) { + final Random random = new Random(seed); + final Metric metric = Metrics.EUCLIDEAN_METRIC.getMetric(); + final AtomicLong nextNodeIdAtomic = new AtomicLong(0L); + + final TestOnReadListener onReadListener = new TestOnReadListener(); + + final int dimensions = 128; + final HNSW hnsw = new HNSW(rtSubspace.getSubspace(), TestExecutors.defaultThreadPool(), + HNSW.DEFAULT_CONFIG.toBuilder().setMetric(Metrics.EUCLIDEAN_METRIC.getMetric()) + .setUseInlining(useInlining).setExtendCandidates(extendCandidates) + .setKeepPrunedConnections(keepPrunedConnections) + .setM(32).setMMax(32).setMMax0(64).build(), + OnWriteListener.NOOP, onReadListener); + + final int k = 10; + final HalfVector queryVector = VectorTest.createRandomHalfVector(random, dimensions); + final TreeSet nodesOrderedByDistance = + new TreeSet<>(Comparator.comparing(NodeReferenceWithDistance::getDistance)); + + for (int i = 0; i < 1000;) { + i += basicInsertBatch(hnsw, 100, nextNodeIdAtomic, onReadListener, + tr -> { + final var primaryKey = createNextPrimaryKey(nextNodeIdAtomic); + final HalfVector dataVector = VectorTest.createRandomHalfVector(random, dimensions); + final double distance = Vector.comparativeDistance(metric, dataVector, queryVector); + final NodeReferenceWithDistance nodeReferenceWithDistance = + new NodeReferenceWithDistance(primaryKey, dataVector, distance); + nodesOrderedByDistance.add(nodeReferenceWithDistance); + if (nodesOrderedByDistance.size() > k) { + nodesOrderedByDistance.pollLast(); + } + return nodeReferenceWithDistance; + }); + } + + onReadListener.reset(); + final long beginTs = System.nanoTime(); + final List> results = + db.run(tr -> hnsw.kNearestNeighborsSearch(tr, k, 100, queryVector).join()); + final long endTs = System.nanoTime(); + + final ImmutableSet trueNN = + ImmutableSet.copyOf(NodeReference.primaryKeys(nodesOrderedByDistance)); + + int recallCount = 0; + for (NodeReferenceAndNode nodeReferenceAndNode : results) { + final NodeReferenceWithDistance nodeReferenceWithDistance = nodeReferenceAndNode.getNodeReferenceWithDistance(); + logger.info("nodeId ={} at distance={}", nodeReferenceWithDistance.getPrimaryKey().getLong(0), + nodeReferenceWithDistance.getDistance()); + if (trueNN.contains(nodeReferenceAndNode.getNode().getPrimaryKey())) { + recallCount ++; + } + } + final double recall = (double)recallCount / (double)k; + Assertions.assertTrue(recall > 0.93); + + logger.info("search transaction took elapsedTime={}ms; read nodes={}, read bytes={}, recall={}", + TimeUnit.NANOSECONDS.toMillis(endTs - beginTs), + onReadListener.getNodeCountByLayer(), onReadListener.getBytesReadByLayer(), + String.format(Locale.ROOT, "%.2f", recall * 100.0d)); + + final Set usedIds = + LongStream.range(0, 1000) + .boxed() + .collect(Collectors.toSet()); + + hnsw.scanLayer(db, 0, 100, node -> Assertions.assertTrue(usedIds.remove(node.getPrimaryKey().getLong(0)))); + } + + private int basicInsertBatch(final HNSW hnsw, final int batchSize, + @Nonnull final AtomicLong nextNodeIdAtomic, @Nonnull final TestOnReadListener onReadListener, + @Nonnull final Function insertFunction) { + return db.run(tr -> { + onReadListener.reset(); + final long nextNodeId = nextNodeIdAtomic.get(); + final long beginTs = System.nanoTime(); + for (int i = 0; i < batchSize; i ++) { + final var newNodeReference = insertFunction.apply(tr); + if (newNodeReference == null) { + return i; + } + hnsw.insert(tr, newNodeReference).join(); + } + final long endTs = System.nanoTime(); + logger.info("inserted batchSize={} records starting at nodeId={} took elapsedTime={}ms, readCounts={}, MSums={}", + batchSize, nextNodeId, TimeUnit.NANOSECONDS.toMillis(endTs - beginTs), + onReadListener.getNodeCountByLayer(), onReadListener.getSumMByLayer()); + return batchSize; + }); + } + + private int insertBatch(final HNSW hnsw, final int batchSize, + @Nonnull final AtomicLong nextNodeIdAtomic, @Nonnull final TestOnReadListener onReadListener, + @Nonnull final Function insertFunction) { + return db.run(tr -> { + onReadListener.reset(); + final long nextNodeId = nextNodeIdAtomic.get(); + final long beginTs = System.nanoTime(); + final ImmutableList.Builder nodeReferenceWithVectorBuilder = + ImmutableList.builder(); + for (int i = 0; i < batchSize; i ++) { + final var newNodeReference = insertFunction.apply(tr); + if (newNodeReference != null) { + nodeReferenceWithVectorBuilder.add(newNodeReference); + } + } + hnsw.insertBatch(tr, nodeReferenceWithVectorBuilder.build()).join(); + final long endTs = System.nanoTime(); + logger.info("inserted batch batchSize={} records starting at nodeId={} took elapsedTime={}ms, readCounts={}, MSums={}", + batchSize, nextNodeId, TimeUnit.NANOSECONDS.toMillis(endTs - beginTs), + onReadListener.getNodeCountByLayer(), onReadListener.getSumMByLayer()); + return batchSize; + }); + } + + @Test + @Timeout(value = 10, unit = TimeUnit.MINUTES) + public void testSIFTInsertSmall() throws Exception { + final Metric metric = Metrics.EUCLIDEAN_METRIC.getMetric(); + final int k = 100; + final AtomicLong nextNodeIdAtomic = new AtomicLong(0L); + + final TestOnReadListener onReadListener = new TestOnReadListener(); + + final HNSW hnsw = new HNSW(rtSubspace.getSubspace(), TestExecutors.defaultThreadPool(), + HNSW.DEFAULT_CONFIG.toBuilder().setMetric(metric).setM(32).setMMax(32).setMMax0(64).build(), + OnWriteListener.NOOP, onReadListener); + + final Path siftSmallPath = Paths.get(".out/extracted/siftsmall/siftsmall_base.fvecs"); + + try (final var fileChannel = FileChannel.open(siftSmallPath, StandardOpenOption.READ)) { + final Iterator vectorIterator = new Vector.StoredFVecsIterator(fileChannel); + + int i = 0; + while (vectorIterator.hasNext()) { + i += basicInsertBatch(hnsw, 100, nextNodeIdAtomic, onReadListener, + tr -> { + if (!vectorIterator.hasNext()) { + return null; + } + final Vector.DoubleVector doubleVector = vectorIterator.next(); + final Tuple currentPrimaryKey = createNextPrimaryKey(nextNodeIdAtomic); + final HalfVector currentVector = doubleVector.toHalfVector(); + return new NodeReferenceWithVector(currentPrimaryKey, currentVector); + }); + } + } + + validateSIFTSmall(hnsw, k); + } + + private void validateSIFTSmall(@Nonnull final HNSW hnsw, final int k) throws IOException { + final Path siftSmallGroundTruthPath = Paths.get(".out/extracted/siftsmall/siftsmall_groundtruth.ivecs"); + final Path siftSmallQueryPath = Paths.get(".out/extracted/siftsmall/siftsmall_query.fvecs"); + + final TestOnReadListener onReadListener = (TestOnReadListener)hnsw.getOnReadListener(); + + try (final var queryChannel = FileChannel.open(siftSmallQueryPath, StandardOpenOption.READ); + final var groundTruthChannel = FileChannel.open(siftSmallGroundTruthPath, StandardOpenOption.READ)) { + final Iterator queryIterator = new Vector.StoredFVecsIterator(queryChannel); + final Iterator> groundTruthIterator = new Vector.StoredIVecsIterator(groundTruthChannel); + + Verify.verify(queryIterator.hasNext() == groundTruthIterator.hasNext()); + + while (queryIterator.hasNext()) { + final HalfVector queryVector = queryIterator.next().toHalfVector(); + final Set groundTruthIndices = ImmutableSet.copyOf(groundTruthIterator.next()); + onReadListener.reset(); + final long beginTs = System.nanoTime(); + final List> results = + db.run(tr -> hnsw.kNearestNeighborsSearch(tr, k, 100, queryVector).join()); + final long endTs = System.nanoTime(); + logger.info("retrieved result in elapsedTimeMs={}, reading numNodes={}, readBytes={}", + TimeUnit.NANOSECONDS.toMillis(endTs - beginTs), + onReadListener.getNodeCountByLayer(), onReadListener.getBytesReadByLayer()); + + int recallCount = 0; + for (NodeReferenceAndNode nodeReferenceAndNode : results) { + final NodeReferenceWithDistance nodeReferenceWithDistance = + nodeReferenceAndNode.getNodeReferenceWithDistance(); + final int primaryKeyIndex = (int)nodeReferenceWithDistance.getPrimaryKey().getLong(0); + logger.trace("retrieved result nodeId = {} at distance = {} ", + primaryKeyIndex, nodeReferenceWithDistance.getDistance()); + if (groundTruthIndices.contains(primaryKeyIndex)) { + recallCount ++; + } + } + + final double recall = (double)recallCount / k; + Assertions.assertTrue(recall > 0.93); + + logger.info("query returned results recall={}", String.format(Locale.ROOT, "%.2f", recall * 100.0d)); + } + } + } + + @Test + @Timeout(value = 10, unit = TimeUnit.MINUTES) + public void testSIFTInsertSmallUsingBatchAPI() throws Exception { + final Metric metric = Metrics.EUCLIDEAN_METRIC.getMetric(); + final int k = 100; + final AtomicLong nextNodeIdAtomic = new AtomicLong(0L); + + final TestOnReadListener onReadListener = new TestOnReadListener(); + + final HNSW hnsw = new HNSW(rtSubspace.getSubspace(), TestExecutors.defaultThreadPool(), + HNSW.DEFAULT_CONFIG.toBuilder().setMetric(metric).setM(32).setMMax(32).setMMax0(64).build(), + OnWriteListener.NOOP, onReadListener); + + final Path siftSmallPath = Paths.get(".out/extracted/siftsmall/siftsmall_base.fvecs"); + + try (final var fileChannel = FileChannel.open(siftSmallPath, StandardOpenOption.READ)) { + final Iterator vectorIterator = new Vector.StoredFVecsIterator(fileChannel); + + int i = 0; + while (vectorIterator.hasNext()) { + i += insertBatch(hnsw, 100, nextNodeIdAtomic, onReadListener, + tr -> { + if (!vectorIterator.hasNext()) { + return null; + } + final Vector.DoubleVector doubleVector = vectorIterator.next(); + final Tuple currentPrimaryKey = createNextPrimaryKey(nextNodeIdAtomic); + final HalfVector currentVector = doubleVector.toHalfVector(); + return new NodeReferenceWithVector(currentPrimaryKey, currentVector); + }); + } + } + validateSIFTSmall(hnsw, k); + } + + @Test + public void testManyRandomVectors() { + final Random random = new Random(); + for (long l = 0L; l < 3000000; l ++) { + final HalfVector randomVector = VectorTest.createRandomHalfVector(random, 768); + final Tuple vectorTuple = StorageAdapter.tupleFromVector(randomVector); + final Vector roundTripVector = StorageAdapter.vectorFromTuple(vectorTuple); + Vector.comparativeDistance(Metrics.EUCLIDEAN_METRIC.getMetric(), randomVector, roundTripVector); + Assertions.assertEquals(randomVector, roundTripVector); + } + } + + private void writeNode(@Nonnull final Transaction transaction, + @Nonnull final StorageAdapter storageAdapter, + @Nonnull final Node node, + final int layer) { + final NeighborsChangeSet insertChangeSet = + new InsertNeighborsChangeSet<>(new BaseNeighborsChangeSet<>(ImmutableList.of()), + node.getNeighbors()); + storageAdapter.writeNode(transaction, node, layer, insertChangeSet); + } + + @Nonnull + private Node createRandomCompactNode(@Nonnull final Random random, + @Nonnull final NodeFactory nodeFactory, + final int dimensionality, + final int numberOfNeighbors) { + final Tuple primaryKey = createRandomPrimaryKey(random); + final ImmutableList.Builder neighborsBuilder = ImmutableList.builder(); + for (int i = 0; i < numberOfNeighbors; i ++) { + neighborsBuilder.add(createRandomNodeReference(random)); + } + + return nodeFactory.create(primaryKey, VectorTest.createRandomHalfVector(random, dimensionality), neighborsBuilder.build()); + } + + @Nonnull + private Node createRandomInliningNode(@Nonnull final Random random, + @Nonnull final NodeFactory nodeFactory, + final int dimensionality, + final int numberOfNeighbors) { + final Tuple primaryKey = createRandomPrimaryKey(random); + final ImmutableList.Builder neighborsBuilder = ImmutableList.builder(); + for (int i = 0; i < numberOfNeighbors; i ++) { + neighborsBuilder.add(createRandomNodeReferenceWithVector(random, dimensionality)); + } + + return nodeFactory.create(primaryKey, VectorTest.createRandomHalfVector(random, dimensionality), neighborsBuilder.build()); + } + + @Nonnull + private NodeReference createRandomNodeReference(@Nonnull final Random random) { + return new NodeReference(createRandomPrimaryKey(random)); + } + + @Nonnull + private NodeReferenceWithVector createRandomNodeReferenceWithVector(@Nonnull final Random random, final int dimensionality) { + return new NodeReferenceWithVector(createRandomPrimaryKey(random), VectorTest.createRandomHalfVector(random, dimensionality)); + } + + @Nonnull + private static Tuple createRandomPrimaryKey(final @Nonnull Random random) { + return Tuple.from(random.nextLong()); + } + + @Nonnull + private static Tuple createNextPrimaryKey(@Nonnull final AtomicLong nextIdAtomic) { + return Tuple.from(nextIdAtomic.getAndIncrement()); + } + + private static class TestOnReadListener implements OnReadListener { + final Map nodeCountByLayer; + final Map sumMByLayer; + final Map bytesReadByLayer; + + public TestOnReadListener() { + this.nodeCountByLayer = Maps.newConcurrentMap(); + this.sumMByLayer = Maps.newConcurrentMap(); + this.bytesReadByLayer = Maps.newConcurrentMap(); + } + + public Map getNodeCountByLayer() { + return nodeCountByLayer; + } + + public Map getBytesReadByLayer() { + return bytesReadByLayer; + } + + public Map getSumMByLayer() { + return sumMByLayer; + } + + public void reset() { + nodeCountByLayer.clear(); + bytesReadByLayer.clear(); + sumMByLayer.clear(); + } + + @Override + public void onNodeRead(final int layer, @Nonnull final Node node) { + nodeCountByLayer.compute(layer, (l, oldValue) -> (oldValue == null ? 0 : oldValue) + 1L); + sumMByLayer.compute(layer, (l, oldValue) -> (oldValue == null ? 0 : oldValue) + node.getNeighbors().size()); + } + + @Override + public void onKeyValueRead(final int layer, @Nonnull final byte[] key, @Nonnull final byte[] value) { + bytesReadByLayer.compute(layer, (l, oldValue) -> (oldValue == null ? 0 : oldValue) + + key.length + value.length); + } + } +} diff --git a/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/MetricTest.java b/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/MetricTest.java new file mode 100644 index 0000000000..78df74a7e4 --- /dev/null +++ b/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/MetricTest.java @@ -0,0 +1,174 @@ +/* + * MetricTest.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class MetricTest { + private final Metric.ManhattanMetric manhattanMetric = new Metric.ManhattanMetric(); + private final Metric.EuclideanMetric euclideanMetric = new Metric.EuclideanMetric(); + private final Metric.EuclideanSquareMetric euclideanSquareMetric = new Metric.EuclideanSquareMetric(); + private final Metric.CosineMetric cosineMetric = new Metric.CosineMetric(); + private Metric.DotProductMetric dotProductMetric; + + @BeforeEach + + public void setUp() { + dotProductMetric = new Metric.DotProductMetric(); + } + + @Test + public void manhattanMetricDistanceWithIdenticalVectorsShouldReturnZeroTest() { + // Arrange + double[] vector1 = {1.0, 2.5, -3.0}; + double[] vector2 = {1.0, 2.5, -3.0}; + double expectedDistance = 0.0; + + // Act + double actualDistance = manhattanMetric.distance(vector1, vector2); + + // Assert + assertEquals(expectedDistance, actualDistance, 0.00001); + } + + @Test + public void manhattanMetricDistanceWithPositiveValueVectorsShouldReturnCorrectDistanceTest() { + // Arrange + double[] vector1 = {1.0, 2.0, 3.0}; + double[] vector2 = {4.0, 5.0, 6.0}; + double expectedDistance = 9.0; // |1-4| + |2-5| + |3-6| = 3 + 3 + 3 + + // Act + double actualDistance = manhattanMetric.distance(vector1, vector2); + + // Assert + assertEquals(expectedDistance, actualDistance, 0.00001); + } + + @Test + public void euclideanMetricDistanceWithIdenticalVectorsShouldReturnZeroTest() { + // Arrange + double[] vector1 = {1.0, 2.5, -3.0}; + double[] vector2 = {1.0, 2.5, -3.0}; + double expectedDistance = 0.0; + + // Act + double actualDistance = euclideanMetric.distance(vector1, vector2); + + // Assert + assertEquals(expectedDistance, actualDistance, 0.00001); + } + + @Test + public void euclideanMetricDistanceWithDifferentPositiveVectorsShouldReturnCorrectDistanceTest() { + // Arrange + double[] vector1 = {1.0, 2.0}; + double[] vector2 = {4.0, 6.0}; + double expectedDistance = 5.0; // sqrt((1-4)^2 + (2-6)^2) = sqrt(9 + 16) = 5.0 + + // Act + double actualDistance = euclideanMetric.distance(vector1, vector2); + + // Assert + assertEquals(expectedDistance, actualDistance, 0.00001); + } + + @Test + public void euclideanSquareMetricDistanceWithIdenticalVectorsShouldReturnZeroTest() { + // Arrange + double[] vector1 = {1.0, 2.5, -3.0}; + double[] vector2 = {1.0, 2.5, -3.0}; + double expectedDistance = 0.0; + + // Act + double actualDistance = euclideanSquareMetric.distance(vector1, vector2); + + // Assert + assertEquals(expectedDistance, actualDistance, 0.00001); + } + + @Test + public void euclideanSquareMetricDistanceWithDifferentPositiveVectorsShouldReturnCorrectDistanceTest() { + // Arrange + double[] vector1 = {1.0, 2.0}; + double[] vector2 = {4.0, 6.0}; + double expectedDistance = 25.0; // (1-4)^2 + (2-6)^2 = 9 + 16 = 25.0 + + // Act + double actualDistance = euclideanSquareMetric.distance(vector1, vector2); + + // Assert + assertEquals(expectedDistance, actualDistance, 0.00001); + } + + @Test + public void cosineMetricDistanceWithIdenticalVectorsReturnsZeroTest() { + // Arrange + double[] vector1 = {5.0, 3.0, -2.0}; + double[] vector2 = {5.0, 3.0, -2.0}; + double expectedDistance = 0.0; + + // Act + double actualDistance = cosineMetric.distance(vector1, vector2); + + // Assert + assertEquals(expectedDistance, actualDistance, 0.00001); + } + + @Test + public void cosineMetricDistanceWithOrthogonalVectorsReturnsOneTest() { + // Arrange + double[] vector1 = {1.0, 0.0}; + double[] vector2 = {0.0, 1.0}; + double expectedDistance = 1.0; + + // Act + double actualDistance = cosineMetric.distance(vector1, vector2); + + // Assert + assertEquals(expectedDistance, actualDistance, 0.00001); + } + + @Test + public void dotProductMetricComparativeDistanceWithPositiveVectorsTest() { + double[] vector1 = {1.0, 2.0, 3.0}; + double[] vector2 = {4.0, 5.0, 6.0}; + double expected = -32.0; + + double actual = dotProductMetric.comparativeDistance(vector1, vector2); + + assertEquals(expected, actual, 0.00001); + } + + @Test + public void dotProductMetricComparativeDistanceWithOrthogonalVectorsReturnsZeroTest() { + double[] vector1 = {1.0, 0.0}; + double[] vector2 = {0.0, 1.0}; + double expected = -0.0; + + double actual = dotProductMetric.comparativeDistance(vector1, vector2); + + assertEquals(expected, actual, 0.00001); + } +} diff --git a/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/VectorTest.java b/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/VectorTest.java new file mode 100644 index 0000000000..fa7f27db21 --- /dev/null +++ b/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/VectorTest.java @@ -0,0 +1,79 @@ +/* + * VectorTest.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.christianheina.langx.half4j.Half; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import javax.annotation.Nonnull; +import java.util.Random; +import java.util.stream.LongStream; +import java.util.stream.Stream; + +public class VectorTest { + private static Stream randomSeeds() { + return LongStream.generate(() -> new Random().nextLong()) + .limit(5) + .boxed(); + } + + @ParameterizedTest(name = "seed={0}") + @MethodSource("randomSeeds") + void testSerializationDeserializationHalfVector(final long seed) { + final Random random = new Random(seed); + final Vector.HalfVector randomVector = createRandomHalfVector(random, 128); + final Vector deserializedVector = StorageAdapter.vectorFromBytes(randomVector.getRawData()); + Assertions.assertThat(deserializedVector).isInstanceOf(Vector.HalfVector.class); + Assertions.assertThat(deserializedVector).isEqualTo(randomVector); + } + + @ParameterizedTest(name = "seed={0}") + @MethodSource("randomSeeds") + void testSerializationDeserializationDoubleVector(final long seed) { + final Random random = new Random(seed); + final Vector.DoubleVector randomVector = createRandomDoubleVector(random, 128); + final Vector deserializedVector = StorageAdapter.vectorFromBytes(randomVector.getRawData()); + Assertions.assertThat(deserializedVector).isInstanceOf(Vector.DoubleVector.class); + Assertions.assertThat(deserializedVector).isEqualTo(randomVector); + } + + @Nonnull + static Vector.HalfVector createRandomHalfVector(@Nonnull final Random random, final int dimensionality) { + final Half[] components = new Half[dimensionality]; + for (int d = 0; d < dimensionality; d ++) { + // don't ask + components[d] = HNSWHelpers.halfValueOf(random.nextDouble()); + } + return new Vector.HalfVector(components); + } + + @Nonnull + static Vector.DoubleVector createRandomDoubleVector(@Nonnull final Random random, final int dimensionality) { + final double[] components = new double[dimensionality]; + for (int d = 0; d < dimensionality; d ++) { + // don't ask + components[d] = random.nextDouble(); + } + return new Vector.DoubleVector(components); + } +} diff --git a/gradle/codequality/pmd-rules.xml b/gradle/codequality/pmd-rules.xml index 500ef17c69..4d8745d875 100644 --- a/gradle/codequality/pmd-rules.xml +++ b/gradle/codequality/pmd-rules.xml @@ -16,6 +16,7 @@ + diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index c4e6482b97..419df00cd0 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -37,6 +37,7 @@ generatedAnnotation = "1.3.2" grpc = "1.64.1" grpc-commonProtos = "2.37.0" guava = "33.3.1-jre" +half4j = "0.0.2" h2 = "1.3.148" icu = "69.1" lucene = "8.11.1" @@ -95,6 +96,7 @@ grpc-services = { module = "io.grpc:grpc-services", version.ref = "grpc" } grpc-stub = { module = "io.grpc:grpc-stub", version.ref = "grpc" } grpc-util = { module = "io.grpc:grpc-util", version.ref = "grpc" } guava = { module = "com.google.guava:guava", version.ref = "guava" } +half4j = { module = "com.christianheina.langx:half4j", version.ref = "half4j"} icu = { module = "com.ibm.icu:icu4j", version.ref = "icu" } javaPoet = { module = "com.squareup:javapoet", version.ref = "javaPoet" } jsr305 = { module = "com.google.code.findbugs:jsr305", version.ref = "jsr305" } diff --git a/gradle/scripts/log4j-test.properties b/gradle/scripts/log4j-test.properties index 447ee2f55a..1ae7583751 100644 --- a/gradle/scripts/log4j-test.properties +++ b/gradle/scripts/log4j-test.properties @@ -26,7 +26,7 @@ appender.console.name = STDOUT appender.console.layout.type = PatternLayout appender.console.layout.pattern = %d [%level] %logger{1.} - %m %X%n%ex{full} -rootLogger.level = debug +rootLogger.level = info rootLogger.appenderRefs = stdout rootLogger.appenderRef.stdout.ref = STDOUT