From bef2bc162a6879efe9a5e52555a65c99fb17141d Mon Sep 17 00:00:00 2001 From: Paul Masurel Date: Thu, 30 Oct 2025 17:33:24 +0100 Subject: [PATCH 01/11] Small preliminary refactoring. --- src/collector/top_collector.rs | 2 +- src/collector/top_score_collector.rs | 61 ++++++++++++--- src/collector/tweak_score_top_collector.rs | 86 ++++++++++++++++------ 3 files changed, 115 insertions(+), 34 deletions(-) diff --git a/src/collector/top_collector.rs b/src/collector/top_collector.rs index 29ff086005..7451a82488 100644 --- a/src/collector/top_collector.rs +++ b/src/collector/top_collector.rs @@ -152,7 +152,7 @@ where T: PartialOrd + Clone pub(crate) struct TopSegmentCollector { /// We reverse the order of the feature in order to /// have top-semantics instead of bottom semantics. - topn_computer: TopNComputer, + pub(crate) topn_computer: TopNComputer, segment_ord: u32, } diff --git a/src/collector/top_score_collector.rs b/src/collector/top_score_collector.rs index 33c5df59e1..aebd8e4dda 100644 --- a/src/collector/top_score_collector.rs +++ b/src/collector/top_score_collector.rs @@ -690,14 +690,13 @@ impl TopDocs { /// /// # See also /// - [custom_score(...)](TopDocs::custom_score) - pub fn tweak_score( + pub fn tweak_score( self, score_tweaker: TScoreTweaker, ) -> impl Collector> where TScore: 'static + Send + Sync + Clone + PartialOrd, - TScoreSegmentTweaker: ScoreSegmentTweaker + 'static, - TScoreTweaker: ScoreTweaker + Send + Sync, + TScoreTweaker: ScoreTweaker + Send + Sync, { TweakedScoreTopCollector::new(score_tweaker, self.0.into_tscore()) } @@ -1003,17 +1002,15 @@ where self.threshold = Some(median); } - // This is faster since it avoids the buffer resizing to be inlined from vec.push() - // (this is in the hot path) - // TODO: Replace with `push_within_capacity` when it's stabilized - let uninit = self.buffer.spare_capacity_mut(); // This cannot panic, because we truncate_median will at least remove one element, since // the min capacity is 2. - uninit[0].write(ComparableDoc { doc, feature }); - // This is safe because it would panic in the line above - unsafe { - self.buffer.set_len(self.buffer.len() + 1); - } + self.append_within_capacity(ComparableDoc { doc, feature }); + } + + // ONLY CALL THIS FUNCTION WHEN YOU KNOW THE BUFFER HAS ENOUGH CAPACITY. + #[inline(always)] + fn append_within_capacity(&mut self, comparable_doc: ComparableDoc) { + push_within_capacity(comparable_doc, &mut self.buffer); } #[inline(never)] @@ -1048,6 +1045,46 @@ where } } +impl TopNComputer +where TScore: PartialOrd + Clone +{ + #[inline(always)] + pub(crate) fn push_lazy( + &mut self, + doc: DocId, + score: Score, + score_tweaker: &mut impl ScoreSegmentTweaker, + ) { + if let Some(threshold) = self.threshold.as_ref() { + if let Some((_cmp, feature)) = + score_tweaker.accept_score_lazy::(doc, score, threshold) + { + self.append_within_capacity(ComparableDoc { feature, doc }); + } + } else { + let feature = score_tweaker.score(doc, score); + self.append_within_capacity(ComparableDoc { feature, doc }); + } + } +} + +// Push an element provided there is enough capacity to do so. +// TODO replace me when push_within_capacity is stabilized. +#[inline(always)] +fn push_within_capacity(el: T, buf: &mut Vec) { + let prev_len = buf.len(); + if prev_len == buf.capacity() { + return; + } + // This is mimicking the current (non-stabilized) implementation in std. + // SAFETY: we just checked we have enough capacity. + unsafe { + let end = buf.as_mut_ptr().add(prev_len); + std::ptr::write(end, el); + buf.set_len(prev_len + 1); + } +} + #[cfg(test)] mod tests { use proptest::prelude::*; diff --git a/src/collector/tweak_score_top_collector.rs b/src/collector/tweak_score_top_collector.rs index e7e8d1547a..b1f92f9840 100644 --- a/src/collector/tweak_score_top_collector.rs +++ b/src/collector/tweak_score_top_collector.rs @@ -1,3 +1,5 @@ +use std::cmp::Ordering; + use crate::collector::top_collector::{TopCollector, TopSegmentCollector}; use crate::collector::{Collector, SegmentCollector}; use crate::{DocAddress, DocId, Result, Score, SegmentReader}; @@ -25,9 +27,47 @@ where TScore: Clone + PartialOrd /// for a given document belonging to a specific segment. /// /// It is the segment local version of the [`ScoreTweaker`]. -pub trait ScoreSegmentTweaker: 'static { +pub trait ScoreSegmentTweaker: 'static { + /// Score used by at the segment level by the `ScoreSegmentTweaker`. + type SegmentScore: 'static + PartialOrd + Clone + Send + Sync; + /// Tweak the given `score` for the document `doc`. - fn score(&mut self, doc: DocId, score: Score) -> TScore; + fn score(&mut self, doc: DocId, score: Score) -> Self::SegmentScore; + + /// Returns true if the `ScoreSegmentTweaker` is a good candidate for the lazy evaluation optimization. + /// See [`ScoreSegmentTweaker::accept_score_lazy`]. + fn is_lazy() -> bool { + false + } + + /// Implementing this method makes it possible to avoid computing + /// a score entirely if we can assess that it won't pass a threshold + /// with a partial computation. + /// + /// This is currently used for lexicographic sorting. + /// + /// If REVERSE_ORDER is false (resp. true), + /// - we return None if the score is below the threshold (resp. above to the threshold) + /// - we return Some(ordering, score) if the score is above or equal to the threshold (resp. below or equal to) + fn accept_score_lazy( + &mut self, + doc_id: DocId, + score: Score, + threshold: &Self::SegmentScore, + ) -> Option<(std::cmp::Ordering, Self::SegmentScore)> { + let excluded_ordering = if REVERSE_ORDER { + Ordering::Greater + } else { + Ordering::Less + }; + let score = self.score(doc_id, score); + let cmp = score.partial_cmp(threshold).unwrap_or(excluded_ordering); + if cmp == excluded_ordering { + return None; + } else { + return Some((cmp, score)); + } + } } /// `ScoreTweaker` makes it possible to tweak the score @@ -38,7 +78,7 @@ pub trait ScoreSegmentTweaker: 'static { /// the score at a segment scale. pub trait ScoreTweaker: Sync { /// Type of the associated [`ScoreSegmentTweaker`]. - type Child: ScoreSegmentTweaker; + type Child: ScoreSegmentTweaker; /// Builds a child tweaker for a specific segment. The child scorer is associated with /// a specific segment. @@ -52,7 +92,7 @@ where { type Fruit = Vec<(TScore, DocAddress)>; - type Child = TopTweakedScoreSegmentCollector; + type Child = TopTweakedScoreSegmentCollector; fn for_segment( &self, @@ -76,29 +116,30 @@ where } } -pub struct TopTweakedScoreSegmentCollector -where - TScore: 'static + PartialOrd + Clone + Send + Sync + Sized, - TSegmentScoreTweaker: ScoreSegmentTweaker, +pub struct TopTweakedScoreSegmentCollector +where TSegmentScoreTweaker: ScoreSegmentTweaker { - segment_collector: TopSegmentCollector, + segment_collector: TopSegmentCollector, segment_scorer: TSegmentScoreTweaker, } -impl SegmentCollector - for TopTweakedScoreSegmentCollector -where - TScore: 'static + PartialOrd + Clone + Send + Sync, - TSegmentScoreTweaker: 'static + ScoreSegmentTweaker, +impl SegmentCollector + for TopTweakedScoreSegmentCollector +where TSegmentScoreTweaker: 'static + ScoreSegmentTweaker { - type Fruit = Vec<(TScore, DocAddress)>; + type Fruit = Vec<(TSegmentScoreTweaker::SegmentScore, DocAddress)>; fn collect(&mut self, doc: DocId, score: Score) { - let score = self.segment_scorer.score(doc, score); - self.segment_collector.collect(doc, score); + // Thanks to generics, this if-statement is free. + if TSegmentScoreTweaker::is_lazy() { + self.segment_collector.topn_computer.push_lazy(doc, score, &mut self.segment_scorer); + } else { + let score = self.segment_scorer.score(doc, score); + self.segment_collector.collect(doc, score); + } } - fn harvest(self) -> Vec<(TScore, DocAddress)> { + fn harvest(self) -> Vec<(TSegmentScoreTweaker::SegmentScore, DocAddress)> { self.segment_collector.harvest() } } @@ -106,7 +147,7 @@ where impl ScoreTweaker for F where F: 'static + Send + Sync + Fn(&SegmentReader) -> TSegmentScoreTweaker, - TSegmentScoreTweaker: ScoreSegmentTweaker, + TSegmentScoreTweaker: ScoreSegmentTweaker, { type Child = TSegmentScoreTweaker; @@ -115,9 +156,12 @@ where } } -impl ScoreSegmentTweaker for F -where F: 'static + FnMut(DocId, Score) -> TScore +impl ScoreSegmentTweaker for F +where + F: 'static + FnMut(DocId, Score) -> TScore, + TScore: 'static + PartialOrd + Clone + Send + Sync, { + type SegmentScore = TScore; fn score(&mut self, doc: DocId, score: Score) -> TScore { (self)(doc, score) } From 76c1c08a83db421b6bcb89b7def8e95230a000e9 Mon Sep 17 00:00:00 2001 From: Paul Masurel Date: Thu, 30 Oct 2025 19:49:23 +0100 Subject: [PATCH 02/11] Refactoring of the score tweaker to unlock two features. - Allow lazy evaluation of score. As soon as we identified that a doc won't reach the topK threshold, we can stop the evaluation. - Allow for a different segment level score, segment level score and their conversion. This PR breaks public API, but fixing code is straightforward. --- src/collector/top_score_collector.rs | 6 +- src/collector/tweak_score_top_collector.rs | 220 +++++++++++++++++++-- 2 files changed, 207 insertions(+), 19 deletions(-) diff --git a/src/collector/top_score_collector.rs b/src/collector/top_score_collector.rs index aebd8e4dda..391b3a4a4d 100644 --- a/src/collector/top_score_collector.rs +++ b/src/collector/top_score_collector.rs @@ -690,13 +690,13 @@ impl TopDocs { /// /// # See also /// - [custom_score(...)](TopDocs::custom_score) - pub fn tweak_score( + pub fn tweak_score( self, score_tweaker: TScoreTweaker, ) -> impl Collector> where - TScore: 'static + Send + Sync + Clone + PartialOrd, - TScoreTweaker: ScoreTweaker + Send + Sync, + TScore: 'static + Clone + Send + Sync + PartialOrd, + TScoreTweaker: ScoreTweaker + Send + Sync, { TweakedScoreTopCollector::new(score_tweaker, self.0.into_tscore()) } diff --git a/src/collector/tweak_score_top_collector.rs b/src/collector/tweak_score_top_collector.rs index b1f92f9840..deadcd11a8 100644 --- a/src/collector/tweak_score_top_collector.rs +++ b/src/collector/tweak_score_top_collector.rs @@ -4,7 +4,7 @@ use crate::collector::top_collector::{TopCollector, TopSegmentCollector}; use crate::collector::{Collector, SegmentCollector}; use crate::{DocAddress, DocId, Result, Score, SegmentReader}; -pub(crate) struct TweakedScoreTopCollector { +pub(crate) struct TweakedScoreTopCollector { score_tweaker: TScoreTweaker, collector: TopCollector, } @@ -28,14 +28,20 @@ where TScore: Clone + PartialOrd /// /// It is the segment local version of the [`ScoreTweaker`]. pub trait ScoreSegmentTweaker: 'static { + /// The final score being emitted. + type Score: 'static + PartialOrd + Send + Sync + Clone; + /// Score used by at the segment level by the `ScoreSegmentTweaker`. - type SegmentScore: 'static + PartialOrd + Clone + Send + Sync; + /// + /// It is typically small like a `u64`, and is meant to be converted + /// to the final score at the end of the collection of the segment. + type SegmentScore: 'static + PartialOrd + Clone + Send + Sync + Clone; /// Tweak the given `score` for the document `doc`. fn score(&mut self, doc: DocId, score: Score) -> Self::SegmentScore; - /// Returns true if the `ScoreSegmentTweaker` is a good candidate for the lazy evaluation optimization. - /// See [`ScoreSegmentTweaker::accept_score_lazy`]. + /// Returns true if the `ScoreSegmentTweaker` is a good candidate for the lazy evaluation + /// optimization. See [`ScoreSegmentTweaker::accept_score_lazy`]. fn is_lazy() -> bool { false } @@ -48,7 +54,8 @@ pub trait ScoreSegmentTweaker: 'static { /// /// If REVERSE_ORDER is false (resp. true), /// - we return None if the score is below the threshold (resp. above to the threshold) - /// - we return Some(ordering, score) if the score is above or equal to the threshold (resp. below or equal to) + /// - we return Some(ordering, score) if the score is above or equal to the threshold (resp. + /// below or equal to) fn accept_score_lazy( &mut self, doc_id: DocId, @@ -68,6 +75,9 @@ pub trait ScoreSegmentTweaker: 'static { return Some((cmp, score)); } } + + /// Convert a segment level score into the global level score. + fn convert_score(&self, score: Self::SegmentScore) -> Self::Score; } /// `ScoreTweaker` makes it possible to tweak the score @@ -76,9 +86,11 @@ pub trait ScoreSegmentTweaker: 'static { /// The `ScoreTweaker` itself does not make much of the computation itself. /// Instead, it helps constructing `Self::Child` instances that will compute /// the score at a segment scale. -pub trait ScoreTweaker: Sync { +pub trait ScoreTweaker: Sync { + /// The actual score emitted by the Tweaker. + type Score: 'static + Send + Sync + PartialOrd + Clone; /// Type of the associated [`ScoreSegmentTweaker`]. - type Child: ScoreSegmentTweaker; + type Child: ScoreSegmentTweaker; /// Builds a child tweaker for a specific segment. The child scorer is associated with /// a specific segment. @@ -87,10 +99,10 @@ pub trait ScoreTweaker: Sync { impl Collector for TweakedScoreTopCollector where - TScoreTweaker: ScoreTweaker + Send + Sync, - TScore: 'static + PartialOrd + Clone + Send + Sync, + TScoreTweaker: ScoreTweaker + Send + Sync, + TScore: 'static + Send + PartialOrd + Sync + Clone, { - type Fruit = Vec<(TScore, DocAddress)>; + type Fruit = Vec<(TScoreTweaker::Score, DocAddress)>; type Child = TopTweakedScoreSegmentCollector; @@ -127,28 +139,37 @@ impl SegmentCollector for TopTweakedScoreSegmentCollector where TSegmentScoreTweaker: 'static + ScoreSegmentTweaker { - type Fruit = Vec<(TSegmentScoreTweaker::SegmentScore, DocAddress)>; + type Fruit = Vec<(TSegmentScoreTweaker::Score, DocAddress)>; fn collect(&mut self, doc: DocId, score: Score) { // Thanks to generics, this if-statement is free. if TSegmentScoreTweaker::is_lazy() { - self.segment_collector.topn_computer.push_lazy(doc, score, &mut self.segment_scorer); + self.segment_collector + .topn_computer + .push_lazy(doc, score, &mut self.segment_scorer); } else { let score = self.segment_scorer.score(doc, score); self.segment_collector.collect(doc, score); } } - fn harvest(self) -> Vec<(TSegmentScoreTweaker::SegmentScore, DocAddress)> { - self.segment_collector.harvest() + fn harvest(self) -> Self::Fruit { + let segment_hits: Vec<(TSegmentScoreTweaker::SegmentScore, DocAddress)> = + self.segment_collector.harvest(); + segment_hits + .into_iter() + .map(|(score, doc)| (self.segment_scorer.convert_score(score), doc)) + .collect() } } -impl ScoreTweaker for F +impl ScoreTweaker for F where F: 'static + Send + Sync + Fn(&SegmentReader) -> TSegmentScoreTweaker, - TSegmentScoreTweaker: ScoreSegmentTweaker, + TSegmentScoreTweaker: ScoreSegmentTweaker, { + + type Score = TSegmentScoreTweaker::Score; type Child = TSegmentScoreTweaker; fn segment_tweaker(&self, segment_reader: &SegmentReader) -> Result { @@ -156,13 +177,180 @@ where } } + impl ScoreSegmentTweaker for F where F: 'static + FnMut(DocId, Score) -> TScore, TScore: 'static + PartialOrd + Clone + Send + Sync, { + type Score = TScore; type SegmentScore = TScore; + fn score(&mut self, doc: DocId, score: Score) -> TScore { (self)(doc, score) } + + /// Convert a segment level score into the global level score. + fn convert_score(&self, score: Self::SegmentScore) -> Self::Score { + score + } +} + +impl ScoreTweaker for (HeadScoreTweaker, TailScoreTweaker) +where + HeadScoreTweaker: ScoreTweaker, + TailScoreTweaker: ScoreTweaker, +{ + type Score = (::Score, ::Score); + type Child = (HeadScoreTweaker::Child, TailScoreTweaker::Child); + + fn segment_tweaker(&self, segment_reader: &SegmentReader) -> Result { + Ok(( + self.0.segment_tweaker(segment_reader)?, + self.1.segment_tweaker(segment_reader)?, + )) + } +} + +impl ScoreSegmentTweaker + for (HeadScoreSegmentTweaker, TailScoreSegmentTweaker) +where + HeadScoreSegmentTweaker: ScoreSegmentTweaker, + TailScoreSegmentTweaker: ScoreSegmentTweaker, +{ + type Score = ( + HeadScoreSegmentTweaker::Score, + TailScoreSegmentTweaker::Score, + ); + type SegmentScore = ( + HeadScoreSegmentTweaker::SegmentScore, + TailScoreSegmentTweaker::SegmentScore, + ); + + fn score(&mut self, doc: DocId, score: Score) -> Self::SegmentScore { + let head_score = self.0.score(doc, score); + let tail_score = self.1.score(doc, score); + (head_score, tail_score) + } + + fn accept_score_lazy( + &mut self, + doc_id: DocId, + score: Score, + threshold: &Self::SegmentScore, + ) -> Option<(Ordering, Self::SegmentScore)> { + let (head_threshold, tail_threshold) = threshold; + let (head_cmp, head_score) = + self.0 + .accept_score_lazy::(doc_id, score, head_threshold)?; + if head_cmp == Ordering::Equal { + let (tail_cmp, tail_score) = + self.1 + .accept_score_lazy::(doc_id, score, tail_threshold)?; + Some((tail_cmp, (head_score, tail_score))) + } else { + let tail_score = self.1.score(doc_id, score); + Some((head_cmp, (head_score, tail_score))) + } + } + + fn is_lazy() -> bool { + true + } + + fn convert_score(&self, score: Self::SegmentScore) -> Self::Score { + let (head_score, tail_score) = score; + ( + self.0.convert_score(head_score), + self.1.convert_score(tail_score), + ) + } +} + +/// This struct is used as an adapter to take a segment score tweaker and map its score to another new score. +pub struct MappedSegmentScoreTweaker { + tweaker: T, + map: fn(PreviousScore) -> NewScore, +} + +impl ScoreSegmentTweaker for MappedSegmentScoreTweaker +where + T: ScoreSegmentTweaker, + PreviousScore: 'static + Clone + Send + Sync + PartialOrd, + NewScore: 'static + Clone + Send + Sync + PartialOrd, +{ + type Score = NewScore; + type SegmentScore = T::SegmentScore; + + fn score(&mut self, doc: DocId, score: Score) -> Self::SegmentScore { + self.tweaker.score(doc, score) + } + + fn accept_score_lazy( + &mut self, + doc_id: DocId, + score: Score, + threshold: &Self::SegmentScore, + ) -> Option<(std::cmp::Ordering, Self::SegmentScore)> { + self.tweaker.accept_score_lazy::(doc_id, score, threshold) + } + + fn is_lazy() -> bool { + T::is_lazy() + } + + fn convert_score(&self, score: Self::SegmentScore) -> Self::Score { + (self.map)(self.tweaker.convert_score(score)) + } +} + + +// We then re-use our (head, tail) implement and our mapper by seeing mapping any tuple (a, b, c, ...) as the chain (a, (b, (c, ...))) + +impl ScoreTweaker for (ScoreTweaker1, ScoreTweaker2, ScoreTweaker3) +where + ScoreTweaker1: ScoreTweaker, + ScoreTweaker2: ScoreTweaker, + ScoreTweaker3: ScoreTweaker, +{ + type Child = MappedSegmentScoreTweaker<<(ScoreTweaker1, (ScoreTweaker2, ScoreTweaker3)) as ScoreTweaker>::Child, (ScoreTweaker1::Score, (ScoreTweaker2::Score, ScoreTweaker3::Score)), Self::Score>; + type Score = (ScoreTweaker1::Score, ScoreTweaker2::Score, ScoreTweaker3::Score); + + fn segment_tweaker(&self, segment_reader: &SegmentReader) -> Result { + let score_tweaker1 = self.0.segment_tweaker(segment_reader)?; + let score_tweaker2 = self.1.segment_tweaker(segment_reader)?; + let score_tweaker3 = self.2.segment_tweaker(segment_reader)?; + Ok( + MappedSegmentScoreTweaker { + tweaker: (score_tweaker1, (score_tweaker2, score_tweaker3)), + map: | (score1, (score2, score3))| (score1, score2, score3), + } + ) + + } +} + +impl ScoreTweaker for (ScoreTweaker1, ScoreTweaker2, ScoreTweaker3, ScoreTweaker4) +where + ScoreTweaker1: ScoreTweaker, + ScoreTweaker2: ScoreTweaker, + ScoreTweaker3: ScoreTweaker, + ScoreTweaker4: ScoreTweaker, +{ + type Child = MappedSegmentScoreTweaker<<(ScoreTweaker1, (ScoreTweaker2, (ScoreTweaker3, ScoreTweaker4))) as ScoreTweaker>::Child, (ScoreTweaker1::Score, (ScoreTweaker2::Score, (ScoreTweaker3::Score, ScoreTweaker4::Score))), Self::Score>; + type Score = (ScoreTweaker1::Score, ScoreTweaker2::Score, ScoreTweaker3::Score, ScoreTweaker4::Score); + + fn segment_tweaker(&self, segment_reader: &SegmentReader) -> Result { + let score_tweaker1 = self.0.segment_tweaker(segment_reader)?; + let score_tweaker2 = self.1.segment_tweaker(segment_reader)?; + let score_tweaker3 = self.2.segment_tweaker(segment_reader)?; + let score_tweaker4 = self.3.segment_tweaker(segment_reader)?; + Ok( + MappedSegmentScoreTweaker { + tweaker: (score_tweaker1, (score_tweaker2, (score_tweaker3, score_tweaker4))), + map: | (score1, (score2, (score3, score4)))| (score1, score2, score3, score4), + } + ) + + } } From 176e984be92043f0f3d26922f2eba4375c19b395 Mon Sep 17 00:00:00 2001 From: Paul Masurel Date: Fri, 31 Oct 2025 10:08:51 +0100 Subject: [PATCH 03/11] CR comment --- src/collector/top_collector.rs | 8 +++++++- src/collector/top_score_collector.rs | 10 ++++++++-- src/collector/tweak_score_top_collector.rs | 10 +--------- 3 files changed, 16 insertions(+), 12 deletions(-) diff --git a/src/collector/top_collector.rs b/src/collector/top_collector.rs index 7451a82488..cf5b5eb167 100644 --- a/src/collector/top_collector.rs +++ b/src/collector/top_collector.rs @@ -4,6 +4,7 @@ use std::marker::PhantomData; use serde::{Deserialize, Serialize}; use super::top_score_collector::TopNComputer; +use crate::collector::ScoreSegmentTweaker; use crate::index::SegmentReader; use crate::{DocAddress, DocId, SegmentOrdinal}; @@ -152,7 +153,7 @@ where T: PartialOrd + Clone pub(crate) struct TopSegmentCollector { /// We reverse the order of the feature in order to /// have top-semantics instead of bottom semantics. - pub(crate) topn_computer: TopNComputer, + topn_computer: TopNComputer, segment_ord: u32, } @@ -191,6 +192,11 @@ impl TopSegmentCollector { pub fn collect(&mut self, doc: DocId, feature: T) { self.topn_computer.push(feature, doc); } + + #[inline] + pub fn collect_lazy(&mut self, doc: DocId, score: crate::Score, segment_scorer: &mut impl ScoreSegmentTweaker) { + self.topn_computer.push_lazy(doc, score, segment_scorer); + } } #[cfg(test)] diff --git a/src/collector/top_score_collector.rs b/src/collector/top_score_collector.rs index 391b3a4a4d..a951cb10c6 100644 --- a/src/collector/top_score_collector.rs +++ b/src/collector/top_score_collector.rs @@ -1049,12 +1049,18 @@ impl TopNComputer>( &mut self, doc: DocId, score: Score, - score_tweaker: &mut impl ScoreSegmentTweaker, + score_tweaker: &mut TScoreSegmentTweaker, ) { + if !TScoreSegmentTweaker::is_lazy() { + let feature = score_tweaker.score(doc, score); + self.push(feature, doc); + return; + } + if let Some(threshold) = self.threshold.as_ref() { if let Some((_cmp, feature)) = score_tweaker.accept_score_lazy::(doc, score, threshold) diff --git a/src/collector/tweak_score_top_collector.rs b/src/collector/tweak_score_top_collector.rs index deadcd11a8..dc89f12c80 100644 --- a/src/collector/tweak_score_top_collector.rs +++ b/src/collector/tweak_score_top_collector.rs @@ -142,15 +142,7 @@ where TSegmentScoreTweaker: 'static + ScoreSegmentTweaker type Fruit = Vec<(TSegmentScoreTweaker::Score, DocAddress)>; fn collect(&mut self, doc: DocId, score: Score) { - // Thanks to generics, this if-statement is free. - if TSegmentScoreTweaker::is_lazy() { - self.segment_collector - .topn_computer - .push_lazy(doc, score, &mut self.segment_scorer); - } else { - let score = self.segment_scorer.score(doc, score); - self.segment_collector.collect(doc, score); - } + self.segment_collector.collect_lazy(doc, score, &mut self.segment_scorer); } fn harvest(self) -> Self::Fruit { From 88d48266c68847672bb0fe498ded273fd7bf690b Mon Sep 17 00:00:00 2001 From: Paul Masurel Date: Fri, 31 Oct 2025 10:15:51 +0100 Subject: [PATCH 04/11] Cargo fmt and renaming --- src/collector/top_collector.rs | 7 +- src/collector/top_score_collector.rs | 2 +- src/collector/tweak_score_top_collector.rs | 122 ++++++++++++++------- 3 files changed, 87 insertions(+), 44 deletions(-) diff --git a/src/collector/top_collector.rs b/src/collector/top_collector.rs index cf5b5eb167..1fcfe5d204 100644 --- a/src/collector/top_collector.rs +++ b/src/collector/top_collector.rs @@ -194,7 +194,12 @@ impl TopSegmentCollector { } #[inline] - pub fn collect_lazy(&mut self, doc: DocId, score: crate::Score, segment_scorer: &mut impl ScoreSegmentTweaker) { + pub fn collect_lazy( + &mut self, + doc: DocId, + score: crate::Score, + segment_scorer: &mut impl ScoreSegmentTweaker, + ) { self.topn_computer.push_lazy(doc, score, segment_scorer); } } diff --git a/src/collector/top_score_collector.rs b/src/collector/top_score_collector.rs index a951cb10c6..1e7188a75f 100644 --- a/src/collector/top_score_collector.rs +++ b/src/collector/top_score_collector.rs @@ -696,7 +696,7 @@ impl TopDocs { ) -> impl Collector> where TScore: 'static + Clone + Send + Sync + PartialOrd, - TScoreTweaker: ScoreTweaker + Send + Sync, + TScoreTweaker: ScoreTweaker + Send + Sync, { TweakedScoreTopCollector::new(score_tweaker, self.0.into_tscore()) } diff --git a/src/collector/tweak_score_top_collector.rs b/src/collector/tweak_score_top_collector.rs index dc89f12c80..f7825ebed7 100644 --- a/src/collector/tweak_score_top_collector.rs +++ b/src/collector/tweak_score_top_collector.rs @@ -77,7 +77,7 @@ pub trait ScoreSegmentTweaker: 'static { } /// Convert a segment level score into the global level score. - fn convert_score(&self, score: Self::SegmentScore) -> Self::Score; + fn convert_segment_score_to_score(&self, score: Self::SegmentScore) -> Self::Score; } /// `ScoreTweaker` makes it possible to tweak the score @@ -142,7 +142,8 @@ where TSegmentScoreTweaker: 'static + ScoreSegmentTweaker type Fruit = Vec<(TSegmentScoreTweaker::Score, DocAddress)>; fn collect(&mut self, doc: DocId, score: Score) { - self.segment_collector.collect_lazy(doc, score, &mut self.segment_scorer); + self.segment_collector + .collect_lazy(doc, score, &mut self.segment_scorer); } fn harvest(self) -> Self::Fruit { @@ -150,7 +151,12 @@ where TSegmentScoreTweaker: 'static + ScoreSegmentTweaker self.segment_collector.harvest(); segment_hits .into_iter() - .map(|(score, doc)| (self.segment_scorer.convert_score(score), doc)) + .map(|(score, doc)| { + ( + self.segment_scorer.convert_segment_score_to_score(score), + doc, + ) + }) .collect() } } @@ -160,7 +166,6 @@ where F: 'static + Send + Sync + Fn(&SegmentReader) -> TSegmentScoreTweaker, TSegmentScoreTweaker: ScoreSegmentTweaker, { - type Score = TSegmentScoreTweaker::Score; type Child = TSegmentScoreTweaker; @@ -169,7 +174,6 @@ where } } - impl ScoreSegmentTweaker for F where F: 'static + FnMut(DocId, Score) -> TScore, @@ -183,7 +187,7 @@ where } /// Convert a segment level score into the global level score. - fn convert_score(&self, score: Self::SegmentScore) -> Self::Score { + fn convert_segment_score_to_score(&self, score: Self::SegmentScore) -> Self::Score { score } } @@ -193,7 +197,10 @@ where HeadScoreTweaker: ScoreTweaker, TailScoreTweaker: ScoreTweaker, { - type Score = (::Score, ::Score); + type Score = ( + ::Score, + ::Score, + ); type Child = (HeadScoreTweaker::Child, TailScoreTweaker::Child); fn segment_tweaker(&self, segment_reader: &SegmentReader) -> Result { @@ -250,22 +257,24 @@ where true } - fn convert_score(&self, score: Self::SegmentScore) -> Self::Score { + fn convert_segment_score_to_score(&self, score: Self::SegmentScore) -> Self::Score { let (head_score, tail_score) = score; ( - self.0.convert_score(head_score), - self.1.convert_score(tail_score), + self.0.convert_segment_score_to_score(head_score), + self.1.convert_segment_score_to_score(tail_score), ) } } -/// This struct is used as an adapter to take a segment score tweaker and map its score to another new score. +/// This struct is used as an adapter to take a segment score tweaker and map its score to another +/// new score. pub struct MappedSegmentScoreTweaker { tweaker: T, map: fn(PreviousScore) -> NewScore, } -impl ScoreSegmentTweaker for MappedSegmentScoreTweaker +impl ScoreSegmentTweaker + for MappedSegmentScoreTweaker where T: ScoreSegmentTweaker, PreviousScore: 'static + Clone + Send + Sync + PartialOrd, @@ -279,70 +288,99 @@ where } fn accept_score_lazy( - &mut self, - doc_id: DocId, - score: Score, - threshold: &Self::SegmentScore, - ) -> Option<(std::cmp::Ordering, Self::SegmentScore)> { - self.tweaker.accept_score_lazy::(doc_id, score, threshold) + &mut self, + doc_id: DocId, + score: Score, + threshold: &Self::SegmentScore, + ) -> Option<(std::cmp::Ordering, Self::SegmentScore)> { + self.tweaker + .accept_score_lazy::(doc_id, score, threshold) } fn is_lazy() -> bool { T::is_lazy() } - fn convert_score(&self, score: Self::SegmentScore) -> Self::Score { - (self.map)(self.tweaker.convert_score(score)) + fn convert_segment_score_to_score(&self, score: Self::SegmentScore) -> Self::Score { + (self.map)(self.tweaker.convert_segment_score_to_score(score)) } } +// We then re-use our (head, tail) implement and our mapper by seeing mapping any tuple (a, b, c, +// ...) as the chain (a, (b, (c, ...))) -// We then re-use our (head, tail) implement and our mapper by seeing mapping any tuple (a, b, c, ...) as the chain (a, (b, (c, ...))) - -impl ScoreTweaker for (ScoreTweaker1, ScoreTweaker2, ScoreTweaker3) +impl ScoreTweaker + for (ScoreTweaker1, ScoreTweaker2, ScoreTweaker3) where ScoreTweaker1: ScoreTweaker, ScoreTweaker2: ScoreTweaker, ScoreTweaker3: ScoreTweaker, { - type Child = MappedSegmentScoreTweaker<<(ScoreTweaker1, (ScoreTweaker2, ScoreTweaker3)) as ScoreTweaker>::Child, (ScoreTweaker1::Score, (ScoreTweaker2::Score, ScoreTweaker3::Score)), Self::Score>; - type Score = (ScoreTweaker1::Score, ScoreTweaker2::Score, ScoreTweaker3::Score); + type Child = MappedSegmentScoreTweaker< + <(ScoreTweaker1, (ScoreTweaker2, ScoreTweaker3)) as ScoreTweaker>::Child, + ( + ScoreTweaker1::Score, + (ScoreTweaker2::Score, ScoreTweaker3::Score), + ), + Self::Score, + >; + type Score = ( + ScoreTweaker1::Score, + ScoreTweaker2::Score, + ScoreTweaker3::Score, + ); fn segment_tweaker(&self, segment_reader: &SegmentReader) -> Result { let score_tweaker1 = self.0.segment_tweaker(segment_reader)?; let score_tweaker2 = self.1.segment_tweaker(segment_reader)?; let score_tweaker3 = self.2.segment_tweaker(segment_reader)?; - Ok( - MappedSegmentScoreTweaker { - tweaker: (score_tweaker1, (score_tweaker2, score_tweaker3)), - map: | (score1, (score2, score3))| (score1, score2, score3), - } - ) - + Ok(MappedSegmentScoreTweaker { + tweaker: (score_tweaker1, (score_tweaker2, score_tweaker3)), + map: |(score1, (score2, score3))| (score1, score2, score3), + }) } } -impl ScoreTweaker for (ScoreTweaker1, ScoreTweaker2, ScoreTweaker3, ScoreTweaker4) +impl ScoreTweaker + for (ScoreTweaker1, ScoreTweaker2, ScoreTweaker3, ScoreTweaker4) where ScoreTweaker1: ScoreTweaker, ScoreTweaker2: ScoreTweaker, ScoreTweaker3: ScoreTweaker, ScoreTweaker4: ScoreTweaker, { - type Child = MappedSegmentScoreTweaker<<(ScoreTweaker1, (ScoreTweaker2, (ScoreTweaker3, ScoreTweaker4))) as ScoreTweaker>::Child, (ScoreTweaker1::Score, (ScoreTweaker2::Score, (ScoreTweaker3::Score, ScoreTweaker4::Score))), Self::Score>; - type Score = (ScoreTweaker1::Score, ScoreTweaker2::Score, ScoreTweaker3::Score, ScoreTweaker4::Score); + type Child = MappedSegmentScoreTweaker< + <( + ScoreTweaker1, + (ScoreTweaker2, (ScoreTweaker3, ScoreTweaker4)), + ) as ScoreTweaker>::Child, + ( + ScoreTweaker1::Score, + ( + ScoreTweaker2::Score, + (ScoreTweaker3::Score, ScoreTweaker4::Score), + ), + ), + Self::Score, + >; + type Score = ( + ScoreTweaker1::Score, + ScoreTweaker2::Score, + ScoreTweaker3::Score, + ScoreTweaker4::Score, + ); fn segment_tweaker(&self, segment_reader: &SegmentReader) -> Result { let score_tweaker1 = self.0.segment_tweaker(segment_reader)?; let score_tweaker2 = self.1.segment_tweaker(segment_reader)?; let score_tweaker3 = self.2.segment_tweaker(segment_reader)?; let score_tweaker4 = self.3.segment_tweaker(segment_reader)?; - Ok( - MappedSegmentScoreTweaker { - tweaker: (score_tweaker1, (score_tweaker2, (score_tweaker3, score_tweaker4))), - map: | (score1, (score2, (score3, score4)))| (score1, score2, score3, score4), - } - ) - + Ok(MappedSegmentScoreTweaker { + tweaker: ( + score_tweaker1, + (score_tweaker2, (score_tweaker3, score_tweaker4)), + ), + map: |(score1, (score2, (score3, score4)))| (score1, score2, score3, score4), + }) } } From 485ee063fcc8fa51f77e05d2cbb0baced387b8c7 Mon Sep 17 00:00:00 2001 From: Paul Masurel Date: Fri, 31 Oct 2025 12:03:45 +0100 Subject: [PATCH 05/11] refactoring --- src/aggregation/metric/top_hits.rs | 22 +- src/collector/mod.rs | 4 +- src/collector/sort_key_top_collector.rs | 398 +++++++++++++++++++++ src/collector/top_collector.rs | 16 +- src/collector/top_score_collector.rs | 124 +++---- src/collector/tweak_score_top_collector.rs | 386 -------------------- 6 files changed, 481 insertions(+), 469 deletions(-) create mode 100644 src/collector/sort_key_top_collector.rs delete mode 100644 src/collector/tweak_score_top_collector.rs diff --git a/src/aggregation/metric/top_hits.rs b/src/aggregation/metric/top_hits.rs index 8156a1b667..bd3bf4c512 100644 --- a/src/aggregation/metric/top_hits.rs +++ b/src/aggregation/metric/top_hits.rs @@ -482,7 +482,7 @@ impl TopHitsTopNComputer { pub(crate) fn merge_fruits(&mut self, other_fruit: Self) -> crate::Result<()> { for doc in other_fruit.top_n.into_vec() { - self.collect(doc.feature, doc.doc); + self.collect(doc.sort_key, doc.doc); } Ok(()) } @@ -494,9 +494,9 @@ impl TopHitsTopNComputer { .into_sorted_vec() .into_iter() .map(|doc| TopHitsVecEntry { - sort: doc.feature.sorts.iter().map(|f| f.value).collect(), + sort: doc.sort_key.sorts.iter().map(|f| f.value).collect(), doc_value_fields: doc - .feature + .sort_key .doc_value_fields .into_iter() .map(|(k, v)| (k, v.into())) @@ -544,7 +544,7 @@ impl TopHitsSegmentCollector { let doc_value_fields = req.get_document_field_data(value_accessors, res.doc.doc_id); top_hits_computer.collect( DocSortValuesAndFields { - sorts: res.feature, + sorts: res.sort_key, doc_value_fields, }, res.doc, @@ -779,7 +779,7 @@ mod tests { segment_ord: 0, doc_id: 0, }, - feature: DocSortValuesAndFields { + sort_key: DocSortValuesAndFields { sorts: vec![DocValueAndOrder { value: Some(1), order: Order::Asc, @@ -792,7 +792,7 @@ mod tests { segment_ord: 0, doc_id: 2, }, - feature: DocSortValuesAndFields { + sort_key: DocSortValuesAndFields { sorts: vec![DocValueAndOrder { value: Some(3), order: Order::Asc, @@ -805,7 +805,7 @@ mod tests { segment_ord: 0, doc_id: 1, }, - feature: DocSortValuesAndFields { + sort_key: DocSortValuesAndFields { sorts: vec![DocValueAndOrder { value: Some(5), order: Order::Asc, @@ -817,7 +817,7 @@ mod tests { let mut collector = collector_with_capacity(3); for doc in docs.clone() { - collector.collect(doc.feature, doc.doc); + collector.collect(doc.sort_key, doc.doc); } let res = collector.into_final_result(); @@ -827,15 +827,15 @@ mod tests { super::TopHitsMetricResult { hits: vec![ super::TopHitsVecEntry { - sort: vec![docs[0].feature.sorts[0].value], + sort: vec![docs[0].sort_key.sorts[0].value], doc_value_fields: Default::default(), }, super::TopHitsVecEntry { - sort: vec![docs[1].feature.sorts[0].value], + sort: vec![docs[1].sort_key.sorts[0].value], doc_value_fields: Default::default(), }, super::TopHitsVecEntry { - sort: vec![docs[2].feature.sorts[0].value], + sort: vec![docs[2].sort_key.sorts[0].value], doc_value_fields: Default::default(), }, ] diff --git a/src/collector/mod.rs b/src/collector/mod.rs index a31754316e..da336ec3b0 100644 --- a/src/collector/mod.rs +++ b/src/collector/mod.rs @@ -103,8 +103,8 @@ pub use self::top_score_collector::{TopDocs, TopNComputer}; mod custom_score_top_collector; pub use self::custom_score_top_collector::{CustomScorer, CustomSegmentScorer}; -mod tweak_score_top_collector; -pub use self::tweak_score_top_collector::{ScoreSegmentTweaker, ScoreTweaker}; +mod sort_key_top_collector; +pub use self::sort_key_top_collector::{SegmentSortKeyComputer, SortKeyComputer}; mod facet_collector; pub use self::facet_collector::{FacetCollector, FacetCounts}; use crate::query::Weight; diff --git a/src/collector/sort_key_top_collector.rs b/src/collector/sort_key_top_collector.rs new file mode 100644 index 0000000000..4064b4f009 --- /dev/null +++ b/src/collector/sort_key_top_collector.rs @@ -0,0 +1,398 @@ +use std::cmp::Ordering; + +use crate::collector::top_collector::{TopCollector, TopSegmentCollector}; +use crate::collector::{Collector, SegmentCollector}; +use crate::{DocAddress, DocId, Result, Score, SegmentReader}; + +pub(crate) struct TopBySortKeyCollector { + sort_key_computer: TSortKeyComputer, + collector: TopCollector, +} + +impl TopBySortKeyCollector +where TSortKey: Clone + PartialOrd +{ + pub fn new( + sort_key_computer: TSortKeyComputer, + collector: TopCollector, + ) -> TopBySortKeyCollector { + TopBySortKeyCollector { + sort_key_computer, + collector, + } + } +} + +/// A `SegmentSortKeyComputer` makes it possible to modify the default score +/// for a given document belonging to a specific segment. +/// +/// It is the segment local version of the [`SortKeyComputer`]. +pub trait SegmentSortKeyComputer: 'static { + /// The final score being emitted. + type SortKey: 'static + PartialOrd + Send + Sync + Clone; + + /// Sort key used by at the segment level by the `SegmentSortKeyComputer`. + /// + /// It is typically small like a `u64`, and is meant to be converted + /// to the final score at the end of the collection of the segment. + type SegmentSortKey: 'static + PartialOrd + Clone + Send + Sync + Clone; + + /// Computes the sort key for the given document and score. + fn sort_key(&mut self, doc: DocId, score: Score) -> Self::SegmentSortKey; + + /// Returns true if the `SegmentSortKeyComputer` is a good candidate for the lazy evaluation + /// optimization. See [`SegmentSortKeyComputer::accept_score_lazy`]. + fn is_lazy() -> bool { + false + } + + /// Implementing this method makes it possible to avoid computing + /// a sort_key entirely if we can assess that it won't pass a threshold + /// with a partial computation. + /// + /// This is currently used for lexicographic sorting. + /// + /// If REVERSE_ORDER is false (resp. true), + /// - we return None if the score is below the threshold (resp. above to the threshold) + /// - we return Some(ordering, score) if the score is above or equal to the threshold (resp. + /// below or equal to) + fn accept_sort_key_lazy( + &mut self, + doc_id: DocId, + score: Score, + threshold: &Self::SegmentSortKey, + ) -> Option<(std::cmp::Ordering, Self::SegmentSortKey)> { + let excluded_ordering = if REVERSE_ORDER { + Ordering::Greater + } else { + Ordering::Less + }; + let sort_key = self.sort_key(doc_id, score); + let cmp = sort_key.partial_cmp(threshold).unwrap_or(excluded_ordering); + if cmp == excluded_ordering { + return None; + } else { + return Some((cmp, sort_key)); + } + } + + /// Convert a segment level sort key into the global sort key. + fn convert_segment_sort_key(&self, sort_key: Self::SegmentSortKey) -> Self::SortKey; +} + +/// `SortKeyComputer` defines the sort key to be used by a TopK Collector. +/// +/// The `SortKeyComputer` itself does not make much of the computation itself. +/// Instead, it helps constructing `Self::Child` instances that will compute +/// the sort key at a segment scale. +pub trait SortKeyComputer: Sync { + /// The sort key type. + type SortKey: 'static + Send + Sync + PartialOrd + Clone; + /// Type of the associated [`SegmentSortKeyComputer`]. + type Child: SegmentSortKeyComputer; + + /// Builds a child sort key computer for a specific segment. + fn segment_sort_key_computer(&self, segment_reader: &SegmentReader) -> Result; +} + +impl Collector for TopBySortKeyCollector +where + TSortKeyComputer: SortKeyComputer + Send + Sync, + TSortKey: 'static + Send + PartialOrd + Sync + Clone, +{ + type Fruit = Vec<(TSortKeyComputer::SortKey, DocAddress)>; + + type Child = TopBySortKeySegmentCollector; + + fn for_segment( + &self, + segment_local_id: u32, + segment_reader: &SegmentReader, + ) -> Result { + let segment_sort_key_computer = self + .sort_key_computer + .segment_sort_key_computer(segment_reader)?; + let segment_collector = self.collector.for_segment(segment_local_id, segment_reader); + Ok(TopBySortKeySegmentCollector { + segment_collector, + segment_sort_key_computer, + }) + } + + fn requires_scoring(&self) -> bool { + true + } + + fn merge_fruits(&self, segment_fruits: Vec) -> Result { + self.collector.merge_fruits(segment_fruits) + } +} + +pub struct TopBySortKeySegmentCollector +where TSegmentSortKeyComputer: SegmentSortKeyComputer +{ + segment_collector: TopSegmentCollector, + segment_sort_key_computer: TSegmentSortKeyComputer, +} + +impl SegmentCollector + for TopBySortKeySegmentCollector +where TSegmentSortKeyComputer: 'static + SegmentSortKeyComputer +{ + type Fruit = Vec<(TSegmentSortKeyComputer::SortKey, DocAddress)>; + + fn collect(&mut self, doc: DocId, score: Score) { + self.segment_collector + .collect_lazy(doc, score, &mut self.segment_sort_key_computer); + } + + fn harvest(self) -> Self::Fruit { + let segment_hits: Vec<(TSegmentSortKeyComputer::SegmentSortKey, DocAddress)> = + self.segment_collector.harvest(); + segment_hits + .into_iter() + .map(|(sort_key, doc)| { + ( + self.segment_sort_key_computer + .convert_segment_sort_key(sort_key), + doc, + ) + }) + .collect() + } +} + +impl SortKeyComputer for F +where + F: 'static + Send + Sync + Fn(&SegmentReader) -> TSegmentSortKeyComputer, + TSegmentSortKeyComputer: SegmentSortKeyComputer, +{ + type SortKey = TSegmentSortKeyComputer::SortKey; + type Child = TSegmentSortKeyComputer; + + fn segment_sort_key_computer(&self, segment_reader: &SegmentReader) -> Result { + Ok((self)(segment_reader)) + } +} + +impl SegmentSortKeyComputer for F +where + F: 'static + FnMut(DocId, Score) -> TSortKey, + TSortKey: 'static + PartialOrd + Clone + Send + Sync, +{ + type SortKey = TSortKey; + type SegmentSortKey = TSortKey; + + fn sort_key(&mut self, doc: DocId, score: Score) -> TSortKey { + (self)(doc, score) + } + + /// Convert a segment level score into the global level score. + fn convert_segment_sort_key(&self, sort_key: Self::SegmentSortKey) -> Self::SortKey { + sort_key + } +} + +impl SortKeyComputer + for (HeadSortKeyComputer, TailSortKeyComputer) +where + HeadSortKeyComputer: SortKeyComputer, + TailSortKeyComputer: SortKeyComputer, +{ + type SortKey = ( + ::SortKey, + ::SortKey, + ); + type Child = (HeadSortKeyComputer::Child, TailSortKeyComputer::Child); + + fn segment_sort_key_computer(&self, segment_reader: &SegmentReader) -> Result { + Ok(( + self.0.segment_sort_key_computer(segment_reader)?, + self.1.segment_sort_key_computer(segment_reader)?, + )) + } +} + +impl SegmentSortKeyComputer + for (HeadSegmentSortKeyComputer, TailSegmentSortKeyComputer) +where + HeadSegmentSortKeyComputer: SegmentSortKeyComputer, + TailSegmentSortKeyComputer: SegmentSortKeyComputer, +{ + type SortKey = ( + HeadSegmentSortKeyComputer::SortKey, + TailSegmentSortKeyComputer::SortKey, + ); + type SegmentSortKey = ( + HeadSegmentSortKeyComputer::SegmentSortKey, + TailSegmentSortKeyComputer::SegmentSortKey, + ); + + fn sort_key(&mut self, doc: DocId, score: Score) -> Self::SegmentSortKey { + let head_sort_key = self.0.sort_key(doc, score); + let tail_sort_key = self.1.sort_key(doc, score); + (head_sort_key, tail_sort_key) + } + + fn accept_sort_key_lazy( + &mut self, + doc_id: DocId, + score: Score, + threshold: &Self::SegmentSortKey, + ) -> Option<(Ordering, Self::SegmentSortKey)> { + let (head_threshold, tail_threshold) = threshold; + let (head_cmp, head_sort_key) = + self.0 + .accept_sort_key_lazy::(doc_id, score, head_threshold)?; + if head_cmp == Ordering::Equal { + let (tail_cmp, tail_sort_key) = + self.1 + .accept_sort_key_lazy::(doc_id, score, tail_threshold)?; + Some((tail_cmp, (head_sort_key, tail_sort_key))) + } else { + let tail_sort_key = self.1.sort_key(doc_id, score); + Some((head_cmp, (head_sort_key, tail_sort_key))) + } + } + + fn is_lazy() -> bool { + true + } + + fn convert_segment_sort_key(&self, sort_key: Self::SegmentSortKey) -> Self::SortKey { + let (head_sort_key, tail_sort_key) = sort_key; + ( + self.0.convert_segment_sort_key(head_sort_key), + self.1.convert_segment_sort_key(tail_sort_key), + ) + } +} + +/// This struct is used as an adapter to take a sort key computer and map its score to another +/// new sort key. +pub struct MappedSegmentSortKeyComputer { + sort_key_computer: T, + map: fn(PreviousSortKey) -> NewSortKey, +} + +impl SegmentSortKeyComputer + for MappedSegmentSortKeyComputer +where + T: SegmentSortKeyComputer, + PreviousScore: 'static + Clone + Send + Sync + PartialOrd, + NewScore: 'static + Clone + Send + Sync + PartialOrd, +{ + type SortKey = NewScore; + type SegmentSortKey = T::SegmentSortKey; + + fn sort_key(&mut self, doc: DocId, score: Score) -> Self::SegmentSortKey { + self.sort_key_computer.sort_key(doc, score) + } + + fn accept_sort_key_lazy( + &mut self, + doc_id: DocId, + score: Score, + threshold: &Self::SegmentSortKey, + ) -> Option<(std::cmp::Ordering, Self::SegmentSortKey)> { + self.sort_key_computer + .accept_sort_key_lazy::(doc_id, score, threshold) + } + + fn is_lazy() -> bool { + T::is_lazy() + } + + fn convert_segment_sort_key(&self, segment_sort_key: Self::SegmentSortKey) -> Self::SortKey { + (self.map)( + self.sort_key_computer + .convert_segment_sort_key(segment_sort_key), + ) + } +} + +// We then re-use our (head, tail) implement and our mapper by seeing mapping any tuple (a, b, c, +// ...) as the chain (a, (b, (c, ...))) + +impl SortKeyComputer + for (SortKeyComputer1, SortKeyComputer2, SortKeyComputer3) +where + SortKeyComputer1: SortKeyComputer, + SortKeyComputer2: SortKeyComputer, + SortKeyComputer3: SortKeyComputer, +{ + type Child = MappedSegmentSortKeyComputer< + <(SortKeyComputer1, (SortKeyComputer2, SortKeyComputer3)) as SortKeyComputer>::Child, + ( + SortKeyComputer1::SortKey, + (SortKeyComputer2::SortKey, SortKeyComputer3::SortKey), + ), + Self::SortKey, + >; + type SortKey = ( + SortKeyComputer1::SortKey, + SortKeyComputer2::SortKey, + SortKeyComputer3::SortKey, + ); + + fn segment_sort_key_computer(&self, segment_reader: &SegmentReader) -> Result { + let sort_key_computer1 = self.0.segment_sort_key_computer(segment_reader)?; + let sort_key_computer2 = self.1.segment_sort_key_computer(segment_reader)?; + let sort_key_computer3 = self.2.segment_sort_key_computer(segment_reader)?; + Ok(MappedSegmentSortKeyComputer { + sort_key_computer: (sort_key_computer1, (sort_key_computer2, sort_key_computer3)), + map: |(sort_key1, (sort_key2, sort_key3))| (sort_key1, sort_key2, sort_key3), + }) + } +} + +impl SortKeyComputer + for ( + SortKeyComputer1, + SortKeyComputer2, + SortKeyComputer3, + SortKeyComputer4, + ) +where + SortKeyComputer1: SortKeyComputer, + SortKeyComputer2: SortKeyComputer, + SortKeyComputer3: SortKeyComputer, + SortKeyComputer4: SortKeyComputer, +{ + type Child = MappedSegmentSortKeyComputer< + <( + SortKeyComputer1, + (SortKeyComputer2, (SortKeyComputer3, SortKeyComputer4)), + ) as SortKeyComputer>::Child, + ( + SortKeyComputer1::SortKey, + ( + SortKeyComputer2::SortKey, + (SortKeyComputer3::SortKey, SortKeyComputer4::SortKey), + ), + ), + Self::SortKey, + >; + type SortKey = ( + SortKeyComputer1::SortKey, + SortKeyComputer2::SortKey, + SortKeyComputer3::SortKey, + SortKeyComputer4::SortKey, + ); + + fn segment_sort_key_computer(&self, segment_reader: &SegmentReader) -> Result { + let sort_key_computer1 = self.0.segment_sort_key_computer(segment_reader)?; + let sort_key_computer2 = self.1.segment_sort_key_computer(segment_reader)?; + let sort_key_computer3 = self.2.segment_sort_key_computer(segment_reader)?; + let sort_key_computer4 = self.3.segment_sort_key_computer(segment_reader)?; + Ok(MappedSegmentSortKeyComputer { + sort_key_computer: ( + sort_key_computer1, + (sort_key_computer2, (sort_key_computer3, sort_key_computer4)), + ), + map: |(sort_key1, (sort_key2, (sort_key3, sort_key4)))| { + (sort_key1, sort_key2, sort_key3, sort_key4) + }, + }) + } +} diff --git a/src/collector/top_collector.rs b/src/collector/top_collector.rs index 1fcfe5d204..6d939dd2fb 100644 --- a/src/collector/top_collector.rs +++ b/src/collector/top_collector.rs @@ -4,7 +4,7 @@ use std::marker::PhantomData; use serde::{Deserialize, Serialize}; use super::top_score_collector::TopNComputer; -use crate::collector::ScoreSegmentTweaker; +use crate::collector::SegmentSortKeyComputer; use crate::index::SegmentReader; use crate::{DocAddress, DocId, SegmentOrdinal}; @@ -20,7 +20,7 @@ use crate::{DocAddress, DocId, SegmentOrdinal}; pub struct ComparableDoc { /// The feature of the document. In practice, this is /// is any type that implements `PartialOrd`. - pub feature: T, + pub sort_key: T, /// The document address. In practice, this is any /// type that implements `PartialOrd`, and is guaranteed /// to be unique for each document. @@ -31,7 +31,7 @@ impl std::fmt::Debug { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct(format!("ComparableDoc<_, _ {R}").as_str()) - .field("feature", &self.feature) + .field("feature", &self.sort_key) .field("doc", &self.doc) .finish() } @@ -47,8 +47,8 @@ impl Ord for ComparableDoc #[inline] fn cmp(&self, other: &Self) -> Ordering { let by_feature = self - .feature - .partial_cmp(&other.feature) + .sort_key + .partial_cmp(&other.sort_key) .map(|ord| if R { ord.reverse() } else { ord }) .unwrap_or(Ordering::Equal); @@ -118,7 +118,7 @@ where T: PartialOrd + Clone .into_sorted_vec() .into_iter() .skip(self.offset) - .map(|cdoc| (cdoc.feature, cdoc.doc)) + .map(|cdoc| (cdoc.sort_key, cdoc.doc)) .collect()) } @@ -174,7 +174,7 @@ impl TopSegmentCollector { .into_iter() .map(|comparable_doc| { ( - comparable_doc.feature, + comparable_doc.sort_key, DocAddress { segment_ord, doc_id: comparable_doc.doc, @@ -198,7 +198,7 @@ impl TopSegmentCollector { &mut self, doc: DocId, score: crate::Score, - segment_scorer: &mut impl ScoreSegmentTweaker, + segment_scorer: &mut impl SegmentSortKeyComputer, ) { self.topn_computer.push_lazy(doc, score, segment_scorer); } diff --git a/src/collector/top_score_collector.rs b/src/collector/top_score_collector.rs index 1e7188a75f..defa418896 100644 --- a/src/collector/top_score_collector.rs +++ b/src/collector/top_score_collector.rs @@ -9,10 +9,10 @@ use super::Collector; use crate::collector::custom_score_top_collector::{ CustomScoreTopCollector, CustomScoreTopSegmentCollector, }; +use crate::collector::sort_key_top_collector::TopBySortKeyCollector; use crate::collector::top_collector::{ComparableDoc, TopCollector, TopSegmentCollector}; -use crate::collector::tweak_score_top_collector::TweakedScoreTopCollector; use crate::collector::{ - CustomScorer, CustomSegmentScorer, ScoreSegmentTweaker, ScoreTweaker, SegmentCollector, + CustomScorer, CustomSegmentScorer, SegmentCollector, SegmentSortKeyComputer, SortKeyComputer, }; use crate::fastfield::{FastFieldNotAvailableError, FastValue}; use crate::query::Weight; @@ -155,7 +155,7 @@ impl Collector for StringConvertCollector { .into_sorted_vec() .into_iter() .skip(self.offset) - .map(|cdoc| (cdoc.feature, cdoc.doc)) + .map(|cdoc| (cdoc.sort_key, cdoc.doc)) .collect()) } else { let mut top_collector: TopNComputer<_, _, false> = @@ -170,7 +170,7 @@ impl Collector for StringConvertCollector { .into_sorted_vec() .into_iter() .skip(self.offset) - .map(|cdoc| (cdoc.feature, cdoc.doc)) + .map(|cdoc| (cdoc.sort_key, cdoc.doc)) .collect()) } } @@ -598,7 +598,7 @@ impl TopDocs { /// /// This method offers a convenient way to tweak or replace /// the documents score. As suggested by the prototype you can - /// manually define your own [`ScoreTweaker`] + /// manually define your own [`SortKeyComputer`] /// and pass it as an argument, but there is a much simpler way to /// tweak your score: you can use a closure as in the following /// example. @@ -690,15 +690,15 @@ impl TopDocs { /// /// # See also /// - [custom_score(...)](TopDocs::custom_score) - pub fn tweak_score( + pub fn by_sort_key( self, - score_tweaker: TScoreTweaker, - ) -> impl Collector> + sort_key_computer: impl SortKeyComputer + Send + Sync, + ) -> impl Collector> where - TScore: 'static + Clone + Send + Sync + PartialOrd, - TScoreTweaker: ScoreTweaker + Send + Sync, + TSortKey: 'static + Clone + Send + Sync + PartialOrd, + TSortKeyComputer: SortKeyComputer + Send + Sync, { - TweakedScoreTopCollector::new(score_tweaker, self.0.into_tscore()) + TopBySortKeyCollector::new(sort_key_computer, self.0.into_tscore()) } /// Ranks the documents using a custom score. @@ -872,7 +872,7 @@ impl Collector for TopDocs { .into_iter() .map(|cid| { ( - cid.feature, + cid.sort_key, DocAddress { segment_ord, doc_id: cid.doc, @@ -969,9 +969,9 @@ impl From> for TopNCompu } } -impl TopNComputer +impl TopNComputer where - Score: PartialOrd + Clone, + TSortKey: PartialOrd + Clone, D: Ord, { /// Create a new `TopNComputer`. @@ -988,12 +988,12 @@ where /// Push a new document to the top n. /// If the document is below the current threshold, it will be ignored. #[inline] - pub fn push(&mut self, feature: Score, doc: D) { + pub fn push(&mut self, sort_key: TSortKey, doc: D) { if let Some(last_median) = self.threshold.clone() { - if !REVERSE_ORDER && feature > last_median { + if !REVERSE_ORDER && sort_key > last_median { return; } - if REVERSE_ORDER && feature < last_median { + if REVERSE_ORDER && sort_key < last_median { return; } } @@ -1004,21 +1004,16 @@ where // This cannot panic, because we truncate_median will at least remove one element, since // the min capacity is 2. - self.append_within_capacity(ComparableDoc { doc, feature }); - } - - // ONLY CALL THIS FUNCTION WHEN YOU KNOW THE BUFFER HAS ENOUGH CAPACITY. - #[inline(always)] - fn append_within_capacity(&mut self, comparable_doc: ComparableDoc) { - push_within_capacity(comparable_doc, &mut self.buffer); + let comparable_doc = ComparableDoc { doc, sort_key }; + push_assuming_capacity(comparable_doc, &mut self.buffer); } #[inline(never)] - fn truncate_top_n(&mut self) -> Score { + fn truncate_top_n(&mut self) -> TSortKey { // Use select_nth_unstable to find the top nth score let (_, median_el, _) = self.buffer.select_nth_unstable(self.top_n); - let median_score = median_el.feature.clone(); + let median_score = median_el.sort_key.clone(); // Remove all elements below the top_n self.buffer.truncate(self.top_n); @@ -1026,7 +1021,7 @@ where } /// Returns the top n elements in sorted order. - pub fn into_sorted_vec(mut self) -> Vec> { + pub fn into_sorted_vec(mut self) -> Vec> { if self.buffer.len() > self.top_n { self.truncate_top_n(); } @@ -1037,7 +1032,7 @@ where /// Returns the top n elements in stored order. /// Useful if you do not need the elements in sorted order, /// for example when merging the results of multiple segments. - pub fn into_vec(mut self) -> Vec> { + pub fn into_vec(mut self) -> Vec> { if self.buffer.len() > self.top_n { self.truncate_top_n(); } @@ -1049,39 +1044,44 @@ impl TopNComputer>( + pub(crate) fn push_lazy< + TSegmentSortKeyComputer: SegmentSortKeyComputer, + >( &mut self, doc: DocId, score: Score, - score_tweaker: &mut TScoreSegmentTweaker, + score_tweaker: &mut TSegmentSortKeyComputer, ) { - if !TScoreSegmentTweaker::is_lazy() { - let feature = score_tweaker.score(doc, score); - self.push(feature, doc); - return; - } - - if let Some(threshold) = self.threshold.as_ref() { - if let Some((_cmp, feature)) = - score_tweaker.accept_score_lazy::(doc, score, threshold) - { - self.append_within_capacity(ComparableDoc { feature, doc }); + if TSegmentSortKeyComputer::is_lazy() { + if let Some(threshold) = self.threshold.as_ref() { + let Some((_cmp, feature)) = + score_tweaker.accept_sort_key_lazy::(doc, score, threshold) + else { + return; + }; + push_assuming_capacity( + ComparableDoc { + sort_key: feature, + doc, + }, + &mut self.buffer, + ); + return; } - } else { - let feature = score_tweaker.score(doc, score); - self.append_within_capacity(ComparableDoc { feature, doc }); } + let feature = score_tweaker.sort_key(doc, score); + self.push(feature, doc); + return; } } // Push an element provided there is enough capacity to do so. -// TODO replace me when push_within_capacity is stabilized. +// +// Panics if there is not enough capacity to add an element. #[inline(always)] -fn push_within_capacity(el: T, buf: &mut Vec) { +fn push_assuming_capacity(el: T, buf: &mut Vec) { let prev_len = buf.len(); - if prev_len == buf.capacity() { - return; - } + assert!(prev_len < buf.capacity()); // This is mimicking the current (non-stabilized) implementation in std. // SAFETY: we just checked we have enough capacity. unsafe { @@ -1141,7 +1141,7 @@ mod tests { assert_eq!( computer.into_sorted_vec(), &[ComparableDoc { - feature: 1u32, + sort_key: 1u32, doc: 0u32, },] ); @@ -1169,11 +1169,11 @@ mod tests { computer.into_sorted_vec(), &[ ComparableDoc { - feature: 3u32, + sort_key: 3u32, doc: 3u32, }, ComparableDoc { - feature: 2u32, + sort_key: 2u32, doc: 2u32, } ] @@ -1202,7 +1202,7 @@ mod tests { for (feature, doc) in &docs { computer.push(*feature, *doc); } - let mut comparable_docs = docs.into_iter().map(|(feature, doc)| ComparableDoc { feature, doc }).collect::>(); + let mut comparable_docs = docs.into_iter().map(|(feature, doc)| ComparableDoc { sort_key, doc }).collect::>(); comparable_docs.sort(); comparable_docs.truncate(limit); prop_assert_eq!( @@ -1220,7 +1220,7 @@ mod tests { for (feature, doc) in &docs { computer.push(*feature, *doc); } - let mut comparable_docs = docs.into_iter().map(|(feature, doc)| ComparableDoc { feature, doc }).collect::>(); + let mut comparable_docs = docs.into_iter().map(|(feature, doc)| ComparableDoc { sort_key, doc }).collect::>(); comparable_docs.sort(); comparable_docs.truncate(limit); prop_assert_eq!( @@ -1835,14 +1835,14 @@ mod tests { // offset, and then taking the limit. let sorted_docs: Vec<_> = if order.is_desc() { let mut comparable_docs: Vec> = - all_results.into_iter().map(|(feature, doc)| ComparableDoc { feature, doc}).collect(); + all_results.into_iter().map(|(feature, doc)| ComparableDoc { sort_key, doc}).collect(); comparable_docs.sort(); - comparable_docs.into_iter().map(|cd| (cd.feature, cd.doc)).collect() + comparable_docs.into_iter().map(|cd| (cd.sort_key, cd.doc)).collect() } else { let mut comparable_docs: Vec> = - all_results.into_iter().map(|(feature, doc)| ComparableDoc { feature, doc}).collect(); + all_results.into_iter().map(|(feature, doc)| ComparableDoc { sort_key, doc}).collect(); comparable_docs.sort(); - comparable_docs.into_iter().map(|cd| (cd.feature, cd.doc)).collect() + comparable_docs.into_iter().map(|cd| (cd.sort_key, cd.doc)).collect() }; let expected_docs = sorted_docs.into_iter().skip(offset).take(limit).collect::>(); prop_assert_eq!( @@ -1912,12 +1912,12 @@ mod tests { } #[test] - fn test_tweak_score_top_collector_with_offset() -> crate::Result<()> { + fn test_sort_key_top_collector_with_offset() -> crate::Result<()> { let index = make_index()?; let field = index.schema().get_field("text").unwrap(); let query_parser = QueryParser::for_index(&index, vec![field]); let text_query = query_parser.parse_query("droopy tax")?; - let collector = TopDocs::with_limit(2).and_offset(1).tweak_score( + let collector = TopDocs::with_limit(2).and_offset(1).by_sort_key( move |_segment_reader: &SegmentReader| move |doc: DocId, _original_score: Score| doc, ); let score_docs: Vec<(u32, DocAddress)> = @@ -2026,11 +2026,11 @@ mod tests { computer.into_sorted_vec(), &[ ComparableDoc { - feature: 1u32, + sort_key: 1u32, doc: 1u32, }, ComparableDoc { - feature: 1u32, + sort_key: 1u32, doc: 6u32, } ] diff --git a/src/collector/tweak_score_top_collector.rs b/src/collector/tweak_score_top_collector.rs deleted file mode 100644 index f7825ebed7..0000000000 --- a/src/collector/tweak_score_top_collector.rs +++ /dev/null @@ -1,386 +0,0 @@ -use std::cmp::Ordering; - -use crate::collector::top_collector::{TopCollector, TopSegmentCollector}; -use crate::collector::{Collector, SegmentCollector}; -use crate::{DocAddress, DocId, Result, Score, SegmentReader}; - -pub(crate) struct TweakedScoreTopCollector { - score_tweaker: TScoreTweaker, - collector: TopCollector, -} - -impl TweakedScoreTopCollector -where TScore: Clone + PartialOrd -{ - pub fn new( - score_tweaker: TScoreTweaker, - collector: TopCollector, - ) -> TweakedScoreTopCollector { - TweakedScoreTopCollector { - score_tweaker, - collector, - } - } -} - -/// A `ScoreSegmentTweaker` makes it possible to modify the default score -/// for a given document belonging to a specific segment. -/// -/// It is the segment local version of the [`ScoreTweaker`]. -pub trait ScoreSegmentTweaker: 'static { - /// The final score being emitted. - type Score: 'static + PartialOrd + Send + Sync + Clone; - - /// Score used by at the segment level by the `ScoreSegmentTweaker`. - /// - /// It is typically small like a `u64`, and is meant to be converted - /// to the final score at the end of the collection of the segment. - type SegmentScore: 'static + PartialOrd + Clone + Send + Sync + Clone; - - /// Tweak the given `score` for the document `doc`. - fn score(&mut self, doc: DocId, score: Score) -> Self::SegmentScore; - - /// Returns true if the `ScoreSegmentTweaker` is a good candidate for the lazy evaluation - /// optimization. See [`ScoreSegmentTweaker::accept_score_lazy`]. - fn is_lazy() -> bool { - false - } - - /// Implementing this method makes it possible to avoid computing - /// a score entirely if we can assess that it won't pass a threshold - /// with a partial computation. - /// - /// This is currently used for lexicographic sorting. - /// - /// If REVERSE_ORDER is false (resp. true), - /// - we return None if the score is below the threshold (resp. above to the threshold) - /// - we return Some(ordering, score) if the score is above or equal to the threshold (resp. - /// below or equal to) - fn accept_score_lazy( - &mut self, - doc_id: DocId, - score: Score, - threshold: &Self::SegmentScore, - ) -> Option<(std::cmp::Ordering, Self::SegmentScore)> { - let excluded_ordering = if REVERSE_ORDER { - Ordering::Greater - } else { - Ordering::Less - }; - let score = self.score(doc_id, score); - let cmp = score.partial_cmp(threshold).unwrap_or(excluded_ordering); - if cmp == excluded_ordering { - return None; - } else { - return Some((cmp, score)); - } - } - - /// Convert a segment level score into the global level score. - fn convert_segment_score_to_score(&self, score: Self::SegmentScore) -> Self::Score; -} - -/// `ScoreTweaker` makes it possible to tweak the score -/// emitted by the scorer into another one. -/// -/// The `ScoreTweaker` itself does not make much of the computation itself. -/// Instead, it helps constructing `Self::Child` instances that will compute -/// the score at a segment scale. -pub trait ScoreTweaker: Sync { - /// The actual score emitted by the Tweaker. - type Score: 'static + Send + Sync + PartialOrd + Clone; - /// Type of the associated [`ScoreSegmentTweaker`]. - type Child: ScoreSegmentTweaker; - - /// Builds a child tweaker for a specific segment. The child scorer is associated with - /// a specific segment. - fn segment_tweaker(&self, segment_reader: &SegmentReader) -> Result; -} - -impl Collector for TweakedScoreTopCollector -where - TScoreTweaker: ScoreTweaker + Send + Sync, - TScore: 'static + Send + PartialOrd + Sync + Clone, -{ - type Fruit = Vec<(TScoreTweaker::Score, DocAddress)>; - - type Child = TopTweakedScoreSegmentCollector; - - fn for_segment( - &self, - segment_local_id: u32, - segment_reader: &SegmentReader, - ) -> Result { - let segment_scorer = self.score_tweaker.segment_tweaker(segment_reader)?; - let segment_collector = self.collector.for_segment(segment_local_id, segment_reader); - Ok(TopTweakedScoreSegmentCollector { - segment_collector, - segment_scorer, - }) - } - - fn requires_scoring(&self) -> bool { - true - } - - fn merge_fruits(&self, segment_fruits: Vec) -> Result { - self.collector.merge_fruits(segment_fruits) - } -} - -pub struct TopTweakedScoreSegmentCollector -where TSegmentScoreTweaker: ScoreSegmentTweaker -{ - segment_collector: TopSegmentCollector, - segment_scorer: TSegmentScoreTweaker, -} - -impl SegmentCollector - for TopTweakedScoreSegmentCollector -where TSegmentScoreTweaker: 'static + ScoreSegmentTweaker -{ - type Fruit = Vec<(TSegmentScoreTweaker::Score, DocAddress)>; - - fn collect(&mut self, doc: DocId, score: Score) { - self.segment_collector - .collect_lazy(doc, score, &mut self.segment_scorer); - } - - fn harvest(self) -> Self::Fruit { - let segment_hits: Vec<(TSegmentScoreTweaker::SegmentScore, DocAddress)> = - self.segment_collector.harvest(); - segment_hits - .into_iter() - .map(|(score, doc)| { - ( - self.segment_scorer.convert_segment_score_to_score(score), - doc, - ) - }) - .collect() - } -} - -impl ScoreTweaker for F -where - F: 'static + Send + Sync + Fn(&SegmentReader) -> TSegmentScoreTweaker, - TSegmentScoreTweaker: ScoreSegmentTweaker, -{ - type Score = TSegmentScoreTweaker::Score; - type Child = TSegmentScoreTweaker; - - fn segment_tweaker(&self, segment_reader: &SegmentReader) -> Result { - Ok((self)(segment_reader)) - } -} - -impl ScoreSegmentTweaker for F -where - F: 'static + FnMut(DocId, Score) -> TScore, - TScore: 'static + PartialOrd + Clone + Send + Sync, -{ - type Score = TScore; - type SegmentScore = TScore; - - fn score(&mut self, doc: DocId, score: Score) -> TScore { - (self)(doc, score) - } - - /// Convert a segment level score into the global level score. - fn convert_segment_score_to_score(&self, score: Self::SegmentScore) -> Self::Score { - score - } -} - -impl ScoreTweaker for (HeadScoreTweaker, TailScoreTweaker) -where - HeadScoreTweaker: ScoreTweaker, - TailScoreTweaker: ScoreTweaker, -{ - type Score = ( - ::Score, - ::Score, - ); - type Child = (HeadScoreTweaker::Child, TailScoreTweaker::Child); - - fn segment_tweaker(&self, segment_reader: &SegmentReader) -> Result { - Ok(( - self.0.segment_tweaker(segment_reader)?, - self.1.segment_tweaker(segment_reader)?, - )) - } -} - -impl ScoreSegmentTweaker - for (HeadScoreSegmentTweaker, TailScoreSegmentTweaker) -where - HeadScoreSegmentTweaker: ScoreSegmentTweaker, - TailScoreSegmentTweaker: ScoreSegmentTweaker, -{ - type Score = ( - HeadScoreSegmentTweaker::Score, - TailScoreSegmentTweaker::Score, - ); - type SegmentScore = ( - HeadScoreSegmentTweaker::SegmentScore, - TailScoreSegmentTweaker::SegmentScore, - ); - - fn score(&mut self, doc: DocId, score: Score) -> Self::SegmentScore { - let head_score = self.0.score(doc, score); - let tail_score = self.1.score(doc, score); - (head_score, tail_score) - } - - fn accept_score_lazy( - &mut self, - doc_id: DocId, - score: Score, - threshold: &Self::SegmentScore, - ) -> Option<(Ordering, Self::SegmentScore)> { - let (head_threshold, tail_threshold) = threshold; - let (head_cmp, head_score) = - self.0 - .accept_score_lazy::(doc_id, score, head_threshold)?; - if head_cmp == Ordering::Equal { - let (tail_cmp, tail_score) = - self.1 - .accept_score_lazy::(doc_id, score, tail_threshold)?; - Some((tail_cmp, (head_score, tail_score))) - } else { - let tail_score = self.1.score(doc_id, score); - Some((head_cmp, (head_score, tail_score))) - } - } - - fn is_lazy() -> bool { - true - } - - fn convert_segment_score_to_score(&self, score: Self::SegmentScore) -> Self::Score { - let (head_score, tail_score) = score; - ( - self.0.convert_segment_score_to_score(head_score), - self.1.convert_segment_score_to_score(tail_score), - ) - } -} - -/// This struct is used as an adapter to take a segment score tweaker and map its score to another -/// new score. -pub struct MappedSegmentScoreTweaker { - tweaker: T, - map: fn(PreviousScore) -> NewScore, -} - -impl ScoreSegmentTweaker - for MappedSegmentScoreTweaker -where - T: ScoreSegmentTweaker, - PreviousScore: 'static + Clone + Send + Sync + PartialOrd, - NewScore: 'static + Clone + Send + Sync + PartialOrd, -{ - type Score = NewScore; - type SegmentScore = T::SegmentScore; - - fn score(&mut self, doc: DocId, score: Score) -> Self::SegmentScore { - self.tweaker.score(doc, score) - } - - fn accept_score_lazy( - &mut self, - doc_id: DocId, - score: Score, - threshold: &Self::SegmentScore, - ) -> Option<(std::cmp::Ordering, Self::SegmentScore)> { - self.tweaker - .accept_score_lazy::(doc_id, score, threshold) - } - - fn is_lazy() -> bool { - T::is_lazy() - } - - fn convert_segment_score_to_score(&self, score: Self::SegmentScore) -> Self::Score { - (self.map)(self.tweaker.convert_segment_score_to_score(score)) - } -} - -// We then re-use our (head, tail) implement and our mapper by seeing mapping any tuple (a, b, c, -// ...) as the chain (a, (b, (c, ...))) - -impl ScoreTweaker - for (ScoreTweaker1, ScoreTweaker2, ScoreTweaker3) -where - ScoreTweaker1: ScoreTweaker, - ScoreTweaker2: ScoreTweaker, - ScoreTweaker3: ScoreTweaker, -{ - type Child = MappedSegmentScoreTweaker< - <(ScoreTweaker1, (ScoreTweaker2, ScoreTweaker3)) as ScoreTweaker>::Child, - ( - ScoreTweaker1::Score, - (ScoreTweaker2::Score, ScoreTweaker3::Score), - ), - Self::Score, - >; - type Score = ( - ScoreTweaker1::Score, - ScoreTweaker2::Score, - ScoreTweaker3::Score, - ); - - fn segment_tweaker(&self, segment_reader: &SegmentReader) -> Result { - let score_tweaker1 = self.0.segment_tweaker(segment_reader)?; - let score_tweaker2 = self.1.segment_tweaker(segment_reader)?; - let score_tweaker3 = self.2.segment_tweaker(segment_reader)?; - Ok(MappedSegmentScoreTweaker { - tweaker: (score_tweaker1, (score_tweaker2, score_tweaker3)), - map: |(score1, (score2, score3))| (score1, score2, score3), - }) - } -} - -impl ScoreTweaker - for (ScoreTweaker1, ScoreTweaker2, ScoreTweaker3, ScoreTweaker4) -where - ScoreTweaker1: ScoreTweaker, - ScoreTweaker2: ScoreTweaker, - ScoreTweaker3: ScoreTweaker, - ScoreTweaker4: ScoreTweaker, -{ - type Child = MappedSegmentScoreTweaker< - <( - ScoreTweaker1, - (ScoreTweaker2, (ScoreTweaker3, ScoreTweaker4)), - ) as ScoreTweaker>::Child, - ( - ScoreTweaker1::Score, - ( - ScoreTweaker2::Score, - (ScoreTweaker3::Score, ScoreTweaker4::Score), - ), - ), - Self::Score, - >; - type Score = ( - ScoreTweaker1::Score, - ScoreTweaker2::Score, - ScoreTweaker3::Score, - ScoreTweaker4::Score, - ); - - fn segment_tweaker(&self, segment_reader: &SegmentReader) -> Result { - let score_tweaker1 = self.0.segment_tweaker(segment_reader)?; - let score_tweaker2 = self.1.segment_tweaker(segment_reader)?; - let score_tweaker3 = self.2.segment_tweaker(segment_reader)?; - let score_tweaker4 = self.3.segment_tweaker(segment_reader)?; - Ok(MappedSegmentScoreTweaker { - tweaker: ( - score_tweaker1, - (score_tweaker2, (score_tweaker3, score_tweaker4)), - ), - map: |(score1, (score2, (score3, score4)))| (score1, score2, score3, score4), - }) - } -} From 1a677e6eae79447ea512c4103be2cf437b321560 Mon Sep 17 00:00:00 2001 From: Paul Masurel Date: Fri, 31 Oct 2025 12:07:53 +0100 Subject: [PATCH 06/11] refactoring --- examples/faceted_search_with_tweaked_score.rs | 2 +- src/collector/top_score_collector.rs | 120 +----------------- 2 files changed, 4 insertions(+), 118 deletions(-) diff --git a/examples/faceted_search_with_tweaked_score.rs b/examples/faceted_search_with_tweaked_score.rs index d21a1c3d4c..7200c9bd54 100644 --- a/examples/faceted_search_with_tweaked_score.rs +++ b/examples/faceted_search_with_tweaked_score.rs @@ -65,7 +65,7 @@ fn main() -> tantivy::Result<()> { ); let top_docs_by_custom_score = // Call TopDocs with a custom tweak score - TopDocs::with_limit(2).tweak_score(move |segment_reader: &SegmentReader| { + TopDocs::with_limit(2).order_by(move |segment_reader: &SegmentReader| { let ingredient_reader = segment_reader.facet_reader("ingredient").unwrap(); let facet_dict = ingredient_reader.facet_dict(); diff --git a/src/collector/top_score_collector.rs b/src/collector/top_score_collector.rs index defa418896..2ae3edcdaa 100644 --- a/src/collector/top_score_collector.rs +++ b/src/collector/top_score_collector.rs @@ -594,7 +594,7 @@ impl TopDocs { } } - /// Ranks the documents using a custom score. + /// Ranks the documents using a sort key. /// /// This method offers a convenient way to tweak or replace /// the documents score. As suggested by the prototype you can @@ -690,129 +690,15 @@ impl TopDocs { /// /// # See also /// - [custom_score(...)](TopDocs::custom_score) - pub fn by_sort_key( + pub fn order_by( self, sort_key_computer: impl SortKeyComputer + Send + Sync, ) -> impl Collector> where TSortKey: 'static + Clone + Send + Sync + PartialOrd, - TSortKeyComputer: SortKeyComputer + Send + Sync, { TopBySortKeyCollector::new(sort_key_computer, self.0.into_tscore()) } - - /// Ranks the documents using a custom score. - /// - /// This method offers a convenient way to use a different score. - /// - /// As suggested by the prototype you can manually define your own [`CustomScorer`] - /// and pass it as an argument, but there is a much simpler way to - /// tweak your score: you can use a closure as in the following - /// example. - /// - /// # Limitation - /// - /// This method only makes it possible to compute the score from a given - /// `DocId`, fastfield values for the doc and any information you could - /// have precomputed beforehand. It does not make it possible for instance - /// to compute something like TfIdf as it does not have access to the list of query - /// terms present in the document, nor the term frequencies for the different terms. - /// - /// It can be used if your search engine relies on a learning-to-rank model for instance, - /// which does not rely on the term frequencies or positions as features. - /// - /// # Example - /// - /// ```rust - /// # use tantivy::schema::{Schema, FAST, TEXT}; - /// # use tantivy::{doc, Index, DocAddress, DocId}; - /// # use tantivy::query::QueryParser; - /// use tantivy::SegmentReader; - /// use tantivy::collector::TopDocs; - /// use tantivy::schema::Field; - /// - /// # fn create_schema() -> Schema { - /// # let mut schema_builder = Schema::builder(); - /// # schema_builder.add_text_field("product_name", TEXT); - /// # schema_builder.add_u64_field("popularity", FAST); - /// # schema_builder.add_u64_field("boosted", FAST); - /// # schema_builder.build() - /// # } - /// # - /// # fn main() -> tantivy::Result<()> { - /// # let schema = create_schema(); - /// # let index = Index::create_in_ram(schema); - /// # let mut index_writer = index.writer_with_num_threads(1, 20_000_000)?; - /// # let product_name = index.schema().get_field("product_name").unwrap(); - /// # - /// let popularity: Field = index.schema().get_field("popularity").unwrap(); - /// let boosted: Field = index.schema().get_field("boosted").unwrap(); - /// # index_writer.add_document(doc!(boosted=>1u64, product_name => "The Diary of Muadib", popularity => 1u64))?; - /// # index_writer.add_document(doc!(boosted=>0u64, product_name => "A Dairy Cow", popularity => 10u64))?; - /// # index_writer.add_document(doc!(boosted=>0u64, product_name => "The Diary of a Young Girl", popularity => 15u64))?; - /// # index_writer.commit()?; - /// // ... - /// # let user_query = "diary"; - /// # let query = QueryParser::for_index(&index, vec![product_name]).parse_query(user_query)?; - /// - /// // This is where we build our collector with our custom score. - /// let top_docs_by_custom_score = TopDocs - /// ::with_limit(10) - /// .custom_score(move |segment_reader: &SegmentReader| { - /// // The argument is a function that returns our scoring - /// // function. - /// // - /// // The point of this "mother" function is to gather all - /// // of the segment level information we need for scoring. - /// // Typically, fast_fields. - /// // - /// // In our case, we will get a reader for the popularity - /// // fast field and a boosted field. - /// // - /// // We want to get boosted items score, and when we get - /// // a tie, return the item with the highest popularity. - /// // - /// // Note that this is implemented by using a `(u64, u64)` - /// // as a score. - /// let popularity_reader = - /// segment_reader.fast_fields().u64("popularity").unwrap().first_or_default_col(0); - /// let boosted_reader = - /// segment_reader.fast_fields().u64("boosted").unwrap().first_or_default_col(0); - /// - /// // We can now define our actual scoring function - /// move |doc: DocId| { - /// let popularity: u64 = popularity_reader.get_val(doc); - /// let boosted: u64 = boosted_reader.get_val(doc); - /// // Score do not have to be `f64` in tantivy. - /// // Here we return a couple to get lexicographical order - /// // for free. - /// (boosted, popularity) - /// } - /// }); - /// # let reader = index.reader()?; - /// # let searcher = reader.searcher(); - /// // ... and here are our documents. Note this is a simple vec. - /// // The `Score` in the pair is our tweaked score. - /// let resulting_docs: Vec<((u64, u64), DocAddress)> = - /// searcher.search(&*query, &top_docs_by_custom_score)?; - /// - /// # Ok(()) - /// # } - /// ``` - /// - /// # See also - /// - [tweak_score(...)](TopDocs::tweak_score) - pub fn custom_score( - self, - custom_score: TCustomScorer, - ) -> impl Collector> - where - TScore: 'static + Send + Sync + Clone + PartialOrd, - TCustomSegmentScorer: CustomSegmentScorer + 'static, - TCustomScorer: CustomScorer + Send + Sync, - { - CustomScoreTopCollector::new(custom_score, self.0.into_tscore()) - } } impl Collector for TopDocs { @@ -1917,7 +1803,7 @@ mod tests { let field = index.schema().get_field("text").unwrap(); let query_parser = QueryParser::for_index(&index, vec![field]); let text_query = query_parser.parse_query("droopy tax")?; - let collector = TopDocs::with_limit(2).and_offset(1).by_sort_key( + let collector = TopDocs::with_limit(2).and_offset(1).order_by( move |_segment_reader: &SegmentReader| move |doc: DocId, _original_score: Score| doc, ); let score_docs: Vec<(u32, DocAddress)> = From b2cb883d0c48fab3f9501116a9ed456257fb258d Mon Sep 17 00:00:00 2001 From: Paul Masurel Date: Fri, 31 Oct 2025 12:31:46 +0100 Subject: [PATCH 07/11] Refactoring --- src/collector/custom_score_top_collector.rs | 121 ------ src/collector/mod.rs | 6 +- src/collector/sort_key.rs | 317 +++++++++++++++ src/collector/sort_key_top_collector.rs | 312 +-------------- src/collector/top_collector.rs | 4 +- src/collector/top_score_collector.rs | 417 ++++++++++---------- 6 files changed, 538 insertions(+), 639 deletions(-) delete mode 100644 src/collector/custom_score_top_collector.rs create mode 100644 src/collector/sort_key.rs diff --git a/src/collector/custom_score_top_collector.rs b/src/collector/custom_score_top_collector.rs deleted file mode 100644 index 54d42469eb..0000000000 --- a/src/collector/custom_score_top_collector.rs +++ /dev/null @@ -1,121 +0,0 @@ -use crate::collector::top_collector::{TopCollector, TopSegmentCollector}; -use crate::collector::{Collector, SegmentCollector}; -use crate::{DocAddress, DocId, Score, SegmentReader}; - -pub(crate) struct CustomScoreTopCollector { - custom_scorer: TCustomScorer, - collector: TopCollector, -} - -impl CustomScoreTopCollector -where TScore: Clone + PartialOrd -{ - pub(crate) fn new( - custom_scorer: TCustomScorer, - collector: TopCollector, - ) -> CustomScoreTopCollector { - CustomScoreTopCollector { - custom_scorer, - collector, - } - } -} - -/// A custom segment scorer makes it possible to define any kind of score -/// for a given document belonging to a specific segment. -/// -/// It is the segment local version of the [`CustomScorer`]. -pub trait CustomSegmentScorer: 'static { - /// Computes the score of a specific `doc`. - fn score(&mut self, doc: DocId) -> TScore; -} - -/// `CustomScorer` makes it possible to define any kind of score. -/// -/// The `CustomerScorer` itself does not make much of the computation itself. -/// Instead, it helps constructing `Self::Child` instances that will compute -/// the score at a segment scale. -pub trait CustomScorer: Sync { - /// Type of the associated [`CustomSegmentScorer`]. - type Child: CustomSegmentScorer; - /// Builds a child scorer for a specific segment. The child scorer is associated with - /// a specific segment. - fn segment_scorer(&self, segment_reader: &SegmentReader) -> crate::Result; -} - -impl Collector for CustomScoreTopCollector -where - TCustomScorer: CustomScorer + Send + Sync, - TScore: 'static + PartialOrd + Clone + Send + Sync, -{ - type Fruit = Vec<(TScore, DocAddress)>; - - type Child = CustomScoreTopSegmentCollector; - - fn for_segment( - &self, - segment_local_id: u32, - segment_reader: &SegmentReader, - ) -> crate::Result { - let segment_collector = self.collector.for_segment(segment_local_id, segment_reader); - let segment_scorer = self.custom_scorer.segment_scorer(segment_reader)?; - Ok(CustomScoreTopSegmentCollector { - segment_collector, - segment_scorer, - }) - } - - fn requires_scoring(&self) -> bool { - false - } - - fn merge_fruits(&self, segment_fruits: Vec) -> crate::Result { - self.collector.merge_fruits(segment_fruits) - } -} - -pub struct CustomScoreTopSegmentCollector -where - TScore: 'static + PartialOrd + Clone + Send + Sync + Sized, - T: CustomSegmentScorer, -{ - segment_collector: TopSegmentCollector, - segment_scorer: T, -} - -impl SegmentCollector for CustomScoreTopSegmentCollector -where - TScore: 'static + PartialOrd + Clone + Send + Sync, - T: 'static + CustomSegmentScorer, -{ - type Fruit = Vec<(TScore, DocAddress)>; - - fn collect(&mut self, doc: DocId, _score: Score) { - let score = self.segment_scorer.score(doc); - self.segment_collector.collect(doc, score); - } - - fn harvest(self) -> Vec<(TScore, DocAddress)> { - self.segment_collector.harvest() - } -} - -impl CustomScorer for F -where - F: 'static + Send + Sync + Fn(&SegmentReader) -> T, - T: CustomSegmentScorer, -{ - type Child = T; - - fn segment_scorer(&self, segment_reader: &SegmentReader) -> crate::Result { - Ok((self)(segment_reader)) - } -} - -impl CustomSegmentScorer for F -where F: 'static + FnMut(DocId) -> TScore -{ - fn score(&mut self, doc: DocId) -> TScore { - (self)(doc) - } -} diff --git a/src/collector/mod.rs b/src/collector/mod.rs index da336ec3b0..c7a4a1826b 100644 --- a/src/collector/mod.rs +++ b/src/collector/mod.rs @@ -100,11 +100,9 @@ mod top_score_collector; pub use self::top_collector::ComparableDoc; pub use self::top_score_collector::{TopDocs, TopNComputer}; -mod custom_score_top_collector; -pub use self::custom_score_top_collector::{CustomScorer, CustomSegmentScorer}; - +mod sort_key; mod sort_key_top_collector; -pub use self::sort_key_top_collector::{SegmentSortKeyComputer, SortKeyComputer}; +pub use self::sort_key::{SegmentSortKeyComputer, SortKeyComputer}; mod facet_collector; pub use self::facet_collector::{FacetCollector, FacetCounts}; use crate::query::Weight; diff --git a/src/collector/sort_key.rs b/src/collector/sort_key.rs new file mode 100644 index 0000000000..9b62a55fa0 --- /dev/null +++ b/src/collector/sort_key.rs @@ -0,0 +1,317 @@ +use std::cmp::Ordering; + +use crate::{DocId, Result, Score, SegmentReader}; + +/// A `SegmentSortKeyComputer` makes it possible to modify the default score +/// for a given document belonging to a specific segment. +/// +/// It is the segment local version of the [`SortKeyComputer`]. +pub trait SegmentSortKeyComputer: 'static { + /// The final score being emitted. + type SortKey: 'static + PartialOrd + Send + Sync + Clone; + + /// Sort key used by at the segment level by the `SegmentSortKeyComputer`. + /// + /// It is typically small like a `u64`, and is meant to be converted + /// to the final score at the end of the collection of the segment. + type SegmentSortKey: 'static + PartialOrd + Clone + Send + Sync + Clone; + + /// Computes the sort key for the given document and score. + fn sort_key(&mut self, doc: DocId, score: Score) -> Self::SegmentSortKey; + + /// Returns true if the `SegmentSortKeyComputer` is a good candidate for the lazy evaluation + /// optimization. See [`SegmentSortKeyComputer::accept_score_lazy`]. + fn is_lazy() -> bool { + false + } + + /// Implementing this method makes it possible to avoid computing + /// a sort_key entirely if we can assess that it won't pass a threshold + /// with a partial computation. + /// + /// This is currently used for lexicographic sorting. + /// + /// If REVERSE_ORDER is false (resp. true), + /// - we return None if the score is below the threshold (resp. above to the threshold) + /// - we return Some(ordering, score) if the score is above or equal to the threshold (resp. + /// below or equal to) + fn accept_sort_key_lazy( + &mut self, + doc_id: DocId, + score: Score, + threshold: &Self::SegmentSortKey, + ) -> Option<(std::cmp::Ordering, Self::SegmentSortKey)> { + let excluded_ordering = if REVERSE_ORDER { + Ordering::Greater + } else { + Ordering::Less + }; + let sort_key = self.sort_key(doc_id, score); + let cmp = sort_key.partial_cmp(threshold).unwrap_or(excluded_ordering); + if cmp == excluded_ordering { + return None; + } else { + return Some((cmp, sort_key)); + } + } + + /// Convert a segment level sort key into the global sort key. + fn convert_segment_sort_key(&self, sort_key: Self::SegmentSortKey) -> Self::SortKey; +} + +/// `SortKeyComputer` defines the sort key to be used by a TopK Collector. +/// +/// The `SortKeyComputer` itself does not make much of the computation itself. +/// Instead, it helps constructing `Self::Child` instances that will compute +/// the sort key at a segment scale. +pub trait SortKeyComputer: Sync { + /// The sort key type. + type SortKey: 'static + Send + Sync + PartialOrd + Clone; + /// Type of the associated [`SegmentSortKeyComputer`]. + type Child: SegmentSortKeyComputer; + + /// Indicates whether the sort key actually uses the similarity score (by default BM25). + /// If set to false, the similary score might not be computed (as an optimization), + /// and the score fed in the segment sort key computer could take any value. + fn requires_scoring(&self) -> bool { + false + } + + /// Builds a child sort key computer for a specific segment. + fn segment_sort_key_computer(&self, segment_reader: &SegmentReader) -> Result; +} + +impl SortKeyComputer + for (HeadSortKeyComputer, TailSortKeyComputer) +where + HeadSortKeyComputer: SortKeyComputer, + TailSortKeyComputer: SortKeyComputer, +{ + type SortKey = ( + ::SortKey, + ::SortKey, + ); + type Child = (HeadSortKeyComputer::Child, TailSortKeyComputer::Child); + + fn segment_sort_key_computer(&self, segment_reader: &SegmentReader) -> Result { + Ok(( + self.0.segment_sort_key_computer(segment_reader)?, + self.1.segment_sort_key_computer(segment_reader)?, + )) + } +} + +impl SegmentSortKeyComputer + for (HeadSegmentSortKeyComputer, TailSegmentSortKeyComputer) +where + HeadSegmentSortKeyComputer: SegmentSortKeyComputer, + TailSegmentSortKeyComputer: SegmentSortKeyComputer, +{ + type SortKey = ( + HeadSegmentSortKeyComputer::SortKey, + TailSegmentSortKeyComputer::SortKey, + ); + type SegmentSortKey = ( + HeadSegmentSortKeyComputer::SegmentSortKey, + TailSegmentSortKeyComputer::SegmentSortKey, + ); + + fn sort_key(&mut self, doc: DocId, score: Score) -> Self::SegmentSortKey { + let head_sort_key = self.0.sort_key(doc, score); + let tail_sort_key = self.1.sort_key(doc, score); + (head_sort_key, tail_sort_key) + } + + fn accept_sort_key_lazy( + &mut self, + doc_id: DocId, + score: Score, + threshold: &Self::SegmentSortKey, + ) -> Option<(Ordering, Self::SegmentSortKey)> { + let (head_threshold, tail_threshold) = threshold; + let (head_cmp, head_sort_key) = + self.0 + .accept_sort_key_lazy::(doc_id, score, head_threshold)?; + if head_cmp == Ordering::Equal { + let (tail_cmp, tail_sort_key) = + self.1 + .accept_sort_key_lazy::(doc_id, score, tail_threshold)?; + Some((tail_cmp, (head_sort_key, tail_sort_key))) + } else { + let tail_sort_key = self.1.sort_key(doc_id, score); + Some((head_cmp, (head_sort_key, tail_sort_key))) + } + } + + fn is_lazy() -> bool { + true + } + + fn convert_segment_sort_key(&self, sort_key: Self::SegmentSortKey) -> Self::SortKey { + let (head_sort_key, tail_sort_key) = sort_key; + ( + self.0.convert_segment_sort_key(head_sort_key), + self.1.convert_segment_sort_key(tail_sort_key), + ) + } +} + +/// This struct is used as an adapter to take a sort key computer and map its score to another +/// new sort key. +pub struct MappedSegmentSortKeyComputer { + sort_key_computer: T, + map: fn(PreviousSortKey) -> NewSortKey, +} + +impl SegmentSortKeyComputer + for MappedSegmentSortKeyComputer +where + T: SegmentSortKeyComputer, + PreviousScore: 'static + Clone + Send + Sync + PartialOrd, + NewScore: 'static + Clone + Send + Sync + PartialOrd, +{ + type SortKey = NewScore; + type SegmentSortKey = T::SegmentSortKey; + + fn sort_key(&mut self, doc: DocId, score: Score) -> Self::SegmentSortKey { + self.sort_key_computer.sort_key(doc, score) + } + + fn accept_sort_key_lazy( + &mut self, + doc_id: DocId, + score: Score, + threshold: &Self::SegmentSortKey, + ) -> Option<(std::cmp::Ordering, Self::SegmentSortKey)> { + self.sort_key_computer + .accept_sort_key_lazy::(doc_id, score, threshold) + } + + fn is_lazy() -> bool { + T::is_lazy() + } + + fn convert_segment_sort_key(&self, segment_sort_key: Self::SegmentSortKey) -> Self::SortKey { + (self.map)( + self.sort_key_computer + .convert_segment_sort_key(segment_sort_key), + ) + } +} + +// We then re-use our (head, tail) implement and our mapper by seeing mapping any tuple (a, b, c, +// ...) as the chain (a, (b, (c, ...))) + +impl SortKeyComputer + for (SortKeyComputer1, SortKeyComputer2, SortKeyComputer3) +where + SortKeyComputer1: SortKeyComputer, + SortKeyComputer2: SortKeyComputer, + SortKeyComputer3: SortKeyComputer, +{ + type Child = MappedSegmentSortKeyComputer< + <(SortKeyComputer1, (SortKeyComputer2, SortKeyComputer3)) as SortKeyComputer>::Child, + ( + SortKeyComputer1::SortKey, + (SortKeyComputer2::SortKey, SortKeyComputer3::SortKey), + ), + Self::SortKey, + >; + type SortKey = ( + SortKeyComputer1::SortKey, + SortKeyComputer2::SortKey, + SortKeyComputer3::SortKey, + ); + + fn segment_sort_key_computer(&self, segment_reader: &SegmentReader) -> Result { + let sort_key_computer1 = self.0.segment_sort_key_computer(segment_reader)?; + let sort_key_computer2 = self.1.segment_sort_key_computer(segment_reader)?; + let sort_key_computer3 = self.2.segment_sort_key_computer(segment_reader)?; + Ok(MappedSegmentSortKeyComputer { + sort_key_computer: (sort_key_computer1, (sort_key_computer2, sort_key_computer3)), + map: |(sort_key1, (sort_key2, sort_key3))| (sort_key1, sort_key2, sort_key3), + }) + } +} + +impl SortKeyComputer + for ( + SortKeyComputer1, + SortKeyComputer2, + SortKeyComputer3, + SortKeyComputer4, + ) +where + SortKeyComputer1: SortKeyComputer, + SortKeyComputer2: SortKeyComputer, + SortKeyComputer3: SortKeyComputer, + SortKeyComputer4: SortKeyComputer, +{ + type Child = MappedSegmentSortKeyComputer< + <( + SortKeyComputer1, + (SortKeyComputer2, (SortKeyComputer3, SortKeyComputer4)), + ) as SortKeyComputer>::Child, + ( + SortKeyComputer1::SortKey, + ( + SortKeyComputer2::SortKey, + (SortKeyComputer3::SortKey, SortKeyComputer4::SortKey), + ), + ), + Self::SortKey, + >; + type SortKey = ( + SortKeyComputer1::SortKey, + SortKeyComputer2::SortKey, + SortKeyComputer3::SortKey, + SortKeyComputer4::SortKey, + ); + + fn segment_sort_key_computer(&self, segment_reader: &SegmentReader) -> Result { + let sort_key_computer1 = self.0.segment_sort_key_computer(segment_reader)?; + let sort_key_computer2 = self.1.segment_sort_key_computer(segment_reader)?; + let sort_key_computer3 = self.2.segment_sort_key_computer(segment_reader)?; + let sort_key_computer4 = self.3.segment_sort_key_computer(segment_reader)?; + Ok(MappedSegmentSortKeyComputer { + sort_key_computer: ( + sort_key_computer1, + (sort_key_computer2, (sort_key_computer3, sort_key_computer4)), + ), + map: |(sort_key1, (sort_key2, (sort_key3, sort_key4)))| { + (sort_key1, sort_key2, sort_key3, sort_key4) + }, + }) + } +} + +impl SortKeyComputer for F +where + F: 'static + Send + Sync + Fn(&SegmentReader) -> TSegmentSortKeyComputer, + TSegmentSortKeyComputer: SegmentSortKeyComputer, +{ + type SortKey = TSegmentSortKeyComputer::SortKey; + type Child = TSegmentSortKeyComputer; + + fn segment_sort_key_computer(&self, segment_reader: &SegmentReader) -> Result { + Ok((self)(segment_reader)) + } +} + +impl SegmentSortKeyComputer for F +where + F: 'static + FnMut(DocId, Score) -> TSortKey, + TSortKey: 'static + PartialOrd + Clone + Send + Sync, +{ + type SortKey = TSortKey; + type SegmentSortKey = TSortKey; + + fn sort_key(&mut self, doc: DocId, score: Score) -> TSortKey { + (self)(doc, score) + } + + /// Convert a segment level score into the global level score. + fn convert_segment_sort_key(&self, sort_key: Self::SegmentSortKey) -> Self::SortKey { + sort_key + } +} diff --git a/src/collector/sort_key_top_collector.rs b/src/collector/sort_key_top_collector.rs index 4064b4f009..f6a6309e28 100644 --- a/src/collector/sort_key_top_collector.rs +++ b/src/collector/sort_key_top_collector.rs @@ -1,5 +1,4 @@ -use std::cmp::Ordering; - +use crate::collector::sort_key::{SegmentSortKeyComputer, SortKeyComputer}; use crate::collector::top_collector::{TopCollector, TopSegmentCollector}; use crate::collector::{Collector, SegmentCollector}; use crate::{DocAddress, DocId, Result, Score, SegmentReader}; @@ -23,78 +22,6 @@ where TSortKey: Clone + PartialOrd } } -/// A `SegmentSortKeyComputer` makes it possible to modify the default score -/// for a given document belonging to a specific segment. -/// -/// It is the segment local version of the [`SortKeyComputer`]. -pub trait SegmentSortKeyComputer: 'static { - /// The final score being emitted. - type SortKey: 'static + PartialOrd + Send + Sync + Clone; - - /// Sort key used by at the segment level by the `SegmentSortKeyComputer`. - /// - /// It is typically small like a `u64`, and is meant to be converted - /// to the final score at the end of the collection of the segment. - type SegmentSortKey: 'static + PartialOrd + Clone + Send + Sync + Clone; - - /// Computes the sort key for the given document and score. - fn sort_key(&mut self, doc: DocId, score: Score) -> Self::SegmentSortKey; - - /// Returns true if the `SegmentSortKeyComputer` is a good candidate for the lazy evaluation - /// optimization. See [`SegmentSortKeyComputer::accept_score_lazy`]. - fn is_lazy() -> bool { - false - } - - /// Implementing this method makes it possible to avoid computing - /// a sort_key entirely if we can assess that it won't pass a threshold - /// with a partial computation. - /// - /// This is currently used for lexicographic sorting. - /// - /// If REVERSE_ORDER is false (resp. true), - /// - we return None if the score is below the threshold (resp. above to the threshold) - /// - we return Some(ordering, score) if the score is above or equal to the threshold (resp. - /// below or equal to) - fn accept_sort_key_lazy( - &mut self, - doc_id: DocId, - score: Score, - threshold: &Self::SegmentSortKey, - ) -> Option<(std::cmp::Ordering, Self::SegmentSortKey)> { - let excluded_ordering = if REVERSE_ORDER { - Ordering::Greater - } else { - Ordering::Less - }; - let sort_key = self.sort_key(doc_id, score); - let cmp = sort_key.partial_cmp(threshold).unwrap_or(excluded_ordering); - if cmp == excluded_ordering { - return None; - } else { - return Some((cmp, sort_key)); - } - } - - /// Convert a segment level sort key into the global sort key. - fn convert_segment_sort_key(&self, sort_key: Self::SegmentSortKey) -> Self::SortKey; -} - -/// `SortKeyComputer` defines the sort key to be used by a TopK Collector. -/// -/// The `SortKeyComputer` itself does not make much of the computation itself. -/// Instead, it helps constructing `Self::Child` instances that will compute -/// the sort key at a segment scale. -pub trait SortKeyComputer: Sync { - /// The sort key type. - type SortKey: 'static + Send + Sync + PartialOrd + Clone; - /// Type of the associated [`SegmentSortKeyComputer`]. - type Child: SegmentSortKeyComputer; - - /// Builds a child sort key computer for a specific segment. - fn segment_sort_key_computer(&self, segment_reader: &SegmentReader) -> Result; -} - impl Collector for TopBySortKeyCollector where TSortKeyComputer: SortKeyComputer + Send + Sync, @@ -120,7 +47,7 @@ where } fn requires_scoring(&self) -> bool { - true + self.sort_key_computer.requires_scoring() } fn merge_fruits(&self, segment_fruits: Vec) -> Result { @@ -161,238 +88,3 @@ where TSegmentSortKeyComputer: 'static + SegmentSortKeyComputer .collect() } } - -impl SortKeyComputer for F -where - F: 'static + Send + Sync + Fn(&SegmentReader) -> TSegmentSortKeyComputer, - TSegmentSortKeyComputer: SegmentSortKeyComputer, -{ - type SortKey = TSegmentSortKeyComputer::SortKey; - type Child = TSegmentSortKeyComputer; - - fn segment_sort_key_computer(&self, segment_reader: &SegmentReader) -> Result { - Ok((self)(segment_reader)) - } -} - -impl SegmentSortKeyComputer for F -where - F: 'static + FnMut(DocId, Score) -> TSortKey, - TSortKey: 'static + PartialOrd + Clone + Send + Sync, -{ - type SortKey = TSortKey; - type SegmentSortKey = TSortKey; - - fn sort_key(&mut self, doc: DocId, score: Score) -> TSortKey { - (self)(doc, score) - } - - /// Convert a segment level score into the global level score. - fn convert_segment_sort_key(&self, sort_key: Self::SegmentSortKey) -> Self::SortKey { - sort_key - } -} - -impl SortKeyComputer - for (HeadSortKeyComputer, TailSortKeyComputer) -where - HeadSortKeyComputer: SortKeyComputer, - TailSortKeyComputer: SortKeyComputer, -{ - type SortKey = ( - ::SortKey, - ::SortKey, - ); - type Child = (HeadSortKeyComputer::Child, TailSortKeyComputer::Child); - - fn segment_sort_key_computer(&self, segment_reader: &SegmentReader) -> Result { - Ok(( - self.0.segment_sort_key_computer(segment_reader)?, - self.1.segment_sort_key_computer(segment_reader)?, - )) - } -} - -impl SegmentSortKeyComputer - for (HeadSegmentSortKeyComputer, TailSegmentSortKeyComputer) -where - HeadSegmentSortKeyComputer: SegmentSortKeyComputer, - TailSegmentSortKeyComputer: SegmentSortKeyComputer, -{ - type SortKey = ( - HeadSegmentSortKeyComputer::SortKey, - TailSegmentSortKeyComputer::SortKey, - ); - type SegmentSortKey = ( - HeadSegmentSortKeyComputer::SegmentSortKey, - TailSegmentSortKeyComputer::SegmentSortKey, - ); - - fn sort_key(&mut self, doc: DocId, score: Score) -> Self::SegmentSortKey { - let head_sort_key = self.0.sort_key(doc, score); - let tail_sort_key = self.1.sort_key(doc, score); - (head_sort_key, tail_sort_key) - } - - fn accept_sort_key_lazy( - &mut self, - doc_id: DocId, - score: Score, - threshold: &Self::SegmentSortKey, - ) -> Option<(Ordering, Self::SegmentSortKey)> { - let (head_threshold, tail_threshold) = threshold; - let (head_cmp, head_sort_key) = - self.0 - .accept_sort_key_lazy::(doc_id, score, head_threshold)?; - if head_cmp == Ordering::Equal { - let (tail_cmp, tail_sort_key) = - self.1 - .accept_sort_key_lazy::(doc_id, score, tail_threshold)?; - Some((tail_cmp, (head_sort_key, tail_sort_key))) - } else { - let tail_sort_key = self.1.sort_key(doc_id, score); - Some((head_cmp, (head_sort_key, tail_sort_key))) - } - } - - fn is_lazy() -> bool { - true - } - - fn convert_segment_sort_key(&self, sort_key: Self::SegmentSortKey) -> Self::SortKey { - let (head_sort_key, tail_sort_key) = sort_key; - ( - self.0.convert_segment_sort_key(head_sort_key), - self.1.convert_segment_sort_key(tail_sort_key), - ) - } -} - -/// This struct is used as an adapter to take a sort key computer and map its score to another -/// new sort key. -pub struct MappedSegmentSortKeyComputer { - sort_key_computer: T, - map: fn(PreviousSortKey) -> NewSortKey, -} - -impl SegmentSortKeyComputer - for MappedSegmentSortKeyComputer -where - T: SegmentSortKeyComputer, - PreviousScore: 'static + Clone + Send + Sync + PartialOrd, - NewScore: 'static + Clone + Send + Sync + PartialOrd, -{ - type SortKey = NewScore; - type SegmentSortKey = T::SegmentSortKey; - - fn sort_key(&mut self, doc: DocId, score: Score) -> Self::SegmentSortKey { - self.sort_key_computer.sort_key(doc, score) - } - - fn accept_sort_key_lazy( - &mut self, - doc_id: DocId, - score: Score, - threshold: &Self::SegmentSortKey, - ) -> Option<(std::cmp::Ordering, Self::SegmentSortKey)> { - self.sort_key_computer - .accept_sort_key_lazy::(doc_id, score, threshold) - } - - fn is_lazy() -> bool { - T::is_lazy() - } - - fn convert_segment_sort_key(&self, segment_sort_key: Self::SegmentSortKey) -> Self::SortKey { - (self.map)( - self.sort_key_computer - .convert_segment_sort_key(segment_sort_key), - ) - } -} - -// We then re-use our (head, tail) implement and our mapper by seeing mapping any tuple (a, b, c, -// ...) as the chain (a, (b, (c, ...))) - -impl SortKeyComputer - for (SortKeyComputer1, SortKeyComputer2, SortKeyComputer3) -where - SortKeyComputer1: SortKeyComputer, - SortKeyComputer2: SortKeyComputer, - SortKeyComputer3: SortKeyComputer, -{ - type Child = MappedSegmentSortKeyComputer< - <(SortKeyComputer1, (SortKeyComputer2, SortKeyComputer3)) as SortKeyComputer>::Child, - ( - SortKeyComputer1::SortKey, - (SortKeyComputer2::SortKey, SortKeyComputer3::SortKey), - ), - Self::SortKey, - >; - type SortKey = ( - SortKeyComputer1::SortKey, - SortKeyComputer2::SortKey, - SortKeyComputer3::SortKey, - ); - - fn segment_sort_key_computer(&self, segment_reader: &SegmentReader) -> Result { - let sort_key_computer1 = self.0.segment_sort_key_computer(segment_reader)?; - let sort_key_computer2 = self.1.segment_sort_key_computer(segment_reader)?; - let sort_key_computer3 = self.2.segment_sort_key_computer(segment_reader)?; - Ok(MappedSegmentSortKeyComputer { - sort_key_computer: (sort_key_computer1, (sort_key_computer2, sort_key_computer3)), - map: |(sort_key1, (sort_key2, sort_key3))| (sort_key1, sort_key2, sort_key3), - }) - } -} - -impl SortKeyComputer - for ( - SortKeyComputer1, - SortKeyComputer2, - SortKeyComputer3, - SortKeyComputer4, - ) -where - SortKeyComputer1: SortKeyComputer, - SortKeyComputer2: SortKeyComputer, - SortKeyComputer3: SortKeyComputer, - SortKeyComputer4: SortKeyComputer, -{ - type Child = MappedSegmentSortKeyComputer< - <( - SortKeyComputer1, - (SortKeyComputer2, (SortKeyComputer3, SortKeyComputer4)), - ) as SortKeyComputer>::Child, - ( - SortKeyComputer1::SortKey, - ( - SortKeyComputer2::SortKey, - (SortKeyComputer3::SortKey, SortKeyComputer4::SortKey), - ), - ), - Self::SortKey, - >; - type SortKey = ( - SortKeyComputer1::SortKey, - SortKeyComputer2::SortKey, - SortKeyComputer3::SortKey, - SortKeyComputer4::SortKey, - ); - - fn segment_sort_key_computer(&self, segment_reader: &SegmentReader) -> Result { - let sort_key_computer1 = self.0.segment_sort_key_computer(segment_reader)?; - let sort_key_computer2 = self.1.segment_sort_key_computer(segment_reader)?; - let sort_key_computer3 = self.2.segment_sort_key_computer(segment_reader)?; - let sort_key_computer4 = self.3.segment_sort_key_computer(segment_reader)?; - Ok(MappedSegmentSortKeyComputer { - sort_key_computer: ( - sort_key_computer1, - (sort_key_computer2, (sort_key_computer3, sort_key_computer4)), - ), - map: |(sort_key1, (sort_key2, (sort_key3, sort_key4)))| { - (sort_key1, sort_key2, sort_key3, sort_key4) - }, - }) - } -} diff --git a/src/collector/top_collector.rs b/src/collector/top_collector.rs index 6d939dd2fb..eeeff7473d 100644 --- a/src/collector/top_collector.rs +++ b/src/collector/top_collector.rs @@ -135,7 +135,9 @@ where T: PartialOrd + Clone /// Ideally we would use Into but the blanket implementation seems to cause the Scorer traits /// to fail. #[doc(hidden)] - pub(crate) fn into_tscore(self) -> TopCollector { + pub(crate) fn into_different_sort_key_type( + self, + ) -> TopCollector { TopCollector { limit: self.limit, offset: self.offset, diff --git a/src/collector/top_score_collector.rs b/src/collector/top_score_collector.rs index 2ae3edcdaa..215233eea2 100644 --- a/src/collector/top_score_collector.rs +++ b/src/collector/top_score_collector.rs @@ -2,21 +2,15 @@ use std::fmt; use std::marker::PhantomData; use std::sync::Arc; -use columnar::{ColumnValues, StrColumn}; +use columnar::ColumnValues; use serde::{Deserialize, Serialize}; use super::Collector; -use crate::collector::custom_score_top_collector::{ - CustomScoreTopCollector, CustomScoreTopSegmentCollector, -}; use crate::collector::sort_key_top_collector::TopBySortKeyCollector; use crate::collector::top_collector::{ComparableDoc, TopCollector, TopSegmentCollector}; -use crate::collector::{ - CustomScorer, CustomSegmentScorer, SegmentCollector, SegmentSortKeyComputer, SortKeyComputer, -}; +use crate::collector::{SegmentCollector, SegmentSortKeyComputer, SortKeyComputer}; use crate::fastfield::{FastFieldNotAvailableError, FastValue}; use crate::query::Weight; -use crate::termdict::TermOrdinal; use crate::{DocAddress, DocId, Order, Score, SegmentOrdinal, SegmentReader, TantivyError}; struct FastFieldConvertCollector< @@ -86,162 +80,162 @@ where } } -struct StringConvertCollector { - pub collector: CustomScoreTopCollector, - pub field: String, - order: Order, - limit: usize, - offset: usize, -} - -impl Collector for StringConvertCollector { - type Fruit = Vec<(String, DocAddress)>; - - type Child = StringConvertSegmentCollector; - - fn for_segment( - &self, - segment_local_id: crate::SegmentOrdinal, - segment: &SegmentReader, - ) -> crate::Result { - let schema = segment.schema(); - let field = schema.get_field(&self.field)?; - let field_entry = schema.get_field_entry(field); - if !field_entry.is_fast() { - return Err(TantivyError::SchemaError(format!( - "Field {:?} is not a fast field.", - field_entry.name() - ))); - } - let requested_type = crate::schema::Type::Str; - let schema_type = field_entry.field_type().value_type(); - if schema_type != requested_type { - return Err(TantivyError::SchemaError(format!( - "Field {:?} is of type {schema_type:?}!={requested_type:?}", - field_entry.name() - ))); - } - let ff = segment - .fast_fields() - .str(&self.field)? - .expect("ff should be a str field"); - Ok(StringConvertSegmentCollector { - collector: self.collector.for_segment(segment_local_id, segment)?, - ff, - order: self.order.clone(), - }) - } - - fn requires_scoring(&self) -> bool { - self.collector.requires_scoring() - } - - fn merge_fruits( - &self, - child_fruits: Vec<::Fruit>, - ) -> crate::Result { - if self.limit == 0 { - return Ok(Vec::new()); - } - if self.order.is_desc() { - let mut top_collector: TopNComputer<_, _, true> = - TopNComputer::new(self.limit + self.offset); - for child_fruit in child_fruits { - for (feature, doc) in child_fruit { - top_collector.push(feature, doc); - } - } - Ok(top_collector - .into_sorted_vec() - .into_iter() - .skip(self.offset) - .map(|cdoc| (cdoc.sort_key, cdoc.doc)) - .collect()) - } else { - let mut top_collector: TopNComputer<_, _, false> = - TopNComputer::new(self.limit + self.offset); - for child_fruit in child_fruits { - for (feature, doc) in child_fruit { - top_collector.push(feature, doc); - } - } - - Ok(top_collector - .into_sorted_vec() - .into_iter() - .skip(self.offset) - .map(|cdoc| (cdoc.sort_key, cdoc.doc)) - .collect()) - } - } -} - -struct StringConvertSegmentCollector { - pub collector: CustomScoreTopSegmentCollector, - ff: StrColumn, - order: Order, -} - -impl SegmentCollector for StringConvertSegmentCollector { - type Fruit = Vec<(String, DocAddress)>; - - fn collect(&mut self, doc: DocId, score: Score) { - self.collector.collect(doc, score); - } - - fn harvest(self) -> Vec<(String, DocAddress)> { - let top_ordinals: Vec<(TermOrdinal, DocAddress)> = self.collector.harvest(); - - // Collect terms. - let mut terms: Vec = Vec::with_capacity(top_ordinals.len()); - let result = if self.order.is_asc() { - self.ff.dictionary().sorted_ords_to_term_cb( - top_ordinals.iter().map(|(term_ord, _)| u64::MAX - term_ord), - |term| { - terms.push( - std::str::from_utf8(term) - .expect("Failed to decode term as unicode") - .to_owned(), - ); - Ok(()) - }, - ) - } else { - self.ff.dictionary().sorted_ords_to_term_cb( - top_ordinals.iter().rev().map(|(term_ord, _)| *term_ord), - |term| { - terms.push( - std::str::from_utf8(term) - .expect("Failed to decode term as unicode") - .to_owned(), - ); - Ok(()) - }, - ) - }; - - assert!( - result.expect("Failed to read terms from term dictionary"), - "Not all terms were matched in segment." - ); - - // Zip them back with their docs. - if self.order.is_asc() { - terms - .into_iter() - .zip(top_ordinals) - .map(|(term, (_, doc))| (term, doc)) - .collect() - } else { - terms - .into_iter() - .rev() - .zip(top_ordinals) - .map(|(term, (_, doc))| (term, doc)) - .collect() - } - } -} +// struct StringConvertCollector { +// pub collector: CustomScoreTopCollector, +// pub field: String, +// order: Order, +// limit: usize, +// offset: usize, +// } + +// impl Collector for StringConvertCollector { +// type Fruit = Vec<(String, DocAddress)>; + +// type Child = StringConvertSegmentCollector; + +// fn for_segment( +// &self, +// segment_local_id: crate::SegmentOrdinal, +// segment: &SegmentReader, +// ) -> crate::Result { +// let schema = segment.schema(); +// let field = schema.get_field(&self.field)?; +// let field_entry = schema.get_field_entry(field); +// if !field_entry.is_fast() { +// return Err(TantivyError::SchemaError(format!( +// "Field {:?} is not a fast field.", +// field_entry.name() +// ))); +// } +// let requested_type = crate::schema::Type::Str; +// let schema_type = field_entry.field_type().value_type(); +// if schema_type != requested_type { +// return Err(TantivyError::SchemaError(format!( +// "Field {:?} is of type {schema_type:?}!={requested_type:?}", +// field_entry.name() +// ))); +// } +// let ff = segment +// .fast_fields() +// .str(&self.field)? +// .expect("ff should be a str field"); +// Ok(StringConvertSegmentCollector { +// collector: self.collector.for_segment(segment_local_id, segment)?, +// ff, +// order: self.order.clone(), +// }) +// } + +// fn requires_scoring(&self) -> bool { +// self.collector.requires_scoring() +// } + +// fn merge_fruits( +// &self, +// child_fruits: Vec<::Fruit>, +// ) -> crate::Result { +// if self.limit == 0 { +// return Ok(Vec::new()); +// } +// if self.order.is_desc() { +// let mut top_collector: TopNComputer<_, _, true> = +// TopNComputer::new(self.limit + self.offset); +// for child_fruit in child_fruits { +// for (feature, doc) in child_fruit { +// top_collector.push(feature, doc); +// } +// } +// Ok(top_collector +// .into_sorted_vec() +// .into_iter() +// .skip(self.offset) +// .map(|cdoc| (cdoc.sort_key, cdoc.doc)) +// .collect()) +// } else { +// let mut top_collector: TopNComputer<_, _, false> = +// TopNComputer::new(self.limit + self.offset); +// for child_fruit in child_fruits { +// for (feature, doc) in child_fruit { +// top_collector.push(feature, doc); +// } +// } + +// Ok(top_collector +// .into_sorted_vec() +// .into_iter() +// .skip(self.offset) +// .map(|cdoc| (cdoc.sort_key, cdoc.doc)) +// .collect()) +// } +// } +// } + +// struct StringConvertSegmentCollector { +// pub collector: CustomScoreTopSegmentCollector, +// ff: StrColumn, +// order: Order, +// } + +// impl SegmentCollector for StringConvertSegmentCollector { +// type Fruit = Vec<(String, DocAddress)>; + +// fn collect(&mut self, doc: DocId, score: Score) { +// self.collector.collect(doc, score); +// } + +// fn harvest(self) -> Vec<(String, DocAddress)> { +// let top_ordinals: Vec<(TermOrdinal, DocAddress)> = self.collector.harvest(); + +// // Collect terms. +// let mut terms: Vec = Vec::with_capacity(top_ordinals.len()); +// let result = if self.order.is_asc() { +// self.ff.dictionary().sorted_ords_to_term_cb( +// top_ordinals.iter().map(|(term_ord, _)| u64::MAX - term_ord), +// |term| { +// terms.push( +// std::str::from_utf8(term) +// .expect("Failed to decode term as unicode") +// .to_owned(), +// ); +// Ok(()) +// }, +// ) +// } else { +// self.ff.dictionary().sorted_ords_to_term_cb( +// top_ordinals.iter().rev().map(|(term_ord, _)| *term_ord), +// |term| { +// terms.push( +// std::str::from_utf8(term) +// .expect("Failed to decode term as unicode") +// .to_owned(), +// ); +// Ok(()) +// }, +// ) +// }; + +// assert!( +// result.expect("Failed to read terms from term dictionary"), +// "Not all terms were matched in segment." +// ); + +// // Zip them back with their docs. +// if self.order.is_asc() { +// terms +// .into_iter() +// .zip(top_ordinals) +// .map(|(term, (_, doc))| (term, doc)) +// .collect() +// } else { +// terms +// .into_iter() +// .rev() +// .zip(top_ordinals) +// .map(|(term, (_, doc))| (term, doc)) +// .collect() +// } +// } +// } /// The `TopDocs` collector keeps track of the top `K` documents /// sorted by their score. @@ -299,18 +293,30 @@ impl fmt::Debug for TopDocs { } } -struct ScorerByFastFieldReader { +struct SortKeyByFastFieldReader { sort_column: Arc>, order: Order, } -impl CustomSegmentScorer for ScorerByFastFieldReader { - fn score(&mut self, doc: DocId) -> u64 { - let value = self.sort_column.get_val(doc); - if self.order.is_desc() { - value +impl SegmentSortKeyComputer for SortKeyByFastFieldReader { + type SortKey = u64; + + type SegmentSortKey = u64; + + fn sort_key(&mut self, doc: DocId, _score: Score) -> Self::SegmentSortKey { + let val = self.sort_column.get_val(doc); + if self.order == Order::Desc { + u64::MAX - val + } else { + val + } + } + + fn convert_segment_sort_key(&self, sort_key: Self::SegmentSortKey) -> Self::SortKey { + if self.order == Order::Desc { + u64::MAX - sort_key } else { - u64::MAX - value + sort_key } } } @@ -320,10 +326,15 @@ struct ScorerByField { order: Order, } -impl CustomScorer for ScorerByField { - type Child = ScorerByFastFieldReader; +impl SortKeyComputer for ScorerByField { + type Child = SortKeyByFastFieldReader; - fn segment_scorer(&self, segment_reader: &SegmentReader) -> crate::Result { + type SortKey = u64; + + fn segment_sort_key_computer( + &self, + segment_reader: &SegmentReader, + ) -> crate::Result { // We interpret this field as u64, regardless of its type, that way, // we avoid needless conversion. Regardless of the fast field type, the // mapping is monotonic, so it is sufficient to compute our top-K docs. @@ -338,7 +349,7 @@ impl CustomScorer for ScorerByField { if self.order.is_asc() { default_value = u64::MAX; } - Ok(ScorerByFastFieldReader { + Ok(SortKeyByFastFieldReader { sort_column: sort_column.first_or_default_col(default_value), order: self.order.clone(), }) @@ -475,12 +486,12 @@ impl TopDocs { field: impl ToString, order: Order, ) -> impl Collector> { - CustomScoreTopCollector::new( + TopBySortKeyCollector::new( ScorerByField { field: field.to_string(), order, }, - self.0.into_tscore(), + self.0.into_different_sort_key_type(), ) } @@ -571,28 +582,28 @@ impl TopDocs { } /// Like `order_by_fast_field`, but for a `String` fast field. - pub fn order_by_string_fast_field( - self, - fast_field: impl ToString, - order: Order, - ) -> impl Collector> { - let limit = self.0.limit; - let offset = self.0.offset; - let u64_collector = CustomScoreTopCollector::new( - ScorerByField { - field: fast_field.to_string(), - order: order.clone(), - }, - self.0.into_tscore(), - ); - StringConvertCollector { - collector: u64_collector, - field: fast_field.to_string(), - order, - limit, - offset, - } - } + // pub fn order_by_string_fast_field( + // self, + // fast_field: impl ToString, + // order: Order, + // ) -> impl Collector> { + // let limit = self.0.limit; + // let offset = self.0.offset; + // let u64_collector = CustomScoreTopCollector::new( + // ScorerByField { + // field: fast_field.to_string(), + // order: order.clone(), + // }, + // self.0.into_different_sort_key_type(), + // ); + // StringConvertCollector { + // collector: u64_collector, + // field: fast_field.to_string(), + // order, + // limit, + // offset, + // } + // } /// Ranks the documents using a sort key. /// @@ -697,7 +708,7 @@ impl TopDocs { where TSortKey: 'static + Clone + Send + Sync + PartialOrd, { - TopBySortKeyCollector::new(sort_key_computer, self.0.into_tscore()) + TopBySortKeyCollector::new(sort_key_computer, self.0.into_different_sort_key_type()) } } @@ -1088,7 +1099,7 @@ mod tests { for (feature, doc) in &docs { computer.push(*feature, *doc); } - let mut comparable_docs = docs.into_iter().map(|(feature, doc)| ComparableDoc { sort_key, doc }).collect::>(); + let mut comparable_docs = docs.into_iter().map(|(sort_key, doc)| ComparableDoc { sort_key, doc }).collect::>(); comparable_docs.sort(); comparable_docs.truncate(limit); prop_assert_eq!( @@ -1721,12 +1732,12 @@ mod tests { // offset, and then taking the limit. let sorted_docs: Vec<_> = if order.is_desc() { let mut comparable_docs: Vec> = - all_results.into_iter().map(|(feature, doc)| ComparableDoc { sort_key, doc}).collect(); + all_results.into_iter().map(|(sort_key, doc)| ComparableDoc { sort_key, doc}).collect(); comparable_docs.sort(); comparable_docs.into_iter().map(|cd| (cd.sort_key, cd.doc)).collect() } else { let mut comparable_docs: Vec> = - all_results.into_iter().map(|(feature, doc)| ComparableDoc { sort_key, doc}).collect(); + all_results.into_iter().map(|(sort_key, doc)| ComparableDoc { sort_key, doc}).collect(); comparable_docs.sort(); comparable_docs.into_iter().map(|cd| (cd.sort_key, cd.doc)).collect() }; From 78ef833e45ffefa7c8c7fe6d37b3c829db42650f Mon Sep 17 00:00:00 2001 From: Paul Masurel Date: Fri, 31 Oct 2025 18:19:31 +0100 Subject: [PATCH 08/11] unification --- examples/warmer.rs | 3 +- src/collector/mod.rs | 2 +- src/collector/sort_key/mod.rs | 171 +++++++++++++ .../sort_key_computer.rs} | 49 ++++ src/collector/top_score_collector.rs | 231 +++--------------- src/index/index_meta.rs | 3 +- 6 files changed, 256 insertions(+), 203 deletions(-) create mode 100644 src/collector/sort_key/mod.rs rename src/collector/{sort_key.rs => sort_key/sort_key_computer.rs} (88%) diff --git a/examples/warmer.rs b/examples/warmer.rs index 1cae9d349a..43237c3578 100644 --- a/examples/warmer.rs +++ b/examples/warmer.rs @@ -2,6 +2,7 @@ use std::cmp::Reverse; use std::collections::{HashMap, HashSet}; use std::sync::{Arc, RwLock, Weak}; +use tantivy::collector::sort_key::NoScoreFn; use tantivy::collector::TopDocs; use tantivy::index::SegmentId; use tantivy::query::QueryParser; @@ -164,7 +165,7 @@ fn main() -> tantivy::Result<()> { move |doc_id: DocId| Reverse(price[doc_id as usize]) }; - let most_expensive_first = TopDocs::with_limit(10).custom_score(score_by_price); + let most_expensive_first = TopDocs::with_limit(10).order_by(NoScoreFn(score_by_price)); let hits = searcher.search(&query, &most_expensive_first)?; assert_eq!( diff --git a/src/collector/mod.rs b/src/collector/mod.rs index c7a4a1826b..d0584e1fea 100644 --- a/src/collector/mod.rs +++ b/src/collector/mod.rs @@ -100,7 +100,7 @@ mod top_score_collector; pub use self::top_collector::ComparableDoc; pub use self::top_score_collector::{TopDocs, TopNComputer}; -mod sort_key; +pub mod sort_key; mod sort_key_top_collector; pub use self::sort_key::{SegmentSortKeyComputer, SortKeyComputer}; mod facet_collector; diff --git a/src/collector/sort_key/mod.rs b/src/collector/sort_key/mod.rs new file mode 100644 index 0000000000..b44b7f13bc --- /dev/null +++ b/src/collector/sort_key/mod.rs @@ -0,0 +1,171 @@ +mod sort_key_computer; + +use columnar::StrColumn; +pub use sort_key_computer::{NoScoreFn, SegmentSortKeyComputer, SortKeyComputer}; + +use crate::termdict::TermOrdinal; +use crate::{DocId, Order, Score}; + +impl SortKeyComputer for (TSortKeyComputer, Order) +where + TSortKeyComputer: SortKeyComputer, + (TSortKeyComputer::Child, Order): SegmentSortKeyComputer, +{ + type SortKey = TSortKeyComputer::SortKey; + + type Child = (TSortKeyComputer::Child, Order); + + fn requires_scoring(&self) -> bool { + self.0.requires_scoring() + } + + fn segment_sort_key_computer( + &self, + segment_reader: &crate::SegmentReader, + ) -> crate::Result { + let child = self.0.segment_sort_key_computer(segment_reader)?; + Ok((child, self.1)) + } +} + +impl SegmentSortKeyComputer + for (TSegmentSortKeyComputer, Order) +where + TSegmentSortKeyComputer: SegmentSortKeyComputer, + TSegmentSortKey: BizarroWorldInvolution + PartialOrd + Clone + 'static + Sync + Send, +{ + type SortKey = TSegmentSortKeyComputer::SortKey; + type SegmentSortKey = TSegmentSortKey; + + fn sort_key(&mut self, doc: DocId, score: Score) -> Self::SegmentSortKey { + let sort_key = self.0.sort_key(doc, score); + sort_key.involution_if_asc(self.1) + } + + fn convert_segment_sort_key(&self, bizarro_sort_key: Self::SegmentSortKey) -> Self::SortKey { + let sort_key = bizarro_sort_key.involution_if_asc(self.1); + self.0.convert_segment_sort_key(sort_key) + } +} + +// BizarroWorldInvolution is a transformation that flips the order of a value. +// +// It is useful when we want to sort things in a way that is given to use dynamically +// (unknown at compile time). +// +// That way the segment score keeps having the same type regardless of the order, +// and we do not rely on an enum. +trait BizarroWorldInvolution: Copy { + fn involution(&self) -> Self; + fn involution_if_asc(&self, order: Order) -> Self { + match order { + Order::Asc => self.involution(), + Order::Desc => *self, + } + } +} + +impl BizarroWorldInvolution for u64 { + fn involution(&self) -> Self { + u64::MAX - self + } +} + +// The point here is that for Option, we do not want None values to come on top +// when running a Asc query. +impl BizarroWorldInvolution for Option { + #[inline] + fn involution(&self) -> Self { + self.map(|val| val.involution()) + } +} + +/// Sort by similarity score. +#[derive(Clone, Debug, Copy)] +pub struct ByScore; + +impl SortKeyComputer for ByScore { + type SortKey = Score; + + type Child = ByScore; + + fn requires_scoring(&self) -> bool { + false + } + + fn segment_sort_key_computer( + &self, + _segment_reader: &crate::SegmentReader, + ) -> crate::Result { + Ok(ByScore) + } +} + +impl SegmentSortKeyComputer for ByScore { + type SortKey = Score; + + type SegmentSortKey = Score; + + fn sort_key(&mut self, _doc: DocId, score: Score) -> Score { + score + } + + fn convert_segment_sort_key(&self, score: Score) -> Score { + score + } +} + +/// Sort by a string column +pub struct ByStringColumn { + column_name: String, +} + +impl ByStringColumn { + pub fn with_column_name(column_name: String) -> Self { + ByStringColumn { column_name } + } +} + +impl SortKeyComputer for ByStringColumn { + type SortKey = Option; + + type Child = ByStringColumnSegmentSortKeyComputer; + + fn requires_scoring(&self) -> bool { + false + } + + fn segment_sort_key_computer( + &self, + segment_reader: &crate::SegmentReader, + ) -> crate::Result { + let str_column_opt = segment_reader.fast_fields().str(&self.column_name)?; + Ok(ByStringColumnSegmentSortKeyComputer { str_column_opt }) + } +} + +pub struct ByStringColumnSegmentSortKeyComputer { + str_column_opt: Option, +} + +impl SegmentSortKeyComputer for ByStringColumnSegmentSortKeyComputer { + type SortKey = Option; + + type SegmentSortKey = Option; + + fn sort_key(&mut self, doc: DocId, _score: Score) -> Option { + let str_column = self.str_column_opt.as_ref()?; + str_column.ords().first(doc) + } + + fn convert_segment_sort_key(&self, term_ord_opt: Option) -> Option { + let term_ord = term_ord_opt?; + let str_column = self.str_column_opt.as_ref()?; + let mut bytes = Vec::new(); + str_column + .dictionary() + .ord_to_term(term_ord, &mut bytes) + .ok()?; + String::try_from(bytes).ok() + } +} diff --git a/src/collector/sort_key.rs b/src/collector/sort_key/sort_key_computer.rs similarity index 88% rename from src/collector/sort_key.rs rename to src/collector/sort_key/sort_key_computer.rs index 9b62a55fa0..9ba12deb27 100644 --- a/src/collector/sort_key.rs +++ b/src/collector/sort_key/sort_key_computer.rs @@ -315,3 +315,52 @@ where sort_key } } + +/// Helper struct to make it possible to define a sort key computer that does not use +/// the similary score from a simple function. +pub struct NoScoreFn(pub F); + +impl SortKeyComputer for NoScoreFn +where + F: 'static + Send + Sync + Fn(&SegmentReader) -> TNoScoreSortKeyFn, + TNoScoreSortKeyFn: 'static + Fn(DocId) -> TSortKey, + TSortKey: 'static + PartialOrd + Clone + Send + Sync, +{ + type SortKey = TSortKey; + type Child = NoScoreSegmentSortKeyComputer; + + fn segment_sort_key_computer(&self, segment_reader: &SegmentReader) -> Result { + Ok({ + NoScoreSegmentSortKeyComputer { + sort_key_fn: (self.0)(segment_reader), + } + }) + } + + fn requires_scoring(&self) -> bool { + false + } +} + +pub struct NoScoreSegmentSortKeyComputer { + sort_key_fn: TNoScoreSortKeyFn, +} + +impl SegmentSortKeyComputer + for NoScoreSegmentSortKeyComputer +where + TNoScoreSortKeyFn: 'static + Fn(DocId) -> TSortKey, + TSortKey: 'static + PartialOrd + Clone + Send + Sync, +{ + type SortKey = TSortKey; + type SegmentSortKey = TSortKey; + + fn sort_key(&mut self, doc: DocId, _score: Score) -> TSortKey { + (self.sort_key_fn)(doc) + } + + /// Convert a segment level score into the global level score. + fn convert_segment_sort_key(&self, sort_key: Self::SegmentSortKey) -> Self::SortKey { + sort_key + } +} diff --git a/src/collector/top_score_collector.rs b/src/collector/top_score_collector.rs index 215233eea2..54404435e3 100644 --- a/src/collector/top_score_collector.rs +++ b/src/collector/top_score_collector.rs @@ -6,6 +6,7 @@ use columnar::ColumnValues; use serde::{Deserialize, Serialize}; use super::Collector; +use crate::collector::sort_key::ByStringColumn; use crate::collector::sort_key_top_collector::TopBySortKeyCollector; use crate::collector::top_collector::{ComparableDoc, TopCollector, TopSegmentCollector}; use crate::collector::{SegmentCollector, SegmentSortKeyComputer, SortKeyComputer}; @@ -80,163 +81,6 @@ where } } -// struct StringConvertCollector { -// pub collector: CustomScoreTopCollector, -// pub field: String, -// order: Order, -// limit: usize, -// offset: usize, -// } - -// impl Collector for StringConvertCollector { -// type Fruit = Vec<(String, DocAddress)>; - -// type Child = StringConvertSegmentCollector; - -// fn for_segment( -// &self, -// segment_local_id: crate::SegmentOrdinal, -// segment: &SegmentReader, -// ) -> crate::Result { -// let schema = segment.schema(); -// let field = schema.get_field(&self.field)?; -// let field_entry = schema.get_field_entry(field); -// if !field_entry.is_fast() { -// return Err(TantivyError::SchemaError(format!( -// "Field {:?} is not a fast field.", -// field_entry.name() -// ))); -// } -// let requested_type = crate::schema::Type::Str; -// let schema_type = field_entry.field_type().value_type(); -// if schema_type != requested_type { -// return Err(TantivyError::SchemaError(format!( -// "Field {:?} is of type {schema_type:?}!={requested_type:?}", -// field_entry.name() -// ))); -// } -// let ff = segment -// .fast_fields() -// .str(&self.field)? -// .expect("ff should be a str field"); -// Ok(StringConvertSegmentCollector { -// collector: self.collector.for_segment(segment_local_id, segment)?, -// ff, -// order: self.order.clone(), -// }) -// } - -// fn requires_scoring(&self) -> bool { -// self.collector.requires_scoring() -// } - -// fn merge_fruits( -// &self, -// child_fruits: Vec<::Fruit>, -// ) -> crate::Result { -// if self.limit == 0 { -// return Ok(Vec::new()); -// } -// if self.order.is_desc() { -// let mut top_collector: TopNComputer<_, _, true> = -// TopNComputer::new(self.limit + self.offset); -// for child_fruit in child_fruits { -// for (feature, doc) in child_fruit { -// top_collector.push(feature, doc); -// } -// } -// Ok(top_collector -// .into_sorted_vec() -// .into_iter() -// .skip(self.offset) -// .map(|cdoc| (cdoc.sort_key, cdoc.doc)) -// .collect()) -// } else { -// let mut top_collector: TopNComputer<_, _, false> = -// TopNComputer::new(self.limit + self.offset); -// for child_fruit in child_fruits { -// for (feature, doc) in child_fruit { -// top_collector.push(feature, doc); -// } -// } - -// Ok(top_collector -// .into_sorted_vec() -// .into_iter() -// .skip(self.offset) -// .map(|cdoc| (cdoc.sort_key, cdoc.doc)) -// .collect()) -// } -// } -// } - -// struct StringConvertSegmentCollector { -// pub collector: CustomScoreTopSegmentCollector, -// ff: StrColumn, -// order: Order, -// } - -// impl SegmentCollector for StringConvertSegmentCollector { -// type Fruit = Vec<(String, DocAddress)>; - -// fn collect(&mut self, doc: DocId, score: Score) { -// self.collector.collect(doc, score); -// } - -// fn harvest(self) -> Vec<(String, DocAddress)> { -// let top_ordinals: Vec<(TermOrdinal, DocAddress)> = self.collector.harvest(); - -// // Collect terms. -// let mut terms: Vec = Vec::with_capacity(top_ordinals.len()); -// let result = if self.order.is_asc() { -// self.ff.dictionary().sorted_ords_to_term_cb( -// top_ordinals.iter().map(|(term_ord, _)| u64::MAX - term_ord), -// |term| { -// terms.push( -// std::str::from_utf8(term) -// .expect("Failed to decode term as unicode") -// .to_owned(), -// ); -// Ok(()) -// }, -// ) -// } else { -// self.ff.dictionary().sorted_ords_to_term_cb( -// top_ordinals.iter().rev().map(|(term_ord, _)| *term_ord), -// |term| { -// terms.push( -// std::str::from_utf8(term) -// .expect("Failed to decode term as unicode") -// .to_owned(), -// ); -// Ok(()) -// }, -// ) -// }; - -// assert!( -// result.expect("Failed to read terms from term dictionary"), -// "Not all terms were matched in segment." -// ); - -// // Zip them back with their docs. -// if self.order.is_asc() { -// terms -// .into_iter() -// .zip(top_ordinals) -// .map(|(term, (_, doc))| (term, doc)) -// .collect() -// } else { -// terms -// .into_iter() -// .rev() -// .zip(top_ordinals) -// .map(|(term, (_, doc))| (term, doc)) -// .collect() -// } -// } -// } - /// The `TopDocs` collector keeps track of the top `K` documents /// sorted by their score. /// @@ -572,7 +416,7 @@ impl TopDocs { where TFastValue: FastValue, { - let u64_collector = self.order_by_u64_field(fast_field.to_string(), order.clone()); + let u64_collector = self.order_by_u64_field(fast_field.to_string(), order); FastFieldConvertCollector { collector: u64_collector, field: fast_field.to_string(), @@ -582,28 +426,14 @@ impl TopDocs { } /// Like `order_by_fast_field`, but for a `String` fast field. - // pub fn order_by_string_fast_field( - // self, - // fast_field: impl ToString, - // order: Order, - // ) -> impl Collector> { - // let limit = self.0.limit; - // let offset = self.0.offset; - // let u64_collector = CustomScoreTopCollector::new( - // ScorerByField { - // field: fast_field.to_string(), - // order: order.clone(), - // }, - // self.0.into_different_sort_key_type(), - // ); - // StringConvertCollector { - // collector: u64_collector, - // field: fast_field.to_string(), - // order, - // limit, - // offset, - // } - // } + pub fn order_by_string_fast_field( + self, + fast_field: impl ToString, + order: Order, + ) -> impl Collector, DocAddress)>> { + let by_string_sort_key_computer = ByStringColumn::with_column_name(fast_field.to_string()); + self.order_by((by_string_sort_key_computer, order)) + } /// Ranks the documents using a sort key. /// @@ -993,6 +823,7 @@ mod tests { use proptest::prelude::*; use super::{TopDocs, TopNComputer}; + use crate::collector::sort_key::NoScoreFn; use crate::collector::top_collector::ComparableDoc; use crate::collector::{Collector, DocSetCollector}; use crate::query::{AllQuery, Query, QueryParser}; @@ -1117,7 +948,7 @@ mod tests { for (feature, doc) in &docs { computer.push(*feature, *doc); } - let mut comparable_docs = docs.into_iter().map(|(feature, doc)| ComparableDoc { sort_key, doc }).collect::>(); + let mut comparable_docs = docs.into_iter().map(|(sort_key, doc)| ComparableDoc { sort_key, doc }).collect::>(); comparable_docs.sort(); comparable_docs.truncate(limit); prop_assert_eq!( @@ -1618,7 +1449,7 @@ mod tests { order: Order, limit: usize, offset: usize, - ) -> crate::Result> { + ) -> crate::Result, DocAddress)>> { let searcher = index.reader()?.searcher(); let top_collector = TopDocs::with_limit(limit) .and_offset(offset) @@ -1629,17 +1460,17 @@ mod tests { assert_eq!( &query(&index, Order::Desc, 3, 0)?, &[ - ("tokyo".to_owned(), DocAddress::new(0, 2)), - ("greenville".to_owned(), DocAddress::new(0, 1)), - ("austin".to_owned(), DocAddress::new(0, 0)), + (Some("tokyo".to_owned()), DocAddress::new(0, 2)), + (Some("greenville".to_owned()), DocAddress::new(0, 1)), + (Some("austin".to_owned()), DocAddress::new(0, 0)), ] ); assert_eq!( &query(&index, Order::Desc, 2, 0)?, &[ - ("tokyo".to_owned(), DocAddress::new(0, 2)), - ("greenville".to_owned(), DocAddress::new(0, 1)), + (Some("tokyo".to_owned()), DocAddress::new(0, 2)), + (Some("greenville".to_owned()), DocAddress::new(0, 1)), ] ); @@ -1648,33 +1479,33 @@ mod tests { assert_eq!( &query(&index, Order::Desc, 2, 1)?, &[ - ("greenville".to_owned(), DocAddress::new(0, 1)), - ("austin".to_owned(), DocAddress::new(0, 0)), + (Some("greenville".to_owned()), DocAddress::new(0, 1)), + (Some("austin".to_owned()), DocAddress::new(0, 0)), ] ); assert_eq!( &query(&index, Order::Asc, 3, 0)?, &[ - ("austin".to_owned(), DocAddress::new(0, 0)), - ("greenville".to_owned(), DocAddress::new(0, 1)), - ("tokyo".to_owned(), DocAddress::new(0, 2)), + (Some("austin".to_owned()), DocAddress::new(0, 0)), + (Some("greenville".to_owned()), DocAddress::new(0, 1)), + (Some("tokyo".to_owned()), DocAddress::new(0, 2)), ] ); assert_eq!( &query(&index, Order::Asc, 2, 1)?, &[ - ("greenville".to_owned(), DocAddress::new(0, 1)), - ("tokyo".to_owned(), DocAddress::new(0, 2)), + (Some("greenville".to_owned()), DocAddress::new(0, 1)), + (Some("tokyo".to_owned()), DocAddress::new(0, 2)), ] ); assert_eq!( &query(&index, Order::Asc, 2, 0)?, &[ - ("austin".to_owned(), DocAddress::new(0, 0)), - ("greenville".to_owned(), DocAddress::new(0, 1)), + (Some("austin".to_owned()), DocAddress::new(0, 0)), + (Some("greenville".to_owned()), DocAddress::new(0, 1)), ] ); @@ -1725,7 +1556,7 @@ mod tests { let term_ord = column.term_ords(doc_address.doc_id).next().unwrap(); let mut city = Vec::new(); column.dictionary().ord_to_term(term_ord, &mut city).unwrap(); - (String::try_from(city).unwrap(), doc_address) + (Some(String::try_from(city).unwrap()), doc_address) }); // Using the TopDocs collector should always be equivalent to sorting, skipping the @@ -1832,9 +1663,9 @@ mod tests { let field = index.schema().get_field("text").unwrap(); let query_parser = QueryParser::for_index(&index, vec![field]); let text_query = query_parser.parse_query("droopy tax").unwrap(); - let collector = TopDocs::with_limit(2) - .and_offset(1) - .custom_score(move |_segment_reader: &SegmentReader| move |doc: DocId| doc); + let collector = TopDocs::with_limit(2).and_offset(1).order_by(NoScoreFn( + move |_segment_reader: &SegmentReader| move |doc: DocId| doc, + )); let score_docs: Vec<(u32, DocAddress)> = index .reader() .unwrap() diff --git a/src/index/index_meta.rs b/src/index/index_meta.rs index 0962bd9bc3..86eaa35d6c 100644 --- a/src/index/index_meta.rs +++ b/src/index/index_meta.rs @@ -276,13 +276,14 @@ impl Default for IndexSettings { } /// The order to sort by -#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] +#[derive(Clone, Copy, Debug, Serialize, Deserialize, Eq, PartialEq)] pub enum Order { /// Ascending Order Asc, /// Descending Order Desc, } + impl Order { /// return if the Order is ascending pub fn is_asc(&self) -> bool { From 01bc7f22d222e3061fde3961123ff3f90bb21e30 Mon Sep 17 00:00:00 2001 From: Paul Masurel Date: Sat, 1 Nov 2025 11:51:09 +0100 Subject: [PATCH 09/11] renaming bizarro --- examples/warmer.rs | 3 +- src/collector/sort_key/mod.rs | 102 +++++++++++++++----- src/collector/sort_key/sort_key_computer.rs | 95 +++++++++--------- src/collector/sort_key_top_collector.rs | 39 +++++++- src/collector/top_collector.rs | 39 +++++--- src/collector/top_score_collector.rs | 88 ++++++++++++++--- 6 files changed, 257 insertions(+), 109 deletions(-) diff --git a/examples/warmer.rs b/examples/warmer.rs index 43237c3578..b08c0c502a 100644 --- a/examples/warmer.rs +++ b/examples/warmer.rs @@ -2,7 +2,6 @@ use std::cmp::Reverse; use std::collections::{HashMap, HashSet}; use std::sync::{Arc, RwLock, Weak}; -use tantivy::collector::sort_key::NoScoreFn; use tantivy::collector::TopDocs; use tantivy::index::SegmentId; use tantivy::query::QueryParser; @@ -165,7 +164,7 @@ fn main() -> tantivy::Result<()> { move |doc_id: DocId| Reverse(price[doc_id as usize]) }; - let most_expensive_first = TopDocs::with_limit(10).order_by(NoScoreFn(score_by_price)); + let most_expensive_first = TopDocs::with_limit(10).order_by_no_score_fn(score_by_price); let hits = searcher.search(&query, &most_expensive_first)?; assert_eq!( diff --git a/src/collector/sort_key/mod.rs b/src/collector/sort_key/mod.rs index b44b7f13bc..f81faf6b2e 100644 --- a/src/collector/sort_key/mod.rs +++ b/src/collector/sort_key/mod.rs @@ -1,7 +1,7 @@ mod sort_key_computer; use columnar::StrColumn; -pub use sort_key_computer::{NoScoreFn, SegmentSortKeyComputer, SortKeyComputer}; +pub use sort_key_computer::{SegmentSortKeyComputer, SortKeyComputer}; use crate::termdict::TermOrdinal; use crate::{DocId, Order, Score}; @@ -19,6 +19,10 @@ where self.0.requires_scoring() } + fn order(&self) -> Order { + self.1 + } + fn segment_sort_key_computer( &self, segment_reader: &crate::SegmentReader, @@ -32,51 +36,103 @@ impl SegmentSortKeyComputer for (TSegmentSortKeyComputer, Order) where TSegmentSortKeyComputer: SegmentSortKeyComputer, - TSegmentSortKey: BizarroWorldInvolution + PartialOrd + Clone + 'static + Sync + Send, + TSegmentSortKey: ReverseOrder + + PartialOrd + + Clone + + 'static + + Sync + + Send, { type SortKey = TSegmentSortKeyComputer::SortKey; type SegmentSortKey = TSegmentSortKey; fn sort_key(&mut self, doc: DocId, score: Score) -> Self::SegmentSortKey { let sort_key = self.0.sort_key(doc, score); - sort_key.involution_if_asc(self.1) + reverse_if_asc(sort_key, self.1) } - fn convert_segment_sort_key(&self, bizarro_sort_key: Self::SegmentSortKey) -> Self::SortKey { - let sort_key = bizarro_sort_key.involution_if_asc(self.1); + fn convert_segment_sort_key(&self, reverse_sort_key: Self::SegmentSortKey) -> Self::SortKey { + let sort_key = reverse_if_asc(reverse_sort_key, self.1); self.0.convert_segment_sort_key(sort_key) } } -// BizarroWorldInvolution is a transformation that flips the order of a value. +// ReverseOrder is a trait that flips the order of a value to match the +// expectation of sorting by "ascending order". // -// It is useful when we want to sort things in a way that is given to use dynamically -// (unknown at compile time). +// From some type, it can differ a little from just applying `std::cmp::Reverse`. +// In particular, for `Option`, the reverse order is not that of `std::cmp::Reverse>`, +// but rather `Option>`: +// Users typically still expect items without a value to appear at the end of the list. // -// That way the segment score keeps having the same type regardless of the order, -// and we do not rely on an enum. -trait BizarroWorldInvolution: Copy { - fn involution(&self) -> Self; - fn involution_if_asc(&self, order: Order) -> Self { - match order { - Order::Asc => self.involution(), - Order::Desc => *self, - } +// Also, when trying to apply an order dynamically (e.g. the order was passed by an API) +// we do not necessarily have the luxury to have a specific type for the new key. +// +// We then rely on an ReverseOrder implementation with a ReverseOrderType that maps to Self. +pub trait ReverseOrder: Clone { + type ReverseOrderType: PartialOrd + Clone; + + fn to_reverse_type(self) -> Self::ReverseOrderType; + + fn from_reverse_type(reverse_value: Self::ReverseOrderType) -> Self; +} + +fn reverse_if_asc(value: T, order: Order) -> T +where T: ReverseOrder { + match order { + Order::Asc => value.to_reverse_type(), + Order::Desc => value, } } -impl BizarroWorldInvolution for u64 { - fn involution(&self) -> Self { +impl ReverseOrder for u64 { + type ReverseOrderType = u64; + + fn to_reverse_type(self) -> Self::ReverseOrderType { u64::MAX - self } + + fn from_reverse_type(reverse_value: Self::ReverseOrderType) -> Self { + reverse_value.to_reverse_type() + } +} + +impl ReverseOrder for u32 { + type ReverseOrderType = u32; + + fn to_reverse_type(self) -> Self::ReverseOrderType { + u32::MAX - self + } + + fn from_reverse_type(reverse_value: Self::ReverseOrderType) -> Self { + reverse_value.to_reverse_type() + } +} + +impl ReverseOrder for f32 { + type ReverseOrderType = f32; + + fn to_reverse_type(self) -> Self::ReverseOrderType { + f32::MAX - self + } + + fn from_reverse_type(reverse_value: Self::ReverseOrderType) -> Self { + // That's an involution + reverse_value.to_reverse_type() + } } // The point here is that for Option, we do not want None values to come on top // when running a Asc query. -impl BizarroWorldInvolution for Option { - #[inline] - fn involution(&self) -> Self { - self.map(|val| val.involution()) +impl ReverseOrder for Option { + type ReverseOrderType = Option; + + fn to_reverse_type(self) -> Self::ReverseOrderType { + self.map(|val| val.to_reverse_type()) + } + + fn from_reverse_type(reverse_value: Self::ReverseOrderType) -> Self { + reverse_value.map(T::from_reverse_type) } } diff --git a/src/collector/sort_key/sort_key_computer.rs b/src/collector/sort_key/sort_key_computer.rs index 9ba12deb27..c23a3522e0 100644 --- a/src/collector/sort_key/sort_key_computer.rs +++ b/src/collector/sort_key/sort_key_computer.rs @@ -1,6 +1,7 @@ use std::cmp::Ordering; -use crate::{DocId, Result, Score, SegmentReader}; +use crate::collector::sort_key::ReverseOrder; +use crate::{DocId, Order, Result, Score, SegmentReader}; /// A `SegmentSortKeyComputer` makes it possible to modify the default score /// for a given document belonging to a specific segment. @@ -77,6 +78,10 @@ pub trait SortKeyComputer: Sync { false } + fn order(&self) -> Order { + Order::Desc + } + /// Builds a child sort key computer for a specific segment. fn segment_sort_key_computer(&self, segment_reader: &SegmentReader) -> Result; } @@ -101,6 +106,45 @@ where } } +impl ReverseOrder for std::cmp::Reverse { + type ReverseOrderType = T; + + fn to_reverse_type(self) -> Self::ReverseOrderType { + self.0 + } + + fn from_reverse_type(reverse_value: Self::ReverseOrderType) -> Self { + Self(reverse_value) + } +} + +impl ReverseOrder for String { + type ReverseOrderType = std::cmp::Reverse; + + fn to_reverse_type(self) -> Self::ReverseOrderType { + std::cmp::Reverse(self) + } + + fn from_reverse_type(reverse_value: Self::ReverseOrderType) -> Self { + reverse_value.0 + } +} + +impl ReverseOrder for (Left, Right) { + type ReverseOrderType = (Left::ReverseOrderType, Right::ReverseOrderType); + + fn to_reverse_type(self) -> Self::ReverseOrderType { + (self.0.to_reverse_type(), self.1.to_reverse_type()) + } + + fn from_reverse_type(reverse_value: Self::ReverseOrderType) -> Self { + ( + Left::from_reverse_type(reverse_value.0), + Right::from_reverse_type(reverse_value.1), + ) + } +} + impl SegmentSortKeyComputer for (HeadSegmentSortKeyComputer, TailSegmentSortKeyComputer) where @@ -315,52 +359,3 @@ where sort_key } } - -/// Helper struct to make it possible to define a sort key computer that does not use -/// the similary score from a simple function. -pub struct NoScoreFn(pub F); - -impl SortKeyComputer for NoScoreFn -where - F: 'static + Send + Sync + Fn(&SegmentReader) -> TNoScoreSortKeyFn, - TNoScoreSortKeyFn: 'static + Fn(DocId) -> TSortKey, - TSortKey: 'static + PartialOrd + Clone + Send + Sync, -{ - type SortKey = TSortKey; - type Child = NoScoreSegmentSortKeyComputer; - - fn segment_sort_key_computer(&self, segment_reader: &SegmentReader) -> Result { - Ok({ - NoScoreSegmentSortKeyComputer { - sort_key_fn: (self.0)(segment_reader), - } - }) - } - - fn requires_scoring(&self) -> bool { - false - } -} - -pub struct NoScoreSegmentSortKeyComputer { - sort_key_fn: TNoScoreSortKeyFn, -} - -impl SegmentSortKeyComputer - for NoScoreSegmentSortKeyComputer -where - TNoScoreSortKeyFn: 'static + Fn(DocId) -> TSortKey, - TSortKey: 'static + PartialOrd + Clone + Send + Sync, -{ - type SortKey = TSortKey; - type SegmentSortKey = TSortKey; - - fn sort_key(&mut self, doc: DocId, _score: Score) -> TSortKey { - (self.sort_key_fn)(doc) - } - - /// Convert a segment level score into the global level score. - fn convert_segment_sort_key(&self, sort_key: Self::SegmentSortKey) -> Self::SortKey { - sort_key - } -} diff --git a/src/collector/sort_key_top_collector.rs b/src/collector/sort_key_top_collector.rs index f6a6309e28..eeae423ca6 100644 --- a/src/collector/sort_key_top_collector.rs +++ b/src/collector/sort_key_top_collector.rs @@ -1,7 +1,7 @@ -use crate::collector::sort_key::{SegmentSortKeyComputer, SortKeyComputer}; -use crate::collector::top_collector::{TopCollector, TopSegmentCollector}; +use crate::collector::sort_key::{ReverseOrder, SegmentSortKeyComputer, SortKeyComputer}; +use crate::collector::top_collector::{merge_fruits, TopCollector, TopSegmentCollector}; use crate::collector::{Collector, SegmentCollector}; -use crate::{DocAddress, DocId, Result, Score, SegmentReader}; +use crate::{DocAddress, DocId, Order, Result, Score, SegmentReader}; pub(crate) struct TopBySortKeyCollector { sort_key_computer: TSortKeyComputer, @@ -25,7 +25,7 @@ where TSortKey: Clone + PartialOrd impl Collector for TopBySortKeyCollector where TSortKeyComputer: SortKeyComputer + Send + Sync, - TSortKey: 'static + Send + PartialOrd + Sync + Clone, + TSortKey: 'static + Send + PartialOrd + Sync + Clone + ReverseOrder, { type Fruit = Vec<(TSortKeyComputer::SortKey, DocAddress)>; @@ -51,7 +51,36 @@ where } fn merge_fruits(&self, segment_fruits: Vec) -> Result { - self.collector.merge_fruits(segment_fruits) + let order = self.sort_key_computer.order(); + match order { + Order::Asc => { + let reverse_segment_fruits: Vec< + Vec<( + ::ReverseOrderType, + DocAddress, + )>, + > = segment_fruits + .into_iter() + .map(|vec| { + vec.into_iter() + .map(|(sort_key, doc_addr)| (sort_key.to_reverse_type(), doc_addr)) + .collect() + }) + .collect(); + let merged_reverse_fruits = merge_fruits( + reverse_segment_fruits, + self.collector.limit, + self.collector.offset, + )?; + Ok(merged_reverse_fruits + .into_iter() + .map(|(reverse_sort_key, doc_addr)| { + (TSortKey::from_reverse_type(reverse_sort_key), doc_addr) + }) + .collect()) + } + Order::Desc => self.collector.merge_fruits(segment_fruits), + } } } diff --git a/src/collector/top_collector.rs b/src/collector/top_collector.rs index eeeff7473d..0cb72de332 100644 --- a/src/collector/top_collector.rs +++ b/src/collector/top_collector.rs @@ -104,22 +104,7 @@ where T: PartialOrd + Clone &self, children: Vec>, ) -> crate::Result> { - if self.limit == 0 { - return Ok(Vec::new()); - } - let mut top_collector: TopNComputer<_, _> = TopNComputer::new(self.limit + self.offset); - for child_fruit in children { - for (feature, doc) in child_fruit { - top_collector.push(feature, doc); - } - } - - Ok(top_collector - .into_sorted_vec() - .into_iter() - .skip(self.offset) - .map(|cdoc| (cdoc.sort_key, cdoc.doc)) - .collect()) + merge_fruits(children, self.limit, self.offset) } pub(crate) fn for_segment( @@ -146,6 +131,28 @@ where T: PartialOrd + Clone } } +pub fn merge_fruits( + children: Vec>, + limit: usize, + offset: usize, +) -> crate::Result> { + if limit == 0 { + return Ok(Vec::new()); + } + let mut top_collector: TopNComputer = TopNComputer::new(limit + offset); + for child_fruit in children { + for (feature, doc) in child_fruit { + top_collector.push(feature, doc); + } + } + Ok(top_collector + .into_sorted_vec() + .into_iter() + .skip(offset) + .map(|cdoc| (cdoc.sort_key, cdoc.doc)) + .collect()) +} + /// The Top Collector keeps track of the K documents /// sorted by type `T`. /// diff --git a/src/collector/top_score_collector.rs b/src/collector/top_score_collector.rs index 54404435e3..8866265641 100644 --- a/src/collector/top_score_collector.rs +++ b/src/collector/top_score_collector.rs @@ -6,7 +6,7 @@ use columnar::ColumnValues; use serde::{Deserialize, Serialize}; use super::Collector; -use crate::collector::sort_key::ByStringColumn; +use crate::collector::sort_key::{ByStringColumn, ReverseOrder}; use crate::collector::sort_key_top_collector::TopBySortKeyCollector; use crate::collector::top_collector::{ComparableDoc, TopCollector, TopSegmentCollector}; use crate::collector::{SegmentCollector, SegmentSortKeyComputer, SortKeyComputer}; @@ -21,7 +21,6 @@ struct FastFieldConvertCollector< pub collector: TCollector, pub field: String, pub fast_value: std::marker::PhantomData, - order: Order, } impl Collector for FastFieldConvertCollector @@ -70,11 +69,12 @@ where let transformed_result = raw_result .into_iter() .map(|(score, doc_address)| { - if self.order.is_desc() { - (TFastValue::from_u64(score), doc_address) - } else { - (TFastValue::from_u64(u64::MAX - score), doc_address) - } + (TFastValue::from_u64(score), doc_address) + // if self.order.is_desc() { + // (TFastValue::from_u64(score), doc_address) + // } else { + // (TFastValue::from_u64(u64::MAX - score), doc_address) + // } }) .collect::>(); Ok(transformed_result) @@ -421,7 +421,6 @@ impl TopDocs { collector: u64_collector, field: fast_field.to_string(), fast_value: PhantomData, - order, } } @@ -536,10 +535,74 @@ impl TopDocs { sort_key_computer: impl SortKeyComputer + Send + Sync, ) -> impl Collector> where - TSortKey: 'static + Clone + Send + Sync + PartialOrd, + TSortKey: 'static + Clone + Send + Sync + PartialOrd + ReverseOrder, { TopBySortKeyCollector::new(sort_key_computer, self.0.into_different_sort_key_type()) } + + pub fn order_by_no_score_fn( + self, + sort_key_fn: F, + ) -> impl Collector> + where + F: 'static + Send + Sync + Fn(&SegmentReader) -> TNoScoreSortKeyFn, + TNoScoreSortKeyFn: 'static + Fn(DocId) -> TSortKey, + TSortKey: 'static + PartialOrd + Clone + Send + Sync + ReverseOrder, + { + self.order_by(NoScoreFn(sort_key_fn)) + } +} + +/// Helper struct to make it possible to define a sort key computer that does not use +/// the similary score from a simple function. +struct NoScoreFn(pub F); + +impl SortKeyComputer for NoScoreFn +where + F: 'static + Send + Sync + Fn(&SegmentReader) -> TNoScoreSortKeyFn, + TNoScoreSortKeyFn: 'static + Fn(DocId) -> TSortKey, + TSortKey: 'static + PartialOrd + Clone + Send + Sync + ReverseOrder, +{ + type SortKey = TSortKey; + type Child = NoScoreSegmentSortKeyComputer; + + fn segment_sort_key_computer( + &self, + segment_reader: &SegmentReader, + ) -> crate::Result { + Ok({ + NoScoreSegmentSortKeyComputer { + sort_key_fn: (self.0)(segment_reader), + } + }) + } + + fn requires_scoring(&self) -> bool { + false + } +} + +struct NoScoreSegmentSortKeyComputer { + sort_key_fn: TNoScoreSortKeyFn, +} + +impl SegmentSortKeyComputer + for NoScoreSegmentSortKeyComputer +where + TNoScoreSortKeyFn: 'static + Fn(DocId) -> TSortKey, + TSortKey: 'static + PartialOrd + Clone + Send + Sync, +{ + type SortKey = TSortKey; + type SegmentSortKey = TSortKey; + + fn sort_key(&mut self, doc: DocId, _score: Score) -> TSortKey { + (self.sort_key_fn)(doc) + } + + /// Convert a segment level score into the global level score. + fn convert_segment_sort_key(&self, sort_key: Self::SegmentSortKey) -> Self::SortKey { + sort_key + } } impl Collector for TopDocs { @@ -823,7 +886,6 @@ mod tests { use proptest::prelude::*; use super::{TopDocs, TopNComputer}; - use crate::collector::sort_key::NoScoreFn; use crate::collector::top_collector::ComparableDoc; use crate::collector::{Collector, DocSetCollector}; use crate::query::{AllQuery, Query, QueryParser}; @@ -1663,9 +1725,9 @@ mod tests { let field = index.schema().get_field("text").unwrap(); let query_parser = QueryParser::for_index(&index, vec![field]); let text_query = query_parser.parse_query("droopy tax").unwrap(); - let collector = TopDocs::with_limit(2).and_offset(1).order_by(NoScoreFn( - move |_segment_reader: &SegmentReader| move |doc: DocId| doc, - )); + let collector = TopDocs::with_limit(2) + .and_offset(1) + .order_by_no_score_fn(move |_segment_reader: &SegmentReader| move |doc: DocId| doc); let score_docs: Vec<(u32, DocAddress)> = index .reader() .unwrap() From 2fba123de28934ce83c9f936dd2f15936f788b90 Mon Sep 17 00:00:00 2001 From: Paul Masurel Date: Sat, 1 Nov 2025 16:12:50 +0100 Subject: [PATCH 10/11] fixing unit test --- src/collector/sort_key/mod.rs | 47 +++-- src/collector/sort_key/sort_key_computer.rs | 18 +- src/collector/sort_key_top_collector.rs | 2 +- src/collector/top_score_collector.rs | 185 ++++++-------------- 4 files changed, 90 insertions(+), 162 deletions(-) diff --git a/src/collector/sort_key/mod.rs b/src/collector/sort_key/mod.rs index f81faf6b2e..7ef99ba5dc 100644 --- a/src/collector/sort_key/mod.rs +++ b/src/collector/sort_key/mod.rs @@ -3,6 +3,7 @@ mod sort_key_computer; use columnar::StrColumn; pub use sort_key_computer::{SegmentSortKeyComputer, SortKeyComputer}; +use crate::fastfield::FastValue; use crate::termdict::TermOrdinal; use crate::{DocId, Order, Score}; @@ -36,12 +37,8 @@ impl SegmentSortKeyComputer for (TSegmentSortKeyComputer, Order) where TSegmentSortKeyComputer: SegmentSortKeyComputer, - TSegmentSortKey: ReverseOrder - + PartialOrd - + Clone - + 'static - + Sync - + Send, + TSegmentSortKey: + ReverseOrder + PartialOrd + Clone + 'static + Sync + Send, { type SortKey = TSegmentSortKeyComputer::SortKey; type SegmentSortKey = TSegmentSortKey; @@ -70,53 +67,55 @@ where // // We then rely on an ReverseOrder implementation with a ReverseOrderType that maps to Self. pub trait ReverseOrder: Clone { - type ReverseOrderType: PartialOrd + Clone; + type ReverseType: PartialOrd + Clone; - fn to_reverse_type(self) -> Self::ReverseOrderType; + fn to_reverse_type(self) -> Self::ReverseType; - fn from_reverse_type(reverse_value: Self::ReverseOrderType) -> Self; + fn from_reverse_type(reverse_value: Self::ReverseType) -> Self; } fn reverse_if_asc(value: T, order: Order) -> T -where T: ReverseOrder { +where T: ReverseOrder { match order { Order::Asc => value.to_reverse_type(), Order::Desc => value, } } -impl ReverseOrder for u64 { - type ReverseOrderType = u64; +impl ReverseOrder for TFastValue { + type ReverseType = TFastValue; - fn to_reverse_type(self) -> Self::ReverseOrderType { - u64::MAX - self + fn to_reverse_type(self) -> Self::ReverseType { + // TODO check that the compiler is good enough to compile that to i64::MAX - self for i64 + // for instance. + TFastValue::from_u64(u64::MAX - self.to_u64()) } - fn from_reverse_type(reverse_value: Self::ReverseOrderType) -> Self { + fn from_reverse_type(reverse_value: Self::ReverseType) -> Self { reverse_value.to_reverse_type() } } impl ReverseOrder for u32 { - type ReverseOrderType = u32; + type ReverseType = u32; - fn to_reverse_type(self) -> Self::ReverseOrderType { + fn to_reverse_type(self) -> Self::ReverseType { u32::MAX - self } - fn from_reverse_type(reverse_value: Self::ReverseOrderType) -> Self { + fn from_reverse_type(reverse_value: Self::ReverseType) -> Self { reverse_value.to_reverse_type() } } impl ReverseOrder for f32 { - type ReverseOrderType = f32; + type ReverseType = f32; - fn to_reverse_type(self) -> Self::ReverseOrderType { + fn to_reverse_type(self) -> Self::ReverseType { f32::MAX - self } - fn from_reverse_type(reverse_value: Self::ReverseOrderType) -> Self { + fn from_reverse_type(reverse_value: Self::ReverseType) -> Self { // That's an involution reverse_value.to_reverse_type() } @@ -125,13 +124,13 @@ impl ReverseOrder for f32 { // The point here is that for Option, we do not want None values to come on top // when running a Asc query. impl ReverseOrder for Option { - type ReverseOrderType = Option; + type ReverseType = Option; - fn to_reverse_type(self) -> Self::ReverseOrderType { + fn to_reverse_type(self) -> Self::ReverseType { self.map(|val| val.to_reverse_type()) } - fn from_reverse_type(reverse_value: Self::ReverseOrderType) -> Self { + fn from_reverse_type(reverse_value: Self::ReverseType) -> Self { reverse_value.map(T::from_reverse_type) } } diff --git a/src/collector/sort_key/sort_key_computer.rs b/src/collector/sort_key/sort_key_computer.rs index c23a3522e0..d0fc569e8b 100644 --- a/src/collector/sort_key/sort_key_computer.rs +++ b/src/collector/sort_key/sort_key_computer.rs @@ -107,37 +107,37 @@ where } impl ReverseOrder for std::cmp::Reverse { - type ReverseOrderType = T; + type ReverseType = T; - fn to_reverse_type(self) -> Self::ReverseOrderType { + fn to_reverse_type(self) -> Self::ReverseType { self.0 } - fn from_reverse_type(reverse_value: Self::ReverseOrderType) -> Self { + fn from_reverse_type(reverse_value: Self::ReverseType) -> Self { Self(reverse_value) } } impl ReverseOrder for String { - type ReverseOrderType = std::cmp::Reverse; + type ReverseType = std::cmp::Reverse; - fn to_reverse_type(self) -> Self::ReverseOrderType { + fn to_reverse_type(self) -> Self::ReverseType { std::cmp::Reverse(self) } - fn from_reverse_type(reverse_value: Self::ReverseOrderType) -> Self { + fn from_reverse_type(reverse_value: Self::ReverseType) -> Self { reverse_value.0 } } impl ReverseOrder for (Left, Right) { - type ReverseOrderType = (Left::ReverseOrderType, Right::ReverseOrderType); + type ReverseType = (Left::ReverseType, Right::ReverseType); - fn to_reverse_type(self) -> Self::ReverseOrderType { + fn to_reverse_type(self) -> Self::ReverseType { (self.0.to_reverse_type(), self.1.to_reverse_type()) } - fn from_reverse_type(reverse_value: Self::ReverseOrderType) -> Self { + fn from_reverse_type(reverse_value: Self::ReverseType) -> Self { ( Left::from_reverse_type(reverse_value.0), Right::from_reverse_type(reverse_value.1), diff --git a/src/collector/sort_key_top_collector.rs b/src/collector/sort_key_top_collector.rs index eeae423ca6..6f4b127dfb 100644 --- a/src/collector/sort_key_top_collector.rs +++ b/src/collector/sort_key_top_collector.rs @@ -56,7 +56,7 @@ where Order::Asc => { let reverse_segment_fruits: Vec< Vec<( - ::ReverseOrderType, + ::ReverseType, DocAddress, )>, > = segment_fruits diff --git a/src/collector/top_score_collector.rs b/src/collector/top_score_collector.rs index 8866265641..afbdfa9224 100644 --- a/src/collector/top_score_collector.rs +++ b/src/collector/top_score_collector.rs @@ -1,8 +1,7 @@ use std::fmt; use std::marker::PhantomData; -use std::sync::Arc; -use columnar::ColumnValues; +use columnar::Column; use serde::{Deserialize, Serialize}; use super::Collector; @@ -12,74 +11,7 @@ use crate::collector::top_collector::{ComparableDoc, TopCollector, TopSegmentCol use crate::collector::{SegmentCollector, SegmentSortKeyComputer, SortKeyComputer}; use crate::fastfield::{FastFieldNotAvailableError, FastValue}; use crate::query::Weight; -use crate::{DocAddress, DocId, Order, Score, SegmentOrdinal, SegmentReader, TantivyError}; - -struct FastFieldConvertCollector< - TCollector: Collector>, - TFastValue: FastValue, -> { - pub collector: TCollector, - pub field: String, - pub fast_value: std::marker::PhantomData, -} - -impl Collector for FastFieldConvertCollector -where - TCollector: Collector>, - TFastValue: FastValue, -{ - type Fruit = Vec<(TFastValue, DocAddress)>; - - type Child = TCollector::Child; - - fn for_segment( - &self, - segment_local_id: crate::SegmentOrdinal, - segment: &SegmentReader, - ) -> crate::Result { - let schema = segment.schema(); - let field = schema.get_field(&self.field)?; - let field_entry = schema.get_field_entry(field); - if !field_entry.is_fast() { - return Err(TantivyError::SchemaError(format!( - "Field {:?} is not a fast field.", - field_entry.name() - ))); - } - let schema_type = TFastValue::to_type(); - let requested_type = field_entry.field_type().value_type(); - if schema_type != requested_type { - return Err(TantivyError::SchemaError(format!( - "Field {:?} is of type {schema_type:?}!={requested_type:?}", - field_entry.name() - ))); - } - self.collector.for_segment(segment_local_id, segment) - } - - fn requires_scoring(&self) -> bool { - self.collector.requires_scoring() - } - - fn merge_fruits( - &self, - segment_fruits: Vec<::Fruit>, - ) -> crate::Result { - let raw_result = self.collector.merge_fruits(segment_fruits)?; - let transformed_result = raw_result - .into_iter() - .map(|(score, doc_address)| { - (TFastValue::from_u64(score), doc_address) - // if self.order.is_desc() { - // (TFastValue::from_u64(score), doc_address) - // } else { - // (TFastValue::from_u64(u64::MAX - score), doc_address) - // } - }) - .collect::>(); - Ok(transformed_result) - } -} +use crate::{DocAddress, DocId, Order, Score, SegmentOrdinal, SegmentReader}; /// The `TopDocs` collector keeps track of the top `K` documents /// sorted by their score. @@ -137,43 +69,34 @@ impl fmt::Debug for TopDocs { } } -struct SortKeyByFastFieldReader { - sort_column: Arc>, - order: Order, +struct SortKeyByFastFieldReader { + sort_column: Column, + typ: PhantomData, } -impl SegmentSortKeyComputer for SortKeyByFastFieldReader { - type SortKey = u64; +impl SegmentSortKeyComputer for SortKeyByFastFieldReader { + type SortKey = Option; - type SegmentSortKey = u64; + type SegmentSortKey = Option; fn sort_key(&mut self, doc: DocId, _score: Score) -> Self::SegmentSortKey { - let val = self.sort_column.get_val(doc); - if self.order == Order::Desc { - u64::MAX - val - } else { - val - } + self.sort_column.first(doc) } fn convert_segment_sort_key(&self, sort_key: Self::SegmentSortKey) -> Self::SortKey { - if self.order == Order::Desc { - u64::MAX - sort_key - } else { - sort_key - } + sort_key.map(T::from_u64) } } -struct ScorerByField { +struct ScorerByField { field: String, - order: Order, + typ: PhantomData, } -impl SortKeyComputer for ScorerByField { - type Child = SortKeyByFastFieldReader; +impl SortKeyComputer for ScorerByField { + type Child = SortKeyByFastFieldReader; - type SortKey = u64; + type SortKey = Option; fn segment_sort_key_computer( &self, @@ -189,13 +112,9 @@ impl SortKeyComputer for ScorerByField { sort_column_opt.ok_or_else(|| FastFieldNotAvailableError { field_name: self.field.clone(), })?; - let mut default_value = 0u64; - if self.order.is_asc() { - default_value = u64::MAX; - } Ok(SortKeyByFastFieldReader { - sort_column: sort_column.first_or_default_col(default_value), - order: self.order.clone(), + sort_column: sort_column, + typ: PhantomData, }) } } @@ -329,12 +248,15 @@ impl TopDocs { self, field: impl ToString, order: Order, - ) -> impl Collector> { + ) -> impl Collector, DocAddress)>> { TopBySortKeyCollector::new( - ScorerByField { - field: field.to_string(), + ( + ScorerByField { + field: field.to_string(), + typ: PhantomData, + }, order, - }, + ), self.0.into_different_sort_key_type(), ) } @@ -412,16 +334,20 @@ impl TopDocs { self, fast_field: impl ToString, order: Order, - ) -> impl Collector> + ) -> impl Collector, DocAddress)>> where - TFastValue: FastValue, + TFastValue: FastValue + ReverseOrder, { - let u64_collector = self.order_by_u64_field(fast_field.to_string(), order); - FastFieldConvertCollector { - collector: u64_collector, - field: fast_field.to_string(), - fast_value: PhantomData, - } + TopBySortKeyCollector::new( + ( + ScorerByField { + field: fast_field.to_string(), + typ: PhantomData, + }, + order, + ), + self.0.into_different_sort_key_type(), + ) } /// Like `order_by_fast_field`, but for a `String` fast field. @@ -1378,13 +1304,13 @@ mod tests { let searcher = index.reader()?.searcher(); let top_collector = TopDocs::with_limit(4).order_by_u64_field(SIZE, Order::Desc); - let top_docs: Vec<(u64, DocAddress)> = searcher.search(&query, &top_collector)?; + let top_docs: Vec<(Option, DocAddress)> = searcher.search(&query, &top_collector)?; assert_eq!( &top_docs[..], &[ - (64, DocAddress::new(0, 1)), - (16, DocAddress::new(0, 2)), - (12, DocAddress::new(0, 0)) + (Some(64), DocAddress::new(0, 1)), + (Some(16), DocAddress::new(0, 2)), + (Some(12), DocAddress::new(0, 0)) ] ); Ok(()) @@ -1417,12 +1343,13 @@ mod tests { index_writer.commit()?; let searcher = index.reader()?.searcher(); let top_collector = TopDocs::with_limit(3).order_by_fast_field("birthday", Order::Desc); - let top_docs: Vec<(DateTime, DocAddress)> = searcher.search(&AllQuery, &top_collector)?; + let top_docs: Vec<(Option, DocAddress)> = + searcher.search(&AllQuery, &top_collector)?; assert_eq!( &top_docs[..], &[ - (mr_birthday, DocAddress::new(0, 1)), - (pr_birthday, DocAddress::new(0, 0)), + (Some(mr_birthday), DocAddress::new(0, 1)), + (Some(pr_birthday), DocAddress::new(0, 0)), ] ); Ok(()) @@ -1447,12 +1374,13 @@ mod tests { index_writer.commit()?; let searcher = index.reader()?.searcher(); let top_collector = TopDocs::with_limit(3).order_by_fast_field("altitude", Order::Desc); - let top_docs: Vec<(i64, DocAddress)> = searcher.search(&AllQuery, &top_collector)?; + let top_docs: Vec<(Option, DocAddress)> = + searcher.search(&AllQuery, &top_collector)?; assert_eq!( &top_docs[..], &[ - (40i64, DocAddress::new(0, 1)), - (-1i64, DocAddress::new(0, 0)), + (Some(40i64), DocAddress::new(0, 1)), + (Some(-1i64), DocAddress::new(0, 0)), ] ); Ok(()) @@ -1477,12 +1405,13 @@ mod tests { index_writer.commit()?; let searcher = index.reader()?.searcher(); let top_collector = TopDocs::with_limit(3).order_by_fast_field("altitude", Order::Desc); - let top_docs: Vec<(f64, DocAddress)> = searcher.search(&AllQuery, &top_collector)?; + let top_docs: Vec<(Option, DocAddress)> = + searcher.search(&AllQuery, &top_collector)?; assert_eq!( &top_docs[..], &[ - (40f64, DocAddress::new(0, 1)), - (-1.0f64, DocAddress::new(0, 0)), + (Some(40f64), DocAddress::new(0, 1)), + (Some(-1.0f64), DocAddress::new(0, 0)), ] ); Ok(()) @@ -1789,14 +1718,14 @@ mod tests { let searcher = index.reader()?.searcher(); let top_collector = TopDocs::with_limit(4).order_by_fast_field(SIZE, Order::Asc); - let top_docs: Vec<(u64, DocAddress)> = searcher.search(&query, &top_collector)?; + let top_docs: Vec<(Option, DocAddress)> = searcher.search(&query, &top_collector)?; assert_eq!( &top_docs[..], &[ - (12, DocAddress::new(0, 0)), - (16, DocAddress::new(0, 2)), - (64, DocAddress::new(0, 1)), - (18446744073709551615, DocAddress::new(0, 3)), + (Some(12), DocAddress::new(0, 0)), + (Some(16), DocAddress::new(0, 2)), + (Some(64), DocAddress::new(0, 1)), + (None, DocAddress::new(0, 3)), ] ); Ok(()) From e73c945c67f2076b0839f1eb17de65e775f684c5 Mon Sep 17 00:00:00 2001 From: Stu Hood Date: Sat, 1 Nov 2025 15:46:46 -0700 Subject: [PATCH 11/11] Implement `collect_block` for lazy scorers. --- src/collector/sort_key/sort_key_computer.rs | 39 +++++++++++- src/collector/sort_key_top_collector.rs | 5 ++ src/collector/top_collector.rs | 9 +++ src/collector/top_score_collector.rs | 67 +++++++++++++++++++-- 4 files changed, 112 insertions(+), 8 deletions(-) diff --git a/src/collector/sort_key/sort_key_computer.rs b/src/collector/sort_key/sort_key_computer.rs index d0fc569e8b..78f26dd9ff 100644 --- a/src/collector/sort_key/sort_key_computer.rs +++ b/src/collector/sort_key/sort_key_computer.rs @@ -1,6 +1,8 @@ use std::cmp::Ordering; +use crate::collector::ComparableDoc; use crate::collector::sort_key::ReverseOrder; +use crate::collector::top_score_collector::push_assuming_capacity; use crate::{DocId, Order, Result, Score, SegmentReader}; /// A `SegmentSortKeyComputer` makes it possible to modify the default score @@ -23,7 +25,8 @@ pub trait SegmentSortKeyComputer: 'static { /// Returns true if the `SegmentSortKeyComputer` is a good candidate for the lazy evaluation /// optimization. See [`SegmentSortKeyComputer::accept_score_lazy`]. fn is_lazy() -> bool { - false + // TODO: Without this, we don't currently have test coverage for laziness. + true } /// Implementing this method makes it possible to avoid computing @@ -43,9 +46,9 @@ pub trait SegmentSortKeyComputer: 'static { threshold: &Self::SegmentSortKey, ) -> Option<(std::cmp::Ordering, Self::SegmentSortKey)> { let excluded_ordering = if REVERSE_ORDER { - Ordering::Greater - } else { Ordering::Less + } else { + Ordering::Greater }; let sort_key = self.sort_key(doc_id, score); let cmp = sort_key.partial_cmp(threshold).unwrap_or(excluded_ordering); @@ -56,6 +59,36 @@ pub trait SegmentSortKeyComputer: 'static { } } + /// Similar to `accept_sort_key_lazy`, but pushes results directly into the given buffer. + /// + /// The buffer must have at least enough capacity for `docs` matches, or this method will + /// panic. + fn accept_sort_key_block_lazy( + &mut self, + docs: &[DocId], + threshold: &Self::SegmentSortKey, + output: &mut Vec>, + ) { + let excluded_ordering = if REVERSE_ORDER { + Ordering::Less + } else { + Ordering::Greater + }; + for &doc in docs { + let sort_key = self.sort_key(doc, 0.0); + let cmp = sort_key.partial_cmp(threshold).unwrap_or(excluded_ordering); + if cmp != excluded_ordering { + push_assuming_capacity( + ComparableDoc { + sort_key, + doc, + }, + output, + ); + } + } + } + /// Convert a segment level sort key into the global sort key. fn convert_segment_sort_key(&self, sort_key: Self::SegmentSortKey) -> Self::SortKey; } diff --git a/src/collector/sort_key_top_collector.rs b/src/collector/sort_key_top_collector.rs index 6f4b127dfb..e2035c505e 100644 --- a/src/collector/sort_key_top_collector.rs +++ b/src/collector/sort_key_top_collector.rs @@ -102,6 +102,11 @@ where TSegmentSortKeyComputer: 'static + SegmentSortKeyComputer .collect_lazy(doc, score, &mut self.segment_sort_key_computer); } + fn collect_block(&mut self, docs: &[DocId]) { + self.segment_collector + .collect_block_lazy(docs, &mut self.segment_sort_key_computer); + } + fn harvest(self) -> Self::Fruit { let segment_hits: Vec<(TSegmentSortKeyComputer::SegmentSortKey, DocAddress)> = self.segment_collector.harvest(); diff --git a/src/collector/top_collector.rs b/src/collector/top_collector.rs index 0cb72de332..9da49792d9 100644 --- a/src/collector/top_collector.rs +++ b/src/collector/top_collector.rs @@ -202,6 +202,15 @@ impl TopSegmentCollector { self.topn_computer.push(feature, doc); } + #[inline] + pub fn collect_block_lazy( + &mut self, + docs: &[DocId], + segment_scorer: &mut impl SegmentSortKeyComputer, + ) { + self.topn_computer.push_block_lazy(docs, segment_scorer); + } + #[inline] pub fn collect_lazy( &mut self, diff --git a/src/collector/top_score_collector.rs b/src/collector/top_score_collector.rs index afbdfa9224..420a0608e3 100644 --- a/src/collector/top_score_collector.rs +++ b/src/collector/top_score_collector.rs @@ -693,7 +693,9 @@ where /// Create a new `TopNComputer`. /// Internally it will allocate a buffer of size `2 * top_n`. pub fn new(top_n: usize) -> Self { - let vec_cap = top_n.max(1) * 2; + // We ensure that there is always enough space to include an entire block in the buffer if + // need be, so that `push_block_lazy` can avoid checking capacity inside its loop. + let vec_cap = (top_n.max(1) * 2) + crate::COLLECT_BLOCK_BUFFER_LEN; TopNComputer { buffer: Vec::with_capacity(vec_cap), top_n, @@ -775,6 +777,12 @@ where TScore: PartialOrd + Clone else { return; }; + + if self.buffer.len() == self.buffer.capacity() { + let median = self.truncate_top_n(); + self.threshold = Some(median); + } + push_assuming_capacity( ComparableDoc { sort_key: feature, @@ -789,13 +797,62 @@ where TScore: PartialOrd + Clone self.push(feature, doc); return; } + + #[inline(always)] + pub(crate) fn push_block_lazy< + TSegmentSortKeyComputer: SegmentSortKeyComputer, + >( + &mut self, + docs: &[DocId], + score_tweaker: &mut TSegmentSortKeyComputer, + ) { + // If the addition of this block might push us over capacity, start by truncating: our + // capacity is larger than 2*n + COLLECT_BLOCK_BUFFER_LEN, so this always makes enough room + // for the entire block (although some of the block might be eliminated). + if self.buffer.len() + docs.len() > self.buffer.capacity() { + let median = self.truncate_top_n(); + self.threshold = Some(median); + } + + if let Some(last_median) = self.threshold.clone() { + if TSegmentSortKeyComputer::is_lazy() { + // We validated at the top of the method that we have capacity. + score_tweaker.accept_sort_key_block_lazy::(docs, &last_median, &mut self.buffer); + return; + } + + // Eagerly push, with a threshold to compare to. + for &doc in docs { + let sort_key = score_tweaker.sort_key(doc, 0.0); + + if !REVERSE_ORDER && sort_key > last_median { + continue; + } + if REVERSE_ORDER && sort_key < last_median { + continue; + } + + // We validated at the top of the method that we have capacity. + let comparable_doc = ComparableDoc { doc, sort_key }; + push_assuming_capacity(comparable_doc, &mut self.buffer); + } + } else { + // Eagerly push, without a threshold to compare to. + for &doc in docs { + let sort_key = score_tweaker.sort_key(doc, 0.0); + // We validated at the top of the method that we have capacity. + let comparable_doc = ComparableDoc { doc, sort_key }; + push_assuming_capacity(comparable_doc, &mut self.buffer); + } + } + } } // Push an element provided there is enough capacity to do so. // // Panics if there is not enough capacity to add an element. #[inline(always)] -fn push_assuming_capacity(el: T, buf: &mut Vec) { +pub fn push_assuming_capacity(el: T, buf: &mut Vec) { let prev_len = buf.len(); assert!(prev_len < buf.capacity()); // This is mimicking the current (non-stabilized) implementation in std. @@ -1509,11 +1566,11 @@ mod tests { #[test] fn test_top_field_collect_string_prop( order in prop_oneof!(Just(Order::Desc), Just(Order::Asc)), - limit in 1..256_usize, - offset in 0..256_usize, + limit in 1..32_usize, + offset in 0..32_usize, segments_terms in proptest::collection::vec( - proptest::collection::vec(0..32_u8, 1..32_usize), + proptest::collection::vec(0..64_u8, 1..256_usize), 0..8_usize, ) ) {