Skip to content

Commit

Permalink
Fix: launch without generics (#1932)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard committed Jun 26, 2024
1 parent 4c90970 commit d772a1c
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 3 deletions.
12 changes: 9 additions & 3 deletions crates/burn-cube-macros/src/codegen_function/launch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,6 @@ impl Codegen {
})
}

let generics = self.generics.split_for_impl().1;

let mut format_str = "{:?}-{}".to_string();
for _ in 0..self.state_comptimes.len() {
format_str.push_str("-{:?}");
Expand All @@ -166,14 +164,22 @@ impl Codegen {
format_args.extend(quote::quote! { self.#ident, });
}

let expand_func = match self.generics.params.is_empty() {
true => quote::quote! { #expand },
false => {
let generics = self.generics.split_for_impl().1;
quote::quote! { #expand::#generics }
}
};

quote::quote! {
impl #impl_gen Kernel for #ident #ty_gen #where_gen {
fn define(&self) -> KernelDefinition {
let mut builder = KernelBuilder::default();

#variables

#expand::#generics(#expand_args);
#expand_func(#expand_args);

builder.build(self.settings.clone())
}
Expand Down
68 changes: 68 additions & 0 deletions crates/burn-cube/src/runtime_tests/launch.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
use crate as burn_cube;
use burn_cube::prelude::*;

#[cube(launch)]
pub fn kernel_with_generics<F: Float>(mut output: Array<F>) {
if UNIT_POS == UInt::new(0) {
output[0] = F::new(5.0);
}
}

#[cube(launch)]
pub fn kernel_without_generics(mut output: Array<F32>) {
if UNIT_POS == UInt::new(0) {
output[0] = F32::new(5.0);
}
}

pub fn test_kernel_with_generics<R: Runtime>(client: ComputeClient<R::Server, R::Channel>) {
let handle = client.create(f32::as_bytes(&[0.0, 1.0]));

kernel_with_generics_launch::<F32, R>(
client.clone(),
CubeCount::new(1, 1, 1),
KernelSettings::default(),
ArrayHandle::new(&handle, 2),
);

let actual = client.read(handle.binding()).read_sync().unwrap();
let actual = f32::from_bytes(&actual);

assert_eq!(actual[0], 5.0);
}

pub fn test_kernel_without_generics<R: Runtime>(client: ComputeClient<R::Server, R::Channel>) {
let handle = client.create(f32::as_bytes(&[0.0, 1.0]));

kernel_without_generics_launch::<R>(
client.clone(),
CubeCount::new(1, 1, 1),
KernelSettings::default(),
ArrayHandle::new(&handle, 2),
);

let actual = client.read(handle.binding()).read_sync().unwrap();
let actual = f32::from_bytes(&actual);

assert_eq!(actual[0], 5.0);
}

#[allow(missing_docs)]
#[macro_export]
macro_rules! testgen_launch {
() => {
use super::*;

#[test]
fn test_launch_with_generics() {
let client = TestRuntime::client(&Default::default());
burn_cube::runtime_tests::launch::test_kernel_with_generics::<TestRuntime>(client);
}

#[test]
fn test_launch_without_generics() {
let client = TestRuntime::client(&Default::default());
burn_cube::runtime_tests::launch::test_kernel_without_generics::<TestRuntime>(client);
}
};
}
2 changes: 2 additions & 0 deletions crates/burn-cube/src/runtime_tests/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
pub mod launch;
pub mod subcube;

#[allow(missing_docs)]
Expand All @@ -7,5 +8,6 @@ macro_rules! testgen_all {
use burn_cube::prelude::*;

burn_cube::testgen_subcube!();
burn_cube::testgen_launch!();
};
}

0 comments on commit d772a1c

Please sign in to comment.