From e758fd43db42e62fe6ea36548cfe2b59e3900eba Mon Sep 17 00:00:00 2001 From: Nathaniel Simard Date: Tue, 18 Jun 2024 16:45:21 -0400 Subject: [PATCH] Fix: constant record loading (#1902) --- .../burn-core/src/module/param/primitive.rs | 39 +++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/crates/burn-core/src/module/param/primitive.rs b/crates/burn-core/src/module/param/primitive.rs index 69c3cb2f78..840732fd46 100644 --- a/crates/burn-core/src/module/param/primitive.rs +++ b/crates/burn-core/src/module/param/primitive.rs @@ -21,6 +21,12 @@ where } fn load_record(self, record: Self::Record) -> Self { + let is_constant = self.num_params() == 0; + + if is_constant { + return self; + } + self.zip(record) .map(|(module, record)| module.load_record(record)) } @@ -89,6 +95,14 @@ where } fn load_record(self, record: Self::Record) -> Self { + assert_eq!( + self.len(), + record.len(), + r#"[Load Record Error] The vec record does not the same length as the module. + Make sure you module initialization is compatible with the record being loaded. + "#, + ); + self.into_iter() .zip(record) .map(|(module, record)| module.load_record(record)) @@ -267,3 +281,28 @@ impl_module_tuple!([L0, L1, L2, L3, L4, L5, L6][0, 1, 2, 3, 4, 5, 6]); impl_module_tuple!([L0, L1, L2, L3, L4, L5, L6, L7][0, 1, 2, 3, 4, 5, 6, 7]); impl_module_tuple!([L0, L1, L2, L3, L4, L5, L6, L7, L8][0, 1, 2, 3, 4, 5, 6, 7, 8]); impl_module_tuple!([L0, L1, L2, L3, L4, L5, L6, L7, L8, L9][0, 1, 2, 3, 4, 5, 6, 7, 8, 9]); + +#[cfg(test)] +mod tests { + use super::*; + use crate::TestBackend; + + #[test] + fn dont_override_constant_module_when_loading_record() { + let module = Some(42); + + let record = Module::::into_record(module); + let loaded = Module::::load_record(module, record); + + assert_eq!(loaded, module); + } + #[test] + fn dont_override_constant_module_when_loading_none_record() { + let module = Some(42); + + let record = None; + let loaded = Module::::load_record(module, record); + + assert_eq!(loaded, module); + } +}