Skip to content

Commit

Permalink
Revert "Revert "Implement 3D and transposed 3D convolutions. (#1945)""
Browse files Browse the repository at this point in the history
This reverts commit b8b47ea.
  • Loading branch information
nathanielsimard authored and syl20bnr committed Jul 5, 2024
1 parent 0928a52 commit 882a27c
Show file tree
Hide file tree
Showing 53 changed files with 5,890 additions and 227 deletions.
9 changes: 9 additions & 0 deletions backend-comparison/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,19 @@ name = "conv-transpose2d"
path = "benches/conv_transpose2d.rs"
harness = false

[[bench]]
name = "conv-transpose3d"
path = "benches/conv_transpose3d.rs"
harness = false

[[bench]]
name = "conv2d"
harness = false

[[bench]]
name = "conv3d"
harness = false

[[bench]]
name = "matmul"
harness = false
Expand Down
109 changes: 109 additions & 0 deletions backend-comparison/benches/conv3d.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
use backend_comparison::persistence::save;
use burn::tensor::{
backend::Backend, module::conv3d, ops::ConvOptions, Distribution, Shape, Tensor,
};
use burn_common::{
benchmark::{run_benchmark, Benchmark},
sync_type::SyncType,
};

pub struct Conv3dBenchmark<B: Backend> {
input_shape: Shape<5>,
weight_shape: Shape<5>,
bias_shape: Shape<1>,
options: ConvOptions<3>,
device: B::Device,
}

impl<B: Backend> Benchmark for Conv3dBenchmark<B> {
type Args = (Tensor<B, 5>, Tensor<B, 5>, Tensor<B, 1>);

fn name(&self) -> String {
"conv3d".into()
}

fn shapes(&self) -> Vec<Vec<usize>> {
vec![
self.input_shape.dims.into(),
self.weight_shape.dims.into(),
self.bias_shape.dims.into(),
]
}

fn execute(&self, (x, w, b): Self::Args) {
conv3d(x, w, Some(b), self.options.clone());
}

fn prepare(&self) -> Self::Args {
(
Tensor::random(
self.input_shape.clone(),
Distribution::Default,
&self.device,
),
Tensor::random(
self.weight_shape.clone(),
Distribution::Default,
&self.device,
),
Tensor::random(self.bias_shape.clone(), Distribution::Default, &self.device),
)
}

fn sync(&self) {
B::sync(&self.device, SyncType::Wait)
}
}

#[allow(dead_code)]
fn bench<B: Backend>(
device: &B::Device,
feature_name: &str,
url: Option<&str>,
token: Option<&str>,
) {
// Shapes
let batch_size = 16;
let channels_in = 16;
let channels_out = 16;
let depth_in = 16;
let height_in = 128;
let width_in = 128;
let kernel_size_0 = 3;
let kernel_size_1 = 3;
let kernel_size_2 = 3;

// Options
let strides = [1, 1, 1];
let padding = [0, 0, 0];
let dilations = [1, 1, 1];
let groups = 1;
let options = ConvOptions::new(strides, padding, dilations, groups);
let benchmark = Conv3dBenchmark::<B> {
input_shape: [batch_size, channels_in, depth_in, height_in, width_in].into(),
weight_shape: [
channels_in,
channels_out / groups,
kernel_size_0,
kernel_size_1,
kernel_size_2,
]
.into(),
bias_shape: [channels_out].into(),
options,
device: device.clone(),
};

save::<B>(
vec![run_benchmark(benchmark)],
device,
feature_name,
url,
token,
)
.unwrap();
}

fn main() {
backend_comparison::bench_on_backend!();
}
111 changes: 111 additions & 0 deletions backend-comparison/benches/conv_transpose3d.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
use backend_comparison::persistence::save;
use burn::tensor::{
backend::Backend, module::conv_transpose3d, ops::ConvTransposeOptions, Distribution, Shape,
Tensor,
};
use burn_common::{
benchmark::{run_benchmark, Benchmark},
sync_type::SyncType,
};

pub struct ConvTranspose3dBenchmark<B: Backend> {
input_shape: Shape<5>,
weight_shape: Shape<5>,
bias_shape: Shape<1>,
options: ConvTransposeOptions<3>,
device: B::Device,
}

impl<B: Backend> Benchmark for ConvTranspose3dBenchmark<B> {
type Args = (Tensor<B, 5>, Tensor<B, 5>, Tensor<B, 1>);

fn name(&self) -> String {
"conv_transpose3d".into()
}

fn shapes(&self) -> Vec<Vec<usize>> {
vec![
self.input_shape.dims.into(),
self.weight_shape.dims.into(),
self.bias_shape.dims.into(),
]
}

fn execute(&self, (x, w, b): Self::Args) {
conv_transpose3d(x, w, Some(b), self.options.clone());
}

fn prepare(&self) -> Self::Args {
(
Tensor::random(
self.input_shape.clone(),
Distribution::Default,
&self.device,
),
Tensor::random(
self.weight_shape.clone(),
Distribution::Default,
&self.device,
),
Tensor::random(self.bias_shape.clone(), Distribution::Default, &self.device),
)
}

fn sync(&self) {
B::sync(&self.device, SyncType::Wait)
}
}

#[allow(dead_code)]
fn bench<B: Backend>(
device: &B::Device,
feature_name: &str,
url: Option<&str>,
token: Option<&str>,
) {
// Shapes
let batch_size = 16;
let channels_in = 16;
let channels_out = 16;
let depth_in = 4;
let height_in = 16;
let width_in = 16;
let kernel_size_0 = 8;
let kernel_size_1 = 8;
let kernel_size_2 = 8;

// Options
let strides = [1, 1, 1];
let padding = [0, 0, 0];
let padding_out = [0, 0, 0];
let dilations = [1, 1, 1];
let groups = 1;
let options = ConvTransposeOptions::new(strides, padding, padding_out, dilations, groups);
let benchmark = ConvTranspose3dBenchmark::<B> {
input_shape: [batch_size, channels_in, depth_in, height_in, width_in].into(),
weight_shape: [
channels_in,
channels_out / groups,
kernel_size_0,
kernel_size_1,
kernel_size_2,
]
.into(),
bias_shape: [channels_out].into(),
options,
device: device.clone(),
};

save::<B>(
vec![run_benchmark(benchmark)],
device,
feature_name,
url,
token,
)
.unwrap();
}

fn main() {
backend_comparison::bench_on_backend!();
}
4 changes: 4 additions & 0 deletions backend-comparison/src/burnbenchapp/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,12 @@ enum BenchmarkValues {
Autodiff,
#[strum(to_string = "conv-transpose2d")]
ConvTranspose2d,
#[strum(to_string = "conv-transpose3d")]
ConvTranspose3d,
#[strum(to_string = "conv2d")]
Conv2d,
#[strum(to_string = "conv3d")]
Conv3d,
}

pub fn execute() {
Expand Down
2 changes: 2 additions & 0 deletions burn-book/src/building-blocks/module.md
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,10 @@ Burn comes with built-in modules that you can use to build your own modules.
| ----------------- | -------------------- |
| `Conv1d` | `nn.Conv1d` |
| `Conv2d` | `nn.Conv2d` |
| `Conv3d` | `nn.Conv3d` |
| `ConvTranspose1d` | `nn.ConvTranspose1d` |
| `ConvTranspose2d` | `nn.ConvTranspose2d` |
| `ConvTranspose3d` | `nn.ConvTranspose3d` |

### Pooling

Expand Down
Loading

0 comments on commit 882a27c

Please sign in to comment.