[RUNTIME] Dump llvm, ttir, and sass to help debugging (#732)
This commit is contained in:
@@ -22,6 +22,7 @@ from filelock import FileLock
|
|||||||
|
|
||||||
import triton
|
import triton
|
||||||
import triton._C.libtriton.triton as _triton
|
import triton._C.libtriton.triton as _triton
|
||||||
|
from .tools.disasm import extract
|
||||||
|
|
||||||
|
|
||||||
def str_to_ty(name):
|
def str_to_ty(name):
|
||||||
@@ -1209,14 +1210,20 @@ def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: i
|
|||||||
ptx_name = f"{name}.ptx"
|
ptx_name = f"{name}.ptx"
|
||||||
cubin_name = f"{name}.cubin"
|
cubin_name = f"{name}.cubin"
|
||||||
data_name = f"{name}.json"
|
data_name = f"{name}.json"
|
||||||
|
ttir_name = f"{name}.ttir"
|
||||||
|
llir_name = f"{name}.llir"
|
||||||
if not fn_cache_manager.has_file(cubin_name) or \
|
if not fn_cache_manager.has_file(cubin_name) or \
|
||||||
not fn_cache_manager.has_file(data_name) or \
|
not fn_cache_manager.has_file(data_name) or \
|
||||||
not fn_cache_manager.has_file(ptx_name):
|
not fn_cache_manager.has_file(ptx_name) or \
|
||||||
|
not fn_cache_manager.has_file(ttir_name) or \
|
||||||
|
not fn_cache_manager.has_file(llir_name):
|
||||||
asm, shared, kernel_name = _compile(fn, signature, device, constants, configs[0], num_warps, num_stages,
|
asm, shared, kernel_name = _compile(fn, signature, device, constants, configs[0], num_warps, num_stages,
|
||||||
extern_libs, "cubin", cc)
|
extern_libs, "cubin", cc)
|
||||||
metadata = {"name": kernel_name, "shared": shared, "num_warps": num_warps, "num_stages": num_stages}
|
metadata = {"name": kernel_name, "shared": shared, "num_warps": num_warps, "num_stages": num_stages}
|
||||||
fn_cache_manager.put(asm["cubin"], cubin_name)
|
fn_cache_manager.put(asm["cubin"], cubin_name)
|
||||||
fn_cache_manager.put(asm["ptx"], ptx_name, binary=False)
|
fn_cache_manager.put(asm["ptx"], ptx_name, binary=False)
|
||||||
|
fn_cache_manager.put(asm["ttir"], ttir_name, binary=False)
|
||||||
|
fn_cache_manager.put(asm["llir"], llir_name, binary=False)
|
||||||
fn_cache_manager.put(json.dumps(metadata), data_name, binary=False)
|
fn_cache_manager.put(json.dumps(metadata), data_name, binary=False)
|
||||||
|
|
||||||
if warm_cache_only:
|
if warm_cache_only:
|
||||||
@@ -1246,10 +1253,16 @@ class CompiledKernel:
|
|||||||
self.asm["cubin"] = f.read()
|
self.asm["cubin"] = f.read()
|
||||||
with open(os.path.join(cache_dir, f"{fn_name}.ptx"), "r") as f:
|
with open(os.path.join(cache_dir, f"{fn_name}.ptx"), "r") as f:
|
||||||
self.asm["ptx"] = f.read()
|
self.asm["ptx"] = f.read()
|
||||||
|
with open(os.path.join(cache_dir, f"{fn_name}.llir"), "r") as f:
|
||||||
|
self.asm["llir"] = f.read()
|
||||||
|
with open(os.path.join(cache_dir, f"{fn_name}.ttir"), "r") as f:
|
||||||
|
self.asm["ttir"] = f.read()
|
||||||
|
|
||||||
mod, func, n_regs, n_spills = _triton.code_gen.load_binary(metadata["name"], self.asm["cubin"], self.shared, device)
|
mod, func, n_regs, n_spills = _triton.code_gen.load_binary(metadata["name"], self.asm["cubin"], self.shared, device)
|
||||||
self.cu_module = mod
|
self.cu_module = mod
|
||||||
self.cu_function = func
|
self.cu_function = func
|
||||||
|
self.n_regs = n_regs
|
||||||
|
self.n_spills = n_spills
|
||||||
|
|
||||||
def __getitem__(self, grid):
|
def __getitem__(self, grid):
|
||||||
def runner(*args, stream=None):
|
def runner(*args, stream=None):
|
||||||
@@ -1257,3 +1270,16 @@ class CompiledKernel:
|
|||||||
stream = torch.cuda.current_stream().cuda_stream
|
stream = torch.cuda.current_stream().cuda_stream
|
||||||
self.c_wrapper(grid[0], grid[1], grid[2], self.num_warps, self.shared, stream, self.cu_function, *args)
|
self.c_wrapper(grid[0], grid[1], grid[2], self.num_warps, self.shared, stream, self.cu_function, *args)
|
||||||
return runner
|
return runner
|
||||||
|
|
||||||
|
def get_sass(self, fun=None):
|
||||||
|
if 'sass' in self.asm:
|
||||||
|
return self.asm['sass']
|
||||||
|
fd, path = tempfile.mkstemp()
|
||||||
|
try:
|
||||||
|
with open(fd, 'wb') as cubin:
|
||||||
|
cubin.write(self.asm['cubin'])
|
||||||
|
self.sass = extract(path, fun)
|
||||||
|
finally:
|
||||||
|
os.remove(path)
|
||||||
|
self.asm['sass'] = self.sass
|
||||||
|
return self.sass
|
||||||
|
Reference in New Issue
Block a user