[PYTHON] Now triton.code_gen.Binary can print PTX and LLIR (#88)

This commit is contained in:
Philippe Tillet
2021-04-22 21:50:19 -04:00
committed by Philippe Tillet
parent 29e33e50b7
commit d9112144b4
2 changed files with 13 additions and 2 deletions

View File

@@ -59,7 +59,10 @@ void init_triton_driver(py::module &&m) {
}); });
py::class_<drv::module>(m, "module"); py::class_<drv::module>(m, "module");
//py::class_<drv::cu_module, drv::module>(m, "cu_module");
py::class_<drv::cu_module, drv::module>(m, "cu_module")
.def("ptx", &drv::cu_module::ptx)
.def("llir", &drv::cu_module::llir);
py::class_<drv::kernel>(m, "kernel"); py::class_<drv::kernel>(m, "kernel");
} }

View File

@@ -393,6 +393,13 @@ class Binary:
self.shared_mem = shared_mem self.shared_mem = shared_mem
self.num_warps = num_warps self.num_warps = num_warps
def asm(self, mode):
if mode == 'ptx':
return self.module.ptx()
if mode == 'llir':
return self.module.llir()
raise ValueError('Unsupported mode ' + mode)
def __call__(self, stream, args, grid_0, grid_1=1, grid_2=1): def __call__(self, stream, args, grid_0, grid_1=1, grid_2=1):
stream.enqueue(self.kernel, grid_0, grid_1, grid_2, self.num_warps * 32, 1, 1, args, self.shared_mem) stream.enqueue(self.kernel, grid_0, grid_1, grid_2, self.num_warps * 32, 1, 1, args, self.shared_mem)
@@ -523,6 +530,7 @@ class Kernel:
stream = _triton.driver.cu_stream(cu_stream, False) stream = _triton.driver.cu_stream(cu_stream, False)
grid = grid(meta) if hasattr(grid, '__call__') else grid grid = grid(meta) if hasattr(grid, '__call__') else grid
binary(stream, params, *grid) binary(stream, params, *grid)
return binary
class Launcher: class Launcher:
@@ -531,7 +539,7 @@ class Launcher:
self.grid = grid self.grid = grid
def __call__(self, *wargs, **kwargs): def __call__(self, *wargs, **kwargs):
self.kernel(*wargs, **kwargs, grid=self.grid) return self.kernel(*wargs, **kwargs, grid=self.grid)
class Autotuner: class Autotuner: