diff --git a/Cargo.toml b/Cargo.toml index ab88a77..b0d4e92 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,3 +24,9 @@ fancy-regex = "0.13.0" regex = "1.10.3" rustc-hash = "2" bstr = "1.5.0" +unicode-properties = { version = "0.1.4", default-features = false, features = ["general-category"] } +rand = { version = "0.8", features = ["std"] } +base64 = "0.21" + +[dev-dependencies] +proptest = "1.4" diff --git a/src/bin/cl100k_fuzz.rs b/src/bin/cl100k_fuzz.rs new file mode 100644 index 0000000..bccd0f5 --- /dev/null +++ b/src/bin/cl100k_fuzz.rs @@ -0,0 +1,537 @@ +use std::collections::HashMap; +use std::fs::File; +use std::io::{BufRead, BufReader}; +use std::env; +use std::time::{Duration, Instant}; + +use base64::engine::general_purpose::STANDARD as BASE64; +use base64::Engine; + +use rand::distributions::{Distribution, Uniform}; +use rand::{rngs::StdRng, Rng, SeedableRng}; +use tiktoken::cl100k::{Cl100kMatchKind, Cl100kParser, CL100K_PATTERN}; +use tiktoken::{CoreBPE, PatternBackendChoice, Rank}; + +const DEFAULT_STEPS: usize = 5_000; +const DEFAULT_MAX_LEN: usize = 512; + +#[derive(Clone, Debug, PartialEq, Eq)] +enum Mode { + Split, + Bpe, + File(String), +} + +const INTERESTING_CHARS: &[char] = &[ + '\0', '\u{0001}', '\u{0002}', '\u{0003}', '\u{0004}', '\u{0005}', '\u{0006}', '\u{0007}', + '\u{0008}', '\t', '\n', '\u{000B}', '\u{000C}', '\r', '\u{000E}', '\u{000F}', '\u{0010}', + '\u{0011}', '\u{0012}', '\u{0013}', '\u{0014}', '\u{0015}', '\u{0016}', '\u{0017}', '\u{0018}', + '\u{0019}', '\u{001A}', '\u{001B}', '\u{001C}', '\u{001D}', '\u{001E}', '\u{001F}', ' ', + '\u{0085}', '\u{00A0}', '\u{1680}', '\u{180E}', '\u{2000}', '\u{2001}', '\u{2002}', '\u{2003}', + '\u{2004}', '\u{2005}', '\u{2006}', '\u{2007}', '\u{2008}', '\u{2009}', '\u{200A}', '\u{2028}', + '\u{2029}', '\u{202F}', '\u{205F}', '\u{3000}', '\u{FEFF}', +]; + +const ASCII_WHITESPACE: &[char] = &[' ', '\t', '\n', '\r', '\u{000B}', '\u{000C}']; +const ASCII_ALNUM: &[char] = &[ + 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', + 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', + 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '0', '1', '2', '3', '4', + '5', '6', '7', '8', '9', +]; +const ASCII_PUNCT: &[char] = &[ + '!', '"', '#', '$', '%', '&', '\'', '(', ')', '*', '+', ' ', ',', '-', '.', '/', ':', ';', '<', + '=', '>', '?', '@', '[', '\\', ']', '^', '_', '`', '{', '|', '}', '~', +]; + +#[derive(Debug, Clone, PartialEq, Eq)] +struct CustomSpan { + start: usize, + end: usize, + kind: Cl100kMatchKind, +} + +fn main() { + let (mode, steps, max_len) = parse_args(); + println!("Starting cl100k fuzz: mode={:?}, steps={steps}, max_len={max_len}", mode); + + match mode { + Mode::Split => run_split_fuzz(steps, max_len), + Mode::Bpe => run_bpe_fuzz(steps, max_len), + Mode::File(path) => run_file_bench(&path, steps), + } +} + +fn parse_args() -> (Mode, usize, usize) { + let mut args = env::args().skip(1); + let mode = match args.next().as_deref() { + Some("split") | None => Mode::Split, + Some("bpe") => Mode::Bpe, + Some("file") => { + let path = args.next().unwrap_or_else(|| { + eprintln!("Usage: cl100k_fuzz file "); + std::process::exit(2); + }); + Mode::File(path) + } + Some(other) => { + eprintln!("Unknown mode '{other}'. Use 'split', 'bpe', or 'file'."); + std::process::exit(2); + } + }; + let steps = args + .next() + .as_deref() + .map(parse_usize) + .unwrap_or(DEFAULT_STEPS); + let max_len = args + .next() + .as_deref() + .map(parse_usize) + .unwrap_or(DEFAULT_MAX_LEN); + if steps == 0 { + eprintln!("Step count must be greater than zero"); + std::process::exit(2); + } + (mode, steps, max_len) +} + +fn parse_usize(arg: &str) -> usize { + arg.parse().unwrap_or_else(|_| { + eprintln!("Invalid integer argument: {arg}"); + std::process::exit(2); + }) +} + +fn generate_input(rng: &mut StdRng, max_len: usize) -> String { + let len_range = Uniform::new_inclusive(0, max_len); + let len = len_range.sample(rng); + let mut s = String::with_capacity(len); + for _ in 0..len { + s.push(sample_char(rng)); + } + s +} +fn run_split_fuzz(steps: usize, max_len: usize) { + use std::io::Write; + + let parser = Cl100kParser::new(); + let mut rng = StdRng::from_entropy(); + + let mut fancy_total = Duration::ZERO; + let mut custom_total = Duration::ZERO; + + for iter in 1..=steps { + let input = generate_input(&mut rng, max_len); + + let fancy_start = Instant::now(); + let fancy = collect_fancy_spans(&input); + fancy_total += fancy_start.elapsed(); + + let custom_start = Instant::now(); + let custom = collect_custom_spans(&input, &parser); + custom_total += custom_start.elapsed(); + + if let Err(err) = compare_spans(&input, &fancy, &custom) { + eprintln!("\nMismatch detected on iteration {iter}"); + eprintln!("Input ({} bytes):", input.len()); + dump_bytes(&input); + eprintln!("{err}"); + std::process::exit(1); + } + + if iter % 100 == 0 { + print!("\rCompleted {iter}/{steps} cases"); + let _ = std::io::stdout().flush(); + } + } + + println!("\nFinished {steps} cases with no mismatches."); + println!("Total fancy_regex split time: {:?}", fancy_total); + println!("Total custom parser split time: {:?}", custom_total); + let fancy_avg = fancy_total.as_secs_f64() / steps as f64; + let custom_avg = custom_total.as_secs_f64() / steps as f64; + println!( + "Average per case: fancy_regex={:.6}s, custom_parser={:.6}s", + fancy_avg, custom_avg + ); + if fancy_total.is_zero() || custom_total.is_zero() { + println!("Custom speedup vs Fancy: n/a"); + } else { + println!( + "Custom speedup vs Fancy: {:.3}x", + fancy_total.as_secs_f64() / custom_total.as_secs_f64() + ); + } +} + +fn run_bpe_fuzz(steps: usize, max_len: usize) { + use std::io::Write; + + let (encoder, specials) = load_bpe_from_file("../cl100k_base.tiktoken") + .unwrap_or_else(|e| { + eprintln!("Failed to load cl100k_base.tiktoken: {e}"); + std::process::exit(2); + }); + let decoder = make_decoder_map(&encoder); + let bpe_custom = CoreBPE::new_with_backend::<_, _, std::iter::Empty<(String, (Rank, Rank))>>( + encoder.clone(), + specials.clone(), + CL100K_PATTERN, + PatternBackendChoice::Cl100kParser, + ) + .expect("failed to construct custom CoreBPE"); + let bpe_fancy = CoreBPE::new_with_backend::<_, _, std::iter::Empty<(String, (Rank, Rank))>>( + encoder, + specials, + CL100K_PATTERN, + PatternBackendChoice::FancyRegex, + ) + .expect("failed to construct fancy CoreBPE"); + + let mut rng = StdRng::from_entropy(); + + let mut fancy_total = Duration::ZERO; + let mut custom_total = Duration::ZERO; + + for iter in 1..=steps { + let input = generate_input(&mut rng, max_len); + + let t0 = Instant::now(); + let fancy_tokens = bpe_fancy.encode_ordinary(&input); + fancy_total += t0.elapsed(); + + let t1 = Instant::now(); + let custom_tokens = bpe_custom.encode_ordinary(&input); + custom_total += t1.elapsed(); + + if fancy_tokens != custom_tokens { + eprintln!("\nToken mismatch on iteration {iter}"); + eprintln!("Input ({} bytes):", input.len()); + dump_bytes(&input); + eprintln!("fancy tokens (len={}):", fancy_tokens.len()); + dump_tokens(&fancy_tokens); + eprintln!("custom tokens (len={}):", custom_tokens.len()); + dump_tokens(&custom_tokens); + std::process::exit(1); + } + // if iter % 500 == 0 { + // let fancy_dec = decode_tokens(&fancy_tokens, &decoder); + // let custom_dec = decode_tokens(&custom_tokens, &decoder); + // let fancy_text = String::from_utf8_lossy(&fancy_dec); + // let custom_text = String::from_utf8_lossy(&custom_dec); + // println!("\n[iter {iter}] sample input: {}", preview(&input, 160)); + // println!("[iter {iter}] fancy decode: {}", preview(&fancy_text, 160)); + // println!("[iter {iter}] custom decode: {}", preview(&custom_text, 160)); + // } + + if iter % 100 == 0 { + print!("\rCompleted {iter}/{steps} cases"); + let _ = std::io::stdout().flush(); + } + } + + println!("\nFinished {steps} cases with no mismatches."); + println!("Total fancy_regex encode time: {:?}", fancy_total); + println!("Total custom parser encode time: {:?}", custom_total); + let fancy_avg = fancy_total.as_secs_f64() / steps as f64; + let custom_avg = custom_total.as_secs_f64() / steps as f64; + println!( + "Average per case: fancy_regex={:.6}s, custom_parser={:.6}s", + fancy_avg, custom_avg + ); + if fancy_total.is_zero() || custom_total.is_zero() { + println!("Custom speedup vs Fancy: n/a"); + } else { + println!( + "Custom speedup vs Fancy: {:.3}x", + fancy_total.as_secs_f64() / custom_total.as_secs_f64() + ); + } +} + +fn run_file_bench(path: &str, iterations: usize) { + let (encoder, specials) = load_bpe_from_file("../cl100k_base.tiktoken").unwrap_or_else(|e| { + eprintln!("Failed to load cl100k_base.tiktoken: {e}"); + std::process::exit(2); + }); + let decoder = make_decoder_map(&encoder); + + let bpe_custom = CoreBPE::new_with_backend::<_, _, std::iter::Empty<(String, (Rank, Rank))>>( + encoder.clone(), + specials.clone(), + CL100K_PATTERN, + PatternBackendChoice::Cl100kParser, + ) + .expect("failed to construct custom CoreBPE"); + let bpe_fancy = CoreBPE::new_with_backend::<_, _, std::iter::Empty<(String, (Rank, Rank))>>( + encoder, + specials, + CL100K_PATTERN, + PatternBackendChoice::FancyRegex, + ) + .expect("failed to construct fancy CoreBPE"); + + let input = std::fs::read_to_string(path).unwrap_or_else(|e| { + eprintln!("Failed to read input file '{path}': {e}"); + std::process::exit(2); + }); + + // First pass: verify equality once and time single-run (warmup) + let t0 = Instant::now(); + let fancy_tokens = bpe_fancy.encode_ordinary(&input); + let fancy_first = t0.elapsed(); + + let t1 = Instant::now(); + let custom_tokens = bpe_custom.encode_ordinary(&input); + let custom_first = t1.elapsed(); + + if fancy_tokens != custom_tokens { + eprintln!("Token mismatch between fancy and custom on file '{}':", path); + eprintln!("fancy tokens (len={}):", fancy_tokens.len()); + dump_tokens(&fancy_tokens); + eprintln!("custom tokens (len={}):", custom_tokens.len()); + dump_tokens(&custom_tokens); + std::process::exit(1); + } + + let fancy_dec = decode_tokens(&fancy_tokens, &decoder); + let custom_dec = decode_tokens(&custom_tokens, &decoder); + if fancy_dec != custom_dec { + eprintln!("Decoded byte mismatch between fancy and custom on file '{}'.", path); + std::process::exit(1); + } + + // Optional: also check roundtrip + if fancy_dec != input.as_bytes() { + eprintln!("Warning: decoded bytes differ from input file bytes (possibly due to encoding/normalization). Continuing."); + } + + // Iterated benchmark over the dataset + let mut fancy_total = fancy_first; // include first pass + let mut custom_total = custom_first; // include first pass + for _ in 1..iterations { // already did one + let t0 = Instant::now(); + let _f = bpe_fancy.encode_ordinary(&input); + fancy_total += t0.elapsed(); + + let t1 = Instant::now(); + let _c = bpe_custom.encode_ordinary(&input); + custom_total += t1.elapsed(); + } + + println!("\nFile bench for '{path}' (iterations={}):", iterations); + println!("Fancy total: {:?}", fancy_total); + println!("Custom total: {:?}", custom_total); + let fancy_avg = fancy_total.as_secs_f64() / iterations as f64; + let custom_avg = custom_total.as_secs_f64() / iterations as f64; + println!("Average/iter: fancy={:.6}s custom={:.6}s", fancy_avg, custom_avg); + if fancy_total.is_zero() || custom_total.is_zero() { + println!("Custom speedup vs Fancy: n/a"); + } else { + println!( + "Custom speedup vs Fancy: {:.3}x", + fancy_total.as_secs_f64() / custom_total.as_secs_f64() + ); + } +} + +fn load_bpe_from_file(path: &str) -> Result<(HashMap, Rank>, HashMap), String> { + let file = File::open(path).map_err(|e| e.to_string())?; + let reader = BufReader::new(file); + let mut encoder: HashMap, Rank> = HashMap::with_capacity(100_000); + for (lineno, line_res) in reader.lines().enumerate() { + let line = line_res.map_err(|e| e.to_string())?; + if line.is_empty() { + continue; + } + let mut parts = line.splitn(2, ' '); + let tok_b64 = parts.next().ok_or_else(|| format!("Malformed line {}", lineno + 1))?; + let rank_str = parts.next().ok_or_else(|| format!("Malformed line {}", lineno + 1))?; + let bytes = BASE64 + .decode(tok_b64.as_bytes()) + .map_err(|e| format!("base64 decode error at line {}: {}", lineno + 1, e))?; + let rank: Rank = rank_str + .parse::() + .map_err(|e| format!("rank parse error at line {}: {}", lineno + 1, e))?; + encoder.insert(bytes, rank); + } + let specials: HashMap = HashMap::new(); + Ok((encoder, specials)) +} + +fn make_decoder_map(encoder: &HashMap, Rank>) -> HashMap> { + let mut decoder: HashMap> = HashMap::with_capacity(encoder.len()); + for (bytes, &rank) in encoder.iter() { + decoder.insert(rank, bytes.clone()); + } + decoder +} + +fn decode_tokens(tokens: &[Rank], decoder: &HashMap>) -> Vec { + let mut out = Vec::with_capacity(tokens.len() * 2); + for &t in tokens { + if let Some(b) = decoder.get(&t) { + out.extend_from_slice(b); + } + } + out +} + +fn preview>(s: S, max: usize) -> String { + let s = s.as_ref(); + if s.chars().count() <= max { + s.to_string() + } else { + let mut out = String::new(); + for (i, ch) in s.chars().enumerate() { + if i >= max { break; } + out.push(ch); + } + out.push('…'); + out + } +} + +fn sample_char(rng: &mut StdRng) -> char { + let bucket: u32 = rng.gen_range(0..14); + match bucket { + 0..=5 => *INTERESTING_CHARS.choose(rng).unwrap(), + 6..=8 => *ASCII_WHITESPACE.choose(rng).unwrap(), + 9..=10 => *ASCII_ALNUM.choose(rng).unwrap(), + 11..=12 => *ASCII_PUNCT.choose(rng).unwrap(), + _ => random_unicode(rng), + } +} + +fn random_unicode(rng: &mut StdRng) -> char { + loop { + let value: u32 = rng.gen_range(0..=0x10FFFF); + if let Some(ch) = char::from_u32(value) { + return ch; + } + } +} + +fn collect_fancy_spans(text: &str) -> Vec<(usize, usize)> { + use fancy_regex::Regex; + use std::sync::OnceLock; + + static INSTANCE: OnceLock = OnceLock::new(); + let regex = INSTANCE.get_or_init(|| Regex::new(tiktoken::cl100k::CL100K_PATTERN).unwrap()); + regex + .find_iter(text) + .map(|res| { + let m = res.expect("fancy-regex error while tokenizing"); + (m.start(), m.end()) + }) + .collect() +} + +fn collect_custom_spans(text: &str, parser: &Cl100kParser) -> Vec { + parser + .find_iter(text) + .map(|m| CustomSpan { + start: m.start(), + end: m.end(), + kind: m.kind(), + }) + .collect() +} + +fn compare_spans( + text: &str, + fancy: &[(usize, usize)], + custom: &[CustomSpan], +) -> Result<(), String> { + let custom_pairs: Vec<_> = custom.iter().map(|span| (span.start, span.end)).collect(); + if custom_pairs == fancy { + return Ok(()); + } + + if fancy.len() != custom.len() { + let matching_prefix = fancy + .iter() + .zip(custom_pairs.iter()) + .take_while(|(a, b)| a == b) + .count(); + return Err(format!( + "Span count mismatch: fancy={} custom={} (matching prefix spans={})", + fancy.len(), + custom.len(), + matching_prefix + )); + } + + for (idx, ((f_start, f_end), span)) in fancy.iter().zip(custom.iter()).enumerate() { + if *f_start != span.start || *f_end != span.end { + let fancy_slice = &text[*f_start..*f_end]; + let custom_slice = &text[span.start..span.end]; + let fancy_snippet = escape_snippet(fancy_slice, 32); + let custom_snippet = escape_snippet(custom_slice, 32); + return Err(format!( + "Mismatch at span {idx}: fancy [{f_start},{f_end}) \"{fancy_snippet}\" vs custom [{c_start},{c_end}) \"{custom_snippet}\" ({:?})", + span.kind, + c_start = span.start, + c_end = span.end + )); + } + } + + Err("Spans differ but mismatch could not be isolated".to_string()) +} + +fn escape_snippet(slice: &str, limit: usize) -> String { + if slice.is_empty() { + return "".to_string(); + } + let mut out = String::new(); + let mut count = 0; + for ch in slice.chars() { + if count >= limit { + out.push('…'); + break; + } + out.extend(ch.escape_default()); + count += 1; + } + out +} + +fn dump_bytes(text: &str) { + let bytes = text.as_bytes(); + const PER_ROW: usize = 16; + for (row, chunk) in bytes.chunks(PER_ROW).enumerate() { + print!("{:04X}:", row * PER_ROW); + for byte in chunk { + print!(" {:02X}", byte); + } + println!(); + } +} + +fn dump_tokens(tokens: &[Rank]) { + const PER_ROW: usize = 32; + for (row, chunk) in tokens.chunks(PER_ROW).enumerate() { + print!("{:04X}:", row * PER_ROW); + for t in chunk { + print!(" {}", t); + } + println!(); + } +} + +trait ChooseExt { + fn choose<'a>(&'a self, rng: &mut StdRng) -> Option<&'a T>; +} + +impl ChooseExt for [T] { + fn choose<'a>(&'a self, rng: &mut StdRng) -> Option<&'a T> { + if self.is_empty() { + None + } else { + let idx = rng.gen_range(0..self.len()); + Some(&self[idx]) + } + } +} diff --git a/src/cl100k.rs b/src/cl100k.rs new file mode 100644 index 0000000..59662c7 --- /dev/null +++ b/src/cl100k.rs @@ -0,0 +1,463 @@ +use std::cmp::min; + +use unicode_properties::{GeneralCategory, UnicodeGeneralCategory}; + +pub const CL100K_PATTERN: &str = r"'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}++|\p{N}{1,3}+| ?[^\s\p{L}\p{N}]++[\r\n]*+|\s++$|\s*[\r\n]|\s+(?!\S)|\s"; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum Cl100kMatchKind { + Contraction, + LetterWithPrefix, + Number, + Punctuation, + WhitespaceToEof, + WhitespaceThenLinebreak, + TrailingWhitespace, + SingleWhitespace, + Fallback, +} + +#[derive(Debug, Clone, Copy, Default)] +pub struct Cl100kParser; + +impl Cl100kParser { + pub fn new() -> Self { + Self + } + + pub fn find_iter<'a>(&self, text: &'a str) -> Cl100kMatches<'a> { + Cl100kMatches { text, offset: 0 } + } +} + +#[derive(Debug, Clone)] +pub struct Cl100kMatch<'a> { + haystack: &'a str, + start: usize, + end: usize, + kind: Cl100kMatchKind, +} + +impl<'a> Cl100kMatch<'a> { + pub fn as_str(&self) -> &'a str { + &self.haystack[self.start..self.end] + } + + pub fn start(&self) -> usize { + self.start + } + + pub fn end(&self) -> usize { + self.end + } + + pub fn len(&self) -> usize { + self.end - self.start + } + + pub fn kind(&self) -> Cl100kMatchKind { + self.kind + } +} + +pub struct Cl100kMatches<'a> { + text: &'a str, + offset: usize, +} + +impl<'a> Iterator for Cl100kMatches<'a> { + type Item = Cl100kMatch<'a>; + + fn next(&mut self) -> Option { + if self.offset >= self.text.len() { + return None; + } + + let start = self.offset; + let slice = self.text; + + let (advance, kind) = match_branch(slice, start).unwrap_or_else(|| { + let next = char_at(slice, start) + .map(|(_, next)| next) + .unwrap_or_else(|| min(start + 1, slice.len())); + (next - start, Cl100kMatchKind::Fallback) + }); + + let end = start + advance; + self.offset = end; + + Some(Cl100kMatch { + haystack: self.text, + start, + end, + kind, + }) + } +} + +fn match_branch(text: &str, idx: usize) -> Option<(usize, Cl100kMatchKind)> { + match_contraction(text, idx) + .or_else(|| match_word_with_optional_prefix(text, idx)) + .or_else(|| match_short_number(text, idx)) + .or_else(|| match_punct_run(text, idx)) + .or_else(|| match_whitespace_to_eof(text, idx)) + .or_else(|| match_ws_then_linebreak(text, idx)) + .or_else(|| match_trailing_ws(text, idx)) + .or_else(|| match_single_ws(text, idx)) +} + +// Regex branch: `'(?i:[sdmt]|ll|ve|re)` +fn match_contraction(text: &str, idx: usize) -> Option<(usize, Cl100kMatchKind)> { + let (first, next) = char_at(text, idx)?; + if first != '\'' { + return None; + } + + let (a, after_a) = char_at(text, next)?; + let lower_a = ascii_lower(a); + + if matches!(lower_a, 's' | 'd' | 'm' | 't') { + return Some((after_a - idx, Cl100kMatchKind::Contraction)); + } + + let (b, after_b) = char_at(text, after_a)?; + let lower_b = ascii_lower(b); + + if (lower_a == 'l' && lower_b == 'l') + || (lower_a == 'v' && lower_b == 'e') + || (lower_a == 'r' && lower_b == 'e') + { + return Some((after_b - idx, Cl100kMatchKind::Contraction)); + } + + None +} + +// Regex branch: `[^\r\n\p{L}\p{N}]?+\p{L}++` +fn match_word_with_optional_prefix(text: &str, idx: usize) -> Option<(usize, Cl100kMatchKind)> { + let (cp0, next0) = char_at(text, idx)?; + + if !is_cr_or_lf(cp0) && !is_alnum(cp0) { + let (cp1, _) = char_at(text, next0)?; + if !is_letter(cp1) { + return None; + } + let end = consume_letters(text, next0)?; + return Some((end - idx, Cl100kMatchKind::LetterWithPrefix)); + } + + if !is_letter(cp0) { + return None; + } + + let end = consume_letters(text, idx)?; + Some((end - idx, Cl100kMatchKind::LetterWithPrefix)) +} + +fn consume_letters(text: &str, start: usize) -> Option { + let mut end = start; + let mut count = 0usize; + while let Some((ch, next)) = char_at(text, end) { + if !is_letter(ch) { + break; + } + end = next; + count += 1; + } + if count == 0 { None } else { Some(end) } +} + +// Regex branch: `\p{N}{1,3}+` +fn match_short_number(text: &str, idx: usize) -> Option<(usize, Cl100kMatchKind)> { + let (cp1, mut end) = char_at(text, idx)?; + if !is_number(cp1) { + return None; + } + + let mut count = 1usize; + while count < 3 { + if let Some((cp, next)) = char_at(text, end) { + if is_number(cp) { + end = next; + count += 1; + } else { + break; + } + } else { + break; + } + } + + Some((end - idx, Cl100kMatchKind::Number)) +} + +// Regex branch: ` ?[^\s\p{L}\p{N}]++[\r\n]*+` +fn match_punct_run(text: &str, idx: usize) -> Option<(usize, Cl100kMatchKind)> { + let mut cursor = idx; + + if let Some((ch, next)) = char_at(text, cursor) { + if ch == ' ' { + let (after_space, _) = char_at(text, next)?; + if is_space(after_space) || is_alnum(after_space) { + return None; + } + cursor = next; + } + } + + let mut end = cursor; + let mut took = false; + while let Some((ch, next)) = char_at(text, end) { + if is_space(ch) || is_alnum(ch) { + break; + } + end = next; + took = true; + } + + if !took { + return None; + } + + while let Some((ch, next)) = char_at(text, end) { + if !is_cr_or_lf(ch) { + break; + } + end = next; + } + + Some((end - idx, Cl100kMatchKind::Punctuation)) +} + +// Regex branch: `\s++$` +fn match_whitespace_to_eof(text: &str, idx: usize) -> Option<(usize, Cl100kMatchKind)> { + let (first, mut end) = char_at(text, idx)?; + if !is_space(first) { + return None; + } + + while let Some((ch, next)) = char_at(text, end) { + if !is_space(ch) { + break; + } + end = next; + } + + if end == text.len() { + Some((end - idx, Cl100kMatchKind::WhitespaceToEof)) + } else { + None + } +} + +// Regex branch: `\s*[\r\n]` +fn match_ws_then_linebreak(text: &str, idx: usize) -> Option<(usize, Cl100kMatchKind)> { + let mut pos = idx; + let mut best: Option = None; + + if let Some((ch, _)) = char_at(text, pos) { + if is_cr_or_lf(ch) { + best = Some(pos); + } + } + + while let Some((ch, next)) = char_at(text, pos) { + if !is_space(ch) { + break; + } + pos = next; + if let Some((next_ch, _)) = char_at(text, pos) { + if is_cr_or_lf(next_ch) { + best = Some(pos); + } + } + } + + let newline_pos = best?; + let (_, newline_end) = char_at(text, newline_pos)?; + Some((newline_end - idx, Cl100kMatchKind::WhitespaceThenLinebreak)) +} + +// Regex branch: `\s+(?!\S)` +fn match_trailing_ws(text: &str, idx: usize) -> Option<(usize, Cl100kMatchKind)> { + let (first, mut run_end) = char_at(text, idx)?; + if !is_space(first) { + return None; + } + + while let Some((ch, next)) = char_at(text, run_end) { + if !is_space(ch) { + break; + } + run_end = next; + } + + if run_end == text.len() { + return Some((run_end - idx, Cl100kMatchKind::TrailingWhitespace)); + } + + let prev_start = prev_char_start(text, run_end, idx)?; + if prev_start == idx { + return None; + } + + Some((prev_start - idx, Cl100kMatchKind::TrailingWhitespace)) +} + +// Regex branch: `\s` +fn match_single_ws(text: &str, idx: usize) -> Option<(usize, Cl100kMatchKind)> { + let (ch, next) = char_at(text, idx)?; + if is_space(ch) { + Some((next - idx, Cl100kMatchKind::SingleWhitespace)) + } else { + None + } +} + +fn char_at(text: &str, idx: usize) -> Option<(char, usize)> { + if idx >= text.len() { + return None; + } + let mut iter = text[idx..].char_indices(); + let (offset, ch) = iter.next()?; + let next = idx + offset + ch.len_utf8(); + Some((ch, next)) +} + +fn prev_char_start(text: &str, idx: usize, floor: usize) -> Option { + if idx <= floor { + return None; + } + let slice = &text[floor..idx]; + slice + .char_indices() + .last() + .map(|(offset, _)| floor + offset) +} + +fn ascii_lower(c: char) -> char { + if c.is_ascii_uppercase() { + c.to_ascii_lowercase() + } else { + c + } +} + +fn is_cr_or_lf(ch: char) -> bool { + matches!(ch, '\r' | '\n') +} + +fn is_letter(ch: char) -> bool { + matches!( + ch.general_category(), + GeneralCategory::UppercaseLetter + | GeneralCategory::LowercaseLetter + | GeneralCategory::TitlecaseLetter + | GeneralCategory::ModifierLetter + | GeneralCategory::OtherLetter + ) && ch.is_alphabetic() +} + +fn is_number(ch: char) -> bool { + matches!( + ch.general_category(), + GeneralCategory::DecimalNumber + | GeneralCategory::LetterNumber + | GeneralCategory::OtherNumber + ) && ch.is_numeric() +} + +fn is_space(ch: char) -> bool { + matches!( + ch.general_category(), + GeneralCategory::SpaceSeparator + | GeneralCategory::LineSeparator + | GeneralCategory::ParagraphSeparator + ) || ch.is_whitespace() +} + +fn is_alnum(ch: char) -> bool { + is_letter(ch) || is_number(ch) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn contraction_matches() { + let parser = Cl100kParser::new(); + let matches = parser + .find_iter("'re") + .map(|m| m.as_str().to_string()) + .collect::>(); + assert_eq!(matches, vec!["'re".to_string()]); + } + + #[test] + fn optional_prefix_word() { + let parser = Cl100kParser::new(); + let mut iter = parser.find_iter("!Hello world"); + let first = iter.next().unwrap(); + assert_eq!(first.as_str(), "!Hello"); + assert_eq!(first.kind(), Cl100kMatchKind::LetterWithPrefix); + } + + #[test] + fn numeric_span_limits_to_three_digits() { + let parser = Cl100kParser::new(); + let mut iter = parser.find_iter("1234"); + let first = iter.next().unwrap(); + assert_eq!(first.as_str(), "123"); + assert_eq!(first.kind(), Cl100kMatchKind::Number); + let second = iter.next().unwrap(); + assert_eq!(second.as_str(), "4"); + assert_eq!(second.kind(), Cl100kMatchKind::Number); + } + + #[test] + fn punctuation_run_consumes_trailing_newlines() { + let parser = Cl100kParser::new(); + let mut iter = parser.find_iter(" !?\nfoo"); + let first = iter.next().unwrap(); + assert_eq!(first.as_str(), " !?\n"); + assert_eq!(first.kind(), Cl100kMatchKind::Punctuation); + } + + #[test] + fn whitespace_to_eof_branch() { + let parser = Cl100kParser::new(); + let mut iter = parser.find_iter("foo "); + let _ = iter.next(); + let spatial = iter.next().unwrap(); + assert_eq!(spatial.as_str(), " "); + assert_eq!(spatial.kind(), Cl100kMatchKind::WhitespaceToEof); + } + + #[test] + fn whitespace_then_linebreak_branch() { + let parser = Cl100kParser::new(); + let mut iter = parser.find_iter(" \nabc"); + let first = iter.next().unwrap(); + assert_eq!(first.as_str(), " \n"); + assert_eq!(first.kind(), Cl100kMatchKind::WhitespaceThenLinebreak); + } + + #[test] + fn trailing_whitespace_branch() { + let parser = Cl100kParser::new(); + let mut iter = parser.find_iter(" X"); + let first = iter.next().unwrap(); + assert_eq!(first.as_str(), " "); + assert_eq!(first.kind(), Cl100kMatchKind::TrailingWhitespace); + } + + #[test] + fn single_whitespace_branch() { + let parser = Cl100kParser::new(); + let mut iter = parser.find_iter("\t!"); + let first = iter.next().unwrap(); + assert_eq!(first.as_str(), "\t"); + assert_eq!(first.kind(), Cl100kMatchKind::SingleWhitespace); + } +} diff --git a/src/lib.rs b/src/lib.rs index 5e89bfd..0658370 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,3 @@ -use std::borrow::Borrow; -use std::borrow::Cow; use std::collections::HashSet; use std::num::NonZeroU64; use std::thread; @@ -9,9 +7,23 @@ use fancy_regex::Regex; use pyo3::prelude::*; use rustc_hash::FxHashMap as HashMap; +pub mod cl100k; + #[cfg(feature = "python")] mod py; +#[derive(Clone)] +enum PatternBackend { + FancyRegex(Vec), + Cl100k(cl100k::Cl100kParser), +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum PatternBackendChoice { + FancyRegex, + Cl100kParser, +} + pub type Rank = u32; fn _byte_pair_merge(ranks: &HashMap, Rank>, piece: &[u8]) -> Vec<(usize, Rank)> { @@ -194,17 +206,23 @@ pub struct CoreBPE { special_tokens_encoder: HashMap, decoder: HashMap>, special_tokens_decoder: HashMap>, - regex_tls: Vec, + pattern_backend: PatternBackend, special_regex_tls: Vec, sorted_token_bytes: Vec>, } impl CoreBPE { - fn _get_tl_regex(&self) -> &Regex { - // See performance notes above for what this is about - // It's also a little janky, please make a better version of it! - // However, it's nice that this doesn't leak memory to short-lived threads - &self.regex_tls[hash_current_thread() % MAX_NUM_THREADS] + #[inline(always)] + fn encode_piece_into(&self, piece: &[u8], out: &mut Vec) -> usize { + if let Some(token) = self.encoder.get(piece) { + out.push(*token); + 1 + } else { + let tokens = byte_pair_encode(piece, &self.encoder); + let len = tokens.len(); + out.extend(&tokens); + len + } } fn _get_tl_special_regex(&self) -> &Regex { @@ -214,7 +232,7 @@ impl CoreBPE { /// Decodes tokens into a list of bytes. /// /// The bytes are not gauranteed to be a valid utf-8 string. - fn decode_bytes(&self, tokens: &[Rank]) -> Result, DecodeKeyError> { + pub fn decode_bytes(&self, tokens: &[Rank]) -> Result, DecodeKeyError> { let mut ret = Vec::with_capacity(tokens.len() * 2); for &token in tokens { let token_bytes = match self.decoder.get(&token) { @@ -232,13 +250,22 @@ impl CoreBPE { pub fn encode_ordinary(&self, text: &str) -> Vec { // This is the core of the encoding logic; the other functions in here // just make things complicated :-) - let regex = self._get_tl_regex(); - let mut ret = vec![]; - for mat in regex.find_iter(text) { - let piece = mat.unwrap().as_str().as_bytes(); - match self.encoder.get(piece) { - Some(token) => ret.push(*token), - None => ret.extend(&byte_pair_encode(piece, &self.encoder)), + let mut ret = Vec::new(); + match &self.pattern_backend { + PatternBackend::Cl100k(parser) => { + for mat in parser.find_iter(text) { + self.encode_piece_into(mat.as_str().as_bytes(), &mut ret); + } + } + PatternBackend::FancyRegex(regex_tls) => { + let regex = ®ex_tls[hash_current_thread() % MAX_NUM_THREADS]; + for mat in regex.find_iter(text) { + let piece = mat + .expect("fancy-regex error while tokenizing") + .as_str() + .as_bytes(); + self.encode_piece_into(piece, &mut ret); + } } } ret @@ -250,8 +277,19 @@ impl CoreBPE { allowed_special: &HashSet<&str>, ) -> Result<(Vec, usize), EncodeError> { let special_regex = self._get_tl_special_regex(); - let regex = self._get_tl_regex(); - let mut ret = vec![]; + enum PatternRunner<'a> { + Cl100k(&'a cl100k::Cl100kParser), + Fancy(&'a Regex), + } + + let pattern_runner = match &self.pattern_backend { + PatternBackend::Cl100k(parser) => PatternRunner::Cl100k(parser), + PatternBackend::FancyRegex(regex_tls) => { + PatternRunner::Fancy(®ex_tls[hash_current_thread() % MAX_NUM_THREADS]) + } + }; + + let mut ret = Vec::new(); let mut start = 0; let mut last_piece_token_len = 0; @@ -274,25 +312,28 @@ impl CoreBPE { let end = next_special.map_or(text.len(), |m| m.start()); // Okay, here we go, compare this logic to encode_ordinary - for mat_res in regex.find_iter(&text[start..end]) { - let mat = match mat_res { - Ok(m) => m, - Err(e) => { - return Err(EncodeError { - message: format!("Regex error while tokenizing: {e}"), - }); + let segment = &text[start..end]; + match &pattern_runner { + PatternRunner::Cl100k(parser) => { + for mat in parser.find_iter(segment) { + last_piece_token_len = + self.encode_piece_into(mat.as_str().as_bytes(), &mut ret); + } + } + PatternRunner::Fancy(regex) => { + for mat_res in regex.find_iter(segment) { + let mat = match mat_res { + Ok(m) => m, + Err(e) => { + return Err(EncodeError { + message: format!("Regex error while tokenizing: {e}"), + }); + } + }; + last_piece_token_len = + self.encode_piece_into(mat.as_str().as_bytes(), &mut ret); } - }; - - let piece = mat.as_str().as_bytes(); - if let Some(token) = self.encoder.get(piece) { - last_piece_token_len = 1; - ret.push(*token); - continue; } - let tokens = byte_pair_encode(piece, &self.encoder); - last_piece_token_len = tokens.len(); - ret.extend(&tokens); } match next_special { @@ -487,12 +528,46 @@ impl CoreBPE { ) } + pub fn new_with_backend( + encoder: E, + special_tokens_encoder: SE, + pattern: &str, + backend_choice: PatternBackendChoice, + ) -> Result> + where + E: IntoIterator, Rank)>, + SE: IntoIterator, + NSE: IntoIterator, + { + Self::new_internal_with_backend( + HashMap::from_iter(encoder), + HashMap::from_iter(special_tokens_encoder), + pattern, + backend_choice, + ) + } + fn new_internal( encoder: HashMap, Rank>, special_tokens_encoder: HashMap, pattern: &str, ) -> Result> { - let regex = Regex::new(pattern)?; + let default_choice = default_backend_for_pattern(pattern); + Self::new_internal_with_backend( + encoder, + special_tokens_encoder, + pattern, + default_choice, + ) + } + + fn new_internal_with_backend( + encoder: HashMap, Rank>, + special_tokens_encoder: HashMap, + pattern: &str, + backend_choice: PatternBackendChoice, + ) -> Result> { + let pattern_backend = build_pattern_backend(pattern, backend_choice)?; let special_regex = { let parts = special_tokens_encoder @@ -526,7 +601,7 @@ impl CoreBPE { special_tokens_encoder, decoder, special_tokens_decoder, - regex_tls: (0..MAX_NUM_THREADS).map(|_| regex.clone()).collect(), + pattern_backend, special_regex_tls: (0..MAX_NUM_THREADS) .map(|_| special_regex.clone()) .collect(), @@ -547,6 +622,43 @@ impl CoreBPE { } } +fn build_pattern_backend( + pattern: &str, + backend_choice: PatternBackendChoice, +) -> Result> { + match backend_choice { + PatternBackendChoice::Cl100kParser => { + if pattern == cl100k::CL100K_PATTERN { + Ok(PatternBackend::Cl100k(cl100k::Cl100kParser::new())) + } else { + // Error if Cl100kParser requested but pattern is not CL100K_PATTERN + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidInput, + format!( + "Cannot use PatternBackendChoice::Cl100kParser with a pattern other than cl100k::CL100K_PATTERN.\nGot pattern: '{}'\nExpected: '{}'", + pattern, cl100k::CL100K_PATTERN + ), + ) + .into()); + } + } + PatternBackendChoice::FancyRegex => { + let regex = Regex::new(pattern)?; + Ok(PatternBackend::FancyRegex( + (0..MAX_NUM_THREADS).map(|_| regex.clone()).collect(), + )) + } + } +} + +fn default_backend_for_pattern(pattern: &str) -> PatternBackendChoice { + if pattern == cl100k::CL100K_PATTERN { + PatternBackendChoice::Cl100kParser + } else { + PatternBackendChoice::FancyRegex + } +} + #[cfg(test)] mod tests { use fancy_regex::Regex;