@@ -15,85 +15,61 @@ use rustc_hash::FxHashMap as HashMap;
15
15
16
16
type Rank = u32 ;
17
17
18
- fn _byte_pair_merge (
19
- ranks : & HashMap < Vec < u8 > , Rank > ,
20
- piece : & [ u8 ] ,
21
- ) -> Vec < ( usize , Rank ) > {
18
+ fn _byte_pair_merge ( ranks : & HashMap < Vec < u8 > , Rank > , piece : & [ u8 ] ) -> Vec < ( usize , Rank ) > {
22
19
// This is a vector of (start, rank).
23
- // The rank is of the byte pair starting at position start.
24
- // The rank of the last item in the vector is not a valid value.
25
- let mut parts: Vec < ( usize , Rank ) > = ( 0 ..piece. len ( ) + 1 ) . map ( |i| ( i, Rank :: MAX ) ) . collect ( ) ;
20
+ // The rank is of the pair starting at position start.
21
+ let mut parts = Vec :: with_capacity ( piece. len ( ) + 1 ) ;
22
+
23
+ // Note that we hash bytes when indexing into `ranks`, not token pairs. As long as we train BPE
24
+ // the way we currently do, this is equivalent. An easy way to break this would be to decouple
25
+ // merge priority from token index or to prevent specific token merges.
26
+ let mut min_rank: ( Rank , usize ) = ( Rank :: MAX , usize:: MAX ) ;
27
+ for i in 0 ..piece. len ( ) - 1 {
28
+ let rank = * ranks. get ( & piece[ i..i + 2 ] ) . unwrap_or ( & Rank :: MAX ) ;
29
+ if rank < min_rank. 0 {
30
+ min_rank = ( rank, i) ;
31
+ }
32
+ parts. push ( ( i, rank) ) ;
33
+ }
34
+ parts. push ( ( piece. len ( ) - 1 , Rank :: MAX ) ) ;
35
+ parts. push ( ( piece. len ( ) , Rank :: MAX ) ) ;
26
36
27
37
let get_rank = {
28
38
#[ inline( always) ]
29
- |parts : & Vec < ( usize , Rank ) > , start_idx : usize , skip : usize | {
30
- if ( start_idx + skip + 2 ) < parts. len ( ) {
31
- ranks
32
- . get ( & piece[ parts[ start_idx] . 0 ..parts[ start_idx + skip + 2 ] . 0 ] )
33
- . copied ( )
39
+ |parts : & Vec < ( usize , Rank ) > , i : usize | {
40
+ if ( i + 3 ) < parts. len ( ) {
41
+ // Similar to `piece[i..i + 2]` above. The +3 is because we haven't yet deleted
42
+ // parts[i + 1], see comment in the main loop.
43
+ * ranks
44
+ . get ( & piece[ parts[ i] . 0 ..parts[ i + 3 ] . 0 ] )
45
+ . unwrap_or ( & Rank :: MAX )
34
46
} else {
35
- None
47
+ Rank :: MAX
36
48
}
37
49
}
38
50
} ;
39
51
40
- // We look up the ranks once in the beginning and iteratively update
41
- // them during each merge, which reduces the number of rank lookups.
42
- for i in 0 ..parts. len ( ) - 2 {
43
- match get_rank ( & parts, i, 0 ) {
44
- Some ( rank) => {
45
- // Rank::MAX is a sentinel value and cannot be a valid rank
46
- debug_assert ! ( rank != Rank :: MAX ) ;
47
- parts[ i] . 1 = rank;
48
- }
49
- None => {
50
- continue ;
51
- }
52
- } ;
53
- }
54
-
55
52
// If you have n parts and m merges, this does O(mn) work.
56
53
// We could do something with a heap and do O(m log n) work.
57
- // It is important to consider that n is often small (<100), and as such
58
- // the cache-locality benefits outweigh the algorithmic complexity downsides
59
- // of the `parts` vector data structure above.
60
-
61
- // Note that we hash bytes, not token pairs. As long as we train BPE the way we
62
- // currently do, this is equivalent. An easy way to break this would be to decouple
63
- // merge priority from token index or to prevent specific token merges.
64
- loop {
65
- if parts. len ( ) == 1 {
66
- break ;
54
+ // n is often very small so considerations like cache-locality outweigh the algorithmic
55
+ // complexity downsides of the `parts` vector.
56
+ while min_rank. 0 != Rank :: MAX {
57
+ let i = min_rank. 1 ;
58
+ // Update parts[i] and parts[i - 1] before removing parts[i + 1], since
59
+ // `parts.remove(i + 1)` will thrash the cache.
60
+ if i > 0 {
61
+ parts[ i - 1 ] . 1 = get_rank ( & parts, i - 1 ) ;
67
62
}
63
+ parts[ i] . 1 = get_rank ( & parts, i) ;
64
+ parts. remove ( i + 1 ) ;
68
65
69
- // Rank::MAX is a sentinel rank value allowing us to
70
- // take the min more quickly
71
- let mut min_rank: ( Rank , usize ) = ( Rank :: MAX , 0 ) ;
66
+ min_rank = ( Rank :: MAX , usize:: MAX ) ;
72
67
for ( i, & ( _, rank) ) in parts[ ..parts. len ( ) - 1 ] . iter ( ) . enumerate ( ) {
73
68
if rank < min_rank. 0 {
74
69
min_rank = ( rank, i) ;
75
70
}
76
71
}
77
-
78
- if min_rank. 0 != Rank :: MAX {
79
- let i = min_rank. 1 ;
80
-
81
- // NOTE: We are about to remove parts[i + 1]. We do not do it
82
- // yet because there are cache-locality benefits to updating
83
- // parts[i] and parts[i-1] before removing, which could thrash
84
- // the cache. Thus, we update the rank calculation by skipping over
85
- // parts[i + 1], by invoking `get_rank!` with `skip = 1`.
86
- parts[ i] . 1 = get_rank ( & parts, i, 1 ) . unwrap_or ( Rank :: MAX ) ;
87
- if i > 0 {
88
- parts[ i - 1 ] . 1 = get_rank ( & parts, i - 1 , 1 ) . unwrap_or ( Rank :: MAX ) ;
89
- }
90
-
91
- parts. remove ( i + 1 ) ;
92
- } else {
93
- break ;
94
- }
95
72
}
96
-
97
73
parts
98
74
}
99
75
0 commit comments