Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle visibility in cube #1929

Merged
merged 1 commit into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Handle visibility in cube
  • Loading branch information
nathanielsimard committed Jun 25, 2024
commit a426f3e86096b49d50e24862cf0d3e9fd5cfb2f9
26 changes: 19 additions & 7 deletions crates/burn-cube-macros/src/codegen_type/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ struct TypeCodegen {
name_expand: syn::Ident,
fields: Vec<syn::Field>,
generics: GenericsCodegen,
vis: syn::Visibility,
}

impl TypeCodegen {
pub fn expand_ty(&self) -> proc_macro2::TokenStream {
let mut fields = quote::quote! {};
Expand All @@ -19,17 +21,19 @@ impl TypeCodegen {
for field in self.fields.iter() {
let ident = &field.ident;
let ty = &field.ty;
let vis = &field.vis;

fields.extend(quote! {
#ident: <#ty as CubeType>::ExpandType,
#vis #ident: <#ty as CubeType>::ExpandType,
});
}

let generics = self.generics.type_definitions();
let vis = &self.vis;

quote! {
#[derive(Clone)]
struct #name #generics {
#vis struct #name #generics {
#fields
}
}
Expand All @@ -42,9 +46,10 @@ impl TypeCodegen {
for field in self.fields.iter() {
let ident = &field.ident;
let ty = &field.ty;
let vis = &field.vis;

fields.extend(quote! {
#ident: <#ty as LaunchArg>::RuntimeArg<'a, R>,
#vis #ident: <#ty as LaunchArg>::RuntimeArg<'a, R>,
});
}

Expand All @@ -65,9 +70,10 @@ impl TypeCodegen {
for field in self.fields.iter() {
let ident = &field.ident;
let ty = &field.ty;
let vis = &field.vis;

args.extend(quote! {
#ident: <#ty as LaunchArg>::RuntimeArg<'a, R>,
#vis #ident: <#ty as LaunchArg>::RuntimeArg<'a, R>,
});
fields.extend(quote! {
#ident,
Expand All @@ -76,11 +82,12 @@ impl TypeCodegen {

let generics_impl = self.generics.all_definitions();
let generics_use = self.generics.all_in_use();
let vis = &self.vis;

quote! {
impl #generics_impl #name #generics_use {
/// New kernel
pub fn new(#args) -> Self {
#vis fn new(#args) -> Self {
Self {
#fields
}
Expand Down Expand Up @@ -137,12 +144,13 @@ impl TypeCodegen {
for field in self.fields.iter() {
let ident = &field.ident;
let ty = &field.ty;
let vis = &field.vis;

body_input.extend(quote! {
#ident: <#ty as LaunchArg>::compile_input(builder, vectorization),
#vis #ident: <#ty as LaunchArg>::compile_input(builder, vectorization),
});
body_output.extend(quote! {
#ident: <#ty as LaunchArg>::compile_output(builder, vectorization),
#vis #ident: <#ty as LaunchArg>::compile_output(builder, vectorization),
});
}

Expand Down Expand Up @@ -194,9 +202,12 @@ impl TypeCodegen {
pub(crate) fn generate_cube_type(ast: &syn::DeriveInput, with_launch: bool) -> TokenStream {
let name = ast.ident.clone();
let generics = ast.generics.clone();
let visibility = ast.vis.clone();

let name_string = name.to_string();
let name_expand = Ident::new(format!("{}Expand", name_string).as_str(), name.span());
let name_launch = Ident::new(format!("{}Launch", name_string).as_str(), name.span());

let mut fields = Vec::new();

match &ast.data {
Expand All @@ -215,6 +226,7 @@ pub(crate) fn generate_cube_type(ast: &syn::DeriveInput, with_launch: bool) -> T
name_expand,
fields,
generics: GenericsCodegen::new(generics),
vis: visibility,
};

let expand_ty = codegen.expand_ty();
Expand Down
5 changes: 3 additions & 2 deletions crates/burn-cube-macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ fn codegen_cube(
func: &syn::ItemFn,
variable_tracker: &mut VariableTracker,
) -> Result<proc_macro2::TokenStream, proc_macro2::TokenStream> {
let signature = expand_sig(&func.sig, variable_tracker);
let signature = expand_sig(&func.sig, &func.vis, variable_tracker);
let mut body = quote::quote! {};

for statement in func.block.stmts.iter() {
Expand Down Expand Up @@ -148,6 +148,7 @@ fn codegen_cube(

fn expand_sig(
sig: &syn::Signature,
visibility: &syn::Visibility,
variable_tracker: &mut VariableTracker,
) -> proc_macro2::TokenStream {
let mut inputs = quote::quote!();
Expand Down Expand Up @@ -188,6 +189,6 @@ fn expand_sig(

quote::quote! {
/// Expanded Cube function
pub fn #ident #generics (context: &mut burn_cube::frontend::CubeContext, #inputs) -> #output
#visibility fn #ident #generics (context: &mut burn_cube::frontend::CubeContext, #inputs) -> #output
}
}
Loading