Skip to content

Commit

Permalink
Fix repeat for dims > 1 (tracel-ai#1713)
Browse files Browse the repository at this point in the history
  • Loading branch information
louisfd committed May 1, 2024
1 parent 3a02a54 commit 2e4c82f
Show file tree
Hide file tree
Showing 12 changed files with 136 additions and 90 deletions.
2 changes: 1 addition & 1 deletion crates/burn-fusion/src/ops/boolean.rs
Original file line number Diff line number Diff line change
Expand Up @@ -564,7 +564,7 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {

let stream = tensor.stream;
let mut shape = tensor.shape.clone();
shape[dim] = times;
shape[dim] *= times;
let out = tensor.client.tensor_uninitialized(shape);

let desc = RepeatOperationDescription {
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-fusion/src/ops/float.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1620,7 +1620,7 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {

let stream = tensor.stream;
let mut shape = tensor.shape.clone();
shape[dim] = times;
shape[dim] *= times;
let out = tensor.client.tensor_uninitialized(shape);

let desc = RepeatOperationDescription {
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-fusion/src/ops/int.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1665,7 +1665,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {

let stream = tensor.stream;
let mut shape = tensor.shape.clone();
shape[dim] = times;
shape[dim] *= times;
let out = tensor.client.tensor_uninitialized(shape);

let desc = RepeatOperationDescription {
Expand Down
25 changes: 12 additions & 13 deletions crates/burn-jit/src/kernel/index/repeat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,19 +38,21 @@ impl RepeatComputeShader {

let stride_input = scope.create_local(Elem::UInt);
let stride_output = scope.create_local(Elem::UInt);
let shape_output = scope.create_local(Elem::UInt);
let shape = scope.create_local(Elem::UInt);

for i in 0..self.rank {
gpu!(scope, stride_input = stride(input, i));
gpu!(scope, stride_output = stride(output, i));
if i != self.dim {
gpu!(scope, stride_input = stride(input, i));
gpu!(scope, stride_output = stride(output, i));
gpu!(scope, shape_output = shape(output, i));

gpu!(scope, offset_local = id / stride_output);
gpu!(scope, offset_local = offset_local % shape_output);
gpu!(scope, offset_local = offset_local * stride_input);
gpu!(scope, offset_input += offset_local);
gpu!(scope, shape = shape(output, i));
} else {
gpu!(scope, shape = shape(input, i));
}

gpu!(scope, offset_local = id / stride_output);
gpu!(scope, offset_local = offset_local % shape);
gpu!(scope, offset_local = offset_local * stride_input);
gpu!(scope, offset_input += offset_local);
}

let result = scope.create_local(input.item());
Expand Down Expand Up @@ -108,12 +110,9 @@ pub(crate) fn repeat<R: Runtime, E: JitElement, const D1: usize>(
times: usize,
) -> JitTensor<R, E, D1> {
let mut shape = input.shape.clone();
if shape.dims[dim] != 1 {
panic!("Can only repeat dimension with dim=1");
}

// Create output handle
shape.dims[dim] = times;
shape.dims[dim] *= times;
let num_elems_output = shape.num_elements();
let handle = input
.client
Expand Down
4 changes: 0 additions & 4 deletions crates/burn-tensor/src/tensor/api/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -564,10 +564,6 @@ where
}

/// Repeat the tensor along the given dimension.
///
/// # Panics
///
/// If the selected dimension more than one item.
pub fn repeat(self, dim: usize, times: usize) -> Self {
Self::new(K::repeat(self.primitive, dim, times))
}
Expand Down
33 changes: 10 additions & 23 deletions crates/burn-tensor/src/tensor/ops/bool_tensor.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use super::{cat::cat_with_slice_assign, BoolTensor, Device, FloatTensor, IntTensor};
use super::{
cat::cat_with_slice_assign, repeat::repeat_with_slice_assign, BoolTensor, Device, FloatTensor,
IntTensor,
};
use crate::{
backend::Backend, chunk, narrow, tensor::Shape, Bool, Data, ElementConversion, Tensor,
};
Expand Down Expand Up @@ -174,28 +177,12 @@ pub trait BoolTensorOps<B: Backend> {
dim: usize,
times: usize,
) -> BoolTensor<B, D> {
let mut shape = Self::bool_shape(&tensor);
if shape.dims[dim] != 1 {
panic!("Can only repeat dimension with dim=1");
}
shape.dims[dim] = times;

let mut i = 0;
let ranges_select_all = [0; D].map(|_| {
let start = 0;
let end = shape.dims[i];
i += 1;
start..end
});

let mut tensor_output = Self::bool_empty(shape, &Self::bool_device(&tensor));
for i in 0..times {
let mut ranges = ranges_select_all.clone();
ranges[dim] = i..i + 1;
tensor_output = Self::bool_slice_assign(tensor_output, ranges, tensor.clone());
}

tensor_output
repeat_with_slice_assign::<B, D, Bool>(
Tensor::<B, D, Bool>::from_primitive(tensor),
dim,
times,
)
.into_primitive()
}

/// Concatenates the tensors along the given dimension.
Expand Down
29 changes: 7 additions & 22 deletions crates/burn-tensor/src/tensor/ops/int_tensor.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use super::cat::cat_with_slice_assign;
use super::repeat::repeat_with_slice_assign;
use super::{BoolTensor, Device, FloatTensor, IntElem, IntTensor};
use crate::Tensor;
use crate::{backend::Backend, tensor::Shape, Data, Distribution, ElementConversion, Int};
Expand Down Expand Up @@ -270,28 +271,12 @@ pub trait IntTensorOps<B: Backend> {
dim: usize,
times: usize,
) -> IntTensor<B, D> {
let mut shape = Self::int_shape(&tensor);
if shape.dims[dim] != 1 {
panic!("Can only repeat dimension with dim=1");
}
shape.dims[dim] = times;

let mut i = 0;
let indices_select_all = [0; D].map(|_| {
let start = 0;
let end = shape.dims[i];
i += 1;
start..end
});

let mut tensor_output = Self::int_empty(shape, &Self::int_device(&tensor));
for i in 0..times {
let mut indices = indices_select_all.clone();
indices[dim] = i..i + 1;
tensor_output = Self::int_slice_assign(tensor_output, indices, tensor.clone());
}

tensor_output
repeat_with_slice_assign::<B, D, Int>(
Tensor::<B, D, Int>::from_primitive(tensor),
dim,
times,
)
.into_primitive()
}

/// Concatenates the given tensors along the given dimension.
Expand Down
4 changes: 1 addition & 3 deletions crates/burn-tensor/src/tensor/ops/modules/cat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,8 @@ pub(crate) fn cat_with_slice_assign<B: Backend, const D: usize, K: TensorKind<B>

let mut i = 0;
let indices_select_all = [0; D].map(|_| {
let start = 0;
let end = shape.dims[i];
i += 1;
start..end
0..shape.dims[i - 1]
});

let mut output_index = 0;
Expand Down
2 changes: 2 additions & 0 deletions crates/burn-tensor/src/tensor/ops/modules/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ pub mod conv;

/// Module with cat operation
pub(crate) mod cat;
/// Module with repeat operation
pub(crate) mod repeat;
/// Module with unfold operations.
pub(crate) mod unfold;

Expand Down
36 changes: 36 additions & 0 deletions crates/burn-tensor/src/tensor/ops/modules/repeat.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
use crate::{backend::Backend, BasicOps, Tensor, TensorKind};

pub(crate) fn repeat_with_slice_assign<
B: Backend,
const D: usize,
K: TensorKind<B> + BasicOps<B>,
>(
tensor: Tensor<B, D, K>,
dim: usize,
times: usize,
) -> Tensor<B, D, K> {
let mut shape = tensor.shape();
let device = tensor.device();

let original_dim_length = shape.dims[dim];
shape.dims[dim] *= times;

let mut tensor_output = Tensor::empty(shape.clone(), &device);

let mut i = 0;
let indices_select_all = [0; D].map(|_| {
i += 1;
0..shape.dims[i - 1]
});

let mut output_index = 0;
for _ in 0..times {
let mut indices = indices_select_all.clone();
indices[dim] = output_index..output_index + original_dim_length;
output_index += original_dim_length;

tensor_output = tensor_output.slice_assign(indices, tensor.clone());
}

tensor_output
}
25 changes: 3 additions & 22 deletions crates/burn-tensor/src/tensor/ops/tensor.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use super::cat::cat_with_slice_assign;
use super::repeat::repeat_with_slice_assign;
use super::{BoolTensor, Device, FloatElem, FloatTensor, FullPrecisionBackend, IntElem, IntTensor};
use crate::backend::BackendBridge;
use crate::Tensor;
Expand Down Expand Up @@ -193,28 +194,8 @@ pub trait FloatTensorOps<B: Backend> {
dim: usize,
times: usize,
) -> FloatTensor<B, D> {
let mut shape = B::float_shape(&tensor);
if shape.dims[dim] != 1 {
panic!("Can only repeat dimension with dim=1");
}
shape.dims[dim] = times;

let mut i = 0;
let indices_select_all = [0; D].map(|_| {
let start = 0;
let end = shape.dims[i];
i += 1;
start..end
});

let mut tensor_output = B::float_empty(shape, &B::float_device(&tensor));
for i in 0..times {
let mut indices = indices_select_all.clone();
indices[dim] = i..i + 1;
tensor_output = B::float_slice_assign(tensor_output, indices, tensor.clone());
}

tensor_output
repeat_with_slice_assign::<B, D, Float>(Tensor::<B, D>::from_primitive(tensor), dim, times)
.into_primitive()
}

/// Adds two tensors together.
Expand Down
62 changes: 62 additions & 0 deletions crates/burn-tensor/src/tests/ops/repeat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,66 @@ mod tests {
let data_expected = Data::from([[0, 1, 2], [0, 1, 2], [0, 1, 2], [0, 1, 2]]);
assert_eq!(data_expected, data_actual);
}

#[test]
fn should_support_float_repeat_on_dims_larger_than_1() {
let data = Data::from([
[[1.0, 2.0], [3.0, 4.0]],
[[5.0, 6.0], [7.0, 8.0]],
[[9.0, 10.0], [11.0, 12.0]],
[[13.0, 14.0], [15.0, 16.0]],
]);
let tensor = Tensor::<TestBackend, 3>::from_data(data, &Default::default());

let data_actual = tensor.repeat(2, 2).into_data();

let data_expected = Data::from([
[[1.0, 2.0, 1.0, 2.0], [3.0, 4.0, 3.0, 4.0]],
[[5.0, 6.0, 5.0, 6.0], [7.0, 8.0, 7.0, 8.0]],
[[9.0, 10.0, 9.0, 10.0], [11.0, 12.0, 11.0, 12.0]],
[[13.0, 14.0, 13.0, 14.0], [15.0, 16.0, 15.0, 16.0]],
]);

assert_eq!(data_expected, data_actual);
}

#[test]
fn should_support_int_repeat_on_dims_larger_than_1() {
let data = Data::from([
[[1, 2], [3, 4]],
[[5, 6], [7, 8]],
[[9, 10], [11, 12]],
[[13, 14], [15, 16]],
]);
let tensor = Tensor::<TestBackend, 3, Int>::from_data(data, &Default::default());

let data_actual = tensor.repeat(2, 3).into_data();

let data_expected = Data::from([
[[1, 2, 1, 2, 1, 2], [3, 4, 3, 4, 3, 4]],
[[5, 6, 5, 6, 5, 6], [7, 8, 7, 8, 7, 8]],
[[9, 10, 9, 10, 9, 10], [11, 12, 11, 12, 11, 12]],
[[13, 14, 13, 14, 13, 14], [15, 16, 15, 16, 15, 16]],
]);

assert_eq!(data_expected, data_actual);
}

#[test]
fn should_support_bool_repeat_on_dims_larger_than_1() {
let data = Data::from([
[[false, true], [true, false]],
[[true, true], [false, false]],
]);
let tensor = Tensor::<TestBackend, 3, Bool>::from_data(data, &Default::default());

let data_actual = tensor.repeat(1, 2).into_data();

let data_expected = Data::from([
[[false, true], [true, false], [false, true], [true, false]],
[[true, true], [false, false], [true, true], [false, false]],
]);

assert_eq!(data_expected, data_actual);
}
}

0 comments on commit 2e4c82f

Please sign in to comment.