Skip to content

Commit

Permalink
Fixed merge conflict
Browse files Browse the repository at this point in the history
  • Loading branch information
norabelrose committed Apr 23, 2021
2 parents b70f0ca + 48a64af commit d97dc98
Show file tree
Hide file tree
Showing 19 changed files with 355 additions and 243 deletions.
14 changes: 9 additions & 5 deletions docs/index.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Welcome to Triton's documentation!
==================================

Triton is an imperative language and compiler for parallel programming. It aims to provide a programming environment for productively writing custom DNN compute kernels capable of running at maximal throughput on modern GPU hardware.
Triton is an language and compiler for parallel programming. It aims to provide a Python-based programming environment for productively writing custom DNN compute kernels capable of running at maximal throughput on modern GPU hardware.

Getting Started
---------------
Expand All @@ -17,18 +17,22 @@ Getting Started
getting-started/installation
getting-started/tutorials/index

Language Reference
Python API
-------------------

- Checkout the :doc:`Python API Documentation <language-reference/python-api/index>`
- :doc:`triton <python-api/triton>`
- :doc:`triton.language <python-api/triton.language>`
- :doc:`triton.testing <python-api/triton.testing>`


.. toctree::
:maxdepth: 1
:caption: Language Reference
:caption: Python API
:hidden:

language-reference/python-api/index
python-api/triton
python-api/triton.language
python-api/triton.testing


Going Further
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Python API
===========
triton.language
================

.. currentmodule:: triton
.. currentmodule:: triton.language


Programming Model
Expand Down
10 changes: 10 additions & 0 deletions docs/python-api/triton.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
triton
========

.. currentmodule:: triton

.. autosummary::
:toctree: generated
:nosignatures:

jit
12 changes: 12 additions & 0 deletions docs/python-api/triton.testing.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
triton.testing
================

.. currentmodule:: triton.testing

.. autosummary::
:toctree: generated
:nosignatures:

do_bench
Benchmark
perf_report
12 changes: 6 additions & 6 deletions python/bench/bench_blocksparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
triton.testing.Benchmark(
x_names = ['M', 'N', 'K'],
x_vals = [128, 256, 512, 1024, 2048, 3072, 4096, 6144],
y_name = 'block',
y_vals = [16, 32, 64],
y_lines = ['Block16', 'Block32', 'Block64'],
line_arg = 'block',
line_vals = [16, 32, 64],
line_names = ['Block16', 'Block32', 'Block64'],
ylabel = 'TFLOPS',
plot_name = f'{op_mode}-{layout_mode}-square-{nt[AT]}{nt[BT]}',
args = {'layout_mode': layout_mode, 'op_mode': op_mode,
Expand Down Expand Up @@ -60,9 +60,9 @@ def bench_matmul(M, N, K, block, layout_mode, op_mode, AT, BT, dtype, provider,
triton.testing.Benchmark(
x_names = ['M', 'N'],
x_vals = [128, 256, 512, 1024, 2048, 3072, 4096, 6144],
y_name = 'block',
y_vals = [16, 32, 64],
y_lines = ['Block16', 'Block32', 'Block64'],
line_arg = 'block',
line_vals = [16, 32, 64],
line_names = ['Block16', 'Block32', 'Block64'],
ylabel = 'GBPS',
plot_name = f'{layout_mode}-square',
args = {'layout_mode': layout_mode, 'dtype': torch.float16, 'provider': 'triton'}
Expand Down
6 changes: 3 additions & 3 deletions python/bench/bench_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
triton.testing.Benchmark(
x_names = ['N'],
x_vals = [128, 256, 512, 1024, 2048, 3072, 4096, 6144, 8192],
y_name = 'provider',
y_vals = ['triton', 'torch'],
y_lines = ['Triton', 'Torch'],
line_arg = 'provider',
line_vals = ['triton', 'torch'],
line_names = ['Triton', 'Torch'],
ylabel = 'GBPS',
plot_name = f'{mode}-2048',
args = {'M': 2048, 'dtype': torch.float16, 'mode': mode}
Expand Down
12 changes: 6 additions & 6 deletions python/bench/bench_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ def rounded_linspace(low, high, steps, div):
triton.testing.Benchmark(
x_names=["M", "N", "K"],
x_vals=rounded_linspace(512, 8192, 32, 128),
y_name="provider",
y_vals=["cublas", "triton", "cutlass"],
y_lines=["cuBLAS", "Triton", "CUTLASS"],
line_arg="provider",
line_vals=["cublas", "triton", "cutlass"],
line_names=["cuBLAS", "Triton", "CUTLASS"],
ylabel="TFLOPS",
plot_name=f"matmul-square-{nt[AT]}{nt[BT]}",
args={"AT": AT, "BT": BT, "dtype": torch.float16},
Expand All @@ -30,9 +30,9 @@ def rounded_linspace(low, high, steps, div):
triton.testing.Benchmark(
x_names=[x],
x_vals = rounded_linspace(NK//16, NK, 32, 128),
y_name="provider",
y_vals=["cublas", "triton", "cutlass"],
y_lines=["cuBLAS", "Triton", "CUTLASS"],
line_arg="provider",
line_vals=["cublas", "triton", "cutlass"],
line_names=["cuBLAS", "Triton", "CUTLASS"],
ylabel="TFLOPS",
plot_name=f"matmul-M{M}-{'NK'.replace(x, '')}{NK}",
args= {"M": M, 'NK'.replace(x,''): NK, "AT": False, "BT": False, "dtype": torch.float16}
Expand Down
25 changes: 13 additions & 12 deletions python/test/test_code_gen.py → python/test/test_language.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
import triton
import triton.language as tl
import copy
import pytest
import ast
Expand Down Expand Up @@ -37,10 +38,10 @@ def _test_unary(dtype_x, expr, device='cuda'):
# define the kernel / launch-grid
@triton.jit
def kernel(Z, X, **meta):
off = triton.arange(0, meta['SIZE'])
x = triton.load(X + off)
off = tl.arange(0, meta['SIZE'])
x = tl.load(X + off)
z = GENERATE_TEST_HERE
triton.store(Z + off, z)
tl.store(Z + off, z)

kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': expr})
# inputs
Expand All @@ -59,11 +60,11 @@ def _test_binary(dtype_x, dtype_y, expr, device='cuda'):
# define the kernel / launch-grid
@triton.jit
def kernel(Z, X, Y, **meta):
off = triton.arange(0, meta['SIZE'])
x = triton.load(X + off)
y = triton.load(Y + off)
off = tl.arange(0, meta['SIZE'])
x = tl.load(X + off)
y = tl.load(Y + off)
z = GENERATE_TEST_HERE
triton.store(Z + off, z)
tl.store(Z + off, z)

kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': expr})
# inputs
Expand Down Expand Up @@ -144,7 +145,7 @@ def make_ptr_str(name, shape):
stride = 1
for i in reversed(range(rank)):
idx = ', '.join([':' if ii == i else 'None' for ii in range(rank)])
offsets += [f'triton.arange(0, {shape[i]})[{idx}]*{stride}']
offsets += [f'tl.arange(0, {shape[i]})[{idx}]*{stride}']
stride *= shape[i]
return f"{name} + {' + '.join(offsets)}"

Expand All @@ -164,11 +165,11 @@ def test_index1d(expr, device='cuda'):
@triton.jit
def kernel(Z, X, **meta):
SIZE = meta['SIZE']
m = triton.arange(0, SIZE)
n = triton.arange(0, SIZE)
x = triton.load(X_PTR_EXPR)
m = tl.arange(0, SIZE)
n = tl.arange(0, SIZE)
x = tl.load(X_PTR_EXPR)
z = GENERATE_TEST_HERE
triton.store(Z_PTR_EXPR, z)
tl.store(Z_PTR_EXPR, z)

to_replace = {
'X_PTR_EXPR': make_ptr_str('X', shape_x),
Expand Down
4 changes: 2 additions & 2 deletions python/triton/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
# or pybind11 shows `munmap_chunk(): invalid pointer`
import torch
# submodules
from .code_gen import jit, autotune, heuristics, Config, Autotuner
from .core import *
from .code_gen import cdiv, jit, autotune, heuristics, Config, Autotuner

from . import language
from . import code_gen
from . import testing
from . import ops
Expand Down
47 changes: 39 additions & 8 deletions python/triton/code_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,21 +26,21 @@ def get_value(self, name):
ret = self.builtins[name]
else:
raise ValueError(f'{name} is not defined')
if isinstance(ret, triton.block):
if isinstance(ret, triton.language.block):
handle = self.module.get_value(name)
return triton.block(handle)
return triton.language.block(handle)
return ret

def set_value(self, name, value):
if isinstance(value, _triton.ir.value):
value = triton.block(value)
if isinstance(value, triton.block):
value = triton.language.block(value)
if isinstance(value, triton.language.block):
self.module.set_value(name, value.handle)
self.module.scope.set_type(name, value.handle.type)
self.lscope[name] = value

def is_triton_object(self, value):
return isinstance(value, triton.block)
return isinstance(value, triton.language.block)

def visit_compound_statement(self, stmts, add_scope=False):
if add_scope:
Expand All @@ -63,7 +63,14 @@ def __init__(self, context, prototype, gscope, attributes, constants, kwargs):
self.constants = constants
self.kwargs = kwargs
self.last_node = None
self.builtins = {'range': range, 'min': triton.minimum, 'float': float, 'int': int, 'print': print, 'getattr': getattr}
self.builtins = {
'range': range,
'min': triton.language.minimum,
'float': float,
'int': int,
'print': print,
'getattr': getattr,
}

def visit_Module(self, node):
self.module.add_new_scope()
Expand Down Expand Up @@ -303,7 +310,7 @@ def visit_For(self, node):
pos_cond_node = ast.Compare(ld_target, [ast.Lt()], [node.iter.args[1]])
neg_cond_node = ast.Compare(ld_target, [ast.Gt()], [node.iter.args[1]])
pos_step_node = ast.Compare(node.iter.args[2], [ast.Gt()], [ast.Num(0)])
build_cond = lambda: triton.where(self.visit(pos_step_node),\
build_cond = lambda: triton.language.where(self.visit(pos_step_node),\
self.visit(pos_cond_node),\
self.visit(neg_cond_node),\
builder=self.builder)
Expand Down Expand Up @@ -359,7 +366,7 @@ def visit_Call(self, node):
if isinstance(fn, JITFunction):
return fn(*args, generator=self, **kws)
if hasattr(fn, '__self__') and self.is_triton_object(fn.__self__) or \
sys.modules[fn.__module__] is triton.core:
sys.modules[fn.__module__] is triton.language:
return fn(*args, builder=self.builder, **kws)
return fn(*args, **kws)

Expand Down Expand Up @@ -613,6 +620,11 @@ def __call__(self, *args, generator: CodeGenerator, **meta):
raise e
raise CompilationError(self.src, node, e)

def __setattr__(self, name, value):
if name == 'kernel_decorators':
self.kernel = None
super(JITFunction, self).__setattr__(name, value)

def _init_kernel(self):
if self.kernel is None:
self.kernel = Kernel(self)
Expand Down Expand Up @@ -659,4 +671,23 @@ def fun(*args, **meta):


def jit(fn):
"""
Decorator for JIT-compiling a function using the Triton compiler.
:note: When a jit'd function is called, :code:`torch.tensor` arguments are implicitly converted to pointers using the :code:`.data_ptr()` method.
:note: This function will be compiled and run on the GPU. It will only have access to:
* python primitives,
* objects within the triton.language package,
* arguments to this function,
* other jit'd functions
:param fn: the function to be jit-compiled
:type fn: Callable
"""
return JITFunction(fn)


def cdiv(x, y):
return (x + y - 1) // y
22 changes: 9 additions & 13 deletions python/triton/core.py → python/triton/language.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import triton
from triton._C.libtriton.triton import ir
from triton._C.libtriton.triton import frontend
import triton
from functools import wraps


Expand All @@ -25,7 +25,7 @@ def _from_ir(x):
if x.type.is_void():
return None
return block(x)
return x
return tl

def wrapper(*args, **kwargs):
builder = args[-1]
Expand Down Expand Up @@ -547,7 +547,7 @@ def minimum(x, y):
:param other: the second input block
:type other: Block
"""
return triton.where(x < y, x, y)
return triton.language.where(x < y, x, y)


@triton.jit
Expand All @@ -560,7 +560,7 @@ def maximum(x, y):
:param other: the second input block
:type other: Block
"""
return triton.where(x > y, x, y)
return triton.language.where(x > y, x, y)


@triton.jit
Expand All @@ -571,7 +571,7 @@ def sigmoid(x):
:param x: the input block
:type x: Block
"""
return 1 / (1 + np.exp(-x))
return 1 / (1 + triton.language.exp(-x))


@triton.jit
Expand All @@ -582,9 +582,9 @@ def softmax(x):
:param x: the input block
:type x: Block
"""
z = x - triton.max(x, 0)
num = triton.exp(z)
den = triton.sum(num, 0)
z = x - triton.language.max(x, 0)
num = triton.language.exp(z)
den = triton.language.sum(num, 0)
return num / den


Expand All @@ -596,8 +596,4 @@ def ravel(x):
:param x: the input block
:type x: Block
"""
return triton.reshape(x, [x.type.numel])


def cdiv(x, y):
return (x + y - 1) // y
return triton.language.reshape(x, [x.type.numel])
Loading

0 comments on commit d97dc98

Please sign in to comment.