Skip to content

Commit

Permalink
[PYTHON] Now triton.code_gen.Binary can print PTX and LLIR (triton-…
Browse files Browse the repository at this point in the history
  • Loading branch information
ptillet committed Apr 23, 2021
1 parent 4c1b69b commit 74fe327
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 2 deletions.
5 changes: 4 additions & 1 deletion python/src/triton.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,10 @@ void init_triton_driver(py::module &&m) {
});

py::class_<drv::module>(m, "module");
//py::class_<drv::cu_module, drv::module>(m, "cu_module");

py::class_<drv::cu_module, drv::module>(m, "cu_module")
.def("ptx", &drv::cu_module::ptx)
.def("llir", &drv::cu_module::llir);

py::class_<drv::kernel>(m, "kernel");
}
Expand Down
10 changes: 9 additions & 1 deletion python/triton/code_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,13 @@ def __init__(self, module, kernel, num_warps, shared_mem):
self.shared_mem = shared_mem
self.num_warps = num_warps

def asm(self, mode):
if mode == 'ptx':
return self.module.ptx()
if mode == 'llir':
return self.module.llir()
raise ValueError('Unsupported mode ' + mode)

def __call__(self, stream, args, grid_0, grid_1=1, grid_2=1):
stream.enqueue(self.kernel, grid_0, grid_1, grid_2, self.num_warps * 32, 1, 1, args, self.shared_mem)

Expand Down Expand Up @@ -523,6 +530,7 @@ def __call__(self, *wargs, grid, num_warps=4, **meta):
stream = _triton.driver.cu_stream(cu_stream, False)
grid = grid(meta) if hasattr(grid, '__call__') else grid
binary(stream, params, *grid)
return binary


class Launcher:
Expand All @@ -531,7 +539,7 @@ def __init__(self, kernel, grid):
self.grid = grid

def __call__(self, *wargs, **kwargs):
self.kernel(*wargs, **kwargs, grid=self.grid)
return self.kernel(*wargs, **kwargs, grid=self.grid)


class Autotuner:
Expand Down

0 comments on commit 74fe327

Please sign in to comment.