From 4a2d3b7d798801b0390f884e2f849ed4fb5c0d8f Mon Sep 17 00:00:00 2001 From: Keren Zhou Date: Sun, 2 Oct 2022 17:39:52 -0700 Subject: [PATCH] [RUNTIME] Dump llvm, ttir, and sass to help debugging (#732) --- python/triton/compiler.py | 28 +++++++++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/python/triton/compiler.py b/python/triton/compiler.py index 87875d1d2..43e69035f 100644 --- a/python/triton/compiler.py +++ b/python/triton/compiler.py @@ -22,6 +22,7 @@ from filelock import FileLock import triton import triton._C.libtriton.triton as _triton +from .tools.disasm import extract def str_to_ty(name): @@ -1209,14 +1210,20 @@ def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: i ptx_name = f"{name}.ptx" cubin_name = f"{name}.cubin" data_name = f"{name}.json" + ttir_name = f"{name}.ttir" + llir_name = f"{name}.llir" if not fn_cache_manager.has_file(cubin_name) or \ not fn_cache_manager.has_file(data_name) or \ - not fn_cache_manager.has_file(ptx_name): + not fn_cache_manager.has_file(ptx_name) or \ + not fn_cache_manager.has_file(ttir_name) or \ + not fn_cache_manager.has_file(llir_name): asm, shared, kernel_name = _compile(fn, signature, device, constants, configs[0], num_warps, num_stages, extern_libs, "cubin", cc) metadata = {"name": kernel_name, "shared": shared, "num_warps": num_warps, "num_stages": num_stages} fn_cache_manager.put(asm["cubin"], cubin_name) fn_cache_manager.put(asm["ptx"], ptx_name, binary=False) + fn_cache_manager.put(asm["ttir"], ttir_name, binary=False) + fn_cache_manager.put(asm["llir"], llir_name, binary=False) fn_cache_manager.put(json.dumps(metadata), data_name, binary=False) if warm_cache_only: @@ -1246,10 +1253,16 @@ class CompiledKernel: self.asm["cubin"] = f.read() with open(os.path.join(cache_dir, f"{fn_name}.ptx"), "r") as f: self.asm["ptx"] = f.read() + with open(os.path.join(cache_dir, f"{fn_name}.llir"), "r") as f: + self.asm["llir"] = f.read() + with open(os.path.join(cache_dir, f"{fn_name}.ttir"), "r") as f: + self.asm["ttir"] = f.read() mod, func, n_regs, n_spills = _triton.code_gen.load_binary(metadata["name"], self.asm["cubin"], self.shared, device) self.cu_module = mod self.cu_function = func + self.n_regs = n_regs + self.n_spills = n_spills def __getitem__(self, grid): def runner(*args, stream=None): @@ -1257,3 +1270,16 @@ class CompiledKernel: stream = torch.cuda.current_stream().cuda_stream self.c_wrapper(grid[0], grid[1], grid[2], self.num_warps, self.shared, stream, self.cu_function, *args) return runner + + def get_sass(self, fun=None): + if 'sass' in self.asm: + return self.asm['sass'] + fd, path = tempfile.mkstemp() + try: + with open(fd, 'wb') as cubin: + cubin.write(self.asm['cubin']) + self.sass = extract(path, fun) + finally: + os.remove(path) + self.asm['sass'] = self.sass + return self.sass