Skip to content

Commit

Permalink
improve error handling
Browse files Browse the repository at this point in the history
  • Loading branch information
eisber committed Feb 22, 2023
1 parent 1d3f707 commit 97d8dea
Showing 1 changed file with 77 additions and 55 deletions.
132 changes: 77 additions & 55 deletions jni/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,40 +15,62 @@ use jni::sys::{jarray, jlong};

use _tiktoken_core::{self, CoreBPENative};

#[no_mangle]
pub extern "system" fn Java_tiktoken_Encoding_init(env: JNIEnv, obj: JObject, model_name: JString) {
// First, we have to get the string out of Java. Check out the `strings`
// module for more info on how this works.
let model_name: String = env
.get_string(model_name)
.expect("Unable to get Java model name")
.into();

let encoding_name = _tiktoken_core::openai_public::MODEL_TO_ENCODING
.get(&model_name)
.expect("Unable to find model");
use jni::errors::Error;

// TODO: this is actually mergable_ranks (lazy)
let mut encoding = _tiktoken_core::openai_public::REGISTRY
.get(encoding_name)
.expect("Unable to find encoding");

// TODO: initialize the CoreBPE object
fn unwrap_or_throw<T>(env: &JNIEnv, result: Result<T, Error>, default: T) -> T {
// Check if an exception is already thrown
if env.exception_check().unwrap() {
return default;
}

// TODO: this should be CoreBPE
match result {
Ok(tokenizer) => tokenizer,
Err(error) => {
let exception_class = env
.find_class("java/lang/Exception")
.unwrap();
env.throw_new(exception_class, format!("{}", error))
.unwrap();
default
}
}
}

let bpe_native = CoreBPENative::new(
encoding.get().unwrap(),
encoding.special_tokens.clone(),
&encoding.pat_str,
)
.unwrap();
#[no_mangle]
pub extern "system" fn Java_tiktoken_Encoding_init(env: JNIEnv, obj: JObject, model_name: JString) {
let result = || -> Result<(), Error> {
// First, we have to get the string out of Java. Check out the `strings`
// module for more info on how this works.
let model_name: String = env
.get_string(model_name)?
.into();

let encoding_name = _tiktoken_core::openai_public::MODEL_TO_ENCODING
.get(&model_name)
.expect("Unable to find model");

// TODO: this is actually mergable_ranks (lazy)
let mut encoding = _tiktoken_core::openai_public::REGISTRY
.get(encoding_name)
.expect("Unable to find encoding");

// TODO: initialize the CoreBPE object

// TODO: this should be CoreBPE

let bpe_native = CoreBPENative::new(
encoding.get().unwrap(),
encoding.special_tokens.clone(),
&encoding.pat_str,
)
.unwrap();

unsafe {
env.set_rust_field(obj, "handle", bpe_native).unwrap();
}
Ok(unsafe {
env.set_rust_field(obj, "handle", bpe_native).unwrap();
})
}();

// env.set_field(obj, "handle", "J", jni::objects::JValue::Long(encoding_ptr)).expect("Unable to store handle");
unwrap_or_throw(&env, result, ())
}

#[no_mangle]
Expand All @@ -66,35 +88,35 @@ pub extern "system" fn Java_tiktoken_Encoding_encode(
allowedSpecialTokens: jarray,
maxTokenLength: jlong,
) -> jarray {
let encoding: MutexGuard<CoreBPENative> = unsafe { env.get_rust_field(obj, "handle").unwrap() };

let enc = encoding;
let input: String = env
.get_string(text)
.expect("Couldn't get java string!")
.into();

let len = env.get_array_length(allowedSpecialTokens).unwrap();
let mut strings: Vec<String> = Vec::with_capacity(len as usize);
for i in 0..len {
let element: JObject = env
.get_object_array_element(allowedSpecialTokens, i)
.unwrap();
let current: String = env.get_string(element.into()).unwrap().into();
strings.push(current);
}
let result = || -> Result<jarray, Error> {
let encoding: MutexGuard<CoreBPENative> = unsafe { env.get_rust_field(obj, "handle")? };

let v2: HashSet<&str> = strings.iter().map(|s| &**s).collect();
let enc = encoding;
let input: String = env
.get_string(text)?
.into();

let (tokens, _, _) = enc._encode_native(&input, &v2, Some(maxTokenLength as usize));
let len = env.get_array_length(allowedSpecialTokens)?;
let mut strings: Vec<String> = Vec::with_capacity(len as usize);
for i in 0..len {
let element: JObject = env
.get_object_array_element(allowedSpecialTokens, i)?;
let current: String = env.get_string(element.into())?.into();
strings.push(current);
}

let mut output = env
.new_long_array(tokens.len().try_into().unwrap())
.unwrap();
let v2: HashSet<&str> = strings.iter().map(|s| &**s).collect();

let array_of_u64 = tokens.iter().map(|x| *x as i64).collect::<Vec<i64>>();
env.set_long_array_region(output, 0, array_of_u64.as_slice())
.unwrap();
let (tokens, _, _) = enc._encode_native(&input, &v2, Some(maxTokenLength as usize));

let mut output = env
.new_long_array(tokens.len().try_into().unwrap())?;

let array_of_u64 = tokens.iter().map(|x| *x as i64).collect::<Vec<i64>>();
env.set_long_array_region(output, 0, array_of_u64.as_slice())?;

Ok(output)
}();

output
unwrap_or_throw(&env, result, JObject::null().into_raw())
}

0 comments on commit 97d8dea

Please sign in to comment.