diff --git a/crates/burn-cube-macros/src/codegen_type/base.rs b/crates/burn-cube-macros/src/codegen_type/base.rs index 32a29b621c..ab7fbc6e34 100644 --- a/crates/burn-cube-macros/src/codegen_type/base.rs +++ b/crates/burn-cube-macros/src/codegen_type/base.rs @@ -10,7 +10,9 @@ struct TypeCodegen { name_expand: syn::Ident, fields: Vec, generics: GenericsCodegen, + vis: syn::Visibility, } + impl TypeCodegen { pub fn expand_ty(&self) -> proc_macro2::TokenStream { let mut fields = quote::quote! {}; @@ -19,17 +21,19 @@ impl TypeCodegen { for field in self.fields.iter() { let ident = &field.ident; let ty = &field.ty; + let vis = &field.vis; fields.extend(quote! { - #ident: <#ty as CubeType>::ExpandType, + #vis #ident: <#ty as CubeType>::ExpandType, }); } let generics = self.generics.type_definitions(); + let vis = &self.vis; quote! { #[derive(Clone)] - struct #name #generics { + #vis struct #name #generics { #fields } } @@ -42,9 +46,10 @@ impl TypeCodegen { for field in self.fields.iter() { let ident = &field.ident; let ty = &field.ty; + let vis = &field.vis; fields.extend(quote! { - #ident: <#ty as LaunchArg>::RuntimeArg<'a, R>, + #vis #ident: <#ty as LaunchArg>::RuntimeArg<'a, R>, }); } @@ -65,9 +70,10 @@ impl TypeCodegen { for field in self.fields.iter() { let ident = &field.ident; let ty = &field.ty; + let vis = &field.vis; args.extend(quote! { - #ident: <#ty as LaunchArg>::RuntimeArg<'a, R>, + #vis #ident: <#ty as LaunchArg>::RuntimeArg<'a, R>, }); fields.extend(quote! { #ident, @@ -76,11 +82,12 @@ impl TypeCodegen { let generics_impl = self.generics.all_definitions(); let generics_use = self.generics.all_in_use(); + let vis = &self.vis; quote! { impl #generics_impl #name #generics_use { /// New kernel - pub fn new(#args) -> Self { + #vis fn new(#args) -> Self { Self { #fields } @@ -137,12 +144,13 @@ impl TypeCodegen { for field in self.fields.iter() { let ident = &field.ident; let ty = &field.ty; + let vis = &field.vis; body_input.extend(quote! { - #ident: <#ty as LaunchArg>::compile_input(builder, vectorization), + #vis #ident: <#ty as LaunchArg>::compile_input(builder, vectorization), }); body_output.extend(quote! { - #ident: <#ty as LaunchArg>::compile_output(builder, vectorization), + #vis #ident: <#ty as LaunchArg>::compile_output(builder, vectorization), }); } @@ -194,9 +202,12 @@ impl TypeCodegen { pub(crate) fn generate_cube_type(ast: &syn::DeriveInput, with_launch: bool) -> TokenStream { let name = ast.ident.clone(); let generics = ast.generics.clone(); + let visibility = ast.vis.clone(); + let name_string = name.to_string(); let name_expand = Ident::new(format!("{}Expand", name_string).as_str(), name.span()); let name_launch = Ident::new(format!("{}Launch", name_string).as_str(), name.span()); + let mut fields = Vec::new(); match &ast.data { @@ -215,6 +226,7 @@ pub(crate) fn generate_cube_type(ast: &syn::DeriveInput, with_launch: bool) -> T name_expand, fields, generics: GenericsCodegen::new(generics), + vis: visibility, }; let expand_ty = codegen.expand_ty(); diff --git a/crates/burn-cube-macros/src/lib.rs b/crates/burn-cube-macros/src/lib.rs index 4adf137027..39038e7fb4 100644 --- a/crates/burn-cube-macros/src/lib.rs +++ b/crates/burn-cube-macros/src/lib.rs @@ -106,7 +106,7 @@ fn codegen_cube( func: &syn::ItemFn, variable_tracker: &mut VariableTracker, ) -> Result { - let signature = expand_sig(&func.sig, variable_tracker); + let signature = expand_sig(&func.sig, &func.vis, variable_tracker); let mut body = quote::quote! {}; for statement in func.block.stmts.iter() { @@ -148,6 +148,7 @@ fn codegen_cube( fn expand_sig( sig: &syn::Signature, + visibility: &syn::Visibility, variable_tracker: &mut VariableTracker, ) -> proc_macro2::TokenStream { let mut inputs = quote::quote!(); @@ -188,6 +189,6 @@ fn expand_sig( quote::quote! { /// Expanded Cube function - pub fn #ident #generics (context: &mut burn_cube::frontend::CubeContext, #inputs) -> #output + #visibility fn #ident #generics (context: &mut burn_cube::frontend::CubeContext, #inputs) -> #output } }