Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ feat: Split code rust, for run as python lib and rust lib #167

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Prev Previous commit
✨ feat(rust): Convert project to a multi-crate workspace
This commit restructures the project from a single-crate workspace into a multi-crate workspace, dividing it into 'rs-tiktoken' and 'py-tiktoken'. This is done to improve the clarity of the organization of the codebase and make the Rust and Python modules separate for easier code maintenance. The setup.py is also updated to reflect these changes in the directory structure.

Refs: #24
  • Loading branch information
Miuler committed Sep 15, 2023
commit 6b7624aa9ca7f042a2bdb947214070d8005f5be7
36 changes: 34 additions & 2 deletions .github/workflows/build_wheels.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ jobs:
name: dist
path: ./wheelhouse/*.whl

build_wheels_aarch64:
name: py${{ matrix.python-version }} on ${{ matrix.os }} (aarch64)
build_wheels_aarch64_glibc:
name: py${{ matrix.python-version }} on ${{ matrix.os }} (aarch64/glibc)
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
Expand All @@ -52,6 +52,38 @@ jobs:
env:
CIBW_BUILD: "cp${{ matrix.python-version}}-*"
CIBW_ARCHS: aarch64
CIBW_SKIP: "*musllinux*"
CIBW_BUILD_VERBOSITY: 3
# https://github.com/rust-lang/cargo/issues/10583
CIBW_ENVIRONMENT_LINUX: PATH="$PATH:$HOME/.cargo/bin" CARGO_NET_GIT_FETCH_WITH_CLI=true
- uses: actions/upload-artifact@v3
with:
name: dist
path: ./wheelhouse/*.whl

build_wheels_aarch64_musl:
name: py${{ matrix.python-version }} on ${{ matrix.os }} (aarch64/musl)
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest]
python-version: [38, 39, 310, 311]

steps:
- uses: actions/checkout@v3

- name: Setup up QEMU
uses: docker/setup-qemu-action@v2
with:
platforms: arm64

- name: Build wheels
uses: pypa/[email protected]
env:
CIBW_BUILD: "cp${{ matrix.python-version}}-*"
CIBW_ARCHS: aarch64
CIBW_SKIP: "*manylinux*"
CIBW_BUILD_VERBOSITY: 3
# https://github.com/rust-lang/cargo/issues/10583
CIBW_ENVIRONMENT_LINUX: PATH="$PATH:$HOME/.cargo/bin" CARGO_NET_GIT_FETCH_WITH_CLI=true
Expand Down
27 changes: 9 additions & 18 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,21 +1,12 @@
[package]
name = "tiktoken"
[workspace]
resolver = "2"
members = [
"rs-tiktoken",
"py-tiktoken",
]

[workspace.package]
version = "0.5.1"
edition = "2021"
rust-version = "1.57.0"

[lib]
name = "_tiktoken"
crate-type = ["cdylib"]

[dependencies]
pyo3 = { version = "0.19.0", features = ["extension-module"] }

# tiktoken dependencies
fancy-regex = "0.11.0"
regex = "1.8.3"
rustc-hash = "1.1.0"
bstr = "1.5.0"

[profile.release]
incremental = true
incremental = true
6 changes: 5 additions & 1 deletion MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@ include *.svg
include *.toml
include *.md
include Makefile
include py-tiktoken/*.toml
include rs-tiktoken/*.toml
include rs-tiktoken/tests/gpt2_encoder
global-include py.typed
recursive-include scripts *.py
recursive-include tests *.py
recursive-include src *.rs
recursive-include py-tiktoken *.rs
recursive-include rs-tiktoken *.rs
20 changes: 20 additions & 0 deletions py-tiktoken/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
[package]
name = "py-tiktoken"
version.workspace = true
edition = "2021"
rust-version = "1.57.0"

[lib]
name = "_tiktoken"
crate-type = ["cdylib"]

[dependencies]
tiktoken = { path = "../rs-tiktoken" }
pyo3 = { version = "0.19.0", features = ["extension-module"] }
once_cell = "1.18.0"

# tiktoken dependencies
fancy-regex = "0.11.0"
regex = "1.8.3"
rustc-hash = "1.1.0"
bstr = "1.5.0"
1 change: 1 addition & 0 deletions py-tiktoken/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pub mod tiktoken_py;
75 changes: 7 additions & 68 deletions src/tiktoken_py.rs → py-tiktoken/src/tiktoken_py.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,13 @@

use std::collections::HashSet;

use fancy_regex::Regex;
use pyo3::exceptions;
use pyo3::prelude::*;
use pyo3::PyResult;
use pyo3::types::{PyBytes, PyList, PyTuple};
use rustc_hash::FxHashMap as HashMap;

use crate::tiktoken::{byte_pair_encode, CoreBPE, MAX_NUM_THREADS};
use tiktoken::core::{byte_pair_encode, CoreBPE};

#[pyclass]
pub struct PyCoreBPE {
Expand All @@ -26,47 +25,10 @@ impl PyCoreBPE {
special_tokens_encoder: HashMap<String, usize>,
pattern: &str,
) -> PyResult<Self> {
let regex = Regex::new(pattern)
.map_err(|e| PyErr::new::<exceptions::PyValueError, _>(e.to_string()))?;

let special_regex = {
let _parts = special_tokens_encoder
.keys()
.map(|s| fancy_regex::escape(s))
.collect::<Vec<_>>();
Regex::new(&_parts.join("|"))
.map_err(|e| PyErr::new::<exceptions::PyValueError, _>(e.to_string()))?
};

let decoder: HashMap<usize, Vec<u8>> =
encoder.iter().map(|(k, v)| (*v, k.clone())).collect();

assert!(
encoder.len() == decoder.len(),
"Encoder and decoder must be of equal length; maybe you had duplicate token indices in your encoder?"
);

let special_tokens_decoder: HashMap<usize, Vec<u8>> = special_tokens_encoder
.iter()
.map(|(k, v)| (*v, k.as_bytes().to_vec()))
.collect();

// Clone because I don't know how to tell Rust I'm not going to change the map
let mut sorted_token_bytes: Vec<Vec<u8>> = encoder.keys().cloned().collect();
sorted_token_bytes.sort();

let core_bpe = CoreBPE {
encoder,
special_tokens_encoder,
decoder,
special_tokens_decoder,
regex_tls: (0..MAX_NUM_THREADS).map(|_| regex.clone()).collect(),
special_regex_tls: (0..MAX_NUM_THREADS)
.map(|_| special_regex.clone())
.collect(),
sorted_token_bytes,
};
Ok(PyCoreBPE { core_bpe })
println!("encoder: {:?}", encoder);
CoreBPE::new(encoder, special_tokens_encoder, pattern)
.map(|core_bpe| PyCoreBPE { core_bpe })
.map_err(|e| PyErr::new::<exceptions::PyValueError, _>(e.to_string()))
}

// ====================
Expand All @@ -82,30 +44,7 @@ impl PyCoreBPE {
}

fn _encode_bytes(&self, py: Python, bytes: &[u8]) -> Vec<usize> {
py.allow_threads(|| {
match std::str::from_utf8(bytes) {
Ok(text) => self.core_bpe._encode_ordinary_native(text),
Err(e) => {
let text = unsafe { std::str::from_utf8_unchecked(&bytes[..e.valid_up_to()]) };
let (tokens, last_piece_token_len) = self.core_bpe._encode_native(text, &HashSet::new());
let (mut tokens, last_piece_token_len) =
self.core_bpe._increase_last_piece_token_len(tokens, last_piece_token_len);
if !tokens.is_empty() && last_piece_token_len > 0 {
// Lop off the tokens from the last piece and run BPE on the remaining bytes
// Somewhat niche, but this may not be correct if we'd have had a regex
// split between the valid UTF-8 and the invalid bytes, which is why this
// method is private
let mut unstable_bytes =
self.core_bpe._decode_native(&tokens[tokens.len() - last_piece_token_len..]);
unstable_bytes.extend_from_slice(&bytes[e.valid_up_to()..]);

tokens.truncate(tokens.len() - last_piece_token_len);
tokens.extend(byte_pair_encode(&unstable_bytes, &self.core_bpe.encoder));
}
tokens
}
}
})
py.allow_threads(|| self.core_bpe._encode_bytes(bytes))
}

fn encode_with_unstable(
Expand Down Expand Up @@ -181,7 +120,7 @@ pub fn _tiktoken(_py: Python, m: &PyModule) -> PyResult<()> {
mod tests {
use rustc_hash::FxHashMap as HashMap;

use crate::tiktoken::byte_pair_split;
use tiktoken::core::byte_pair_split;

#[test]
fn very_simple_test() {
Expand Down
13 changes: 13 additions & 0 deletions rs-tiktoken/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
[package]
name = "tiktoken"
version.workspace = true
edition = "2021"
rust-version = "1.57.0"

[dependencies]
fancy-regex = "0.11.0"
regex = "1.8.3"
rustc-hash = "1.1.0"
bstr = "1.5.0"
once_cell = "1.18.0"
parse-display = "0.8.2"