Skip to content

Commit

Permalink
Improve error handling in JNI functions
Browse files Browse the repository at this point in the history
  • Loading branch information
jackgerrits committed Feb 24, 2023
1 parent 97d8dea commit 53628a4
Showing 1 changed file with 16 additions and 19 deletions.
35 changes: 16 additions & 19 deletions jni/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ use jni::sys::{jarray, jlong};

use _tiktoken_core::{self, CoreBPENative};

use jni::errors::Error;
type Result<T> = std::result::Result<T, Box<dyn std::error::Error>>;

fn unwrap_or_throw<T>(env: &JNIEnv, result: Result<T, Error>, default: T) -> T {
fn unwrap_or_throw<T>(env: &JNIEnv, result: Result<T>, default: T) -> T {
// Check if an exception is already thrown
if env.exception_check().unwrap() {
if env.exception_check().expect("exception_check() failed") {
return default;
}

Expand All @@ -28,45 +28,42 @@ fn unwrap_or_throw<T>(env: &JNIEnv, result: Result<T, Error>, default: T) -> T {
Err(error) => {
let exception_class = env
.find_class("java/lang/Exception")
.unwrap();
.expect("Unable to find exception class");
env.throw_new(exception_class, format!("{}", error))
.unwrap();
.expect("Unable to throw exception");
default
}
}
}

#[no_mangle]
pub extern "system" fn Java_tiktoken_Encoding_init(env: JNIEnv, obj: JObject, model_name: JString) {
let result = || -> Result<(), Error> {
let result = || -> Result<()> {
// First, we have to get the string out of Java. Check out the `strings`
// module for more info on how this works.
let model_name: String = env
.get_string(model_name)?
.into();

let encoding_name = _tiktoken_core::openai_public::MODEL_TO_ENCODING
.get(&model_name)
.expect("Unable to find model");
.get(&model_name).ok_or("Unable to find model")?;

// TODO: this is actually mergable_ranks (lazy)
let mut encoding = _tiktoken_core::openai_public::REGISTRY
.get(encoding_name)
.expect("Unable to find encoding");
let encoding = _tiktoken_core::openai_public::REGISTRY
.get(encoding_name).ok_or("Unable to find encoding")?;

// TODO: initialize the CoreBPE object

// TODO: this should be CoreBPE

let bpe_native = CoreBPENative::new(
encoding.get().unwrap(),
encoding.get()?,
encoding.special_tokens.clone(),
&encoding.pat_str,
)
.unwrap();
)?;

Ok(unsafe {
env.set_rust_field(obj, "handle", bpe_native).unwrap();
env.set_rust_field(obj, "handle", bpe_native)?;
})
}();

Expand All @@ -76,7 +73,7 @@ pub extern "system" fn Java_tiktoken_Encoding_init(env: JNIEnv, obj: JObject, mo
#[no_mangle]
pub extern "system" fn Java_tiktoken_Encoding_destroy(env: JNIEnv, obj: JObject) {
unsafe {
let _: CoreBPENative = env.take_rust_field(obj, "handle").unwrap();
let _: CoreBPENative = env.take_rust_field(obj, "handle").expect("Unable to get handle during destruction");
}
}

Expand All @@ -88,7 +85,7 @@ pub extern "system" fn Java_tiktoken_Encoding_encode(
allowedSpecialTokens: jarray,
maxTokenLength: jlong,
) -> jarray {
let result = || -> Result<jarray, Error> {
let result = || -> Result<jarray> {
let encoding: MutexGuard<CoreBPENative> = unsafe { env.get_rust_field(obj, "handle")? };

let enc = encoding;
Expand All @@ -109,8 +106,8 @@ pub extern "system" fn Java_tiktoken_Encoding_encode(

let (tokens, _, _) = enc._encode_native(&input, &v2, Some(maxTokenLength as usize));

let mut output = env
.new_long_array(tokens.len().try_into().unwrap())?;
let output = env
.new_long_array(tokens.len().try_into()?)?;

let array_of_u64 = tokens.iter().map(|x| *x as i64).collect::<Vec<i64>>();
env.set_long_array_region(output, 0, array_of_u64.as_slice())?;
Expand Down

0 comments on commit 53628a4

Please sign in to comment.