Skip to content

Commit

Permalink
Add vectorization support into cube (tracel-ai#1830)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard committed May 27, 2024
1 parent dc85daa commit fd54a8b
Show file tree
Hide file tree
Showing 8 changed files with 143 additions and 36 deletions.
8 changes: 4 additions & 4 deletions crates/burn-cube-macros/src/codegen/launch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,15 +132,15 @@ impl Codegen {

let mut variables = quote::quote! {};

for (ident, ty) in self.state_inputs.iter() {
for (pos, (ident, ty)) in self.state_inputs.iter().enumerate() {
variables.extend(quote::quote! {
let #ident = <#ty as LaunchArg>::compile_input(&mut builder);
let #ident = <#ty as LaunchArg>::compile_input(&mut builder, self.settings.vectorization_input(#pos));
});
}

for (ident, ty) in self.state_outputs.iter() {
for (pos, (ident, ty)) in self.state_outputs.iter().enumerate() {
variables.extend(quote::quote! {
let #ident = <#ty as LaunchArg>::compile_output(&mut builder);
let #ident = <#ty as LaunchArg>::compile_output(&mut builder, self.settings.vectorization_output(#pos));
});
}

Expand Down
106 changes: 94 additions & 12 deletions crates/burn-cube/src/codegen/compilation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,23 @@ pub struct InplaceMapping {
pub pos_output: usize,
}

#[derive(Clone, Copy, Debug)]
enum VectorizationPartial {
Input {
pos: usize,
vectorization: Vectorization,
},
Output {
pos: usize,
vectorization: Vectorization,
},
}

#[derive(Default, Clone)]
pub struct CompilationSettings {
pub mappings: Vec<InplaceMapping>,
vectorization: Option<Vectorization>,
vectorization_global: Option<Vectorization>,
vectorization_partial: Vec<VectorizationPartial>,
workgroup_size: WorkgroupSize,
pub reading_strategy: Vec<(u16, ReadingStrategy)>,
}
Expand All @@ -59,7 +72,9 @@ impl core::fmt::Display for CompilationSettings {
// * Output layout: o
// * Plain: p
//
// * Vectorization: v
// * Vectorization Global: vg{factor}
// * Vectorization Partial Input: v{factor}i{pos}
// * Vectorization Partial Output: vo
// * Workgroup Size X: x
// * Workgroup Size Y: y
// * Workgroup Size Z: z
Expand All @@ -80,11 +95,22 @@ impl core::fmt::Display for CompilationSettings {
}?;
}

match self.vectorization {
Some(vectorization) => f.write_fmt(format_args!("v{}", vectorization))?,
match self.vectorization_global {
Some(vectorization) => f.write_fmt(format_args!("vg{}", vectorization))?,
None => f.write_str("vn")?,
};

for vectorization in self.vectorization_partial.iter() {
match vectorization {
VectorizationPartial::Input { pos, vectorization } => {
f.write_fmt(format_args!("v{vectorization}i{pos}"))?
}
VectorizationPartial::Output { pos, vectorization } => {
f.write_fmt(format_args!("v{vectorization}o{pos}"))?
}
};
}

f.write_fmt(format_args!(
"x{}y{}z{}",
self.workgroup_size.x, self.workgroup_size.y, self.workgroup_size.x
Expand All @@ -93,13 +119,69 @@ impl core::fmt::Display for CompilationSettings {
}

impl CompilationSettings {
/// Compile the shader with vectorization enabled.
/// Compile the shader with vectorization enabled for all inputs and outputs.
#[allow(dead_code)]
pub fn vectorize_global(mut self, vectorization: Vectorization) -> Self {
self.vectorization_global = Some(vectorization);
self
}

/// Compile the shader with vectorization enabled for an input.
#[allow(dead_code)]
pub fn vectorize_input(mut self, position: usize, vectorization: Vectorization) -> Self {
self.vectorization_partial
.push(VectorizationPartial::Input {
pos: position,
vectorization,
});
self
}

/// Compile the shader with vectorization enabled for an output.
#[allow(dead_code)]
pub fn vectorize(mut self, vectorization: Vectorization) -> Self {
self.vectorization = Some(vectorization);
pub fn vectorize_output(mut self, position: usize, vectorization: Vectorization) -> Self {
self.vectorization_partial
.push(VectorizationPartial::Output {
pos: position,
vectorization,
});
self
}

/// Fetch the vectorization for the provided input position.
pub fn vectorization_input(&self, position: usize) -> Vectorization {
if let Some(vec) = self.vectorization_global {
return vec;
}

for partial in self.vectorization_partial.iter() {
if let VectorizationPartial::Input { pos, vectorization } = partial {
if *pos == position {
return *vectorization;
}
}
}

1
}

/// Fetch the vectorization for the provided output position.
pub fn vectorization_output(&self, position: usize) -> Vectorization {
if let Some(vec) = self.vectorization_global {
return vec;
}

for partial in self.vectorization_partial.iter() {
if let VectorizationPartial::Output { pos, vectorization } = partial {
if *pos == position {
return *vectorization;
}
}
}

1
}

/// Compile the shader with inplace enabled by the given [mapping](InplaceMapping).
///
/// Notes:
Expand Down Expand Up @@ -233,7 +315,7 @@ impl Compilation {

/// Performs the compilation with the provided [settings](CompilationSettings).
pub fn compile(mut self, mut settings: CompilationSettings) -> ComputeShader {
if let Some(vectorization) = settings.vectorization {
if let Some(vectorization) = settings.vectorization_global {
self.info.scope.vectorize(vectorization);
}

Expand Down Expand Up @@ -276,7 +358,7 @@ impl Compilation {
for input in self.info.inputs.drain(..) {
match input {
InputInfo::Array { item, visibility } => {
let item = if let Some(vectorization) = settings.vectorization {
let item = if let Some(vectorization) = settings.vectorization_global {
item.vectorize(vectorization)
} else {
item
Expand Down Expand Up @@ -325,7 +407,7 @@ impl Compilation {
local,
position,
} => {
let item = if let Some(vectorization) = settings.vectorization {
let item = if let Some(vectorization) = settings.vectorization_global {
item.vectorize(vectorization)
} else {
item
Expand All @@ -351,7 +433,7 @@ impl Compilation {
local,
position,
} => {
let item = if let Some(vectorization) = settings.vectorization {
let item = if let Some(vectorization) = settings.vectorization_global {
item.vectorize(vectorization)
} else {
item
Expand All @@ -364,7 +446,7 @@ impl Compilation {
);
}
OutputInfo::Array { item } => {
let item = if let Some(vectorization) = settings.vectorization {
let item = if let Some(vectorization) = settings.vectorization_global {
item.vectorize(vectorization)
} else {
item
Expand Down
9 changes: 6 additions & 3 deletions crates/burn-cube/src/language/element/base.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use crate::{dialect::Variable, KernelBuilder, KernelLauncher, Runtime};
use crate::{
dialect::{Variable, Vectorization},
KernelBuilder, KernelLauncher, Runtime,
};
use alloc::rc::Rc;

/// Types used in a cube function must implement this trait
Expand All @@ -23,9 +26,9 @@ pub trait LaunchArg {
type RuntimeArg<'a, R: Runtime>: ArgSettings<R>;

/// Register an input variable during compilation that fill the [KernelBuilder].
fn compile_input(builder: &mut KernelBuilder) -> ExpandElement;
fn compile_input(builder: &mut KernelBuilder, vectorization: Vectorization) -> ExpandElement;
/// Register an output variable during compilation that fill the [KernelBuilder].
fn compile_output(builder: &mut KernelBuilder) -> ExpandElement;
fn compile_output(builder: &mut KernelBuilder, vectorization: Vectorization) -> ExpandElement;
}

/// Defines the argument settings used to launch a kernel.
Expand Down
16 changes: 11 additions & 5 deletions crates/burn-cube/src/language/element/tensor.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::{
dialect::{Elem, Item, Metadata},
dialect::{Elem, Item, Metadata, Vectorization},
language::{CubeType, ExpandElement},
unexpanded, ArgSettings, CubeContext, CubeElem, KernelLauncher, LaunchArg, Runtime, UInt,
};
Expand All @@ -17,12 +17,18 @@ impl<T: CubeType> CubeType for Tensor<T> {
impl<C: CubeElem> LaunchArg for Tensor<C> {
type RuntimeArg<'a, R: Runtime> = TensorHandle<'a, R>;

fn compile_input(builder: &mut crate::KernelBuilder) -> ExpandElement {
builder.input_array(Item::new(C::as_elem()))
fn compile_input(
builder: &mut crate::KernelBuilder,
vectorization: Vectorization,
) -> ExpandElement {
builder.input_array(Item::vectorized(C::as_elem(), vectorization))
}

fn compile_output(builder: &mut crate::KernelBuilder) -> ExpandElement {
builder.output_array(Item::new(C::as_elem()))
fn compile_output(
builder: &mut crate::KernelBuilder,
vectorization: Vectorization,
) -> ExpandElement {
builder.output_array(Item::vectorized(C::as_elem(), vectorization))
}
}

Expand Down
14 changes: 11 additions & 3 deletions crates/burn-cube/src/language/element/uint.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::dialect::{Elem, Variable};
use crate::dialect::{Elem, Variable, Vectorization};
use crate::language::{CubeContext, CubeElem, CubeType, ExpandElement, Numeric};
use crate::{ArgSettings, KernelLauncher, LaunchArg, Runtime};

Expand All @@ -23,11 +23,19 @@ impl CubeElem for UInt {
impl LaunchArg for UInt {
type RuntimeArg<'a, R: Runtime> = u32;

fn compile_input(builder: &mut crate::KernelBuilder) -> ExpandElement {
fn compile_input(
builder: &mut crate::KernelBuilder,
vectorization: Vectorization,
) -> ExpandElement {
assert_eq!(vectorization, 1, "Attempted to vectorize a scalar");
builder.scalar(Self::as_elem())
}

fn compile_output(builder: &mut crate::KernelBuilder) -> ExpandElement {
fn compile_output(
builder: &mut crate::KernelBuilder,
vectorization: Vectorization,
) -> ExpandElement {
assert_eq!(vectorization, 1, "Attempted to vectorize a scalar");
builder.scalar(Self::as_elem())
}
}
Expand Down
18 changes: 12 additions & 6 deletions crates/burn-cube/src/language/operation/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,16 @@ where
let item_lhs = lhs.item();
let item_rhs = rhs.item();

check_vectorization(item_lhs.vectorization, item_rhs.vectorization);
let vectorization = check_vectorization(item_lhs.vectorization, item_rhs.vectorization);
let item = Item::vectorized(item_lhs.elem, vectorization);

// We can only reuse rhs.
let out = if lhs.can_mut() {
let out = if lhs.can_mut() && item_lhs == item {
lhs
} else if item_rhs == item_lhs && rhs.can_mut() {
} else if rhs.can_mut() && item_rhs == item {
rhs
} else {
context.create_local(item_lhs)
context.create_local(item)
};

let out_var = *out;
Expand Down Expand Up @@ -131,12 +132,17 @@ where
out
}

fn check_vectorization(lhs: Vectorization, rhs: Vectorization) {
fn check_vectorization(lhs: Vectorization, rhs: Vectorization) -> Vectorization {
let output = u8::max(lhs, rhs);

if lhs == 1 || rhs == 1 {
return;
return output;
}

assert!(
lhs == rhs,
"Tried to perform binary operation on different vectorization schemes."
);

output
}
4 changes: 2 additions & 2 deletions crates/burn-jit/src/fusion/elemwise/kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,10 @@ impl<R: JitRuntime> FusionKernelFactory<R> for ElementWiseKernelFactory<R> {
);

if vectorize_4 {
settings = settings.vectorize(4);
settings = settings.vectorize_global(4);
factor = 4;
} else if vectorize_2 {
settings = settings.vectorize(2);
settings = settings.vectorize_global(2);
factor = 2;
}

Expand Down
4 changes: 3 additions & 1 deletion crates/burn-jit/src/kernel/conv/conv2d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,9 @@ pub(crate) fn conv2d<R: JitRuntime, E: FloatElement>(

let num_elems_output = output.shape.num_elements();
let workgroup = elemwise_workgroup(num_elems_output, WORKGROUP_DEFAULT);
let settings = CompilationSettings::default();
let settings = CompilationSettings::default()
.vectorize_input(0, 1)
.vectorize_output(0, 1);

kernel_launch::<E::CubeElement, R>(
input.client,
Expand Down

0 comments on commit fd54a8b

Please sign in to comment.