diff --git a/lucene/core/src/java/org/apache/lucene/search/FilterWeight.java b/lucene/core/src/java/org/apache/lucene/search/FilterWeight.java index 16bb75ef4062..efbd9badd8a2 100644 --- a/lucene/core/src/java/org/apache/lucene/search/FilterWeight.java +++ b/lucene/core/src/java/org/apache/lucene/search/FilterWeight.java @@ -67,4 +67,10 @@ public Matches matches(LeafReaderContext context, int doc) throws IOException { public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException { return in.scorerSupplier(context); } + + @Override + public ScorerSupplier scorerSupplier(IndexSearcher.LeafReaderContextPartition partition) + throws IOException { + return in.scorerSupplier(partition); + } } diff --git a/lucene/core/src/java/org/apache/lucene/search/IndexSearcher.java b/lucene/core/src/java/org/apache/lucene/search/IndexSearcher.java index d1079b69089a..0e5a82a87ec4 100644 --- a/lucene/core/src/java/org/apache/lucene/search/IndexSearcher.java +++ b/lucene/core/src/java/org/apache/lucene/search/IndexSearcher.java @@ -321,7 +321,7 @@ public QueryCachingPolicy getQueryCachingPolicy() { * href="https://github.com/apache/lucene/issues/13745">the corresponding github issue. */ protected LeafSlice[] slices(List leaves) { - return slices(leaves, MAX_DOCS_PER_SLICE, MAX_SEGMENTS_PER_SLICE, false); + return slices(leaves, MAX_DOCS_PER_SLICE, MAX_SEGMENTS_PER_SLICE, true); } /** @@ -828,7 +828,14 @@ protected void searchLeaf( // continue with the following leaf return; } - ScorerSupplier scorerSupplier = weight.scorerSupplier(ctx); + ScorerSupplier scorerSupplier; + if (minDocId == 0 && maxDocId == DocIdSetIterator.NO_MORE_DOCS) { + scorerSupplier = weight.scorerSupplier(ctx); + } else { + LeafReaderContextPartition partition = + LeafReaderContextPartition.createFromAndTo(ctx, minDocId, maxDocId); + scorerSupplier = weight.scorerSupplier(partition); + } if (scorerSupplier != null) { scorerSupplier.setTopLevelScoringClause(); BulkScorer scorer = scorerSupplier.bulkScorer(); diff --git a/lucene/core/src/java/org/apache/lucene/search/PointRangeQuery.java b/lucene/core/src/java/org/apache/lucene/search/PointRangeQuery.java index c198fecb4b35..0e6947054928 100644 --- a/lucene/core/src/java/org/apache/lucene/search/PointRangeQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/PointRangeQuery.java @@ -19,6 +19,7 @@ import java.io.IOException; import java.util.Arrays; import java.util.Objects; +import java.util.concurrent.ConcurrentHashMap; import java.util.function.BiFunction; import java.util.function.Predicate; import org.apache.lucene.document.IntPoint; @@ -32,7 +33,7 @@ import org.apache.lucene.index.PointValues.Relation; import org.apache.lucene.util.ArrayUtil; import org.apache.lucene.util.ArrayUtil.ByteArrayComparator; -import org.apache.lucene.util.BitSetIterator; +import org.apache.lucene.util.BitDocIdSet; import org.apache.lucene.util.DocIdSetBuilder; import org.apache.lucene.util.FixedBitSet; import org.apache.lucene.util.IntsRef; @@ -131,6 +132,12 @@ public final Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, fl return new ConstantScoreWeight(this, boost) { + // Cache to share DocIdSet computation across partitions of the same segment + // Key: LeafReaderContext (identifies the segment) + // Value: Lazily-initialized DocIdSet for the entire segment + private final ConcurrentHashMap segmentCache = + new ConcurrentHashMap<>(); + private boolean matches(byte[] packedValue) { int offset = 0; for (int dim = 0; dim < numDims; dim++, offset += bytesPerDim) { @@ -248,15 +255,76 @@ public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) { }; } - @Override - public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException { - LeafReader reader = context.reader(); + /** + * Helper class that lazily builds and caches a DocIdSet for an entire segment. This allows + * multiple partitions of the same segment to share the BKD traversal work. + */ + final class SegmentDocIdSetSupplier { + private final LeafReaderContext context; + private volatile DocIdSet cachedDocIdSet = null; + private final Object buildLock = new Object(); + + SegmentDocIdSetSupplier(LeafReaderContext context) { + this.context = context; + } + + /** + * Get or build the DocIdSet for the entire segment. Thread-safe: first thread builds, + * others wait and reuse. + */ + DocIdSet getOrBuild() throws IOException { + DocIdSet result = cachedDocIdSet; + if (result == null) { + synchronized (buildLock) { + result = cachedDocIdSet; + if (result == null) { + result = buildDocIdSet(); + cachedDocIdSet = result; + } + } + } + return result; + } + + private DocIdSet buildDocIdSet() throws IOException { + PointValues values = context.reader().getPointValues(field); + LeafReader reader = context.reader(); + // Check if we should use inverse intersection optimization + if (values.getDocCount() == reader.maxDoc() + && values.getDocCount() == values.size() + && estimateCost(values) > reader.maxDoc() / 2) { + // Build inverse bitset (docs that DON'T match) + final FixedBitSet result = new FixedBitSet(reader.maxDoc()); + long[] cost = new long[1]; + values.intersect(getInverseIntersectVisitor(result, cost)); + // Flip to get docs that DO match + result.flip(0, reader.maxDoc()); + cost[0] = Math.max(0, reader.maxDoc() - cost[0]); + return new BitDocIdSet(result, cost[0]); + } else { + // Normal path: build DocIdSet from matching docs + DocIdSetBuilder builder = new DocIdSetBuilder(reader.maxDoc(), values); + IntersectVisitor visitor = getIntersectVisitor(builder); + values.intersect(visitor); + return builder.build(); + } + } + private long estimateCost(PointValues values) throws IOException { + DocIdSetBuilder builder = new DocIdSetBuilder(context.reader().maxDoc(), values); + IntersectVisitor visitor = getIntersectVisitor(builder); + return values.estimateDocCount(visitor); + } + } + + @Override + public ScorerSupplier scorerSupplier(IndexSearcher.LeafReaderContextPartition partition) + throws IOException { + LeafReader reader = partition.ctx.reader(); PointValues values = reader.getPointValues(field); if (checkValidPointValues(values) == false) { return null; } - if (values.getDocCount() == 0) { return null; } else { @@ -274,7 +342,6 @@ public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOExcepti } } } - boolean allDocsMatch; if (values.getDocCount() == reader.maxDoc()) { final byte[] fieldPackedLower = values.getMinPackedValue(); @@ -291,49 +358,157 @@ public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOExcepti } else { allDocsMatch = false; } - if (allDocsMatch) { // all docs have a value and all points are within bounds, so everything matches return ConstantScoreScorerSupplier.matchAll(score(), scoreMode, reader.maxDoc()); } else { - return new ConstantScoreScorerSupplier(score(), scoreMode, reader.maxDoc()) { - - final DocIdSetBuilder result = new DocIdSetBuilder(reader.maxDoc(), values); - final IntersectVisitor visitor = getIntersectVisitor(result); - long cost = -1; - - @Override - public DocIdSetIterator iterator(long leadCost) throws IOException { - if (values.getDocCount() == reader.maxDoc() - && values.getDocCount() == values.size() - && cost() > reader.maxDoc() / 2) { - // If all docs have exactly one value and the cost is greater - // than half the leaf size then maybe we can make things faster - // by computing the set of documents that do NOT match the range - final FixedBitSet result = new FixedBitSet(reader.maxDoc()); - long[] cost = new long[1]; - values.intersect(getInverseIntersectVisitor(result, cost)); - // Flip the bit set and cost - result.flip(0, reader.maxDoc()); - cost[0] = Math.max(0, reader.maxDoc() - cost[0]); - return new BitSetIterator(result, cost[0]); - } + // Get or create the cached supplier for this segment + SegmentDocIdSetSupplier segmentSupplier = + segmentCache.computeIfAbsent(partition.ctx, ctx -> new SegmentDocIdSetSupplier(ctx)); + // Each call creates a new PartitionScorerSupplier and all partitions share the same + // SegmentDocIdSetSupplier + return new PartitionScorerSupplier( + segmentSupplier, partition.minDocId, partition.maxDocId, score(), scoreMode); + } + } - values.intersect(visitor); - return result.build().iterator(); - } + /** ScorerSupplier for a partition that filters results from the shared segment DocIdSet. */ + final class PartitionScorerSupplier extends ScorerSupplier { + private final SegmentDocIdSetSupplier segmentSupplier; + private final int minDocId; + private final int maxDocId; + private final float score; + private final ScoreMode scoreMode; + + PartitionScorerSupplier( + SegmentDocIdSetSupplier segmentSupplier, + int minDocId, + int maxDocId, + float score, + ScoreMode scoreMode) { + this.segmentSupplier = segmentSupplier; + this.minDocId = minDocId; + this.maxDocId = maxDocId; + this.score = score; + this.scoreMode = scoreMode; + } - @Override - public long cost() { - if (cost == -1) { - // Computing the cost may be expensive, so only do it if necessary - cost = values.estimateDocCount(visitor); - assert cost >= 0; - } - return cost; - } - }; + @Override + public Scorer get(long leadCost) throws IOException { + DocIdSetIterator iterator = getIterator(); + if (iterator == null) { + return null; + } + return new ConstantScoreScorer(score, scoreMode, iterator); } + + private DocIdSetIterator getIterator() throws IOException { + // Get the shared DocIdSet (built once per segment) + // The underlying FixedBitSet/int[] buffer is shared across all partitions, + // but each partition gets its own iterator with its own position state. + DocIdSet docIdSet = segmentSupplier.getOrBuild(); + DocIdSetIterator fullIterator = docIdSet.iterator(); + if (fullIterator == null) { + return null; + } + // Check if this is a full segment (no partition filtering needed) + boolean isFullSegment = (minDocId == 0 && maxDocId == DocIdSetIterator.NO_MORE_DOCS); + if (isFullSegment) { + return fullIterator; + } + // Wrap iterator to filter to partition range + return new PartitionFilteredDocIdSetIterator(fullIterator, minDocId, maxDocId); + } + + @Override + public long cost() { + DocIdSet docIdSet; + try { + docIdSet = segmentSupplier.getOrBuild(); + } catch (IOException e) { + throw new RuntimeException(e); + } + long totalCost = docIdSet.iterator().cost(); + boolean isFullSegment = (minDocId == 0 && maxDocId == DocIdSetIterator.NO_MORE_DOCS); + if (isFullSegment) { + return totalCost; + } + int segmentSize = segmentSupplier.context.reader().maxDoc(); + int partitionSize = maxDocId - minDocId; + return (totalCost * partitionSize) / segmentSize; + } + + @Override + public BulkScorer bulkScorer() throws IOException { + Scorer scorer = get(Long.MAX_VALUE); + if (scorer == null) { + return null; + } + return new Weight.DefaultBulkScorer(scorer); + } + } + + /** + * Iterator that filters a delegate iterator to only return docs within a partition range. + * Used to restrict a full-segment DocIdSetIterator to a specific partition's boundaries. + */ + static final class PartitionFilteredDocIdSetIterator extends DocIdSetIterator { + private final DocIdSetIterator delegate; + private final int minDocId; + private final int maxDocId; + private int doc = -1; + + PartitionFilteredDocIdSetIterator(DocIdSetIterator delegate, int minDocId, int maxDocId) { + this.delegate = delegate; + this.minDocId = minDocId; + this.maxDocId = maxDocId; + } + + @Override + public int docID() { + return doc; + } + + @Override + public int nextDoc() throws IOException { + if (doc == -1) { + // First call: advance to minDocId + doc = delegate.advance(minDocId); + } else { + doc = delegate.nextDoc(); + } + // Stop if we've exceeded the partition range + if (doc >= maxDocId) { + doc = NO_MORE_DOCS; + } + return doc; + } + + @Override + public int advance(int target) throws IOException { + if (target >= maxDocId) { + return doc = NO_MORE_DOCS; + } + // Ensure target is at least minDocId + target = Math.max(target, minDocId); + doc = delegate.advance(target); + if (doc >= maxDocId) { + doc = NO_MORE_DOCS; + } + return doc; + } + + @Override + public long cost() { + // Conservative estimate based on partition size + return Math.min(delegate.cost(), maxDocId - minDocId); + } + } + + @Override + public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException { + return scorerSupplier( + IndexSearcher.LeafReaderContextPartition.createForEntireSegment(context)); } @Override diff --git a/lucene/core/src/java/org/apache/lucene/search/Weight.java b/lucene/core/src/java/org/apache/lucene/search/Weight.java index 341dd3cadf6a..f1cce197d7b4 100644 --- a/lucene/core/src/java/org/apache/lucene/search/Weight.java +++ b/lucene/core/src/java/org/apache/lucene/search/Weight.java @@ -149,6 +149,31 @@ public final Scorer scorer(LeafReaderContext context) throws IOException { */ public abstract ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException; + /** + * Returns a {@link ScorerSupplier}, which can then be used to get a {@link Scorer} for a + * partition of a leaf reader context. + * + *

This method allows queries to optimize for intra-segment concurrency by knowing the specific + * doc ID range being searched within the segment. The default implementation delegates to {@link + * #scorerSupplier(LeafReaderContext)} ignoring the partition bounds. Queries that can benefit + * from partition awareness (e.g., by creating smaller data structures scoped to the partition) + * should override this method. + * + *

A scorer supplier for the same {@link LeafReaderContext} instance may be requested multiple + * times as part of a single search call, potentially from different threads searching different + * doc ID ranges concurrently. + * + * @param partition the leaf reader context partition containing the context and doc ID range + * @return a {@link ScorerSupplier} providing the scorer, or null if scorer is null + * @throws IOException if an IOException occurs + * @see IndexSearcher.LeafReaderContextPartition + * @since 10.1 + */ + public ScorerSupplier scorerSupplier(IndexSearcher.LeafReaderContextPartition partition) + throws IOException { + return scorerSupplier(partition.ctx); + } + /** * Helper method that delegates to {@link #scorerSupplier(LeafReaderContext)}. It is implemented * as diff --git a/lucene/sandbox/src/test/org/apache/lucene/sandbox/search/TestQueryProfilerWeight.java b/lucene/sandbox/src/test/org/apache/lucene/sandbox/search/TestQueryProfilerWeight.java index 41dc054a0756..5bb3191804a3 100644 --- a/lucene/sandbox/src/test/org/apache/lucene/sandbox/search/TestQueryProfilerWeight.java +++ b/lucene/sandbox/src/test/org/apache/lucene/sandbox/search/TestQueryProfilerWeight.java @@ -177,7 +177,7 @@ public void testPropagateTopLevelScoringClause() throws IOException { Weight fakeWeight = new FakeWeight(query); QueryProfilerBreakdown profile = new QueryProfilerBreakdown(); QueryProfilerWeight profileWeight = new QueryProfilerWeight(fakeWeight, profile); - ScorerSupplier scorerSupplier = profileWeight.scorerSupplier(null); + ScorerSupplier scorerSupplier = profileWeight.scorerSupplier((LeafReaderContext) null); scorerSupplier.setTopLevelScoringClause(); assertEquals(42, scorerSupplier.cost()); }