Skip to content

Commit

Permalink
add tvm example, formatting (tinygrad#1813)
Browse files Browse the repository at this point in the history
* add tvm example

* no realize
  • Loading branch information
geohot committed Sep 7, 2023
1 parent 5b15a97 commit 4613c9e
Show file tree
Hide file tree
Showing 7 changed files with 54 additions and 8 deletions.
46 changes: 46 additions & 0 deletions extra/gemm/tvm_gemm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# https://tvm.apache.org/docs/tutorial/tensor_expr_get_started.html#example-2-manually-optimizing-matrix-multiplication-with-te
import tvm
from tvm import te
#print(tvm.target.Target.list_kinds())

M, N, K = 1024, 1024, 1024

# c, opencl
target = tvm.target.Target(target="c")

# TVM Matrix Multiplication using TE
k = te.reduce_axis((0, K), "k")
A = te.placeholder((M, K), name="A")
B = te.placeholder((K, N), name="B")
C = te.compute((M, N), lambda x, y: te.sum(A[x, k] * B[k, y], axis=k), name="C")

# Default schedule
s = te.create_schedule(C.op)
#print(tvm.lower(s, [A, B, C], simple_mode=True))

# Output C code
func = tvm.build(s, [A, B, C], target=target, name="mmult")
print(func.get_source())

# tinygrad version

import os
from tinygrad.tensor import Tensor

# disable optimizations
os.environ["NOOPT"] = "1"

# define the compute
A = Tensor.rand(M, K, device="clang")
B = Tensor.rand(K, N, device="clang")
C = (A.reshape(M, 1, K) * B.permute(1,0).reshape(1, N, K)).sum(axis=2)

# capture the kernel. TODO: https://github.com/tinygrad/tinygrad/issues/1812
from tinygrad.jit import CacheCollector
CacheCollector.start()
C.realize()
result = CacheCollector.finish()

print(result[0][0].prg)


4 changes: 2 additions & 2 deletions tinygrad/codegen/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ class LinearizerOptions(NamedTuple):
local_max: Optional[List[int]] = None

class Kernel:
def __init__(self, ast:LazyOp, output_buffer:LazyBuffer, opts:LinearizerOptions):
def __init__(self, ast:LazyOp, output_buffer:LazyBuffer, opts:Optional[LinearizerOptions]=None):
# NOTE: if there's a RESHAPE, we skip it. the output shape is set from the reduce op or a latebuf
self.ast = ast.src[0] if ast.op == MovementOps.RESHAPE else ast
self.opts = opts
self.opts = opts if opts else LinearizerOptions()

# get the output buffers
self.bufs = [output_buffer] + dedup(ast.buffers)
Expand Down
4 changes: 2 additions & 2 deletions tinygrad/renderer/cstyle.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str
buftypes = [(name,f"{'read_only' if i > 0 else 'write_only'} image2d_t" if dtype.name.startswith('image') else
self.arg_int_prefix if dtype == dtypes._arg_int32 else
("const " if i > 0 else "")+self.buffer_prefix+dtype.name+"*"+self.buffer_suffix) for i,(name,dtype) in enumerate(bufs)]
prg = ''.join([f"{self.kernel_prefix} void {f'__launch_bounds__ ({prod(local_size)}, 1) ' if self.launch_bounds else ''}{function_name}(",] +
prg = ''.join([f"{self.kernel_prefix}void {f'__launch_bounds__ ({prod(local_size)}, 1) ' if self.launch_bounds else ''}{function_name}(",] +
[', '.join([f'{t} {name}' for name,t in buftypes] + self.extra_args)] +
[") {\n" + tmp] + ['\n'.join(kernel), "\n}"])
if self.half_prekernel and any(dtype == dtypes.float16 for _,dtype in bufs): prg = ''.join([f"{self.half_prekernel}", "\n", prg])
Expand All @@ -101,7 +101,7 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> st
kernel,prekernel = [],[]
#pend_close = None
bufs = []
depth = 0
depth = 1
def kk(s): kernel.append(" "*depth+s)

c: DefaultDict[str, int] = defaultdict(int)
Expand Down
2 changes: 1 addition & 1 deletion tinygrad/runtime/ops_clang.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
if CI and ARM64: from unicorn import Uc, UC_ARCH_ARM64, UC_MODE_ARM, UC_HOOK_CODE, arm64_const # type: ignore

args = {
'Windows': {'cflags':'', 'ext':'dll', 'exp':'__declspec(dllexport)'},
'Windows': {'cflags':'', 'ext':'dll', 'exp':'__declspec(dllexport) '},
'Linux': {'cflags':'-lm -fPIC --rtlib=compiler-rt ', 'ext':'so', 'exp':''},
'Darwin': {'cflags':'-lm -fPIC --rtlib=compiler-rt ', 'ext':'dylib', 'exp':''}
}[platform.system()]
Expand Down
2 changes: 1 addition & 1 deletion tinygrad/runtime/ops_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def __call__(self, global_size, local_size, *args, wait=False):
return start.time_till(end)*1e-3

renderer = functools.partial(uops_to_cstyle, CStyleLanguage(
kernel_prefix = "__global__", smem_prefix = "__shared__ ", arg_int_prefix = "const int", barrier = "__syncthreads();", float4 = "make_float4",
kernel_prefix = "__global__ ", smem_prefix = "__shared__ ", arg_int_prefix = "const int", barrier = "__syncthreads();", float4 = "make_float4",
gid = [f'blockIdx.{chr(120+i)}' for i in range(3)],
lid = [f'threadIdx.{chr(120+i)}' for i in range(3)],
half_prekernel = """
Expand Down
2 changes: 1 addition & 1 deletion tinygrad/runtime/ops_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def __call__(self, global_size, local_size, *bufs, wait=False) -> Optional[float
return None

renderer = functools.partial(uops_to_cstyle, CStyleLanguage(
kernel_prefix = "__kernel", buffer_prefix = "__global ", smem_prefix = "__local ", arg_int_prefix = "const int",
kernel_prefix = "__kernel ", buffer_prefix = "__global ", smem_prefix = "__local ", arg_int_prefix = "const int",
half_prekernel = "#pragma OPENCL EXTENSION cl_khr_fp16 : enable",
barrier = "barrier(CLK_LOCAL_MEM_FENCE);", float4 = "(float4)",
gid = [f'get_group_id({i})' for i in range(3)], lid = [f'get_local_id({i})' for i in range(3)], uses_vload=True))
Expand Down
2 changes: 1 addition & 1 deletion tinygrad/runtime/ops_metal.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def __call__(self, global_size, local_size, *bufs, wait=False):
METAL.mtl_buffers_in_flight.append(command_buffer)

renderer = functools.partial(uops_to_cstyle, CStyleLanguage(
kernel_prefix = "#include <metal_stdlib>\nusing namespace metal;\nkernel", buffer_prefix = "device ", smem_prefix = "threadgroup ", arg_int_prefix = "constant int&",
kernel_prefix = "#include <metal_stdlib>\nusing namespace metal;\nkernel ", buffer_prefix = "device ", smem_prefix = "threadgroup ", arg_int_prefix = "constant int&",
barrier = "threadgroup_barrier(mem_flags::mem_threadgroup);", float4 = "float4", uses_ptr_arithmetic=True,
gid = [f"gid.{chr(120+i)}" for i in range(3)], lid = [f"lid.{chr(120+i)}" for i in range(3)],
extra_args = ['uint3 gid [[threadgroup_position_in_grid]]', 'uint3 lid [[thread_position_in_threadgroup]]']))
Expand Down

0 comments on commit 4613c9e

Please sign in to comment.