1- use std:: borrow:: Borrow ;
2- use std:: borrow:: Cow ;
31use std:: collections:: HashSet ;
42use std:: num:: NonZeroU64 ;
53use std:: thread;
@@ -14,6 +12,131 @@ mod py;
1412
1513pub type Rank = u32 ;
1614
15+ use std:: collections:: BinaryHeap ;
16+
17+ #[ derive( Eq , PartialEq , Clone , Copy ) ]
18+ struct Merge {
19+ start : usize ,
20+ rank : Rank ,
21+ }
22+
23+ impl Ord for Merge {
24+ #[ inline]
25+ fn cmp ( & self , other : & Self ) -> std:: cmp:: Ordering {
26+ other
27+ . rank
28+ . cmp ( & self . rank )
29+ . then_with ( || other. start . cmp ( & self . start ) )
30+ }
31+ }
32+
33+ impl PartialOrd for Merge {
34+ fn partial_cmp ( & self , other : & Self ) -> Option < std:: cmp:: Ordering > {
35+ Some ( self . cmp ( other) )
36+ }
37+ }
38+
39+ struct State {
40+ prev : usize ,
41+ end : usize ,
42+ next_end : usize ,
43+ next_rank : Rank ,
44+ cur_rank : Rank ,
45+ }
46+
47+ fn _byte_pair_merge_large ( ranks : & HashMap < Vec < u8 > , Rank > , piece : & [ u8 ] ) -> Vec < Rank > {
48+ let mut state = Vec :: with_capacity ( piece. len ( ) ) ;
49+ state. push ( State {
50+ prev : usize:: MAX ,
51+ end : 1 ,
52+ next_end : 2 ,
53+ next_rank : Rank :: MAX ,
54+ cur_rank : Rank :: MAX ,
55+ } ) ;
56+
57+ let mut heap = BinaryHeap :: with_capacity ( piece. len ( ) ) ;
58+ for i in 0 ..piece. len ( ) - 1 {
59+ if let Some ( & rank) = ranks. get ( & piece[ i..i + 2 ] ) {
60+ heap. push ( Merge { start : i, rank } ) ;
61+ state[ i] . next_rank = rank;
62+ }
63+ // note this is happening offset by 1
64+ state. push ( State {
65+ prev : i,
66+ end : i + 2 ,
67+ next_end : i + 3 ,
68+ next_rank : Rank :: MAX ,
69+ cur_rank : Rank :: MAX ,
70+ } ) ;
71+ }
72+
73+ // Repeatedly find the valid merge with smallest rank. We merge the (left) token that
74+ // starts at `start` and ends at `state[start].end` with the (right) token that starts at
75+ // `state[start].end` and ends at `state[start].next_end`. We invalidate the old merges
76+ // (the ones that started at `state[start].end` and ended at `state[start]`) and add the two
77+ // new potential merges to the heap.
78+
79+ let potential_merge = {
80+ #[ inline( always) ]
81+ |state : & mut Vec < State > ,
82+ heap : & mut BinaryHeap < Merge > ,
83+ start : usize ,
84+ next_end_item : usize | {
85+ state[ start] . next_end = next_end_item;
86+ state[ start] . next_rank = Rank :: MAX ; // Always invalidate the old merge
87+ if next_end_item <= piece. len ( ) {
88+ if let Some ( & rank) = ranks. get ( & piece[ start..next_end_item] ) {
89+ // We have a valid potential merge!
90+ heap. push ( Merge { start, rank } ) ;
91+ state[ start] . next_rank = rank;
92+ }
93+ }
94+ }
95+ } ;
96+
97+ while let Some ( left) = heap. pop ( ) {
98+ if left. rank == Rank :: MAX {
99+ break ;
100+ }
101+ if left. rank != state[ left. start ] . next_rank {
102+ continue ; // This merge was invalidated, ignore it
103+ }
104+
105+ let left_start = left. start ;
106+ let right_start = state[ left_start] . end ;
107+ let right_end = state[ left_start] . next_end ;
108+ debug_assert ! ( right_end == state[ right_start] . end) ;
109+ let right_next_end = state[ right_start] . next_end ;
110+
111+ // Merge left and right into a single token
112+ state[ left_start] . cur_rank = state[ left_start] . next_rank ;
113+ state[ left_start] . end = right_end;
114+ potential_merge ( & mut state, & mut heap, left_start, right_next_end) ;
115+ if right_end < state. len ( ) {
116+ state[ right_end] . prev = left_start;
117+ }
118+ // Update the merge that ends at left_start
119+ if left_start > 0 {
120+ let prev_start = state[ left_start] . prev ;
121+ potential_merge ( & mut state, & mut heap, prev_start, right_end) ;
122+ }
123+ // Invalidate the merge starting at right_start, so we ignore it when it comes off the heap
124+ state[ right_start] . next_rank = Rank :: MAX ;
125+ }
126+
127+ let mut result = Vec :: new ( ) ;
128+ let mut i = 0 ;
129+ while i < state. len ( ) {
130+ if state[ i] . cur_rank != Rank :: MAX {
131+ result. push ( state[ i] . cur_rank ) ;
132+ } else {
133+ result. push ( ranks[ & piece[ i..state[ i] . end ] ] ) ;
134+ }
135+ i = state[ i] . end ;
136+ }
137+ result
138+ }
139+
17140fn _byte_pair_merge ( ranks : & HashMap < Vec < u8 > , Rank > , piece : & [ u8 ] ) -> Vec < ( usize , Rank ) > {
18141 // This is a vector of (start, rank).
19142 // The rank is of the pair starting at position start.
@@ -73,21 +196,27 @@ fn _byte_pair_merge(ranks: &HashMap<Vec<u8>, Rank>, piece: &[u8]) -> Vec<(usize,
73196}
74197
75198pub fn byte_pair_encode ( piece : & [ u8 ] , ranks : & HashMap < Vec < u8 > , Rank > ) -> Vec < Rank > {
76- if piece. len ( ) == 1 {
199+ let piece_len = piece. len ( ) ;
200+
201+ if piece_len == 1 {
77202 return vec ! [ ranks[ piece] ] ;
78203 }
79- _byte_pair_merge ( ranks, piece)
80- . windows ( 2 )
81- . map ( |part| ranks[ & piece[ part[ 0 ] . 0 ..part[ 1 ] . 0 ] ] )
82- . collect ( )
204+ if piece_len < 100 {
205+ return _byte_pair_merge ( ranks, piece)
206+ . windows ( 2 )
207+ . map ( |part| ranks[ & piece[ part[ 0 ] . 0 ..part[ 1 ] . 0 ] ] )
208+ . collect ( ) ;
209+ }
210+ return _byte_pair_merge_large ( ranks, piece) ;
83211}
84212
85213pub fn byte_pair_split < ' a > ( piece : & ' a [ u8 ] , ranks : & HashMap < Vec < u8 > , Rank > ) -> Vec < & ' a [ u8 ] > {
86214 assert ! ( piece. len( ) > 1 ) ;
87- _byte_pair_merge ( ranks, piece)
215+ return _byte_pair_merge ( ranks, piece)
88216 . windows ( 2 )
89217 . map ( |part| & piece[ part[ 0 ] . 0 ..part[ 1 ] . 0 ] )
90- . collect ( )
218+ . collect ( ) ;
219+ // TODO: _byte_pair_merge_large
91220}
92221
93222// Various performance notes:
@@ -521,7 +650,7 @@ impl CoreBPE {
521650
522651#[ cfg( test) ]
523652mod tests {
524- use fancy_regex :: Regex ;
653+
525654 use rustc_hash:: FxHashMap as HashMap ;
526655
527656 use crate :: { byte_pair_split, Rank } ;
0 commit comments