Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
eisber committed Feb 27, 2023
2 parents 98070f1 + f5fbc9c commit 3b39f6b
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 23 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,12 +93,12 @@ from setuptools import setup, find_namespace_packages

setup(
name="my_tiktoken_extension",
packages=find_namespace_packages(include=['tiktoken_ext.*'])
packages=find_namespace_packages(include=['tiktoken_ext*']),
install_requires=["tiktoken"],
...
)
```

Then simply `pip install my_tiktoken_extension` and you should be able to use your custom encodings!
Make sure **not** to use an editable install.
Then simply `pip install ./my_tiktoken_extension` and you should be able to use your
custom encodings! Make sure **not** to use an editable install.

104 changes: 84 additions & 20 deletions core/src/util.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,57 @@
use rustc_hash::FxHashMap as HashMap;

fn _byte_pair_merge<T>(
piece: &[u8],
ranks: &HashMap<Vec<u8>, usize>,
f: impl Fn(std::ops::Range<usize>) -> T,
) -> Vec<T> {
// This is a vector of (start, rank).
// The rank is of the byte pair starting at position start.
// The rank of the last item in the vector is not a valid value.
let mut parts: Vec<(usize, usize)> = (0..piece.len() + 1).map(|i| (i, usize::MAX)).collect();

pub fn _byte_pair_merge(piece: &[u8], ranks: &HashMap<Vec<u8>, usize>) -> Vec<std::ops::Range<usize>> {
let mut parts: Vec<_> = (0..piece.len()).map(|i| i..i + 1).collect();
// NOTE: using a macro here because a closure fails to get inlined
// according to optimization remarks.
// A closure also cannot capture a reference to `piece` without
// the borrow checker complaining about the mutable borrows during
// the assignments later in this code.
macro_rules! get_rank {
($start_idx:expr, $skip:expr) => {{
let start_idx: usize = $start_idx;
let skip: usize = $skip;
if (start_idx + skip + 2) < parts.len() {
ranks
.get(&piece[parts[start_idx].0..parts[start_idx + skip + 2].0])
.map(|r| *r)
} else {
None
}
}};
($idx:expr) => {{
get_rank!($idx, 0)
}};
}

// We look up the ranks once in the beggining and iteratively update
// them during each merge, which reduces the number of rank lookups.
for i in 0..parts.len() - 2 {
match get_rank!(i) {
Some(rank) => {
// usize::MAX is a sentinel value and cannot be a valid rank
debug_assert!(rank != usize::MAX);
parts[i].1 = rank;
}
None => {
continue;
}
};
}

// If you have n parts and m merges, this does O(mn) work
// We could do something with a heap and do O(m log n) work
// If you have n parts and m merges, this does O(mn) work.
// We could do something with a heap and do O(m log n) work.
// It is important to consider that n is often small (<100), and as such
// the cache-locality benefits outweigh the algorithmic complexity downsides
// of the `parts` vector data structure above.

// Note that we hash bytes, not token pairs. As long as we train BPE the way we
// currently do, this is equivalent. An easy way to break this would be to decouple
Expand All @@ -14,35 +60,53 @@ pub fn _byte_pair_merge(piece: &[u8], ranks: &HashMap<Vec<u8>, usize>) -> Vec<st
if parts.len() == 1 {
break;
}
let mut min_rank: Option<(usize, usize)> = None;
for i in 0..parts.len() - 1 {
let rank = if let Some(r) = ranks.get(&piece[parts[i].start..parts[i + 1].end]) {
*r
} else {
continue;
};
if min_rank.is_none() || rank < min_rank.unwrap().0 {
min_rank = Some((rank, i));

// usize::MAX is a sentinel rank value allowing us to
// take the min more quickly
let mut min_rank: (usize, usize) = (usize::MAX, 0);
for (i, &(_, rank)) in parts[..parts.len() - 1].iter().enumerate() {
if rank < min_rank.0 {
min_rank = (rank, i);
}
}
if let Some((_, i)) = min_rank {
parts[i] = parts[i].start..parts[i + 1].end;

if min_rank.0 != usize::MAX {
let i = min_rank.1;

// NOTE: We are about to remove parts[i + 1]. We do not do it
// yet because there are cache-locality benefits to updating
// parts[i] and parts[i-1] before removing, which could thrash
// the cache. Thus, we update the rank calculation by skipping over
// parts[i + 1], by invoking `get_rank!` with `skip = 1`.
parts[i].1 = get_rank!(i, 1).unwrap_or(usize::MAX);
if i > 0 {
parts[i - 1].1 = get_rank!(i - 1, 1).unwrap_or(usize::MAX);
}

parts.remove(i + 1);
} else {
break;
}
}
parts
let mut out: Vec<T> = Vec::with_capacity(parts.len() - 1);
for i in 0..parts.len() - 1 {
out.push(f(parts[i].0..parts[i + 1].0));
}
out
}

pub fn byte_pair_encode(piece: &[u8], ranks: &HashMap<Vec<u8>, usize>) -> Vec<usize> {
if piece.len() == 1 {
return vec![ranks[piece]];
}
_byte_pair_merge(piece, ranks)
.iter()
.map(|p| ranks[&piece[p.start..p.end]])
.collect()
_byte_pair_merge(piece, ranks, |p| ranks[&piece[p.start..p.end]])
}

pub fn byte_pair_split<'a>(piece: &'a [u8], ranks: &HashMap<Vec<u8>, usize>) -> Vec<&'a [u8]> {
if piece.len() == 1 {
return vec![piece];
}
_byte_pair_merge(piece, ranks, |p| &piece[p.start..p.end])
}

#[cfg(test)]
Expand Down
9 changes: 9 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,9 +1,18 @@
[project]
name = "tiktoken"
version = "0.2.0"
description = "tiktoken is a fast BPE tokeniser for use with OpenAI's models"
readme = "README.md"
license = {file = "LICENSE"}
authors = [{name = "Shantanu Jain"}, {email = "[email protected]"}]
dependencies = ["blobfile>=2", "regex>=2022.1.18", "requests>=2.26.0"]
requires-python = ">=3.8"

[project.urls]
homepage = "https://github.com/openai/tiktoken"
repository = "https://github.com/openai/tiktoken"
changelog = "https://github.com/openai/tiktoken/blob/main/CHANGELOG.md"

[build-system]
build-backend = "setuptools.build_meta"
requires = ["setuptools>=62.4", "wheel", "setuptools-rust>=1.5.2"]
Expand Down
2 changes: 2 additions & 0 deletions tests/test_simple_public.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,5 @@ def test_encoding_for_model():
assert enc.name == "gpt2"
enc = tiktoken.encoding_for_model("text-davinci-003")
assert enc.name == "p50k_base"
enc = tiktoken.encoding_for_model("text-davinci-edit-001")
assert enc.name == "p50k_edit"

0 comments on commit 3b39f6b

Please sign in to comment.