diff --git a/src/lib.rs b/src/lib.rs index 3d930655..fc97503d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,28 +2,31 @@ #![allow(clippy::borrow_deref_ref)] use std::collections::HashSet; +use std::num::NonZeroU64; use std::thread; use fancy_regex::Regex; use pyo3::exceptions; use pyo3::prelude::*; -use pyo3::types::{PyBytes, PyList, PyTuple}; use pyo3::PyResult; +use pyo3::types::{PyBytes, PyList, PyTuple}; use rustc_hash::FxHashMap as HashMap; +type Rank = u32; + fn _byte_pair_merge( piece: &[u8], - ranks: &HashMap, usize>, + ranks: &HashMap, Rank>, f: impl Fn(std::ops::Range) -> T, ) -> Vec { // This is a vector of (start, rank). // The rank is of the byte pair starting at position start. // The rank of the last item in the vector is not a valid value. - let mut parts: Vec<(usize, usize)> = (0..piece.len() + 1).map(|i| (i, usize::MAX)).collect(); + let mut parts: Vec<(usize, Rank)> = (0..piece.len() + 1).map(|i| (i, Rank::MAX)).collect(); let get_rank = { #[inline(always)] - |parts: &Vec<(usize, usize)>, start_idx: usize, skip: usize| { + |parts: &Vec<(usize, Rank)>, start_idx: usize, skip: usize| { if (start_idx + skip + 2) < parts.len() { ranks .get(&piece[parts[start_idx].0..parts[start_idx + skip + 2].0]) @@ -39,8 +42,8 @@ fn _byte_pair_merge( for i in 0..parts.len() - 2 { match get_rank(&parts, i, 0) { Some(rank) => { - // usize::MAX is a sentinel value and cannot be a valid rank - debug_assert!(rank != usize::MAX); + // Rank::MAX is a sentinel value and cannot be a valid rank + debug_assert!(rank != Rank::MAX); parts[i].1 = rank; } None => { @@ -63,16 +66,16 @@ fn _byte_pair_merge( break; } - // usize::MAX is a sentinel rank value allowing us to + // Rank::MAX is a sentinel rank value allowing us to // take the min more quickly - let mut min_rank: (usize, usize) = (usize::MAX, 0); + let mut min_rank: (Rank, usize) = (Rank::MAX, 0); for (i, &(_, rank)) in parts[..parts.len() - 1].iter().enumerate() { if rank < min_rank.0 { min_rank = (rank, i); } } - if min_rank.0 != usize::MAX { + if min_rank.0 != Rank::MAX { let i = min_rank.1; // NOTE: We are about to remove parts[i + 1]. We do not do it @@ -80,9 +83,9 @@ fn _byte_pair_merge( // parts[i] and parts[i-1] before removing, which could thrash // the cache. Thus, we update the rank calculation by skipping over // parts[i + 1], by invoking `get_rank!` with `skip = 1`. - parts[i].1 = get_rank(&parts, i, 1).unwrap_or(usize::MAX); + parts[i].1 = get_rank(&parts, i, 1).unwrap_or(Rank::MAX); if i > 0 { - parts[i - 1].1 = get_rank(&parts, i - 1, 1).unwrap_or(usize::MAX); + parts[i - 1].1 = get_rank(&parts, i - 1, 1).unwrap_or(Rank::MAX); } parts.remove(i + 1); @@ -97,14 +100,14 @@ fn _byte_pair_merge( out } -pub fn byte_pair_encode(piece: &[u8], ranks: &HashMap, usize>) -> Vec { +pub fn byte_pair_encode(piece: &[u8], ranks: &HashMap, Rank>) -> Vec { if piece.len() == 1 { return vec![ranks[piece]]; } _byte_pair_merge(piece, ranks, |p| ranks[&piece[p.start..p.end]]) } -pub fn byte_pair_split<'a>(piece: &'a [u8], ranks: &HashMap, usize>) -> Vec<&'a [u8]> { +pub fn byte_pair_split<'a>(piece: &'a [u8], ranks: &HashMap, Rank>) -> Vec<&'a [u8]> { if piece.len() == 1 { return vec![piece]; } @@ -152,7 +155,6 @@ pub fn byte_pair_split<'a>(piece: &'a [u8], ranks: &HashMap, usize>) -> // The current implementation ends up doing a lot of hashing of bytes. In theory, this could be made // to be hashing of two-tuples of ints, which looks like it may also be a couple percent faster. -use std::num::NonZeroU64; pub struct FakeThreadId(NonZeroU64); fn hash_current_thread() -> usize { @@ -169,12 +171,13 @@ fn hash_current_thread() -> usize { } const MAX_NUM_THREADS: usize = 128; + #[pyclass] struct CoreBPE { - encoder: HashMap, usize>, - special_tokens_encoder: HashMap, - decoder: HashMap>, - special_tokens_decoder: HashMap>, + encoder: HashMap, Rank>, + special_tokens_encoder: HashMap, + decoder: HashMap>, + special_tokens_decoder: HashMap>, regex_tls: Vec, special_regex_tls: Vec, sorted_token_bytes: Vec>, @@ -192,7 +195,7 @@ impl CoreBPE { &self.special_regex_tls[hash_current_thread() % MAX_NUM_THREADS] } - fn _decode_native(&self, tokens: &[usize]) -> Vec { + fn _decode_native(&self, tokens: &[Rank]) -> Vec { let mut ret = Vec::with_capacity(tokens.len() * 2); for token in tokens { let token_bytes = self @@ -204,7 +207,7 @@ impl CoreBPE { ret } - fn _encode_ordinary_native(&self, text: &str) -> Vec { + fn _encode_ordinary_native(&self, text: &str) -> Vec { // This is the core of the encoding logic; the other functions in here // just make things complicated :-) let regex = self._get_tl_regex(); @@ -220,7 +223,7 @@ impl CoreBPE { ret } - fn _encode_native(&self, text: &str, allowed_special: &HashSet<&str>) -> (Vec, usize) { + fn _encode_native(&self, text: &str, allowed_special: &HashSet<&str>) -> (Vec, usize) { let special_regex = self._get_tl_special_regex(); let regex = self._get_tl_regex(); let mut ret = vec![]; @@ -278,9 +281,9 @@ impl CoreBPE { fn _increase_last_piece_token_len( &self, - tokens: Vec, + tokens: Vec, mut last_piece_token_len: usize, - ) -> (Vec, usize) { + ) -> (Vec, usize) { // Unfortunately, the locations where our regex splits can be unstable. // For the purposes of determining unstable tokens, unstable regex splitting // is only a problem if a split that was present disappears, since this can @@ -319,7 +322,7 @@ impl CoreBPE { &self, text: &str, allowed_special: &HashSet<&str>, - ) -> (Vec, HashSet>) { + ) -> (Vec, HashSet>) { let (tokens, last_piece_token_len) = self._encode_native(text, allowed_special); if last_piece_token_len == 0 { // If last_piece_token_len is zero, the last token was a special token and we have @@ -436,8 +439,8 @@ impl CoreBPE { impl CoreBPE { #[new] fn new( - encoder: HashMap, usize>, - special_tokens_encoder: HashMap, + encoder: HashMap, Rank>, + special_tokens_encoder: HashMap, pattern: &str, ) -> PyResult { let regex = Regex::new(pattern) @@ -452,7 +455,7 @@ impl CoreBPE { .map_err(|e| PyErr::new::(e.to_string()))? }; - let decoder: HashMap> = + let decoder: HashMap> = encoder.iter().map(|(k, v)| (*v, k.clone())).collect(); assert!( @@ -460,7 +463,7 @@ impl CoreBPE { "Encoder and decoder must be of equal length; maybe you had duplicate token indices in your encoder?" ); - let special_tokens_decoder: HashMap> = special_tokens_encoder + let special_tokens_decoder: HashMap> = special_tokens_encoder .iter() .map(|(k, v)| (*v, k.as_bytes().to_vec())) .collect(); @@ -486,15 +489,15 @@ impl CoreBPE { // Encoding // ==================== - fn encode_ordinary(&self, py: Python, text: &str) -> Vec { + fn encode_ordinary(&self, py: Python, text: &str) -> Vec { py.allow_threads(|| self._encode_ordinary_native(text)) } - fn encode(&self, py: Python, text: &str, allowed_special: HashSet<&str>) -> Vec { + fn encode(&self, py: Python, text: &str, allowed_special: HashSet<&str>) -> Vec { py.allow_threads(|| self._encode_native(text, &allowed_special).0) } - fn _encode_bytes(&self, py: Python, bytes: &[u8]) -> Vec { + fn _encode_bytes(&self, py: Python, bytes: &[u8]) -> Vec { py.allow_threads(|| { match std::str::from_utf8(bytes) { Ok(text) => self._encode_ordinary_native(text), @@ -534,7 +537,7 @@ impl CoreBPE { (tokens, py_completions).into_py(py) } - fn encode_single_token(&self, piece: &[u8]) -> PyResult { + fn encode_single_token(&self, piece: &[u8]) -> PyResult { if let Some(token) = self.encoder.get(piece).copied() { return Ok(token); } @@ -546,7 +549,7 @@ impl CoreBPE { Err(PyErr::new::(piece.to_owned())) } - fn encode_single_piece(&self, piece: &[u8]) -> Vec { + fn encode_single_piece(&self, piece: &[u8]) -> Vec { if let Some(token) = self.encoder.get(piece) { return vec![*token]; } @@ -557,12 +560,12 @@ impl CoreBPE { // Decoding // ==================== - fn decode_bytes(&self, py: Python, tokens: Vec) -> Py { + fn decode_bytes(&self, py: Python, tokens: Vec) -> Py { let bytes = py.allow_threads(|| self._decode_native(&tokens)); PyBytes::new(py, &bytes).into() } - fn decode_single_token_bytes(&self, py: Python, token: usize) -> PyResult> { + fn decode_single_token_bytes(&self, py: Python, token: Rank) -> PyResult> { if let Some(bytes) = self.decoder.get(&token) { return Ok(PyBytes::new(py, bytes).into()); }