Skip to content

Commit

Permalink
✨ feat(rust): Convert project to a multi-crate workspace
Browse files Browse the repository at this point in the history
This commit restructures the project from a single-crate workspace into a multi-crate workspace, dividing it into 'rs-tiktoken' and 'py-tiktoken'. This is done to improve the clarity of the organization of the codebase and make the Rust and Python modules separate for easier code maintenance. The setup.py is also updated to reflect these changes in the directory structure.

Refs: #24
  • Loading branch information
Miuler committed Sep 15, 2023
1 parent b7af705 commit 6b7624a
Show file tree
Hide file tree
Showing 14 changed files with 50,591 additions and 94 deletions.
36 changes: 34 additions & 2 deletions .github/workflows/build_wheels.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ jobs:
name: dist
path: ./wheelhouse/*.whl

build_wheels_aarch64:
name: py${{ matrix.python-version }} on ${{ matrix.os }} (aarch64)
build_wheels_aarch64_glibc:
name: py${{ matrix.python-version }} on ${{ matrix.os }} (aarch64/glibc)
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
Expand All @@ -52,6 +52,38 @@ jobs:
env:
CIBW_BUILD: "cp${{ matrix.python-version}}-*"
CIBW_ARCHS: aarch64
CIBW_SKIP: "*musllinux*"
CIBW_BUILD_VERBOSITY: 3
# https://github.com/rust-lang/cargo/issues/10583
CIBW_ENVIRONMENT_LINUX: PATH="$PATH:$HOME/.cargo/bin" CARGO_NET_GIT_FETCH_WITH_CLI=true
- uses: actions/upload-artifact@v3
with:
name: dist
path: ./wheelhouse/*.whl

build_wheels_aarch64_musl:
name: py${{ matrix.python-version }} on ${{ matrix.os }} (aarch64/musl)
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest]
python-version: [38, 39, 310, 311]

steps:
- uses: actions/checkout@v3

- name: Setup up QEMU
uses: docker/setup-qemu-action@v2
with:
platforms: arm64

- name: Build wheels
uses: pypa/[email protected]
env:
CIBW_BUILD: "cp${{ matrix.python-version}}-*"
CIBW_ARCHS: aarch64
CIBW_SKIP: "*manylinux*"
CIBW_BUILD_VERBOSITY: 3
# https://github.com/rust-lang/cargo/issues/10583
CIBW_ENVIRONMENT_LINUX: PATH="$PATH:$HOME/.cargo/bin" CARGO_NET_GIT_FETCH_WITH_CLI=true
Expand Down
27 changes: 9 additions & 18 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,21 +1,12 @@
[package]
name = "tiktoken"
[workspace]
resolver = "2"
members = [
"rs-tiktoken",
"py-tiktoken",
]

[workspace.package]
version = "0.5.1"
edition = "2021"
rust-version = "1.57.0"

[lib]
name = "_tiktoken"
crate-type = ["cdylib"]

[dependencies]
pyo3 = { version = "0.19.0", features = ["extension-module"] }

# tiktoken dependencies
fancy-regex = "0.11.0"
regex = "1.8.3"
rustc-hash = "1.1.0"
bstr = "1.5.0"

[profile.release]
incremental = true
incremental = true
6 changes: 5 additions & 1 deletion MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@ include *.svg
include *.toml
include *.md
include Makefile
include py-tiktoken/*.toml
include rs-tiktoken/*.toml
include rs-tiktoken/tests/gpt2_encoder
global-include py.typed
recursive-include scripts *.py
recursive-include tests *.py
recursive-include src *.rs
recursive-include py-tiktoken *.rs
recursive-include rs-tiktoken *.rs
20 changes: 20 additions & 0 deletions py-tiktoken/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
[package]
name = "py-tiktoken"
version.workspace = true
edition = "2021"
rust-version = "1.57.0"

[lib]
name = "_tiktoken"
crate-type = ["cdylib"]

[dependencies]
tiktoken = { path = "../rs-tiktoken" }
pyo3 = { version = "0.19.0", features = ["extension-module"] }
once_cell = "1.18.0"

# tiktoken dependencies
fancy-regex = "0.11.0"
regex = "1.8.3"
rustc-hash = "1.1.0"
bstr = "1.5.0"
1 change: 1 addition & 0 deletions py-tiktoken/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pub mod tiktoken_py;
75 changes: 7 additions & 68 deletions src/tiktoken_py.rs → py-tiktoken/src/tiktoken_py.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,13 @@

use std::collections::HashSet;

use fancy_regex::Regex;
use pyo3::exceptions;
use pyo3::prelude::*;
use pyo3::PyResult;
use pyo3::types::{PyBytes, PyList, PyTuple};
use rustc_hash::FxHashMap as HashMap;

use crate::tiktoken::{byte_pair_encode, CoreBPE, MAX_NUM_THREADS};
use tiktoken::core::{byte_pair_encode, CoreBPE};

#[pyclass]
pub struct PyCoreBPE {
Expand All @@ -26,47 +25,10 @@ impl PyCoreBPE {
special_tokens_encoder: HashMap<String, usize>,
pattern: &str,
) -> PyResult<Self> {
let regex = Regex::new(pattern)
.map_err(|e| PyErr::new::<exceptions::PyValueError, _>(e.to_string()))?;

let special_regex = {
let _parts = special_tokens_encoder
.keys()
.map(|s| fancy_regex::escape(s))
.collect::<Vec<_>>();
Regex::new(&_parts.join("|"))
.map_err(|e| PyErr::new::<exceptions::PyValueError, _>(e.to_string()))?
};

let decoder: HashMap<usize, Vec<u8>> =
encoder.iter().map(|(k, v)| (*v, k.clone())).collect();

assert!(
encoder.len() == decoder.len(),
"Encoder and decoder must be of equal length; maybe you had duplicate token indices in your encoder?"
);

let special_tokens_decoder: HashMap<usize, Vec<u8>> = special_tokens_encoder
.iter()
.map(|(k, v)| (*v, k.as_bytes().to_vec()))
.collect();

// Clone because I don't know how to tell Rust I'm not going to change the map
let mut sorted_token_bytes: Vec<Vec<u8>> = encoder.keys().cloned().collect();
sorted_token_bytes.sort();

let core_bpe = CoreBPE {
encoder,
special_tokens_encoder,
decoder,
special_tokens_decoder,
regex_tls: (0..MAX_NUM_THREADS).map(|_| regex.clone()).collect(),
special_regex_tls: (0..MAX_NUM_THREADS)
.map(|_| special_regex.clone())
.collect(),
sorted_token_bytes,
};
Ok(PyCoreBPE { core_bpe })
println!("encoder: {:?}", encoder);
CoreBPE::new(encoder, special_tokens_encoder, pattern)
.map(|core_bpe| PyCoreBPE { core_bpe })
.map_err(|e| PyErr::new::<exceptions::PyValueError, _>(e.to_string()))
}

// ====================
Expand All @@ -82,30 +44,7 @@ impl PyCoreBPE {
}

fn _encode_bytes(&self, py: Python, bytes: &[u8]) -> Vec<usize> {
py.allow_threads(|| {
match std::str::from_utf8(bytes) {
Ok(text) => self.core_bpe._encode_ordinary_native(text),
Err(e) => {
let text = unsafe { std::str::from_utf8_unchecked(&bytes[..e.valid_up_to()]) };
let (tokens, last_piece_token_len) = self.core_bpe._encode_native(text, &HashSet::new());
let (mut tokens, last_piece_token_len) =
self.core_bpe._increase_last_piece_token_len(tokens, last_piece_token_len);
if !tokens.is_empty() && last_piece_token_len > 0 {
// Lop off the tokens from the last piece and run BPE on the remaining bytes
// Somewhat niche, but this may not be correct if we'd have had a regex
// split between the valid UTF-8 and the invalid bytes, which is why this
// method is private
let mut unstable_bytes =
self.core_bpe._decode_native(&tokens[tokens.len() - last_piece_token_len..]);
unstable_bytes.extend_from_slice(&bytes[e.valid_up_to()..]);

tokens.truncate(tokens.len() - last_piece_token_len);
tokens.extend(byte_pair_encode(&unstable_bytes, &self.core_bpe.encoder));
}
tokens
}
}
})
py.allow_threads(|| self.core_bpe._encode_bytes(bytes))
}

fn encode_with_unstable(
Expand Down Expand Up @@ -181,7 +120,7 @@ pub fn _tiktoken(_py: Python, m: &PyModule) -> PyResult<()> {
mod tests {
use rustc_hash::FxHashMap as HashMap;

use crate::tiktoken::byte_pair_split;
use tiktoken::core::byte_pair_split;

#[test]
fn very_simple_test() {
Expand Down
13 changes: 13 additions & 0 deletions rs-tiktoken/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
[package]
name = "tiktoken"
version.workspace = true
edition = "2021"
rust-version = "1.57.0"

[dependencies]
fancy-regex = "0.11.0"
regex = "1.8.3"
rustc-hash = "1.1.0"
bstr = "1.5.0"
once_cell = "1.18.0"
parse-display = "0.8.2"

0 comments on commit 6b7624a

Please sign in to comment.