Skip to content

Commit

Permalink
Update cuda-jit (#1799)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard committed May 24, 2024
1 parent 23c622a commit c7ad25a
Show file tree
Hide file tree
Showing 15 changed files with 80 additions and 104 deletions.
26 changes: 25 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ members = [

exclude = [
"examples/notebook",
"crates/burn-cuda" # comment this line to work on burn-cuda
# "crates/burn-cuda" # comment this line to work on burn-cuda
]

[workspace.package]
Expand Down
2 changes: 2 additions & 0 deletions backend-comparison/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,14 @@ tch-cpu = ["burn/tch"]
tch-gpu = ["burn/tch"]
wgpu = ["burn/wgpu", "burn/autotune"]
wgpu-fusion = ["wgpu", "burn/fusion"]
cuda-jit = ["burn-cuda"]

[dependencies]
arboard = { workspace = true }
burn = { path = "../crates/burn", default-features = false }
burn-common = { path = "../crates/burn-common", version = "0.15.0" }
burn-wgpu = { path = "../crates/burn-wgpu", default-features = false, version = "0.15.0" }
burn-cuda = { path = "../crates/burn-cuda", version = "0.15.0", optional = true }
clap = { workspace = true }
colored = { workspace = true }
derive-new = { workspace = true }
Expand Down
2 changes: 2 additions & 0 deletions backend-comparison/src/burnbenchapp/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ enum BackendValues {
Wgpu,
#[strum(to_string = "wgpu-fusion")]
WgpuFusion,
#[strum(to_string = "cuda-jit")]
CudaJit,
}

#[derive(Debug, Clone, PartialEq, Eq, ValueEnum, Display, EnumIter)]
Expand Down
9 changes: 9 additions & 0 deletions backend-comparison/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ macro_rules! bench_on_backend {
let feature_name = "wgpu";
#[cfg(feature = "wgpu-fusion")]
let feature_name = "wgpu-fusion";
#[cfg(feature = "cuda-jit")]
let feature_name = "cuda-jit";

#[cfg(feature = "wgpu")]
{
Expand Down Expand Up @@ -129,6 +131,13 @@ macro_rules! bench_on_backend {
let device = CandleDevice::Metal(0);
bench::<Candle>(&device, feature_name, url, token);
}

#[cfg(feature = "cuda-jit")]
{
use burn_cuda::{Cuda, CudaDevice};

bench::<Cuda>(&CudaDevice::default(), feature_name, url, token);
}
};
}

Expand Down
9 changes: 5 additions & 4 deletions crates/burn-common/src/benchmark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ impl BenchmarkComputations {
/// Benchmark trait.
pub trait Benchmark {
/// Benchmark arguments.
type Args;
type Args: Clone;

/// Prepare the benchmark, run anything that is essential for the benchmark, but shouldn't
/// count as included in the duration.
Expand Down Expand Up @@ -149,19 +149,20 @@ pub trait Benchmark {
#[cfg(feature = "std")]
{
// Warmup
self.execute(self.prepare());
let args = self.prepare();

self.execute(args.clone());
self.sync();

let mut durations = Vec::with_capacity(self.num_samples());

for _ in 0..self.num_samples() {
// Prepare
let args = self.prepare();
self.sync();

// Execute the benchmark
let start = Instant::now();
self.execute(args);
self.execute(args.clone());
self.sync();
let end = Instant::now();

Expand Down
6 changes: 6 additions & 0 deletions crates/burn-compute/src/tune/tune_benchmark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@ pub struct TuneBenchmark<S: ComputeServer, C> {
client: ComputeClient<S, C>,
}

impl Clone for Box<dyn AutotuneOperation> {
fn clone(&self) -> Self {
self.as_ref().clone()
}
}

impl<S: ComputeServer, C: ComputeChannel<S>> Benchmark for TuneBenchmark<S, C> {
type Args = Box<dyn AutotuneOperation>;

Expand Down
5 changes: 3 additions & 2 deletions crates/burn-cuda/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@ burn-jit = { path = "../burn-jit", version = "0.15.0", default-features = false
burn-compute = { path = "../burn-compute", version = "0.15.0" }
burn-tensor = { path = "../burn-tensor", version = "0.15.0" }
burn-common = { path = "../burn-common", version = "0.15.0" }
burn-cube = { path = "../burn-cube", version = "0.15.0" }
burn-fusion = { path = "../burn-fusion", version = "0.15.0", optional = true }
half = { workspace = true }

half = { workspace = true }
bytemuck = { workspace = true }
cudarc = "0.10.0"

Expand All @@ -37,4 +38,4 @@ burn-jit = { path = "../burn-jit", version = "0.15.0", default-features = false,
] }

[package.metadata.docs.rs]
features = ["doc"]
features = ["doc"]
59 changes: 12 additions & 47 deletions crates/burn-cuda/src/compiler/base.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use burn_cube::{dialect as gpu, Compiler};

use super::Instruction;
use burn_jit::gpu::{self};

#[allow(clippy::too_many_arguments)]
#[derive(new, Clone, Debug, Default)]
Expand All @@ -16,15 +17,15 @@ pub struct CudaCompiler {
global_invocation_id: (bool, bool, bool),
}

impl burn_jit::Compiler for CudaCompiler {
impl Compiler for CudaCompiler {
type Representation = super::ComputeShader;

fn compile(shader: burn_jit::gpu::ComputeShader) -> Self::Representation {
fn compile(shader: burn_cube::dialect::ComputeShader) -> Self::Representation {
let compiler = Self::default();
compiler.compile_shader(shader)
}

fn elem_size(elem: burn_jit::gpu::Elem) -> usize {
fn elem_size(elem: gpu::Elem) -> usize {
Self::compile_elem(elem).size()
}

Expand Down Expand Up @@ -75,44 +76,7 @@ impl CudaCompiler {

fn compile_scope(&mut self, value: &mut gpu::Scope) -> Vec<Instruction> {
let mut instructions = Vec::new();
let mut processing = value.process();

for operation in &mut processing.operations {
if let gpu::Operation::Operator(gpu::Operator::Index(operands)) = operation {
// Replace all Index operators for global arrays with CheckedIndexAssign procedures
match operands.lhs {
gpu::Variable::GlobalInputArray(_, _)
| gpu::Variable::GlobalOutputArray(_, _) => {
*operation = gpu::Operation::Procedure(gpu::Procedure::CheckedIndex(
gpu::CheckedIndex {
lhs: operands.lhs,
rhs: operands.rhs,
out: operands.out,
},
));
}
// Cannot perform bound check on non-global arrays, do nothing.
_ => (),
}
}
if let gpu::Operation::Operator(gpu::Operator::IndexAssign(operands)) = operation {
// Replace all IndexAssign operators of global arrays with CheckedIndexAssign procedures
match operands.out {
gpu::Variable::GlobalInputArray(_, _)
| gpu::Variable::GlobalOutputArray(_, _) => {
*operation = gpu::Operation::Procedure(gpu::Procedure::CheckedIndexAssign(
gpu::CheckedIndexAssign {
lhs: operands.lhs,
rhs: operands.rhs,
out: operands.out,
},
));
}
// Cannot perform bound check on non-global arrays, do nothing.
_ => (),
}
}
}
let processing = value.process();

for var in processing.variables {
instructions.push(Instruction::DeclareVariable {
Expand Down Expand Up @@ -415,11 +379,12 @@ impl CudaCompiler {
}

fn compile_item(item: gpu::Item) -> super::Item {
match item {
gpu::Item::Vec4(elem) => super::Item::Vec4(Self::compile_elem(elem)),
gpu::Item::Vec3(elem) => super::Item::Vec3(Self::compile_elem(elem)),
gpu::Item::Vec2(elem) => super::Item::Vec2(Self::compile_elem(elem)),
gpu::Item::Scalar(elem) => super::Item::Scalar(Self::compile_elem(elem)),
match item.vectorization {
4 => super::Item::Vec4(Self::compile_elem(item.elem)),
3 => super::Item::Vec3(Self::compile_elem(item.elem)),
2 => super::Item::Vec2(Self::compile_elem(item.elem)),
1 => super::Item::Scalar(Self::compile_elem(item.elem)),
_ => panic!("Vectorization factor unsupported {:?}", item.vectorization),
}
}

Expand Down
2 changes: 1 addition & 1 deletion crates/burn-cuda/src/compiler/element.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use burn_jit::gpu;
use burn_cube::dialect as gpu;
use half::{bf16, f16};
use std::fmt::Display;

Expand Down
3 changes: 2 additions & 1 deletion crates/burn-cuda/src/compiler/shader.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use burn_cube::{dialect::WorkgroupSize, CompilerRepresentation};

// use super::{Body, Extension, Item};
use super::{Body, Item};
use burn_jit::{gpu::WorkgroupSize, CompilerRepresentation};
use std::fmt::Display;

#[derive(Debug, PartialEq, Eq, Clone, Copy)]
Expand Down
7 changes: 5 additions & 2 deletions crates/burn-cuda/src/compute/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@ use burn_compute::{
memory_management::MemoryManagement,
server::{self, ComputeServer},
};
use burn_jit::compute::{JitAutotuneKey, Kernel, WorkGroup};
use burn_jit::gpu::WorkgroupSize;
use burn_cube::dialect::WorkgroupSize;
use burn_cube::JitKernel;
use burn_cube::Kernel;
use burn_cube::WorkGroup;
use burn_jit::JitAutotuneKey;
use cudarc::driver::sys::CUctx_st;
use cudarc::driver::sys::CUfunc_st;
use std::collections::HashMap;
Expand Down
42 changes: 0 additions & 42 deletions crates/burn-cuda/src/element.rs

This file was deleted.

1 change: 0 additions & 1 deletion crates/burn-cuda/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ extern crate alloc;

mod compute;
mod device;
mod element;
mod runtime;

pub mod compiler;
Expand Down
9 changes: 7 additions & 2 deletions crates/burn-cuda/src/runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use burn_compute::{
tune::Tuner,
ComputeRuntime,
};
use burn_jit::Runtime;
use burn_cube::Runtime;
use std::sync::Arc;

use crate::{
Expand All @@ -18,6 +18,11 @@ use crate::{
#[derive(Debug)]
pub struct CudaRuntime;

impl burn_jit::JitRuntime for CudaRuntime {
type JitDevice = CudaDevice;
type JitServer = CudaServer<SimpleMemoryManagement<CudaStorage>>;
}

// static RUNTIME: ComputeRuntime<CudaDevice, Server, MutexComputeChannel<Server>> =
static RUNTIME: ComputeRuntime<CudaDevice, Server, MutexComputeChannel<Server>> =
ComputeRuntime::new();
Expand Down Expand Up @@ -51,7 +56,7 @@ impl Runtime for CudaRuntime {
let memory_management = SimpleMemoryManagement::new(
storage,
DeallocStrategy::new_period_tick(1),
SliceStrategy::Never,
SliceStrategy::Ratio(0.8),
);
CudaContext::new(memory_management, stream, ctx)
}
Expand Down

0 comments on commit c7ad25a

Please sign in to comment.