Skip to content

Commit

Permalink
Fix: constant record loading (tracel-ai#1902)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard committed Jun 18, 2024
1 parent 263add2 commit e758fd4
Showing 1 changed file with 39 additions and 0 deletions.
39 changes: 39 additions & 0 deletions crates/burn-core/src/module/param/primitive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@ where
}

fn load_record(self, record: Self::Record) -> Self {
let is_constant = self.num_params() == 0;

if is_constant {
return self;
}

self.zip(record)
.map(|(module, record)| module.load_record(record))
}
Expand Down Expand Up @@ -89,6 +95,14 @@ where
}

fn load_record(self, record: Self::Record) -> Self {
assert_eq!(
self.len(),
record.len(),
r#"[Load Record Error] The vec record does not the same length as the module.
Make sure you module initialization is compatible with the record being loaded.
"#,
);

self.into_iter()
.zip(record)
.map(|(module, record)| module.load_record(record))
Expand Down Expand Up @@ -267,3 +281,28 @@ impl_module_tuple!([L0, L1, L2, L3, L4, L5, L6][0, 1, 2, 3, 4, 5, 6]);
impl_module_tuple!([L0, L1, L2, L3, L4, L5, L6, L7][0, 1, 2, 3, 4, 5, 6, 7]);
impl_module_tuple!([L0, L1, L2, L3, L4, L5, L6, L7, L8][0, 1, 2, 3, 4, 5, 6, 7, 8]);
impl_module_tuple!([L0, L1, L2, L3, L4, L5, L6, L7, L8, L9][0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);

#[cfg(test)]
mod tests {
use super::*;
use crate::TestBackend;

#[test]
fn dont_override_constant_module_when_loading_record() {
let module = Some(42);

let record = Module::<TestBackend>::into_record(module);
let loaded = Module::<TestBackend>::load_record(module, record);

assert_eq!(loaded, module);
}
#[test]
fn dont_override_constant_module_when_loading_none_record() {
let module = Some(42);

let record = None;
let loaded = Module::<TestBackend>::load_record(module, record);

assert_eq!(loaded, module);
}
}

0 comments on commit e758fd4

Please sign in to comment.