Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update cuda-jit #1799

Merged
merged 2 commits into from
May 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ members = [

exclude = [
"examples/notebook",
"crates/burn-cuda" # comment this line to work on burn-cuda
# "crates/burn-cuda" # comment this line to work on burn-cuda
]

[workspace.package]
Expand Down
2 changes: 2 additions & 0 deletions backend-comparison/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,14 @@ tch-cpu = ["burn/tch"]
tch-gpu = ["burn/tch"]
wgpu = ["burn/wgpu", "burn/autotune"]
wgpu-fusion = ["wgpu", "burn/fusion"]
cuda-jit = ["burn-cuda"]

[dependencies]
arboard = { workspace = true }
burn = { path = "../crates/burn", default-features = false }
burn-common = { path = "../crates/burn-common", version = "0.15.0" }
burn-wgpu = { path = "../crates/burn-wgpu", default-features = false, version = "0.15.0" }
burn-cuda = { path = "../crates/burn-cuda", version = "0.15.0", optional = true }
clap = { workspace = true }
colored = { workspace = true }
derive-new = { workspace = true }
Expand Down
2 changes: 2 additions & 0 deletions backend-comparison/src/burnbenchapp/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ enum BackendValues {
Wgpu,
#[strum(to_string = "wgpu-fusion")]
WgpuFusion,
#[strum(to_string = "cuda-jit")]
CudaJit,
}

#[derive(Debug, Clone, PartialEq, Eq, ValueEnum, Display, EnumIter)]
Expand Down
9 changes: 9 additions & 0 deletions backend-comparison/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ macro_rules! bench_on_backend {
let feature_name = "wgpu";
#[cfg(feature = "wgpu-fusion")]
let feature_name = "wgpu-fusion";
#[cfg(feature = "cuda-jit")]
let feature_name = "cuda-jit";

#[cfg(feature = "wgpu")]
{
Expand Down Expand Up @@ -129,6 +131,13 @@ macro_rules! bench_on_backend {
let device = CandleDevice::Metal(0);
bench::<Candle>(&device, feature_name, url, token);
}

#[cfg(feature = "cuda-jit")]
{
use burn_cuda::{Cuda, CudaDevice};

bench::<Cuda>(&CudaDevice::default(), feature_name, url, token);
}
};
}

Expand Down
9 changes: 5 additions & 4 deletions crates/burn-common/src/benchmark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ impl BenchmarkComputations {
/// Benchmark trait.
pub trait Benchmark {
/// Benchmark arguments.
type Args;
type Args: Clone;

/// Prepare the benchmark, run anything that is essential for the benchmark, but shouldn't
/// count as included in the duration.
Expand Down Expand Up @@ -149,19 +149,20 @@ pub trait Benchmark {
#[cfg(feature = "std")]
{
// Warmup
self.execute(self.prepare());
let args = self.prepare();

self.execute(args.clone());
self.sync();

let mut durations = Vec::with_capacity(self.num_samples());

for _ in 0..self.num_samples() {
// Prepare
let args = self.prepare();
self.sync();

// Execute the benchmark
let start = Instant::now();
self.execute(args);
self.execute(args.clone());
self.sync();
let end = Instant::now();

Expand Down
6 changes: 6 additions & 0 deletions crates/burn-compute/src/tune/tune_benchmark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@ pub struct TuneBenchmark<S: ComputeServer, C> {
client: ComputeClient<S, C>,
}

impl Clone for Box<dyn AutotuneOperation> {
fn clone(&self) -> Self {
self.as_ref().clone()
}
}

impl<S: ComputeServer, C: ComputeChannel<S>> Benchmark for TuneBenchmark<S, C> {
type Args = Box<dyn AutotuneOperation>;

Expand Down
5 changes: 3 additions & 2 deletions crates/burn-cuda/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@ burn-jit = { path = "../burn-jit", version = "0.15.0", default-features = false
burn-compute = { path = "../burn-compute", version = "0.15.0" }
burn-tensor = { path = "../burn-tensor", version = "0.15.0" }
burn-common = { path = "../burn-common", version = "0.15.0" }
burn-cube = { path = "../burn-cube", version = "0.15.0" }
burn-fusion = { path = "../burn-fusion", version = "0.15.0", optional = true }
half = { workspace = true }

half = { workspace = true }
bytemuck = { workspace = true }
cudarc = "0.10.0"

Expand All @@ -37,4 +38,4 @@ burn-jit = { path = "../burn-jit", version = "0.15.0", default-features = false,
] }

[package.metadata.docs.rs]
features = ["doc"]
features = ["doc"]
59 changes: 12 additions & 47 deletions crates/burn-cuda/src/compiler/base.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use burn_cube::{dialect as gpu, Compiler};

use super::Instruction;
use burn_jit::gpu::{self};

#[allow(clippy::too_many_arguments)]
#[derive(new, Clone, Debug, Default)]
Expand All @@ -16,15 +17,15 @@ pub struct CudaCompiler {
global_invocation_id: (bool, bool, bool),
}

impl burn_jit::Compiler for CudaCompiler {
impl Compiler for CudaCompiler {
type Representation = super::ComputeShader;

fn compile(shader: burn_jit::gpu::ComputeShader) -> Self::Representation {
fn compile(shader: burn_cube::dialect::ComputeShader) -> Self::Representation {
let compiler = Self::default();
compiler.compile_shader(shader)
}

fn elem_size(elem: burn_jit::gpu::Elem) -> usize {
fn elem_size(elem: gpu::Elem) -> usize {
Self::compile_elem(elem).size()
}

Expand Down Expand Up @@ -75,44 +76,7 @@ impl CudaCompiler {

fn compile_scope(&mut self, value: &mut gpu::Scope) -> Vec<Instruction> {
let mut instructions = Vec::new();
let mut processing = value.process();

for operation in &mut processing.operations {
if let gpu::Operation::Operator(gpu::Operator::Index(operands)) = operation {
// Replace all Index operators for global arrays with CheckedIndexAssign procedures
match operands.lhs {
gpu::Variable::GlobalInputArray(_, _)
| gpu::Variable::GlobalOutputArray(_, _) => {
*operation = gpu::Operation::Procedure(gpu::Procedure::CheckedIndex(
gpu::CheckedIndex {
lhs: operands.lhs,
rhs: operands.rhs,
out: operands.out,
},
));
}
// Cannot perform bound check on non-global arrays, do nothing.
_ => (),
}
}
if let gpu::Operation::Operator(gpu::Operator::IndexAssign(operands)) = operation {
// Replace all IndexAssign operators of global arrays with CheckedIndexAssign procedures
match operands.out {
gpu::Variable::GlobalInputArray(_, _)
| gpu::Variable::GlobalOutputArray(_, _) => {
*operation = gpu::Operation::Procedure(gpu::Procedure::CheckedIndexAssign(
gpu::CheckedIndexAssign {
lhs: operands.lhs,
rhs: operands.rhs,
out: operands.out,
},
));
}
// Cannot perform bound check on non-global arrays, do nothing.
_ => (),
}
}
}
let processing = value.process();

for var in processing.variables {
instructions.push(Instruction::DeclareVariable {
Expand Down Expand Up @@ -415,11 +379,12 @@ impl CudaCompiler {
}

fn compile_item(item: gpu::Item) -> super::Item {
match item {
gpu::Item::Vec4(elem) => super::Item::Vec4(Self::compile_elem(elem)),
gpu::Item::Vec3(elem) => super::Item::Vec3(Self::compile_elem(elem)),
gpu::Item::Vec2(elem) => super::Item::Vec2(Self::compile_elem(elem)),
gpu::Item::Scalar(elem) => super::Item::Scalar(Self::compile_elem(elem)),
match item.vectorization {
4 => super::Item::Vec4(Self::compile_elem(item.elem)),
3 => super::Item::Vec3(Self::compile_elem(item.elem)),
2 => super::Item::Vec2(Self::compile_elem(item.elem)),
1 => super::Item::Scalar(Self::compile_elem(item.elem)),
_ => panic!("Vectorization factor unsupported {:?}", item.vectorization),
}
}

Expand Down
2 changes: 1 addition & 1 deletion crates/burn-cuda/src/compiler/element.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use burn_jit::gpu;
use burn_cube::dialect as gpu;
use half::{bf16, f16};
use std::fmt::Display;

Expand Down
3 changes: 2 additions & 1 deletion crates/burn-cuda/src/compiler/shader.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use burn_cube::{dialect::WorkgroupSize, CompilerRepresentation};

// use super::{Body, Extension, Item};
use super::{Body, Item};
use burn_jit::{gpu::WorkgroupSize, CompilerRepresentation};
use std::fmt::Display;

#[derive(Debug, PartialEq, Eq, Clone, Copy)]
Expand Down
7 changes: 5 additions & 2 deletions crates/burn-cuda/src/compute/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@ use burn_compute::{
memory_management::MemoryManagement,
server::{self, ComputeServer},
};
use burn_jit::compute::{JitAutotuneKey, Kernel, WorkGroup};
use burn_jit::gpu::WorkgroupSize;
use burn_cube::dialect::WorkgroupSize;
use burn_cube::JitKernel;
use burn_cube::Kernel;
use burn_cube::WorkGroup;
use burn_jit::JitAutotuneKey;
use cudarc::driver::sys::CUctx_st;
use cudarc::driver::sys::CUfunc_st;
use std::collections::HashMap;
Expand Down
42 changes: 0 additions & 42 deletions crates/burn-cuda/src/element.rs

This file was deleted.

1 change: 0 additions & 1 deletion crates/burn-cuda/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ extern crate alloc;

mod compute;
mod device;
mod element;
mod runtime;

pub mod compiler;
Expand Down
9 changes: 7 additions & 2 deletions crates/burn-cuda/src/runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use burn_compute::{
tune::Tuner,
ComputeRuntime,
};
use burn_jit::Runtime;
use burn_cube::Runtime;
use std::sync::Arc;

use crate::{
Expand All @@ -18,6 +18,11 @@ use crate::{
#[derive(Debug)]
pub struct CudaRuntime;

impl burn_jit::JitRuntime for CudaRuntime {
type JitDevice = CudaDevice;
type JitServer = CudaServer<SimpleMemoryManagement<CudaStorage>>;
}

// static RUNTIME: ComputeRuntime<CudaDevice, Server, MutexComputeChannel<Server>> =
static RUNTIME: ComputeRuntime<CudaDevice, Server, MutexComputeChannel<Server>> =
ComputeRuntime::new();
Expand Down Expand Up @@ -51,7 +56,7 @@ impl Runtime for CudaRuntime {
let memory_management = SimpleMemoryManagement::new(
storage,
DeallocStrategy::new_period_tick(1),
SliceStrategy::Never,
SliceStrategy::Ratio(0.8),
);
CudaContext::new(memory_management, stream, ctx)
}
Expand Down
Loading