From fb5f4eedf91ad5f76330d211720478657d7e802b Mon Sep 17 00:00:00 2001
From: Adrien Grand <jpountz@gmail.com>
Date: Mon, 13 Jan 2025 23:27:53 +0100
Subject: [PATCH] Implement #intoBitSet on `IntArrayDocIdSet` and
 `RoaringDocIdSet`.

These doc id sets can implement `#intoBitSet` in a way that auto-vectorizes.

For reference, `RoaringDocIdSet` is used by the query cache, and
`IntArrayDocIdSet` is used by point queries.
---
 .../lucene/search/DocIdSetIterator.java       |  2 +
 .../apache/lucene/util/BitSetIterator.java    |  1 +
 .../apache/lucene/util/IntArrayDocIdSet.java  | 21 +++++++++
 .../apache/lucene/util/RoaringDocIdSet.java   | 46 +++++++++++++++++++
 .../tests/util/BaseDocIdSetTestCase.java      | 37 +++++++++++++++
 5 files changed, 107 insertions(+)

diff --git a/lucene/core/src/java/org/apache/lucene/search/DocIdSetIterator.java b/lucene/core/src/java/org/apache/lucene/search/DocIdSetIterator.java
index ee30f627a56b..b550d0e584aa 100644
--- a/lucene/core/src/java/org/apache/lucene/search/DocIdSetIterator.java
+++ b/lucene/core/src/java/org/apache/lucene/search/DocIdSetIterator.java
@@ -231,6 +231,8 @@ protected final int slowAdvance(int target) throws IOException {
    *
    * <p><b>Note</b>: It is important not to clear bits from {@code bitSet} that may be already set.
    *
+   * <p><b>Note</b>: {@code offset} may be negative.
+   *
    * @lucene.internal
    */
   public void intoBitSet(Bits acceptDocs, int upTo, FixedBitSet bitSet, int offset)
diff --git a/lucene/core/src/java/org/apache/lucene/util/BitSetIterator.java b/lucene/core/src/java/org/apache/lucene/util/BitSetIterator.java
index 4d7c83057cbe..69af64fb3a2a 100644
--- a/lucene/core/src/java/org/apache/lucene/util/BitSetIterator.java
+++ b/lucene/core/src/java/org/apache/lucene/util/BitSetIterator.java
@@ -105,6 +105,7 @@ public void intoBitSet(Bits acceptDocs, int upTo, FixedBitSet bitSet, int offset
     if (acceptDocs == null
         && offset < bits.length()
         && bits instanceof FixedBitSet fixedBits
+        && offset >= 0
         // no bits are set between `offset` and `doc`
         && fixedBits.nextSetBit(offset) == doc
         // the whole `bitSet` is getting filled
diff --git a/lucene/core/src/java/org/apache/lucene/util/IntArrayDocIdSet.java b/lucene/core/src/java/org/apache/lucene/util/IntArrayDocIdSet.java
index d44cc7839233..d8299f578b72 100644
--- a/lucene/core/src/java/org/apache/lucene/util/IntArrayDocIdSet.java
+++ b/lucene/core/src/java/org/apache/lucene/util/IntArrayDocIdSet.java
@@ -95,6 +95,27 @@ public int advance(int target) throws IOException {
       return doc = docs[i++];
     }
 
+    @Override
+    public void intoBitSet(Bits acceptDocs, int upTo, FixedBitSet bitSet, int offset)
+        throws IOException {
+      if (doc >= upTo) {
+        return;
+      }
+
+      if (acceptDocs != null) {
+        super.intoBitSet(acceptDocs, upTo, bitSet, offset);
+        return;
+      }
+
+      int from = i - 1;
+      int to = VectorUtil.findNextGEQ(docs, upTo, from, length);
+      for (int i = from; i < to; ++i) {
+        bitSet.set(docs[i] - offset);
+      }
+      doc = docs[to];
+      i = to + 1;
+    }
+
     @Override
     public long cost() {
       return length;
diff --git a/lucene/core/src/java/org/apache/lucene/util/RoaringDocIdSet.java b/lucene/core/src/java/org/apache/lucene/util/RoaringDocIdSet.java
index ccd92a74250e..171e89d82f38 100644
--- a/lucene/core/src/java/org/apache/lucene/util/RoaringDocIdSet.java
+++ b/lucene/core/src/java/org/apache/lucene/util/RoaringDocIdSet.java
@@ -217,6 +217,26 @@ public int advance(int target) throws IOException {
             return doc = docId(i);
           }
         }
+
+        @Override
+        public void intoBitSet(Bits acceptDocs, int upTo, FixedBitSet bitSet, int offset)
+            throws IOException {
+          if (doc >= upTo) {
+            return;
+          }
+
+          if (acceptDocs != null) {
+            super.intoBitSet(acceptDocs, upTo, bitSet, offset);
+            return;
+          }
+
+          int from = i;
+          advance(upTo);
+          int to = i;
+          for (int i = from; i < to; ++i) {
+            bitSet.set(docId(i) - offset);
+          }
+        }
       };
     }
   }
@@ -312,6 +332,32 @@ private int firstDocFromNextBlock() throws IOException {
       }
     }
 
+    @Override
+    public void intoBitSet(Bits acceptDocs, int upTo, FixedBitSet bitSet, int offset)
+        throws IOException {
+      if (acceptDocs != null) {
+        super.intoBitSet(acceptDocs, upTo, bitSet, offset);
+        return;
+      }
+
+      for (; ; ) {
+        int subUpto = upTo - (block << 16);
+        if (subUpto < 0) {
+          break;
+        }
+        int subOffset = offset - (block << 16);
+        sub.intoBitSet(null, subUpto, bitSet, subOffset);
+        if (sub.docID() == NO_MORE_DOCS) {
+          if (firstDocFromNextBlock() == NO_MORE_DOCS) {
+            break;
+          }
+        } else {
+          doc = (block << 16) | sub.docID();
+          break;
+        }
+      }
+    }
+
     @Override
     public long cost() {
       return cardinality;
diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/util/BaseDocIdSetTestCase.java b/lucene/test-framework/src/java/org/apache/lucene/tests/util/BaseDocIdSetTestCase.java
index 2c1dfc72a31a..336fbd3af63f 100644
--- a/lucene/test-framework/src/java/org/apache/lucene/tests/util/BaseDocIdSetTestCase.java
+++ b/lucene/test-framework/src/java/org/apache/lucene/tests/util/BaseDocIdSetTestCase.java
@@ -24,6 +24,7 @@
 import org.apache.lucene.search.DocIdSet;
 import org.apache.lucene.search.DocIdSetIterator;
 import org.apache.lucene.util.Bits;
+import org.apache.lucene.util.FixedBitSet;
 
 /** Base test class for {@link DocIdSet}s. */
 public abstract class BaseDocIdSetTestCase<T extends DocIdSet> extends LuceneTestCase {
@@ -196,4 +197,40 @@ private long ramBytesUsed(DocIdSet set, int length) throws IOException {
     long bytes2 = RamUsageTester.ramUsed(dummy);
     return bytes1 - bytes2;
   }
+
+  public void testIntoBitSet() throws IOException {
+    Random random = random();
+    final int numBits = TestUtil.nextInt(random, 100, 1 << 20);
+    // test various random sets with various load factors
+    for (float percentSet : new float[] {0f, 0.0001f, random.nextFloat(), 0.9f, 1f}) {
+      final BitSet set = randomSet(numBits, percentSet);
+      final T copy = copyOf(set, numBits);
+      int from = TestUtil.nextInt(random(), 0, numBits - 1);
+      int to = TestUtil.nextInt(random(), from, numBits + 5);
+      FixedBitSet actual = new FixedBitSet(to - from);
+      DocIdSetIterator it1 = copy.iterator();
+      if (it1 == null) {
+        continue;
+      }
+      int fromDoc = it1.advance(from);
+      // No docs to set
+      it1.intoBitSet(null, from, actual, from);
+      assertTrue(actual.scanIsEmpty());
+      assertEquals(fromDoc, it1.docID());
+
+      // Now actually set some bits
+      it1.intoBitSet(null, to, actual, from);
+      FixedBitSet expected = new FixedBitSet(to - from);
+      DocIdSetIterator it2 = copy.iterator();
+      for (int doc = it2.advance(from); doc < to; doc = it2.nextDoc()) {
+        expected.set(doc - from);
+      }
+      assertEquals(expected, actual);
+      // Check if docID() / nextDoc() return the same value after #intoBitSet has been called.
+      assertEquals(it2.docID(), it1.docID());
+      if (it2.docID() != DocIdSetIterator.NO_MORE_DOCS) {
+        assertEquals(it2.nextDoc(), it1.nextDoc());
+      }
+    }
+  }
 }