Skip to content
/ burn Public
forked from tracel-ai/burn

Burn - A Flexible and Comprehensive Deep Learning Framework in Rust

License

Notifications You must be signed in to change notification settings

hehehe159/burn

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Discord Test Status Documentation Current Crates.io Version Rust Version license

This library strives to serve as a comprehensive deep learning framework, offering exceptional flexibility and written in Rust. Our objective is to cater to both researchers and practitioners by simplifying the process of experimenting, training, and deploying models.

Features

  • Customizable, user-friendly neural network module 🔥
  • Comprehensive training tools, inclusive of metrics, logging, and checkpointing 📈
  • Versatile Tensor crate equipped with pluggable backends 🔧
    • Torch backend, supporting both CPU and GPU 🚀
    • Ndarray backend with no_std compatibility, ensuring universal platform adaptability 👌
    • WebGPU backend, offering cross-platform, browser-inclusive, GPU-based computations 🌐
    • Autodiff backend that enables differentiability across all backends 🌟
  • Dataset crate containing a diverse range of utilities and sources 📚
  • Import crate that simplifies the integration of pretrained models 📦

Supported Platforms

Burn-ndarray Backend

Option CPU GPU Linux MacOS Windows Android iOS WASM
Pure Rust Yes No Yes Yes Yes Yes Yes Yes
Accelerate Yes No No Yes No No Yes No
Netlib Yes No Yes Yes Yes No No No
Openblas Yes No Yes Yes Yes Yes Yes No

Burn-tch Backend

Option CPU GPU Linux MacOS Windows Android iOS WASM
CPU Yes No Yes Yes Yes Yes Yes No
CUDA No Yes Yes No Yes No No No
MPS No Yes No Yes No No No No
Vulkan Yes Yes Yes Yes Yes Yes No No

Burn-wgpu Backend

Option CPU GPU Linux MacOS Windows Android iOS WASM
Metal No Yes No Yes No No Yes No
Vulkan Yes Yes Yes Yes Yes Yes Yes No
OpenGL No Yes Yes Yes Yes Yes Yes No
WebGpu No Yes No No No No No Yes
Dx11/Dx12 No Yes No No Yes No No No

Pre-trained Models

We keep an updated and curated list of models and examples built with Burn, see the burn-rs/models repository for more details.

Get Started

The best way to get started with burn is to clone the repo and play with the examples. This may also be a good idea to take a look the main components of burn to get a quick overview of the fundamental building blocks. If you're interested in how the framework works, you can read our architecture document.

Examples

Components

Understanding the key components and philosophy of burn can greatly help when beginning to work with the framework.

Backend

Nearly everything in burn is based on the Backend trait, which enables you to run tensor operations using different implementations without having to modify your code. While a backend may not necessarily have autodiff capabilities, the ADBackend trait specifies when autodiff is needed. This trait not only abstracts operations but also tensor, device and element types, providing each backend the flexibility they need. It's worth noting that the trait assumes eager mode since burn fully supports dynamic graphs. However, we may create another API to assist with integrating graph-based backends, without requiring any changes to the user's code.

Tensor

At the core of burn lies the Tensor struct, which encompasses multiple types of tensors, including Float, Int, and Bool. The element types of these tensors are specified by the backend and are usually designated as a generic argument (e.g., NdArrayBackend<f32>). Although the same struct is used for all tensors, the available methods differ depending on the tensor kind. You can specify the desired tensor kind by setting the third generic argument, which defaults to Float. The first generic argument specifies the backend, while the second specifies the number of dimensions.

use burn::tensor::backend::Backend;
use burn::tensor::{Tensor, Int};

fn function<B: Backend>(tensor_float: Tensor<B, 2>) {
    let _tensor_bool = tensor_float.clone().equal_elem(2.0); // Tensor<B, 2, Bool>
    let _tensor_int = tensor_float.argmax(1) // Tensor<B, 2, Int>
}

As demonstrated in the previous example, nearly all operations require owned tensors as parameters, which means that calling Clone explicitly is necessary when reusing the same tensor multiple times. However, there's no need to worry since the tensor's data won't be copied, it will be flagged as readonly when multiple tensors use the same allocated memory. This enables backends to reuse tensor data when possible, similar to a copy-on-write pattern, while remaining completely transparent to the user.

Autodiff

The 'Backend' trait is highly flexible, enabling backpropagation to be implemented using a simple backend decorator, which makes any backend differentiable.

use burn::tensor::backend::{ADBackend, Backend};
use burn::tensor::{Distribution, Tensor};
use burn_autodiff::ADBackendDecorator;
use burn_ndarray::NdArrayBackend;

fn linear<B: Backend>(x: Tensor<B, 2>, weight: Tensor<B, 2>, bias: Tensor<B, 2>) -> Tensor<B, 2> {
    x.matmul(weight) + bias
}

fn main() {
    type Backend = NdArrayBackend<f32>;

    let weight = Tensor::random([3, 3], Distribution::Standard);
    let bias = Tensor::zeros([1, 3]);
    let x = Tensor::random([3, 3], Distribution::Standard);

    let y = linear::<Backend>(x.clone(), weight.clone(), bias.clone());
    // y.backward() // Method backward doesn't exist

    let y = linear::<ADBackendDecorator<Backend>>(
        Tensor::from_inner(x),
        Tensor::from_inner(weight).require_grad(),
        Tensor::from_inner(bias).require_grad(),
    );
    let grads = y.backward(); // Method exists
}

Module

The Module derive allows you to create your own neural network modules, similar to PyTorch. The derive function only generates the necessary methods to essentially act as a parameter container for your type, it makes no assumptions about how the forward pass is declared.

use burn::nn;
use burn::module::Module;
use burn::tensor::backend::Backend;

#[derive(Module, Debug)]
pub struct PositionWiseFeedForward<B: Backend> {
    linear_inner: Linear<B>,
    linear_outer: Linear<B>,
    dropout: Dropout,
    gelu: GELU,
}

impl<B: Backend> PositionWiseFeedForward<B> {
    pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
        let x = self.linear_inner.forward(input);
        let x = self.gelu.forward(x);
        let x = self.dropout.forward(x);

        self.linear_outer.forward(x)
    }
}

Note that all fields declared in the struct must also implement the Module trait. The Tensor struct doesn't implement Module, but Param<Tensor<B, D>> does.

Config

The Config derive lets you define serializable and deserializable configurations or hyper-parameters for your modules or any components.

use burn::config::Config;

#[derive(Config)]
pub struct PositionWiseFeedForwardConfig {
    pub d_model: usize,
    pub d_ff: usize,
    #[config(default = 0.1)]
    pub dropout: f64,
}

The derive also adds useful methods to your config, similar to a builder pattern.

fn main() {
    let config = PositionWiseFeedForwardConfig::new(512, 2048);
    println!("{}", config.d_model); // 512
    println!("{}", config.d_ff); // 2048
    println!("{}", config.dropout); // 0.1
    let config =  config.with_dropout(0.2);
    println!("{}", config.dropout); // 0.2
}

Learner

The Learner is the main struct that let you train a neural network with support for logging, metric, checkpointing and more. In order to create a learner, you must use the LearnerBuilder.

use burn::train::LearnerBuilder;
use burn::train::metric::{AccuracyMetric, LossMetric};
use burn::record::DefaultRecordSettings;

fn main() {
    let dataloader_train = ...;
    let dataloader_valid = ...;

    let model = ...;
    let optim = ...;

    let learner = LearnerBuilder::new("/tmp/artifact_dir")
        .metric_train_plot(AccuracyMetric::new())
        .metric_valid_plot(AccuracyMetric::new())
        .metric_train(LossMetric::new())
        .metric_valid(LossMetric::new())
        .with_file_checkpointer::<DefaultRecordSettings>(2)
        .num_epochs(10)
        .build(model, optim);

    let _model_trained = learner.fit(dataloader_train, dataloader_valid);
}

See this example for a real usage.

Support for no_std

Burn, including its burn-ndarray backend, can work in a no_std environment, provided alloc is available for the inference mode. To accomplish this, simply turn off the default features in burn and burn-ndarray (which is the minimum requirement for running the inference mode). You can find a reference example in burn-no-std-tests.

The burn-core and burn-tensor crates also support no_std with alloc. These crates can be directly added as dependencies if necessary, as they are reexported by the burn crate.

Please be aware that when using the no_std mode, a random seed will be generated at build time if one hasn't been set using the Backend::seed method. Also, the spin::mutex::Mutex is used instead of std::sync::Mutex in this mode.

Contributing

Before contributing, please take a moment to review our code of conduct. It's also highly recommended to read our architecture document, which explains our architectural decisions. Please see more details in our contributing guide.

CI

Publish crates

Compile scripts/publish.rs using this command:

rustc scripts/publish.rs --crate-type bin --out-dir scripts

Disclamer

Burn is currently in active development, and there will be breaking changes. While any resulting issues are likely to be easy to fix, there are no guarantees at this stage.

Sponsors

You can sponsor the founder of Burn from his GitHub Sponsors profile. The Burn-rs organization doesn't yet have a fiscal entity, but other sponsor methods might become available as the project grows.

Thanks to all current sponsors 🙏.

nathanielsimard

License

Burn is distributed under the terms of both the MIT license and the Apache License (Version 2.0). See LICENSE-APACHE and LICENSE-MIT for details. Opening a pull request is assumed to signal agreement with these licensing terms.

About

Burn - A Flexible and Comprehensive Deep Learning Framework in Rust

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Rust 94.4%
  • WGSL 4.6%
  • Other 1.0%