diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/AbstractStorageAdapter.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/AbstractStorageAdapter.java index 84e7db99ab..02570bc2fa 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/AbstractStorageAdapter.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/AbstractStorageAdapter.java @@ -26,6 +26,7 @@ import com.apple.foundationdb.linear.Quantizer; import com.apple.foundationdb.subspace.Subspace; import com.apple.foundationdb.tuple.Tuple; +import com.google.common.base.Verify; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -98,6 +99,34 @@ public NodeFactory getNodeFactory() { return nodeFactory; } + @Override + public boolean isInliningStorageAdapter() { + final boolean isInliningStorageAdapter = getNodeFactory().getNodeKind() == NodeKind.INLINING; + Verify.verify(!isInliningStorageAdapter || this instanceof InliningStorageAdapter); + return isInliningStorageAdapter; + } + + @Nonnull + @Override + public InliningStorageAdapter asInliningStorageAdapter() { + Verify.verify(isInliningStorageAdapter()); + return (InliningStorageAdapter)this; + } + + @Override + public boolean isCompactStorageAdapter() { + final boolean isCompactStorageAdapter = getNodeFactory().getNodeKind() == NodeKind.COMPACT; + Verify.verify(!isCompactStorageAdapter || this instanceof CompactStorageAdapter); + return isCompactStorageAdapter; + } + + @Nonnull + @Override + public CompactStorageAdapter asCompactStorageAdapter() { + Verify.verify(isCompactStorageAdapter()); + return (CompactStorageAdapter)this; + } + @Override @Nonnull public Subspace getSubspace() { @@ -130,23 +159,6 @@ public OnReadListener getOnReadListener() { return onReadListener; } - /** - * Asynchronously fetches a node from a specific layer of the HNSW. - *

- * The node is identified by its {@code layer} and {@code primaryKey}. The entire fetch operation is - * performed within the given {@link ReadTransaction}. After the underlying - * fetch operation completes, the retrieved node is validated by the - * {@link #checkNode(Node)} method before the returned future is completed. - * - * @param readTransaction the non-null transaction to use for the read operation - * @param storageTransform an affine vector transformation operator that is used to transform the fetched vector - * into the storage space that is currently being used - * @param layer the layer of the tree from which to fetch the node - * @param primaryKey the non-null primary key that identifies the node to fetch - * - * @return a {@link CompletableFuture} that will complete with the fetched {@link AbstractNode} - * once it has been read from storage and validated - */ @Nonnull @Override public CompletableFuture> fetchNode(@Nonnull final ReadTransaction readTransaction, @@ -169,7 +181,7 @@ public CompletableFuture> fetchNode(@Nonnull final ReadTransacti * @param primaryKey the primary key that uniquely identifies the node to be fetched; must not be {@code null} * * @return a {@link CompletableFuture} that will be completed with the fetched {@link AbstractNode}. - * The future will complete with {@code null} if no node is found for the given key and layer. + * The future will complete with {@code null} if no node is found for the given key and layer. */ @Nonnull protected abstract CompletableFuture> fetchNodeInternal(@Nonnull ReadTransaction readTransaction, @@ -185,7 +197,7 @@ protected abstract CompletableFuture> fetchNodeInternal(@Nonnull * @return the node that was passed in */ @Nullable - private > T checkNode(@Nullable final T node) { + protected > T checkNode(@Nullable final T node) { return node; } @@ -200,23 +212,23 @@ private > T checkNode(@Nullable final T node) { * * @param transaction the non-null {@link Transaction} context for this write operation * @param quantizer the quantizer to use - * @param node the non-null {@link Node} to be written to storage * @param layer the layer index where the node is being written + * @param node the non-null {@link Node} to be written to storage * @param changeSet the non-null {@link NeighborsChangeSet} detailing the modifications * to the node's neighbors */ @Override public void writeNode(@Nonnull final Transaction transaction, @Nonnull final Quantizer quantizer, - @Nonnull final AbstractNode node, final int layer, + final int layer, @Nonnull final AbstractNode node, @Nonnull final NeighborsChangeSet changeSet) { - writeNodeInternal(transaction, quantizer, node, layer, changeSet); + writeNodeInternal(transaction, quantizer, layer, node, changeSet); if (logger.isTraceEnabled()) { logger.trace("written node with key={} at layer={}", node.getPrimaryKey(), layer); } } /** - * Writes a single node to the data store as part of a larger transaction. + * Writes a single node to the given layer of the data store as part of a larger transaction. *

* This is an abstract method that concrete implementations must provide. * It is responsible for the low-level persistence of the given {@code node} at a @@ -225,12 +237,28 @@ public void writeNode(@Nonnull final Transaction transaction, @Nonnull final Qua * * @param transaction the non-null transaction context for the write operation * @param quantizer the quantizer to use - * @param node the non-null {@link Node} to write * @param layer the layer or level of the node in the structure + * @param node the non-null {@link Node} to write * @param changeSet the non-null {@link NeighborsChangeSet} detailing additions or * removals of neighbor links */ protected abstract void writeNodeInternal(@Nonnull Transaction transaction, @Nonnull Quantizer quantizer, - @Nonnull AbstractNode node, int layer, + int layer, @Nonnull AbstractNode node, @Nonnull NeighborsChangeSet changeSet); + + @Override + public void deleteNode(@Nonnull final Transaction transaction, final int layer, @Nonnull final Tuple primaryKey) { + deleteNodeInternal(transaction, layer, primaryKey); + if (logger.isTraceEnabled()) { + logger.trace("deleted node with key={} at layer={}", primaryKey, layer); + } + } + + /** + * Deletes a single node from the given layer of the data store as part of a larger transaction. + * @param transaction the transaction to use + * @param layer the layer + * @param primaryKey the primary key of the node + */ + protected abstract void deleteNodeInternal(@Nonnull Transaction transaction, int layer, @Nonnull Tuple primaryKey); } diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/BaseNeighborsChangeSet.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/BaseNeighborsChangeSet.java index 490b4bc844..f7ee479920 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/BaseNeighborsChangeSet.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/BaseNeighborsChangeSet.java @@ -67,6 +67,15 @@ public BaseNeighborsChangeSet getParent() { return null; } + /** + * Returns {@code false} as this change set is a base change set. It does not represent any changes. + * @return {@code false} as this change set does not have any changes. + */ + @Override + public boolean hasChanges() { + return false; + } + /** * Retrieves the list of neighbors associated with this object. *

diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/CompactStorageAdapter.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/CompactStorageAdapter.java index b03c296f67..b6fa91b0ab 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/CompactStorageAdapter.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/CompactStorageAdapter.java @@ -73,6 +73,12 @@ public CompactStorageAdapter(@Nonnull final Config config, super(config, nodeFactory, subspace, onWriteListener, onReadListener); } + @Nonnull + @Override + public Transformed getVector(@Nonnull final NodeReference nodeReference, @Nonnull final AbstractNode node) { + return node.asCompactNode().getVector(); + } + /** * Asynchronously fetches a node from the database for a given layer and primary key. *

@@ -88,8 +94,6 @@ public CompactStorageAdapter(@Nonnull final Config config, * * @return a future that will complete with the fetched {@link AbstractNode} or {@code null} if the node cannot * be fetched - * - * @throws IllegalStateException if the node cannot be found in the database for the given key */ @Nonnull @Override @@ -97,8 +101,7 @@ protected CompletableFuture> fetchNodeInternal(@Nonn @Nonnull final AffineOperator storageTransform, final int layer, @Nonnull final Tuple primaryKey) { - final byte[] keyBytes = getDataSubspace().pack(Tuple.from(layer, primaryKey)); - + final byte[] keyBytes = getNodeKey(layer, primaryKey); return readTransaction.get(keyBytes) .thenApply(valueBytes -> { if (valueBytes == null) { @@ -216,16 +219,16 @@ private AbstractNode compactNodeFromTuples(@Nonnull final AffineO * * @param transaction the {@link Transaction} to use for the write operation. * @param quantizer the quantizer to use - * @param node the {@link AbstractNode} to be serialized and written; it is processed as a {@link CompactNode}. * @param layer the graph layer index for the node, used to construct the storage key. + * @param node the {@link AbstractNode} to be serialized and written; it is processed as a {@link CompactNode}. * @param neighborsChangeSet a {@link NeighborsChangeSet} containing the additions and removals, which are * merged to determine the final set of neighbors to be written. */ @Override public void writeNodeInternal(@Nonnull final Transaction transaction, @Nonnull final Quantizer quantizer, - @Nonnull final AbstractNode node, final int layer, + final int layer, @Nonnull final AbstractNode node, @Nonnull final NeighborsChangeSet neighborsChangeSet) { - final byte[] key = getDataSubspace().pack(Tuple.from(layer, node.getPrimaryKey())); + final byte[] key = getNodeKey(layer, node.getPrimaryKey()); final List nodeItems = Lists.newArrayListWithExpectedSize(3); nodeItems.add(NodeKind.COMPACT.getSerialized()); @@ -254,6 +257,33 @@ public void writeNodeInternal(@Nonnull final Transaction transaction, @Nonnull f } } + @Override + protected void deleteNodeInternal(@Nonnull final Transaction transaction, final int layer, + @Nonnull final Tuple primaryKey) { + final byte[] key = getNodeKey(layer, primaryKey); + transaction.clear(key); + getOnWriteListener().onNodeDeleted(layer, primaryKey); + getOnWriteListener().onKeyDeleted(layer, key); + } + + /** + * Constructs the raw database key for a node based on its layer and primary key. + *

+ * This key is created by packing a tuple containing the specified {@code layer} and the node's {@code primaryKey} + * within the data subspace. The resulting byte array is suitable for use in direct database lookups and preserves + * the sort order of the components. + * + * @param layer the layer index where the node resides + * @param primaryKey the primary key that uniquely identifies the node within its layer, + * encapsulated in a {@link Tuple} + * + * @return a byte array representing the packed key for the specified node + */ + @Nonnull + private byte[] getNodeKey(final int layer, @Nonnull final Tuple primaryKey) { + return getDataSubspace().pack(Tuple.from(layer, primaryKey)); + } + /** * Scans a given layer for nodes, returning an iterable over the results. *

diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Config.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Config.java index efa2d3181b..ac3aae3279 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Config.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Config.java @@ -32,13 +32,13 @@ */ @SuppressWarnings("checkstyle:MemberName") public final class Config { - public static final boolean DEFAULT_DETERMINISTIC_SEEDING = false; @Nonnull public static final Metric DEFAULT_METRIC = Metric.EUCLIDEAN_METRIC; public static final boolean DEFAULT_USE_INLINING = false; public static final int DEFAULT_M = 16; public static final int DEFAULT_M_MAX_0 = 2 * DEFAULT_M; public static final int DEFAULT_M_MAX = DEFAULT_M; public static final int DEFAULT_EF_CONSTRUCTION = 200; + public static final int DEFAULT_EF_REPAIR = 64; public static final boolean DEFAULT_EXTEND_CANDIDATES = false; public static final boolean DEFAULT_KEEP_PRUNED_CONNECTIONS = false; // stats @@ -48,12 +48,11 @@ public final class Config { // RaBitQ public static final boolean DEFAULT_USE_RABITQ = false; public static final int DEFAULT_RABITQ_NUM_EX_BITS = 4; - // concurrency public static final int DEFAULT_MAX_NUM_CONCURRENT_NODE_FETCHES = 16; - public static final int DEFAULT_MAX_NUM_CONCURRENT_NEIGHBOR_FETCHES = 16; + public static final int DEFAULT_MAX_NUM_CONCURRENT_NEIGHBOR_FETCHES = 10; + public static final int DEFAULT_MAX_NUM_CONCURRENT_DELETE_FROM_LAYER = 2; - private final boolean deterministicSeeding; @Nonnull private final Metric metric; private final int numDimensions; @@ -62,6 +61,7 @@ public final class Config { private final int mMax; private final int mMax0; private final int efConstruction; + private final int efRepair; private final boolean extendCandidates; private final boolean keepPrunedConnections; private final double sampleVectorStatsProbability; @@ -71,13 +71,15 @@ public final class Config { private final int raBitQNumExBits; private final int maxNumConcurrentNodeFetches; private final int maxNumConcurrentNeighborhoodFetches; + private final int maxNumConcurrentDeleteFromLayer; - private Config(final boolean deterministicSeeding, @Nonnull final Metric metric, final int numDimensions, - final boolean useInlining, final int m, final int mMax, final int mMax0, - final int efConstruction, final boolean extendCandidates, final boolean keepPrunedConnections, + private Config(@Nonnull final Metric metric, final int numDimensions, final boolean useInlining, final int m, + final int mMax, final int mMax0, final int efConstruction, final int efRepair, + final boolean extendCandidates, final boolean keepPrunedConnections, final double sampleVectorStatsProbability, final double maintainStatsProbability, final int statsThreshold, final boolean useRaBitQ, final int raBitQNumExBits, - final int maxNumConcurrentNodeFetches, final int maxNumConcurrentNeighborhoodFetches) { + final int maxNumConcurrentNodeFetches, final int maxNumConcurrentNeighborhoodFetches, + final int maxNumConcurrentDeleteFromLayer) { Preconditions.checkArgument(numDimensions >= 1, "numDimensions must be (1, MAX_INT]"); Preconditions.checkArgument(m >= 4 && m <= 200, "m must be [4, 200]"); Preconditions.checkArgument(mMax >= 4 && mMax <= 200, "mMax must be [4, 200]"); @@ -86,6 +88,8 @@ private Config(final boolean deterministicSeeding, @Nonnull final Metric metric, Preconditions.checkArgument(mMax <= mMax0, "mMax must be less than or equal to mMax0"); Preconditions.checkArgument(efConstruction >= 100 && efConstruction <= 400, "efConstruction must be [100, 400]"); + Preconditions.checkArgument(efRepair >= m && efRepair <= 400, + "efRepair must be [m, 400]"); Preconditions.checkArgument(!useRaBitQ || (sampleVectorStatsProbability > 0.0d && sampleVectorStatsProbability <= 1.0d), "sampleVectorStatsProbability out of range"); @@ -98,10 +102,12 @@ private Config(final boolean deterministicSeeding, @Nonnull final Metric metric, Preconditions.checkArgument(maxNumConcurrentNodeFetches > 0 && maxNumConcurrentNodeFetches <= 64, "maxNumConcurrentNodeFetches must be (0, 64]"); Preconditions.checkArgument(maxNumConcurrentNeighborhoodFetches > 0 && - maxNumConcurrentNeighborhoodFetches <= 64, - "maxNumConcurrentNeighborhoodFetches must be (0, 64]"); + maxNumConcurrentNeighborhoodFetches <= 20, + "maxNumConcurrentNeighborhoodFetches must be (0, 20]"); + Preconditions.checkArgument(maxNumConcurrentDeleteFromLayer > 0 && + maxNumConcurrentDeleteFromLayer <= 10, + "maxNumConcurrentDeleteFromLayer must be (0, 10]"); - this.deterministicSeeding = deterministicSeeding; this.metric = metric; this.numDimensions = numDimensions; this.useInlining = useInlining; @@ -109,6 +115,7 @@ private Config(final boolean deterministicSeeding, @Nonnull final Metric metric, this.mMax = mMax; this.mMax0 = mMax0; this.efConstruction = efConstruction; + this.efRepair = efRepair; this.extendCandidates = extendCandidates; this.keepPrunedConnections = keepPrunedConnections; this.sampleVectorStatsProbability = sampleVectorStatsProbability; @@ -118,15 +125,7 @@ private Config(final boolean deterministicSeeding, @Nonnull final Metric metric, this.raBitQNumExBits = raBitQNumExBits; this.maxNumConcurrentNodeFetches = maxNumConcurrentNodeFetches; this.maxNumConcurrentNeighborhoodFetches = maxNumConcurrentNeighborhoodFetches; - } - - /** - * Indicator that if {@code true} causes the insert logic of the HNSW to be seeded using a hash of the primary key - * of the record that is inserted. That can be useful for testing. If {@code isDeterministicSeeding} is - * {@code false}, we use {@link System#nanoTime()} for seeding. - */ - public boolean isDeterministicSeeding() { - return deterministicSeeding; + this.maxNumConcurrentDeleteFromLayer = maxNumConcurrentDeleteFromLayer; } /** @@ -210,14 +209,24 @@ public int getMMax0() { /** * Maximum size of the search queues (one independent queue per layer) that are used during the insertion of a new - * node. If {@code efConstruction} is set to {@code 1}, the search naturally follows a greedy approach - * (monotonous descent), whereas a high number for {@code efConstruction} allows for a more nuanced search that can - * tolerate (false) local minima. + * node. If {@code efConstruction} is set to a smaller number, the search naturally follows a more greedy approach + * (monotonous descent), whereas a higher number for {@code efConstruction} allows for a more nuanced search that + * can tolerate (false) local minima. */ public int getEfConstruction() { return efConstruction; } + /** + * Maximum number of candidate nodes that are considered when a HNSW layer is locally repaired as part of a + * delete operation. A smaller number causes the delete operation to create a smaller set of candidate nodes + * which improves repair performance but decreases repair quality; a higher number results in qualitatively + * better repairs at the expense of slower performance. + */ + public int getEfRepair() { + return efRepair; + } + /** * Indicator to signal if, during the insertion of a node, the set of nearest neighbors of that node is to be * extended by the actual neighbors of those neighbors to form a set of candidates that the new node may be @@ -295,13 +304,20 @@ public int getMaxNumConcurrentNeighborhoodFetches() { return maxNumConcurrentNeighborhoodFetches; } + /** + * Maximum number of delete operations that can run concurrently during a delete operation. + */ + public int getMaxNumConcurrentDeleteFromLayer() { + return maxNumConcurrentDeleteFromLayer; + } + @Nonnull public ConfigBuilder toBuilder() { - return new ConfigBuilder(isDeterministicSeeding(), getMetric(), isUseInlining(), getM(), getMMax(), getMMax0(), - getEfConstruction(), isExtendCandidates(), isKeepPrunedConnections(), + return new ConfigBuilder(getMetric(), isUseInlining(), getM(), getMMax(), getMMax0(), + getEfConstruction(), getEfRepair(), isExtendCandidates(), isKeepPrunedConnections(), getSampleVectorStatsProbability(), getMaintainStatsProbability(), getStatsThreshold(), isUseRaBitQ(), getRaBitQNumExBits(), getMaxNumConcurrentNodeFetches(), - getMaxNumConcurrentNeighborhoodFetches()); + getMaxNumConcurrentNeighborhoodFetches(), getMaxNumConcurrentDeleteFromLayer()); } @Override @@ -313,38 +329,41 @@ public boolean equals(final Object o) { return false; } final Config config = (Config)o; - return deterministicSeeding == config.deterministicSeeding && numDimensions == config.numDimensions && - useInlining == config.useInlining && m == config.m && mMax == config.mMax && mMax0 == config.mMax0 && - efConstruction == config.efConstruction && extendCandidates == config.extendCandidates && + return numDimensions == config.numDimensions && useInlining == config.useInlining && m == config.m && + mMax == config.mMax && mMax0 == config.mMax0 && efConstruction == config.efConstruction && + efRepair == config.efRepair && extendCandidates == config.extendCandidates && keepPrunedConnections == config.keepPrunedConnections && Double.compare(sampleVectorStatsProbability, config.sampleVectorStatsProbability) == 0 && Double.compare(maintainStatsProbability, config.maintainStatsProbability) == 0 && statsThreshold == config.statsThreshold && useRaBitQ == config.useRaBitQ && raBitQNumExBits == config.raBitQNumExBits && metric == config.metric && maxNumConcurrentNodeFetches == config.maxNumConcurrentNodeFetches && - maxNumConcurrentNeighborhoodFetches == config.maxNumConcurrentNeighborhoodFetches; + maxNumConcurrentNeighborhoodFetches == config.maxNumConcurrentNeighborhoodFetches && + maxNumConcurrentDeleteFromLayer == config.maxNumConcurrentDeleteFromLayer; } @Override public int hashCode() { - return Objects.hash(deterministicSeeding, metric, numDimensions, useInlining, m, mMax, mMax0, efConstruction, + return Objects.hash(metric, numDimensions, useInlining, m, mMax, mMax0, efConstruction, efRepair, extendCandidates, keepPrunedConnections, sampleVectorStatsProbability, maintainStatsProbability, - statsThreshold, useRaBitQ, raBitQNumExBits, maxNumConcurrentNodeFetches, maxNumConcurrentNeighborhoodFetches); + statsThreshold, useRaBitQ, raBitQNumExBits, maxNumConcurrentNodeFetches, + maxNumConcurrentNeighborhoodFetches, maxNumConcurrentDeleteFromLayer); } @Override @Nonnull public String toString() { - return "Config[deterministicSeeding=" + isDeterministicSeeding() + ", metric=" + getMetric() + - ", numDimensions=" + getNumDimensions() + ", isUseInlining=" + isUseInlining() + ", M=" + getM() + - ", MMax=" + getMMax() + ", MMax0=" + getMMax0() + ", efConstruction=" + getEfConstruction() + - ", isExtendCandidates=" + isExtendCandidates() + + return "Config[metric=" + getMetric() + ", numDimensions=" + getNumDimensions() + + ", isUseInlining=" + isUseInlining() + ", M=" + getM() + ", MMax=" + getMMax() + + ", MMax0=" + getMMax0() + ", efConstruction=" + getEfConstruction() + + ", efRepair=" + getEfRepair() + ", isExtendCandidates=" + isExtendCandidates() + ", isKeepPrunedConnections=" + isKeepPrunedConnections() + ", sampleVectorStatsProbability=" + getSampleVectorStatsProbability() + ", mainStatsProbability=" + getMaintainStatsProbability() + ", statsThreshold=" + getStatsThreshold() + ", useRaBitQ=" + isUseRaBitQ() + ", raBitQNumExBits=" + getRaBitQNumExBits() + ", maxNumConcurrentNodeFetches=" + getMaxNumConcurrentNodeFetches() + ", maxNumConcurrentNeighborhoodFetches=" + getMaxNumConcurrentNeighborhoodFetches() + + ", maxNumConcurrentDeleteFromLayer=" + getMaxNumConcurrentDeleteFromLayer() + "]"; } @@ -356,7 +375,6 @@ public String toString() { @CanIgnoreReturnValue @SuppressWarnings("checkstyle:MemberName") public static class ConfigBuilder { - private boolean deterministicSeeding = DEFAULT_DETERMINISTIC_SEEDING; @Nonnull private Metric metric = DEFAULT_METRIC; private boolean useInlining = DEFAULT_USE_INLINING; @@ -364,6 +382,7 @@ public static class ConfigBuilder { private int mMax = DEFAULT_M_MAX; private int mMax0 = DEFAULT_M_MAX_0; private int efConstruction = DEFAULT_EF_CONSTRUCTION; + private int efRepair = DEFAULT_EF_REPAIR; private boolean extendCandidates = DEFAULT_EXTEND_CANDIDATES; private boolean keepPrunedConnections = DEFAULT_KEEP_PRUNED_CONNECTIONS; @@ -376,23 +395,25 @@ public static class ConfigBuilder { private int maxNumConcurrentNodeFetches = DEFAULT_MAX_NUM_CONCURRENT_NODE_FETCHES; private int maxNumConcurrentNeighborhoodFetches = DEFAULT_MAX_NUM_CONCURRENT_NEIGHBOR_FETCHES; + private int maxNumConcurrentDeleteFromLayer = DEFAULT_MAX_NUM_CONCURRENT_DELETE_FROM_LAYER; public ConfigBuilder() { } - public ConfigBuilder(final boolean deterministicSeeding, @Nonnull final Metric metric, final boolean useInlining, - final int m, final int mMax, final int mMax0, final int efConstruction, + public ConfigBuilder(@Nonnull final Metric metric, final boolean useInlining, final int m, final int mMax, + final int mMax0, final int efConstruction, final int efRepair, final boolean extendCandidates, final boolean keepPrunedConnections, final double sampleVectorStatsProbability, final double maintainStatsProbability, final int statsThreshold, final boolean useRaBitQ, final int raBitQNumExBits, - final int maxNumConcurrentNodeFetches, final int maxNumConcurrentNeighborhoodFetches) { - this.deterministicSeeding = deterministicSeeding; + final int maxNumConcurrentNodeFetches, final int maxNumConcurrentNeighborhoodFetches, + final int maxNumConcurrentDeleteFromLayer) { this.metric = metric; this.useInlining = useInlining; this.m = m; this.mMax = mMax; this.mMax0 = mMax0; this.efConstruction = efConstruction; + this.efRepair = efRepair; this.extendCandidates = extendCandidates; this.keepPrunedConnections = keepPrunedConnections; this.sampleVectorStatsProbability = sampleVectorStatsProbability; @@ -402,16 +423,7 @@ public ConfigBuilder(final boolean deterministicSeeding, @Nonnull final Metric m this.raBitQNumExBits = raBitQNumExBits; this.maxNumConcurrentNodeFetches = maxNumConcurrentNodeFetches; this.maxNumConcurrentNeighborhoodFetches = maxNumConcurrentNeighborhoodFetches; - } - - public boolean isDeterministicSeeding() { - return deterministicSeeding; - } - - @Nonnull - public ConfigBuilder setDeterministicSeeding(final boolean deterministicSeeding) { - this.deterministicSeeding = deterministicSeeding; - return this; + this.maxNumConcurrentDeleteFromLayer = maxNumConcurrentDeleteFromLayer; } @Nonnull @@ -475,6 +487,16 @@ public ConfigBuilder setEfConstruction(final int efConstruction) { return this; } + public int getEfRepair() { + return efRepair; + } + + @Nonnull + public ConfigBuilder setEfRepair(final int efRepair) { + this.efRepair = efRepair; + return this; + } + public boolean isExtendCandidates() { return extendCandidates; } @@ -563,12 +585,21 @@ public ConfigBuilder setMaxNumConcurrentNeighborhoodFetches(final int maxNumConc return this; } + public int getMaxNumConcurrentDeleteFromLayer() { + return maxNumConcurrentDeleteFromLayer; + } + + public ConfigBuilder setMaxNumConcurrentDeleteFromLayer(final int maxNumConcurrentDeleteFromLayer) { + this.maxNumConcurrentDeleteFromLayer = maxNumConcurrentDeleteFromLayer; + return this; + } + public Config build(final int numDimensions) { - return new Config(isDeterministicSeeding(), getMetric(), numDimensions, isUseInlining(), getM(), getMMax(), - getMMax0(), getEfConstruction(), isExtendCandidates(), isKeepPrunedConnections(), + return new Config(getMetric(), numDimensions, isUseInlining(), getM(), getMMax(), + getMMax0(), getEfConstruction(), getEfRepair(), isExtendCandidates(), isKeepPrunedConnections(), getSampleVectorStatsProbability(), getMaintainStatsProbability(), getStatsThreshold(), isUseRaBitQ(), getRaBitQNumExBits(), getMaxNumConcurrentNodeFetches(), - getMaxNumConcurrentNeighborhoodFetches()); + getMaxNumConcurrentNeighborhoodFetches(), getMaxNumConcurrentDeleteFromLayer()); } } } diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/DeleteNeighborsChangeSet.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/DeleteNeighborsChangeSet.java index 1d6b5ff4a6..012ea9e54d 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/DeleteNeighborsChangeSet.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/DeleteNeighborsChangeSet.java @@ -83,6 +83,16 @@ public NeighborsChangeSet getParent() { return parent; } + @Override + public boolean hasChanges() { + // + // We can probably do better by testing if the deletion has an effect on the merge, i.e. if the neighbors that + // are being deleted by this set are in fact part of the underlying set. That case is currently impossible so + // we just return true for now. + // + return true; + } + /** * Merges the neighbors from the parent context, filtering out any neighbors that have been marked as deleted. *

@@ -132,8 +142,8 @@ public void writeDelta(@Nonnull final InliningStorageAdapter storageAdapter, @No if (tuplePredicate.test(deletedNeighborPrimaryKey)) { storageAdapter.deleteNeighbor(transaction, layer, node.asInliningNode(), deletedNeighborPrimaryKey); if (logger.isTraceEnabled()) { - logger.trace("deleted neighbor of primaryKey={} targeting primaryKey={}", node.getPrimaryKey(), - deletedNeighborPrimaryKey); + logger.trace("deleted neighbor of layer={}, primaryKey={} targeting primaryKey={}", + layer, node.getPrimaryKey(), deletedNeighborPrimaryKey); } } } diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/HNSW.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/HNSW.java index 639cd2f273..15fd1fa02f 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/HNSW.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/HNSW.java @@ -26,7 +26,6 @@ import com.apple.foundationdb.annotation.API; import com.apple.foundationdb.async.AsyncUtil; import com.apple.foundationdb.async.MoreAsyncUtil; -import com.apple.foundationdb.linear.AffineOperator; import com.apple.foundationdb.linear.Estimator; import com.apple.foundationdb.linear.FhtKacRotator; import com.apple.foundationdb.linear.Metric; @@ -38,6 +37,8 @@ import com.apple.foundationdb.tuple.Tuple; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; import com.google.common.collect.Lists; import com.google.common.collect.Maps; @@ -64,6 +65,7 @@ import java.util.function.Consumer; import java.util.function.Function; import java.util.stream.Collectors; +import java.util.stream.IntStream; import static com.apple.foundationdb.async.MoreAsyncUtil.forEach; import static com.apple.foundationdb.async.MoreAsyncUtil.forLoop; @@ -194,9 +196,9 @@ public OnReadListener getOnReadListener() { } @Nonnull - private AffineOperator storageTransform(@Nullable final AccessInfo accessInfo) { + private StorageTransform storageTransform(@Nullable final AccessInfo accessInfo) { if (accessInfo == null || !accessInfo.canUseRaBitQ()) { - return AffineOperator.identity(); + return StorageTransform.identity(); } return new StorageTransform(accessInfo.getRotatorSeed(), @@ -256,7 +258,7 @@ private Quantizer quantizer(@Nullable final AccessInfo accessInfo) { } final EntryNodeReference entryNodeReference = accessInfo.getEntryNodeReference(); - final AffineOperator storageTransform = storageTransform(accessInfo); + final StorageTransform storageTransform = storageTransform(accessInfo); final Transformed transformedQueryVector = storageTransform.transform(queryVector); final Quantizer quantizer = quantizer(accessInfo); final Estimator estimator = quantizer.estimator(); @@ -266,8 +268,8 @@ private Quantizer quantizer(@Nullable final AccessInfo accessInfo) { entryNodeReference.getVector(), estimator.distance(transformedQueryVector, entryNodeReference.getVector())); - final int entryLayer = entryNodeReference.getLayer(); - return forLoop(entryLayer, entryState, + final int topLayer = entryNodeReference.getLayer(); + return forLoop(topLayer, entryState, layer -> layer > 0, layer -> layer - 1, (layer, previousNodeReference) -> { @@ -308,7 +310,7 @@ private Quantizer quantizer(@Nullable final AccessInfo accessInfo) { private CompletableFuture> searchFinalLayer(@Nonnull final StorageAdapter storageAdapter, final @Nonnull ReadTransaction readTransaction, - @Nonnull final AffineOperator storageTransform, + @Nonnull final StorageTransform storageTransform, @Nonnull final Estimator estimator, final int k, final int efSearch, @@ -324,8 +326,8 @@ private Quantizer quantizer(@Nullable final AccessInfo accessInfo) { @Nonnull private ImmutableList - postProcessNearestNeighbors(@Nonnull final AffineOperator storageTransform, final int k, - @Nonnull final List> nearestNeighbors, + postProcessNearestNeighbors(@Nonnull final StorageTransform storageTransform, final int k, + @Nonnull final List> nearestNeighbors, final boolean includeVectors) { final int lastIndex = Math.max(nearestNeighbors.size() - k, 0); @@ -335,7 +337,7 @@ private Quantizer quantizer(@Nullable final AccessInfo accessInfo) { for (int i = nearestNeighbors.size() - 1; i >= lastIndex; i --) { final var nodeReferenceAndNode = nearestNeighbors.get(i); final var nodeReference = - Objects.requireNonNull(nodeReferenceAndNode).getNodeReferenceWithDistance(); + Objects.requireNonNull(nodeReferenceAndNode).getNodeReference(); final AbstractNode node = nodeReferenceAndNode.getNode(); @Nullable final RealVector reconstructedVector = includeVectors ? storageTransform.untransform(node.asCompactNode().getVector()) : null; @@ -359,7 +361,7 @@ private Quantizer quantizer(@Nullable final AccessInfo accessInfo) { * @param storageAdapter the {@link StorageAdapter} for accessing the graph data * @param readTransaction the {@link ReadTransaction} to use for the search * @param estimator a distance estimator - * @param nodeReference the starting point for the search on this layer, which includes the node and its distance to + * @param nodeReferenceWithDistance the starting point for the search on this layer, which includes the node and its distance to * the query vector * @param layer the zero-based index of the layer to search within * @param queryVector the query vector for which to find the nearest neighbor @@ -371,21 +373,123 @@ private Quantizer quantizer(@Nullable final AccessInfo accessInfo) { private CompletableFuture greedySearchLayer(@Nonnull final StorageAdapter storageAdapter, @Nonnull final ReadTransaction readTransaction, - @Nonnull final AffineOperator storageTransform, + @Nonnull final StorageTransform storageTransform, @Nonnull final Estimator estimator, - @Nonnull final NodeReferenceWithDistance nodeReference, + @Nonnull final NodeReferenceWithDistance nodeReferenceWithDistance, final int layer, @Nonnull final Transformed queryVector) { - return searchLayer(storageAdapter, readTransaction, storageTransform, estimator, - ImmutableList.of(nodeReference), layer, 1, Maps.newConcurrentMap(), queryVector) - .thenApply(searchResult -> - Iterables.getOnlyElement(searchResult).getNodeReferenceWithDistance()); + if (storageAdapter.isInliningStorageAdapter()) { + return greedySearchInliningLayer(storageAdapter.asInliningStorageAdapter(), readTransaction, + storageTransform, estimator, nodeReferenceWithDistance, layer, queryVector); + } else { + return searchLayer(storageAdapter, readTransaction, storageTransform, estimator, + ImmutableList.of(nodeReferenceWithDistance), layer, 1, Maps.newConcurrentMap(), queryVector) + .thenApply(searchResult -> + Iterables.getOnlyElement(searchResult).getNodeReference()); + } + } + + @Nonnull + private CompletableFuture greedySearchInliningLayer(@Nonnull final InliningStorageAdapter storageAdapter, + @Nonnull final ReadTransaction readTransaction, + @Nonnull final StorageTransform storageTransform, + @Nonnull final Estimator estimator, + @Nonnull final NodeReferenceWithDistance nodeReferenceWithDistance, + final int layer, + @Nonnull final Transformed queryVector) { + final NodeFactory nodeFactory = storageAdapter.getNodeFactory(); + final Map> nodeCache = Maps.newHashMap(); + final Map> updatedNodes = Maps.newHashMap(); + + final AtomicReference nearestNodeReferenceAtomic = + new AtomicReference<>(null); + + final Queue candidates = + // This initial capacity is somewhat arbitrary as m is not necessarily a limit, + // but it gives us a number that is better than the default. + new PriorityQueue<>(config.getM(), + Comparator.comparing(NodeReferenceWithDistance::getDistance)); + candidates.add(nodeReferenceWithDistance); + + return AsyncUtil.whileTrue(() -> { + final NodeReferenceWithDistance candidateReference = Objects.requireNonNull(candidates.poll()); + return onReadListener.onAsyncRead( + fetchNodeIfNotCached(storageAdapter, readTransaction, storageTransform, layer, + candidateReference, nodeCache)) + .thenCompose(node -> { + if (node == null) { + // + // This cannot happen under normal circumstances as the storage adapter returns a node with no + // neighbors if it already has been deleted. Therefore, it is correct to throw here. + // + throw new IllegalStateException("unable to fetch node"); + } + final InliningNode candidateNode = node.asInliningNode(); + + if (updatedNodes.containsKey(candidateReference.getPrimaryKey())) { + return CompletableFuture.completedFuture(updatedNodes.get(candidateReference.getPrimaryKey())); + } + + return fetchBaseNode(readTransaction, storageTransform, candidateReference.getPrimaryKey()) + .thenAccept(baseCompactNode -> { + if (baseCompactNode == null) { + // node does not exist on layer 0 + return; + } + + // + // Node does still exist or an updated version exists -- create new reference + // and push it back into the queue + // + final Transformed baseVector = baseCompactNode.getVector(); + + final double distance = + estimator.distance(baseVector, queryVector); + + final NodeReferenceWithDistance updatedNodeReference = + new NodeReferenceWithDistance(baseCompactNode.getPrimaryKey(), + baseVector, + distance); + candidates.add(updatedNodeReference); + updatedNodes.put(candidateReference.getPrimaryKey(), + nodeFactory.create(candidateReference.getPrimaryKey(), + baseCompactNode.getVector(), candidateNode.getNeighbors())); + }) + .thenApply(ignored -> null); // keep Java happy about the return type + + }) + .thenApply(candidateNode -> { + if (candidateNode != null) { + // + // This node definitely does exist. And it's the nearest one. + // + nearestNodeReferenceAtomic.set(candidateReference); + candidates.clear(); + + // + // Find some new candidates. + // + double minDistance = candidateReference.getDistance(); + + for (final NodeReferenceWithVector neighbor : candidateNode.getNeighbors()) { + final double distance = + estimator.distance(neighbor.getVector(), queryVector); + if (distance < minDistance) { + candidates.add( + new NodeReferenceWithDistance(neighbor.getPrimaryKey(), neighbor.getVector(), + distance)); + } + } + } + return !candidates.isEmpty(); + }); + }, executor).thenApply(ignored -> nearestNodeReferenceAtomic.get()); } /** * Searches a single layer of the graph to find the nearest neighbors to a query vector. *

- * This method implements the greedy search algorithm used in HNSW (Hierarchical Navigable Small World) + * This method implements the search algorithm used in HNSW (Hierarchical Navigable Small World) * graphs for a specific layer. It begins with a set of entry points and iteratively explores the graph, * always moving towards nodes that are closer to the {@code queryVector}. *

@@ -413,10 +517,10 @@ private Quantizer quantizer(@Nullable final AccessInfo accessInfo) { * best candidate nodes found in this layer, paired with their full node data. */ @Nonnull - private CompletableFuture>> + private CompletableFuture>> searchLayer(@Nonnull final StorageAdapter storageAdapter, @Nonnull final ReadTransaction readTransaction, - @Nonnull final AffineOperator storageTransform, + @Nonnull final StorageTransform storageTransform, @Nonnull final Estimator estimator, @Nonnull final Collection nodeReferences, final int layer, @@ -450,9 +554,11 @@ private Quantizer quantizer(@Nullable final AccessInfo accessInfo) { return fetchNodeIfNotCached(storageAdapter, readTransaction, storageTransform, layer, candidate, nodeCache) .thenApply(candidateNode -> - Iterables.filter(candidateNode.getNeighbors(), + candidateNode == null + ? ImmutableList.of() + : Iterables.filter(candidateNode.getNeighbors(), neighbor -> !visited.contains(Objects.requireNonNull(neighbor).getPrimaryKey()))) - .thenCompose(neighborReferences -> fetchNeighborhood(storageAdapter, readTransaction, + .thenCompose(neighborReferences -> fetchNeighborhoodReferences(storageAdapter, readTransaction, storageTransform, layer, neighborReferences, nodeCache)) .thenApply(neighborReferences -> { for (final NodeReferenceWithVector current : neighborReferences) { @@ -484,15 +590,36 @@ private Quantizer quantizer(@Nullable final AccessInfo accessInfo) { searchResult.stream() .map(nodeReferenceAndNode -> "(primaryKey=" + - nodeReferenceAndNode.getNodeReferenceWithDistance().getPrimaryKey() + + nodeReferenceAndNode.getNodeReference().getPrimaryKey() + ",distance=" + - nodeReferenceAndNode.getNodeReferenceWithDistance().getDistance() + ")") + nodeReferenceAndNode.getNodeReference().getDistance() + ")") .collect(Collectors.joining(","))); } return searchResult; }); } + /** + * Gets a node from the cache or throws an exception. + * + * @param the type of the node reference, which must extend {@link NodeReference} + * @param primaryKey the {@link Tuple} representing the primary key of the node + * @param nodeCache the cache to check for the node + * + * @return a {@link CompletableFuture} that will be completed with the cached {@link AbstractNode} + * @throws IllegalArgumentException if the node is not already present in the cache + */ + @Nonnull + private AbstractNode + nodeFromCache(@Nonnull final Tuple primaryKey, + @Nonnull final Map> nodeCache) { + final AbstractNode nodeFromCache = nodeCache.get(primaryKey); + if (nodeFromCache == null) { + throw new IllegalStateException("node should already have been fetched: " + primaryKey); + } + return nodeFromCache; + } + /** * Asynchronously fetches a node if it is not already present in the cache. *

@@ -501,8 +628,7 @@ private Quantizer quantizer(@Nullable final AccessInfo accessInfo) { * fetched from the underlying storage using the {@code storageAdapter}. Once fetched, the node * is added to the {@code nodeCache} before the future is completed. *

- * This is a convenience method that delegates to - * {@link #fetchNodeIfNecessaryAndApply(StorageAdapter, ReadTransaction, AffineOperator, int, NodeReference, Function, BiFunction)}. + * This is a convenience method that delegates to {@link #fetchNodeIfNecessaryAndApply}. * * @param the type of the node reference, which must extend {@link NodeReference} * @param storageAdapter the storage adapter used to fetch the node from persistent storage @@ -519,14 +645,17 @@ private Quantizer quantizer(@Nullable final AccessInfo accessInfo) { private CompletableFuture> fetchNodeIfNotCached(@Nonnull final StorageAdapter storageAdapter, @Nonnull final ReadTransaction readTransaction, - @Nonnull final AffineOperator storageTransform, + @Nonnull final StorageTransform storageTransform, final int layer, @Nonnull final NodeReference nodeReference, @Nonnull final Map> nodeCache) { return fetchNodeIfNecessaryAndApply(storageAdapter, readTransaction, storageTransform, layer, nodeReference, nR -> nodeCache.get(nR.getPrimaryKey()), (nR, node) -> { - nodeCache.put(nR.getPrimaryKey(), node); + // TODO maybe use a placeholder instance for null so we won't try multiple times + if (node != null) { + nodeCache.put(nR.getPrimaryKey(), node); + } return node; }); } @@ -565,7 +694,7 @@ private Quantizer quantizer(@Nullable final AccessInfo accessInfo) { private CompletableFuture fetchNodeIfNecessaryAndApply(@Nonnull final StorageAdapter storageAdapter, @Nonnull final ReadTransaction readTransaction, - @Nonnull final AffineOperator storageTransform, + @Nonnull final StorageTransform storageTransform, final int layer, @Nonnull final R nodeReference, @Nonnull final Function fetchBypassFunction, @@ -578,7 +707,7 @@ private Quantizer quantizer(@Nullable final AccessInfo accessInfo) { return onReadListener.onAsyncRead( storageAdapter.fetchNode(readTransaction, storageTransform, layer, nodeReference.getPrimaryKey())) - .thenApply(node -> biMapFunction.apply(nodeReference, Objects.requireNonNull(node))); + .thenApply(node -> biMapFunction.apply(nodeReference, node)); } /** @@ -605,12 +734,12 @@ private Quantizer quantizer(@Nullable final AccessInfo accessInfo) { */ @Nonnull private CompletableFuture> - fetchNeighborhood(@Nonnull final StorageAdapter storageAdapter, - @Nonnull final ReadTransaction readTransaction, - @Nonnull final AffineOperator storageTransform, - final int layer, - @Nonnull final Iterable neighborReferences, - @Nonnull final Map> nodeCache) { + fetchNeighborhoodReferences(@Nonnull final StorageAdapter storageAdapter, + @Nonnull final ReadTransaction readTransaction, + @Nonnull final StorageTransform storageTransform, + final int layer, + @Nonnull final Iterable neighborReferences, + @Nonnull final Map> nodeCache) { return fetchSomeNodesAndApply(storageAdapter, readTransaction, storageTransform, layer, neighborReferences, neighborReference -> { if (neighborReference.isNodeReferenceWithVector()) { @@ -624,13 +753,17 @@ private Quantizer quantizer(@Nullable final AccessInfo accessInfo) { neighborNode.asCompactNode().getVector()); }, (neighborReference, neighborNode) -> { - // - // At this point we know that the node needed to be fetched which excludes INLINING nodes - // as they never have to be fetched. Therefore, we can safely treat the nodes as compact nodes. - // - nodeCache.put(neighborReference.getPrimaryKey(), neighborNode); - return new NodeReferenceWithVector(neighborReference.getPrimaryKey(), - neighborNode.asCompactNode().getVector()); + if (neighborNode != null) { + // + // At this point we know that the node needed to be fetched, which means this branch cannot be + // reached for INLINING nodes as they never have to be fetched. Therefore, we can safely treat + // the nodes as compact nodes. + // + nodeCache.put(neighborReference.getPrimaryKey(), neighborNode); + return new NodeReferenceWithVector(neighborReference.getPrimaryKey(), + neighborNode.asCompactNode().getVector()); + } + return null; }); } @@ -659,12 +792,12 @@ private Quantizer quantizer(@Nullable final AccessInfo accessInfo) { * objects, pairing each requested reference with its corresponding node. */ @Nonnull - private CompletableFuture>> + private CompletableFuture>> fetchSomeNodesIfNotCached(@Nonnull final StorageAdapter storageAdapter, @Nonnull final ReadTransaction readTransaction, - @Nonnull final AffineOperator storageTransform, + @Nonnull final StorageTransform storageTransform, final int layer, - @Nonnull final Iterable nodeReferences, + @Nonnull final Iterable nodeReferences, @Nonnull final Map> nodeCache) { return fetchSomeNodesAndApply(storageAdapter, readTransaction, storageTransform, layer, nodeReferences, nodeReference -> { @@ -674,9 +807,12 @@ private Quantizer quantizer(@Nullable final AccessInfo accessInfo) { } return new NodeReferenceAndNode<>(nodeReference, node); }, - (nodeReferenceWithDistance, node) -> { - nodeCache.put(nodeReferenceWithDistance.getPrimaryKey(), node); - return new NodeReferenceAndNode<>(nodeReferenceWithDistance, node); + (nodeReference, node) -> { + if (node != null) { + nodeCache.put(nodeReference.getPrimaryKey(), node); + return new NodeReferenceAndNode<>(nodeReference, node); + } + return null; }); } @@ -704,14 +840,14 @@ private Quantizer quantizer(@Nullable final AccessInfo accessInfo) { * @param biMapFunction The function to apply when a node is successfully fetched, mapping the original * reference and the fetched {@link AbstractNode} to a result of type {@code U}. * - * @return A {@link CompletableFuture} that, upon completion, will hold a {@link java.util.List} of results - * of type {@code U}, corresponding to each processed node reference. + * @return A {@link CompletableFuture} that, upon completion, will hold a {@link java.util.List} of non-null results + * of type {@code U} */ @Nonnull private CompletableFuture> fetchSomeNodesAndApply(@Nonnull final StorageAdapter storageAdapter, @Nonnull final ReadTransaction readTransaction, - @Nonnull final AffineOperator storageTransform, + @Nonnull final StorageTransform storageTransform, final int layer, @Nonnull final Iterable nodeReferences, @Nonnull final Function fetchBypassFunction, @@ -720,19 +856,28 @@ private Quantizer quantizer(@Nullable final AccessInfo accessInfo) { currentNeighborReference -> fetchNodeIfNecessaryAndApply(storageAdapter, readTransaction, storageTransform, layer, currentNeighborReference, fetchBypassFunction, biMapFunction), getConfig().getMaxNumConcurrentNodeFetches(), - getExecutor()); + getExecutor()) + .thenApply(results -> { + final ImmutableList.Builder filteredListBuilder = ImmutableList.builder(); + for (final U result : results) { + if (result != null) { + filteredListBuilder.add(result); + } + } + return filteredListBuilder.build(); + }); } /** * Inserts a new vector with its associated primary key into the HNSW graph. *

- * The method first determines a random layer for the new node, called the {@code insertionLayer}. + * The method first determines a layer for the new node, called the {@code top layer}. * It then traverses the graph from the entry point downwards, greedily searching for the nearest * neighbors to the {@code newVector} at each layer. This search identifies the optimal * connection points for the new node. *

* Once the nearest neighbors are found, the new node is linked into the graph structure at all - * layers up to its {@code insertionLayer}. Special handling is included for inserting the + * layers up to its {@code top layer}. Special handling is included for inserting the * first-ever node into the graph or when a new node's layer is higher than any existing node, * which updates the graph's entry point. All operations are performed asynchronously. * @@ -746,7 +891,7 @@ private Quantizer quantizer(@Nullable final AccessInfo accessInfo) { public CompletableFuture insert(@Nonnull final Transaction transaction, @Nonnull final Tuple newPrimaryKey, @Nonnull final RealVector newVector) { final SplittableRandom random = random(newPrimaryKey); - final int insertionLayer = insertionLayer(random); + final int insertionLayer = topLayer(newPrimaryKey); if (logger.isTraceEnabled()) { logger.trace("new node with key={} selected to be inserted into layer={}", newPrimaryKey, insertionLayer); } @@ -768,7 +913,7 @@ public CompletableFuture insert(@Nonnull final Transaction transaction, @N } final AccessInfo accessInfo = accessInfoAndNodeExistence.getAccessInfo(); - final AffineOperator storageTransform = storageTransform(accessInfo); + final StorageTransform storageTransform = storageTransform(accessInfo); final Transformed transformedNewVector = storageTransform.transform(newVector); final Quantizer quantizer = quantizer(accessInfo); final Estimator estimator = quantizer.estimator(); @@ -830,24 +975,81 @@ public CompletableFuture insert(@Nonnull final Transaction transaction, @N insertIntoLayers(transaction, storageTransform, quantizer, newPrimaryKey, transformedNewVector, nodeReference, lMax, insertionLayer)) .thenCompose(ignored -> - addToStatsIfNecessary(random.split(), transaction, currentAccessInfo, transformedNewVector)); + addToStatsIfNecessary(random, transaction, currentAccessInfo, transformedNewVector)); }).thenCompose(ignored -> AsyncUtil.DONE); } @Nonnull - @VisibleForTesting - CompletableFuture exists(@Nonnull final ReadTransaction readTransaction, - @Nonnull final Tuple primaryKey) { - final StorageAdapter storageAdapter = getStorageAdapterForLayer(0); + private CompletableFuture>> + filterExisting(@Nonnull final StorageAdapter storageAdapter, + @Nonnull final ReadTransaction readTransaction, + @Nonnull final StorageTransform storageTransform, + @Nonnull final Iterable> nodeReferenceAndNodes) { + if (!storageAdapter.isInliningStorageAdapter()) { + return CompletableFuture.completedFuture(ImmutableList.copyOf(nodeReferenceAndNodes)); + } + + return forEach(nodeReferenceAndNodes, + nodeReferenceAndNode -> { + final AbstractNode node = nodeReferenceAndNode.getNode(); + final NodeReferenceWithVector nodeReference = nodeReferenceAndNode.getNodeReference(); + return fetchBaseNode(readTransaction, storageTransform, nodeReference.getPrimaryKey()) + .thenApply(baseCompactNode -> { + if (baseCompactNode == null) { + return null; + } + // + // The node does exist on layer 0 meaning the base node is a compact node, and we + // can use its vector going forward. This may be necessary if this is a dangling + // reference and the record has been reinserted after deletion. + // + final NodeReferenceWithVector updatedNodeReference = + new NodeReferenceWithVector(baseCompactNode.getPrimaryKey(), + baseCompactNode.getVector()); + return new NodeReferenceAndNode<>(updatedNodeReference, node); + }); + }, + getConfig().getMaxNumConcurrentNodeFetches(), + getExecutor()) + .thenApply(results -> { + final ImmutableList.Builder> filteredListBuilder = + ImmutableList.builder(); + for (final NodeReferenceAndNode result : results) { + if (result != null) { + filteredListBuilder.add(result); + } + } + return filteredListBuilder.build(); + }); + } + + @Nonnull + private CompletableFuture exists(@Nonnull final ReadTransaction readTransaction, + @Nonnull final Tuple primaryKey) { // - // Call fetchNode() to check for the node's existence; we are handing in the identity operator, since we don't - // care about the vector itself at all. + // Call fetchBaseNode() to check for the node's existence; we are handing in the identity operator, + // since we do not care about the vector itself at all. // - return storageAdapter.fetchNode(readTransaction, AffineOperator.identity(), 0, primaryKey) + return fetchBaseNode(readTransaction, StorageTransform.identity(), primaryKey) .thenApply(Objects::nonNull); } + @Nonnull + private CompletableFuture fetchBaseNode(@Nonnull final ReadTransaction readTransaction, + @Nonnull final StorageTransform storageTransform, + @Nonnull final Tuple primaryKey) { + final StorageAdapter storageAdapter = getStorageAdapterForLayer(0); + + return storageAdapter.fetchNode(readTransaction, storageTransform, 0, primaryKey) + .thenApply(node -> { + if (node == null) { + return null; + } + return node.asCompactNode(); + }); + } + /** * Method to keep stats if necessary. Stats need to be kept and maintained when the client would like to use * e.g. RaBitQ as RaBitQ needs a stable somewhat correct centroid in order to function properly. @@ -919,8 +1121,8 @@ private CompletableFuture addToStatsIfNecessary(@Nonnull final SplittableR final AccessInfo newAccessInfo = new AccessInfo(currentAccessInfo.getEntryNodeReference().withVector(transformedEntryNodeVector), rotatorSeed, rotatedCentroid); - StorageAdapter.writeAccessInfo(transaction, getSubspace(), newAccessInfo, onWriteListener); - StorageAdapter.removeAllSampledVectors(transaction, getSubspace()); + StorageAdapter.writeAccessInfo(transaction, getSubspace(), newAccessInfo, getOnWriteListener()); + StorageAdapter.deleteAllSampledVectors(transaction, getSubspace(), getOnWriteListener()); if (logger.isTraceEnabled()) { logger.trace("established rotatorSeed={}, centroid with count={}, centroid={}", rotatorSeed, partialCount, rotatedCentroid); @@ -951,8 +1153,8 @@ private AggregatedVector aggregateVectors(@Nonnull final Iterable * This method implements the second phase of the HNSW insertion algorithm. It begins at a starting layer, which is * the minimum of the graph's maximum layer ({@code lMax}) and the new node's randomly assigned - * {@code insertionLayer}. It then iterates downwards to layer 0. In each layer, it invokes - * {@link #insertIntoLayer(StorageAdapter, Transaction, AffineOperator, Quantizer, List, int, Tuple, Transformed)} + * {@code layer}. It then iterates downwards to layer 0. In each layer, it invokes + * {@link #insertIntoLayer(StorageAdapter, Transaction, StorageTransform, Quantizer, List, int, Tuple, Transformed)} * to perform the search and connect the new node. The set of nearest neighbors found at layer {@code L} serves as * the entry points for the search at layer {@code L-1}. *

@@ -974,7 +1176,7 @@ private AggregatedVector aggregateVectors(@Nonnull final Iterable insertIntoLayers(@Nonnull final Transaction transaction, - @Nonnull final AffineOperator storageTransform, + @Nonnull final StorageTransform storageTransform, @Nonnull final Quantizer quantizer, @Nonnull final Tuple newPrimaryKey, @Nonnull final Transformed newVector, @@ -990,7 +1192,8 @@ private CompletableFuture insertIntoLayers(@Nonnull final Transaction tran (layer, previousNodeReferences) -> { final StorageAdapter storageAdapter = getStorageAdapterForLayer(layer); return insertIntoLayer(storageAdapter, transaction, storageTransform, quantizer, - previousNodeReferences, layer, newPrimaryKey, newVector); + previousNodeReferences, layer, newPrimaryKey, newVector) + .thenApply(NodeReferenceAndNode::getReferences); }, executor).thenCompose(ignored -> AsyncUtil.DONE); } @@ -1027,14 +1230,14 @@ private CompletableFuture insertIntoLayers(@Nonnull final Transaction tran * @param newVector the vector associated with the new node * * @return a {@code CompletableFuture} that completes with a list of the nearest neighbors found during the - * initial search phase. This list serves as the entry point for insertion into the next lower layer - * (i.e., {@code layer - 1}). + * initial search phase. This list serves as the entry point for insertion into the next lower layer + * (i.e., {@code layer - 1}). */ @Nonnull - private CompletableFuture> + private CompletableFuture>> insertIntoLayer(@Nonnull final StorageAdapter storageAdapter, @Nonnull final Transaction transaction, - @Nonnull final AffineOperator storageTransform, + @Nonnull final StorageTransform storageTransform, @Nonnull final Quantizer quantizer, @Nonnull final List nearestNeighbors, final int layer, @@ -1048,63 +1251,75 @@ private CompletableFuture insertIntoLayers(@Nonnull final Transaction tran return searchLayer(storageAdapter, transaction, storageTransform, estimator, nearestNeighbors, layer, config.getEfConstruction(), nodeCache, newVector) - .thenCompose(searchResult -> { - final List references = NodeReferenceAndNode.getReferences(searchResult); - - return selectNeighbors(storageAdapter, transaction, storageTransform, estimator, searchResult, - layer, getConfig().getM(), getConfig().isExtendCandidates(), nodeCache, newVector) - .thenCompose(selectedNeighbors -> { - final NodeFactory nodeFactory = storageAdapter.getNodeFactory(); - - final AbstractNode newNode = - nodeFactory.create(newPrimaryKey, newVector, - NodeReferenceAndNode.getReferences(selectedNeighbors)); - - final NeighborsChangeSet newNodeChangeSet = - new InsertNeighborsChangeSet<>(new BaseNeighborsChangeSet<>(ImmutableList.of()), - newNode.getNeighbors()); - - storageAdapter.writeNode(transaction, quantizer, newNode, layer, newNodeChangeSet); - - // create change sets for each selected neighbor and insert new node into them - final Map> neighborChangeSetMap = - Maps.newLinkedHashMap(); - for (final NodeReferenceAndNode selectedNeighbor : selectedNeighbors) { - final NeighborsChangeSet baseSet = - new BaseNeighborsChangeSet<>(selectedNeighbor.getNode().getNeighbors()); - final NeighborsChangeSet insertSet = - new InsertNeighborsChangeSet<>(baseSet, ImmutableList.of(newNode.getSelfReference(newVector))); - neighborChangeSetMap.put(selectedNeighbor.getNode().getPrimaryKey(), - insertSet); - } + .thenCompose(searchResult -> + extendCandidatesIfNecessary(storageAdapter, transaction, storageTransform, estimator, + searchResult, layer, getConfig().isExtendCandidates(), nodeCache, newVector) + .thenCompose(extendedCandidates -> + selectCandidates(storageAdapter, transaction, storageTransform, estimator, + extendedCandidates, layer, getConfig().getM(), nodeCache)) + .thenCompose(selectedNeighbors -> { + final NodeFactory nodeFactory = storageAdapter.getNodeFactory(); + + final AbstractNode newNode = + nodeFactory.create(newPrimaryKey, newVector, + NodeReferenceAndNode.getReferences(selectedNeighbors)); + + final NeighborsChangeSet newNodeChangeSet = + new InsertNeighborsChangeSet<>( + new BaseNeighborsChangeSet<>(ImmutableList.of()), + newNode.getNeighbors()); + + storageAdapter.writeNode(transaction, quantizer, layer, newNode, + newNodeChangeSet); + + // create change sets for each selected neighbor and insert new node into them + final Map> neighborChangeSetMap = + Maps.newLinkedHashMap(); + for (final NodeReferenceAndNode selectedNeighbor : selectedNeighbors) { + final NeighborsChangeSet baseSet = + new BaseNeighborsChangeSet<>( + selectedNeighbor.getNode().getNeighbors()); + final NeighborsChangeSet insertSet = + new InsertNeighborsChangeSet<>(baseSet, + ImmutableList.of(newNode.getSelfReference(newVector))); + neighborChangeSetMap.put(selectedNeighbor.getNode().getPrimaryKey(), + insertSet); + } - final int currentMMax = layer == 0 ? getConfig().getMMax0() : getConfig().getMMax(); - return forEach(selectedNeighbors, - selectedNeighbor -> { - final AbstractNode selectedNeighborNode = selectedNeighbor.getNode(); - final NeighborsChangeSet changeSet = - Objects.requireNonNull(neighborChangeSetMap.get(selectedNeighborNode.getPrimaryKey())); - return pruneNeighborsIfNecessary(storageAdapter, transaction, - storageTransform, estimator, selectedNeighbor, layer, - currentMMax, changeSet, nodeCache) - .thenApply(nodeReferencesAndNodes -> { - if (nodeReferencesAndNodes == null) { - return changeSet; - } - return resolveChangeSetFromNewNeighbors(changeSet, nodeReferencesAndNodes); - }); - }, getConfig().getMaxNumConcurrentNeighborhoodFetches(), getExecutor()) - .thenApply(changeSets -> { - for (int i = 0; i < selectedNeighbors.size(); i++) { - final NodeReferenceAndNode selectedNeighbor = selectedNeighbors.get(i); - final NeighborsChangeSet changeSet = changeSets.get(i); - storageAdapter.writeNode(transaction, quantizer, - selectedNeighbor.getNode(), layer, changeSet); - } - return ImmutableList.copyOf(references); - }); - }); - }).thenApply(nodeReferencesWithDistances -> { + final int currentMMax = + layer == 0 ? getConfig().getMMax0() : getConfig().getMMax(); + + return forEach(selectedNeighbors, + selectedNeighbor -> { + final NodeReferenceWithDistance selectedNeighborReference = + selectedNeighbor.getNodeReference(); + final AbstractNode selectedNeighborNode = selectedNeighbor.getNode(); + final NeighborsChangeSet changeSet = + Objects.requireNonNull(neighborChangeSetMap.get(selectedNeighborNode.getPrimaryKey())); + return pruneNeighborsIfNecessary(storageAdapter, transaction, + storageTransform, estimator, layer, selectedNeighborReference, + currentMMax, changeSet, nodeCache) + .thenApply(nodeReferencesAndNodes -> { + if (nodeReferencesAndNodes == null) { + return changeSet; + } + return resolveChangeSetFromNewNeighbors(changeSet, nodeReferencesAndNodes); + }); + }, getConfig().getMaxNumConcurrentNeighborhoodFetches(), getExecutor()) + .thenApply(changeSets -> { + for (int i = 0; i < selectedNeighbors.size(); i++) { + final NodeReferenceAndNode selectedNeighbor = + selectedNeighbors.get(i); + final NeighborsChangeSet changeSet = changeSets.get(i); + if (changeSet.hasChanges()) { + storageAdapter.writeNode(transaction, quantizer, + layer, selectedNeighbor.getNode(), changeSet); + } + } + return ImmutableList.copyOf(searchResult); + }); + })) + .thenApply(nodeReferencesWithDistances -> { if (logger.isTraceEnabled()) { logger.trace("end insert key={} at layer={}", newPrimaryKey, layer); } @@ -1134,15 +1349,15 @@ layer, getConfig().getM(), getConfig().isExtendCandidates(), nodeCache, newVecto */ private NeighborsChangeSet resolveChangeSetFromNewNeighbors(@Nonnull final NeighborsChangeSet beforeChangeSet, - @Nonnull final Iterable> afterNeighbors) { + @Nonnull final Iterable> afterNeighbors) { final Map beforeNeighborsMap = Maps.newLinkedHashMap(); for (final N n : beforeChangeSet.merge()) { beforeNeighborsMap.put(n.getPrimaryKey(), n); } final Map afterNeighborsMap = Maps.newLinkedHashMap(); - for (final NodeReferenceAndNode nodeReferenceAndNode : afterNeighbors) { - final NodeReferenceWithDistance nodeReferenceWithDistance = nodeReferenceAndNode.getNodeReferenceWithDistance(); + for (final NodeReferenceAndNode nodeReferenceAndNode : afterNeighbors) { + final NodeReferenceWithDistance nodeReferenceWithDistance = nodeReferenceAndNode.getNodeReference(); afterNeighborsMap.put(nodeReferenceWithDistance.getPrimaryKey(), nodeReferenceAndNode.getNode().getSelfReference(nodeReferenceWithDistance.getVector())); @@ -1189,28 +1404,27 @@ layer, getConfig().getM(), getConfig().isExtendCandidates(), nodeCache, newVecto * @param estimator an estimator to estimate distances * @param storageTransform an affine transformation operator that is used to transform the fetched vector into the * storage space that is currently being used - * @param selectedNeighbor the node whose neighborhood is being considered for pruning + * @param nodeReferenceWithVector the node reference of the node whose neighborhood is being considered for pruning * @param layer the graph layer on which the operation is performed * @param mMax the maximum number of neighbors a node is allowed to have on this layer * @param neighborChangeSet a set of pending changes to the neighborhood that must be included in the pruning - * calculation + * calculation * @param nodeCache a cache of nodes to avoid redundant database fetches * * @return a {@link CompletableFuture} which completes with a list of the newly selected neighbors for the pruned node. * If no pruning was necessary, it completes with {@code null}. */ @Nonnull - private CompletableFuture>> + private CompletableFuture>> pruneNeighborsIfNecessary(@Nonnull final StorageAdapter storageAdapter, @Nonnull final Transaction transaction, - @Nonnull final AffineOperator storageTransform, + @Nonnull final StorageTransform storageTransform, @Nonnull final Estimator estimator, - @Nonnull final NodeReferenceAndNode selectedNeighbor, final int layer, + @Nonnull final NodeReferenceWithVector nodeReferenceWithVector, final int mMax, @Nonnull final NeighborsChangeSet neighborChangeSet, @Nonnull final Map> nodeCache) { - final AbstractNode selectedNeighborNode = selectedNeighbor.getNode(); final int numNeighbors = Iterables.size(neighborChangeSet.merge()); // this is a view over the iterable neighbors in the set if (numNeighbors < mMax) { @@ -1218,29 +1432,25 @@ layer, getConfig().getM(), getConfig().isExtendCandidates(), nodeCache, newVecto } else { if (logger.isTraceEnabled()) { logger.trace("pruning neighborhood of key={} which has numNeighbors={} out of mMax={}", - selectedNeighborNode.getPrimaryKey(), numNeighbors, mMax); + nodeReferenceWithVector.getPrimaryKey(), numNeighbors, mMax); } - return fetchNeighborhood(storageAdapter, transaction, storageTransform, layer, neighborChangeSet.merge(), nodeCache) - .thenCompose(nodeReferenceWithVectors -> { + return fetchNeighborhoodReferences(storageAdapter, transaction, storageTransform, layer, neighborChangeSet.merge(), nodeCache) + .thenApply(neighborReferenceWithVectors -> { final ImmutableList.Builder nodeReferencesWithDistancesBuilder = ImmutableList.builder(); - for (final NodeReferenceWithVector nodeReferenceWithVector : nodeReferenceWithVectors) { - final var vector = nodeReferenceWithVector.getVector(); - final double distance = - estimator.distance(vector, - selectedNeighbor.getNodeReferenceWithDistance().getVector()); + for (final NodeReferenceWithVector neighborReferenceWithVector : neighborReferenceWithVectors) { + final var neighborVector = neighborReferenceWithVector.getVector(); + final double distance = estimator.distance(neighborVector, nodeReferenceWithVector.getVector()); nodeReferencesWithDistancesBuilder.add( - new NodeReferenceWithDistance(nodeReferenceWithVector.getPrimaryKey(), - vector, distance)); + new NodeReferenceWithDistance(neighborReferenceWithVector.getPrimaryKey(), + neighborVector, distance)); } - return fetchSomeNodesIfNotCached(storageAdapter, transaction, storageTransform, layer, - nodeReferencesWithDistancesBuilder.build(), nodeCache); + return nodeReferencesWithDistancesBuilder.build(); }) .thenCompose(nodeReferencesAndNodes -> - selectNeighbors(storageAdapter, transaction, storageTransform, estimator, + selectCandidates(storageAdapter, transaction, storageTransform, estimator, nodeReferencesAndNodes, layer, - mMax, false, nodeCache, - selectedNeighbor.getNodeReferenceWithDistance().getVector())); + mMax, nodeCache)); } } @@ -1266,80 +1476,72 @@ layer, getConfig().getM(), getConfig().isExtendCandidates(), nodeCache, newVecto * @param estimator the estimator in use * @param storageTransform an affine transformation operator that is used to transform the fetched vector into the * storage space that is currently being used - * @param nearestNeighbors the initial pool of candidate neighbors, typically from a search in a higher layer + * @param initialCandidates the initial pool of candidate neighbors, typically from a search in a higher layer * @param layer the layer in the HNSW graph where the selection is being performed * @param m the maximum number of neighbors to select - * @param isExtendCandidates a flag indicating whether to extend the initial candidate pool by fetching the * neighbors of the {@code nearestNeighbors} * @param nodeCache a cache of nodes to avoid redundant storage lookups - * @param vector the query vector for which neighbors are being selected * * @return a {@link CompletableFuture} which will complete with a list of the selected neighbors, * each represented as a {@link NodeReferenceAndNode} */ - private CompletableFuture>> - selectNeighbors(@Nonnull final StorageAdapter storageAdapter, - @Nonnull final ReadTransaction readTransaction, - @Nonnull final AffineOperator storageTransform, - @Nonnull final Estimator estimator, - @Nonnull final Iterable> nearestNeighbors, - final int layer, - final int m, - final boolean isExtendCandidates, - @Nonnull final Map> nodeCache, - @Nonnull final Transformed vector) { + private CompletableFuture>> + selectCandidates(@Nonnull final StorageAdapter storageAdapter, + @Nonnull final ReadTransaction readTransaction, + @Nonnull final StorageTransform storageTransform, + @Nonnull final Estimator estimator, + @Nonnull final Iterable initialCandidates, + final int layer, + final int m, + @Nonnull final Map> nodeCache) { final Metric metric = getConfig().getMetric(); - return extendCandidatesIfNecessary(storageAdapter, readTransaction, storageTransform, estimator, - nearestNeighbors, layer, isExtendCandidates, nodeCache, vector) - .thenApply(extendedCandidates -> { - final List selected = Lists.newArrayListWithExpectedSize(m); - final Queue candidates = - new PriorityQueue<>(extendedCandidates.size(), - Comparator.comparing(NodeReferenceWithDistance::getDistance)); - candidates.addAll(extendedCandidates); - final Queue discardedCandidates = - getConfig().isKeepPrunedConnections() - ? new PriorityQueue<>(config.getM(), - Comparator.comparing(NodeReferenceWithDistance::getDistance)) - : null; - - while (!candidates.isEmpty() && selected.size() < m) { - final NodeReferenceWithDistance nearestCandidate = candidates.poll(); - boolean shouldSelect = true; - // if the metric does not support triangle inequality, we shold not use the heuristic - if (metric.satisfiesTriangleInequality()) { - for (final NodeReferenceWithDistance alreadySelected : selected) { - if (estimator.distance(nearestCandidate.getVector(), - alreadySelected.getVector()) < nearestCandidate.getDistance()) { - shouldSelect = false; - break; - } - } - } - if (shouldSelect) { - selected.add(nearestCandidate); - } else if (discardedCandidates != null) { - discardedCandidates.add(nearestCandidate); - } - } - if (discardedCandidates != null) { // isKeepPrunedConnections is set to true - while (!discardedCandidates.isEmpty() && selected.size() < m) { - selected.add(discardedCandidates.poll()); - } + final List selected = Lists.newArrayListWithExpectedSize(m); + final Queue candidates = + new PriorityQueue<>(getConfig().getM(), + Comparator.comparing(NodeReferenceWithDistance::getDistance)); + initialCandidates.forEach(candidates::add); + final Queue discardedCandidates = + getConfig().isKeepPrunedConnections() + ? new PriorityQueue<>(config.getM(), + Comparator.comparing(NodeReferenceWithDistance::getDistance)) + : null; + + while (!candidates.isEmpty() && selected.size() < m) { + final NodeReferenceWithDistance nearestCandidate = candidates.poll(); + boolean shouldSelect = true; + // if the metric does not support triangle inequality, we shold not use the heuristic + if (metric.satisfiesTriangleInequality()) { + for (final NodeReferenceWithDistance alreadySelected : selected) { + if (estimator.distance(nearestCandidate.getVector(), + alreadySelected.getVector()) < nearestCandidate.getDistance()) { + shouldSelect = false; + break; } + } + } + if (shouldSelect) { + selected.add(nearestCandidate); + } else if (discardedCandidates != null) { + discardedCandidates.add(nearestCandidate); + } + } - return ImmutableList.copyOf(selected); - }).thenCompose(selectedNeighbors -> - fetchSomeNodesIfNotCached(storageAdapter, readTransaction, storageTransform, layer, - selectedNeighbors, nodeCache)) + if (discardedCandidates != null) { // isKeepPrunedConnections is set to true + while (!discardedCandidates.isEmpty() && selected.size() < m) { + selected.add(discardedCandidates.poll()); + } + } + + return fetchSomeNodesIfNotCached(storageAdapter, readTransaction, storageTransform, layer, + selected, nodeCache) .thenApply(selectedNeighbors -> { if (logger.isTraceEnabled()) { logger.trace("selected neighbors={}", selectedNeighbors.stream() .map(selectedNeighbor -> - "(primaryKey=" + selectedNeighbor.getNodeReferenceWithDistance().getPrimaryKey() + - ",distance=" + selectedNeighbor.getNodeReferenceWithDistance().getDistance() + ")") + "(primaryKey=" + selectedNeighbor.getNodeReference().getPrimaryKey() + + ",distance=" + selectedNeighbor.getNodeReference().getDistance() + ")") .collect(Collectors.joining(","))); } return selectedNeighbors; @@ -1363,7 +1565,7 @@ layer, getConfig().getM(), getConfig().isExtendCandidates(), nodeCache, newVecto * @param estimator the estimator * @param storageTransform an affine transformation operator that is used to transform the fetched vector into the * storage space that is currently being used - * @param candidates an {@link Iterable} of initial candidate nodes, which have already been evaluated + * @param candidates an {@link Collection} of initial candidate nodes, which have already been evaluated * @param layer the graph layer from which to fetch nodes * @param isExtendCandidates a boolean flag; if {@code true}, the candidate set is extended with neighbors * @param nodeCache a cache mapping primary keys to {@link AbstractNode} objects to avoid redundant fetches @@ -1375,56 +1577,163 @@ layer, getConfig().getM(), getConfig().isExtendCandidates(), nodeCache, newVecto private CompletableFuture> extendCandidatesIfNecessary(@Nonnull final StorageAdapter storageAdapter, @Nonnull final ReadTransaction readTransaction, - @Nonnull final AffineOperator storageTransform, + @Nonnull final StorageTransform storageTransform, @Nonnull final Estimator estimator, - @Nonnull final Iterable> candidates, - int layer, - boolean isExtendCandidates, + @Nonnull final Collection> candidates, + final int layer, + final boolean isExtendCandidates, @Nonnull final Map> nodeCache, @Nonnull final Transformed vector) { + final ImmutableList.Builder resultBuilder = ImmutableList.builder(); + if (isExtendCandidates) { - final Set candidatesSeen = Sets.newConcurrentHashSet(); - for (final NodeReferenceAndNode candidate : candidates) { - candidatesSeen.add(candidate.getNode().getPrimaryKey()); + return neighborReferences(storageAdapter, readTransaction, storageTransform, null, candidates, + CandidatePredicate.tautology(), layer, nodeCache) + .thenApply(neighborsOfCandidates -> { + for (final NodeReferenceWithVector nodeReferenceWithVector : neighborsOfCandidates) { + final double distance = estimator.distance(nodeReferenceWithVector.getVector(), vector); + resultBuilder.add(new NodeReferenceWithDistance(nodeReferenceWithVector.getPrimaryKey(), + nodeReferenceWithVector.getVector(), distance)); + } + return resultBuilder.build(); + }); + } else { + // + // Add all given candidates to the result. + // + for (final NodeReferenceAndNode candidate : candidates) { + resultBuilder.add(candidate.getNodeReference()); } - final ImmutableList.Builder neighborsOfCandidatesBuilder = ImmutableList.builder(); - for (final NodeReferenceAndNode candidate : candidates) { - for (final N neighbor : candidate.getNode().getNeighbors()) { - final Tuple neighborPrimaryKey = neighbor.getPrimaryKey(); - if (!candidatesSeen.contains(neighborPrimaryKey)) { - candidatesSeen.add(neighborPrimaryKey); - neighborsOfCandidatesBuilder.add(neighbor); - } - } - } + return CompletableFuture.completedFuture(resultBuilder.build()); + } + } - final Iterable neighborsOfCandidates = neighborsOfCandidatesBuilder.build(); + /** + * Compute and if necessary fetch the neighbor references (with vectors) and the neighboring nodes of an iterable + * of initial nodes that is passed in. Note that the neighbor of an initial node might be another initial node. + * If that is the case the node is returned. + * + * @param the type of the {@link NodeReference} + * @param storageAdapter the {@link StorageAdapter} used to access node data from storage + * @param readTransaction the active {@link ReadTransaction} for database access + * @param storageTransform an affine transformation operator that is used to transform the fetched vector into the + * storage space that is currently being used + * @param random the random to be used for sampling + * @param initialNodeReferenceAndNodes an {@link Iterable} of initial candidate nodes, which have already been evaluated + * @param layer the graph layer from which to fetch nodes + * @param nodeCache a cache mapping primary keys to {@link AbstractNode} objects to avoid redundant fetches + * + * @return a {@link CompletableFuture} which will complete with a list of fetched nodes + */ + private CompletableFuture>> + neighbors(@Nonnull final StorageAdapter storageAdapter, + @Nonnull final ReadTransaction readTransaction, + @Nonnull final StorageTransform storageTransform, + @Nonnull final SplittableRandom random, + @Nonnull final Collection> initialNodeReferenceAndNodes, + @Nonnull final CandidatePredicate samplingPredicate, + final int layer, + @Nonnull final Map> nodeCache) { + return neighborReferences(storageAdapter, readTransaction, storageTransform, random, + initialNodeReferenceAndNodes, samplingPredicate, layer, nodeCache) + .thenCompose(neighbors -> + fetchSomeNodesIfNotCached(storageAdapter, readTransaction, storageTransform, layer, + neighbors, nodeCache)) + .thenCompose(neighbors -> + filterExisting(storageAdapter, readTransaction, storageTransform, neighbors)); + } - return fetchNeighborhood(storageAdapter, readTransaction, storageTransform, layer, - neighborsOfCandidates, nodeCache) - .thenApply(withVectors -> { - final ImmutableList.Builder extendedCandidatesBuilder = - ImmutableList.builder(); - for (final NodeReferenceAndNode candidate : candidates) { - extendedCandidatesBuilder.add(candidate.getNodeReferenceWithDistance()); - } + /** + * Compute and if necessary fetch the neighbor references (with vectors) of an iterable of initial nodes that is + * passed in. Note that the neighbor of an initial node might be another initial node. If that is the case the node + * is returned. + * + * @param the type of the {@link NodeReference} + * @param storageAdapter the {@link StorageAdapter} used to access node data from storage + * @param readTransaction the active {@link ReadTransaction} for database access + * @param storageTransform an affine transformation operator that is used to transform the fetched vector into the + * storage space that is currently being used + * @param random a {@link SplittableRandom} to be used for sampling + * @param initialNodeReferenceAndNodes an {@link Iterable} of initial candidate nodes, which have already been + * evaluated + * @param samplingPredicate a predicate that restricts the number of neighbors to be fetched + * @param layer the graph layer from which to fetch nodes + * @param nodeCache a cache mapping primary keys to {@link AbstractNode} objects to avoid redundant fetches + * + * @return a {@link CompletableFuture} which will complete with a list of {@link NodeReferenceWithVector} + */ + private CompletableFuture> + neighborReferences(@Nonnull final StorageAdapter storageAdapter, + @Nonnull final ReadTransaction readTransaction, + @Nonnull final StorageTransform storageTransform, + @Nullable final SplittableRandom random, + @Nonnull final Collection> initialNodeReferenceAndNodes, + @Nonnull final CandidatePredicate samplingPredicate, + final int layer, + @Nonnull final Map> nodeCache) { + final Iterable toBeFetched = + findNeighborReferences(initialNodeReferenceAndNodes, random, samplingPredicate); + return fetchNeighborhoodReferences(storageAdapter, readTransaction, storageTransform, layer, toBeFetched, + nodeCache); + } - for (final NodeReferenceWithVector withVector : withVectors) { - final double distance = estimator.distance(vector, withVector.getVector()); - extendedCandidatesBuilder.add(new NodeReferenceWithDistance(withVector.getPrimaryKey(), - withVector.getVector(), distance)); - } - return extendedCandidatesBuilder.build(); - }); - } else { - final ImmutableList.Builder resultBuilder = ImmutableList.builder(); - for (final NodeReferenceAndNode candidate : candidates) { - resultBuilder.add(candidate.getNodeReferenceWithDistance()); + /** + * Return the union of the nodes passed in and their neighbors. + * + * @param the type of the {@link NodeReference} storage space that is currently being used + * @param initialNodeReferenceAndNodes an {@link Iterable} of initial candidate nodes + * + * @return a {@link CompletableFuture} which will complete with a set of {@link NodeReference}s + */ + private Set + findNeighborReferences(@Nonnull final Collection> initialNodeReferenceAndNodes, + @Nullable final SplittableRandom random, + @Nonnull final CandidatePredicate candidatePredicate) { + final Set neighborReferences = Sets.newLinkedHashSet(); + final ImmutableMap.Builder> initialNodesMapBuilder = ImmutableMap.builder(); + for (final NodeReferenceAndNode nodeReferenceAndNode : initialNodeReferenceAndNodes) { + initialNodesMapBuilder.put(nodeReferenceAndNode.getNode().getPrimaryKey(), nodeReferenceAndNode); + neighborReferences.add(nodeReferenceAndNode.getNodeReference()); + } + + final ImmutableMap> initialNodesMap = initialNodesMapBuilder.build(); + final Set nodeReferencesSeen = Sets.newHashSet(); + + for (final NodeReferenceAndNode nodeReferenceAndNode : initialNodeReferenceAndNodes) { + for (final N neighbor : nodeReferenceAndNode.getNode().getNeighbors()) { + final Tuple neighborPrimaryKey = neighbor.getPrimaryKey(); + + // + // We need to distinguish between initial node references and non-initial node references: + // Initial nodes references are of type T (and sometimes already contain a vector in which case + // we do not want to refetch the node later if we don't have to). The initial nodes already have been + // added earlier in this method (with or without a vector). The neighbors that are not initial most + // likely do not contain a vector which is fine but if T != N, we need to be careful in order to not + // create duplicates in this set. + // + @Nullable final NodeReferenceAndNode initialNode = initialNodesMap.get(neighborPrimaryKey); + if (initialNode == null && !nodeReferencesSeen.contains(neighborPrimaryKey)) { + // + // This is a node that is currently not known to us. It is not an initial node. We need to fetch it, + // and we need to mark it as seen so we won't consider it more than once. + // + neighborReferences.add(neighbor); + nodeReferencesSeen.add(neighborPrimaryKey); + } } + } - return CompletableFuture.completedFuture(resultBuilder.build()); + // sample down the set of neighbors by testing the candidate predicate + final ImmutableSet.Builder resultBuilder = ImmutableSet.builder(); + for (final NodeReference neighborReference : neighborReferences) { + if (candidatePredicate.test(random, initialNodesMap.keySet(), + neighborReferences.size(), neighborReference)) { + resultBuilder.add(neighborReference); + } } + + return resultBuilder.build(); } /** @@ -1478,14 +1787,610 @@ private void writeLonelyNodeOnLayer(@Nonnull final Qua @Nonnull final Tuple primaryKey, @Nonnull final Transformed vector) { storageAdapter.writeNode(transaction, quantizer, - storageAdapter.getNodeFactory() - .create(primaryKey, vector, ImmutableList.of()), layer, + layer, storageAdapter.getNodeFactory() + .create(primaryKey, vector, ImmutableList.of()), new BaseNeighborsChangeSet<>(ImmutableList.of())); if (logger.isTraceEnabled()) { logger.trace("written lonely node at key={} on layer={}", primaryKey, layer); } } + /** + * Deletes a record using its associated primary key from the HNSW graph. + *

+ * This method implements a multi-layer deletion algorithm that maintains the structural integrity of the HNSW + * graph. The deletion process consists of several key phases: + *

    + *
  • Layer Determination: First determines the top layer for the node using the same deterministic + * algorithm used during insertion, ensuring consistent layer assignment across operations. + *
  • + *
  • Existence Verification: Checks whether the node actually exists in the graph before attempting + * deletion. If the node doesn't exist, the operation completes immediately without error. + *
  • + *
  • Multi-Layer Deletion: Removes the node from all layers spanning from layer 0 (base layer + * containing all nodes) up to and including the node's top layer. The deletion is performed in parallel + * across all layers for optimal performance. + *
  • + *
  • Graph Repair: For each layer where the node is deleted, the algorithm repairs the local graph + * structure by identifying the deleted node's neighbors and reconnecting them appropriately. This process: + *
      + *
    • Finds candidate replacement connections among the neighbors of neighbors
    • + *
    • Selects optimal new connections using the HNSW distance heuristics
    • + *
    • Updates neighbor lists to maintain graph connectivity and search performance
    • + *
    • Applies connection limits (M, MMax) and prunes excess connections if necessary
    • + *
    + *
  • + *
  • Entry Point Management: If the deleted node was serving as the graph's entry point (the starting + * node for search operations), the method automatically selects a new entry point from the remaining nodes + * at the highest available layer. If no nodes remain after deletion, the access information is cleared, + * effectively resetting the graph to an empty state. + *
  • + *
+ * All operations are performed transactionally and asynchronously, ensuring consistency and enabling + * non-blocking execution in concurrent environments. + * + * @param transaction the {@link Transaction} context for all database operations, ensuring atomicity + * and consistency of the deletion and repair operations + * @param primaryKey the unique {@link Tuple} primary key identifying the node to be deleted from the graph + * + * @return a {@link CompletableFuture} that completes when the deletion operation is fully finished, + * including all graph repairs and entry point updates. The future completes with {@code null} + * on successful deletion. + */ + @Nonnull + public CompletableFuture delete(@Nonnull final Transaction transaction, @Nonnull final Tuple primaryKey) { + final SplittableRandom random = random(primaryKey); + final int topLayer = topLayer(primaryKey); + if (logger.isTraceEnabled()) { + logger.trace("node with key={} to be deleted form layer={}", primaryKey, topLayer); + } + + return StorageAdapter.fetchAccessInfo(getConfig(), transaction, getSubspace(), getOnReadListener()) + .thenCombine(exists(transaction, primaryKey), + (accessInfo, nodeExists) -> { + if (!nodeExists) { + if (logger.isTraceEnabled()) { + logger.trace("record does not exists in HNSW with key={} on layer={}", + primaryKey, topLayer); + } + } + return new AccessInfoAndNodeExistence(accessInfo, nodeExists); + }) + .thenCompose(accessInfoAndNodeExistence -> { + if (!accessInfoAndNodeExistence.isNodeExists()) { + return AsyncUtil.DONE; + } + + final AccessInfo accessInfo = accessInfoAndNodeExistence.getAccessInfo(); + final EntryNodeReference entryNodeReference = + accessInfo == null ? null : accessInfo.getEntryNodeReference(); + final StorageTransform storageTransform = storageTransform(accessInfo); + final Quantizer quantizer = quantizer(accessInfo); + + return deleteFromLayers(transaction, storageTransform, quantizer, random, primaryKey, topLayer) + .thenCompose(potentialEntryNodeReferences -> { + if (entryNodeReference != null && primaryKey.equals(entryNodeReference.getPrimaryKey())) { + // find (and store) a new entry reference + for (int i = potentialEntryNodeReferences.size() - 1; i >= 0; i --) { + final EntryNodeReference potentialEntyNodeReference = + potentialEntryNodeReferences.get(i); + if (potentialEntyNodeReference != null) { + StorageAdapter.writeAccessInfo(transaction, getSubspace(), + accessInfo.withNewEntryNodeReference(potentialEntyNodeReference), + getOnWriteListener()); + // early out + return AsyncUtil.DONE; + } + } + + // there is no data in the structure, delete access info to start new + StorageAdapter.deleteAccessInfo(transaction, getSubspace(), getOnWriteListener()); + } + return AsyncUtil.DONE; + }); + }); + } + + /** + * Deletes a node from the HNSW graph across multiple layers, using a primary key and a given top layer. + * + * @param transaction the transaction to use for database operations + * @param storageTransform an affine transformation operator that is used to transform the fetched vector into the + * storage space that is currently being used + * @param quantizer the quantizer to be used for this insert + * @param primaryKey the primary key of the new node being inserted + * @param topLayer the top layer for the node. + * + * @return a {@link CompletableFuture} that completes when the new node has been successfully inserted into all + * its designated layers and contains an existing neighboring entry node reference on that layer. + */ + @Nonnull + private CompletableFuture> deleteFromLayers(@Nonnull final Transaction transaction, + @Nonnull final StorageTransform storageTransform, + @Nonnull final Quantizer quantizer, + @Nonnull final SplittableRandom random, + @Nonnull final Tuple primaryKey, + final int topLayer) { + // delete the node from all layers in parallel (inside layer in [0, topLayer]) + return forEach(() -> IntStream.rangeClosed(0, topLayer).iterator(), + layer -> + deleteFromLayer(getStorageAdapterForLayer(layer), transaction, storageTransform, quantizer, + random.split(), layer, primaryKey), + getConfig().getMaxNumConcurrentDeleteFromLayer(), + executor); + } + + /** + * Deletes a node from a specified layer of the HNSW graph. This method orchestrates the complete deletion process + * for a single layer. + * + * @param the type of the node reference, extending {@link NodeReference} + * @param storageAdapter the storage adapter for reading from and writing to the graph + * @param transaction the transaction context for the database operations + * @param storageTransform an affine transformation operator that is used to transform the fetched vector into the + * storage space that is currently being used + * @param quantizer the quantizer for this insert + * @param layer the layer number to insert the new node into + * @param toBeDeletedPrimaryKey the primary key of the new node to be inserted + * + * @return a {@code CompletableFuture} that completes with a {@code null} + */ + @Nonnull + private CompletableFuture + deleteFromLayer(@Nonnull final StorageAdapter storageAdapter, + @Nonnull final Transaction transaction, + @Nonnull final StorageTransform storageTransform, + @Nonnull final Quantizer quantizer, + @Nonnull final SplittableRandom random, + final int layer, + @Nonnull final Tuple toBeDeletedPrimaryKey) { + if (logger.isTraceEnabled()) { + logger.trace("begin delete key={} at layer={}", toBeDeletedPrimaryKey, layer); + } + final Estimator estimator = quantizer.estimator(); + final Map> nodeCache = Maps.newConcurrentMap(); + final Map> candidateChangeSetMap = + Maps.newConcurrentMap(); + + return storageAdapter.fetchNode(transaction, storageTransform, layer, toBeDeletedPrimaryKey) + .thenCompose(toBeDeletedNode -> { + final NodeReferenceAndNode toBeDeletedNodeReferenceAndNode = + new NodeReferenceAndNode<>(new NodeReference(toBeDeletedPrimaryKey), toBeDeletedNode); + + return findDeletionRepairCandidates(storageAdapter, transaction, storageTransform, random, layer, + toBeDeletedNodeReferenceAndNode, nodeCache) + .thenCompose(candidates -> { + initializeCandidateChangeSetMap(toBeDeletedPrimaryKey, toBeDeletedNode, candidates, + candidateChangeSetMap); + // resolve the actually existing direct neighbors + final ImmutableList primaryNeighbors = + primaryNeighbors(toBeDeletedNode, candidateChangeSetMap); + + // + // Repair each primary neighbor in parallel, there should not be much actual I/O, + // except in edge cases, but we should still parallelize it. + // + return forEach(primaryNeighbors, + neighborReference -> + repairNeighbor(storageAdapter, transaction, + storageTransform, estimator, layer, neighborReference, + candidates, candidateChangeSetMap, nodeCache), + getConfig().getMaxNumConcurrentNeighborhoodFetches(), executor) + .thenApply(ignored -> { + final ImmutableMap.Builder candidateReferencesMapBuilder = + ImmutableMap.builder(); + for (final NodeReferenceAndNode candidate : candidates) { + final var candidatePrimaryKey = candidate.getNodeReference().getPrimaryKey(); + if (candidateChangeSetMap.containsKey(candidatePrimaryKey)) { + candidateReferencesMapBuilder.put(candidatePrimaryKey, candidate.getNodeReference()); + } + } + return candidateReferencesMapBuilder.build(); + }); + }) + .thenCompose(candidateReferencesMap -> { + final int currentMMax = + layer == 0 ? getConfig().getMMax0() : getConfig().getMMax(); + + // + // If we previously went beyond the mMax/mMax0, we need to prune the neighbors. + // Pruning is independent among different nodes -- we can therefore prune in + // parallel. + // + return forEach(candidateChangeSetMap.entrySet(), // for each modified neighbor set + changeSetEntry -> { + final NodeReferenceWithVector candidateReference = + Objects.requireNonNull(candidateReferencesMap.get(changeSetEntry.getKey())); + final NeighborsChangeSet candidateChangeSet = changeSetEntry.getValue(); + return pruneNeighborsIfNecessary(storageAdapter, transaction, + storageTransform, estimator, layer, candidateReference, + currentMMax, candidateChangeSet, nodeCache) + .thenApply(nodeReferencesAndNodes -> { + if (nodeReferencesAndNodes == null) { + return candidateChangeSet; + } + + final var prunedCandidateChangeSet = + resolveChangeSetFromNewNeighbors(candidateChangeSet, + nodeReferencesAndNodes); + candidateChangeSetMap.put(changeSetEntry.getKey(), prunedCandidateChangeSet); + return prunedCandidateChangeSet; + }); + }, + getConfig().getMaxNumConcurrentNeighborhoodFetches(), executor) + .thenApply(ignored -> candidateReferencesMap); + }) + .thenApply(candidateReferencesMap -> { + // + // Finally delete the node we set out to delete and persist the change sets for all + // repaired nodes. + // + storageAdapter.deleteNode(transaction, layer, toBeDeletedPrimaryKey); + + for (final Map.Entry> changeSetEntry : candidateChangeSetMap.entrySet()) { + final NeighborsChangeSet changeSet = changeSetEntry.getValue(); + if (changeSet.hasChanges()) { + final AbstractNode candidateNode = + nodeFromCache(changeSetEntry.getKey(), nodeCache); + storageAdapter.writeNode(transaction, quantizer, + layer, candidateNode, changeSet); + } + } + + // + // Return the first item in the candidates reference map as a potential new + // entry node reference in order to avoid a costly search for a new global entry point. + // This reference is guaranteed to exist. + // + final Tuple firstPrimaryKey = + Iterables.getFirst(candidateReferencesMap.keySet(), null); + return firstPrimaryKey == null + ? null + : new EntryNodeReference(firstPrimaryKey, + Objects.requireNonNull(candidateReferencesMap.get(firstPrimaryKey)).getVector(), + layer); + }); + }).thenApply(result -> { + if (logger.isTraceEnabled()) { + logger.trace("end delete key={} at layer={}", toBeDeletedPrimaryKey, layer); + } + return result; + }); + } + + private void initializeCandidateChangeSetMap(@Nonnull final Tuple toBeDeletedPrimaryKey, + @Nonnull final AbstractNode toBeDeletedNode, + @Nonnull final List> candidates, + @Nonnull final Map> candidateChangeSetMap) { + for (final NodeReferenceAndNode candidate : candidates) { + final AbstractNode candidateNode = candidate.getNode(); + boolean foundToBeDeleted = false; + for (final N neighborOfCandidate : candidateNode.getNeighbors()) { + if (neighborOfCandidate.getPrimaryKey().equals(toBeDeletedPrimaryKey)) { + // + // Make sure a neighbor pointing to the node being deleted is deleted as well. + // + candidateChangeSetMap.put(candidateNode.getPrimaryKey(), + new DeleteNeighborsChangeSet<>( + new BaseNeighborsChangeSet<>(candidateNode.getNeighbors()), + ImmutableList.of(toBeDeletedPrimaryKey))); + foundToBeDeleted = true; + break; + } + } + if (!foundToBeDeleted) { + // if there is no reference back to the node being deleted, just create the base set + candidateChangeSetMap.put(candidateNode.getPrimaryKey(), + new BaseNeighborsChangeSet<>(candidateNode.getNeighbors())); + } + } + if (logger.isTraceEnabled()) { + logger.trace("number of neighbors to repair={}", toBeDeletedNode.getNeighbors().size()); + } + } + + /** + * Compile a list of node references that definitely exist. The neighbor list of a node may contain node + * references to neighbors that don't exist anymore (stale reference). The (non-existing) nodes that these node + * references might refer to must not be repaired as that may resurrect a node. + *

+ * We know that the candidate change set map only contains keys for nodes that exist AND that the candidate change + * set map contains all primary neighbors (if they exist). Therefore, we filter the neighbors list from the node by + * cross-referencing the change set map. + * @param type parameter extending {@link NodeReference} + * @param toBeDeletedNode the node that is being deleted. + * @param candidateChangeSetMap the initialized candidate change set map. + * @return a list of existing primary neighbors + */ + @Nonnull + private ImmutableList + primaryNeighbors(@Nonnull final AbstractNode toBeDeletedNode, + @Nonnull final Map> candidateChangeSetMap) { + // + // All entries in the change set map definitely exist and the candidate change set map hold all keys for all + // existing primary candidates. + // + final ImmutableList.Builder primaryNeighborsBuilder = ImmutableList.builder(); + for (final N potentialPrimaryNeighbor : toBeDeletedNode.getNeighbors()) { + if (candidateChangeSetMap.containsKey(potentialPrimaryNeighbor.getPrimaryKey())) { + primaryNeighborsBuilder.add(potentialPrimaryNeighbor); + } + } + return primaryNeighborsBuilder.build(); + } + + /** + * Find candidates starting from the node to be deleted. To this end we find all the existing first degree (primary) + * and second-degree (secondary) neighbors. As that set is too big to consider for the repair we rely on sampling + * to eventually compile a list of roughly {@code efRepair} number of candidates. + * + * @param type parameter extending {@link NodeReference} + * @param storageAdapter the storage adapter for the layer + * @param transaction the transaction + * @param storageTransform the storage transform + * @param random a {@link SplittableRandom} used for sampling the candidate set + * @param layer the layer + * @param toBeDeletedNodeReferenceAndNode the node that is about to be deleted + * @param nodeCache the node cache to avoid repeated fetches + * @return a future that if successful completes with {@code null} + */ + @Nonnull + private CompletableFuture>> + findDeletionRepairCandidates(final @Nonnull StorageAdapter storageAdapter, + final @Nonnull Transaction transaction, + final @Nonnull StorageTransform storageTransform, + final @Nonnull SplittableRandom random, + final int layer, + final NodeReferenceAndNode toBeDeletedNodeReferenceAndNode, + final Map> nodeCache) { + return neighbors(storageAdapter, transaction, storageTransform, random, + ImmutableList.of(toBeDeletedNodeReferenceAndNode), + ((r, initialNodeKeys, size, nodeReference) -> + shouldUsePrimaryCandidateForRepair(nodeReference, + toBeDeletedNodeReferenceAndNode.getNodeReference().getPrimaryKey())), layer, nodeCache) + .thenCompose(candidates -> + neighbors(storageAdapter, transaction, storageTransform, random, + candidates, + ((r, initialNodeKeys, size, nodeReference) -> + shouldUseSecondaryCandidateForRepair(r, initialNodeKeys, size, nodeReference, + toBeDeletedNodeReferenceAndNode.getNodeReference().getPrimaryKey())), + layer, nodeCache)) + .thenApply(candidates -> { + if (logger.isTraceEnabled()) { + final ImmutableList.Builder candidateStringsBuilder = ImmutableList.builder(); + for (final NodeReferenceAndNode candidate : candidates) { + candidateStringsBuilder.add(candidate.getNode().getPrimaryKey().toString()); + } + logger.trace("found at layer={} num={} candidates={}", layer, candidates.size(), + String.join(",", candidateStringsBuilder.build())); + } + return candidates; + }); + } + + /** + * Repair a neighbor node of the node that is being deleted using a set of candidates. All candidates contain only + * the vector (in addition to identifying information like the primary key). The logic in + * computes distances between the neighbor vector and each candidate vector which is required by + * {@link #repairInsForNeighborNode}. + * + * @param type parameter extending {@link NodeReference} + * @param storageAdapter the storage adapter for the layer + * @param transaction the transaction + * @param storageTransform the storage transform + * @param estimator an estimator for distances + * @param layer the layer + * @param neighborReference the reference for which this method repairs incoming references + * @param candidates the set of candidates + * @param neighborChangeSetMap the change set map which records all changes to all nodes that are being repaired + * @param nodeCache the node cache to avoid repeated fetches + * @return a future that if successful completes with {@code null} + */ + private @Nonnull CompletableFuture + repairNeighbor(@Nonnull final StorageAdapter storageAdapter, + @Nonnull final Transaction transaction, + @Nonnull final StorageTransform storageTransform, + @Nonnull final Estimator estimator, + final int layer, + @Nonnull final N neighborReference, + @Nonnull final Collection> candidates, + @Nonnull final Map> neighborChangeSetMap, + @Nonnull final Map> nodeCache) { + + return fetchNodeIfNotCached(storageAdapter, transaction, + storageTransform, layer, neighborReference, nodeCache) + .thenCompose(neighborNode -> { + final ImmutableList.Builder candidatesReferencesBuilder = + ImmutableList.builder(); + final Transformed neighborVector = + storageAdapter.getVector(neighborReference, neighborNode); + // transform the NodeReferencesWithVectors into NodeReferencesWithDistance + for (final NodeReferenceAndNode candidate : candidates) { + // do not add the candidate if that candidate is in fact the neighbor itself + if (!candidate.getNodeReference().getPrimaryKey().equals(neighborReference.getPrimaryKey())) { + final Transformed candidateVector = + candidate.getNodeReference().getVector(); + final double distance = + estimator.distance(candidateVector, neighborVector); + candidatesReferencesBuilder.add(new NodeReferenceWithDistance( + candidate.getNode().getPrimaryKey(), candidateVector, distance)); + } + } + return repairInsForNeighborNode(storageAdapter, transaction, storageTransform, estimator, + layer, neighborReference, candidatesReferencesBuilder.build(), + neighborChangeSetMap, nodeCache); + }); + } + + /** + * Repairs the ins of a neighbor node of the node that is being deleted using a set of candidates. Each such + * neighbor is part of a set that is referred to as {@code p_out} in literature. In this method we only repair + * incoming references to this node. As this method is called once per direct neighbor and all direct neighbors are + * in the candidate set, outgoing references from this node to other nodes (in {@code p_out}) are repaired when this + * method is called for the respective neighbors. + * + * @param type parameter extending {@link NodeReference} + * @param storageAdapter the storage adapter for the layer + * @param transaction the transaction + * @param storageTransform the storage transform + * @param estimator an estimator for distances + * @param layer the layer + * @param neighborReference the reference for which this method repairs incoming references + * @param candidates the set of candidates + * @param neighborChangeSetMap the change set map which records all changes to all nodes that are being repaired + * @param nodeCache the node cache to avoid repeated fetches + * @return a future that if successful completes with {@code null} + */ + private CompletableFuture + repairInsForNeighborNode(@Nonnull final StorageAdapter storageAdapter, + @Nonnull final Transaction transaction, + @Nonnull final StorageTransform storageTransform, + @Nonnull final Estimator estimator, + final int layer, + @Nonnull final N neighborReference, + @Nonnull final Iterable candidates, + @Nonnull final Map> neighborChangeSetMap, + final Map> nodeCache) { + return selectCandidates(storageAdapter, transaction, storageTransform, estimator, candidates, + layer, getConfig().getM(), nodeCache) + .thenApply(selectedCandidates -> { + if (logger.isTraceEnabled()) { + final ImmutableList.Builder candidateStringsBuilder = ImmutableList.builder(); + for (final NodeReferenceAndNode candidate : selectedCandidates) { + candidateStringsBuilder.add(candidate.getNode().getPrimaryKey().toString()); + } + logger.trace("selected for neighbor={}, candidates={}", + neighborReference.getPrimaryKey(), + String.join(",", candidateStringsBuilder.build())); + } + return selectedCandidates; + }) + .thenCompose(selectedCandidates -> { + // create change sets for each selected neighbor and insert new node into them + for (final NodeReferenceAndNode selectedCandidate : selectedCandidates) { + neighborChangeSetMap.compute(selectedCandidate.getNode().getPrimaryKey(), + (ignored, oldChangeSet) -> { + Objects.requireNonNull(oldChangeSet); + // insert a reference to the neighbor + return new InsertNeighborsChangeSet<>(oldChangeSet, ImmutableList.of(neighborReference)); + }); + } + return AsyncUtil.DONE; + }); + } + + /** + * Gets the appropriate storage adapter for a given layer. + *

+ * This method selects a {@link StorageAdapter} implementation based on the layer number. The logic is intended to + * use an {@code InliningStorageAdapter} for layers greater than {@code 0} and a {@code CompactStorageAdapter} for + * layer 0. Note that we will only use inlining at all if the config indicates we should use inlining. + * + * @param layer the layer number for which to get the storage adapter + * @return a non-null {@link StorageAdapter} instance + */ + @Nonnull + private StorageAdapter getStorageAdapterForLayer(final int layer) { + return storageAdapterForLayer(getConfig(), getSubspace(), getOnWriteListener(), getOnReadListener(), layer); + } + + @Nonnull + private SplittableRandom random(@Nonnull final Tuple primaryKey) { + return new SplittableRandom(splitMixLong(primaryKey.hashCode())); + } + + /** + * Calculates a layer for a new element to be inserted or for an element to be deleted from. + *

+ * The layer is selected according to a logarithmic distribution, which ensures that the probability of choosing + * a higher layer decreases exponentially. This is achieved by applying the inverse transform sampling method. + * The specific formula is {@code floor(-ln(u) * lambda)}, where {@code u} is a uniform random number and + * {@code lambda} is a normalization factor derived from a system configuration parameter {@code M}. + * @param primaryKey the primary key of the record to be inserted/updated/deleted + * @return a non-negative integer representing the randomly selected layer + */ + private int topLayer(@Nonnull final Tuple primaryKey) { + double lambda = 1.0 / Math.log(getConfig().getM()); + double u = 1.0 - splitMixDouble(primaryKey.hashCode()); // Avoid log(0) + return (int) Math.floor(-Math.log(u) * lambda); + } + + /** + * Predicate to determine if a potential candidate is to be used as a candidate for repairing the HNSW. + * The predicate rejects the candidate reference if it is referring to the node that is being deleted, otherwise the + * predicate accepts the candidate reference. + * @param candidateReference a potential candidate that is either accepted or rejected + * @param toBeDeletedPrimaryKey the {@link Tuple} representing the node that is being deleted + * @return {@code true} iff {@code candidateReference} is accepted as an actual candidate for repair. + */ + private boolean shouldUsePrimaryCandidateForRepair(@Nonnull final NodeReference candidateReference, + @Nonnull final Tuple toBeDeletedPrimaryKey) { + final Tuple candidatePrimaryKey = candidateReference.getPrimaryKey(); + + // + // If the node reference is the record we are trying to delete we must reject it here as it is not a suitable + // candidate. + // + return !candidatePrimaryKey.equals(toBeDeletedPrimaryKey); + } + + /** + * Predicate to determine if a potential candidate is to be used ad a candidate for repairing the HNSW. + *

    + *
  1. The predicate rejects the candidate reference if it is referring to the node that is being deleted.
  2. + *
  3. The predicate always accepts a direct neighbor of the node that is about to be deleted.
  4. + *
  5. Sample the remaining potential candidates such that eventually the repair algorithm can use + * roughly {@code efRepair} actual candidates.
  6. + *
+ * @param random the PRNG to be used (splittable) + * @param initialNodeKeys a set of {@link Tuple}s that hold the primary neighbors of the node being deleted. + * @param numberOfCandidates the number of potential candidates the repair algorithm compiled + * @param candidateReference a potential candidate that is either accepted or rejected + * @param toBeDeletedPrimaryKey the {@link Tuple} representing the node that is being deleted + * @return {@code true} iff {@code candidateReference} is accepted as an actual candidate for repair. + */ + private boolean shouldUseSecondaryCandidateForRepair(@Nullable final SplittableRandom random, + @Nonnull final Set initialNodeKeys, + final int numberOfCandidates, + @Nonnull final NodeReference candidateReference, + @Nonnull final Tuple toBeDeletedPrimaryKey) { + final Tuple candidatePrimaryKey = candidateReference.getPrimaryKey(); + + // + // If the node reference is the record we are trying to delete we must reject it here as it is not a suitable + // candidate. + // + if (candidatePrimaryKey.equals(toBeDeletedPrimaryKey)) { + return false; + } + + // + // If the node reference is among the initial nodes we must accept it as they are very likely the best + // candidates. + // + if (initialNodeKeys.contains(candidatePrimaryKey)) { + return true; + } + + // + // Sample all the rest -- For the sampling rate, subtract the size of initialNodeKeys so that we get roughly + // efRepair nodes. + // + final double sampleRate = (double)(getConfig().getEfRepair() - initialNodeKeys.size()) / numberOfCandidates; + if (sampleRate >= 1) { + return true; + } + return Objects.requireNonNull(random).nextDouble() < sampleRate; + } + + private boolean shouldSampleVector(@Nonnull final SplittableRandom random) { + return random.nextDouble() < getConfig().getSampleVectorStatsProbability(); + } + + private boolean shouldMaintainStats(@Nonnull final SplittableRandom random) { + return random.nextDouble() < getConfig().getMaintainStatsProbability(); + } + /** * Scans all nodes within a given layer of the database. *

@@ -1500,11 +2405,14 @@ private void writeLonelyNodeOnLayer(@Nonnull final Qua * found in the layer. */ @VisibleForTesting - void scanLayer(@Nonnull final Database db, - final int layer, - final int batchSize, - @Nonnull final Consumer> nodeConsumer) { - final StorageAdapter storageAdapter = getStorageAdapterForLayer(layer); + static void scanLayer(@Nonnull final Config config, + @Nonnull final Subspace subspace, + @Nonnull final Database db, + final int layer, + final int batchSize, + @Nonnull final Consumer> nodeConsumer) { + final StorageAdapter storageAdapter = + storageAdapterForLayer(config, subspace, OnWriteListener.NOOP, OnReadListener.NOOP, layer); final AtomicReference lastPrimaryKeyAtomic = new AtomicReference<>(); Tuple newPrimaryKey; do { @@ -1517,7 +2425,7 @@ void scanLayer(@Nonnull final Database db, lastPrimaryKeyAtomic.set(node.getPrimaryKey()); }); return lastPrimaryKeyAtomic.get(); - }, executor); + }); } while (newPrimaryKey != null); } @@ -1528,52 +2436,52 @@ void scanLayer(@Nonnull final Database db, * use an {@code InliningStorageAdapter} for layers greater than {@code 0} and a {@code CompactStorageAdapter} for * layer 0. Note that we will only use inlining at all if the config indicates we should use inlining. * - * @param layer the layer number for which to get the storage adapter; currently unused - * @return a non-null {@link StorageAdapter} instance, which will always be a - * {@link CompactStorageAdapter} in the current implementation + * @param config the config to use + * @param subspace the subspace of the HNSW object itself + * @param onWriteListener a listener that the new {@link StorageAdapter} will call back for any write events + * @param onReadListener a listener that the new {@link StorageAdapter} will call back for any read events + * @param layer the layer number for which to get the storage adapter + * @return a non-null {@link StorageAdapter} instance */ @Nonnull - private StorageAdapter getStorageAdapterForLayer(final int layer) { + @VisibleForTesting + static StorageAdapter + storageAdapterForLayer(@Nonnull final Config config, + @Nonnull final Subspace subspace, + @Nonnull final OnWriteListener onWriteListener, + @Nonnull final OnReadListener onReadListener, + final int layer) { return config.isUseInlining() && layer > 0 - ? new InliningStorageAdapter(getConfig(), InliningNode.factory(), getSubspace(), getOnWriteListener(), - getOnReadListener()) - : new CompactStorageAdapter(getConfig(), CompactNode.factory(), getSubspace(), getOnWriteListener(), - getOnReadListener()); - } - - @Nonnull - private SplittableRandom random(@Nonnull final Tuple primaryKey) { - if (config.isDeterministicSeeding()) { - return new SplittableRandom(primaryKey.hashCode()); - } else { - return new SplittableRandom(System.nanoTime()); - } + ? new InliningStorageAdapter(config, InliningNode.factory(), subspace, onWriteListener, onReadListener) + : new CompactStorageAdapter(config, CompactNode.factory(), subspace, onWriteListener, onReadListener); } /** - * Calculates a random layer for a new element to be inserted. - *

- * The layer is selected according to a logarithmic distribution, which ensures that - * the probability of choosing a higher layer decreases exponentially. This is - * achieved by applying the inverse transform sampling method. The specific formula - * is {@code floor(-ln(u) * lambda)}, where {@code u} is a uniform random - * number and {@code lambda} is a normalization factor derived from a system - * configuration parameter {@code M}. - * @param random a random to use - * @return a non-negative integer representing the randomly selected layer. + * Returns a good double hash code for the argument of type {@code long}. It uses {@link #splitMixLong(long)} + * internally and then maps the {@code long} result to a {@code double} between {@code 0} and {@code 1}. + * This method is directly used in {@link #topLayer(Tuple)} to determine the top layer of a record given its + * primary key. + * @param x a {@code long} + * @return a high quality hash code of {@code x} as a {@code double} in the range {@code [0.0d, 1.0d)}. */ - private int insertionLayer(@Nonnull final SplittableRandom random) { - double lambda = 1.0 / Math.log(getConfig().getM()); - double u = 1.0 - random.nextDouble(); // Avoid log(0) - return (int) Math.floor(-Math.log(u) * lambda); - } - - private boolean shouldSampleVector(@Nonnull final SplittableRandom random) { - return random.nextDouble() < getConfig().getSampleVectorStatsProbability(); + private static double splitMixDouble(final long x) { + return (splitMixLong(x) >>> 11) * 0x1.0p-53; } - private boolean shouldMaintainStats(@Nonnull final SplittableRandom random) { - return random.nextDouble() < getConfig().getMaintainStatsProbability(); + /** + * Returns a good long hash code for the argument of type {@code long}. It is an implementation of the + * output mixing function {@code SplitMix64} as employed by many PRNG such as {@link SplittableRandom}. + * See Linear congruential generator for + * more information. + * @param x a {@code long} + * @return a high quality hash code of {@code x} + */ + private static long splitMixLong(long x) { + x += 0x9e3779b97f4a7c15L; + x = (x ^ (x >>> 30)) * 0xbf58476d1ce4e5b9L; + x = (x ^ (x >>> 27)) * 0x94d049bb133111ebL; + x = x ^ (x >>> 31); + return x; } @Nonnull @@ -1585,6 +2493,16 @@ private static List drain(@Nonnull Queue queue) { return resultBuilder.build(); } + @FunctionalInterface + private interface CandidatePredicate { + @Nonnull + static CandidatePredicate tautology() { + return (random, initialNodeKeys, size, nodeReference) -> true; + } + + boolean test(@Nullable SplittableRandom random, @Nonnull Set initialNodeKeys, int size, NodeReference nodeReference); + } + private static class AccessInfoAndNodeExistence { @Nullable private final AccessInfo accessInfo; diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/InliningStorageAdapter.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/InliningStorageAdapter.java index fce0fdac34..5c0c36395a 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/InliningStorageAdapter.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/InliningStorageAdapter.java @@ -34,6 +34,7 @@ import com.apple.foundationdb.subspace.Subspace; import com.apple.foundationdb.tuple.ByteArrayUtil; import com.apple.foundationdb.tuple.Tuple; +import com.google.common.base.Verify; import com.google.common.collect.ImmutableList; import javax.annotation.Nonnull; @@ -75,6 +76,14 @@ public InliningStorageAdapter(@Nonnull final Config config, super(config, nodeFactory, subspace, onWriteListener, onReadListener); } + @Nonnull + @Override + public Transformed getVector(@Nonnull final NodeReferenceWithVector nodeReference, + @Nonnull final AbstractNode node) { + Verify.verify(nodeReference.isNodeReferenceWithVector()); + return nodeReference.asNodeReferenceWithVector().getVector(); + } + /** * Asynchronously fetches a single node from a given layer by its primary key. *

@@ -82,6 +91,9 @@ public InliningStorageAdapter(@Nonnull final Config config, * It then performs an asynchronous range scan to retrieve all key-value pairs associated with that prefix. * Finally, it reconstructs the complete {@link AbstractNode} object from the collected raw data using * the {@code nodeFromRaw} method. + *

+ * Note that when using the inlining storage adapter it is not possible for distinguish between a node that has no + * neighbors and a node that is not present in the database (i.e. it was deleted). * * @param readTransaction the transaction to use for reading from the database * @param storageTransform an affine transformation operator that is used to transform the fetched vector into the @@ -208,15 +220,15 @@ private NodeReferenceWithVector neighborFromTuples(@Nonnull final AffineOperator * * @param transaction the transaction context for the write operation; must not be null * @param quantizer the quantizer to use + * @param layer the layer index where the node and its neighbor changes should be written * @param node the node to be written, which is expected to be an * {@code InliningNode}; must not be null - * @param layer the layer index where the node and its neighbor changes should be written * @param neighborsChangeSet the set of changes to the node's neighbors to be * persisted; must not be null */ @Override public void writeNodeInternal(@Nonnull final Transaction transaction, @Nonnull final Quantizer quantizer, - @Nonnull final AbstractNode node, final int layer, + final int layer, @Nonnull final AbstractNode node, @Nonnull final NeighborsChangeSet neighborsChangeSet) { final InliningNode inliningNode = node.asInliningNode(); @@ -224,6 +236,16 @@ public void writeNodeInternal(@Nonnull final Transaction transaction, @Nonnull f getOnWriteListener().onNodeWritten(layer, node); } + @Override + protected void deleteNodeInternal(@Nonnull final Transaction transaction, final int layer, + @Nonnull final Tuple primaryKey) { + final byte[] key = getNodeKey(layer, primaryKey); + final Range range = Range.startsWith(key); + transaction.clear(range); + getOnWriteListener().onNodeDeleted(layer, primaryKey); + getOnWriteListener().onRangeDeleted(layer, range); + } + /** * Constructs the raw database key for a node based on its layer and primary key. *

diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/InsertNeighborsChangeSet.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/InsertNeighborsChangeSet.java index b3b5ef8a12..0616446bf5 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/InsertNeighborsChangeSet.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/InsertNeighborsChangeSet.java @@ -31,6 +31,7 @@ import javax.annotation.Nonnull; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.function.Predicate; /** @@ -84,6 +85,11 @@ public NeighborsChangeSet getParent() { return parent; } + @Override + public boolean hasChanges() { + return true; + } + /** * Merges the neighbors from this level of the hierarchy with all neighbors from parent levels. *

@@ -95,7 +101,9 @@ public NeighborsChangeSet getParent() { @Nonnull @Override public Iterable merge() { - return Iterables.concat(getParent().merge(), insertedNeighborsMap.values()); + return Iterables.concat(Iterables.filter(getParent().merge(), + current -> !insertedNeighborsMap.containsKey(Objects.requireNonNull(current).getPrimaryKey())), + insertedNeighborsMap.values()); } /** diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NeighborsChangeSet.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NeighborsChangeSet.java index 207c6a1f1f..98a8e92b9e 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NeighborsChangeSet.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NeighborsChangeSet.java @@ -51,6 +51,12 @@ interface NeighborsChangeSet { @Nullable NeighborsChangeSet getParent(); + /** + * Method to indicate iff changes have been made that need to be persisted. + * @return {@code true} iff changes have been made in this or parent change sets. + */ + boolean hasChanges(); + /** * Merges multiple internal sequences into a single, consolidated iterable sequence. *

diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeReferenceAndNode.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeReferenceAndNode.java index a6c4f33abe..7108624535 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeReferenceAndNode.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeReferenceAndNode.java @@ -26,29 +26,31 @@ import java.util.List; /** - * A container class that pairs a {@link NodeReferenceWithDistance} with its corresponding {@link AbstractNode} object. + * A container class that pairs a {@link NodeReference} with its corresponding {@link AbstractNode} object. *

* This is often used during graph traversal or searching, where a reference to a node (along with its distance from a * query point) is first identified, and then the complete node data is fetched. This class holds these two related * pieces of information together. - * @param the type of {@link NodeReference} used within the {@link AbstractNode} + * @param the type of {@link NodeReference} referencing the node + * @param the type of {@link NodeReference} used within the {@link AbstractNode}, i.e. the type of the neighbor + * references */ -class NodeReferenceAndNode { +class NodeReferenceAndNode { @Nonnull - private final NodeReferenceWithDistance nodeReferenceWithDistance; + private final T nodeReference; @Nonnull private final AbstractNode node; /** * Constructs a new instance that pairs a node reference (with distance) with its * corresponding {@link AbstractNode} object. - * @param nodeReferenceWithDistance the reference to a node, which also includes distance information. Must not be + * @param nodeReference the reference to a node, which also includes distance information. Must not be * {@code null}. * @param node the actual {@link AbstractNode} object that the reference points to. Must not be {@code null}. */ - public NodeReferenceAndNode(@Nonnull final NodeReferenceWithDistance nodeReferenceWithDistance, + public NodeReferenceAndNode(@Nonnull final T nodeReference, @Nonnull final AbstractNode node) { - this.nodeReferenceWithDistance = nodeReferenceWithDistance; + this.nodeReference = nodeReference; this.node = node; } @@ -57,8 +59,8 @@ public NodeReferenceAndNode(@Nonnull final NodeReferenceWithDistance nodeReferen * @return the non-null {@link NodeReferenceWithDistance} object. */ @Nonnull - public NodeReferenceWithDistance getNodeReferenceWithDistance() { - return nodeReferenceWithDistance; + public T getNodeReference() { + return nodeReference; } /** @@ -70,6 +72,11 @@ public AbstractNode getNode() { return node; } + @Override + public String toString() { + return "NRaN[" + nodeReference + "," + node + ']'; + } + /** * Helper to extract the references from a given collection of objects of this container class. * @param referencesAndNodes an iterable of {@link NodeReferenceAndNode} objects from which to extract the @@ -77,10 +84,10 @@ public AbstractNode getNode() { * @return a {@link List} of {@link NodeReferenceAndNode}s */ @Nonnull - public static List getReferences(@Nonnull List> referencesAndNodes) { - final ImmutableList.Builder referencesBuilder = ImmutableList.builder(); - for (final NodeReferenceAndNode referenceWithNode : referencesAndNodes) { - referencesBuilder.add(referenceWithNode.getNodeReferenceWithDistance()); + public static List getReferences(@Nonnull List> referencesAndNodes) { + final ImmutableList.Builder referencesBuilder = ImmutableList.builder(); + for (final NodeReferenceAndNode referenceWithNode : referencesAndNodes) { + referencesBuilder.add(referenceWithNode.getNodeReference()); } return referencesBuilder.build(); } diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/OnWriteListener.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/OnWriteListener.java index aacc1ca8f2..4f2f434a0c 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/OnWriteListener.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/OnWriteListener.java @@ -20,6 +20,7 @@ package com.apple.foundationdb.async.hnsw; +import com.apple.foundationdb.Range; import com.apple.foundationdb.tuple.Tuple; import javax.annotation.Nonnull; @@ -32,7 +33,7 @@ public interface OnWriteListener { }; /** - * Callback method invoked after a node has been successfully written to a specific layer. + * Callback method that is invoked after a node has been successfully written to a specific layer. *

* This is a default method with an empty implementation, allowing implementing classes to override it only if they * need to react to this event. @@ -63,6 +64,34 @@ default void onNeighborWritten(final int layer, @Nonnull final Node + * This is a default method and its base implementation is a no-op. Implementors of the interface can override this + * method to react to the deletion of a neighbor node, for example, to clean up related resources or update internal + * state. + * @param layer the layer the data was written to + * @param key the key + * @param value the value. + */ + @SuppressWarnings("unused") + default void onKeyValueWritten(final int layer, @Nonnull final byte[] key, @Nonnull final byte[] value) { + // nothing + } + + /** + * Callback method invoked after a node has been successfully deleted from a specific layer. + *

+ * This is a default method with an empty implementation, allowing implementing classes to override it only if they + * need to react to this event. + * @param layer the index of the layer where the node was deleted. + * @param primaryKey the {@link Tuple} used as key to identify the node that was deleted; guaranteed to be non-null. + */ + @SuppressWarnings("unused") + default void onNodeDeleted(final int layer, @Nonnull final Tuple primaryKey) { + // nothing + } + /** * Callback method invoked when a neighbor of a specific node is deleted. *

@@ -79,8 +108,31 @@ default void onNeighborDeleted(final int layer, @Nonnull final Node + * This is a default method and its base implementation is a no-op. Implementors of the interface can override this + * method to react to the deletion of a neighbor node, for example, to clean up related resources or update internal + * state. + * @param layer the layer index where the deletion occurred + * @param key the key that was deleted + */ @SuppressWarnings("unused") - default void onKeyValueWritten(final int layer, @Nonnull final byte[] key, @Nonnull final byte[] value) { + default void onKeyDeleted(final int layer, @Nonnull final byte[] key) { + // nothing + } + + /** + * Callback method invoked when an entire range is deleted. + *

+ * This is a default method and its base implementation is a no-op. Implementors of the interface can override this + * method to react to the deletion of a neighbor node, for example, to clean up related resources or update internal + * state. + * @param layer the layer index where the deletion occurred + * @param range the {@link Range} that was deleted + */ + @SuppressWarnings("unused") + default void onRangeDeleted(final int layer, @Nonnull final Range range) { // nothing } } diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/StorageAdapter.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/StorageAdapter.java index 4c7296ad45..b219fbdd6e 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/StorageAdapter.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/StorageAdapter.java @@ -93,6 +93,40 @@ interface StorageAdapter { @Nonnull NodeFactory getNodeFactory(); + /** + * Method that returns {@code true} iff this {@link StorageAdapter} is inlining neighboring vectors (i.e. it is + * an {@link InliningStorageAdapter}). + * @return {@code true} iff this {@link StorageAdapter} is inlining neighboring vectors. + */ + boolean isInliningStorageAdapter(); + + /** + * Method that returns {@code this} object as an {@link InliningStorageAdapter} if this {@link StorageAdapter} is + * inlining neighboring vectors and is an {@link InliningStorageAdapter}. This method throws an exception if this + * storage adapter is any other kind of storage adapter. Callers of this method should ensure prior to calling this + * method that the storage adapter actually is of the right kind (by calling{@link #isInliningStorageAdapter()}. + * @return {@code this} as an {@link InliningStorageAdapter} + */ + @Nonnull + InliningStorageAdapter asInliningStorageAdapter(); + + /** + * Method that returns {@code true} iff this {@link StorageAdapter} is a compact storage adapter which means it is + * not inlining neighboring vectors (i.e. {@code this} is a {@link CompactStorageAdapter}). + * @return {@code true} iff this {@link StorageAdapter} is a {@link CompactStorageAdapter}. + */ + boolean isCompactStorageAdapter(); + + /** + * Method that returns {@code this} object a {@link CompactStorageAdapter} if this {@link StorageAdapter} is + * a {@link CompactStorageAdapter}. This method throws an exception if this storage adapter is any other kind of + * storage adapter. Callers of this method should ensure prior to calling this method that the storage adapter + * actually is of the right kind (by calling{@link #isCompactStorageAdapter()}. + * @return {@code this} as a {@link CompactStorageAdapter} + */ + @Nonnull + CompactStorageAdapter asCompactStorageAdapter(); + /** * Get the subspace used to store this HNSW structure. * @return the subspace @@ -124,6 +158,19 @@ interface StorageAdapter { @Nonnull OnReadListener getOnReadListener(); + /** + * Method that returns the vector associated with node information passed in. Note that depending on the storage + * layout and therefore the used {@link StorageAdapter}, the vector is either part of the reference + * (when using {@link InliningStorageAdapter}) or is s part of the {@link AbstractNode} itself (when using + * {@link CompactStorageAdapter}). This method hides that detail from the caller and correctly resolves the vector + * for both use cases. + * @param nodeReference a node reference + * @param node the accompanying node to {@code nodeReference} + * @return the associated vector as {@link Transformed} of {@link RealVector} + */ + @Nonnull + Transformed getVector(@Nonnull N nodeReference, @Nonnull AbstractNode node); + /** * Asynchronously fetches a node from a specific layer, identified by its primary key. *

@@ -146,18 +193,28 @@ CompletableFuture> fetchNode(@Nonnull ReadTransaction readTransa /** * Writes a node and its neighbor changes to the data store within a given transaction. *

- * This method is responsible for persisting the state of a {@link AbstractNode} and applying any modifications to its + * This method is responsible for persisting the state of a {@link AbstractNode} and applying any modifications to + * its * neighboring nodes as defined in the {@code NeighborsChangeSet}. The entire operation is performed atomically as * part of the provided {@link Transaction}. + * * @param transaction the non-null transaction context for this write operation. * @param quantizer the quantizer to use - * @param node the non-null node to be written to the data store. * @param layer the layer index where the node resides. + * @param node the non-null node to be written to the data store. * @param changeSet the non-null set of changes describing additions or removals of * neighbors for the given {@link AbstractNode}. */ - void writeNode(@Nonnull Transaction transaction, @Nonnull Quantizer quantizer, @Nonnull AbstractNode node, - int layer, @Nonnull NeighborsChangeSet changeSet); + void writeNode(@Nonnull Transaction transaction, @Nonnull Quantizer quantizer, int layer, + @Nonnull AbstractNode node, @Nonnull NeighborsChangeSet changeSet); + + /** + * Deletes a node from a particular layer in the database. + * @param transaction the transaction to use + * @param layer the layer the node should be deleted from + * @param primaryKey the primary key of the node + */ + void deleteNode(@Nonnull Transaction transaction, int layer, @Nonnull Tuple primaryKey); /** * Scans a specified layer of the structure, returning an iterable sequence of nodes. @@ -306,6 +363,21 @@ static void writeAccessInfo(@Nonnull final Transaction transaction, onWriteListener.onKeyValueWritten(entryNodeReference.getLayer(), key, value); } + /** + * Deletes the {@link AccessInfo} from the database within a given transaction and subspace. + * @param transaction the database transaction to use for the write operation + * @param subspace the subspace where the entry node reference will be stored + * @param onWriteListener the listener to be notified after the key-value pair is written + */ + static void deleteAccessInfo(@Nonnull final Transaction transaction, + @Nonnull final Subspace subspace, + @Nonnull final OnWriteListener onWriteListener) { + final Subspace entryNodeSubspace = accessInfoSubspace(subspace); + final byte[] key = entryNodeSubspace.pack(); + transaction.clear(key); + onWriteListener.onKeyDeleted(-1, key); + } + @Nonnull static CompletableFuture> consumeSampledVectors(@Nonnull final Transaction transaction, @Nonnull final Subspace subspace, @@ -324,6 +396,7 @@ static CompletableFuture> consumeSampledVectors(@Nonnull final byte[] key = keyValue.getKey(); final byte[] value = keyValue.getValue(); resultBuilder.add(aggregatedVectorFromRaw(prefixSubspace, key, value)); + // this is done to not lock the entire range we just read but jst the keys we did read transaction.addReadConflictKey(key); transaction.clear(key); onReadListener.onKeyValueRead(-1, key, value); @@ -346,12 +419,14 @@ static void appendSampledVector(@Nonnull final Transaction transaction, onWriteListener.onKeyValueWritten(-1, prefixKey, value); } - static void removeAllSampledVectors(@Nonnull final Transaction transaction, @Nonnull final Subspace subspace) { + static void deleteAllSampledVectors(@Nonnull final Transaction transaction, @Nonnull final Subspace subspace, + @Nonnull final OnWriteListener onWriteListener) { final Subspace prefixSubspace = samplesSubspace(subspace); final byte[] prefixKey = prefixSubspace.pack(); final Range range = Range.startsWith(prefixKey); transaction.clear(range); + onWriteListener.onRangeDeleted(-1, range); } @Nonnull diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/StorageTransform.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/StorageTransform.java index 3cac6f4826..27cdec9187 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/StorageTransform.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/StorageTransform.java @@ -20,6 +20,7 @@ package com.apple.foundationdb.async.hnsw; +import com.apple.foundationdb.annotation.SpotBugsSuppressWarnings; import com.apple.foundationdb.linear.AffineOperator; import com.apple.foundationdb.linear.FhtKacRotator; import com.apple.foundationdb.linear.LinearOperator; @@ -34,7 +35,10 @@ * (pre-rotated) centroid. This operator is used inside the HNSW to transform back and forth between the coordinate * system of the client and the coordinate system that is currently employed in the HNSW. */ +@SpotBugsSuppressWarnings(value = "SING_SINGLETON_HAS_NONPRIVATE_CONSTRUCTOR", justification = "Singleton designation is a false positive") class StorageTransform extends AffineOperator { + private static final StorageTransform IDENTITY_STORAGE_TRANSFORM = new StorageTransform(null, null); + public StorageTransform(final long seed, final int numDimensions, @Nonnull final RealVector translationVector) { this(new FhtKacRotator(seed, numDimensions, 10), translationVector); @@ -67,4 +71,9 @@ public RealVector apply(@Nonnull final RealVector vector) { public RealVector invertedApply(@Nonnull final RealVector vector) { return super.invertedApply(vector); } + + @Nonnull + public static StorageTransform identity() { + return IDENTITY_STORAGE_TRANSFORM; + } } diff --git a/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/ConfigTest.java b/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/ConfigTest.java index 141df55cfe..bf354e9459 100644 --- a/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/ConfigTest.java +++ b/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/ConfigTest.java @@ -32,13 +32,13 @@ void testConfig() { Assertions.assertThat(HNSW.newConfigBuilder().build(768)).isEqualTo(defaultConfig); Assertions.assertThat(defaultConfig.toBuilder().build(768)).isEqualTo(defaultConfig); - final boolean deterministicSeeding = true; final Metric metric = Metric.COSINE_METRIC; final boolean useInlining = true; final int m = Config.DEFAULT_M + 1; final int mMax = Config.DEFAULT_M_MAX + 1; final int mMax0 = Config.DEFAULT_M_MAX_0 + 1; final int efConstruction = Config.DEFAULT_EF_CONSTRUCTION + 1; + final int efRepair = Config.DEFAULT_EF_REPAIR + 1; final boolean extendCandidates = true; final boolean keepPrunedConnections = true; final int statsThreshold = 5000; @@ -50,14 +50,15 @@ void testConfig() { final int maxNumConcurrentNodeFetches = 1; final int maxNumConcurrentNeighborhoodFetches = 2; + final int maxNumConcurrentDeleteFromLayer = Config.DEFAULT_MAX_NUM_CONCURRENT_DELETE_FROM_LAYER + 1; - Assertions.assertThat(defaultConfig.isDeterministicSeeding()).isNotEqualTo(deterministicSeeding); Assertions.assertThat(defaultConfig.getMetric()).isNotSameAs(metric); Assertions.assertThat(defaultConfig.isUseInlining()).isNotEqualTo(useInlining); Assertions.assertThat(defaultConfig.getM()).isNotEqualTo(m); Assertions.assertThat(defaultConfig.getMMax()).isNotEqualTo(mMax); Assertions.assertThat(defaultConfig.getMMax0()).isNotEqualTo(mMax0); Assertions.assertThat(defaultConfig.getEfConstruction()).isNotEqualTo(efConstruction); + Assertions.assertThat(defaultConfig.getEfRepair()).isNotEqualTo(efRepair); Assertions.assertThat(defaultConfig.isExtendCandidates()).isNotEqualTo(extendCandidates); Assertions.assertThat(defaultConfig.isKeepPrunedConnections()).isNotEqualTo(keepPrunedConnections); @@ -70,16 +71,17 @@ void testConfig() { Assertions.assertThat(defaultConfig.getMaxNumConcurrentNodeFetches()).isNotEqualTo(maxNumConcurrentNodeFetches); Assertions.assertThat(defaultConfig.getMaxNumConcurrentNeighborhoodFetches()).isNotEqualTo(maxNumConcurrentNeighborhoodFetches); + Assertions.assertThat(defaultConfig.getMaxNumConcurrentDeleteFromLayer()).isNotEqualTo(maxNumConcurrentDeleteFromLayer); final Config newConfig = defaultConfig.toBuilder() - .setDeterministicSeeding(deterministicSeeding) .setMetric(metric) .setUseInlining(useInlining) .setM(m) .setMMax(mMax) .setMMax0(mMax0) .setEfConstruction(efConstruction) + .setEfRepair(efRepair) .setExtendCandidates(extendCandidates) .setKeepPrunedConnections(keepPrunedConnections) .setSampleVectorStatsProbability(sampleVectorStatsProbability) @@ -89,15 +91,16 @@ void testConfig() { .setRaBitQNumExBits(raBitQNumExBits) .setMaxNumConcurrentNodeFetches(maxNumConcurrentNodeFetches) .setMaxNumConcurrentNeighborhoodFetches(maxNumConcurrentNeighborhoodFetches) + .setMaxNumConcurrentDeleteFromLayer(maxNumConcurrentDeleteFromLayer) .build(768); - Assertions.assertThat(newConfig.isDeterministicSeeding()).isEqualTo(deterministicSeeding); Assertions.assertThat(newConfig.getMetric()).isSameAs(metric); Assertions.assertThat(newConfig.isUseInlining()).isEqualTo(useInlining); Assertions.assertThat(newConfig.getM()).isEqualTo(m); Assertions.assertThat(newConfig.getMMax()).isEqualTo(mMax); Assertions.assertThat(newConfig.getMMax0()).isEqualTo(mMax0); Assertions.assertThat(newConfig.getEfConstruction()).isEqualTo(efConstruction); + Assertions.assertThat(newConfig.getEfRepair()).isEqualTo(efRepair); Assertions.assertThat(newConfig.isExtendCandidates()).isEqualTo(extendCandidates); Assertions.assertThat(newConfig.isKeepPrunedConnections()).isEqualTo(keepPrunedConnections); @@ -110,6 +113,7 @@ void testConfig() { Assertions.assertThat(newConfig.getMaxNumConcurrentNodeFetches()).isEqualTo(maxNumConcurrentNodeFetches); Assertions.assertThat(newConfig.getMaxNumConcurrentNeighborhoodFetches()).isEqualTo(maxNumConcurrentNeighborhoodFetches); + Assertions.assertThat(newConfig.getMaxNumConcurrentDeleteFromLayer()).isEqualTo(maxNumConcurrentDeleteFromLayer); } @Test diff --git a/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/DataRecordsTest.java b/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/DataRecordsTest.java index 6d316103c8..69c4ade7ac 100644 --- a/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/DataRecordsTest.java +++ b/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/DataRecordsTest.java @@ -124,6 +124,25 @@ void testResultEntry(final long randomSeed) { assertHashCodeEqualsToString(randomSeed, DataRecordsTest::resultEntry, DataRecordsTest::resultEntry); } + @ParameterizedTest + @RandomSeedSource({0x0fdbL, 0x5ca1eL, 123456L, 78910L, 1123581321345589L}) + void testNodeReferenceAndNode(final long randomSeed) { + assertToString(randomSeed, DataRecordsTest::nodeReferenceAndNode, DataRecordsTest::nodeReferenceAndNode); + } + + private static void assertToString(final long randomSeed, + @Nonnull final Function createFunction, + @Nonnull final BiFunction createDifferentFunction) { + final Random random = new Random(randomSeed); + final long dependentRandomSeed = random.nextLong(); + final T t1 = createFunction.apply(new Random(dependentRandomSeed)); + final T t1Clone = createFunction.apply(new Random(dependentRandomSeed)); + Assertions.assertThat(t1).hasToString(t1Clone.toString()); + + final T t2 = createDifferentFunction.apply(random, t1); + Assertions.assertThat(t1).doesNotHaveToString(t2.toString()); + } + private static void assertHashCodeEqualsToString(final long randomSeed, @Nonnull final Function createFunction, @Nonnull final BiFunction createDifferentFunction) { @@ -140,6 +159,20 @@ private static void assertHashCodeEqualsToString(final long randomSeed, Assertions.assertThat(t1).doesNotHaveToString(t2.toString()); } + @Nonnull + private static NodeReferenceAndNode + nodeReferenceAndNode(@Nonnull final Random random) { + return new NodeReferenceAndNode<>(nodeReferenceWithDistance(random), inliningNode(random)); + } + + @Nonnull + private static NodeReferenceAndNode + nodeReferenceAndNode(@Nonnull final Random random, + @Nonnull final NodeReferenceAndNode original) { + return new NodeReferenceAndNode<>(nodeReferenceWithDistance(random, original.getNodeReference()), + inliningNode(random, original.getNode().asInliningNode())); + } + @Nonnull private static ResultEntry resultEntry(@Nonnull final Random random) { return new ResultEntry(primaryKey(random), rawVector(random), random.nextDouble(), random.nextInt(100)); diff --git a/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/HNSWTest.java b/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/HNSWTest.java index 726a38902f..961374115b 100644 --- a/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/HNSWTest.java +++ b/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/HNSWTest.java @@ -30,6 +30,7 @@ import com.apple.foundationdb.linear.Quantizer; import com.apple.foundationdb.linear.RealVector; import com.apple.foundationdb.linear.StoredVecsIterator; +import com.apple.foundationdb.linear.Transformed; import com.apple.foundationdb.rabitq.EncodedRealVector; import com.apple.foundationdb.test.TestDatabaseExtension; import com.apple.foundationdb.test.TestExecutors; @@ -40,20 +41,26 @@ import com.apple.test.SuperSlow; import com.apple.test.Tags; import com.google.common.base.Verify; +import com.google.common.base.VerifyException; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Lists; import com.google.common.collect.Maps; import com.google.common.collect.ObjectArrays; import com.google.common.collect.Sets; -import org.assertj.core.api.Assertions; -import org.assertj.core.util.Lists; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Tag; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.AfterTestExecutionCallback; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.api.extension.ExtensionContext; import org.junit.jupiter.api.extension.RegisterExtension; +import org.junit.jupiter.api.io.TempDir; import org.junit.jupiter.api.parallel.Execution; import org.junit.jupiter.api.parallel.ExecutionMode; +import org.junit.jupiter.params.ParameterInfo; import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.aggregator.ArgumentsAccessor; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; import org.slf4j.Logger; @@ -61,30 +68,39 @@ import javax.annotation.Nonnull; import javax.annotation.Nullable; +import java.io.BufferedWriter; import java.io.IOException; import java.nio.channels.FileChannel; +import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; import java.nio.file.StandardOpenOption; import java.util.ArrayList; +import java.util.Collection; import java.util.Comparator; import java.util.Iterator; import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.NavigableSet; import java.util.Objects; +import java.util.Optional; import java.util.Random; import java.util.Set; import java.util.TreeSet; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; import java.util.function.BiFunction; +import java.util.function.Consumer; import java.util.stream.Collectors; import java.util.stream.LongStream; import java.util.stream.Stream; import static com.apple.foundationdb.linear.RealVectorTest.createRandomDoubleVector; import static com.apple.foundationdb.linear.RealVectorTest.createRandomHalfVector; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.assertj.core.api.Assertions.within; /** @@ -104,6 +120,9 @@ class HNSWTest { @RegisterExtension TestSubspaceExtension rtSecondarySubspace = new TestSubspaceExtension(dbExtension); + @TempDir + Path tempDir; + private Database db; @BeforeEach @@ -113,12 +132,15 @@ public void setUpDb() { @ParameterizedTest @RandomSeedSource({0x0fdbL, 0x5ca1eL, 123456L, 78910L, 1123581321345589L}) - void testCompactSerialization(final long seed) { + void testCompactSerialization(final long seed) throws Exception { final Random random = new Random(seed); final int numDimensions = 768; final CompactStorageAdapter storageAdapter = new CompactStorageAdapter(HNSW.newConfigBuilder().build(numDimensions), CompactNode.factory(), rtSubspace.getSubspace(), OnWriteListener.NOOP, OnReadListener.NOOP); + assertThat(storageAdapter.asCompactStorageAdapter()).isSameAs(storageAdapter); + assertThatThrownBy(storageAdapter::asInliningStorageAdapter).isInstanceOf(VerifyException.class); + final AbstractNode originalNode = db.run(tr -> { final NodeFactory nodeFactory = storageAdapter.getNodeFactory(); @@ -133,11 +155,11 @@ void testCompactSerialization(final long seed) { db.run(tr -> storageAdapter.fetchNode(tr, AffineOperator.identity(), 0, originalNode.getPrimaryKey()) .thenAccept(node -> - Assertions.assertThat(node).satisfies( - n -> Assertions.assertThat(n).isInstanceOf(CompactNode.class), - n -> Assertions.assertThat(n.getKind()).isSameAs(NodeKind.COMPACT), - n -> Assertions.assertThat((Object)n.getPrimaryKey()).isEqualTo(originalNode.getPrimaryKey()), - n -> Assertions.assertThat(n.asCompactNode().getVector()) + assertThat(node).satisfies( + n -> assertThat(n).isInstanceOf(CompactNode.class), + n -> assertThat(n.getKind()).isSameAs(NodeKind.COMPACT), + n -> assertThat((Object)n.getPrimaryKey()).isEqualTo(originalNode.getPrimaryKey()), + n -> assertThat(n.asCompactNode().getVector()) .isEqualTo(originalNode.asCompactNode().getVector()), n -> { final ArrayList neighbors = @@ -146,20 +168,28 @@ void testCompactSerialization(final long seed) { final ArrayList originalNeighbors = Lists.newArrayList(originalNode.getNeighbors()); originalNeighbors.sort(Comparator.comparing(NodeReference::getPrimaryKey)); - Assertions.assertThat(neighbors).isEqualTo(originalNeighbors); + assertThat(neighbors).isEqualTo(originalNeighbors); } )).join()); + + assertThat( + dumpLayer(HNSW.newConfigBuilder() + .build(numDimensions), "debug", 0)) + .isGreaterThan(0); } @ParameterizedTest @RandomSeedSource({0x0fdbL, 0x5ca1eL, 123456L, 78910L, 1123581321345589L}) - void testInliningSerialization(final long seed) { + void testInliningSerialization(final long seed) throws Exception { final Random random = new Random(seed); final int numDimensions = 768; final InliningStorageAdapter storageAdapter = new InliningStorageAdapter(HNSW.newConfigBuilder().build(numDimensions), InliningNode.factory(), rtSubspace.getSubspace(), OnWriteListener.NOOP, OnReadListener.NOOP); + assertThat(storageAdapter.asInliningStorageAdapter()).isSameAs(storageAdapter); + assertThatThrownBy(storageAdapter::asCompactStorageAdapter).isInstanceOf(VerifyException.class); + final Node originalNode = db.run(tr -> { final NodeFactory nodeFactory = storageAdapter.getNodeFactory(); @@ -167,17 +197,17 @@ void testInliningSerialization(final long seed) { final AbstractNode randomInliningNode = createRandomInliningNode(random, nodeFactory, numDimensions, 16); - writeNode(tr, storageAdapter, randomInliningNode, 0); + writeNode(tr, storageAdapter, randomInliningNode, 1); return randomInliningNode; }); - db.run(tr -> storageAdapter.fetchNode(tr, AffineOperator.identity(), 0, + db.run(tr -> storageAdapter.fetchNode(tr, AffineOperator.identity(), 1, originalNode.getPrimaryKey()) .thenAccept(node -> - Assertions.assertThat(node).satisfies( - n -> Assertions.assertThat(n).isInstanceOf(InliningNode.class), - n -> Assertions.assertThat(n.getKind()).isSameAs(NodeKind.INLINING), - n -> Assertions.assertThat((Object)node.getPrimaryKey()).isEqualTo(originalNode.getPrimaryKey()), + assertThat(node).satisfies( + n -> assertThat(n).isInstanceOf(InliningNode.class), + n -> assertThat(n.getKind()).isSameAs(NodeKind.INLINING), + n -> assertThat((Object)node.getPrimaryKey()).isEqualTo(originalNode.getPrimaryKey()), n -> { final ArrayList neighbors = Lists.newArrayList(node.getNeighbors()); @@ -185,78 +215,74 @@ void testInliningSerialization(final long seed) { final ArrayList originalNeighbors = Lists.newArrayList(originalNode.getNeighbors()); originalNeighbors.sort(Comparator.comparing(NodeReference::getPrimaryKey)); - Assertions.assertThat(neighbors).isEqualTo(originalNeighbors); + assertThat(neighbors).isEqualTo(originalNeighbors); } )).join()); + + assertThat( + dumpLayer(HNSW.newConfigBuilder() + .setUseInlining(true) + .build(numDimensions), "debug", 1)) + .isGreaterThan(0); } - static Stream randomSeedsWithOptions() { + static Stream randomSeedsWithConfig() { return RandomizedTestUtils.randomSeeds(0xdeadc0deL) .flatMap(seed -> Sets.cartesianProduct(ImmutableSet.of(true, false), - ImmutableSet.of(true, false), - ImmutableSet.of(true, false), - ImmutableSet.of(true, false)).stream() - .map(arguments -> Arguments.of(ObjectArrays.concat(seed, arguments.toArray())))); + ImmutableSet.of(false, true), + ImmutableSet.of(false, true), + ImmutableSet.of(false, true)).stream() + .map(arguments -> Arguments.of(ObjectArrays.concat(seed, + new Object[] {HNSW.newConfigBuilder() + .setMetric(Metric.EUCLIDEAN_METRIC) + .setUseInlining(arguments.get(0)) + .setEfRepair(64) + .setExtendCandidates(arguments.get(1)) + .setKeepPrunedConnections(arguments.get(2)) + .setUseRaBitQ(arguments.get(3)) + .setRaBitQNumExBits(5) + .setSampleVectorStatsProbability(1.0d) + .setMaintainStatsProbability(0.1d) + .setStatsThreshold(100) + .setM(16) + .setMMax(32) + .setMMax0(64) + .build(128)})))); } - @ParameterizedTest(name = "seed={0} useInlining={1} extendCandidates={2} keepPrunedConnections={3} useRaBitQ={4}") - @MethodSource("randomSeedsWithOptions") - void testBasicInsert(final long seed, final boolean useInlining, final boolean extendCandidates, - final boolean keepPrunedConnections, final boolean useRaBitQ) { + @ExtendWith(HNSWTest.DumpLayersIfFailure.class) + @ParameterizedTest + @MethodSource("randomSeedsWithConfig") + void testBasicInsert(final long seed, final Config config) { final Random random = new Random(seed); - final Metric metric = Metric.EUCLIDEAN_METRIC; + final Metric metric = config.getMetric(); + final int size = 1000; + final TestOnWriteListener onWriteListener = new TestOnWriteListener(); final TestOnReadListener onReadListener = new TestOnReadListener(); - final int numDimensions = 128; - final HNSW hnsw = new HNSW(rtSubspace.getSubspace(), TestExecutors.defaultThreadPool(), - HNSW.newConfigBuilder() - .setDeterministicSeeding(true) - .setMetric(metric) - .setUseInlining(useInlining) - .setExtendCandidates(extendCandidates) - .setKeepPrunedConnections(keepPrunedConnections) - .setUseRaBitQ(useRaBitQ) - .setRaBitQNumExBits(5) - .setSampleVectorStatsProbability(1.0d) - .setMaintainStatsProbability(0.1d) - .setStatsThreshold(100) - .setM(32) - .setMMax(32) - .setMMax0(64) - .build(numDimensions), - OnWriteListener.NOOP, onReadListener); + final HNSW hnsw = new HNSW(rtSubspace.getSubspace(), TestExecutors.defaultThreadPool(), config, + onWriteListener, onReadListener); final int k = 50; - final HalfRealVector queryVector = createRandomHalfVector(random, numDimensions); - final TreeSet recordsOrderedByDistance = - new TreeSet<>(Comparator.comparing(PrimaryKeyVectorAndDistance::getDistance)); + final List insertedData = randomVectors(random, config.getNumDimensions(), size); - for (int i = 0; i < 1000;) { + for (int i = 0; i < size;) { i += basicInsertBatch(hnsw, 100, i, onReadListener, - (tr, nextId) -> { - final var primaryKey = createPrimaryKey(nextId); - final HalfRealVector dataVector = createRandomHalfVector(random, numDimensions); - final double distance = metric.distance(dataVector, queryVector); - final PrimaryKeyVectorAndDistance record = - new PrimaryKeyVectorAndDistance(primaryKey, dataVector, distance); - recordsOrderedByDistance.add(record); - if (recordsOrderedByDistance.size() > k) { - recordsOrderedByDistance.pollLast(); - } - return record; - }); + (tr, nextId) -> insertedData.get(Math.toIntExact(nextId))); } + final HalfRealVector queryVector = createRandomHalfVector(random, config.getNumDimensions()); + // // Attempt to mutate some records by updating them using the same primary keys but different random vectors. // This should not fail but should be silently ignored. If this succeeds, the following searches will all // return records that are not aligned with recordsOrderedByDistance. // - for (int i = 0; i < 100;) { + for (int i = 0; i < 100; ) { i += basicInsertBatch(hnsw, 100, 0, onReadListener, (tr, ignored) -> { final var primaryKey = createPrimaryKey(random.nextInt(1000)); - final HalfRealVector dataVector = createRandomHalfVector(random, numDimensions); + final HalfRealVector dataVector = createRandomHalfVector(random, config.getNumDimensions()); final double distance = metric.distance(dataVector, queryVector); return new PrimaryKeyVectorAndDistance(primaryKey, dataVector, distance); }); @@ -270,7 +296,8 @@ void testBasicInsert(final long seed, final boolean useInlining, final boolean e final long endTs = System.nanoTime(); final ImmutableSet trueNN = - recordsOrderedByDistance.stream() + orderedByDistances(Metric.EUCLIDEAN_METRIC, insertedData, queryVector).stream() + .limit(k) .map(PrimaryKeyVectorAndDistance::getPrimaryKey) .collect(ImmutableSet.toImmutableSet()); @@ -279,7 +306,7 @@ void testBasicInsert(final long seed, final boolean useInlining, final boolean e logger.info("nodeId ={} at distance={}", resultEntry.getPrimaryKey().getLong(0), resultEntry.getDistance()); if (trueNN.contains(resultEntry.getPrimaryKey())) { - recallCount ++; + recallCount++; } } final double recall = (double)recallCount / (double)k; @@ -287,7 +314,7 @@ void testBasicInsert(final long seed, final boolean useInlining, final boolean e TimeUnit.NANOSECONDS.toMillis(endTs - beginTs), onReadListener.getNodeCountByLayer(), onReadListener.getBytesReadByLayer(), String.format(Locale.ROOT, "%.2f", recall * 100.0d)); - Assertions.assertThat(recall).isGreaterThan(0.9); + assertThat(recall).isGreaterThan(0.9); final Set insertedIds = LongStream.range(0, 1000) @@ -295,14 +322,117 @@ void testBasicInsert(final long seed, final boolean useInlining, final boolean e .collect(Collectors.toSet()); final Set readIds = Sets.newHashSet(); - hnsw.scanLayer(db, 0, 100, - node -> Assertions.assertThat(readIds.add(node.getPrimaryKey().getLong(0))).isTrue()); - Assertions.assertThat(readIds).isEqualTo(insertedIds); + scanLayer(config, 0, 100, + node -> + assertThat(readIds.add(node.getPrimaryKey().getLong(0))).isTrue()); + assertThat(readIds).isEqualTo(insertedIds); readIds.clear(); - hnsw.scanLayer(db, 1, 100, - node -> Assertions.assertThat(readIds.add(node.getPrimaryKey().getLong(0))).isTrue()); - Assertions.assertThat(readIds.size()).isBetween(10, 50); + scanLayer(config, 1, 100, + node -> + assertThat(readIds.add(node.getPrimaryKey().getLong(0))).isTrue()); + assertThat(readIds.size()).isBetween(10, 100); + } + + @ExtendWith(HNSWTest.DumpLayersIfFailure.class) + @ParameterizedTest + @MethodSource("randomSeedsWithConfig") + void testBasicInsertDelete(final long seed, final Config config) { + final Random random = new Random(seed); + final int size = 1000; + final TestOnWriteListener onWriteListener = new TestOnWriteListener(); + final TestOnReadListener onReadListener = new TestOnReadListener(); + + final HNSW hnsw = new HNSW(rtSubspace.getSubspace(), TestExecutors.defaultThreadPool(), config, + onWriteListener, onReadListener); + + final int k = 50; + final List insertedData = randomVectors(random, config.getNumDimensions(), size); + + for (int i = 0; i < size;) { + i += basicInsertBatch(hnsw, 100, i, onReadListener, + (tr, nextId) -> insertedData.get(Math.toIntExact(nextId))); + } + + final int numVectorsPerDeleteBatch = 100; + List remainingData = insertedData; + do { + final List toBeDeleted = + pickRandomVectors(random, remainingData, numVectorsPerDeleteBatch); + + onWriteListener.reset(); + onReadListener.reset(); + + final long beginTs = System.nanoTime(); + db.run(tr -> { + for (final PrimaryKeyAndVector primaryKeyAndVector : toBeDeleted) { + hnsw.delete(tr, primaryKeyAndVector.getPrimaryKey()).join(); + } + return null; + }); + long endTs = System.nanoTime(); + + assertThat(onWriteListener.getDeleteCountByLayer().get(0)).isEqualTo(toBeDeleted.size()); + + logger.info("delete transaction of {} records after {} records took elapsedTime={}ms; read nodes={}, read bytes={}", + numVectorsPerDeleteBatch, + size - remainingData.size(), + TimeUnit.NANOSECONDS.toMillis(endTs - beginTs), + onReadListener.getNodeCountByLayer(), onReadListener.getBytesReadByLayer()); + + db.run(tr -> { + for (final PrimaryKeyAndVector primaryKeyAndVector : toBeDeleted) { + hnsw.delete(tr, primaryKeyAndVector.getPrimaryKey()).join(); + } + return null; + }); + + final Set deletedSet = toBeDeleted.stream().collect(ImmutableSet.toImmutableSet()); + remainingData = remainingData.stream() + .filter(vector -> !deletedSet.contains(vector)) + .collect(ImmutableList.toImmutableList()); + + if (!remainingData.isEmpty()) { + final HalfRealVector queryVector = createRandomHalfVector(random, config.getNumDimensions()); + final ImmutableSet trueNN = + orderedByDistances(Metric.EUCLIDEAN_METRIC, remainingData, queryVector).stream() + .limit(k) + .map(PrimaryKeyVectorAndDistance::getPrimaryKey) + .collect(ImmutableSet.toImmutableSet()); + + onReadListener.reset(); + + final long beginTsQuery = System.nanoTime(); + final List results = + db.run(tr -> + hnsw.kNearestNeighborsSearch(tr, k, 100, true, queryVector).join()); + final long endTsQuery = System.nanoTime(); + + int recallCount = 0; + for (ResultEntry resultEntry : results) { + if (trueNN.contains(resultEntry.getPrimaryKey())) { + recallCount++; + } + } + final double recall = (double)recallCount / (double)trueNN.size(); + + logger.info("search transaction after delete of {} records took elapsedTime={}ms; read nodes={}, read bytes={}, recall={}", + size - remainingData.size(), + TimeUnit.NANOSECONDS.toMillis(endTsQuery - beginTsQuery), + onReadListener.getNodeCountByLayer(), onReadListener.getBytesReadByLayer(), + String.format(Locale.ROOT, "%.2f", recall * 100.0d)); + + assertThat(recall).isGreaterThan(0.9); + + final long remainingNumNodes = countNodesOnLayer(config, 0); + assertThat(remainingNumNodes).isEqualTo(remainingData.size()); + } + } while (!remainingData.isEmpty()); + + final var accessInfo = + db.run(transaction -> StorageAdapter.fetchAccessInfo(hnsw.getConfig(), + transaction, hnsw.getSubspace(), OnReadListener.NOOP).join()); + assertThat(accessInfo).isNull(); } @ParameterizedTest() @@ -314,7 +444,6 @@ void testBasicInsertWithRaBitQEncodings(final long seed) { final int numDimensions = 128; final HNSW hnsw = new HNSW(rtSubspace.getSubspace(), TestExecutors.defaultThreadPool(), HNSW.newConfigBuilder() - .setDeterministicSeeding(true) .setMetric(metric) .setUseRaBitQ(true) .setRaBitQNumExBits(5) @@ -352,13 +481,13 @@ void testBasicInsertWithRaBitQEncodings(final long seed) { } // - // If we fetch the current state back from the db some vectors are regular vectors and some vectors are + // If we fetch the current state back from the db, some vectors are regular vectors and some vectors are // RaBitQ encoded. Since that information is not surfaced through the API, we need to scan layer 0, get // all vectors directly from disk (encoded/not-encoded, transformed/not-transformed) in order to check // that transformations/reconstructions are applied properly. // final Map fromDBMap = Maps.newHashMap(); - hnsw.scanLayer(db, 0, 100, + scanLayer(hnsw.getConfig(), 0, 100, node -> fromDBMap.put(node.getPrimaryKey(), node.asCompactNode().getVector().getUnderlyingVector())); @@ -383,27 +512,27 @@ void testBasicInsertWithRaBitQEncodings(final long seed) { } final RealVector originalVector = dataMap.get(resultEntry.getPrimaryKey()); - Assertions.assertThat(originalVector).isNotNull(); + assertThat(originalVector).isNotNull(); final RealVector fromDBVector = fromDBMap.get(resultEntry.getPrimaryKey()); - Assertions.assertThat(fromDBVector).isNotNull(); + assertThat(fromDBVector).isNotNull(); if (!(fromDBVector instanceof EncodedRealVector)) { - Assertions.assertThat(originalVector).isEqualTo(fromDBVector); + assertThat(originalVector).isEqualTo(fromDBVector); exactVectorCount ++; final double distance = metric.distance(originalVector, Objects.requireNonNull(resultEntry.getVector())); - Assertions.assertThat(distance).isCloseTo(0.0d, within(2E-12)); + assertThat(distance).isCloseTo(0.0d, within(2E-12)); } else { encodedVectorCount ++; final double distance = metric.distance(originalVector, Objects.requireNonNull(resultEntry.getVector()).toDoubleRealVector()); - Assertions.assertThat(distance).isCloseTo(0.0d, within(20.0d)); + assertThat(distance).isCloseTo(0.0d, within(20.0d)); } } final double recall = (double)recallCount / (double)k; - Assertions.assertThat(recall).isGreaterThan(0.9); + assertThat(recall).isGreaterThan(0.9); // must have both kinds - Assertions.assertThat(exactVectorCount).isGreaterThan(0); - Assertions.assertThat(encodedVectorCount).isGreaterThan(0); + assertThat(exactVectorCount).isGreaterThan(0); + assertThat(encodedVectorCount).isGreaterThan(0); } private int basicInsertBatch(final HNSW hnsw, final int batchSize, @@ -436,9 +565,8 @@ void testSIFTInsertSmall() throws Exception { final HNSW hnsw = new HNSW(rtSubspace.getSubspace(), TestExecutors.defaultThreadPool(), HNSW.newConfigBuilder() - .setDeterministicSeeding(false) .setUseRaBitQ(true) - .setRaBitQNumExBits(5) + .setRaBitQNumExBits(6) .setMetric(metric) .setM(32) .setMMax(32) @@ -475,7 +603,7 @@ void testSIFTInsertSmall() throws Exception { return new PrimaryKeyAndVector(currentPrimaryKey, currentVector); }); } - Assertions.assertThat(i).isEqualTo(10000); + assertThat(i).isEqualTo(10000); } validateSIFTSmall(hnsw, dataMap, k); @@ -515,15 +643,15 @@ private void validateSIFTSmall(@Nonnull final HNSW hnsw, @Nonnull final Map randomVectors(@Nonnull final Random random, final int numDimensions, + final int numberOfVectors) { + final ImmutableList.Builder resultBuilder = ImmutableList.builder(); + for (int i = 0; i < numberOfVectors; i ++) { + final var primaryKey = createPrimaryKey(i); + final HalfRealVector dataVector = createRandomHalfVector(random, numDimensions); + resultBuilder.add(new PrimaryKeyAndVector(primaryKey, dataVector)); + } + return resultBuilder.build(); + } + + @Nonnull + private List pickRandomVectors(@Nonnull final Random random, + @Nonnull final Collection vectors, + final int numberOfVectors) { + Verify.verify(numberOfVectors <= vectors.size()); + final List remainingVectors = Lists.newArrayList(vectors); + final ImmutableList.Builder resultBuilder = ImmutableList.builder(); + for (int i = 0; i < numberOfVectors; i ++) { + resultBuilder.add(remainingVectors.remove(random.nextInt(remainingVectors.size()))); + } + return resultBuilder.build(); + } + + @Nonnull + private NavigableSet orderedByDistances(@Nonnull final Metric metric, + @Nonnull final List vectors, + @Nonnull final HalfRealVector queryVector) { + final TreeSet vectorsOrderedByDistance = + new TreeSet<>(Comparator.comparing(PrimaryKeyVectorAndDistance::getDistance)); + for (final PrimaryKeyAndVector vector : vectors) { + final double distance = metric.distance(vector.getVector(), queryVector); + final PrimaryKeyVectorAndDistance record = + new PrimaryKeyVectorAndDistance(vector.getPrimaryKey(), vector.getVector(), distance); + vectorsOrderedByDistance.add(record); + } + return vectorsOrderedByDistance; + } + + private long countNodesOnLayer(@Nonnull final Config config, final int layer) { + final AtomicLong counter = new AtomicLong(); + scanLayer(config, layer, 100, node -> counter.incrementAndGet()); + return counter.get(); + } + + private void scanLayer(@Nonnull final Config config, + final int layer, + final int batchSize, + @Nonnull final Consumer> nodeConsumer) { + HNSW.scanLayer(config, rtSubspace.getSubspace(), db, layer, batchSize, nodeConsumer); + } + + private long dumpLayer(@Nonnull final Config config, + @Nonnull final String prefix, final int layer) throws IOException { + final Path verticesFile = tempDir.resolve("vertices-" + prefix + "-" + layer + ".csv"); + final Path edgesFile = tempDir.resolve("edges-" + prefix + "-" + layer + ".csv"); + + final StorageAdapter storageAdapter = + HNSW.storageAdapterForLayer(config, rtSubspace.getSubspace(), + OnWriteListener.NOOP, OnReadListener.NOOP, layer); + + final AtomicLong numReadAtomic = new AtomicLong(0L); + try (final BufferedWriter verticesWriter = Files.newBufferedWriter(verticesFile); + final BufferedWriter edgesWriter = Files.newBufferedWriter(edgesFile)) { + scanLayer(config, layer, 100, node -> { + @Nullable final Transformed vector = + storageAdapter.isCompactStorageAdapter() + ? node.asCompactNode().getVector() + : null; + try { + verticesWriter.write(Long.toString(node.getPrimaryKey().getLong(0))); + if (vector != null) { + verticesWriter.write(","); + final RealVector realVector = vector.getUnderlyingVector(); + for (int i = 0; i < realVector.getNumDimensions(); i++) { + if (i != 0) { + verticesWriter.write(","); + } + verticesWriter.write(String.valueOf(realVector.getComponent(i))); + } + } + verticesWriter.newLine(); + + for (final var neighbor : node.getNeighbors()) { + edgesWriter.write(node.getPrimaryKey().getLong(0) + "," + + neighbor.getPrimaryKey().getLong(0)); + edgesWriter.newLine(); + } + numReadAtomic.getAndIncrement(); + } catch (final IOException e) { + throw new RuntimeException("unable to write to file", e); + } + }); + } + return numReadAtomic.get(); + } + private void writeNode(@Nonnull final Transaction transaction, @Nonnull final StorageAdapter storageAdapter, @Nonnull final AbstractNode node, @@ -547,7 +773,7 @@ private void writeNode(@Nonnull final Transaction tran final NeighborsChangeSet insertChangeSet = new InsertNeighborsChangeSet<>(new BaseNeighborsChangeSet<>(ImmutableList.of()), node.getNeighbors()); - storageAdapter.writeNode(transaction, Quantizer.noOpQuantizer(Metric.EUCLIDEAN_METRIC), node, layer, + storageAdapter.writeNode(transaction, Quantizer.noOpQuantizer(Metric.EUCLIDEAN_METRIC), layer, node, insertChangeSet); } @@ -605,6 +831,63 @@ private static Tuple createPrimaryKey(final long nextId) { return Tuple.from(nextId); } + public static class DumpLayersIfFailure implements AfterTestExecutionCallback { + @Override + public void afterTestExecution(@Nonnull final ExtensionContext context) { + final Optional failure = context.getExecutionException(); + if (failure.isEmpty()) { + return; + } + + final ParameterInfo parameterInfo = ParameterInfo.get(context); + + if (parameterInfo != null) { + final ArgumentsAccessor args = parameterInfo.getArguments(); + + final HNSWTest hnswTest = (HNSWTest)context.getRequiredTestInstance(); + final Config config = (Config)args.get(1); + logger.error("dumping contents of HNSW to {}", hnswTest.tempDir.toString()); + dumpLayers(hnswTest, config); + } else { + logger.error("test failed with no parameterized arguments (non-parameterized test or older JUnit)."); + } + } + + private void dumpLayers(@Nonnull final HNSWTest hnswTest, @Nonnull final Config config) { + int layer = 0; + while (true) { + try { + if (hnswTest.dumpLayer(config, "debug", layer++) == 0) { + break; + } + } catch (IOException e) { + throw new RuntimeException(e); + } + } + } + } + + private static class TestOnWriteListener implements OnWriteListener { + final Map deleteCountByLayer; + + public TestOnWriteListener() { + this.deleteCountByLayer = Maps.newConcurrentMap(); + } + + public Map getDeleteCountByLayer() { + return deleteCountByLayer; + } + + public void reset() { + deleteCountByLayer.clear(); + } + + @Override + public void onNodeDeleted(final int layer, @Nonnull final Tuple primaryKey) { + deleteCountByLayer.compute(layer, (l, oldValue) -> (oldValue == null ? 0 : oldValue) + 1L); + } + } + private static class TestOnReadListener implements OnReadListener { final Map nodeCountByLayer; final Map sumMByLayer; @@ -668,6 +951,20 @@ public Tuple getPrimaryKey() { public RealVector getVector() { return vector; } + + @Override + public boolean equals(final Object o) { + if (o == null || getClass() != o.getClass()) { + return false; + } + final PrimaryKeyAndVector that = (PrimaryKeyAndVector)o; + return Objects.equals(getPrimaryKey(), that.getPrimaryKey()) && Objects.equals(getVector(), that.getVector()); + } + + @Override + public int hashCode() { + return Objects.hash(getPrimaryKey(), getVector()); + } } private static class PrimaryKeyVectorAndDistance extends PrimaryKeyAndVector { @@ -683,5 +980,22 @@ public PrimaryKeyVectorAndDistance(@Nonnull final Tuple primaryKey, public double getDistance() { return distance; } + + @Override + public boolean equals(final Object o) { + if (o == null || getClass() != o.getClass()) { + return false; + } + if (!super.equals(o)) { + return false; + } + final PrimaryKeyVectorAndDistance that = (PrimaryKeyVectorAndDistance)o; + return Double.compare(getDistance(), that.getDistance()) == 0; + } + + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), getDistance()); + } } } diff --git a/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/metadata/IndexOptions.java b/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/metadata/IndexOptions.java index a393debff4..e00bb7a8d4 100644 --- a/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/metadata/IndexOptions.java +++ b/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/metadata/IndexOptions.java @@ -224,13 +224,6 @@ public class IndexOptions { */ public static final String RTREE_USE_NODE_SLOT_INDEX = "rtreeUseNodeSlotIndex"; - /** - * HNSW-only: The seeding method that is used to see the PRNG that is in turn used to probabilistically determine - * the highest layer of an insert into an HNSW structure. See {@link Config#isDeterministicSeeding()}. The default - * random seed is {@link Config#DEFAULT_DETERMINISTIC_SEEDING}. - */ - public static final String HNSW_DETERMINISTIC_SEEDING = "hnswDeterministicSeeding"; - /** * HNSW-only: The metric that is used to determine distances between vectors. The default metric is * {@link Config#DEFAULT_METRIC}. See {@link Config#getMetric()}. @@ -240,7 +233,7 @@ public class IndexOptions { /** * HNSW-only: The number of dimensions used. All vectors must have exactly this number of dimensions. This option * must be set when interacting with a vector index as it there is no default. - * See {@link Config#getNumDimensions()}. + * @see Config#getNumDimensions() */ public static final String HNSW_NUM_DIMENSIONS = "hnswNumDimensions"; @@ -249,7 +242,8 @@ public class IndexOptions { * persisted as a key/value pair per neighbor which includes the vectors of the neighbors but not for itself. If * inlining is not used, each node is persisted as exactly one key/value pair per node which stores its own vector * but specifically excludes the vectors of the neighbors. The default value is set to - * {@link Config#DEFAULT_USE_INLINING}. See {@link Config#isUseInlining()}. + * {@link Config#DEFAULT_USE_INLINING}. + * @see Config#isUseInlining() */ public static final String HNSW_USE_INLINING = "hnswUseInlining"; @@ -258,7 +252,7 @@ public class IndexOptions { * any layer. While by no means enforced or even enforceable, we strive to create and maintain exactly {@code m} * neighbors for a node. Due to insert/delete operations it is possible that the actual number of neighbors a node * references is not exactly {@code m} at any given time. The default value is set to {@link Config#DEFAULT_M}. - * See {@link Config#getM()}. + * @see Config#getM() */ public static final String HNSW_M = "hnswM"; @@ -267,7 +261,8 @@ public class IndexOptions { * stored on a layer greater than {@code 0}. A node can never have more that {@code mMax} neighbors. That means that * neighbors of a node are pruned if the actual number of neighbors would otherwise exceed {@code mMax}. Note that * this option must be greater than or equal to {@link #HNSW_M}. The default value is set to - * {@link Config#DEFAULT_M_MAX}. See {@link Config#getMMax()}. + * {@link Config#DEFAULT_M_MAX}. + * @see Config#getMMax() */ public static final String HNSW_M_MAX = "hnswMMax"; @@ -276,7 +271,8 @@ public class IndexOptions { * stored on layer {@code 0}. We will never create more that {@code mMax0} neighbors for a node that is stored on * that layer. That means that we even prune the neighbors of a node if the actual number of neighbors would * otherwise exceed {@code mMax0}. Note that this option must be greater than or equal to {@link #HNSW_M_MAX}. - * The default value is set to {@link Config#DEFAULT_M_MAX_0}. See {@link Config#getMMax0()}. + * The default value is set to {@link Config#DEFAULT_M_MAX_0}. + * @see Config#getMMax0() */ public static final String HNSW_M_MAX_0 = "hnswMMax0"; @@ -285,15 +281,25 @@ public class IndexOptions { * of a new node. If {@code HNSW_EF_CONSTRUCTION} is set to {@code 1}, the search naturally follows a greedy * approach (monotonous descent), whereas a high number for {@code HNSW_EF_CONSTRUCTION} allows for a more nuanced * search that can tolerate (false) local minima. The default value is set to {@link Config#DEFAULT_EF_CONSTRUCTION}. - * See {@link Config#getEfConstruction()}. + * @see Config#getEfConstruction() */ public static final String HNSW_EF_CONSTRUCTION = "hnswEfConstruction"; + /** + * HNSW-only: Maximum number of candidate nodes that are considered when a HNSW layer is locally repaired as part of + * a delete operation. A smaller number causes the delete operation to create a smaller set of candidate nodes + * which improves repair performance but decreases repair quality; a higher number results in qualitatively + * better repairs at the expense of slower performance. + * The default value is set to {@link Config#DEFAULT_EF_REPAIR}. + * @see Config#getEfRepair() + */ + public static final String HNSW_EF_REPAIR = "hnswEfRepair"; + /** * HNSW-only: Indicator to signal if, during the insertion of a node, the set of nearest neighbors of that node is * to be extended by the actual neighbors of those neighbors to form a set of candidates that the new node may be * connected to during the insert operation. The default value is set to {@link Config#DEFAULT_EXTEND_CANDIDATES}. - * See {@link Config#isExtendCandidates()}. + * @see Config#isExtendCandidates() */ public static final String HNSW_EXTEND_CANDIDATES = "hnswExtendCandidates"; @@ -301,7 +307,8 @@ public class IndexOptions { * HNSW-only: Indicator to signal if, during the insertion of a node, candidates that have been discarded due to not * satisfying the select-neighbor heuristic may get added back in to pad the set of neighbors if the new node would * otherwise have too few neighbors (see {@link Config#getM()}). The default value is set to - * {@link Config#DEFAULT_KEEP_PRUNED_CONNECTIONS}. See {@link Config#isKeepPrunedConnections()}. + * {@link Config#DEFAULT_KEEP_PRUNED_CONNECTIONS}. + * @see Config#isKeepPrunedConnections() */ public static final String HNSW_KEEP_PRUNED_CONNECTIONS = "hnswKeepPrunedConnections"; @@ -310,7 +317,7 @@ public class IndexOptions { * represents the probability of a vector being inserted to also be written into the samples subspace of the hnsw * structure. The vectors in that subspace are continuously aggregated until a total {@link #HNSW_STATS_THRESHOLD} * has been reached. The default value is set to {@link Config#DEFAULT_SAMPLE_VECTOR_STATS_PROBABILITY}. See - * {@link Config#getSampleVectorStatsProbability()}. + * @see Config#getSampleVectorStatsProbability() */ public static final String HNSW_SAMPLE_VECTOR_STATS_PROBABILITY = "hnswSampleVectorStatsProbability"; @@ -319,7 +326,8 @@ public class IndexOptions { * represents the probability of the samples subspace to be further aggregated (rolled-up) when a new vector is * inserted. The vectors in that subspace are continuously aggregated until a total * {@link #HNSW_STATS_THRESHOLD} has been reached. The default value is set to - * {@link Config#DEFAULT_MAINTAIN_STATS_PROBABILITY}. See {@link Config#getMaintainStatsProbability()}. + * {@link Config#DEFAULT_MAINTAIN_STATS_PROBABILITY}. + * @see Config#getMaintainStatsProbability() */ public static final String HNSW_MAINTAIN_STATS_PROBABILITY = "hnswMaintainStatsProbability"; @@ -328,14 +336,15 @@ public class IndexOptions { * represents the threshold (being a number of vectors) that when reached causes the stats maintenance logic to * compute the actual statistics (currently the centroid of the vectors that have been inserted to far). The result * is then inserted into the access info subspace of the index. The default value is set to - * {@link Config#DEFAULT_STATS_THRESHOLD}. See {@link Config#getStatsThreshold()}. + * {@link Config#DEFAULT_STATS_THRESHOLD}. + * @see Config#getStatsThreshold() */ public static final String HNSW_STATS_THRESHOLD = "hnswStatsThreshold"; /** * HNSW-only: Indicator if we should RaBitQ quantization. See {@link com.apple.foundationdb.rabitq.RaBitQuantizer} * for more details. The default value is set to {@link Config#DEFAULT_USE_RABITQ}. - * See {@link Config#isUseRaBitQ()}. + * @see Config#isUseRaBitQ() */ public static final String HNSW_USE_RABITQ = "hnswUseRaBitQ"; @@ -343,24 +352,32 @@ public class IndexOptions { * HNSW-only: Number of bits per dimensions iff {@link #HNSW_USE_RABITQ} is set to {@code "true"}, ignored * otherwise. If RaBitQ encoding is used, a vector is stored using roughly * {@code 25 + numDimensions * (numExBits + 1) / 8} bytes. The default value is set to - * {@link Config#DEFAULT_RABITQ_NUM_EX_BITS}. See {@link Config#getRaBitQNumExBits()}. + * {@link Config#DEFAULT_RABITQ_NUM_EX_BITS}. + * @see Config#getRaBitQNumExBits() */ public static final String HNSW_RABITQ_NUM_EX_BITS = "hnswRaBitQNumExBits"; /** * HNSW-only: Maximum number of concurrent node fetches during search and modification operations. The default value * is set to {@link Config#DEFAULT_MAX_NUM_CONCURRENT_NODE_FETCHES}. - * See {@link Config#getMaxNumConcurrentNodeFetches()}. + * @see Config#getMaxNumConcurrentNodeFetches() */ public static final String HNSW_MAX_NUM_CONCURRENT_NODE_FETCHES = "hnswMaxNumConcurrentNodeFetches"; /** * HNSW-only: Maximum number of concurrent neighborhood fetches during modification operations when the neighbors * are pruned. The default value is set to {@link Config#DEFAULT_MAX_NUM_CONCURRENT_NEIGHBOR_FETCHES}. - * See {@link Config#getMaxNumConcurrentNeighborhoodFetches()}. + * @see Config#getMaxNumConcurrentNeighborhoodFetches() */ public static final String HNSW_MAX_NUM_CONCURRENT_NEIGHBORHOOD_FETCHES = "hnswMaxNumConcurrentNeighborhoodFetches"; + /** + * HNSW-only: Maximum number of delete operations that can run concurrently in separate layers during the deletion + * of a record. The default value is set to {@link Config#DEFAULT_MAX_NUM_CONCURRENT_DELETE_FROM_LAYER}. + * @see Config#getMaxNumConcurrentDeleteFromLayer() + */ + public static final String HNSW_MAX_NUM_CONCURRENT_DELETE_FROM_LAYER = "hnswMaxNumConcurrentDeleteFromLayer"; + private IndexOptions() { } } diff --git a/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/provider/foundationdb/VectorIndexScanComparisons.java b/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/provider/foundationdb/VectorIndexScanComparisons.java index d2cc648f3e..83f4fff756 100644 --- a/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/provider/foundationdb/VectorIndexScanComparisons.java +++ b/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/provider/foundationdb/VectorIndexScanComparisons.java @@ -54,7 +54,7 @@ * {@link ScanComparisons} for use in a multidimensional index scan. */ @API(API.Status.UNSTABLE) -public class VectorIndexScanComparisons implements IndexScanParameters { +public final class VectorIndexScanComparisons implements IndexScanParameters { @Nonnull private final ScanComparisons prefixScanComparisons; @Nonnull @@ -62,9 +62,9 @@ public class VectorIndexScanComparisons implements IndexScanParameters { @Nonnull private final VectorIndexScanOptions vectorIndexScanOptions; - public VectorIndexScanComparisons(@Nonnull final ScanComparisons prefixScanComparisons, - @Nonnull final DistanceRankValueComparison distanceRankValueComparison, - @Nonnull final VectorIndexScanOptions vectorIndexScanOptions) { + private VectorIndexScanComparisons(@Nonnull final ScanComparisons prefixScanComparisons, + @Nonnull final DistanceRankValueComparison distanceRankValueComparison, + @Nonnull final VectorIndexScanOptions vectorIndexScanOptions) { this.prefixScanComparisons = prefixScanComparisons; this.distanceRankValueComparison = distanceRankValueComparison; this.vectorIndexScanOptions = vectorIndexScanOptions; @@ -261,9 +261,9 @@ public IndexScanParameters translateCorrelations(@Nonnull final TranslationMap t } @Nonnull - protected VectorIndexScanComparisons withComparisonsAndOptions(@Nonnull final ScanComparisons prefixScanComparisons, - @Nonnull final DistanceRankValueComparison distanceRankValueComparison, - @Nonnull final VectorIndexScanOptions vectorIndexScanOptions) { + VectorIndexScanComparisons withComparisonsAndOptions(@Nonnull final ScanComparisons prefixScanComparisons, + @Nonnull final DistanceRankValueComparison distanceRankValueComparison, + @Nonnull final VectorIndexScanOptions vectorIndexScanOptions) { return new VectorIndexScanComparisons(prefixScanComparisons, distanceRankValueComparison, vectorIndexScanOptions); } @@ -311,14 +311,12 @@ public static VectorIndexScanComparisons fromProto(@Nonnull final PlanSerializat } @Nonnull - public static VectorIndexScanComparisons byDistance(@Nullable ScanComparisons prefixScanComparisons, + public static VectorIndexScanComparisons byDistance(@Nullable final ScanComparisons prefixScanComparisons, @Nonnull final DistanceRankValueComparison distanceRankValueComparison, - @Nonnull VectorIndexScanOptions vectorIndexScanOptions) { - if (prefixScanComparisons == null) { - prefixScanComparisons = ScanComparisons.EMPTY; - } - - return new VectorIndexScanComparisons(prefixScanComparisons, distanceRankValueComparison, + @Nonnull final VectorIndexScanOptions vectorIndexScanOptions) { + return new VectorIndexScanComparisons( + prefixScanComparisons == null ? ScanComparisons.EMPTY : prefixScanComparisons, + distanceRankValueComparison, vectorIndexScanOptions); } diff --git a/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/provider/foundationdb/indexes/VectorIndexHelper.java b/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/provider/foundationdb/indexes/VectorIndexHelper.java index a0be8a86c0..6e2c7abf8e 100644 --- a/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/provider/foundationdb/indexes/VectorIndexHelper.java +++ b/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/provider/foundationdb/indexes/VectorIndexHelper.java @@ -49,10 +49,6 @@ private VectorIndexHelper() { @Nonnull public static Config getConfig(@Nonnull final Index index) { final ConfigBuilder builder = HNSW.newConfigBuilder(); - final String hnswRandomSeedOption = index.getOption(IndexOptions.HNSW_DETERMINISTIC_SEEDING); - if (hnswRandomSeedOption != null) { - builder.setDeterministicSeeding(Boolean.parseBoolean(hnswRandomSeedOption)); - } final String hnswMetricOption = index.getOption(IndexOptions.HNSW_METRIC); if (hnswMetricOption != null) { builder.setMetric(Metric.valueOf(hnswMetricOption)); @@ -84,6 +80,10 @@ public static Config getConfig(@Nonnull final Index index) { if (hnswEfConstructionOption != null) { builder.setEfConstruction(Integer.parseInt(hnswEfConstructionOption)); } + final String hnswEfRepairOption = index.getOption(IndexOptions.HNSW_EF_REPAIR); + if (hnswEfRepairOption != null) { + builder.setEfRepair(Integer.parseInt(hnswEfRepairOption)); + } final String hnswExtendCandidatesOption = index.getOption(IndexOptions.HNSW_EXTEND_CANDIDATES); if (hnswExtendCandidatesOption != null) { builder.setExtendCandidates(Boolean.parseBoolean(hnswExtendCandidatesOption)); @@ -120,6 +120,10 @@ public static Config getConfig(@Nonnull final Index index) { if (hnswMaxNumConcurrentNeighborhoodFetchesOption != null) { builder.setMaxNumConcurrentNeighborhoodFetches(Integer.parseInt(hnswMaxNumConcurrentNeighborhoodFetchesOption)); } + final String hnswMaxNumConcurrentDeleteFromLayerOption = index.getOption(IndexOptions.HNSW_MAX_NUM_CONCURRENT_DELETE_FROM_LAYER); + if (hnswMaxNumConcurrentDeleteFromLayerOption != null) { + builder.setMaxNumConcurrentDeleteFromLayer(Integer.parseInt(hnswMaxNumConcurrentDeleteFromLayerOption)); + } return builder.build(numDimensions); } diff --git a/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/provider/foundationdb/indexes/VectorIndexMaintainer.java b/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/provider/foundationdb/indexes/VectorIndexMaintainer.java index e5cb5bc996..d63176aab2 100644 --- a/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/provider/foundationdb/indexes/VectorIndexMaintainer.java +++ b/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/provider/foundationdb/indexes/VectorIndexMaintainer.java @@ -346,7 +346,7 @@ protected CompletableFuture updateIndexKeys(@Nonnull f final HNSW hnsw = new HNSW(rtSubspace, getExecutor(), getConfig(), new OnWrite(timer), OnReadListener.NOOP); if (remove) { - throw new UnsupportedOperationException("not implemented"); + return hnsw.delete(state.transaction, trimmedPrimaryKey); } else { return hnsw.insert(state.transaction, trimmedPrimaryKey, RealVector.fromBytes(vectorBytes)); diff --git a/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/provider/foundationdb/indexes/VectorIndexMaintainerFactory.java b/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/provider/foundationdb/indexes/VectorIndexMaintainerFactory.java index 5e164ba943..2a83fb355a 100644 --- a/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/provider/foundationdb/indexes/VectorIndexMaintainerFactory.java +++ b/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/provider/foundationdb/indexes/VectorIndexMaintainerFactory.java @@ -137,8 +137,6 @@ public void validateChangedOptions(@Nonnull final Index oldIndex, final Config newOptions = VectorIndexHelper.getConfig(index); // do not allow changing any of the following - disallowChange(changedOptions, IndexOptions.HNSW_DETERMINISTIC_SEEDING, - oldOptions, newOptions, Config::isDeterministicSeeding); disallowChange(changedOptions, IndexOptions.HNSW_METRIC, oldOptions, newOptions, Config::getMetric); disallowChange(changedOptions, IndexOptions.HNSW_NUM_DIMENSIONS, @@ -153,6 +151,8 @@ public void validateChangedOptions(@Nonnull final Index oldIndex, oldOptions, newOptions, Config::getMMax0); disallowChange(changedOptions, IndexOptions.HNSW_EF_CONSTRUCTION, oldOptions, newOptions, Config::getEfConstruction); + disallowChange(changedOptions, IndexOptions.HNSW_EF_REPAIR, + oldOptions, newOptions, Config::getEfRepair); disallowChange(changedOptions, IndexOptions.HNSW_EXTEND_CANDIDATES, oldOptions, newOptions, Config::isExtendCandidates); disallowChange(changedOptions, IndexOptions.HNSW_KEEP_PRUNED_CONNECTIONS, @@ -168,6 +168,7 @@ public void validateChangedOptions(@Nonnull final Index oldIndex, changedOptions.remove(IndexOptions.HNSW_STATS_THRESHOLD); changedOptions.remove(IndexOptions.HNSW_MAX_NUM_CONCURRENT_NODE_FETCHES); changedOptions.remove(IndexOptions.HNSW_MAX_NUM_CONCURRENT_NEIGHBORHOOD_FETCHES); + changedOptions.remove(IndexOptions.HNSW_MAX_NUM_CONCURRENT_DELETE_FROM_LAYER); } super.validateChangedOptions(oldIndex, changedOptions); } diff --git a/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/metadata/MetaDataProtoTest.java b/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/metadata/MetaDataProtoTest.java index 34c64f23e3..b6865c91f9 100644 --- a/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/metadata/MetaDataProtoTest.java +++ b/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/metadata/MetaDataProtoTest.java @@ -58,6 +58,7 @@ import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.ArgumentsProvider; import org.junit.jupiter.params.provider.ArgumentsSource; +import org.junit.jupiter.params.support.ParameterDeclarations; import javax.annotation.Nonnull; import java.util.Arrays; @@ -376,7 +377,8 @@ public void indexGroupingCompatibility() throws Exception { private static class ArgumentProvider implements ArgumentsProvider { @Override - public Stream provideArguments(final ExtensionContext context) throws Exception { + public Stream provideArguments(final ParameterDeclarations parameterDeclarations, + final ExtensionContext context) { return Stream.of(Arguments.of("double parameter", 10.10d, 12, 12d), Arguments.of("float parameter", 11.11f, 13.13f), Arguments.of("long parameter", 42L, 44L), diff --git a/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/provider/foundationdb/RemoteFetchIndexScanTest.java b/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/provider/foundationdb/RemoteFetchIndexScanTest.java index becfdf0e45..0481501131 100644 --- a/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/provider/foundationdb/RemoteFetchIndexScanTest.java +++ b/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/provider/foundationdb/RemoteFetchIndexScanTest.java @@ -66,7 +66,7 @@ import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assumptions.assumeTrue; -import static org.junit.jupiter.params.ParameterizedTest.ARGUMENTS_WITH_NAMES_PLACEHOLDER; +import static org.junit.jupiter.params.ParameterizedInvocationConstants.ARGUMENTS_WITH_NAMES_PLACEHOLDER; /** * A test for the remote fetch index scan wrapper. diff --git a/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/provider/foundationdb/RemoteFetchMultiColumnKeyTest.java b/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/provider/foundationdb/RemoteFetchMultiColumnKeyTest.java index 32cb1dac6f..a4dc79ff6b 100644 --- a/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/provider/foundationdb/RemoteFetchMultiColumnKeyTest.java +++ b/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/provider/foundationdb/RemoteFetchMultiColumnKeyTest.java @@ -45,7 +45,7 @@ import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.equalTo; import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.params.ParameterizedTest.ARGUMENTS_WITH_NAMES_PLACEHOLDER; +import static org.junit.jupiter.params.ParameterizedInvocationConstants.ARGUMENTS_WITH_NAMES_PLACEHOLDER; /** * Remote fetch test with a compound primary key. diff --git a/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/provider/foundationdb/RemoteFetchSplitRecordsTest.java b/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/provider/foundationdb/RemoteFetchSplitRecordsTest.java index 70c8e1ef9b..6be4ea3e0a 100644 --- a/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/provider/foundationdb/RemoteFetchSplitRecordsTest.java +++ b/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/provider/foundationdb/RemoteFetchSplitRecordsTest.java @@ -33,7 +33,7 @@ import javax.annotation.Nonnull; -import static org.junit.jupiter.params.ParameterizedTest.ARGUMENTS_WITH_NAMES_PLACEHOLDER; +import static org.junit.jupiter.params.ParameterizedInvocationConstants.ARGUMENTS_WITH_NAMES_PLACEHOLDER; /** * A test for the Remote Fetch with large records that are split (more than just the version split). diff --git a/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/provider/foundationdb/RemoteFetchTest.java b/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/provider/foundationdb/RemoteFetchTest.java index 152e5c871e..8096cb282f 100644 --- a/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/provider/foundationdb/RemoteFetchTest.java +++ b/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/provider/foundationdb/RemoteFetchTest.java @@ -63,7 +63,7 @@ import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assumptions.assumeTrue; -import static org.junit.jupiter.params.ParameterizedTest.ARGUMENTS_WITH_NAMES_PLACEHOLDER; +import static org.junit.jupiter.params.ParameterizedInvocationConstants.ARGUMENTS_WITH_NAMES_PLACEHOLDER; /** * A test for the remote fetch feature. diff --git a/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/provider/foundationdb/indexes/VectorIndexTest.java b/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/provider/foundationdb/indexes/VectorIndexTest.java index 65132ad37c..9a8dad467c 100644 --- a/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/provider/foundationdb/indexes/VectorIndexTest.java +++ b/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/provider/foundationdb/indexes/VectorIndexTest.java @@ -20,7 +20,6 @@ package com.apple.foundationdb.record.provider.foundationdb.indexes; -import com.apple.foundationdb.async.hnsw.NodeReference; import com.apple.foundationdb.linear.HalfRealVector; import com.apple.foundationdb.linear.Metric; import com.apple.foundationdb.record.Bindings; @@ -59,9 +58,12 @@ import com.apple.foundationdb.record.query.plan.plans.RecordQueryFetchFromPartialRecordPlan; import com.apple.foundationdb.record.query.plan.plans.RecordQueryIndexPlan; import com.apple.foundationdb.record.vector.TestRecordsVectorsProto.VectorRecord; +import com.apple.foundationdb.tuple.Tuple; import com.apple.test.RandomizedTestUtils; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Maps; import com.google.common.collect.ObjectArrays; import com.google.common.collect.Sets; import com.google.common.primitives.Ints; @@ -81,7 +83,7 @@ import java.util.Optional; import java.util.Random; import java.util.Set; -import java.util.stream.Collectors; +import java.util.stream.IntStream; import java.util.stream.Stream; import static com.apple.foundationdb.record.metadata.Key.Expressions.concat; @@ -118,7 +120,7 @@ static Stream randomSeedsWithAsyncAndLimit() { void basicWriteReadTest(final long seed, final boolean useAsync) throws Exception { final Random random = new Random(seed); final List> savedRecords = - saveRecords(useAsync, this::addVectorIndexes, random, 1000, 0.3); + saveRandomRecords(useAsync, this::addVectorIndexes, random, 1000, 0.3); try (final FDBRecordContext context = openContext()) { openRecordStore(context, this::addVectorIndexes); for (int l = 0; l < 1000; l ++) { @@ -140,7 +142,7 @@ void basicWriteIndexReadWithContinuationTest(final long seed, final boolean useA final HalfRealVector queryVector = randomHalfVector(random, 128); final List> savedRecords = - saveRecords(useAsync, this::addUngroupedVectorIndex, random, 1000); + saveRandomRecords(useAsync, this::addUngroupedVectorIndex, random, 1000); final Set expectedResults = sortByDistances(savedRecords, queryVector, Metric.EUCLIDEAN_METRIC).stream() @@ -149,9 +151,15 @@ void basicWriteIndexReadWithContinuationTest(final long seed, final boolean useA nodeReferenceWithDistance.getPrimaryKey().getLong(0)) .collect(ImmutableSet.toImmutableSet()); - final var indexPlan = + final RecordQueryIndexPlan indexPlan = createIndexPlan(queryVector, k, "UngroupedVectorIndex"); + checkResults(indexPlan, limit, expectedResults); + } + + private void checkResults(@Nonnull final RecordQueryIndexPlan indexPlan, + final int limit, + @Nonnull final Set expectedResults) throws Exception { verifyRebase(indexPlan); verifySerialization(indexPlan); @@ -188,26 +196,35 @@ void basicWriteIndexReadWithContinuationTest(final long seed, final boolean useA } } } while (continuation != null); - assertThat(allCounter).isEqualTo(k); - assertThat((double)recallCounter / k).isGreaterThan(0.9); + assertThat(allCounter).isEqualTo(expectedResults.size()); + assertThat((double)recallCounter / expectedResults.size()).isGreaterThan(0.9); } } @ParameterizedTest @MethodSource("randomSeedsWithAsyncAndLimit") void basicWriteIndexReadGroupedWithContinuationTest(final long seed, final boolean useAsync, final int limit) throws Exception { + final int size = 1000; final int k = 100; final Random random = new Random(seed); final HalfRealVector queryVector = randomHalfVector(random, 128); - final Map> expectedResults = - saveRandomRecords(random, this::addGroupedVectorIndex, useAsync, 1000, queryVector); - final var indexPlan = createIndexPlan(queryVector, k, "GroupedVectorIndex"); + final List> savedRecords = + saveRandomRecords(useAsync, this::addGroupedVectorIndex, random, size); + final Map> randomRecords = groupAndSortByDistances(savedRecords, queryVector); + final Map> expectedResults = trueTopK(randomRecords, k); + final RecordQueryIndexPlan indexPlan = createIndexPlan(queryVector, k, "GroupedVectorIndex"); + + checkResultsGrouped(indexPlan, limit, expectedResults); + } + + private void checkResultsGrouped(@Nonnull final RecordQueryIndexPlan indexPlan, final int limit, + @Nonnull final Map> expectedResults) throws Exception { verifyRebase(indexPlan); verifySerialization(indexPlan); - try (FDBRecordContext context = openContext()) { + try (final FDBRecordContext context = openContext()) { openRecordStore(context, this::addGroupedVectorIndex); final int[] allCounters = new int[2]; @@ -240,13 +257,111 @@ void basicWriteIndexReadGroupedWithContinuationTest(final long seed, final boole } } } while (continuation != null); - assertThat(Ints.asList(allCounters)) - .allSatisfy(allCounter -> - assertThat(allCounter).isEqualTo(k)); - assertThat(Ints.asList(recallCounters)) - .allSatisfy(recallCounter -> - assertThat((double)recallCounter / k).isGreaterThan(0.9)); + + IntStream.range(0, allCounters.length) + .forEach(index -> { + assertThat(allCounters[index]) + .as("allCounters[%d]", index) + .satisfies(allCountersAtIndex -> { + assertThat(allCountersAtIndex).isEqualTo( + expectedResults.getOrDefault(index, ImmutableSet.of()).size()); + }); + assertThat(recallCounters[index]) + .as("recallCounters[%d]", index) + .satisfies(recallCountersAtIndex -> { + assertThat((double)recallCountersAtIndex / + expectedResults.getOrDefault(index, ImmutableSet.of()).size()) + .isGreaterThan(0.9); + }); + + }); + } + } + + @ParameterizedTest + @MethodSource("randomSeedsWithAsyncAndLimit") + void insertReadDeleteReadGroupedWithContinuationTest(final long seed, final boolean useAsync, final int limit) throws Exception { + final int size = 1000; + Assertions.assertThat(size % 2).isEqualTo(0); // needs to be even + final int updateBatchSize = 50; + Assertions.assertThat(size % updateBatchSize).isEqualTo(0); // needs to be divisible + + final int k = 100; + final Random random = new Random(seed); + final var savedRecords = saveRandomRecords(useAsync, this::addGroupedVectorIndex, random, size); + + final HalfRealVector queryVector = randomHalfVector(random, 128); + + // + // Artificially create a lot of churn. Take the first record and flip its vector with the 999th vector, + // take the second record and flip it with the 998th and so on. We still know the expected ground truth and + // can compensate for that. + // + for (int i = 0; i < size / 2;) { + try (FDBRecordContext context = openContext()) { + openRecordStore(context, this::addGroupedVectorIndex); + for (int b = 0; b < updateBatchSize; b ++) { + final int nearerGroupId = i % 2; + final FDBStoredRecord nearer = + Objects.requireNonNull(recordStore.loadRecord(Tuple.from(nearerGroupId, i))); + final VectorRecord nearerRecord = + VectorRecord.newBuilder() + .mergeFrom(nearer.getRecord()) + .build(); + final int furtherRecId = size - i - 1; + final int furtherGroupId = furtherRecId % 2; + final FDBStoredRecord further = + Objects.requireNonNull(recordStore.loadRecord(Tuple.from(furtherGroupId, furtherRecId))); + final VectorRecord furtherRecord = VectorRecord.newBuilder() + .mergeFrom(further.getRecord()) + .build(); + + final Message newNearer = VectorRecord.newBuilder() + .setRecNo(nearerRecord.getRecNo()) + .setGroupId(nearerRecord.getGroupId()) + .setVectorData(furtherRecord.getVectorData()) + .build(); + final Message newFurther = VectorRecord.newBuilder() + .setRecNo(furtherRecord.getRecNo()) + .setGroupId(furtherRecord.getGroupId()) + .setVectorData(nearerRecord.getVectorData()) + .build(); + + recordStore.updateRecord(newNearer); + recordStore.updateRecord(newFurther); + i ++; + } + commit(context); + } } + + final List> flippedRecords = + savedRecords + .stream() + .map(storedRecord -> { + final VectorRecord vectorRecord = + VectorRecord.newBuilder() + .mergeFrom(storedRecord.getRecord()) + .build(); + final VectorRecord newVectorRecord = + VectorRecord.newBuilder() + .setGroupId((int)(size - vectorRecord.getRecNo() - 1) % 2) + .setRecNo(size - vectorRecord.getRecNo() - 1) + .setVectorData(vectorRecord.getVectorData()) + .build(); + return FDBStoredRecord.newBuilder() + .setRecord(newVectorRecord) + .setPrimaryKey(storedRecord.getPrimaryKey()) + .setRecordType(storedRecord.getRecordType()) + .build(); + }) + .collect(ImmutableList.toImmutableList()); + final Map> groupedFlippedRecords = groupAndSortByDistances(flippedRecords, queryVector); + final Map> expectedResults = trueTopK(groupedFlippedRecords, k); + + final RecordQueryIndexPlan indexPlan = createIndexPlan(queryVector, k, "GroupedVectorIndex"); + + checkResultsGrouped(indexPlan, limit, expectedResults); } @ParameterizedTest @@ -256,35 +371,21 @@ void deleteWhereGroupedTest(final long seed, final boolean useAsync) throws Exce final Random random = new Random(seed); final HalfRealVector queryVector = randomHalfVector(random, 128); - final Map> expectedResults = saveRandomRecords(random, this::addGroupedVectorIndex, - useAsync, 200, queryVector); - final var indexPlan = createIndexPlan(queryVector, k, "GroupedVectorIndex"); + final List> savedRecords = + saveRandomRecords(useAsync, this::addGroupedVectorIndex, random, 200); + final Map> randomRecords = groupAndSortByDistances(savedRecords, queryVector); + final Map> expectedResults = + Maps.filterKeys( + trueTopK(randomRecords, 200), key -> Objects.requireNonNull(key) % 2 != 0); try (FDBRecordContext context = openContext()) { openRecordStore(context, this::addGroupedVectorIndex); recordStore.deleteRecordsWhere(Query.field("group_id").equalsValue(0)); - - final int[] allCounters = new int[2]; - final int[] recallCounters = new int[2]; - try (final RecordCursorIterator> cursor = executeQuery(indexPlan)) { - while (cursor.hasNext()) { - final FDBQueriedRecord rec = cursor.next(); - final VectorRecord record = - VectorRecord.newBuilder() - .mergeFrom(Objects.requireNonNull(rec).getRecord()) - .build(); - allCounters[record.getGroupId()] ++; - if (expectedResults.get(record.getGroupId()).contains(record.getRecNo())) { - recallCounters[record.getGroupId()] ++; - } - } - } - assertThat(allCounters[0]).isEqualTo(0); - assertThat(allCounters[1]).isEqualTo(k); - - assertThat((double)recallCounters[0] / k).isEqualTo(0.0); - assertThat((double)recallCounters[1] / k).isGreaterThan(0.9); + commit(context); } + + final RecordQueryIndexPlan indexPlan = createIndexPlan(queryVector, k, "GroupedVectorIndex"); + checkResultsGrouped(indexPlan, Integer.MAX_VALUE, expectedResults); } @Test @@ -304,7 +405,6 @@ void directIndexValidatorTest() throws Exception { validateIndexEvolution(metaDataValidator, index, ImmutableMap.builder() // cannot change those per se but must accept same value - .put(IndexOptions.HNSW_DETERMINISTIC_SEEDING, "false") .put(IndexOptions.HNSW_METRIC, Metric.EUCLIDEAN_METRIC.name()) .put(IndexOptions.HNSW_NUM_DIMENSIONS, "128") .put(IndexOptions.HNSW_USE_INLINING, "false") @@ -312,6 +412,7 @@ void directIndexValidatorTest() throws Exception { .put(IndexOptions.HNSW_M_MAX, "16") .put(IndexOptions.HNSW_M_MAX_0, "32") .put(IndexOptions.HNSW_EF_CONSTRUCTION, "200") + .put(IndexOptions.HNSW_EF_REPAIR, "64") .put(IndexOptions.HNSW_EXTEND_CANDIDATES, "false") .put(IndexOptions.HNSW_KEEP_PRUNED_CONNECTIONS, "false") .put(IndexOptions.HNSW_USE_RABITQ, "false") @@ -322,11 +423,8 @@ void directIndexValidatorTest() throws Exception { .put(IndexOptions.HNSW_MAINTAIN_STATS_PROBABILITY, "0.78") .put(IndexOptions.HNSW_STATS_THRESHOLD, "500") .put(IndexOptions.HNSW_MAX_NUM_CONCURRENT_NODE_FETCHES, "17") - .put(IndexOptions.HNSW_MAX_NUM_CONCURRENT_NEIGHBORHOOD_FETCHES, "9").build()); - - Assertions.assertThatThrownBy(() -> validateIndexEvolution(metaDataValidator, index, - ImmutableMap.of(IndexOptions.HNSW_NUM_DIMENSIONS, "128", - IndexOptions.HNSW_DETERMINISTIC_SEEDING, "true"))).isInstanceOf(MetaDataException.class); + .put(IndexOptions.HNSW_MAX_NUM_CONCURRENT_NEIGHBORHOOD_FETCHES, "9") + .put(IndexOptions.HNSW_MAX_NUM_CONCURRENT_DELETE_FROM_LAYER, "5").build()); Assertions.assertThatThrownBy(() -> validateIndexEvolution(metaDataValidator, index, ImmutableMap.of(IndexOptions.HNSW_NUM_DIMENSIONS, "128", @@ -357,6 +455,10 @@ void directIndexValidatorTest() throws Exception { ImmutableMap.of(IndexOptions.HNSW_NUM_DIMENSIONS, "128", IndexOptions.HNSW_EF_CONSTRUCTION, "500"))).isInstanceOf(MetaDataException.class); + Assertions.assertThatThrownBy(() -> validateIndexEvolution(metaDataValidator, index, + ImmutableMap.of(IndexOptions.HNSW_NUM_DIMENSIONS, "128", + IndexOptions.HNSW_EF_REPAIR, "500"))).isInstanceOf(MetaDataException.class); + Assertions.assertThatThrownBy(() -> validateIndexEvolution(metaDataValidator, index, ImmutableMap.of(IndexOptions.HNSW_NUM_DIMENSIONS, "128", IndexOptions.HNSW_EXTEND_CANDIDATES, "true"))).isInstanceOf(MetaDataException.class); @@ -413,12 +515,15 @@ void directIndexMaintainerTest() throws Exception { @ParameterizedTest @MethodSource("randomSeedsWithReturnVectors") void directIndexReadGroupedWithContinuationTest(final long seed, final boolean returnVectors) throws Exception { + final int size = 1000; final int k = 100; final Random random = new Random(seed); final HalfRealVector queryVector = randomHalfVector(random, 128); - final Map> expectedResults = - saveRandomRecords(random, this::addGroupedVectorIndex, true, 1000, queryVector); + final List> savedRecords = + saveRandomRecords(true, this::addGroupedVectorIndex, random, size); + final Map> randomRecords = groupAndSortByDistances(savedRecords, queryVector); + final Map> expectedResults = trueTopK(randomRecords, k); try (FDBRecordContext context = openContext()) { openRecordStore(context, this::addGroupedVectorIndex); @@ -440,7 +545,6 @@ void directIndexReadGroupedWithContinuationTest(final long seed, final boolean r .setState(ExecuteState.NO_LIMITS) .setReturnedRowLimit(Integer.MAX_VALUE).build().asScanProperties(false); - try (final RecordCursor cursor = indexMaintainer.scan(vectorIndexScanComparisons.bind(recordStore, index, EvaluationContext.empty()), null, scanProperties)) { @@ -458,7 +562,7 @@ void directIndexReadGroupedWithContinuationTest(final long seed, final boolean r assertThat(indexEntry.getValue().get(0) != null).isEqualTo(returnVectors); } if (logger.isInfoEnabled()) { - logger.info("grouped read {} records, allCounters={}, recallCounters={}", numRecords, allCounters, + logger.info("(direct) grouped read {} records, allCounters={}, recallCounters={}", numRecords, allCounters, recallCounters); } } @@ -475,10 +579,10 @@ void directIndexReadGroupedWithContinuationTest(final long seed, final boolean r @Nonnull private static RecordQueryIndexPlan createIndexPlan(@Nonnull final HalfRealVector queryVector, final int k, @Nonnull final String indexName) { - final var vectorIndexScanComparisons = + final VectorIndexScanComparisons vectorIndexScanComparisons = createVectorIndexScanComparisons(queryVector, k, VectorIndexScanOptions.empty()); - final var baseRecordType = + final Type.Record baseRecordType = Type.Record.fromFieldDescriptorsMap( Type.Record.toFieldDescriptorMap(VectorRecord.getDescriptor().getFields())); @@ -499,18 +603,4 @@ private static VectorIndexScanComparisons createVectorIndexScanComparisons(@Nonn return VectorIndexScanComparisons.byDistance(ScanComparisons.EMPTY, distanceRankComparison, vectorIndexScanOptions); } - - @Nonnull - private Map> saveRandomRecords(@Nonnull final Random random, @Nonnull final RecordMetaDataHook hook, - final boolean useAsync, final int numSamples, - @Nonnull final HalfRealVector queryVector) throws Exception { - final List> savedRecords = - saveRecords(useAsync, hook, random, numSamples); - - return sortByDistances(savedRecords, queryVector, Metric.EUCLIDEAN_METRIC) - .stream() - .map(NodeReference::getPrimaryKey) - .map(primaryKey -> primaryKey.getLong(0)) - .collect(Collectors.groupingBy(nodeId -> Math.toIntExact(nodeId) % 2, Collectors.toSet())); - } } diff --git a/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/provider/foundationdb/indexes/VectorIndexTestBase.java b/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/provider/foundationdb/indexes/VectorIndexTestBase.java index 525d64f593..3fe7ee9535 100644 --- a/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/provider/foundationdb/indexes/VectorIndexTestBase.java +++ b/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/provider/foundationdb/indexes/VectorIndexTestBase.java @@ -21,6 +21,7 @@ package com.apple.foundationdb.record.provider.foundationdb.indexes; import com.apple.foundationdb.async.AsyncUtil; +import com.apple.foundationdb.async.hnsw.NodeReference; import com.apple.foundationdb.async.hnsw.NodeReferenceWithDistance; import com.apple.foundationdb.half.Half; import com.apple.foundationdb.linear.AffineOperator; @@ -42,6 +43,7 @@ import com.apple.test.Tags; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; import com.google.errorprone.annotations.CanIgnoreReturnValue; import com.google.protobuf.ByteString; import com.google.protobuf.Message; @@ -54,9 +56,12 @@ import java.util.ArrayList; import java.util.Comparator; import java.util.List; +import java.util.Map; import java.util.Random; +import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.function.Function; +import java.util.stream.Collectors; import static com.apple.foundationdb.record.metadata.Key.Expressions.concat; import static com.apple.foundationdb.record.metadata.Key.Expressions.concatenateFields; @@ -133,24 +138,24 @@ protected static HalfRealVector randomHalfVector(final Random random, final int return new HalfRealVector(componentData); } - protected List> saveRecords(final boolean useAsync, - @Nonnull final RecordMetaDataHook hook, - @Nonnull final Random random, - final int numSamples) throws Exception { - return saveRecords(useAsync, hook, random, numSamples, 0.0d); + protected List> saveRandomRecords(final boolean useAsync, + @Nonnull final RecordMetaDataHook hook, + @Nonnull final Random random, + final int numRecords) throws Exception { + return saveRandomRecords(useAsync, hook, random, numRecords, 0.0d); } - protected List> saveRecords(final boolean useAsync, - @Nonnull final RecordMetaDataHook hook, - @Nonnull final Random random, - final int numSamples, - final double nullProbability) throws Exception { + protected List> saveRandomRecords(final boolean useAsync, + @Nonnull final RecordMetaDataHook hook, + @Nonnull final Random random, + final int numRecords, + final double nullProbability) throws Exception { final var recordGenerator = getRecordGenerator(random, nullProbability); if (useAsync) { - return asyncBatch(hook, numSamples, 100, + return asyncBatch(hook, numRecords, 100, recNo -> recordStore.saveRecordAsync(recordGenerator.apply(recNo))); } else { - return batch(hook, numSamples, 100, + return batch(hook, numRecords, 100, recNo -> recordStore.saveRecord(recordGenerator.apply(recNo))); } } @@ -200,6 +205,30 @@ private List> batch(final RecordMetaDataH return records; } + @Nonnull + protected static Map> trueTopK(@Nonnull final Map> sortedByDistances, + final int k) { + return sortedByDistances.entrySet() + .stream() + .collect(Collectors.toMap(Map.Entry::getKey, + entry -> + entry.getValue() + .stream() + .limit(k) + .collect(ImmutableSet.toImmutableSet()))); + } + + @Nonnull + protected static Map> groupAndSortByDistances(@Nonnull final List> savedRecords, + @Nonnull final HalfRealVector queryVector) { + return sortByDistances(savedRecords, queryVector, Metric.EUCLIDEAN_METRIC) + .stream() + .map(NodeReference::getPrimaryKey) + .map(primaryKey -> primaryKey.getLong(0)) + .collect(Collectors.groupingBy(nodeId -> Math.toIntExact(nodeId) % 2, Collectors.toList())); + } + + @Nonnull protected static List sortByDistances(@Nonnull final List> storedRecords, @Nonnull final RealVector queryVector, diff --git a/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/provider/foundationdb/indexes/VersionIndexTest.java b/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/provider/foundationdb/indexes/VersionIndexTest.java index 4d53b14828..23839bee9d 100644 --- a/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/provider/foundationdb/indexes/VersionIndexTest.java +++ b/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/provider/foundationdb/indexes/VersionIndexTest.java @@ -146,7 +146,7 @@ import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.fail; import static org.junit.jupiter.api.Assumptions.assumeTrue; -import static org.junit.jupiter.params.ParameterizedTest.ARGUMENTS_PLACEHOLDER; +import static org.junit.jupiter.params.ParameterizedInvocationConstants.ARGUMENTS_PLACEHOLDER; /** * Tests for {@code VERSION} type indexes. diff --git a/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/provider/foundationdb/recordrepair/ValidationTestUtils.java b/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/provider/foundationdb/recordrepair/ValidationTestUtils.java index 6412a4f1a2..9ee88b03da 100644 --- a/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/provider/foundationdb/recordrepair/ValidationTestUtils.java +++ b/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/provider/foundationdb/recordrepair/ValidationTestUtils.java @@ -42,7 +42,7 @@ public class ValidationTestUtils { private static final int LONG_RECORD_SPACING = 17; - // A few constants for records that were saved with saveRecords() below + // A few constants for records that were saved with saveRandomRecords() below public static final int RECORD_INDEX_WITH_NO_SPLITS = 1; public static final int RECORD_ID_WITH_NO_SPLITS = RECORD_INDEX_WITH_NO_SPLITS + 1; public static final int RECORD_INDEX_WITH_TWO_SPLITS = 16; diff --git a/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/query/plan/cascades/ArithmeticValueTest.java b/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/query/plan/cascades/ArithmeticValueTest.java index c06062d56b..e191f93a1d 100644 --- a/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/query/plan/cascades/ArithmeticValueTest.java +++ b/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/query/plan/cascades/ArithmeticValueTest.java @@ -41,6 +41,7 @@ import org.junit.jupiter.params.provider.ArgumentsProvider; import org.junit.jupiter.params.provider.ArgumentsSource; import org.junit.jupiter.params.provider.MethodSource; +import org.junit.jupiter.params.support.ParameterDeclarations; import java.util.List; import java.util.Optional; @@ -76,7 +77,8 @@ class ArithmeticValueTest { static class BinaryPredicateTestProvider implements ArgumentsProvider { @Override - public Stream provideArguments(final ExtensionContext context) { + public Stream provideArguments(final ParameterDeclarations parameterDeclarations, + final ExtensionContext context) { return Stream.of( Arguments.of(List.of(INT_1, INT_1), new ArithmeticValue.AddFn(), 2, false), Arguments.of(List.of(INT_1, INT_1), new ArithmeticValue.SubFn(), 0, false), diff --git a/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/query/plan/cascades/BooleanValueTest.java b/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/query/plan/cascades/BooleanValueTest.java index 39599c9c43..43d419bb3e 100644 --- a/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/query/plan/cascades/BooleanValueTest.java +++ b/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/query/plan/cascades/BooleanValueTest.java @@ -56,6 +56,7 @@ import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.ArgumentsProvider; import org.junit.jupiter.params.provider.ArgumentsSource; +import org.junit.jupiter.params.support.ParameterDeclarations; import javax.annotation.Nonnull; import java.nio.charset.StandardCharsets; @@ -133,7 +134,8 @@ class BooleanValueTest { static class BinaryPredicateTestProvider implements ArgumentsProvider { @Override - public Stream provideArguments(final ExtensionContext context) { + public Stream provideArguments(final ParameterDeclarations parameterDeclarations, + final ExtensionContext context) { return Stream.of( Arguments.of(List.of(BOOL_TRUE, BOOL_TRUE), new RelOpValue.EqualsFn(), ConstantPredicate.TRUE), Arguments.of(List.of(BOOL_FALSE, BOOL_TRUE), new RelOpValue.EqualsFn(), ConstantPredicate.FALSE), @@ -885,7 +887,8 @@ public Stream provideArguments(final ExtensionContext conte static class LazyBinaryPredicateTestProvider implements ArgumentsProvider { @Override - public Stream provideArguments(final ExtensionContext context) { + public Stream provideArguments(final ParameterDeclarations parameterDeclarations, + final ExtensionContext context) { return Stream.of( /* lazy evaluation tests */ Arguments.of(List.of(new RelOpValue.NotEqualsFn().encapsulate(List.of(INT_1, INT_1)), diff --git a/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/query/plan/cascades/LikeOperatorValueTest.java b/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/query/plan/cascades/LikeOperatorValueTest.java index 3d87a6e26d..ddc952fc91 100644 --- a/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/query/plan/cascades/LikeOperatorValueTest.java +++ b/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/query/plan/cascades/LikeOperatorValueTest.java @@ -44,6 +44,7 @@ import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.ArgumentsProvider; import org.junit.jupiter.params.provider.ArgumentsSource; +import org.junit.jupiter.params.support.ParameterDeclarations; import javax.annotation.Nonnull; import java.util.Arrays; @@ -69,7 +70,8 @@ class LikeOperatorValueTest { static class InvalidInputArgumentsProvider implements ArgumentsProvider { @Override - public Stream provideArguments(final ExtensionContext context) { + public Stream provideArguments(final ParameterDeclarations parameterDeclarations, + final ExtensionContext context) { return Stream.of( Arguments.of(INT_1, INT_1, STRING_NULL), Arguments.of(LONG_1, LONG_1, STRING_NULL), @@ -94,7 +96,8 @@ public Stream provideArguments(final ExtensionContext conte static class ValidInputArgumentsProvider implements ArgumentsProvider { @Override - public Stream provideArguments(final ExtensionContext context) { + public Stream provideArguments(final ParameterDeclarations parameterDeclarations, + final ExtensionContext context) { return Stream.of( Arguments.of(null, null, null, null), Arguments.of("a", null, null, null), diff --git a/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/query/plan/cascades/TypeTest.java b/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/query/plan/cascades/TypeTest.java index 1d3ed6b1e2..db5ace10d4 100644 --- a/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/query/plan/cascades/TypeTest.java +++ b/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/query/plan/cascades/TypeTest.java @@ -55,6 +55,7 @@ import org.junit.jupiter.params.provider.ArgumentsProvider; import org.junit.jupiter.params.provider.ArgumentsSource; import org.junit.jupiter.params.provider.MethodSource; +import org.junit.jupiter.params.support.ParameterDeclarations; import javax.annotation.Nonnull; import javax.annotation.Nullable; @@ -88,7 +89,8 @@ static class ProtobufRandomMessageProvider implements ArgumentsProvider { private static final Random random = new Random(seed); @Override - public Stream provideArguments(final ExtensionContext context) throws Exception { + public Stream provideArguments(final ParameterDeclarations parameterDeclarations, + final ExtensionContext context) { return Stream.of( Arguments.of( "TestRecords4WrapperProto.RestaurantRecord", TestRecords4WrapperProto.RestaurantRecord.newBuilder() @@ -226,7 +228,8 @@ void recordTypeIsParsable(final String paramTestTitleIgnored, final Message mess static class TypesProvider implements ArgumentsProvider { @Override - public Stream provideArguments(final ExtensionContext context) throws Exception { + public Stream provideArguments(final ParameterDeclarations parameterDeclarations, + final ExtensionContext context) throws Exception { final var listOfNulls = new LinkedList(); listOfNulls.add(null); final var listOfNullsAndNonNulls = new LinkedList(); diff --git a/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/query/plan/cascades/VariadicFunctionValueTest.java b/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/query/plan/cascades/VariadicFunctionValueTest.java index ad0142439b..2f241d4f3b 100644 --- a/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/query/plan/cascades/VariadicFunctionValueTest.java +++ b/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/query/plan/cascades/VariadicFunctionValueTest.java @@ -42,6 +42,7 @@ import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.ArgumentsProvider; import org.junit.jupiter.params.provider.ArgumentsSource; +import org.junit.jupiter.params.support.ParameterDeclarations; import java.util.List; import java.util.Optional; @@ -138,7 +139,8 @@ private static DynamicMessage getMessageForRecordNamed() { static class BinaryPredicateTestProvider implements ArgumentsProvider { @Override - public Stream provideArguments(final ExtensionContext context) { + public Stream provideArguments(final ParameterDeclarations parameterDeclarations, + final ExtensionContext context) { return Stream.of( // Greatest Function Arguments.of(List.of(INT_1, INT_1), new VariadicFunctionValue.GreatestFn(), 1, false), diff --git a/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/query/plan/plans/ExplodePlanTest.java b/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/query/plan/plans/ExplodePlanTest.java index 25870edadb..b4b4676fee 100644 --- a/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/query/plan/plans/ExplodePlanTest.java +++ b/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/query/plan/plans/ExplodePlanTest.java @@ -32,6 +32,7 @@ import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.ArgumentsProvider; import org.junit.jupiter.params.provider.ArgumentsSource; +import org.junit.jupiter.params.support.ParameterDeclarations; import javax.annotation.Nonnull; import java.util.List; @@ -118,7 +119,8 @@ private static void verifyCursor(@Nonnull final RecordCursor actual private static class ArgumentProvider implements ArgumentsProvider { @Override - public Stream provideArguments(final ExtensionContext context) { + public Stream provideArguments(final ParameterDeclarations parameterDeclarations, + final ExtensionContext context) { return Stream.of( Arguments.of(ExplodeCursorBuilder.instance().withLimit(1), ImmutableList.of(1), true), Arguments.of(ExplodeCursorBuilder.instance().withLimit(4), ImmutableList.of(1, 2, 3, 4), true), diff --git a/fdb-relational-core/src/test/java/com/apple/foundationdb/relational/api/ddl/SqlFunctionTest.java b/fdb-relational-core/src/test/java/com/apple/foundationdb/relational/api/ddl/SqlFunctionTest.java index b4e14927f5..15fc4e46ca 100644 --- a/fdb-relational-core/src/test/java/com/apple/foundationdb/relational/api/ddl/SqlFunctionTest.java +++ b/fdb-relational-core/src/test/java/com/apple/foundationdb/relational/api/ddl/SqlFunctionTest.java @@ -51,7 +51,7 @@ import static com.apple.foundationdb.relational.matchers.SchemaTemplateMatchers.routine; import static com.apple.foundationdb.relational.utils.RelationalAssertions.assertThrows; import static org.hamcrest.MatcherAssert.assertThat; -import static org.junit.jupiter.params.ParameterizedTest.ARGUMENTS_PLACEHOLDER; +import static org.junit.jupiter.params.ParameterizedInvocationConstants.ARGUMENTS_PLACEHOLDER; /** * Contains a number of tests for creating SQL functions. diff --git a/fdb-relational-core/src/test/java/com/apple/foundationdb/relational/autotest/engine/AutoTestDescriptor.java b/fdb-relational-core/src/test/java/com/apple/foundationdb/relational/autotest/engine/AutoTestDescriptor.java index 027c6c1832..023f3b264e 100644 --- a/fdb-relational-core/src/test/java/com/apple/foundationdb/relational/autotest/engine/AutoTestDescriptor.java +++ b/fdb-relational-core/src/test/java/com/apple/foundationdb/relational/autotest/engine/AutoTestDescriptor.java @@ -85,11 +85,6 @@ class AutoTestDescriptor extends ClassTestDescriptor { this.configInvoker = configInvoker; } - @Override - public Type getType() { - return Type.CONTAINER; - } - @Override public boolean mayRegisterTests() { return true; @@ -114,7 +109,7 @@ private void invokeBeforeEachCallbacks(JupiterEngineExecutionContext context) t MutableExtensionRegistry registry = context.getExtensionRegistry(); ExtensionContext extensionContext = context.getExtensionContext(); ThrowableCollector throwableCollector = context.getThrowableCollector(); - final TestInstances testInstances = context.getTestInstancesProvider().getTestInstances(registry, throwableCollector); + final TestInstances testInstances = context.getTestInstancesProvider().getTestInstances(registry, context); JunitUtils.setTestInstances(extensionContext, testInstances); for (BeforeEachCallback callback : registry.getExtensions(BeforeEachCallback.class)) { @@ -145,7 +140,7 @@ public void cleanUp(JupiterEngineExecutionContext context) throws Exception { private void invokeWorkloadTests(JupiterEngineExecutionContext context, DynamicTestExecutor dynamicTestExecutor, JupiterConfiguration configuration) throws InterruptedException { Object instance = context.getTestInstancesProvider() - .getTestInstances(context.getExtensionRegistry(), context.getThrowableCollector()) + .getTestInstances(context.getExtensionRegistry(), context) .findInstance(getTestClass()) .orElseThrow(); @@ -173,7 +168,7 @@ private void invokeWorkloadTests(JupiterEngineExecutionContext context, DynamicT Collection queries = queryInvoker.getQueries(instance, workload.getSchema(), context, executableInvoker); queries.forEach(querySet -> { UniqueId uid = parentId.append(DYNAMIC_CONTAINER_SEGMENT_TYPE, workload.getDisplayName() + "-" + querySet.getLabel()); - TestDescriptor descriptor = new WorkloadTestDescriptor(uid, getTestClass(), configuration, workload, querySet); + TestDescriptor descriptor = new WorkloadTestDescriptor(uid, getTestClass(), this, configuration, workload, querySet); dynamicTestExecutor.execute(descriptor); }); })); diff --git a/fdb-relational-core/src/test/java/com/apple/foundationdb/relational/autotest/engine/AutoTestEngine.java b/fdb-relational-core/src/test/java/com/apple/foundationdb/relational/autotest/engine/AutoTestEngine.java index 322f3872ed..fae1333e80 100644 --- a/fdb-relational-core/src/test/java/com/apple/foundationdb/relational/autotest/engine/AutoTestEngine.java +++ b/fdb-relational-core/src/test/java/com/apple/foundationdb/relational/autotest/engine/AutoTestEngine.java @@ -21,12 +21,12 @@ package com.apple.foundationdb.relational.autotest.engine; import com.apple.foundationdb.relational.autotest.AutomatedTest; - import org.junit.jupiter.engine.config.CachingJupiterConfiguration; import org.junit.jupiter.engine.config.DefaultJupiterConfiguration; import org.junit.jupiter.engine.config.JupiterConfiguration; import org.junit.jupiter.engine.descriptor.JupiterEngineDescriptor; import org.junit.jupiter.engine.execution.JupiterEngineExecutionContext; +import org.junit.jupiter.engine.execution.LauncherStoreFacade; import org.junit.platform.commons.support.AnnotationSupport; import org.junit.platform.commons.support.ReflectionSupport; import org.junit.platform.engine.EngineDiscoveryRequest; @@ -54,7 +54,8 @@ public String getId() { @Override public TestDescriptor discover(EngineDiscoveryRequest discoveryRequest, UniqueId uniqueId) { JupiterConfiguration config = new CachingJupiterConfiguration( - new DefaultJupiterConfiguration(discoveryRequest.getConfigurationParameters())); + new DefaultJupiterConfiguration(discoveryRequest.getConfigurationParameters(), + discoveryRequest.getOutputDirectoryCreator())); TestDescriptor rootDescriptor = new JupiterEngineDescriptor(uniqueId, config); discoveryRequest.getSelectorsByType(ClasspathRootSelector.class).forEach(selector -> @@ -72,7 +73,8 @@ public TestDescriptor discover(EngineDiscoveryRequest discoveryRequest, UniqueId protected JupiterEngineExecutionContext createExecutionContext(ExecutionRequest request) { JupiterEngineDescriptor engineDescriptor = (JupiterEngineDescriptor) request.getRootTestDescriptor(); JupiterConfiguration config = engineDescriptor.getConfiguration(); - return new JupiterEngineExecutionContext(request.getEngineExecutionListener(), config); + return new JupiterEngineExecutionContext(request.getEngineExecutionListener(), config, + new LauncherStoreFacade(request.getStore())); } private void appendTestsInClass(Class javaClass, TestDescriptor engineDesc, JupiterConfiguration config) { @@ -109,5 +111,4 @@ public void execute(ExecutionRequest request) { new AutomatedTestExecutor().execute(request,root); } */ - } diff --git a/fdb-relational-core/src/test/java/com/apple/foundationdb/relational/autotest/engine/WorkloadTestDescriptor.java b/fdb-relational-core/src/test/java/com/apple/foundationdb/relational/autotest/engine/WorkloadTestDescriptor.java index ce38cce0a5..8a32958b55 100644 --- a/fdb-relational-core/src/test/java/com/apple/foundationdb/relational/autotest/engine/WorkloadTestDescriptor.java +++ b/fdb-relational-core/src/test/java/com/apple/foundationdb/relational/autotest/engine/WorkloadTestDescriptor.java @@ -20,12 +20,12 @@ package com.apple.foundationdb.relational.autotest.engine; -import com.apple.foundationdb.relational.api.Row; -import com.apple.foundationdb.relational.api.StructMetaData; import com.apple.foundationdb.relational.api.RelationalConnection; import com.apple.foundationdb.relational.api.RelationalResultSet; import com.apple.foundationdb.relational.api.RelationalStatement; import com.apple.foundationdb.relational.api.RelationalStruct; +import com.apple.foundationdb.relational.api.Row; +import com.apple.foundationdb.relational.api.StructMetaData; import com.apple.foundationdb.relational.api.exceptions.ErrorCode; import com.apple.foundationdb.relational.api.exceptions.RelationalException; import com.apple.foundationdb.relational.autotest.Connector; @@ -40,7 +40,6 @@ import com.apple.foundationdb.relational.recordlayer.util.ExceptionUtil; import com.apple.foundationdb.relational.utils.ReservoirSample; import com.apple.foundationdb.relational.utils.ResultSetAssert; - import org.junit.jupiter.api.DynamicTest; import org.junit.jupiter.api.extension.ExtensionContext; import org.junit.jupiter.api.extension.TestExecutionExceptionHandler; @@ -82,19 +81,15 @@ class WorkloadTestDescriptor extends NestedClassTestDescriptor { public WorkloadTestDescriptor(UniqueId uniqueId, Class testClass, + TestDescriptor parent, JupiterConfiguration configuration, AutoWorkload workload, QuerySet queries) { - super(uniqueId, testClass, configuration); + super(uniqueId, testClass, () -> NestedClassTestDescriptor.getEnclosingTestClasses(parent), configuration); this.workload = workload; this.querySet = queries; } - @Override - public Type getType() { - return Type.CONTAINER; - } - @Override public boolean mayRegisterTests() { return true; diff --git a/fdb-relational-core/src/test/java/com/apple/foundationdb/relational/recordlayer/PlanGenerationStackTest.java b/fdb-relational-core/src/test/java/com/apple/foundationdb/relational/recordlayer/PlanGenerationStackTest.java index a1de830062..341043d570 100644 --- a/fdb-relational-core/src/test/java/com/apple/foundationdb/relational/recordlayer/PlanGenerationStackTest.java +++ b/fdb-relational-core/src/test/java/com/apple/foundationdb/relational/recordlayer/PlanGenerationStackTest.java @@ -37,6 +37,7 @@ import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.ArgumentsProvider; import org.junit.jupiter.params.provider.ArgumentsSource; +import org.junit.jupiter.params.support.ParameterDeclarations; import javax.annotation.Nonnull; import javax.annotation.Nullable; @@ -80,7 +81,8 @@ public PlanGenerationStackTest() { static class RandomQueryProvider implements ArgumentsProvider { @Override - public Stream provideArguments(final ExtensionContext context) throws Exception { + public Stream provideArguments(final ParameterDeclarations parameterDeclarations, + final ExtensionContext context) { return Stream.of( Arguments.of(0, "select count(*) from restaurant", null), Arguments.of(1, "select * from restaurant", null), diff --git a/fdb-relational-core/src/test/java/com/apple/foundationdb/relational/recordlayer/query/QueryTypeTests.java b/fdb-relational-core/src/test/java/com/apple/foundationdb/relational/recordlayer/query/QueryTypeTests.java index cb6308bd93..c410c5c714 100644 --- a/fdb-relational-core/src/test/java/com/apple/foundationdb/relational/recordlayer/query/QueryTypeTests.java +++ b/fdb-relational-core/src/test/java/com/apple/foundationdb/relational/recordlayer/query/QueryTypeTests.java @@ -29,6 +29,7 @@ import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.ArgumentsProvider; import org.junit.jupiter.params.provider.ArgumentsSource; +import org.junit.jupiter.params.support.ParameterDeclarations; import javax.annotation.Nonnull; import java.util.stream.Stream; @@ -40,7 +41,8 @@ public class QueryTypeTests { static class QueriesProvider implements ArgumentsProvider { @Override - public Stream provideArguments(final ExtensionContext context) throws Exception { + public Stream provideArguments(final ParameterDeclarations parameterDeclarations, + final ExtensionContext context) { return Stream.of( Arguments.of("select count(*) from restaurant", ParseTreeInfo.QueryType.SELECT), Arguments.of(" select * from restaurant", ParseTreeInfo.QueryType.SELECT), diff --git a/fdb-test-utils/src/main/java/com/apple/test/BooleanArgumentsProvider.java b/fdb-test-utils/src/main/java/com/apple/test/BooleanArgumentsProvider.java index 33e5c853e3..a5524e8348 100644 --- a/fdb-test-utils/src/main/java/com/apple/test/BooleanArgumentsProvider.java +++ b/fdb-test-utils/src/main/java/com/apple/test/BooleanArgumentsProvider.java @@ -24,6 +24,7 @@ import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.ArgumentsProvider; import org.junit.jupiter.params.support.AnnotationConsumer; +import org.junit.jupiter.params.support.ParameterDeclarations; import java.util.Arrays; import java.util.stream.Stream; @@ -42,7 +43,8 @@ public void accept(BooleanSource booleanSource) { } @Override - public Stream provideArguments(ExtensionContext extensionContext) throws Exception { + public Stream provideArguments(final ParameterDeclarations parameterDeclarations, + final ExtensionContext extensionContext) throws Exception { if (names.length == 0) { throw new IllegalStateException("@BooleanSource has an empty list of names"); } diff --git a/fdb-test-utils/src/main/java/com/apple/test/RandomSeedProvider.java b/fdb-test-utils/src/main/java/com/apple/test/RandomSeedProvider.java index e719ad059b..50bafcb790 100644 --- a/fdb-test-utils/src/main/java/com/apple/test/RandomSeedProvider.java +++ b/fdb-test-utils/src/main/java/com/apple/test/RandomSeedProvider.java @@ -24,6 +24,7 @@ import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.ArgumentsProvider; import org.junit.jupiter.params.support.AnnotationConsumer; +import org.junit.jupiter.params.support.ParameterDeclarations; import java.util.stream.Stream; @@ -40,7 +41,8 @@ public void accept(final RandomSeedSource annotation) { } @Override - public Stream provideArguments(final ExtensionContext extensionContext) throws Exception { + public Stream provideArguments(final ParameterDeclarations parameterDeclarations, + final ExtensionContext extensionContext) throws Exception { return RandomizedTestUtils.randomSeeds(fixedSeeds).map(Arguments::of); } } diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index c4e6482b97..f9671f2b0d 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -62,8 +62,8 @@ diffutils = "4.12" hamcrest = "2.2" jcommander = "1.81" jline = "3.30.4" -junit = "5.11.3" -junit-platform = "1.7.1" +junit = "5.14.1" +junit-platform = "1.14.1" mockito = "3.7.7" snakeyaml = "2.2" @@ -145,7 +145,7 @@ spotbugs-annotations = { module = "com.github.spotbugs:spotbugs-annotations", ve [bundles] test-impl = [ "assertj", "hamcrest", "junit-api", "junit-params", "log4j-core", "mockito", "bndtools" ] -test-runtime = [ "junit-engine", "log4j-slf4jBinding"] +test-runtime = [ "junit-engine", "junit-platform", "log4j-slf4jBinding"] test-compileOnly = [ "autoService", "jsr305" ] [plugins]