[RUNTIME] Dump llvm, ttir, and sass to help debugging (#732)
This commit is contained in:
@@ -22,6 +22,7 @@ from filelock import FileLock
|
||||
|
||||
import triton
|
||||
import triton._C.libtriton.triton as _triton
|
||||
from .tools.disasm import extract
|
||||
|
||||
|
||||
def str_to_ty(name):
|
||||
@@ -1209,14 +1210,20 @@ def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: i
|
||||
ptx_name = f"{name}.ptx"
|
||||
cubin_name = f"{name}.cubin"
|
||||
data_name = f"{name}.json"
|
||||
ttir_name = f"{name}.ttir"
|
||||
llir_name = f"{name}.llir"
|
||||
if not fn_cache_manager.has_file(cubin_name) or \
|
||||
not fn_cache_manager.has_file(data_name) or \
|
||||
not fn_cache_manager.has_file(ptx_name):
|
||||
not fn_cache_manager.has_file(ptx_name) or \
|
||||
not fn_cache_manager.has_file(ttir_name) or \
|
||||
not fn_cache_manager.has_file(llir_name):
|
||||
asm, shared, kernel_name = _compile(fn, signature, device, constants, configs[0], num_warps, num_stages,
|
||||
extern_libs, "cubin", cc)
|
||||
metadata = {"name": kernel_name, "shared": shared, "num_warps": num_warps, "num_stages": num_stages}
|
||||
fn_cache_manager.put(asm["cubin"], cubin_name)
|
||||
fn_cache_manager.put(asm["ptx"], ptx_name, binary=False)
|
||||
fn_cache_manager.put(asm["ttir"], ttir_name, binary=False)
|
||||
fn_cache_manager.put(asm["llir"], llir_name, binary=False)
|
||||
fn_cache_manager.put(json.dumps(metadata), data_name, binary=False)
|
||||
|
||||
if warm_cache_only:
|
||||
@@ -1246,10 +1253,16 @@ class CompiledKernel:
|
||||
self.asm["cubin"] = f.read()
|
||||
with open(os.path.join(cache_dir, f"{fn_name}.ptx"), "r") as f:
|
||||
self.asm["ptx"] = f.read()
|
||||
with open(os.path.join(cache_dir, f"{fn_name}.llir"), "r") as f:
|
||||
self.asm["llir"] = f.read()
|
||||
with open(os.path.join(cache_dir, f"{fn_name}.ttir"), "r") as f:
|
||||
self.asm["ttir"] = f.read()
|
||||
|
||||
mod, func, n_regs, n_spills = _triton.code_gen.load_binary(metadata["name"], self.asm["cubin"], self.shared, device)
|
||||
self.cu_module = mod
|
||||
self.cu_function = func
|
||||
self.n_regs = n_regs
|
||||
self.n_spills = n_spills
|
||||
|
||||
def __getitem__(self, grid):
|
||||
def runner(*args, stream=None):
|
||||
@@ -1257,3 +1270,16 @@ class CompiledKernel:
|
||||
stream = torch.cuda.current_stream().cuda_stream
|
||||
self.c_wrapper(grid[0], grid[1], grid[2], self.num_warps, self.shared, stream, self.cu_function, *args)
|
||||
return runner
|
||||
|
||||
def get_sass(self, fun=None):
|
||||
if 'sass' in self.asm:
|
||||
return self.asm['sass']
|
||||
fd, path = tempfile.mkstemp()
|
||||
try:
|
||||
with open(fd, 'wb') as cubin:
|
||||
cubin.write(self.asm['cubin'])
|
||||
self.sass = extract(path, fun)
|
||||
finally:
|
||||
os.remove(path)
|
||||
self.asm['sass'] = self.sass
|
||||
return self.sass
|
||||
|
Reference in New Issue
Block a user