diff --git a/Cargo.lock b/Cargo.lock index 6d18232..4e568ef 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -916,9 +916,11 @@ name = "onpair" version = "0.0.4" dependencies = [ "arrow-array", + "arrow-buffer", "arrow-schema", "codspeed-divan-compat", "hashbrown 0.16.1", + "memchr", "parquet", "rand", "rstest", diff --git a/Cargo.toml b/Cargo.toml index 487d1bb..152502c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -37,7 +37,9 @@ rand = "0.9.0" [dev-dependencies] divan = { package = "codspeed-divan-compat", version = "4.0.4" } rstest = "0.26.1" +memchr = "2" arrow-array = "57.1" +arrow-buffer = "57.1" arrow-schema = "57.1" parquet = "57.1" tpchgen = "2.0.2" @@ -50,3 +52,7 @@ harness = false [[bench]] name = "clickbench" harness = false + +[[bench]] +name = "search" +harness = false diff --git a/benches/search.rs b/benches/search.rs new file mode 100644 index 0000000..6f467cb --- /dev/null +++ b/benches/search.rs @@ -0,0 +1,709 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors +//! Compressed-domain search benchmark: `Pattern::Contains` / `Pattern::Prefix` +//! over a real (or synthetic) string column, never decompressing. +#![allow( + clippy::cast_possible_truncation, + clippy::cast_precision_loss, + clippy::cast_lossless, + clippy::cast_sign_loss, + clippy::expect_used, + clippy::missing_panics_doc, + clippy::unwrap_used +)] +// +// A pre-pass scans the corpus to bucket needles by selectivity — `rare`, +// `medium`, `common` — for both modes, so the benchmark reports how throughput +// varies with match density (a `common` needle hits the automaton's early-exit +// on most rows; a `rare` one scans almost every token). The selected needles, +// their measured selectivity, and a brute-force cross-check are printed at +// startup. +// +// Corpus resolution mirrors `clickbench.rs`: +// 1. env `ONPAIR_BENCH_PARQUET` (+ optional `ONPAIR_BENCH_COLUMN`) +// 2. `/tmp/userdata1.parquet` +// 3. a synthetic ClickBench-shaped URL corpus. +// Code width is `ONPAIR_SEARCH_BITS` (default 16). +// +// Run with: cargo bench --bench search + +use std::env; +use std::fmt; +use std::fs::File; +use std::path::PathBuf; +use std::sync::OnceLock; + +use arrow_array::Array; +use arrow_array::cast::AsArray; +use divan::Bencher; +use divan::counter::BytesCount; +use divan::counter::ItemsCount; +use onpair::Bits; +use onpair::Column; +use onpair::Config; +use onpair::Pattern; +use onpair::Threshold; +use onpair::compress; +use onpair::decompress; +use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder; + +// ───────────────────────────────────────────────────────────────────────────── +// Corpus loading (shared shape with clickbench.rs). +// ───────────────────────────────────────────────────────────────────────────── + +struct Corpus { + source: String, + rows: Vec>, + bytes: Vec, + offsets: Vec, + total_bytes: usize, +} + +fn pack(strings: &[Vec]) -> (Vec, Vec) { + let mut bytes = Vec::with_capacity(strings.iter().map(|s| s.len()).sum()); + let mut offsets = Vec::with_capacity(strings.len() + 1); + offsets.push(0u64); + for s in strings { + bytes.extend_from_slice(s); + offsets.push(bytes.len() as u64); + } + (bytes, offsets) +} + +fn corpus() -> &'static Corpus { + static CORPUS: OnceLock = OnceLock::new(); + CORPUS.get_or_init(|| { + let (source, rows) = load_corpus(); + let (bytes, offsets) = pack(&rows); + let total_bytes = bytes.len(); + let c = Corpus { + source, + rows, + bytes, + offsets, + total_bytes, + }; + eprintln!( + "[onpair search] corpus: {} ({} rows, {:.2} MiB)", + c.source, + c.rows.len(), + c.total_bytes as f64 / (1024.0 * 1024.0) + ); + c + }) +} + +fn load_corpus() -> (String, Vec>) { + if let Ok(path) = env::var("ONPAIR_BENCH_PARQUET") + && let Some(rows) = read_parquet_strings(&PathBuf::from(&path)) + { + return (format!("{path} (env)"), rows); + } + let fallback = PathBuf::from("/tmp/userdata1.parquet"); + if fallback.exists() + && let Some(rows) = read_parquet_strings(&fallback) + { + return (format!("{} (auto-detected)", fallback.display()), rows); + } + let rows = synthetic_clickbench_urls(100_000); + ("synthetic ClickBench-shaped URL corpus".to_string(), rows) +} + +fn read_parquet_strings(path: &PathBuf) -> Option>> { + let file = File::open(path).ok()?; + let builder = ParquetRecordBatchReaderBuilder::try_new(file).ok()?; + let schema = builder.schema().clone(); + + let col_name = env::var("ONPAIR_BENCH_COLUMN").ok(); + let picked = match col_name.as_deref() { + Some(name) => schema.fields().iter().position(|f| f.name() == name)?, + None => schema.fields().iter().position(|f| { + use arrow_schema::DataType::*; + matches!( + f.data_type(), + Utf8 | LargeUtf8 | Utf8View | Binary | LargeBinary | BinaryView + ) + })?, + }; + let col_field = schema.fields().get(picked)?.clone(); + eprintln!( + "[onpair search] reading column #{picked} `{}` ({})", + col_field.name(), + col_field.data_type() + ); + + let mut rows: Vec> = Vec::new(); + // Optional cap so huge corpora (e.g. FineWeb `text`) fit in memory. + let max_rows = env::var("ONPAIR_BENCH_MAX_ROWS") + .ok() + .and_then(|s| s.parse::().ok()) + .unwrap_or(usize::MAX); + let reader = builder.build().ok()?; + 'outer: for batch in reader.flatten() { + let arr = batch.column(picked); + use arrow_schema::DataType::*; + match arr.data_type() { + Utf8 => { + for s in arr.as_string::().iter() { + rows.push(s.unwrap_or("").as_bytes().to_vec()); + } + } + LargeUtf8 => { + for s in arr.as_string::().iter() { + rows.push(s.unwrap_or("").as_bytes().to_vec()); + } + } + Utf8View => { + for s in arr.as_string_view().iter() { + rows.push(s.unwrap_or("").as_bytes().to_vec()); + } + } + Binary => { + let a = arr.as_any().downcast_ref::()?; + for b in a.iter() { + rows.push(b.unwrap_or(b"").to_vec()); + } + } + LargeBinary => { + let a = arr + .as_any() + .downcast_ref::()?; + for b in a.iter() { + rows.push(b.unwrap_or(b"").to_vec()); + } + } + BinaryView => { + let a = arr + .as_any() + .downcast_ref::()?; + for b in a.iter() { + rows.push(b.unwrap_or(b"").to_vec()); + } + } + _ => return None, + } + if rows.len() >= max_rows { + rows.truncate(max_rows); + break 'outer; + } + } + Some(rows) +} + +fn synthetic_clickbench_urls(n: usize) -> Vec> { + const HOSTS: &[&str] = &[ + "https://www.yandex.ru", + "https://www.google.com", + "https://news.ycombinator.com", + "https://www.example.com", + "https://docs.example.org", + "https://api.example.net", + "http://m.yandex.ru", + "https://maps.example.com", + "https://shop.example.com", + "ftp://files.example.com", + ]; + const PATHS: &[&str] = &[ + "/", + "/page", + "/news", + "/search?q=", + "/profile", + "/login", + "/api/v1/data", + "/static/asset.png", + "/blog/post-", + "/feed.xml", + "/sitemap.xml", + "/users/", + "/admin/dashboard", + "/categories/electronics", + "/cart/checkout", + ]; + const TAILS: &[&str] = &["", "alpha", "beta", "gamma", "delta", "001", "002", "003"]; + let mut out = Vec::with_capacity(n); + let mut x = 0x9E3779B97F4A7C15u64; + for _ in 0..n { + x = x.wrapping_add(0x9E3779B97F4A7C15); + let h = HOSTS[(x as usize) % HOSTS.len()]; + let p = PATHS[((x >> 16) as usize) % PATHS.len()]; + let t = TAILS[((x >> 32) as usize) % TAILS.len()]; + let m = (x >> 48) as u16; + out.push(format!("{h}{p}{t}{m}").into_bytes()); + } + out +} + +// ───────────────────────────────────────────────────────────────────────────── +// Compressed column (one width, default 16). +// ───────────────────────────────────────────────────────────────────────────── + +fn search_bits() -> u8 { + env::var("ONPAIR_SEARCH_BITS") + .ok() + .and_then(|s| s.parse().ok()) + .unwrap_or(16) +} + +fn column() -> &'static Column { + static COL: OnceLock> = OnceLock::new(); + COL.get_or_init(|| { + let c = corpus(); + let cfg = Config { + bits: Bits::new(search_bits()).unwrap(), + threshold: Threshold::new(0.5).unwrap(), + seed: Some(42), + }; + let col = compress(&c.bytes, &c.offsets, cfg).unwrap(); + let dict_b = col.dict_bytes.len() + col.dict_offsets.len() * 4; + let codes_b = col.codes.len() * 2; + let offs_b = col.code_offsets.len() * 8; + let first_b = col.first_codes.as_ref().map_or(0, |f| f.len() * 2); + let core = dict_b + codes_b + offs_b; + eprintln!( + "[onpair search] compressed @ bits={}: {} dict tokens, {} codes", + col.bits, + col.dict_offsets.len() - 1, + col.codes.len(), + ); + eprintln!( + "[onpair search] footprint: dict {:.0} KiB + codes {:.0} KiB + code_offsets {:.0} KiB = {:.0} KiB core; \ + first_codes (search index) {:.0} KiB = +{:.2}% over core, +{:.2}% over input", + dict_b as f64 / 1024.0, + codes_b as f64 / 1024.0, + offs_b as f64 / 1024.0, + core as f64 / 1024.0, + first_b as f64 / 1024.0, + 100.0 * first_b as f64 / core as f64, + 100.0 * first_b as f64 / c.total_bytes as f64, + ); + col + }) +} + +// ───────────────────────────────────────────────────────────────────────────── +// Needle pre-pass: bucket candidates by selectivity. +// ───────────────────────────────────────────────────────────────────────────── + +#[derive(Copy, Clone, PartialEq, Eq)] +enum Mode { + Contains, + Prefix, +} + +struct Needle { + bucket: &'static str, + mode: Mode, + bytes: Vec, + selectivity: f64, +} + +impl fmt::Display for Needle { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + // e.g. common:"example"(58.1%) + write!( + f, + "{}:\"{}\"({:.1}%)", + self.bucket, + self.bytes.escape_ascii(), + self.selectivity * 100.0, + ) + } +} + +/// Buckets as (label, target selectivity, inclusive range). +const BUCKETS: &[(&str, f64, f64, f64)] = &[ + ("rare", 0.002, 0.0003, 0.02), + ("medium", 0.10, 0.03, 0.25), + ("common", 0.55, 0.40, 1.0), +]; + +const CAND_LENS: &[usize] = &[3, 5, 8, 12]; + +/// Count rows in `rows` matching `needle` under `mode`. Brute force. +fn brute_count(rows: &[Vec], needle: &[u8], mode: Mode) -> usize { + if needle.is_empty() { + return rows.len(); + } + match mode { + Mode::Prefix => rows.iter().filter(|r| r.starts_with(needle)).count(), + Mode::Contains => rows + .iter() + .filter(|r| r.len() >= needle.len() && r.windows(needle.len()).any(|w| w == needle)) + .count(), + } +} + +/// Pick one representative needle per (bucket, mode) by sampling candidate +/// substrings/prefixes and estimating their selectivity over a row sample. +fn select_needles() -> &'static [Needle] { + static NEEDLES: OnceLock> = OnceLock::new(); + NEEDLES.get_or_init(|| { + let rows = &corpus().rows; + + // Explicit override: `ONPAIR_NEEDLES="contains:google,prefix:http://"`. + // Each spec is `mode:text` (mode = contains|prefix); the bucket label is + // the literal text so the report and the C++ dump name it. Real + // selectivity is computed over the full corpus. Lets a specific query + // (e.g. the ClickBench `URL LIKE '%google%'`) be benchmarked directly. + if let Ok(spec) = env::var("ONPAIR_NEEDLES") { + let mut out = Vec::new(); + for item in spec.split(',').map(str::trim).filter(|s| !s.is_empty()) { + let (mode, text) = match item.split_once(':') { + Some(("prefix", t)) => (Mode::Prefix, t), + Some(("contains", t)) => (Mode::Contains, t), + _ => panic!( + "ONPAIR_NEEDLES item must be `contains:TEXT` or `prefix:TEXT`, got {item:?}" + ), + }; + let bytes = text.as_bytes().to_vec(); + let sel = brute_count(rows, &bytes, mode) as f64 / rows.len() as f64; + out.push(Needle { + bucket: Box::leak(text.to_string().into_boxed_str()), + mode, + bytes, + selectivity: sel, + }); + } + return out; + } + + // Deterministic sampler shared across phases. + let mut x = 0xD1B54A32D192ED03u64; + let mut next = |bound: usize| -> usize { + x = x + .wrapping_mul(6364136223846793005) + .wrapping_add(1442695040888963407); + ((x >> 33) as usize) % bound.max(1) + }; + + // Row sample used for cheap selectivity estimation. + let est_rows: Vec> = { + let take = rows.len().min(8000); + (0..take).map(|_| rows[next(rows.len())].clone()).collect() + }; + let est_n = est_rows.len() as f64; + + let mut out: Vec = Vec::new(); + for &mode in &[Mode::Contains, Mode::Prefix] { + // Generate candidates from random rows × candidate lengths, dedup. + let mut seen: std::collections::HashSet> = std::collections::HashSet::new(); + let mut cands: Vec> = Vec::new(); + let target = 700usize; + let mut tries = 0usize; + while cands.len() < target && tries < target * 20 { + tries += 1; + let row = &rows[next(rows.len())]; + if row.is_empty() { + continue; + } + let len = CAND_LENS[next(CAND_LENS.len())]; + if row.len() < len { + continue; + } + let start = match mode { + Mode::Prefix => 0, + Mode::Contains => next(row.len() - len + 1), + }; + let cand = row[start..start + len].to_vec(); + if seen.insert(cand.clone()) { + cands.push(cand); + } + } + + // Estimate selectivity for every candidate, then for each bucket + // keep the candidate whose selectivity is closest to the target. + let mut best: Vec)>> = vec![None; BUCKETS.len()]; + for cand in &cands { + let sel = brute_count(&est_rows, cand, mode) as f64 / est_n; + for (bi, &(_, tgt, lo, hi)) in BUCKETS.iter().enumerate() { + if sel < lo || sel > hi { + continue; + } + let dist = (sel - tgt).abs(); + let better = best[bi].as_ref().is_none_or(|(bdist, _)| dist < *bdist); + if better { + best[bi] = Some((dist, cand.clone())); + } + } + } + + for (bi, &(label, ..)) in BUCKETS.iter().enumerate() { + if let Some((_, bytes)) = best[bi].take() { + // Exact selectivity over the full corpus for the report. + let sel = brute_count(rows, &bytes, mode) as f64 / rows.len() as f64; + out.push(Needle { + bucket: label, + mode, + bytes, + selectivity: sel, + }); + } + } + } + out + }) +} + +fn contains_needles() -> Vec<&'static Needle> { + select_needles() + .iter() + .filter(|n| n.mode == Mode::Contains) + .collect() +} + +fn prefix_needles() -> Vec<&'static Needle> { + select_needles() + .iter() + .filter(|n| n.mode == Mode::Prefix) + .collect() +} + +// ───────────────────────────────────────────────────────────────────────────── +// Benches. +// ───────────────────────────────────────────────────────────────────────────── + +fn bench_search(bencher: Bencher, needle: &Needle) { + let parts = column().as_search_parts(); + let c = corpus(); + // Throughput is reported over the bytes the scan must stream and the rows + // it covers. `Contains` walks the whole code stream; `Prefix` (with the + // index) streams only the first-code table (2 B/row) in pass 1. + let bytes_scanned = match needle.mode { + Mode::Contains => parts.codes.len() * 2, + Mode::Prefix => c.rows.len() * 2, + }; + bencher + .counter(BytesCount::new(bytes_scanned)) + .counter(ItemsCount::new(c.rows.len())) + .bench_local(|| { + let pattern = match needle.mode { + Mode::Contains => Pattern::Contains(&needle.bytes), + Mode::Prefix => Pattern::Prefix(&needle.bytes), + }; + // Count via the callback primitive so the timing reflects the scan, + // not the result-mask allocation. + let mut matches = 0usize; + parts.search_callback(pattern, |_| matches += 1); + divan::black_box(matches) + }); +} + +#[divan::bench(args = contains_needles())] +fn contains(bencher: Bencher, needle: &Needle) { + bench_search(bencher, needle); +} + +#[divan::bench(args = prefix_needles())] +fn prefix(bencher: Bencher, needle: &Needle) { + bench_search(bencher, needle); +} + +/// Prefix via `search()` → `RowMask` (the bitmap-merge fast path: pass-1 accept +/// bits are written straight into the mask, no per-row callback). Contrast with +/// `prefix`, which exercises the per-row `search_callback` path. +#[divan::bench(args = prefix_needles())] +fn prefix_mask(bencher: Bencher, needle: &Needle) { + let parts = column().as_search_parts(); + let c = corpus(); + bencher + .counter(BytesCount::new(c.rows.len() * 2)) + .counter(ItemsCount::new(c.rows.len())) + .bench_local(|| { + let mask = parts.search(Pattern::Prefix(&needle.bytes)); + divan::black_box(popcount(mask.as_words())) + }); +} + +/// Count set bits across packed mask words. +fn popcount(words: &[u64]) -> usize { + words.iter().map(|w| w.count_ones() as usize).sum() +} + +/// A/B baseline: identical prefix search but with the first-token index +/// suppressed (`first_codes = None`), forcing the generic per-row scan. The +/// gap to `prefix` is the search index's runtime payoff. +#[divan::bench(args = prefix_needles())] +fn prefix_no_index(bencher: Bencher, needle: &Needle) { + let mut parts = column().as_search_parts(); + parts.first_codes = None; + let c = corpus(); + // Same denominator as `prefix` (rows × 2) so the two are directly + // comparable, though the no-index path actually streams the code stream. + bencher + .counter(BytesCount::new(c.rows.len() * 2)) + .counter(ItemsCount::new(c.rows.len())) + .bench_local(|| { + let mut matches = 0usize; + parts.search_callback(Pattern::Prefix(&needle.bytes), |_| matches += 1); + divan::black_box(matches) + }); +} + +// ───────────────────────────────────────────────────────────────────────────── +// Arrow-like baselines: evaluate the predicate over decompressed bytes the way +// an Arrow `StringArray` LIKE kernel does — a (values, offsets) buffer pair with +// `starts_with` (prefix) or `memchr::memmem` (contains, the finder Arrow's +// `contains`/`like` kernels use) — and pack the per-row verdict into a +// `BooleanBuffer` via `collect_bool`, the same 64-bits-per-word packer arrow-rs +// uses to build the `BooleanArray` result. This makes the baseline produce a +// packed bitmask comparable to onpair's `RowMask`, not just a counter. +// `*_arrow` assumes the data is already decompressed in memory; the +// `*_decompress_arrow` pair pays the onpair decompress first, so it is the true +// "decode then scan" alternative to compressed-domain search. +// ───────────────────────────────────────────────────────────────────────────── + +/// Evaluate `needle` over a decompressed `(bytes, offsets)` buffer +/// Arrow-`StringArray`-style, packing the verdicts into a `BooleanBuffer` with +/// `collect_bool` (the arrow-rs bitmask builder), and return its set-bit count. +fn arrow_mask(bytes: &[u8], offsets: &[u64], needle: &Needle) -> usize { + let n = offsets.len() - 1; + let mask = match needle.mode { + Mode::Prefix => arrow_buffer::BooleanBuffer::collect_bool(n, |r| { + bytes[offsets[r] as usize..offsets[r + 1] as usize].starts_with(&needle.bytes) + }), + Mode::Contains => { + let finder = memchr::memmem::Finder::new(&needle.bytes); + arrow_buffer::BooleanBuffer::collect_bool(n, |r| { + finder + .find(&bytes[offsets[r] as usize..offsets[r + 1] as usize]) + .is_some() + }) + } + }; + mask.count_set_bits() +} + +#[divan::bench(args = prefix_needles())] +fn prefix_arrow(bencher: Bencher, needle: &Needle) { + let c = corpus(); + bencher + .counter(BytesCount::new(c.total_bytes)) + .counter(ItemsCount::new(c.rows.len())) + .bench_local(|| divan::black_box(arrow_mask(&c.bytes, &c.offsets, needle))); +} + +#[divan::bench(args = contains_needles())] +fn contains_arrow(bencher: Bencher, needle: &Needle) { + let c = corpus(); + bencher + .counter(BytesCount::new(c.total_bytes)) + .counter(ItemsCount::new(c.rows.len())) + .bench_local(|| divan::black_box(arrow_mask(&c.bytes, &c.offsets, needle))); +} + +#[divan::bench(args = prefix_needles())] +fn prefix_decompress_arrow(bencher: Bencher, needle: &Needle) { + let col = column(); + let c = corpus(); + bencher + .counter(BytesCount::new(c.total_bytes)) + .counter(ItemsCount::new(c.rows.len())) + .bench_local(|| { + let bytes = decompress(col.as_parts()); + divan::black_box(arrow_mask(&bytes, &c.offsets, needle)) + }); +} + +#[divan::bench(args = contains_needles())] +fn contains_decompress_arrow(bencher: Bencher, needle: &Needle) { + let col = column(); + let c = corpus(); + bencher + .counter(BytesCount::new(c.total_bytes)) + .counter(ItemsCount::new(c.rows.len())) + .bench_local(|| { + let bytes = decompress(col.as_parts()); + divan::black_box(arrow_mask(&bytes, &c.offsets, needle)) + }); +} + +// ───────────────────────────────────────────────────────────────────────────── +// Roofline baselines. +// ───────────────────────────────────────────────────────────────────────────── +// All report throughput against the same logical corpus bytes as the search +// benches, so the GB/s column is directly comparable. +// +// copy_all_codes — read + write the whole codes stream (a memcpy of the +// compressed payload). The "decode would at least cost +// this" reference, and what prefix must beat to win. +// scan_all_codes — read every code once (no early exit). The hard floor +// for `contains`: it must look at every token of a +// non-matching row. +// first_code_per_row — read code_offsets + the first code of each row. The +// floor for `prefix`, which dies after ~one token. + +#[divan::bench] +fn copy_all_codes(bencher: Bencher) { + let codes = &column().codes; + let mut dst = vec![0u16; codes.len()]; + bencher + .counter(BytesCount::new(corpus().total_bytes)) + .bench_local(|| { + dst.copy_from_slice(codes); + divan::black_box(&dst); + }); +} + +#[divan::bench] +fn scan_all_codes(bencher: Bencher) { + let codes = &column().codes; + bencher + .counter(BytesCount::new(corpus().total_bytes)) + .bench_local(|| { + let mut acc = 0u64; + for &c in codes { + acc = acc.wrapping_add(c as u64); + } + divan::black_box(acc) + }); +} + +#[divan::bench] +fn first_code_per_row(bencher: Bencher) { + let col = column(); + bencher + .counter(BytesCount::new(corpus().total_bytes)) + .counter(ItemsCount::new(corpus().rows.len())) + .bench_local(|| { + let mut acc = 0u64; + for w in col.code_offsets.windows(2) { + if w[1] > w[0] { + acc ^= col.codes[w[0] as usize] as u64; + } + } + divan::black_box(acc) + }); +} + +fn main() { + // Touch corpus, column, and needles so the report prints before divan runs, + // and cross-check the compressed-domain count against brute force. + let _ = column(); + let rows = &corpus().rows; + eprintln!("[onpair search] selected needles (compressed-domain vs brute-force):"); + for n in select_needles() { + let mode = match n.mode { + Mode::Contains => "contains", + Mode::Prefix => "prefix", + }; + let mask = column().as_search_parts().search(match n.mode { + Mode::Contains => Pattern::Contains(&n.bytes), + Mode::Prefix => Pattern::Prefix(&n.bytes), + }); + let cd = popcount(mask.as_words()); + let bf = brute_count(rows, &n.bytes, n.mode); + let ok = if cd == bf { "ok" } else { "MISMATCH" }; + eprintln!( + " [{ok}] {mode:>8} {:>6} \"{}\" sel={:.3}% cd={cd} bf={bf}", + n.bucket, + n.bytes.escape_ascii(), + n.selectivity * 100.0, + ); + assert_eq!( + cd, bf, + "compressed-domain search disagrees with brute force" + ); + } + divan::main(); +} diff --git a/src/column.rs b/src/column.rs index 72a7274..8a34776 100644 --- a/src/column.rs +++ b/src/column.rs @@ -31,6 +31,16 @@ pub struct Column { /// emits these because a token may span a row boundary, so the row /// structure cannot be recovered from the codes alone. pub code_offsets: Vec, + /// Per-row first-token side-table (`R` entries when present): + /// `first_codes[r]` is the first code of row `r`, or [`u16::MAX`] for an + /// empty row. A contiguous child array that lets prefix search prefilter + /// rows with a single linear scan instead of a scattered + /// `codes[code_offsets[r]]` gather per row — see + /// [`crate::SearchParts::search`]. [`Parser::parse`](crate::Parser::parse) + /// always populates it (it costs 2 bytes per row); the [`Option`] is for + /// columns rehydrated from storage that did not persist it, in which case + /// prefix search falls back to the generic per-row scan. + pub first_codes: Option>, } /// Borrowed view of the data the decoder needs, consumed by diff --git a/src/lib.rs b/src/lib.rs index 75a600c..2c35618 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -46,6 +46,7 @@ mod hash; mod lpm; mod offset; mod parser; +mod search; mod trainer; mod types; @@ -67,6 +68,9 @@ pub use decompress::decompressed_len; pub use dict::Dictionary; pub use offset::Offset; pub use parser::Parser; +pub use search::Pattern; +pub use search::RowMask; +pub use search::SearchParts; pub use types::MAX_TOKEN_SIZE; /// Compress `bytes` / `offsets` end-to-end. Equivalent to diff --git a/src/offset.rs b/src/offset.rs index ae6a3f8..585e1cc 100644 --- a/src/offset.rs +++ b/src/offset.rs @@ -14,6 +14,13 @@ pub trait Offset: sealed::Sealed + Copy + Clone + Default + std::fmt::Debug + 's fn to_usize(self) -> Option; /// Construct from a `usize`, truncating if it does not fit. fn from_usize(n: usize) -> Self; + /// Convert to `usize` by truncation — the exact inverse of + /// [`from_usize`](Self::from_usize). For offsets that were validated at + /// construction (fit in `usize`, ≤ buffer length) this is lossless, and + /// unlike [`to_usize`](Self::to_usize) it is branchless: no fallible check, + /// no panic path. Use it on hot per-row paths over already-validated + /// offsets. + fn as_usize(self) -> usize; } impl Offset for u32 { @@ -25,6 +32,10 @@ impl Offset for u32 { fn from_usize(n: usize) -> Self { n as u32 } + #[inline] + fn as_usize(self) -> usize { + self as usize + } } impl Offset for u64 { @@ -36,4 +47,8 @@ impl Offset for u64 { fn from_usize(n: usize) -> Self { n as u64 } + #[inline] + fn as_usize(self) -> usize { + self as usize + } } diff --git a/src/parser.rs b/src/parser.rs index ddb40d8..7df828e 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -69,16 +69,32 @@ impl Parser { // 1-byte final token needs MAX_TOKEN_SIZE - 1 trailing bytes). See // `Parts::validate_dictionary`. dict_bytes.resize(dict_bytes.len() + (MAX_TOKEN_SIZE - 1), 0); + let first_codes = Some(first_codes(&codes, &code_offsets)); Column { dict_bytes, dict_offsets: self.dict.offsets.clone(), bits: self.dict.bits, codes, code_offsets, + first_codes, } } } +/// Build the per-row first-token side-table: `first_codes[r]` is the first +/// code of row `r`, or `u16::MAX` for an empty row (a sentinel that never +/// equals a real token id when the dictionary is not fully saturated). +pub(crate) fn first_codes(codes: &[u16], code_offsets: &[O]) -> Vec { + let n = code_offsets.len() - 1; + let mut out = Vec::with_capacity(n); + for r in 0..n { + let s = code_offsets[r].as_usize(); + let e = code_offsets[r + 1].as_usize(); + out.push(if s < e { codes[s] } else { u16::MAX }); + } + out +} + /// Encode every string into a flat `Vec` of codes plus per-row /// `code_offsets`. Offset `[i]..[i + 1]` indexes the codes for row `i`. The /// offsets are compressor metadata — a token may span a row boundary, so the diff --git a/src/search/kmp.rs b/src/search/kmp.rs new file mode 100644 index 0000000..05b81f3 --- /dev/null +++ b/src/search/kmp.rs @@ -0,0 +1,508 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors +// +// Port of `include/onpair/search/automata/kmp_automaton.h`. + +use super::{DictView, RowMatcher, TokenRange}; +use crate::types::Token; + +/// KMP state. A byte-level KMP over a pattern of length `m` has states +/// `0..=m`; `m` is the absorbing match state. Mirrors the C++ `uint8_t` so the +/// per-token `base` table stays one byte wide (it dominates cache footprint at +/// up to 64K tokens). Patterns are therefore capped at 255 bytes. +type State = u8; + +/// Tokens in `[range.begin, range.last]` transition the KMP from a given entry +/// state to `target` (overriding the entry-state-0 base transition). +#[derive(Copy, Clone)] +struct SparseTransition { + range: TokenRange, + target: State, +} + +/// Token-level KMP automaton for substring search (`col LIKE '%pattern%'`). +/// +/// Each token id transitions the KMP as if its bytes were fed one by one. The +/// transition table is stored in two tiers: +/// * `base[t]` — the exit state when entering token `t` from state 0 (the +/// common case once the automaton has not yet partially matched); +/// * `sparse` — for each non-zero entry state, the few token ranges whose +/// exit state differs from `base[t]`, grouped by entry state via `offsets`. +/// +/// [`matches`](RowMatcher::matches) splits the scan into a fast path (KMP +/// state 0, the common case) whose `base[]` loads carry no state between +/// iterations and so pipeline, and a slow partial-match path that consults the +/// sparse table. It stops the row the moment the match state is reached. +pub(crate) struct KmpAutomaton { + match_state: State, + /// `base[token]` = KMP exit state after consuming the token from state 0. + base: Vec, + /// Flattened sparse transitions grouped by entry state: the transitions for + /// entry state `s` live at `sparse[offsets[s]..offsets[s + 1]]`. + sparse: Vec, + offsets: Vec, +} + +/// Consume `data` from KMP state `s`, absorbing once the match state `m` is +/// reached. Direct port of the C++ `step_bytes` lambda. +#[inline] +fn step_bytes(p: &[u8], fail: &[State], m: usize, mut s: State, data: &[u8]) -> State { + for &b in data { + if s as usize == m { + return m as State; + } + while s > 0 && p[s as usize] != b { + s = fail[s as usize - 1]; + } + if p[s as usize] == b { + s += 1; + } + } + s +} + +impl KmpAutomaton { + pub(crate) fn new(pattern: &[u8], dict: DictView<'_>) -> Self { + let m = pattern.len(); + assert!( + m <= State::MAX as usize, + "onpair: contains needle exceeds 255 bytes" + ); + let num_tokens = dict.num_tokens(); + let match_state = m as State; + + if m == 0 { + return Self { + match_state: 0, + base: vec![0; num_tokens], + sparse: Vec::new(), + offsets: vec![0, 0], + }; + } + + let p = pattern; + + // ── 1. KMP failure table ──────────────────────────────────────────── + let mut fail = vec![0 as State; m]; + { + let mut i = 1usize; + let mut len = 0 as State; + while i < m { + if p[i] == p[len as usize] { + len += 1; + fail[i] = len; + i += 1; + } else if len > 0 { + len = fail[len as usize - 1]; + } else { + fail[i] = 0; + i += 1; + } + } + } + + // ── 2. Base pass ──────────────────────────────────────────────────── + let mut base = vec![0 as State; num_tokens]; + let p0 = p[0]; + for t in 0..num_tokens { + let tok = dict.data(t as Token); + base[t] = if tok.contains(&p0) { + step_bytes(p, &fail, m, 0, tok) + } else { + 0 + }; + } + + // ── 3. Sparse pass — dual-KMP trie traversal ──────────────────────── + let mut offsets = vec![0u32; m + 1]; + let mut pass = SparsePass { + dict, + p, + fail: &fail, + base: &base, + m, + sparse: Vec::new(), + range_start: 0, + }; + + let mut relevant: Vec = Vec::with_capacity(m); + for j in 1..m { + pass.range_start = pass.sparse.len(); + offsets[j] = pass.range_start as u32; + + // Only the bytes p[s] along the failure chain j → fail[j-1] → … → 0 + // can make state j diverge from state 0; gather and dedup them. + relevant.clear(); + let mut s = j as State; + while s > 0 { + relevant.push(p[s as usize]); + s = fail[s as usize - 1]; + } + relevant.sort_unstable(); + relevant.dedup(); + + for &byte in &relevant { + let range = dict.prefix_range(&[byte]); + if range.empty() { + continue; + } + let kmp_j = step_bytes(p, &fail, m, j as State, &[byte]); + let kmp_0 = step_bytes(p, &fail, m, 0, &[byte]); + pass.traverse(range, 1, kmp_j, kmp_0); + } + } + offsets[m] = pass.sparse.len() as u32; + // Move the sparse table out, ending the `&base` borrow held by `pass` + // so `base` itself can be moved into the returned automaton. + let sparse = pass.sparse; + + Self { + match_state, + base, + sparse, + offsets, + } + } + + /// Full KMP transition from `state` (in `1..match_state`) on token `t`: + /// consult the sparse exceptions for `state`, falling back to `base[t]`. + #[inline] + fn next_state(&self, state: State, t: Token) -> State { + let lo = self.offsets[state as usize] as usize; + let hi = self.offsets[state as usize + 1] as usize; + for tr in &self.sparse[lo..hi] { + if t < tr.range.begin { + break; + } + if t <= tr.range.last { + return tr.target; + } + } + self.base[t as usize] + } + + /// Per-token class table for the prefilter, one entry per token id: + /// Per-token table for the Teddy-style 2-code chain prefilter. Each entry is + /// an OR of three independent bit flags: + /// * [`CHAIN_DEFINITE`] — the token contains the whole needle + /// (`base == match_state`); a row holding it matches outright. + /// * [`CHAIN_OPEN`] — the token opens a partial match from state 0 + /// (`base != 0`); it can be the first token of a boundary-spanning match. + /// * [`CHAIN_CONT`] — feeding the token from *some* positive entry state + /// can leave the KMP positive (not dead); it can be the second token of a + /// boundary-spanning pair. Computed as a sound superset: `base != 0`, or + /// any sparse transition for any entry state has target `!= 0` and covers + /// this token id. + /// + /// The row scan then accepts a row iff it has a DEFINITE token, or a + /// consecutive pair `(open token, continue token)`. Soundness: in any + /// matching row with no DEFINITE token, walk the KMP state sequence back from + /// the match to the opener `j` (`s_{j-1}=0 < s_j`); token `j` is OPEN and the + /// next token `j+1` (which exists, since no token does `0 -> match` alone) + /// has a positive entry state staying positive, hence CONT. So every match + /// exhibits a DEFINITE token or an OPEN→CONT pair — no false negatives. + pub(crate) fn chain_table(&self) -> Vec { + let m = self.match_state; + let mut t: Vec = self + .base + .iter() + .map(|&b| { + if b == m { + CHAIN_DEFINITE | CHAIN_OPEN | CHAIN_CONT + } else if b != 0 { + CHAIN_OPEN | CHAIN_CONT + } else { + 0 + } + }) + .collect(); + // Any sparse transition with a non-dead target means the covered tokens + // can continue a partial match from that entry state: mark CONT. + for tr in &self.sparse { + if tr.target != 0 { + let lo = tr.range.begin as usize; + let hi = tr.range.last as usize; + for cell in &mut t[lo..=hi] { + *cell |= CHAIN_CONT; + } + } + } + t + } + + /// Contiguous token-id ranges of INNER tokens: those that complete or + /// continue a partial match — `base[t] == match_state` (DEFINITE) or covered + /// by a sparse transition with a non-dead target. Merged and sorted. + /// + /// INNER is a *sound necessary* contains filter: the token that completes any + /// match enters from state 0 (then `base == match_state`, DEFINITE) or from a + /// positive state via a sparse transition (INNER), so every matching row + /// holds an INNER token. Unlike the scattered open-set, these tokens cluster + /// into few contiguous ranges (the dictionary sorts by leading byte and a + /// continuation needs a specific next byte), so the filter is SIMD + /// range-testable. Returns `None` if there are more ranges than `max` (not + /// worth a per-code multi-range test). + pub(crate) fn inner_ranges(&self, max: usize) -> Option> { + let m = self.match_state; + let mut raw: Vec<(Token, Token)> = Vec::new(); + // DEFINITE runs in base. + let mut i = 0u32; + let n = self.base.len() as u32; + while i < n { + if self.base[i as usize] == m { + let mut j = i + 1; + while j < n && self.base[j as usize] == m { + j += 1; + } + raw.push((i as Token, (j - 1) as Token)); + i = j; + } else { + i += 1; + } + } + // Completing sparse transitions only (target == match state), and only + // from a boundary-reachable entry state. Two sound tightenings: + // 1. A row matches iff some boundary reaches `m`; the token completing + // that step enters from state 0 (DEFINITE, above) or via a sparse + // transition with target `m`. Partial→partial transitions can never + // be the completing token, so dropping them adds no false negative. + // 2. A completing transition from entry state `s` can only fire if a + // boundary ever lands on `s`. `reachable_states` over-approximates + // the reachable boundary states from the dictionary alone, so + // skipping transitions from unreachable `s` drops no true match. + // Both only ever remove false positives — KMP still confirms survivors. + let reach = self.reachable_states(); + for s in 1..m as usize { + if !reach[s] { + continue; + } + let lo = self.offsets[s] as usize; + let hi = self.offsets[s + 1] as usize; + for tr in &self.sparse[lo..hi] { + if tr.target == m { + raw.push((tr.range.begin, tr.range.last)); + } + } + } + if raw.is_empty() { + return Some(Vec::new()); + } + // Merge overlapping/adjacent ranges. + raw.sort_unstable(); + let mut merged: Vec<(Token, Token)> = Vec::with_capacity(raw.len()); + for (lo, hi) in raw { + if let Some(last) = merged.last_mut() + && lo <= last.1.saturating_add(1) + { + last.1 = last.1.max(hi); + continue; + } + merged.push((lo, hi)); + } + if merged.len() > max { + None + } else { + Some(merged) + } + } + + /// The set of DFA states that can occur at a token boundary, as a sound + /// over-approximation computed from the dictionary alone (no row data). + /// + /// Fixpoint over the real per-token transition function: state 0 (row start + /// / KMP death) is always reachable; state `s'` is reachable if some + /// boundary-reachable state `s` has a token `t` with `step(s, t) == s'`. + /// + /// Soundness: any boundary an actual row reaches is the image of a + /// (previous boundary state, token) pair, so it is included. This does NOT + /// model LPM (greedy tokenisation never taking a shorter token when a longer + /// one fits), so it can mark a state reachable that LPM forbids in practice + /// — that only adds false positives to a prefilter built from it, never a + /// false negative. + fn reachable_states(&self) -> Vec { + let m = self.match_state as usize; + let mut reach = vec![false; m + 1]; + reach[0] = true; + let nt = self.base.len(); + let mut changed = true; + while changed { + changed = false; + for s in 0..m { + if !reach[s] { + continue; + } + for t in 0..nt { + let ns = if s == 0 { + self.base[t] as usize + } else { + self.next_state(s as State, t as Token) as usize + }; + if !reach[ns] { + reach[ns] = true; + changed = true; + } + } + } + } + reach + } + + /// Whether the needle is empty (matches every row); the prefilter is skipped + /// for it. + #[inline] + pub(crate) fn is_empty_needle(&self) -> bool { + self.match_state == 0 + } +} + +/// [`chain_table`](KmpAutomaton::chain_table) flags. A token containing the +/// whole needle (definite match for any row holding it). +pub(crate) const CHAIN_DEFINITE: u8 = 4; +/// The token opens a partial match from state 0 (can start a spanning match). +pub(crate) const CHAIN_OPEN: u8 = 1; +/// The token can continue a partial match (can be the second of a spanning pair). +pub(crate) const CHAIN_CONT: u8 = 2; + +/// Scratch state for the sparse-transition trie traversal. Kept in a struct so +/// the recursion (bounded by `MAX_TOKEN_SIZE` depth) can be a method. +struct SparsePass<'a> { + dict: DictView<'a>, + p: &'a [u8], + fail: &'a [State], + base: &'a [State], + m: usize, + sparse: Vec, + range_start: usize, +} + +impl SparsePass<'_> { + /// Extend the last transition of the current group or push a new one. + /// Tokens are visited in ascending order, so adjacent same-target ranges + /// merge on the fly. + fn emit(&mut self, range: TokenRange, target: State) { + if self.sparse.len() > self.range_start { + let last = self.sparse.last_mut().expect("len checked above"); + if last.target == target && last.range.last as u32 + 1 == range.begin as u32 { + last.range.last = range.last; + return; + } + } + self.sparse.push(SparseTransition { range, target }); + } + + /// Traverse the implicit trie of the sorted dictionary over `tr`, tracking + /// the KMP state evolved from entry state `kmp_j` and from state 0 + /// (`kmp_0`) in parallel. Where they agree the subtree yields nothing and + /// is pruned. Direct port of the recursive C++ `traverse` lambda. + fn traverse(&mut self, tr: TokenRange, depth: usize, kmp_j: State, kmp_0: State) { + if kmp_j == kmp_0 || tr.empty() { + return; + } + + // Full match: override tokens whose base exit differs from m. + if kmp_j as usize == self.m { + let exit = self.m as State; + let last = tr.last as usize; + let mut i = tr.begin as usize; + while i <= last { + if self.base[i] != exit { + let start = i; + while i <= last && self.base[i] != exit { + i += 1; + } + self.emit( + TokenRange { + begin: start as Token, + last: (i - 1) as Token, + }, + exit, + ); + } else { + i += 1; + } + } + return; + } + + // Leaf tokens (length == depth) are fully consumed and share exit kmp_j. + let last = tr.last as usize; + let mut cur = tr.begin as usize; + while cur <= last && self.dict.token_size(cur as Token) == depth { + cur += 1; + } + if cur > tr.begin as usize { + self.emit( + TokenRange { + begin: tr.begin, + last: (cur - 1) as Token, + }, + kmp_j, + ); + } + if cur > last { + return; + } + + // Recurse into subtrees partitioned by the byte at `depth`. + while cur <= last { + let c = self.dict.data(cur as Token)[depth]; + let mut sub_hi = cur; + while sub_hi < last && self.dict.data((sub_hi + 1) as Token)[depth] == c { + sub_hi += 1; + } + let nj = step_bytes(self.p, self.fail, self.m, kmp_j, &[c]); + let n0 = step_bytes(self.p, self.fail, self.m, kmp_0, &[c]); + self.traverse( + TokenRange { + begin: cur as Token, + last: sub_hi as Token, + }, + depth + 1, + nj, + n0, + ); + cur = sub_hi + 1; + } + } +} + +impl RowMatcher for KmpAutomaton { + #[inline] + fn matches(&self, codes: &[Token]) -> bool { + // Empty needle matches every row. + if self.match_state == 0 { + return true; + } + let base = self.base.as_slice(); + let match_state = self.match_state; + let n = codes.len(); + let mut i = 0usize; + while i < n { + // Fast path: KMP state 0. `base[code]` loads are independent across + // iterations (no state carried between them), so the CPU pipelines + // them — no `is_dead`/`state > 0` branch, no sparse lookup. + let s = base[codes[i] as usize]; + i += 1; + if s != 0 { + if s == match_state { + return true; + } + // Slow path: a partial match is open. Step carefully (sparse + // exceptions + base) until it completes, dies back to state 0, + // or the row ends. + let mut state = s; + while i < n { + state = self.next_state(state, codes[i]); + i += 1; + if state == match_state { + return true; + } + if state == 0 { + break; + } + } + } + } + false + } +} diff --git a/src/search/mod.rs b/src/search/mod.rs new file mode 100644 index 0000000..8611a6d --- /dev/null +++ b/src/search/mod.rs @@ -0,0 +1,1190 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Compressed-domain prefix / substring search. +//! +//! Rust port of the token-level search automata in the reference C++ +//! implementation (`include/onpair/search/automata/*`). The central idea: a +//! column's bytes are encoded as a stream of dictionary token ids, so instead +//! of decompressing each row and running a byte matcher, we run a small +//! deterministic automaton **directly over the token ids**. Every input byte +//! becomes part of one token, so a `T`-token row costs `T` automaton steps +//! regardless of how many bytes it decodes to — and matches early-exit. +//! +//! Two predicates are supported, expressed as [`Pattern`]: +//! * [`Pattern::Prefix`] — `col LIKE 'needle%'`, via [`prefix::PrefixAutomaton`]. +//! * [`Pattern::Contains`] — `col LIKE '%needle%'`, via [`kmp::KmpAutomaton`]. +//! +//! Both automata are built once per query against the (sorted) dictionary and +//! then driven over every row. Construction relies on two dictionary +//! properties guaranteed by [`crate::Parser::train`]: the token ids are in +//! lexicographic order, and the 256 single-byte tokens are always present. + +mod kmp; +mod prefix; +mod tokenize; + +use crate::column::Column; +use crate::offset::Offset; +use crate::types::{MAX_TOKEN_SIZE, Token}; + +use kmp::{CHAIN_CONT, CHAIN_DEFINITE, CHAIN_OPEN, KmpAutomaton}; +use prefix::PrefixAutomaton; + +/// A search predicate evaluated against every row of a compressed column, +/// without decompressing it. Borrows the needle bytes for the duration of the +/// search. +#[derive(Copy, Clone, Debug)] +pub enum Pattern<'a> { + /// Matches rows whose decoded bytes begin with the needle + /// (SQL `col LIKE 'needle%'`). + Prefix(&'a [u8]), + /// Matches rows whose decoded bytes contain the needle anywhere + /// (SQL `col LIKE '%needle%'`). + Contains(&'a [u8]), +} + +// ───────────────────────────────────────────────────────────────────────────── +// TokenRange — closed range of token ids [begin, last]; begin > last is empty. +// ───────────────────────────────────────────────────────────────────────────── + +/// Closed range of token ids `[begin, last]`. The default-constructed +/// `{ begin: 1, last: 0 }` is the canonical empty range. +#[derive(Copy, Clone, Debug)] +pub(crate) struct TokenRange { + pub(crate) begin: Token, + pub(crate) last: Token, +} + +impl TokenRange { + /// Canonical empty range (`begin > last`). + pub(crate) const EMPTY: Self = Self { begin: 1, last: 0 }; + + #[inline] + pub(crate) fn empty(self) -> bool { + self.begin > self.last + } + + #[inline] + pub(crate) fn contains(self, t: Token) -> bool { + t >= self.begin && t <= self.last + } +} + +// ───────────────────────────────────────────────────────────────────────────── +// DictView — borrowed, read-only view over a column's sorted dictionary. +// ───────────────────────────────────────────────────────────────────────────── + +/// Borrowed view over the `(bytes, offsets)` of a sorted dictionary. Mirrors +/// the C++ `DictionaryView`: O(1) token access plus O(log n) prefix-range +/// lookups via binary search over the sorted token ids. +#[derive(Copy, Clone)] +pub(crate) struct DictView<'a> { + bytes: &'a [u8], + offsets: &'a [u32], +} + +impl<'a> DictView<'a> { + #[inline] + fn num_tokens(self) -> usize { + self.offsets.len() - 1 + } + + #[inline] + fn token_size(self, id: Token) -> usize { + (self.offsets[id as usize + 1] - self.offsets[id as usize]) as usize + } + + #[inline] + fn data(self, id: Token) -> &'a [u8] { + let s = self.offsets[id as usize] as usize; + let e = self.offsets[id as usize + 1] as usize; + &self.bytes[s..e] + } + + /// First token id in `[start, num_tokens)` whose bytes are `>= target` + /// under the dictionary's sort order (shorter token sorts before a longer + /// one sharing its prefix). Direct port of the C++ `lower_bound` lambda. + fn lower_bound(self, target: &[u8], start: u32) -> u32 { + let n = self.num_tokens() as u32; + let (mut lo, mut hi) = (start, n); + while lo < hi { + let mid = lo + ((hi - lo) >> 1); + let tok = self.data(mid as Token); + let mlen = tok.len(); + let clen = mlen.min(target.len()); + let cmp = tok[..clen].cmp(&target[..clen]); + // token[mid] < target iff cmp < 0, or equal-prefix and token shorter. + if cmp.is_lt() || (cmp.is_eq() && mlen < target.len()) { + lo = mid + 1; + } else { + hi = mid; + } + } + lo + } + + /// `[lo, hi]` token-id range whose byte sequences share `prefix`, or the + /// empty range if none do. Port of `DictionaryView::prefix_range`. + fn prefix_range(self, prefix: &[u8]) -> TokenRange { + // A prefix longer than any token can never match. + if prefix.len() > MAX_TOKEN_SIZE { + return TokenRange::EMPTY; + } + let n = self.num_tokens() as u32; + + let lo = self.lower_bound(prefix, 0); + + // Next lexicographic prefix: increment the last non-0xFF byte after + // trimming trailing 0xFF bytes. If all bytes are 0xFF the prefix has no + // successor, so the range runs to the end of the dictionary. + let mut buf = [0u8; MAX_TOKEN_SIZE]; + let mut ulen = prefix.len(); + let mut overflow = true; + while ulen > 0 { + if prefix[ulen - 1] < 0xFF { + buf[..ulen].copy_from_slice(&prefix[..ulen]); + buf[ulen - 1] += 1; + overflow = false; + break; + } + ulen -= 1; + } + + // hi >= lo always, so the second search starts from lo, not 0. + let hi = if overflow { + n + } else { + self.lower_bound(&buf[..ulen], lo) + }; + + if lo < hi { + TokenRange { + begin: lo as Token, + last: (hi - 1) as Token, + } + } else { + TokenRange::EMPTY + } + } +} + +// ───────────────────────────────────────────────────────────────────────────── +// Row matcher. +// ───────────────────────────────────────────────────────────────────────────── + +/// A compiled query that decides whether one row's token sequence matches. +/// +/// Stateless across rows: all per-row state lives in [`matches`](Self::matches) +/// locals, so one matcher is built per query and reused for every row (no +/// reset between rows, and it can be shared by reference). +pub(crate) trait RowMatcher { + /// Whether the row whose codes are `codes` matches. + fn matches(&self, codes: &[Token]) -> bool; +} + +/// Run `matcher` over every row delimited by `code_offsets`, invoking +/// `on_match` with the index of each matching row. +#[inline] +fn scan( + matcher: &impl RowMatcher, + codes: &[Token], + code_offsets: &[O], + mut on_match: impl FnMut(usize), +) { + for r in 0..code_offsets.len() - 1 { + let s = code_offsets[r].as_usize(); + let e = code_offsets[r + 1].as_usize(); + if matcher.matches(&codes[s..e]) { + on_match(r); + } + } +} + +/// Whether the AVX2 pass-1 kernels should be used: the CPU supports AVX2 and +/// the `ONPAIR_NO_SIMD` benchmarking escape hatch is unset. Resolved once. +#[cfg(target_arch = "x86_64")] +fn avx2_enabled() -> bool { + use std::sync::atomic::{AtomicU8, Ordering}; + static STATE: AtomicU8 = AtomicU8::new(u8::MAX); // MAX = not yet resolved + let cached = STATE.load(Ordering::Relaxed); + if cached != u8::MAX { + return cached == 1; + } + let on = std::is_x86_feature_detected!("avx2") && std::env::var_os("ONPAIR_NO_SIMD").is_none(); + STATE.store(on as u8, Ordering::Relaxed); + on +} + +/// Whether the AVX-512BW prefix kernel should be used: the CPU supports +/// AVX-512BW and SIMD is not disabled. Measured ~1.2× faster than the AVX2 +/// prefix kernel (32 `u16`/vector + mask-register output, no pack/movemask). +/// `ONPAIR_NO_SIMD` disables it; `ONPAIR_NO_AVX512` forces the AVX2 path for A/B. +/// Resolved once. +#[cfg(target_arch = "x86_64")] +fn avx512_enabled() -> bool { + use std::sync::atomic::{AtomicU8, Ordering}; + static STATE: AtomicU8 = AtomicU8::new(u8::MAX); + let cached = STATE.load(Ordering::Relaxed); + if cached != u8::MAX { + return cached == 1; + } + let on = std::is_x86_feature_detected!("avx512bw") + && std::env::var_os("ONPAIR_NO_SIMD").is_none() + && std::env::var_os("ONPAIR_NO_AVX512").is_none(); + STATE.store(on as u8, Ordering::Relaxed); + on +} + +/// Verdict of the Teddy-style 2-code chain row filter; see [`row_chain`]. +enum RowChain { + /// A token contains the whole needle — the row matches outright. + Definite, + /// A consecutive open→continue token pair exists — confirm with exact KMP. + Candidate, + /// Neither — the row cannot contain the needle. + Reject, +} + +/// Teddy-style 2-code chain filter over one row's codes, using the per-token +/// [`chain_table`](kmp::KmpAutomaton::chain_table). Carries the previous token's +/// `CHAIN_OPEN` bit and looks for a consecutive pair `(open, continue)` — far +/// more selective than "any opener present", since a boundary-spanning match +/// needs an opener token *immediately followed* by a continuation token. +/// +/// Returns [`RowChain::Definite`] on the first DEFINITE token (the row matches +/// with no KMP needed), [`RowChain::Candidate`] if an open→continue pair occurs +/// (a spanning match is possible; the exact KMP confirms), else +/// [`RowChain::Reject`]. The early `return` on a definite token keeps LLVM from +/// the slower auto-vectorized gather (verified in the asm); the loop body is a +/// single scattered `chain[code]` load plus a register-carried `prev_open`. +#[inline] +fn row_chain(chain: &[u8], codes: &[Token]) -> RowChain { + let mut prev_open = false; + let mut candidate = false; + for &c in codes { + let f = chain[c as usize]; + if f & CHAIN_DEFINITE != 0 { + return RowChain::Definite; + } + // Open token immediately followed by a continue token = possible + // boundary-spanning match. + candidate |= prev_open && (f & CHAIN_CONT != 0); + prev_open = f & CHAIN_OPEN != 0; + } + if candidate { + RowChain::Candidate + } else { + RowChain::Reject + } +} + +/// Pass-1 accept filter: set bit `r` of `acc` iff `first_codes[r]` lies in the +/// inclusive accept range `[alo, alo + awidth]` (unsigned). Branchless; +/// dispatches to AVX2 when available. Precondition for the SIMD path: the range +/// is non-empty (`alo <= u16::MAX`), which holds for every single-token query. +#[inline] +fn prefilter_accept(first_codes: &[u16], alo: u32, awidth: u32, acc: &mut [u64]) { + #[cfg(target_arch = "x86_64")] + if alo <= u16::MAX as u32 { + if avx512_enabled() { + // SAFETY: avx512bw just confirmed present. + unsafe { prefilter_accept_avx512(first_codes, alo as u16, awidth as u16, acc) }; + return; + } + if avx2_enabled() { + // SAFETY: avx2 just confirmed present. + unsafe { prefilter_accept_avx2(first_codes, alo as u16, awidth as u16, acc) }; + return; + } + } + prefilter_accept_scalar(first_codes, alo, awidth, acc); +} + +/// Scalar fully-branchless accept filter: `(fc - alo) <= awidth` lowers to a +/// `sub` + unsigned compare with no branch, accumulated into one bitset word +/// per 64 rows. +#[inline] +fn prefilter_accept_scalar(first_codes: &[u16], alo: u32, awidth: u32, acc: &mut [u64]) { + for (word, chunk) in acc.iter_mut().zip(first_codes.chunks(64)) { + let mut w = 0u64; + for (i, &fc) in chunk.iter().enumerate() { + w |= u64::from((fc as u32).wrapping_sub(alo) <= awidth) << i; + } + *word = w; + } +} + +/// Pass-1 accept + verify filter: as [`prefilter_accept`], but also sets bit +/// `r` of `ver` iff `first_codes[r] == vpoint`. The two predicates are disjoint +/// (`vpoint < alo`), so no row lands in both. Branchless; dispatches to AVX2. +#[inline] +fn prefilter_accept_verify( + first_codes: &[u16], + alo: u32, + awidth: u32, + vpoint: u32, + acc: &mut [u64], + ver: &mut [u64], +) { + #[cfg(target_arch = "x86_64")] + if avx2_enabled() { + // An empty accept range (alo > u16::MAX) is encoded by disabling the + // accept compare; vpoint is always a real `u16` here (multi-token q0). + let (alo16, awidth16, aenable) = if alo <= u16::MAX as u32 { + (alo as u16, awidth as u16, 0xFFFFu16) + } else { + (0, 0, 0) + }; + // SAFETY: avx2 just confirmed present. + unsafe { + prefilter_accept_verify_avx2( + first_codes, + alo16, + awidth16, + aenable, + vpoint as u16, + acc, + ver, + ) + }; + return; + } + prefilter_accept_verify_scalar(first_codes, alo, awidth, vpoint, acc, ver); +} + +/// Scalar fully-branchless accept + verify filter. +#[inline] +fn prefilter_accept_verify_scalar( + first_codes: &[u16], + alo: u32, + awidth: u32, + vpoint: u32, + acc: &mut [u64], + ver: &mut [u64], +) { + for ((accw, verw), chunk) in acc + .iter_mut() + .zip(ver.iter_mut()) + .zip(first_codes.chunks(64)) + { + let mut a = 0u64; + let mut v = 0u64; + for (i, &fc) in chunk.iter().enumerate() { + let fc = fc as u32; + a |= u64::from(fc.wrapping_sub(alo) <= awidth) << i; + v |= u64::from(fc == vpoint) << i; + } + *accw = a; + *verw = v; + } +} + +/// Maximum INNER range count for which the SIMD multi-range contains pass-1 is +/// attempted; above this the per-code range chain outweighs the scalar gather. +const INNER_RANGE_BUDGET: usize = 16; + +/// Set bit `i` of `bits` iff `codes[i]` lies in any of the (sorted, merged) +/// INNER `ranges`. Dispatches to AVX2 when available. +fn classify_inner(codes: &[Token], ranges: &[(Token, Token)], bits: &mut [u64]) { + #[cfg(target_arch = "x86_64")] + if avx2_enabled() { + // SAFETY: avx2 confirmed present. + unsafe { classify_inner_avx2(codes, ranges, bits) }; + return; + } + classify_inner_scalar(codes, ranges, bits); +} + +/// Scalar reference for [`classify_inner`]. +fn classify_inner_scalar(codes: &[Token], ranges: &[(Token, Token)], bits: &mut [u64]) { + for (i, &c) in codes.iter().enumerate() { + if ranges.iter().any(|&(lo, hi)| c >= lo && c <= hi) { + bits[i >> 6] |= 1u64 << (i & 63); + } + } +} + +/// Whether any bit in `bits[lo..hi]` (bit indices) is set. +#[inline] +fn any_bit_in_range(bits: &[u64], lo: usize, hi: usize) -> bool { + if lo >= hi { + return false; + } + let (wlo, whi) = (lo >> 6, (hi - 1) >> 6); + if wlo == whi { + let mask = (!0u64 << (lo & 63)) & (!0u64 >> (63 - ((hi - 1) & 63))); + return bits[wlo] & mask != 0; + } + if bits[wlo] & (!0u64 << (lo & 63)) != 0 { + return true; + } + if bits[wlo + 1..whi].iter().any(|&w| w != 0) { + return true; + } + bits[whi] & (!0u64 >> (63 - ((hi - 1) & 63))) != 0 +} + +/// Invoke `f` with the index of every set bit in `words`, in ascending order. +#[inline] +fn for_each_set_bit(words: &[u64], mut f: impl FnMut(usize)) { + for (w, &word) in words.iter().enumerate() { + let mut bits = word; + let base = w * 64; + while bits != 0 { + f(base + bits.trailing_zeros() as usize); + bits &= bits - 1; + } + } +} + +// ───────────────────────────────────────────────────────────────────────────── +// AVX2 pass-1 kernels. The range filter over the contiguous first-token table +// is a pure SIMD shape: one `sub` + unsigned compare per lane, 16 u16 rows per +// vector, packed straight into the candidate bitset words. +// ───────────────────────────────────────────────────────────────────────────── + +#[cfg(target_arch = "x86_64")] +use std::arch::x86_64::*; + +/// Reduce 16 `u16` lanes that are each `0xFFFF` (true) or `0x0000` (false) to a +/// 16-bit mask, bit `i` from lane `i`. +#[cfg(target_arch = "x86_64")] +#[inline] +#[target_feature(enable = "avx2")] +unsafe fn movemask_epu16(v: __m256i) -> u32 { + // Saturating pack i16->i8 maps 0xFFFF (-1) -> 0xFF and 0 -> 0, preserving + // lane order across the two 128-bit halves, then one byte movemask. + let lo = _mm256_castsi256_si128(v); + let hi = _mm256_extracti128_si256::<1>(v); + _mm_movemask_epi8(_mm_packs_epi16(lo, hi)) as u32 +} + +/// Lanewise `(fc - alo) <= awidth`, unsigned, as a `0xFFFF`/`0` mask vector. +#[cfg(target_arch = "x86_64")] +#[inline] +#[target_feature(enable = "avx2")] +unsafe fn in_range_epu16(v: __m256i, valo: __m256i, vawidth: __m256i) -> __m256i { + let sub = _mm256_sub_epi16(v, valo); + // Unsigned `sub <= awidth` == `min_epu16(sub, awidth) == sub`. + _mm256_cmpeq_epi16(_mm256_min_epu16(sub, vawidth), sub) +} + +/// AVX2 accept filter; see [`prefilter_accept`]. +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx2")] +unsafe fn prefilter_accept_avx2(first_codes: &[u16], alo: u16, awidth: u16, acc: &mut [u64]) { + let valo = _mm256_set1_epi16(alo as i16); + let vawidth = _mm256_set1_epi16(awidth as i16); + let n = first_codes.len(); + let ptr = first_codes.as_ptr(); + let mut r = 0usize; + let mut wi = 0usize; + while r + 64 <= n { + let mut word = 0u64; + for k in 0..4 { + // SAFETY: r + k*16 + 16 <= r + 64 <= n, in bounds; both helpers are + // avx2, confirmed present. + let v = unsafe { _mm256_loadu_si256(ptr.add(r + k * 16) as *const __m256i) }; + let m = unsafe { movemask_epu16(in_range_epu16(v, valo, vawidth)) }; + word |= (m as u64) << (k * 16); + } + acc[wi] = word; + wi += 1; + r += 64; + } + if r < n { + prefilter_accept_scalar(&first_codes[r..], alo as u32, awidth as u32, &mut acc[wi..]); + } +} + +/// AVX-512BW accept filter (experiment #3): 32 `u16` codes per vector, one +/// `vpsubw` + `vpcmpuw` (`cmple_epu16`) yielding a `__mmask32` directly — no +/// pack/movemask reduction. Two masks compose one 64-bit bitset word. +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx512bw,avx512f")] +unsafe fn prefilter_accept_avx512(first_codes: &[u16], alo: u16, awidth: u16, acc: &mut [u64]) { + let valo = _mm512_set1_epi16(alo as i16); + let vawidth = _mm512_set1_epi16(awidth as i16); + let n = first_codes.len(); + let ptr = first_codes.as_ptr(); + let mut r = 0usize; + let mut wi = 0usize; + while r + 64 <= n { + // SAFETY: r + 32 and r + 64 are <= n. + let v0 = unsafe { _mm512_loadu_si512(ptr.add(r) as *const __m512i) }; + let v1 = unsafe { _mm512_loadu_si512(ptr.add(r + 32) as *const __m512i) }; + // Unsigned (fc - alo) <= awidth, directly to a 32-bit mask register. + let m0 = _mm512_cmple_epu16_mask(_mm512_sub_epi16(v0, valo), vawidth); + let m1 = _mm512_cmple_epu16_mask(_mm512_sub_epi16(v1, valo), vawidth); + acc[wi] = (m0 as u64) | ((m1 as u64) << 32); + wi += 1; + r += 64; + } + if r < n { + prefilter_accept_scalar(&first_codes[r..], alo as u32, awidth as u32, &mut acc[wi..]); + } +} + +/// AVX2 multi-range INNER classifier; see [`classify_inner`]. For each 16-code +/// vector, OR together one `in_range_epu16` per INNER range, pack to a 16-bit +/// mask, and accumulate into the per-code bitset words (64 codes per word). +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx2")] +unsafe fn classify_inner_avx2(codes: &[Token], ranges: &[(Token, Token)], bits: &mut [u64]) { + // Preload (lo, width) vectors for each range. + let zero = _mm256_setzero_si256(); + let mut vlo = [zero; INNER_RANGE_BUDGET]; + let mut vw = [zero; INNER_RANGE_BUDGET]; + for (i, &(lo, hi)) in ranges.iter().enumerate() { + vlo[i] = _mm256_set1_epi16(lo as i16); + vw[i] = _mm256_set1_epi16((hi - lo) as i16); + } + let nr = ranges.len(); + let n = codes.len(); + let ptr = codes.as_ptr(); + let (mut r, mut wi) = (0usize, 0usize); + while r + 64 <= n { + let mut word = 0u64; + for k in 0..4 { + // SAFETY: r + k*16 + 16 <= r + 64 <= n. + let v = unsafe { _mm256_loadu_si256(ptr.add(r + k * 16) as *const __m256i) }; + let mut hit = zero; + for vrange in vlo.iter().zip(vw.iter()).take(nr) { + // SAFETY: both helpers are avx2, enabled for this fn. + hit = _mm256_or_si256(hit, unsafe { in_range_epu16(v, *vrange.0, *vrange.1) }); + } + // SAFETY: avx2 enabled for this fn. + word |= (unsafe { movemask_epu16(hit) } as u64) << (k * 16); + } + bits[wi] = word; + wi += 1; + r += 64; + } + if r < n { + classify_inner_scalar(&codes[r..], ranges, &mut bits[wi..]); + } +} + +/// AVX2 accept + verify filter; see [`prefilter_accept_verify`]. +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx2")] +unsafe fn prefilter_accept_verify_avx2( + first_codes: &[u16], + alo: u16, + awidth: u16, + aenable: u16, + vpoint: u16, + acc: &mut [u64], + ver: &mut [u64], +) { + let valo = _mm256_set1_epi16(alo as i16); + let vawidth = _mm256_set1_epi16(awidth as i16); + let vaenable = _mm256_set1_epi16(aenable as i16); + let vvpoint = _mm256_set1_epi16(vpoint as i16); + let n = first_codes.len(); + let ptr = first_codes.as_ptr(); + let mut r = 0usize; + let mut wi = 0usize; + while r + 64 <= n { + let mut accword = 0u64; + let mut verword = 0u64; + for k in 0..4 { + // SAFETY: r + k*16 + 16 <= r + 64 <= n, in bounds; helpers are avx2. + let v = unsafe { _mm256_loadu_si256(ptr.add(r + k * 16) as *const __m256i) }; + // Accept, masked off when the range is empty (aenable == 0). + let accl = _mm256_and_si256(unsafe { in_range_epu16(v, valo, vawidth) }, vaenable); + let verl = _mm256_cmpeq_epi16(v, vvpoint); + accword |= (unsafe { movemask_epu16(accl) } as u64) << (k * 16); + verword |= (unsafe { movemask_epu16(verl) } as u64) << (k * 16); + } + acc[wi] = accword; + ver[wi] = verword; + wi += 1; + r += 64; + } + if r < n { + // Reproduce the empty-range encoding for the scalar tail: alo = u32::MAX + // makes `(fc - alo) <= 0` false for every real first code. + let (talo, tawidth) = if aenable != 0 { + (alo as u32, awidth as u32) + } else { + (u32::MAX, 0) + }; + prefilter_accept_verify_scalar( + &first_codes[r..], + talo, + tawidth, + vpoint as u32, + &mut acc[wi..], + &mut ver[wi..], + ); + } +} + +// ───────────────────────────────────────────────────────────────────────────── +// RowMask — packed result bitset. +// ───────────────────────────────────────────────────────────────────────────── + +/// Result of a [`search`](SearchParts::search): a packed bitmap over the +/// column's rows, one bit per row. Bit `i` is set iff row `i` matched. +/// +/// The packed `u64` representation composes directly with a query engine's +/// own selection vectors (AND/OR of masks is word-wise), and is compact even +/// when most rows match. +#[derive(Clone, Debug, Default, PartialEq, Eq)] +pub struct RowMask { + words: Vec, + rows: usize, +} + +impl RowMask { + /// All-zero mask sized for `rows` rows. + fn zeros(rows: usize) -> Self { + Self { + words: vec![0; rows.div_ceil(64)], + rows, + } + } + + #[inline] + fn set(&mut self, i: usize) { + self.words[i >> 6] |= 1u64 << (i & 63); + } + + /// Number of rows the mask covers (set or not). The bitmap has + /// `len().div_ceil(64)` words; bits at indices `>= len()` in the final word + /// are zero. + #[inline] + pub fn len(&self) -> usize { + self.rows + } + + /// Whether the mask covers zero rows. + #[inline] + pub fn is_empty(&self) -> bool { + self.rows == 0 + } + + /// Borrow the packed bitmap words (bit `i` = row `i`, LSB-first within each + /// word). Compose directly with a query engine's own selection vectors via + /// word-wise AND/OR. Length is `len().div_ceil(64)`. + #[inline] + pub fn as_words(&self) -> &[u64] { + &self.words + } + + /// Consume the mask into its owned `(words, len)` parts: the packed bitmap + /// and the row count it covers. Inverse shape of the borrowed + /// [`as_words`](Self::as_words) + [`len`](Self::len). + #[inline] + pub fn into_parts(self) -> (Vec, usize) { + (self.words, self.rows) + } +} + +// ───────────────────────────────────────────────────────────────────────────── +// SearchParts — borrowed view of the data search needs. +// ───────────────────────────────────────────────────────────────────────────── + +/// Borrowed view of everything compressed-domain search needs: the sorted +/// dictionary plus the per-row code stream. Mirrors [`crate::Parts`] (the +/// decode view) but additionally carries `code_offsets`, the row delimiters a +/// row-wise scan requires. +/// +/// Build one cheaply from an owned column with +/// [`Column::as_search_parts`], or by struct literal from data deserialized out +/// of storage. Like [`crate::Parts`], the fields are public and unchecked: the +/// search methods index `codes` by `code_offsets` without revalidating, so a +/// hand-built view must keep `code_offsets` monotonic and in bounds (a view +/// from `as_search_parts` always is). +#[derive(Copy, Clone, Debug)] +pub struct SearchParts<'a, O: Offset> { + /// Dictionary bytes (sorted token order). Mirrors [`Column::dict_bytes`]. + pub dict_bytes: &'a [u8], + /// Token byte ranges into `dict_bytes`. Mirrors [`Column::dict_offsets`]. + pub dict_offsets: &'a [u32], + /// Encoded tokens, row-concatenated. Mirrors [`Column::codes`]. + pub codes: &'a [u16], + /// `R + 1` offsets into `codes` delimiting the `R` rows: row `r`'s codes + /// are `codes[code_offsets[r]..code_offsets[r + 1]]`. Mirrors + /// [`Column::code_offsets`]. + pub code_offsets: &'a [O], + /// Optional per-row first token id (`R` entries); mirrors + /// [`Column::first_codes`]. When present, it is used as a contiguous + /// prefilter for [`Pattern::Prefix`] searches; when `None`, prefix search + /// falls back to the generic per-row scan. + pub first_codes: Option<&'a [u16]>, +} + +impl SearchParts<'_, O> { + #[inline] + fn dict(&self) -> DictView<'_> { + DictView { + bytes: self.dict_bytes, + offsets: self.dict_offsets, + } + } + + /// Number of rows in the view. + #[inline] + fn num_rows(&self) -> usize { + self.code_offsets.len().saturating_sub(1) + } + + /// Codes of row `r`: `codes[code_offsets[r]..code_offsets[r + 1]]`. The + /// offsets are a caller-upheld invariant (monotonic, in bounds — see the + /// type docs), so the conversion is the branchless [`Offset::as_usize`]. + #[inline] + fn row_codes(&self, r: usize) -> &[Token] { + let s = self.code_offsets[r].as_usize(); + let e = self.code_offsets[r + 1].as_usize(); + &self.codes[s..e] + } + + /// Evaluate `pattern` against every row, invoking `on_match` with the + /// 0-based index of each matching row, in order. The low-level primitive + /// [`search`](Self::search) builds its [`RowMask`] on top of. + pub fn search_callback(&self, pattern: Pattern<'_>, on_match: impl FnMut(usize)) { + let dict = self.dict(); + match pattern { + Pattern::Contains(needle) => { + let aut = KmpAutomaton::new(needle, dict); + self.scan_contains(&aut, on_match); + } + Pattern::Prefix(needle) => { + let aut = PrefixAutomaton::new(needle, dict); + self.scan_prefix(&aut, dict.num_tokens(), on_match); + } + } + } + + /// Contains scan in two passes over the whole code stream. + /// + /// Unlike prefix (which need only inspect each row's first token), a + /// substring can begin at any token, so pass 1 must stream every code. Using + /// the KMP [`class_table`](KmpAutomaton::class_table), each row is reduced to + /// one of three verdicts by OR-ing its tokens' classes: + /// * a [`CLASS_DEFINITE`] token present → the row matches outright (a token + /// contains the whole needle); emit without a row check; + /// * else a [`CLASS_OPENER`] token present → the row is a candidate; the + /// exact KMP confirms it in pass 2; + /// * else (all classes zero) → reject, never touching the KMP. + /// + /// The dependent-load + branch chain of the KMP fast path is thus paid only + /// on candidate rows, not on the (dominant at low/medium selectivity) + /// reject majority. Falls back to the generic per-row scan only for the empty + /// needle (which matches every row). + fn scan_contains(&self, aut: &KmpAutomaton, mut on_match: impl FnMut(usize)) { + let n = self.code_offsets.len() - 1; + if aut.is_empty_needle() { + scan(aut, self.codes, self.code_offsets, on_match); + return; + } + + // Optional SIMD INNER pass-1: when the INNER token set collapses into a + // small number of contiguous id ranges, classify the whole code stream + // with AVX2 range tests (any-INNER per code), reduce to candidate rows, + // and confirm with the exact KMP. Sound (every match holds an INNER + // token). Gated behind a small range budget — above it the per-code + // range chain is longer than the scalar gather it replaces. + if std::env::var_os("ONPAIR_INNER_SIMD").is_some() + && let Some(ranges) = aut.inner_ranges(INNER_RANGE_BUDGET) + { + if ranges.is_empty() { + return; // no INNER token ⇒ no match possible. + } + self.scan_contains_inner(aut, &ranges, on_match); + return; + } + + // Optional 3-layer funnel: a cheap SIMD INNER reject (layer 1) over the + // whole stream, then the precise scalar adjacency chain (layer 2) only on + // layer-1 survivors, then exact KMP (layer 3) on chain candidates. Both + // INNER-presence and the open→cont chain are independently necessary for + // a match, so ANDing them drops no true match. The point: replace + // row_chain over ALL codes with classify_inner over all codes (SIMD) + + // row_chain over only the survivors (13–38%). + if std::env::var_os("ONPAIR_FUNNEL").is_some() + && let Some(ranges) = aut.inner_ranges(INNER_RANGE_BUDGET) + { + if ranges.is_empty() { + return; + } + self.scan_contains_funnel(aut, &ranges, on_match); + return; + } + + let chain = aut.chain_table(); + for r in 0..n { + let codes = self.row_codes(r); + match row_chain(&chain, codes) { + // A DEFINITE token: the row matches outright. + RowChain::Definite => on_match(r), + // An open→continue pair exists: a boundary-spanning match is + // possible; confirm with the exact KMP. + RowChain::Candidate => { + if aut.matches(codes) { + on_match(r); + } + } + // Neither: the row cannot contain the needle. + RowChain::Reject => {} + } + } + } + + /// SIMD INNER contains scan. Pass 1 marks each code that lies in any INNER + /// range (AVX2 multi-range test over the whole stream) into a per-code + /// bitset; pass 2 visits rows holding a marked code and confirms with the + /// exact KMP. INNER presence is a sound necessary condition for a match. + fn scan_contains_inner( + &self, + aut: &KmpAutomaton, + ranges: &[(Token, Token)], + mut on_match: impl FnMut(usize), + ) { + let m = self.codes.len(); + let words = m.div_ceil(64); + let mut inner_bits = vec![0u64; words]; + classify_inner(self.codes, ranges, &mut inner_bits); + for r in 0..self.code_offsets.len() - 1 { + let s = self.code_offsets[r].as_usize(); + let e = self.code_offsets[r + 1].as_usize(); + if any_bit_in_range(&inner_bits, s, e) && aut.matches(&self.codes[s..e]) { + on_match(r); + } + } + } + + /// 3-layer funnel contains scan. Layer 1: SIMD INNER classify over the whole + /// code stream into a per-code bitset (cheap, but only ~13–38% selective). + /// Layer 2: for rows surviving layer 1, the precise scalar adjacency chain + /// (`row_chain`) — run only on survivors, not all rows. Layer 3: exact KMP on + /// chain candidates. Both INNER-presence and the open→cont chain are + /// necessary for a match, so ANDing the layers drops no true match. + fn scan_contains_funnel( + &self, + aut: &KmpAutomaton, + ranges: &[(Token, Token)], + mut on_match: impl FnMut(usize), + ) { + let words = self.codes.len().div_ceil(64); + let mut inner_bits = vec![0u64; words]; + classify_inner(self.codes, ranges, &mut inner_bits); + let chain = aut.chain_table(); + for r in 0..self.code_offsets.len() - 1 { + let s = self.code_offsets[r].as_usize(); + let e = self.code_offsets[r + 1].as_usize(); + // Layer 1: SIMD INNER reject — skip the scalar chain entirely if no + // INNER code is present. + if !any_bit_in_range(&inner_bits, s, e) { + continue; + } + let codes = &self.codes[s..e]; + // Layer 2+3: precise chain, then exact KMP. + match row_chain(&chain, codes) { + RowChain::Definite => on_match(r), + RowChain::Candidate => { + if aut.matches(codes) { + on_match(r); + } + } + RowChain::Reject => {} + } + } + } + + /// Prefix scan in two passes over the contiguous first-token table. + /// + /// Pass 1 is a fully branchless range filter: a row is a candidate iff its + /// first token lies in the sound superset range `[lo, hi]` returned by + /// [`PrefixAutomaton::prefilter_range`]. It touches one code per row (the + /// linear `first_codes`, never the scattered code stream), so it is cheap + /// even at low selectivity, and is the part that vectorises. + /// + /// Pass 2 only visits candidates. For a single-token query the range is + /// exact, so candidates are emitted directly; otherwise each is confirmed + /// with a full row check — the only place the scattered codes are read. + /// + /// Falls back to the generic per-row scan for the empty query, or when the + /// dictionary is fully saturated (`num_tokens == 65536`) and the empty-row + /// sentinel `u16::MAX` could collide with a real token id. + fn scan_prefix( + &self, + aut: &PrefixAutomaton, + num_tokens: usize, + mut on_match: impl FnMut(usize), + ) { + let n = self.code_offsets.len() - 1; + // Use the prefilter only with a same-length first-token table and an + // unsaturated dictionary (so the u16::MAX empty-row sentinel cannot + // collide with a real id); otherwise scan generically. + let first_codes = match self.first_codes { + Some(fc) if fc.len() == n && num_tokens <= u16::MAX as usize => fc, + _ => { + scan(aut, self.codes, self.code_offsets, on_match); + return; + } + }; + if aut.is_empty_query() { + scan(aut, self.codes, self.code_offsets, on_match); + return; + } + let pf = aut.prefilter(); + let words = n.div_ceil(64); + + if !pf.needs_verify() { + // Single-token query: the accept range is exact. One branchless + // pass, emit directly — no row ever touches the scattered codes. + let mut acc = vec![0u64; words]; + prefilter_accept(first_codes, pf.alo, pf.awidth, &mut acc); + for_each_set_bit(&acc, on_match); + return; + } + + // Multi-token query. Pass 1 splits rows into definite accepts (first + // token begins with the whole needle) and verify candidates (first + // token equals the query head). Both predicates are branchless. + let mut acc = vec![0u64; words]; + let mut ver = vec![0u64; words]; + prefilter_accept_verify( + first_codes, + pf.alo, + pf.awidth, + pf.vpoint, + &mut acc, + &mut ver, + ); + + // Definite accepts: emit directly. + for_each_set_bit(&acc, &mut on_match); + // Pass 2: confirm only the (usually few) verify candidates — the one + // place the scattered code stream is read. + for_each_set_bit(&ver, |r| { + if aut.matches(self.row_codes(r)) { + on_match(r); + } + }); + } + + /// Prefix scan that writes its result directly as a [`RowMask`] bitset, + /// skipping the per-row callback. Pass 1's accept predicate already produces + /// the matching-rows bitmap, so it is written straight into the mask words + /// (a contiguous SIMD store) instead of being walked bit-by-bit; only the + /// verify candidates are confirmed and OR'd in individually. This is the + /// fast path behind [`search`](Self::search) for prefix queries — at high + /// selectivity it avoids emitting hundreds of thousands of bits one call at + /// a time. + /// + /// Returns `None` when the first-token prefilter is not applicable (empty + /// query, missing/short index, or saturated dictionary), so the caller can + /// fall back to the generic callback scan. + fn prefix_mask(&self, aut: &PrefixAutomaton, num_tokens: usize) -> Option { + let n = self.code_offsets.len() - 1; + let first_codes = match self.first_codes { + Some(fc) if fc.len() == n && num_tokens <= u16::MAX as usize => fc, + _ => return None, + }; + if aut.is_empty_query() { + return None; + } + let pf = aut.prefilter(); + let words = n.div_ceil(64); + let mut acc = vec![0u64; words]; + + if pf.needs_verify() { + // Multi-token: accepts go straight into `acc`; verify candidates are + // confirmed and OR'd in (they are disjoint from the accept range). + let mut ver = vec![0u64; words]; + prefilter_accept_verify( + first_codes, + pf.alo, + pf.awidth, + pf.vpoint, + &mut acc, + &mut ver, + ); + for_each_set_bit(&ver, |r| { + if aut.matches(self.row_codes(r)) { + acc[r >> 6] |= 1u64 << (r & 63); + } + }); + } else { + // Single-token: the accept range is exact — pass 1 is the answer. + prefilter_accept(first_codes, pf.alo, pf.awidth, &mut acc); + } + Some(RowMask { + words: acc, + rows: n, + }) + } + + /// Evaluate `pattern` against every row, returning a [`RowMask`] whose set + /// bits are the matching row indices. The match is computed in the + /// compressed domain — rows are never decompressed. + pub fn search(&self, pattern: Pattern<'_>) -> RowMask { + // Prefix queries take the bitmap-merge fast path: the prefilter writes + // the result bits straight into the mask instead of via a per-row + // callback. Falls through to the generic callback build otherwise. + if let Pattern::Prefix(needle) = pattern { + let dict = self.dict(); + let aut = PrefixAutomaton::new(needle, dict); + if let Some(mask) = self.prefix_mask(&aut, dict.num_tokens()) { + return mask; + } + } + let mut mask = RowMask::zeros(self.num_rows()); + self.search_callback(pattern, |r| mask.set(r)); + mask + } +} + +impl Column { + /// Zero-copy [`SearchParts`] view over this column, for + /// [`SearchParts::search`]. Parallels [`as_parts`](Column::as_parts), but + /// includes `code_offsets` (the row delimiters search needs). + #[inline] + pub fn as_search_parts(&self) -> SearchParts<'_, O> { + SearchParts { + dict_bytes: &self.dict_bytes, + dict_offsets: &self.dict_offsets, + codes: &self.codes, + code_offsets: &self.code_offsets, + first_codes: self.first_codes.as_deref(), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{Bits, Config, Threshold, compress}; + + /// Pack rows into the Arrow `(bytes, offsets)` pair `compress` expects. + fn pack(rows: &[&[u8]]) -> (Vec, Vec) { + let mut bytes = Vec::new(); + let mut offsets = vec![0u32]; + for r in rows { + bytes.extend_from_slice(r); + offsets.push(bytes.len() as u32); + } + (bytes, offsets) + } + + fn cfg() -> Config { + Config { + bits: Bits::new(12).unwrap(), + threshold: Threshold::new(0.5).unwrap(), + seed: Some(42), + } + } + + fn naive_contains(row: &[u8], needle: &[u8]) -> bool { + needle.is_empty() || row.windows(needle.len()).any(|w| w == needle) + } + + /// Materialise the set-row indices of a mask from its packed words. + fn mask_ones(mask: &RowMask) -> Vec { + let mut out = Vec::new(); + for (w, &word) in mask.as_words().iter().enumerate() { + let mut bits = word; + while bits != 0 { + out.push(w * 64 + bits.trailing_zeros() as usize); + bits &= bits - 1; + } + } + out + } + + fn assert_matches(rows: &[&[u8]], pattern: Pattern<'_>, expect: impl Fn(&[u8]) -> bool) { + let (bytes, offsets) = pack(rows); + let col = compress(&bytes, &offsets, cfg()).unwrap(); + let mask = col.as_search_parts().search(pattern); + let got = mask_ones(&mask); + let want: Vec = rows + .iter() + .enumerate() + .filter_map(|(i, r)| expect(r).then_some(i)) + .collect(); + assert_eq!(got, want, "pattern {pattern:?}"); + assert_eq!(mask.len(), rows.len()); + } + + /// A corpus with heavy prefix sharing and repeated substrings so the + /// trainer emits multi-byte tokens (exercising the sparse KMP transitions + /// and prefix-divergence intervals rather than only single-byte tokens). + fn url_corpus() -> Vec> { + let hosts = [ + "https://www.example.com", + "https://api.example.org", + "ftp://x.example.net", + ]; + let paths = ["/index.html", "/search?q=onpair", "/a/b/c", "", "/login"]; + let mut out = Vec::new(); + let mut x = 0x1234_5678u64; + for _ in 0..2000 { + x = x.wrapping_mul(6364136223846793005).wrapping_add(1); + let h = hosts[(x >> 33) as usize % hosts.len()]; + let p = paths[(x >> 17) as usize % paths.len()]; + out.push(format!("{h}{p}{}", x % 100).into_bytes()); + } + out + } + + #[test] + fn contains_matches_naive_across_needles() { + let owned = url_corpus(); + let rows: Vec<&[u8]> = owned.iter().map(|v| v.as_slice()).collect(); + for needle in [ + b"example".as_slice(), + b"https://".as_slice(), + b"search?q=onpair".as_slice(), + b"/a/b/c".as_slice(), + b"zzz-not-present".as_slice(), + b"e".as_slice(), + b"".as_slice(), + ] { + assert_matches(&rows, Pattern::Contains(needle), |r| { + naive_contains(r, needle) + }); + } + } + + #[test] + fn prefix_matches_naive_across_needles() { + let owned = url_corpus(); + let rows: Vec<&[u8]> = owned.iter().map(|v| v.as_slice()).collect(); + for needle in [ + b"https://".as_slice(), + b"https://www.example.com".as_slice(), + b"ftp://".as_slice(), + b"https://api.example.org/login".as_slice(), + b"nope".as_slice(), + b"".as_slice(), + ] { + assert_matches(&rows, Pattern::Prefix(needle), |r| r.starts_with(needle)); + } + } + + #[test] + fn single_byte_needles() { + let rows: &[&[u8]] = &[b"abc", b"xyz", b"a", b"", b"cba"]; + for b in [b"a".as_slice(), b"z".as_slice(), b"q".as_slice()] { + assert_matches(rows, Pattern::Contains(b), |r| naive_contains(r, b)); + assert_matches(rows, Pattern::Prefix(b), |r| r.starts_with(b)); + } + } + + #[test] + fn needle_longer_than_any_token() { + // A 20-byte needle exceeds MAX_TOKEN_SIZE; prefix_range short-circuits. + let rows: &[&[u8]] = &[b"this is a fairly long row of text", b"short"]; + let needle = b"fairly long row of t"; // 20 bytes + assert_matches(rows, Pattern::Contains(needle), |r| { + naive_contains(r, needle) + }); + let pneedle = b"this is a fairly lon"; // 20 bytes + assert_matches(rows, Pattern::Prefix(pneedle), |r| r.starts_with(pneedle)); + } +} diff --git a/src/search/prefix.rs b/src/search/prefix.rs new file mode 100644 index 0000000..4136734 --- /dev/null +++ b/src/search/prefix.rs @@ -0,0 +1,140 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors +// +// Port of `include/onpair/search/automata/prefix_automaton.h`. + +use super::tokenize::tokenize; +use super::{DictView, RowMatcher, TokenRange}; +use crate::types::Token; + +/// Token-level matcher for prefix search (`col LIKE 'prefix%'`). +/// +/// The needle is tokenised once. Walking a row's tokens against the query +/// sequence: +/// * exact match → advance; +/// * mismatch at position `i` → match iff the token falls inside the +/// precomputed valid-divergence interval for `i` (the row's token still +/// has the remaining needle bytes as a prefix), else no match; +/// * all query tokens consumed → match (the rest of the row is irrelevant). +/// +/// The decision is final at the first non-advancing token, so most rows are +/// settled in one step. +pub(crate) struct PrefixAutomaton { + query_tokens: Vec, + intervals: Vec, +} + +impl PrefixAutomaton { + pub(crate) fn new(prefix: &[u8], dv: DictView<'_>) -> Self { + let query_tokens = tokenize(prefix, dv); + let q_len = query_tokens.len(); + let mut intervals = vec![TokenRange::EMPTY; q_len]; + + // For each query position, the divergence interval is the set of tokens + // that begin with the not-yet-consumed needle suffix. + let mut current_pos = 0usize; + for i in 0..q_len { + intervals[i] = dv.prefix_range(&prefix[current_pos..]); + current_pos += dv.token_size(query_tokens[i]); + } + + Self { + query_tokens, + intervals, + } + } + + /// Whether the query tokenised to nothing (the empty prefix, which matches + /// every row). The prefilter path is skipped for it. + #[inline] + pub(crate) fn is_empty_query(&self) -> bool { + self.query_tokens.is_empty() + } + + /// First-token prefilter parameters. Precondition: the query is non-empty. + /// + /// A row matches only if its first token either begins with the whole + /// needle (`first_code ∈ intervals[0] = [begin, last]`) or equals the query + /// head `q0` (with the remaining query tokens still to be checked). Because + /// the dictionary is lexicographically sorted and `q0` is a prefix of the + /// needle, `q0 <= begin`, so the two id sets are disjoint for a multi-token + /// query — the [`Prefilter`] reports them separately: + /// + /// * the **accept** range `[begin, last]` — a single unsigned range check + /// `(fc - alo) <= awidth` — is a definite match (the first token alone + /// begins with the needle), so it needs no row check; + /// * the **verify** point `q0` flags the rare case where the needle is + /// split at `q0`, which a full row check then settles. + /// + /// A single-token query *is* the whole needle, so `q0 == begin` and the + /// accept range is necessary and sufficient; [`Prefilter::vpoint`] is then + /// disabled. The `u16::MAX` empty-row sentinel exceeds `last` (when the + /// dictionary is not saturated) and equals neither, so empties drop out. + #[inline] + pub(crate) fn prefilter(&self) -> Prefilter { + let q0 = self.query_tokens[0]; + let iv = self.intervals[0]; + // Empty accept range → match nothing: `alo` above any u16 makes + // `(fc - alo)` wrap past `awidth = 0` for every real first code. + let (alo, awidth) = if iv.empty() { + (u32::MAX, 0) + } else { + (iv.begin as u32, (iv.last - iv.begin) as u32) + }; + // Single-token query is exact; disable the verify point (no u16 first + // code can equal u32::MAX). + let vpoint = if self.query_tokens.len() == 1 { + u32::MAX + } else { + q0 as u32 + }; + Prefilter { + alo, + awidth, + vpoint, + } + } +} + +/// First-token prefilter parameters; see [`PrefixAutomaton::prefilter`]. +pub(crate) struct Prefilter { + /// Accept range lower bound. `u32::MAX` makes the range match nothing. + pub alo: u32, + /// Accept range width: `first_code` accepts iff `(first_code - alo) <= awidth`. + pub awidth: u32, + /// Verify point: `first_code == vpoint` needs a full row check. `u32::MAX` + /// (a value no `u16` first code can take) disables verification. + pub vpoint: u32, +} + +impl Prefilter { + /// Whether any first code can route to a full row check. + #[inline] + pub(crate) fn needs_verify(&self) -> bool { + self.vpoint != u32::MAX + } +} + +impl RowMatcher for PrefixAutomaton { + #[inline] + fn matches(&self, codes: &[Token]) -> bool { + // Empty prefix matches every row. + if self.query_tokens.is_empty() { + return true; + } + let mut pos = 0usize; + for &t in codes { + if t != self.query_tokens[pos] { + // First divergence: matches iff the token still carries the + // remaining needle bytes as a prefix. + return self.intervals[pos].contains(t); + } + pos += 1; + if pos == self.query_tokens.len() { + return true; + } + } + // Row ended with every token matched but the prefix not exhausted. + false + } +} diff --git a/src/search/tokenize.rs b/src/search/tokenize.rs new file mode 100644 index 0000000..48e279d --- /dev/null +++ b/src/search/tokenize.rs @@ -0,0 +1,81 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors +// +// Port of `include/onpair/search/detail/tokenize.h`. + +use super::DictView; +use crate::types::{MAX_TOKEN_SIZE, Token}; + +/// Greedy longest-match tokenisation of `text` against the sorted dictionary, +/// matching the encoder's own segmentation. Used to turn a query needle into +/// the token sequence the automata reason about. +/// +/// Precondition: the dictionary is sorted and contains the 256 single-byte +/// base tokens (guaranteed after [`crate::Parser::train`]). +pub(crate) fn tokenize(text: &[u8], dv: DictView<'_>) -> Vec { + let mut tokens = Vec::with_capacity(text.len()); + + let num_tokens = dv.num_tokens(); + let tlen = |t: Token| -> usize { dv.token_size(t) }; + let byte_at = |t: Token, k: usize| -> u8 { dv.data(t)[k] }; + + let mut pos = 0usize; + while pos < text.len() { + let remaining = text.len() - pos; + let max_len = remaining.min(MAX_TOKEN_SIZE); + + let mut best: Token = 0; + let mut range = (0u32, (num_tokens - 1) as u32); // [begin, last] + + for k in 0..max_len { + let target = text[pos + k]; + + // Lower bound: first token in range with byte[k] >= target. + // Tokens shorter than k+1 sort before any that has the byte. + let (mut lo, mut hi) = (range.0, range.1); + while lo < hi { + let mid = lo + ((hi - lo) >> 1); + if tlen(mid as Token) <= k || byte_at(mid as Token, k) < target { + lo = mid + 1; + } else { + hi = mid; + } + } + if tlen(lo as Token) <= k || byte_at(lo as Token, k) != target { + break; + } + + let first = lo; + + // Upper bound: first token with byte[k] > target, stepped back to + // the last with byte[k] == target. `lo` already holds `first`. + hi = range.1; + while lo < hi { + let mid = lo + ((hi - lo) >> 1); + if tlen(mid as Token) <= k || byte_at(mid as Token, k) <= target { + lo = mid + 1; + } else { + hi = mid; + } + } + let last = if tlen(lo as Token) > k && byte_at(lo as Token, k) > target { + lo - 1 + } else { + lo + }; + + // The shortest token in range sorts first; if its length is exactly + // k+1 it is an exact match of the consumed bytes. + if tlen(first as Token) == k + 1 { + best = first as Token; + } + + range = (first, last); + } + + tokens.push(best); + pos += tlen(best); + } + + tokens +}