Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DOCS] Updates and improvements #87

Merged
merged 6 commits into from
Apr 22, 2021
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
[DOCS] more improvements
  • Loading branch information
ptillet committed Apr 21, 2021
commit dd0b9a8df0f117adf46e61f3d93b850647f52047
156 changes: 126 additions & 30 deletions python/triton/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ def wrapper(*args, **kwargs):

if wrapper.__doc__:
wrapper.__doc__ += """\
:param builder: IR builder to generate code into, optional from within @triton.jit functions
:type builder: triton.ir.builder
:param builder: IR builder to generate code into
:type builder: triton.ir.builder, optional from within JIT'ed functions
"""
return wrapper

Expand Down Expand Up @@ -237,7 +237,7 @@ def to(self, dtype, builder=None):
def program_id(axis, builder=None):
"""
Returns the id of the current program instance along the given `axis`.
Triton uses an SPMD model in which different @triton.jit functions run in parallel with different `program_id`s.
Triton uses an SPMD model in which different JIT'ed functions run in parallel with different `program_id`s.

:param axis: The axis of the 3D launch grid. Has to be either 0, 1 or 2.
:type axis: int
Expand All @@ -249,6 +249,7 @@ def program_id(axis, builder=None):
def num_programs(axis, builder=None):
"""
Returns the number of program instances launched along the given `axis`.
Triton uses an SPMD model in which a 3D grid of JIT'ed functions run in parallel.

:param axis: The axis of the 3D launch grid. Has to be either 0, 1 or 2.
:type axis: int
Expand All @@ -266,9 +267,9 @@ def arange(start, end, builder=None):
"""
Returns contiguous values within the open interval [start, end).

:param start: Start of the interval.
:param start: Start of the interval. Must be a power of two.
:type start: int
:param stop: End of the interval.
:param stop: End of the interval. Must be a power of two >= start.
:type stop: int
"""
return frontend.arange(start, end, builder)
Expand All @@ -282,7 +283,7 @@ def zeros(shape, dtype, builder=None):
:param shape: Shape of the new array, e.g., (8, 16) or (8, )
:type shape: tuple of ints
:param dtype: Data-type of the new array, e.g., triton.float16
:type dtype: triton.ir.dtype
:type dtype: DType
"""
return frontend.zeros(shape, dtype, builder)

Expand All @@ -295,23 +296,23 @@ def zeros(shape, dtype, builder=None):
@builtin
def broadcast(input, other, builder=None):
"""
Tries to broadcast two blocks to a common compatible shape.
Tries to broadcast the two given blocks to a common compatible shape.

:param input: The first input block.
:type input: triton.ir.value
:type input: Block
:param other: The second input block.
:type other: triton.ir.value
:type other: Block
"""
return frontend.broadcast(input, other, builder)


@builtin
def broadcast_to(input, shape, builder=None):
"""
Tries to broadcast a block to a new shape.
Tries to broadcast the given block to a new shape.

:param input: The input block.
:type input: triton.value
:type input: Block
:param shape: The new shape.
:type shape: tuple of int
"""
Expand All @@ -321,7 +322,10 @@ def broadcast_to(input, shape, builder=None):
@builtin
def reshape(input, shape, builder=None):
"""
Reshapes a block to a new shape.
Tries to reshape the given block to a new shape.

:param input: The input block.
:type input:
"""
return frontend.reshape(input, shape, builder)

Expand Down Expand Up @@ -354,13 +358,15 @@ def dot(input, other, builder=None):
def load(pointer, mask=None, other=None, builder=None):
"""
Return a block of data whose values are, elementwise, loaded from memory at location defined by `pointer`.

:param pointer: Pointer to the data to be loaded.
:type pointer: Block of triton.pointer
:param mask: if mask[idx] is false, do not load the data at `pointer[idx]`.
:type mask: Block of triton.bool, optional
:param other: if mask[idx] is false, return other[idx] instead of 'pointer[idx]`
:type other: Block of triton.value, optional
`mask` and `other` are implicitly broadcast to `pointer.shape`.
`other` is implicitly typecast to `pointer.dtype.element_ty`.

:param pointer: Pointers to the data to be loaded.
:type pointer: Block of dtype=triton.PointerDType
:param mask: if mask[idx] is false, do not load the data at address `pointer[idx]`.
:type mask: Block of triton.int1, optional
:param other: if mask[idx] is false, return other[idx]
:type other: Block, optional
"""
return frontend.load(pointer, mask, other, builder)

Expand All @@ -369,24 +375,44 @@ def load(pointer, mask=None, other=None, builder=None):
def store(pointer, value, mask=None, builder=None):
"""
Stores `value` block of elements in memory, element-wise, at the memory locations specified by `pointer`.
`value` is implicitly broadcast to `pointer.shape` and typecast to `pointer.dtype.element_ty`.

:param pointer: The memory locations where the elements of `value` are stored.
:type pointer: Block of triton.pointer
:type pointer: Block of dtype=triton.PointerDType
:param value: The block of elements to be stored.
:type value: Block of triton.value
:type value: Block
:param mask: If mask[idx] is false, do not store `value[idx]` at `pointer[idx]`.
:type mask: Block of triton.bool, optional
:type mask: Block of triton.int1, optional
"""
return frontend.store(pointer, value, mask, builder)


@builtin
def atomic_cas(ptr, cmp, val, builder=None):
def atomic_cas(pointer, cmp, val, builder=None):
"""
Performs an atomic "compare-and-swap" and the memory locations specified by `pointer`.

:param pointer: The memory locations to compare-and-swap.
:type pointer: Block of dtype=triton.PointerDType
:param cmp: The values expected to be found in the atomic object
:type cmp: Block of dtype=`pointer.dtype.element_ty`
:param val: The values to copy in case the expected value matches the contained value.
:type val: Block of dtype=`pointer.dtype.element_ty`
"""

return frontend.atomic_cas(ptr, cmp, val, builder)


@builtin
def atomic_xchg(ptr, val, builder=None):
def atomic_xchg(pointer, val, builder=None):
"""
Swaps the *old* values stored at location `pointer` with the new values given by `val`. Returns the old values.

:param pointer: The memory locations which contain the old values
:type pointer: Block of dtype=triton.PointerDType
:param val: The new values to store
:type val: Block of dtype=`pointer.dtype.element_ty`
"""
return frontend.atomic_xchg(ptr, val, builder)


Expand Down Expand Up @@ -419,11 +445,25 @@ def where(condition, x, y, builder=None):

@builtin
def exp(x, builder=None):
"""
Returns a new tensor with the exponential of the elements of the input block.

:param x: the input values
:type x: Block
"""

return frontend.exp(x, builder)


@builtin
def log(x, builder=None):
"""
Returns a new tensor with the natural logarithm of the elements of the input block.

:param x: the input values
:type x: Block
"""

return frontend.log(x, builder)


Expand All @@ -434,16 +474,35 @@ def log(x, builder=None):

@builtin
def max(input, axis, builder=None):
"""
Returns the maximum value of all elements in the :code:`input` block along the provided :code:`axis`

:param input: the input values
:param axis: the dimension along which the reduction should be done
"""
return frontend.max(input, axis, builder)


@builtin
def min(input, axis, builder=None):
"""
Returns the minimum value of all elements in the :code:`input` block along the provided :code:`axis`

:param input: the input values
:param axis: the dimension along which the reduction should be done
"""
return frontend.min(input, axis, builder)


@builtin
def sum(input, axis, builder=None):
"""
Returns the sum of all elements in the :code:`input` block along the provided :code:`axis`

:param input: the input values
:param axis: the dimension along which the reduction should be done
"""

return frontend.sum(input, axis, builder)


Expand All @@ -458,7 +517,10 @@ def debug_barrier(builder=None):


@builtin
def multiple_of(x, value, builder=None):
def multiple_of(input, value, builder=None):
"""
Let the compiler knows that the values in :code:`input` are all multiples of :code:`value`.
"""
return frontend.multiple_of(x, value, builder)


Expand All @@ -469,31 +531,65 @@ def multiple_of(x, value, builder=None):

@triton.jit
def minimum(x, y):
"""
Computes the element-wise minimum of :code:`x` and :code:`y`.

:param input: the first input block
:type input: Block
:param other: the second input block
:type other: Block
"""
return triton.where(x < y, x, y)


@triton.jit
def maximum(x, y):
"""
Computes the element-wise maximum of :code:`x` and :code:`y`.

:param input: the first input block
:type input: Block
:param other: the second input block
:type other: Block
"""
return triton.where(x > y, x, y)


@triton.jit
def sigmoid(x):
return 1 / (1 + np.exp(-x))

"""
Computes the element-wise sigmoid of :code:`x`.

@triton.jit
def ravel(x):
return triton.reshape(x, [x.type.numel])
:param x: the input block
:type x: Block
"""
return 1 / (1 + np.exp(-x))


@triton.jit
def softmax(x):
"""
Computes the element-wise softmax of :code:`x`.

:param x: the input block
:type x: Block
"""
z = x - triton.max(x, 0)
num = triton.exp(z)
den = triton.sum(num, 0)
return num / den


@triton.jit
def ravel(x):
"""
Returns a contiguous flattened view of :code:`x`

:param x: the input block
:type x: Block
"""
return triton.reshape(x, [x.type.numel])


def cdiv(x, y):
return (x + y - 1) // y