Skip to content

Commit

Permalink
Refactor tensor data (#1916)
Browse files Browse the repository at this point in the history
* Move distribution to module

* Add new TensorData with serialization support

* Implement display and from for TensorData

* Add missing Cargo.lock

* Add missing bytemuck feature

* Add zeros, ones, full and random TensorData methods

* Refactor Data -> TensorData usage

* Fix tests

Since TensorData is not generic over the element type anymore no type inference can be done by the compiler. We must explicitly cast the expected results to the expected backend type.

* Remove commented line

* Fix import

* Add record-backward-compat

* Remove dim const generic from TensorData

* Support NestedValue de/serialization with TensorData

* Fix burn-jit tests

* Remove eprinln

* Refactor onnx import to use TensorData

* Fix tch from_data

* Fix nested value serialization for u8

* Fix missing import

* Fix reduce min onnx test

* Fix deprecated attribute

* Remove shape getter

* Remove strict assert in tests

* Add tensor data as_bytes

* Add tensor check for rank mismatch

* Fix typo (dimensions plural)

* Fix error message

* Update book examples with from_data and fix Display impl for TensorData

* Add deprecation note
  • Loading branch information
laggui committed Jun 27, 2024
1 parent 1c7780a commit cdd1fa1
Show file tree
Hide file tree
Showing 288 changed files with 4,800 additions and 3,697 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

48 changes: 45 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,48 @@ leads to more reliable, bug-free solutions built faster (after some practice

<br />

> **Deprecation Note**<br />Since `0.14.0`, the internal structure for tensor data has changed. The
> previous `Data` struct is being deprecated in favor of the new `TensorData` struct, which allows
> for more flexibility by storing the underlying data as bytes and keeping the data type as a field.
> If you are using `Data` in your code, make sure to switch to `TensorData`.
<!-- >
> In the event that you are trying to load a model record saved in a previous version, make sure to
> enable the `record-backward-compat` feature. Otherwise, the record won't be deserialized correctly
> and you will get an error message (which will also point you to the backward compatible feature
> flag). The backward compatibility is maintained for deserialization (loading), so as soon as you
> have saved the record again it will be saved according to the new structure and you won't need the
> backward compatible feature flag anymore. Please note that binary formats are not backward
> compatible. Thus, you will need to load your record in a previous version and save it to another
> of the self-describing record formats before using the new version with the
> `record-backward-compat` feature flag. -->

<details id="deprecation">
<summary>
Loading Model Records From Previous Versions ⚠️
</summary>
<br />

In the event that you are trying to load a model record saved in a previous version, make sure to
enable the `record-backward-compat` feature flag.

```
features = [..., "record-backward-compat"]
```

Otherwise, the record won't be deserialized correctly and you will get an error message. This error
will also point you to the backward compatible feature flag.

The backward compatibility is maintained for deserialization when loading records. Therefore, as
soon as you have saved the record again it will be saved according to the new structure and you
won't need the backward compatible feature flag anymore.

Please note that binary formats are not backward compatible. Thus, you will need to load your record
in a previous version and save it any of the other self-describing record format (e.g., using the
`NamedMpkFileRecorder`) before using the new version with the `record-backward-compat` feature flag.

</details>

## Community

<div align="left">
Expand All @@ -592,9 +634,9 @@ any background. You can ask your questions and share what you built with the com
Before contributing, please take a moment to review our
[code of conduct](https://github.com/tracel-ai/burn/tree/main/CODE-OF-CONDUCT.md). It's also highly
recommended to read the
[architecture overview](https://github.com/tracel-ai/burn/tree/main/contributor-book/src/project-architecture), which explains
some of our architectural decisions. Refer to our [contributing guide](/CONTRIBUTING.md) for more
details.
[architecture overview](https://github.com/tracel-ai/burn/tree/main/contributor-book/src/project-architecture),
which explains some of our architectural decisions. Refer to our
[contributing guide](/CONTRIBUTING.md) for more details.

## Status

Expand Down
6 changes: 3 additions & 3 deletions backend-comparison/benches/data.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use backend_comparison::persistence::save;
use burn::tensor::{backend::Backend, Data, Distribution, Shape, Tensor};
use burn::tensor::{backend::Backend, Distribution, Shape, Tensor, TensorData};
use burn_common::{
benchmark::{run_benchmark, Benchmark},
sync_type::SyncType,
Expand Down Expand Up @@ -43,7 +43,7 @@ struct FromDataBenchmark<B: Backend, const D: usize> {
}

impl<B: Backend, const D: usize> Benchmark for FromDataBenchmark<B, D> {
type Args = (Data<B::FloatElem, D>, B::Device);
type Args = (TensorData, B::Device);

fn name(&self) -> String {
"from_data".into()
Expand All @@ -59,7 +59,7 @@ impl<B: Backend, const D: usize> Benchmark for FromDataBenchmark<B, D> {

fn prepare(&self) -> Self::Args {
(
Data::random(
TensorData::random::<B::FloatElem, _, _>(
self.shape.clone(),
Distribution::Default,
&mut rand::thread_rng(),
Expand Down
20 changes: 10 additions & 10 deletions burn-book/src/basic-workflow/data.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,12 @@ Next, we need to actually implement the batching logic.
# data::{dataloader::batcher::Batcher, dataset::vision::MnistItem},
# prelude::*,
# };
#
#
# #[derive(Clone)]
# pub struct MnistBatcher<B: Backend> {
# device: B::Device,
# }
#
#
# impl<B: Backend> MnistBatcher<B> {
# pub fn new(device: B::Device) -> Self {
# Self { device }
Expand All @@ -68,7 +68,7 @@ impl<B: Backend> Batcher<MnistItem, MnistBatch<B>> for MnistBatcher<B> {
fn batch(&self, items: Vec<MnistItem>) -> MnistBatch<B> {
let images = items
.iter()
.map(|item| Data::<f32, 2>::from(item.image))
.map(|item| TensorData::from(item.image))
.map(|data| Tensor::<B, 2>::from_data(data.convert(), &self.device))
.map(|tensor| tensor.reshape([1, 28, 28]))
// Normalize: make between [0,1] and make the mean=0 and std=1
Expand All @@ -80,7 +80,7 @@ impl<B: Backend> Batcher<MnistItem, MnistBatch<B>> for MnistBatcher<B> {
let targets = items
.iter()
.map(|item| Tensor::<B, 1, Int>::from_data(
Data::from([(item.label as i64).elem()]),
TensorData::from([(item.label as i64).elem()]),
&self.device
))
.collect();
Expand Down Expand Up @@ -119,7 +119,7 @@ images.
```rust, ignore
let images = items // take items Vec<MnistItem>
.iter() // create an iterator over it
.map(|item| Data::<f32, 2>::from(item.image)) // for each item, convert the image to float32 data struct
.map(|item| TensorData::from(item.image)) // for each item, convert the image to float32 data struct
.map(|data| Tensor::<B, 2>::from_data(data.convert(), &self.device)) // for each data struct, create a tensor on the device
.map(|tensor| tensor.reshape([1, 28, 28])) // for each tensor, reshape to the image dimensions [C, H, W]
.map(|tensor| ((tensor / 255) - 0.1307) / 0.3081) // for each image tensor, apply normalization
Expand All @@ -135,8 +135,8 @@ Book.
In the previous example, we implement the `Batcher` trait with a list of `MnistItem` as input and a
single `MnistBatch` as output. The batch contains the images in the form of a 3D tensor, along with
a targets tensor that contains the indexes of the correct digit class. The first step is to parse
the image array into a `Data` struct. Burn provides the `Data` struct to encapsulate tensor storage
information without being specific for a backend. When creating a tensor from data, we often need to
convert the data precision to the current backend in use. This can be done with the `.convert()`
method. While importing the `burn::tensor::ElementConversion` trait, you can call `.elem()` on a
specific number to convert it to the current backend element type in use.
the image array into a `TensorData` struct. Burn provides the `TensorData` struct to encapsulate
tensor storage information without being specific for a backend. When creating a tensor from data,
we often need to convert the data precision to the current backend in use. This can be done with the
`.convert()` method. While importing the `burn::tensor::ElementConversion` trait, you can call
`.elem()` on a specific number to convert it to the current backend element type in use.
2 changes: 1 addition & 1 deletion burn-book/src/building-blocks/metric.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ impl<B: Backend> Metric for LossMetric<B> {
type Input = LossInput<B>;
fn update(&mut self, loss: &Self::Input, _metadata: &MetricMetadata) -> MetricEntry {
let loss = f64::from_elem(loss.tensor.clone().mean().into_data().value[0]);
let loss = loss.tensor.clone().mean().into_scalar().elem::<f64>();
self.state
.update(loss, 1, FormatOptions::new("Loss").precision(2))
Expand Down
52 changes: 23 additions & 29 deletions burn-book/src/building-blocks/tensor.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,28 +32,29 @@ let tensor_1 = Tensor::<Backend, 1>::from_floats(floats, &device);

### Initialization

Burn Tensors are primarily initialized using the `from_data()` method which takes the `Data` struct
as input. The `Data` struct has two fields: value & shape. To retrieve the data from a tensor, the
method `.to_data()` should be employed when intending to reuse the tensor afterward. Alternatively,
`.into_data()` is recommended for one-time use. Let's look at a couple of examples for initializing
a tensor from different inputs.
Burn Tensors are primarily initialized using the `from_data()` method which takes the `TensorData`
struct as input. The `TensorData` struct has two public fields: `shape` and `dtype`. The `value`,
now stored as bytes, is private but can be accessed via any of the following methods: `as_slice`,
`as_mut_slice`, `to_vec` and `iter`. To retrieve the data from a tensor, the method `.to_data()`
should be employed when intending to reuse the tensor afterward. Alternatively, `.into_data()` is
recommended for one-time use. Let's look at a couple of examples for initializing a tensor from
different inputs.

```rust, ignore
// Initialization from a given Backend (Wgpu)
let tensor_1 = Tensor::<Wgpu, 1>::from_data([1.0, 2.0, 3.0], &device);
// Initialization from a generic Backend
let tensor_2 = Tensor::<Backend, 1>::from_data(Data::from([1.0, 2.0, 3.0]).convert(), &device);
let tensor_2 = Tensor::<Backend, 1>::from_data(TensorData::from([1.0, 2.0, 3.0]), &device);
// Initialization using from_floats (Recommended for f32 ElementType)
// Will be converted to Data internally.
// `.convert()` not needed as from_floats() defined for fixed ElementType
// Will be converted to TensorData internally.
let tensor_3 = Tensor::<Backend, 1>::from_floats([1.0, 2.0, 3.0], &device);
// Initialization of Int Tensor from array slices
let arr: [i32; 6] = [1, 2, 3, 4, 5, 6];
let tensor_4 = Tensor::<Backend, 1, Int>::from_data(Data::from(&arr[0..3]).convert(), &device);
let tensor_4 = Tensor::<Backend, 1, Int>::from_data(TensorData::from(&arr[0..3]), &device);
// Initialization from a custom type
Expand All @@ -68,18 +69,11 @@ let bmi = BodyMetrics{
height: 180,
weight: 80.0
};
let data = Data::from([bmi.age as f32, bmi.height as f32, bmi.weight]).convert();
let data = TensorData::from([bmi.age as f32, bmi.height as f32, bmi.weight]);
let tensor_5 = Tensor::<Backend, 1>::from_data(data, &device);
```

The `.convert()` method for Data struct is called to ensure that the data's primitive type is
consistent across all backends. With `.from_floats()` method the ElementType is fixed as f32 and
therefore no convert operation is required across backends. This operation can also be done at
element wise level as:
`let tensor_6 = Tensor::<B, 1, Int>::from_data(Data::from([(item.age as i64).elem()])`. The
`ElementConversion` trait however needs to be imported for the element wise operation.

## Ownership and Cloning

Almost all Burn operations take ownership of the input tensors. Therefore, reusing a tensor multiple
Expand All @@ -105,7 +99,7 @@ let input = Tensor::<Wgpu, 1>::from_floats([1.0, 2.0, 3.0, 4.0], &device);
let min = input.clone().min();
let max = input.clone().max();
let input = (input.clone() - min.clone()).div(max - min);
println!("{:?}", input.to_data());// Success: [0.0, 0.33333334, 0.6666667, 1.0]
println!("{}", input.to_data());// Success: [0.0, 0.33333334, 0.6666667, 1.0]
// Notice that max, min have been moved in last operation so
// the below print will give an error.
Expand Down Expand Up @@ -285,7 +279,7 @@ Those operations are only available for `Int` tensors.

| Burn API | PyTorch Equivalent |
| ------------------------------------------------ | ------------------------------------------------------- |
| `tensor.arange(5..10, device) ` | `tensor.arange(start=5, end=10, device=device)` |
| `tensor.arange(5..10, device)` | `tensor.arange(start=5, end=10, device=device)` |
| `tensor.arange_step(5..10, 2, device)` | `tensor.arange(start=5, end=10, step=2, device=device)` |
| `tensor.float()` | `tensor.to(torch.float)` |
| `tensor.from_ints(ints)` | N/A |
Expand All @@ -296,16 +290,16 @@ Those operations are only available for `Int` tensors.

Those operations are only available for `Bool` tensors.

| Burn API | PyTorch Equivalent |
| ----------------------------------- | ------------------------------- |
| `Tensor::diag_mask(shape, diagonal)`| N/A |
| `Tensor::tril_mask(shape, diagonal)`| N/A |
| `Tensor::triu_mask(shape, diagonal)`| N/A |
| `tensor.argwhere()` | `tensor.argwhere()` |
| `tensor.float()` | `tensor.to(torch.float)` |
| `tensor.int()` | `tensor.to(torch.long)` |
| `tensor.nonzero()` | `tensor.nonzero(as_tuple=True)` |
| `tensor.not()` | `tensor.logical_not()` |
| Burn API | PyTorch Equivalent |
| ------------------------------------ | ------------------------------- |
| `Tensor::diag_mask(shape, diagonal)` | N/A |
| `Tensor::tril_mask(shape, diagonal)` | N/A |
| `Tensor::triu_mask(shape, diagonal)` | N/A |
| `tensor.argwhere()` | `tensor.argwhere()` |
| `tensor.float()` | `tensor.to(torch.float)` |
| `tensor.int()` | `tensor.to(torch.long)` |
| `tensor.nonzero()` | `tensor.nonzero(as_tuple=True)` |
| `tensor.not()` | `tensor.logical_not()` |

## Activation Functions

Expand Down
3 changes: 2 additions & 1 deletion burn-book/src/getting-started.md
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,8 @@ use burn::{
module::Module,
nn,
tensor::{
backend::Backend, Bool, Data, Device, ElementConversion, Float, Int, Shape, Tensor,
backend::Backend, Bool, Device, ElementConversion, Float, Int, Shape, Tensor,
TensorData,
},
};
```
Expand Down
2 changes: 1 addition & 1 deletion contributor-book/src/project-architecture/serialization.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ Then, a `Recorder` instance can be used to serialize any record. The `Recorder`
provided at the creation of the `Recorder` instance. Note that tensors implement record, and their
item is just a wrapper struct that contains information about the precision in which the tensor
should be saved or loaded. No actual copy of the tensor is made until this point. The tensor is
converted to the `Data` struct and then converted into the specified precision only when
converted to the `TensorData` struct and then converted into the specified precision only when
`serialize()` or `deserialize()` are called, which makes the whole process lazy.

To recapitulate, the `Module` trait has an associated type that implements `Record`, which only
Expand Down
8 changes: 4 additions & 4 deletions crates/burn-autodiff/src/ops/bool_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,23 @@ use crate::{checkpoint::strategy::CheckpointStrategy, tensor::AutodiffTensor, Au
use burn_tensor::{
backend::Backend,
ops::{BoolTensor, BoolTensorOps, IntTensor},
Data, Device, Reader, Shape,
Device, Reader, Shape, TensorData,
};

impl<B: Backend, C: CheckpointStrategy> BoolTensorOps<Self> for Autodiff<B, C> {
fn bool_from_data<const D: usize>(data: Data<bool, D>, device: &Device<B>) -> BoolTensor<B, D> {
fn bool_from_data<const D: usize>(data: TensorData, device: &Device<B>) -> BoolTensor<B, D> {
B::bool_from_data(data, device)
}

fn bool_shape<const D: usize>(tensor: &BoolTensor<B, D>) -> Shape<D> {
B::bool_shape(tensor)
}

fn bool_to_data<const D: usize>(tensor: &BoolTensor<B, D>) -> Reader<Data<bool, D>> {
fn bool_to_data<const D: usize>(tensor: &BoolTensor<B, D>) -> Reader<TensorData> {
B::bool_to_data(tensor)
}

fn bool_into_data<const D: usize>(tensor: BoolTensor<B, D>) -> Reader<Data<bool, D>> {
fn bool_into_data<const D: usize>(tensor: BoolTensor<B, D>) -> Reader<TensorData> {
B::bool_into_data(tensor)
}

Expand Down
11 changes: 4 additions & 7 deletions crates/burn-autodiff/src/ops/int_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,23 @@ use crate::{checkpoint::strategy::CheckpointStrategy, tensor::AutodiffTensor, Au
use burn_tensor::{
backend::Backend,
ops::{BoolTensor, IntTensor, IntTensorOps},
Data, Device, Distribution, Reader, Shape,
Device, Distribution, Reader, Shape, TensorData,
};

impl<B: Backend, C: CheckpointStrategy> IntTensorOps<Self> for Autodiff<B, C> {
fn int_from_data<const D: usize>(
data: Data<B::IntElem, D>,
device: &Device<Self>,
) -> IntTensor<B, D> {
fn int_from_data<const D: usize>(data: TensorData, device: &Device<Self>) -> IntTensor<B, D> {
B::int_from_data(data, device)
}

fn int_shape<const D: usize>(tensor: &IntTensor<B, D>) -> Shape<D> {
B::int_shape(tensor)
}

fn int_to_data<const D: usize>(tensor: &IntTensor<B, D>) -> Reader<Data<B::IntElem, D>> {
fn int_to_data<const D: usize>(tensor: &IntTensor<B, D>) -> Reader<TensorData> {
B::int_to_data(tensor)
}

fn int_into_data<const D: usize>(tensor: IntTensor<B, D>) -> Reader<Data<B::IntElem, D>> {
fn int_into_data<const D: usize>(tensor: IntTensor<B, D>) -> Reader<TensorData> {
B::int_into_data(tensor)
}

Expand Down
12 changes: 4 additions & 8 deletions crates/burn-autodiff/src/ops/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@ use crate::{
use burn_tensor::{
backend::Backend,
ops::{BoolTensor, FloatElem, FloatTensor, FloatTensorOps, IntTensor},
Data, Device, ElementConversion, Reader, Shape, Tensor,
Device, ElementConversion, Reader, Shape, Tensor, TensorData,
};

use super::maxmin::MaxMinDim;

impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C> {
fn float_from_data<const D: usize>(
data: Data<FloatElem<B>, D>,
data: TensorData,
device: &Device<Self>,
) -> FloatTensor<Self, D> {
AutodiffTensor::new(B::float_from_data(data, device))
Expand All @@ -50,15 +50,11 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
B::float_shape(&tensor.primitive)
}

fn float_to_data<const D: usize>(
tensor: &FloatTensor<Self, D>,
) -> Reader<Data<FloatElem<B>, D>> {
fn float_to_data<const D: usize>(tensor: &FloatTensor<Self, D>) -> Reader<TensorData> {
B::float_to_data(&tensor.primitive)
}

fn float_into_data<const D: usize>(
tensor: FloatTensor<Self, D>,
) -> Reader<Data<FloatElem<B>, D>> {
fn float_into_data<const D: usize>(tensor: FloatTensor<Self, D>) -> Reader<TensorData> {
B::float_into_data(tensor.primitive)
}

Expand Down
Loading

0 comments on commit cdd1fa1

Please sign in to comment.