[RUNTIME] Dump llvm, ttir, and sass to help debugging (#732)

This commit is contained in:
Keren Zhou
2022-10-02 17:39:52 -07:00
committed by GitHub
parent f55960e773
commit 4a2d3b7d79

View File

@@ -22,6 +22,7 @@ from filelock import FileLock
import triton import triton
import triton._C.libtriton.triton as _triton import triton._C.libtriton.triton as _triton
from .tools.disasm import extract
def str_to_ty(name): def str_to_ty(name):
@@ -1209,14 +1210,20 @@ def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: i
ptx_name = f"{name}.ptx" ptx_name = f"{name}.ptx"
cubin_name = f"{name}.cubin" cubin_name = f"{name}.cubin"
data_name = f"{name}.json" data_name = f"{name}.json"
ttir_name = f"{name}.ttir"
llir_name = f"{name}.llir"
if not fn_cache_manager.has_file(cubin_name) or \ if not fn_cache_manager.has_file(cubin_name) or \
not fn_cache_manager.has_file(data_name) or \ not fn_cache_manager.has_file(data_name) or \
not fn_cache_manager.has_file(ptx_name): not fn_cache_manager.has_file(ptx_name) or \
not fn_cache_manager.has_file(ttir_name) or \
not fn_cache_manager.has_file(llir_name):
asm, shared, kernel_name = _compile(fn, signature, device, constants, configs[0], num_warps, num_stages, asm, shared, kernel_name = _compile(fn, signature, device, constants, configs[0], num_warps, num_stages,
extern_libs, "cubin", cc) extern_libs, "cubin", cc)
metadata = {"name": kernel_name, "shared": shared, "num_warps": num_warps, "num_stages": num_stages} metadata = {"name": kernel_name, "shared": shared, "num_warps": num_warps, "num_stages": num_stages}
fn_cache_manager.put(asm["cubin"], cubin_name) fn_cache_manager.put(asm["cubin"], cubin_name)
fn_cache_manager.put(asm["ptx"], ptx_name, binary=False) fn_cache_manager.put(asm["ptx"], ptx_name, binary=False)
fn_cache_manager.put(asm["ttir"], ttir_name, binary=False)
fn_cache_manager.put(asm["llir"], llir_name, binary=False)
fn_cache_manager.put(json.dumps(metadata), data_name, binary=False) fn_cache_manager.put(json.dumps(metadata), data_name, binary=False)
if warm_cache_only: if warm_cache_only:
@@ -1246,10 +1253,16 @@ class CompiledKernel:
self.asm["cubin"] = f.read() self.asm["cubin"] = f.read()
with open(os.path.join(cache_dir, f"{fn_name}.ptx"), "r") as f: with open(os.path.join(cache_dir, f"{fn_name}.ptx"), "r") as f:
self.asm["ptx"] = f.read() self.asm["ptx"] = f.read()
with open(os.path.join(cache_dir, f"{fn_name}.llir"), "r") as f:
self.asm["llir"] = f.read()
with open(os.path.join(cache_dir, f"{fn_name}.ttir"), "r") as f:
self.asm["ttir"] = f.read()
mod, func, n_regs, n_spills = _triton.code_gen.load_binary(metadata["name"], self.asm["cubin"], self.shared, device) mod, func, n_regs, n_spills = _triton.code_gen.load_binary(metadata["name"], self.asm["cubin"], self.shared, device)
self.cu_module = mod self.cu_module = mod
self.cu_function = func self.cu_function = func
self.n_regs = n_regs
self.n_spills = n_spills
def __getitem__(self, grid): def __getitem__(self, grid):
def runner(*args, stream=None): def runner(*args, stream=None):
@@ -1257,3 +1270,16 @@ class CompiledKernel:
stream = torch.cuda.current_stream().cuda_stream stream = torch.cuda.current_stream().cuda_stream
self.c_wrapper(grid[0], grid[1], grid[2], self.num_warps, self.shared, stream, self.cu_function, *args) self.c_wrapper(grid[0], grid[1], grid[2], self.num_warps, self.shared, stream, self.cu_function, *args)
return runner return runner
def get_sass(self, fun=None):
if 'sass' in self.asm:
return self.asm['sass']
fd, path = tempfile.mkstemp()
try:
with open(fd, 'wb') as cubin:
cubin.write(self.asm['cubin'])
self.sass = extract(path, fun)
finally:
os.remove(path)
self.asm['sass'] = self.sass
return self.sass