-
Notifications
You must be signed in to change notification settings - Fork 392
/
matmul.rs
65 lines (51 loc) · 1.68 KB
/
matmul.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
use backend_comparison::persistence::save;
use burn::tensor::{backend::Backend, Distribution, Shape, Tensor};
use burn_common::benchmark::{run_benchmark, Benchmark};
use derive_new::new;
#[derive(new)]
struct MatmulBenchmark<B: Backend, const D: usize> {
shape_lhs: Shape<D>,
shape_rhs: Shape<D>,
device: B::Device,
}
impl<B: Backend, const D: usize> Benchmark for MatmulBenchmark<B, D> {
type Args = (Tensor<B, D>, Tensor<B, D>);
fn name(&self) -> String {
"matmul".into()
}
fn backend_config_name(&self) -> Option<String> {
B::config_name(&self.device)
}
fn shapes(&self) -> Vec<Vec<usize>> {
vec![self.shape_lhs.dims.into(), self.shape_rhs.dims.into()]
}
fn num_samples(&self) -> usize {
10
}
fn execute(&self, (lhs, rhs): Self::Args) {
lhs.clone().matmul(rhs.clone());
}
fn prepare(&self) -> Self::Args {
let lhs = Tensor::random(self.shape_lhs.clone(), Distribution::Default, &self.device);
let rhs = Tensor::random(self.shape_rhs.clone(), Distribution::Default, &self.device);
(lhs, rhs)
}
fn sync(&self) {
B::sync(&self.device)
}
}
#[allow(dead_code)]
fn bench<B: Backend>(device: &B::Device, url: Option<&str>, token: Option<&str>) {
const D: usize = 3;
let batch_size = 3;
let m = 1024;
let k = 2048;
let n = 1024;
let shape_lhs = [batch_size, m, k].into();
let shape_rhs = [batch_size, k, n].into();
let benchmark = MatmulBenchmark::<B, D>::new(shape_lhs, shape_rhs, device.clone());
save::<B>(vec![run_benchmark(benchmark)], device, url, token).unwrap();
}
fn main() {
backend_comparison::bench_on_backend!();
}