Skip to content

Commit

Permalink
[DOCS] Added matrix multiplication tutorial
Browse files Browse the repository at this point in the history
  • Loading branch information
ptillet committed Mar 15, 2021
1 parent d1c0bf2 commit 2f8f004
Show file tree
Hide file tree
Showing 9 changed files with 395 additions and 18 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ if(BUILD_PYTHON_MODULE)
if(NOT("${CUTLASS_INCLUDE_DIR}" STREQUAL "") AND NOT("${CUTLASS_LIBRARY_DIR}" STREQUAL ""))
set(TORCH_SRC ${TORCH_SRC} cutlass.cc)
add_definitions(-DWITH_CUTLASS_BINDINGS)
set(CUTLASS_LIBRARIES "cutlass")
set(CUTLASS_LIBRARIES "cutlass.a")
endif()
message(STATUS ${CUTLASS_INCLUDE_PATH})
set(PYTHON_SRC main.cc triton.cc ${TORCH_SRC})
Expand Down
2 changes: 1 addition & 1 deletion lib/ir/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ void module::set_value(const std::string& name, ir::basic_block *block, ir::valu
if(it != metadatas_.end()){
x->set_metadata(it->second.first, it->second.second);
}
value->set_name(name);
// value->set_name(name);
}

void module::set_value(const std::string& name, ir::value *value){
Expand Down
3 changes: 3 additions & 0 deletions lib/runtime/function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,9 @@ function::function(const std::string& src, const options_t &opt, driver::device
// find indices of autotune keys
for(const std::string& name: tune_key){
auto pred = [&](ir::argument* arg) { return arg->get_name() == name; };
// std::cout << "----" << std::endl;
// for(ir::argument* arg: args)
// std::cout << arg->get_name() << std::endl;
auto it = std::find_if(args.begin(), args.end(), pred);
if(it == args.end())
throw std::runtime_error(name + " is not a valid argument name");
Expand Down
8 changes: 2 additions & 6 deletions python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,7 @@ def run(self):
out = subprocess.check_output(["cmake", "--version"])
except OSError:
raise RuntimeError(
"CMake must be installed to build the following extensions: " +
", ".join(e.name for e in self.extensions)
"CMake must be installed to build the following extensions: " + ", ".join(e.name for e in self.extensions)
)

if platform.system() == "Windows":
Expand Down Expand Up @@ -107,10 +106,7 @@ def build_extension(self, ext):
long_description="",
packages=["triton", "triton/_C", "triton/ops", "triton/ops/blocksparse"],
install_requires=["numpy", "torch"],
package_data={
"triton/ops": ["*.c"],
"triton/ops/blocksparse": ["*.c"]
},
package_data={"triton/ops": ["*.c"], "triton/ops/blocksparse": ["*.c"]},
include_package_data=True,
ext_modules=[CMakeExtension("triton", "triton/_C/")],
cmdclass={"build_ext": CMakeBuild},
Expand Down
4 changes: 2 additions & 2 deletions python/triton/ops/matmul.c
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#define STM 8
#define STN 8
#define STM 1
#define STN 1

__global__ void matmul(TYPE *A __noalias __readonly,
TYPE *B __noalias __readonly,
Expand Down
2 changes: 1 addition & 1 deletion python/triton/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def allclose(x, y):
return err < tol


def do_bench(fn, warmup=50, rep=50, grad_to_none=None, percentiles=[0.2, 0.8]):
def do_bench(fn, warmup=25, rep=100, grad_to_none=None, percentiles=[0.2, 0.8]):
# Estimate the runtime of the function
fn()
torch.cuda.synchronize()
Expand Down
8 changes: 4 additions & 4 deletions python/tutorials/01-vector-add.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
# The existence of arrays as a primitive data-type for Triton comes with a number of advantages that are highlighted in the `MAPL'2019 Triton paper <http:https://www.eecs.harvard.edu/~htk/publication/2019-mapl-tillet-kung-cox.pdf>`_.

# %%
# Torch bindings
# Torch Bindings
# --------------------------
# The only thing that matters when it comes to Triton and Torch is the :code:`triton.kernel` class. This allows you to transform the above C-like function into a callable python object that can be used to modify :code:`torch.tensor` objects. To create a :code:`triton.kernel`, you only need three things:
#
Expand Down Expand Up @@ -127,7 +127,7 @@ def forward(ctx, x, y):

# %%
# Unit Test
# --------------------------
# -----------
#
# Of course, the first thing that we should check is that whether kernel is correct. This is pretty easy to test, as shown below:

Expand All @@ -144,8 +144,8 @@ def forward(ctx, x, y):
# Seems like we're good to go!

# %%
# Benchmarking
# --------------------------
# Benchmark
# -----------
# We can now benchmark our custom op for vectors of increasing sizes to get a sense of how it does relative to PyTorch.
# To make things easier, Triton has a set of built-in utilities that allow us to concisely plot the performance of our custom op.
# for different problem sizes.
Expand Down
15 changes: 12 additions & 3 deletions python/tutorials/02-fused-softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,19 @@ def make_kernel(N, device):
# Now are kernels are indexed not only by the provided device but also
# by the rounded number of columns in the input matrix
BLOCK = next_power_of_2(N)
key = (BLOCK, device)
# Another trick we can use is to ask the compiler to parallelize each
# row-normalization more aggressively -- i.e., with more warps -- vectors
# that are longer
# You will see in the next tutorial how to auto-tune this value in a more natural
# way so you don't have to come up with manual heuristics yourself
num_warps = 4
if BLOCK >= 2048: num_warps = 8
if BLOCK >= 4096: num_warps = 16
# Each (BLOCK, num_warps, device) results in a different kernel
key = (BLOCK, num_warps, device)
if key not in cache:
defines = {'BLOCK': BLOCK}
cache[key] = triton.kernel(_src, device=device, defines=defines)
cache[key] = triton.kernel(_src, device=device, defines=defines, num_warps=num_warps)
return cache[key]


Expand Down Expand Up @@ -174,7 +183,7 @@ def forward(ctx, x):
# As expected, the results are identical.

# %%
# Benchmarking
# Benchmark
# -------------
# Here we will benchmark our operation as a function of the number of columns in the input matrix -- assuming 4096 rows.
# We will then compare its performance against (1) :code:`torch.softmax` and (2) the :code:`naive_softmax` defined above.
Expand Down
Loading

0 comments on commit 2f8f004

Please sign in to comment.