Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ feat: Split code rust, for run as python lib and rust lib #167

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 34 additions & 2 deletions .github/workflows/build_wheels.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ jobs:
name: dist
path: ./wheelhouse/*.whl

build_wheels_aarch64:
name: py${{ matrix.python-version }} on ${{ matrix.os }} (aarch64)
build_wheels_aarch64_glibc:
name: py${{ matrix.python-version }} on ${{ matrix.os }} (aarch64/glibc)
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
Expand All @@ -52,6 +52,38 @@ jobs:
env:
CIBW_BUILD: "cp${{ matrix.python-version}}-*"
CIBW_ARCHS: aarch64
CIBW_SKIP: "*musllinux*"
CIBW_BUILD_VERBOSITY: 3
# https://github.com/rust-lang/cargo/issues/10583
CIBW_ENVIRONMENT_LINUX: PATH="$PATH:$HOME/.cargo/bin" CARGO_NET_GIT_FETCH_WITH_CLI=true
- uses: actions/upload-artifact@v3
with:
name: dist
path: ./wheelhouse/*.whl

build_wheels_aarch64_musl:
name: py${{ matrix.python-version }} on ${{ matrix.os }} (aarch64/musl)
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest]
python-version: [38, 39, 310, 311]

steps:
- uses: actions/checkout@v3

- name: Setup up QEMU
uses: docker/setup-qemu-action@v2
with:
platforms: arm64

- name: Build wheels
uses: pypa/[email protected]
env:
CIBW_BUILD: "cp${{ matrix.python-version}}-*"
CIBW_ARCHS: aarch64
CIBW_SKIP: "*manylinux*"
CIBW_BUILD_VERBOSITY: 3
# https://github.com/rust-lang/cargo/issues/10583
CIBW_ENVIRONMENT_LINUX: PATH="$PATH:$HOME/.cargo/bin" CARGO_NET_GIT_FETCH_WITH_CLI=true
Expand Down
27 changes: 9 additions & 18 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,21 +1,12 @@
[package]
name = "tiktoken"
[workspace]
resolver = "2"
members = [
"rs-tiktoken",
"py-tiktoken",
]

[workspace.package]
version = "0.5.1"
edition = "2021"
rust-version = "1.57.0"

[lib]
name = "_tiktoken"
crate-type = ["cdylib"]

[dependencies]
pyo3 = { version = "0.19.0", features = ["extension-module"] }

# tiktoken dependencies
fancy-regex = "0.11.0"
regex = "1.8.3"
rustc-hash = "1.1.0"
bstr = "1.5.0"

[profile.release]
incremental = true
incremental = true
6 changes: 5 additions & 1 deletion MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@ include *.svg
include *.toml
include *.md
include Makefile
include py-tiktoken/*.toml
include rs-tiktoken/*.toml
include rs-tiktoken/tests/gpt2_encoder
global-include py.typed
recursive-include scripts *.py
recursive-include tests *.py
recursive-include src *.rs
recursive-include py-tiktoken *.rs
recursive-include rs-tiktoken *.rs
20 changes: 20 additions & 0 deletions py-tiktoken/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
[package]
name = "py-tiktoken"
version.workspace = true
edition = "2021"
rust-version = "1.57.0"

[lib]
name = "_tiktoken"
crate-type = ["cdylib"]

[dependencies]
tiktoken = { path = "../rs-tiktoken" }
pyo3 = { version = "0.19.0", features = ["extension-module"] }
once_cell = "1.18.0"

# tiktoken dependencies
fancy-regex = "0.11.0"
regex = "1.8.3"
rustc-hash = "1.1.0"
bstr = "1.5.0"
1 change: 1 addition & 0 deletions py-tiktoken/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pub mod tiktoken_py;
134 changes: 134 additions & 0 deletions py-tiktoken/src/tiktoken_py.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
// This check is new and seems buggy (possibly with PyO3 interaction)
#![allow(clippy::borrow_deref_ref)]

use std::collections::HashSet;

use pyo3::exceptions;
use pyo3::prelude::*;
use pyo3::PyResult;
use pyo3::types::{PyBytes, PyList, PyTuple};
use rustc_hash::FxHashMap as HashMap;

use tiktoken::core::{byte_pair_encode, CoreBPE};

#[pyclass]
pub struct PyCoreBPE {
pub core_bpe: CoreBPE,
}


#[pymethods]
impl PyCoreBPE {
#[new]
fn new(
encoder: HashMap<Vec<u8>, usize>,
special_tokens_encoder: HashMap<String, usize>,
pattern: &str,
) -> PyResult<Self> {
println!("encoder: {:?}", encoder);
CoreBPE::new(encoder, special_tokens_encoder, pattern)
.map(|core_bpe| PyCoreBPE { core_bpe })
.map_err(|e| PyErr::new::<exceptions::PyValueError, _>(e.to_string()))
}

// ====================
// Encoding
// ====================

fn encode_ordinary(&self, py: Python, text: &str) -> Vec<usize> {
py.allow_threads(|| self.core_bpe._encode_ordinary_native(text))
}

fn encode(&self, py: Python, text: &str, allowed_special: HashSet<&str>) -> Vec<usize> {
py.allow_threads(|| self.core_bpe._encode_native(text, &allowed_special).0)
}

fn _encode_bytes(&self, py: Python, bytes: &[u8]) -> Vec<usize> {
py.allow_threads(|| self.core_bpe._encode_bytes(bytes))
}

fn encode_with_unstable(
&self,
py: Python,
text: &str,
allowed_special: HashSet<&str>,
) -> Py<PyTuple> {
let (tokens, completions) =
py.allow_threads(|| self.core_bpe._encode_unstable_native(text, &allowed_special));
let py_completions =
PyList::new(py, completions.iter().map(|seq| PyList::new(py, &seq[..])));
(tokens, py_completions).into_py(py)
}

fn encode_single_token(&self, piece: &[u8]) -> PyResult<usize> {
if let Some(token) = self.core_bpe.encoder.get(piece).copied() {
return Ok(token);
}
if let Ok(piece_str) = std::str::from_utf8(piece) {
if let Some(token) = self.core_bpe.special_tokens_encoder.get(piece_str).copied() {
return Ok(token);
}
}
Err(PyErr::new::<exceptions::PyKeyError, _>(piece.to_owned()))
}

fn encode_single_piece(&self, piece: &[u8]) -> Vec<usize> {
if let Some(token) = self.core_bpe.encoder.get(piece) {
return vec![*token];
}
byte_pair_encode(piece, &self.core_bpe.encoder)
}

// ====================
// Decoding
// ====================

fn decode_bytes(&self, py: Python, tokens: Vec<usize>) -> Py<PyBytes> {
let bytes = py.allow_threads(|| self.core_bpe._decode_native(&tokens));
PyBytes::new(py, &bytes).into()
}

fn decode_single_token_bytes(&self, py: Python, token: usize) -> PyResult<Py<PyBytes>> {
if let Some(bytes) = self.core_bpe.decoder.get(&token) {
return Ok(PyBytes::new(py, bytes).into());
}
if let Some(bytes) = self.core_bpe.special_tokens_decoder.get(&token) {
return Ok(PyBytes::new(py, bytes).into());
}
Err(PyErr::new::<exceptions::PyKeyError, _>(token.to_string()))
}

// ====================
// Miscellaneous
// ====================

fn token_byte_values(&self, py: Python) -> Vec<Py<PyBytes>> {
self.core_bpe.sorted_token_bytes
.iter()
.map(|x| PyBytes::new(py, x).into())
.collect()
}
}

#[pymodule]
pub fn _tiktoken(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<PyCoreBPE>()?;
Ok(())
}

#[cfg(test)]
mod tests {
use rustc_hash::FxHashMap as HashMap;

use tiktoken::core::byte_pair_split;

#[test]
fn very_simple_test() {
let mut ranks = HashMap::default();
ranks.insert(b"ab".to_vec(), 1);
ranks.insert(b"cd".to_vec(), 2);

let res = byte_pair_split(b"abcd", &ranks);
assert_eq!(res, vec![b"ab", b"cd"]);
}
}
13 changes: 13 additions & 0 deletions rs-tiktoken/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
[package]
name = "tiktoken"
version.workspace = true
edition = "2021"
rust-version = "1.57.0"

[dependencies]
fancy-regex = "0.11.0"
regex = "1.8.3"
rustc-hash = "1.1.0"
bstr = "1.5.0"
once_cell = "1.18.0"
parse-display = "0.8.2"
Loading