Skip to content

Commit 1b9faf2

Browse files
hauntsaninjal0rinc
andauthored
Simplify byte_pair_merge (#255)
Based on suggestion in #239 (specifically 8f5dd7d) Like that commit, this: - Does the init in a single loop and saves a loop if there are no merges - Simplifies get_rank and no longer uses it in init (so you don't need multiple skip values) Unlike that commit: - We drop optimisations enabled by ignoring single tokens. These didn't show any benefit on benchmarks for me (this makes sense given typical piece sizes, but let me know if that's unexpected!). Given this, I opted for the simpler version. - I preserve some of the comments from the original that I think are still useful Co-authored-by: @paplorinc --------- Co-authored-by: Lőrinc Pap <1841944+paplorinc@users.noreply.github.com>
1 parent 6defed5 commit 1b9faf2

File tree

1 file changed

+36
-60
lines changed

1 file changed

+36
-60
lines changed

src/lib.rs

+36-60
Original file line numberDiff line numberDiff line change
@@ -15,85 +15,61 @@ use rustc_hash::FxHashMap as HashMap;
1515

1616
type Rank = u32;
1717

18-
fn _byte_pair_merge(
19-
ranks: &HashMap<Vec<u8>, Rank>,
20-
piece: &[u8],
21-
) -> Vec<(usize, Rank)> {
18+
fn _byte_pair_merge(ranks: &HashMap<Vec<u8>, Rank>, piece: &[u8]) -> Vec<(usize, Rank)> {
2219
// This is a vector of (start, rank).
23-
// The rank is of the byte pair starting at position start.
24-
// The rank of the last item in the vector is not a valid value.
25-
let mut parts: Vec<(usize, Rank)> = (0..piece.len() + 1).map(|i| (i, Rank::MAX)).collect();
20+
// The rank is of the pair starting at position start.
21+
let mut parts = Vec::with_capacity(piece.len() + 1);
22+
23+
// Note that we hash bytes when indexing into `ranks`, not token pairs. As long as we train BPE
24+
// the way we currently do, this is equivalent. An easy way to break this would be to decouple
25+
// merge priority from token index or to prevent specific token merges.
26+
let mut min_rank: (Rank, usize) = (Rank::MAX, usize::MAX);
27+
for i in 0..piece.len() - 1 {
28+
let rank = *ranks.get(&piece[i..i + 2]).unwrap_or(&Rank::MAX);
29+
if rank < min_rank.0 {
30+
min_rank = (rank, i);
31+
}
32+
parts.push((i, rank));
33+
}
34+
parts.push((piece.len() - 1, Rank::MAX));
35+
parts.push((piece.len(), Rank::MAX));
2636

2737
let get_rank = {
2838
#[inline(always)]
29-
|parts: &Vec<(usize, Rank)>, start_idx: usize, skip: usize| {
30-
if (start_idx + skip + 2) < parts.len() {
31-
ranks
32-
.get(&piece[parts[start_idx].0..parts[start_idx + skip + 2].0])
33-
.copied()
39+
|parts: &Vec<(usize, Rank)>, i: usize| {
40+
if (i + 3) < parts.len() {
41+
// Similar to `piece[i..i + 2]` above. The +3 is because we haven't yet deleted
42+
// parts[i + 1], see comment in the main loop.
43+
*ranks
44+
.get(&piece[parts[i].0..parts[i + 3].0])
45+
.unwrap_or(&Rank::MAX)
3446
} else {
35-
None
47+
Rank::MAX
3648
}
3749
}
3850
};
3951

40-
// We look up the ranks once in the beginning and iteratively update
41-
// them during each merge, which reduces the number of rank lookups.
42-
for i in 0..parts.len() - 2 {
43-
match get_rank(&parts, i, 0) {
44-
Some(rank) => {
45-
// Rank::MAX is a sentinel value and cannot be a valid rank
46-
debug_assert!(rank != Rank::MAX);
47-
parts[i].1 = rank;
48-
}
49-
None => {
50-
continue;
51-
}
52-
};
53-
}
54-
5552
// If you have n parts and m merges, this does O(mn) work.
5653
// We could do something with a heap and do O(m log n) work.
57-
// It is important to consider that n is often small (<100), and as such
58-
// the cache-locality benefits outweigh the algorithmic complexity downsides
59-
// of the `parts` vector data structure above.
60-
61-
// Note that we hash bytes, not token pairs. As long as we train BPE the way we
62-
// currently do, this is equivalent. An easy way to break this would be to decouple
63-
// merge priority from token index or to prevent specific token merges.
64-
loop {
65-
if parts.len() == 1 {
66-
break;
54+
// n is often very small so considerations like cache-locality outweigh the algorithmic
55+
// complexity downsides of the `parts` vector.
56+
while min_rank.0 != Rank::MAX {
57+
let i = min_rank.1;
58+
// Update parts[i] and parts[i - 1] before removing parts[i + 1], since
59+
// `parts.remove(i + 1)` will thrash the cache.
60+
if i > 0 {
61+
parts[i - 1].1 = get_rank(&parts, i - 1);
6762
}
63+
parts[i].1 = get_rank(&parts, i);
64+
parts.remove(i + 1);
6865

69-
// Rank::MAX is a sentinel rank value allowing us to
70-
// take the min more quickly
71-
let mut min_rank: (Rank, usize) = (Rank::MAX, 0);
66+
min_rank = (Rank::MAX, usize::MAX);
7267
for (i, &(_, rank)) in parts[..parts.len() - 1].iter().enumerate() {
7368
if rank < min_rank.0 {
7469
min_rank = (rank, i);
7570
}
7671
}
77-
78-
if min_rank.0 != Rank::MAX {
79-
let i = min_rank.1;
80-
81-
// NOTE: We are about to remove parts[i + 1]. We do not do it
82-
// yet because there are cache-locality benefits to updating
83-
// parts[i] and parts[i-1] before removing, which could thrash
84-
// the cache. Thus, we update the rank calculation by skipping over
85-
// parts[i + 1], by invoking `get_rank!` with `skip = 1`.
86-
parts[i].1 = get_rank(&parts, i, 1).unwrap_or(Rank::MAX);
87-
if i > 0 {
88-
parts[i - 1].1 = get_rank(&parts, i - 1, 1).unwrap_or(Rank::MAX);
89-
}
90-
91-
parts.remove(i + 1);
92-
} else {
93-
break;
94-
}
9572
}
96-
9773
parts
9874
}
9975

0 commit comments

Comments
 (0)