diff --git a/python/src/triton.cc b/python/src/triton.cc index 52468a0bf..17d720f78 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -59,7 +59,10 @@ void init_triton_driver(py::module &&m) { }); py::class_(m, "module"); - //py::class_(m, "cu_module"); + + py::class_(m, "cu_module") + .def("ptx", &drv::cu_module::ptx) + .def("llir", &drv::cu_module::llir); py::class_(m, "kernel"); } diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index a4c42070f..96a1a9d2a 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -393,6 +393,13 @@ class Binary: self.shared_mem = shared_mem self.num_warps = num_warps + def asm(self, mode): + if mode == 'ptx': + return self.module.ptx() + if mode == 'llir': + return self.module.llir() + raise ValueError('Unsupported mode ' + mode) + def __call__(self, stream, args, grid_0, grid_1=1, grid_2=1): stream.enqueue(self.kernel, grid_0, grid_1, grid_2, self.num_warps * 32, 1, 1, args, self.shared_mem) @@ -523,6 +530,7 @@ class Kernel: stream = _triton.driver.cu_stream(cu_stream, False) grid = grid(meta) if hasattr(grid, '__call__') else grid binary(stream, params, *grid) + return binary class Launcher: @@ -531,7 +539,7 @@ class Launcher: self.grid = grid def __call__(self, *wargs, **kwargs): - self.kernel(*wargs, **kwargs, grid=self.grid) + return self.kernel(*wargs, **kwargs, grid=self.grid) class Autotuner: