[PYTHON] Now triton.code_gen.Binary
can print PTX and LLIR (#88)
This commit is contained in:
committed by
Philippe Tillet
parent
29e33e50b7
commit
d9112144b4
@@ -59,7 +59,10 @@ void init_triton_driver(py::module &&m) {
|
||||
});
|
||||
|
||||
py::class_<drv::module>(m, "module");
|
||||
//py::class_<drv::cu_module, drv::module>(m, "cu_module");
|
||||
|
||||
py::class_<drv::cu_module, drv::module>(m, "cu_module")
|
||||
.def("ptx", &drv::cu_module::ptx)
|
||||
.def("llir", &drv::cu_module::llir);
|
||||
|
||||
py::class_<drv::kernel>(m, "kernel");
|
||||
}
|
||||
|
@@ -393,6 +393,13 @@ class Binary:
|
||||
self.shared_mem = shared_mem
|
||||
self.num_warps = num_warps
|
||||
|
||||
def asm(self, mode):
|
||||
if mode == 'ptx':
|
||||
return self.module.ptx()
|
||||
if mode == 'llir':
|
||||
return self.module.llir()
|
||||
raise ValueError('Unsupported mode ' + mode)
|
||||
|
||||
def __call__(self, stream, args, grid_0, grid_1=1, grid_2=1):
|
||||
stream.enqueue(self.kernel, grid_0, grid_1, grid_2, self.num_warps * 32, 1, 1, args, self.shared_mem)
|
||||
|
||||
@@ -523,6 +530,7 @@ class Kernel:
|
||||
stream = _triton.driver.cu_stream(cu_stream, False)
|
||||
grid = grid(meta) if hasattr(grid, '__call__') else grid
|
||||
binary(stream, params, *grid)
|
||||
return binary
|
||||
|
||||
|
||||
class Launcher:
|
||||
@@ -531,7 +539,7 @@ class Launcher:
|
||||
self.grid = grid
|
||||
|
||||
def __call__(self, *wargs, **kwargs):
|
||||
self.kernel(*wargs, **kwargs, grid=self.grid)
|
||||
return self.kernel(*wargs, **kwargs, grid=self.grid)
|
||||
|
||||
|
||||
class Autotuner:
|
||||
|
Reference in New Issue
Block a user