[RUNTIME] Re-vamped cache so users can manually patch IR / ptx / cubin files (#845)
Also deprecates a couple of tests
This commit is contained in:
@@ -16,8 +16,9 @@ import tempfile
|
||||
import warnings
|
||||
from collections import namedtuple
|
||||
from sysconfig import get_paths
|
||||
from typing import Any, Dict, Tuple, Union
|
||||
from typing import Any, Callable, Dict, Tuple, Union
|
||||
|
||||
from pathlib import Path
|
||||
import setuptools
|
||||
import torch
|
||||
from filelock import FileLock
|
||||
@@ -822,7 +823,10 @@ def kernel_suffix(signature, specialization):
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
|
||||
def make_triton_ir(fn, signature, specialization, constants):
|
||||
def build_triton_ir(fn, signature, specialization, constants):
|
||||
# canonicalize signature
|
||||
if isinstance(signature, str):
|
||||
signature = {k: v.strip() for k, v in enumerate(signature.split(","))}
|
||||
context = _triton.ir.context()
|
||||
context.load_triton()
|
||||
# create kernel prototype
|
||||
@@ -852,7 +856,6 @@ def make_triton_ir(fn, signature, specialization, constants):
|
||||
ret.context = context
|
||||
return ret, generator
|
||||
|
||||
|
||||
def optimize_triton_ir(mod):
|
||||
pm = _triton.ir.pass_manager(mod.context)
|
||||
pm.enable_debug()
|
||||
@@ -864,16 +867,13 @@ def optimize_triton_ir(mod):
|
||||
pm.run(mod)
|
||||
return mod
|
||||
|
||||
def ast_to_ttir(fn, signature, specialization, constants):
|
||||
mod, _ = build_triton_ir(fn, signature, specialization, constants)
|
||||
return optimize_triton_ir(mod)
|
||||
|
||||
def make_tritongpu_ir(mod, num_warps):
|
||||
def ttir_to_ttgir(mod, num_warps, num_stages):
|
||||
pm = _triton.ir.pass_manager(mod.context)
|
||||
pm.add_convert_triton_to_tritongpu_pass(num_warps)
|
||||
pm.run(mod)
|
||||
return mod
|
||||
|
||||
|
||||
def optimize_tritongpu_ir(mod, num_stages):
|
||||
pm = _triton.ir.pass_manager(mod.context)
|
||||
pm.enable_debug()
|
||||
# Get error in backend due to wrong conversion in expanding async-related instruction.
|
||||
# TODO[Superjomn]: Open it when fixed.
|
||||
@@ -897,11 +897,13 @@ def add_external_libs(mod, libs):
|
||||
_triton.add_external_libs(mod, list(libs.keys()), list(libs.values()))
|
||||
|
||||
|
||||
def make_llvm_ir(mod):
|
||||
def ttgir_to_llir(mod, extern_libs):
|
||||
if extern_libs:
|
||||
add_external_libs(mod, extern_libs)
|
||||
return _triton.translate_triton_gpu_to_llvmir(mod)
|
||||
|
||||
|
||||
def make_ptx(mod: Any, compute_capability: int, ptx_version: int) -> Tuple[str, int]:
|
||||
def llir_to_ptx(mod: Any, compute_capability: int = None, ptx_version: int = None) -> Tuple[str, int]:
|
||||
'''
|
||||
Translate TritonGPU module to PTX code.
|
||||
:param mod: a TritonGPU dialect module
|
||||
@@ -909,16 +911,27 @@ def make_ptx(mod: Any, compute_capability: int, ptx_version: int) -> Tuple[str,
|
||||
- PTX code
|
||||
- shared memory alloaction size
|
||||
'''
|
||||
if compute_capability is None:
|
||||
device = torch.cuda.current_device()
|
||||
compute_capability = torch.cuda.get_device_capability(device)
|
||||
compute_capability = compute_capability[0] * 10 + compute_capability[1]
|
||||
if ptx_version is None:
|
||||
_, cuda_version = path_to_ptxas()
|
||||
ptx_version = ptx_get_version(cuda_version)
|
||||
return _triton.translate_llvmir_to_ptx(mod, compute_capability, ptx_version)
|
||||
|
||||
|
||||
def make_cubin(ptx: str, ptxas: str, compute_capability: int):
|
||||
|
||||
def ptx_to_cubin(ptx: str, device: int):
|
||||
'''
|
||||
Compile TritonGPU module to cubin.
|
||||
:param ptx: ptx code
|
||||
:param device: CUDA device
|
||||
:return: str
|
||||
'''
|
||||
ptxas, _ = path_to_ptxas()
|
||||
compute_capability = torch.cuda.get_device_capability(device)
|
||||
compute_capability = compute_capability[0] * 10 + compute_capability[1]
|
||||
return _triton.compile_ptx_to_cubin(ptx, ptxas, compute_capability)
|
||||
|
||||
|
||||
@@ -978,46 +991,6 @@ def path_to_ptxas():
|
||||
instance_descriptor = namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1"], defaults=[set(), set()])
|
||||
|
||||
|
||||
def _compile(fn, signature: str, device: int = -1, constants=dict(), specialization=instance_descriptor(), num_warps: int = 4, num_stages: int = 3, extern_libs=None, output: str = "ttgir") -> Tuple[str, int, str]:
|
||||
if isinstance(signature, str):
|
||||
signature = {k: v.strip() for k, v in enumerate(signature.split(","))}
|
||||
valid_outputs = ("ttir", "ttgir", "ptx", "cubin")
|
||||
assert output in valid_outputs, "output should be one of [%s], but get \"%s\"" % (','.join(valid_outputs), output)
|
||||
|
||||
# triton-ir
|
||||
module, _ = make_triton_ir(fn, signature, specialization, constants)
|
||||
module = optimize_triton_ir(module)
|
||||
if output == "ttir":
|
||||
return module.str()
|
||||
|
||||
# tritongpu-ir
|
||||
module = make_tritongpu_ir(module, num_warps)
|
||||
module = optimize_tritongpu_ir(module, num_stages)
|
||||
if output == "ttgir":
|
||||
return module.str()
|
||||
|
||||
if extern_libs:
|
||||
add_external_libs(module, extern_libs)
|
||||
|
||||
# llvm-ir
|
||||
llvm_ir = make_llvm_ir(module)
|
||||
|
||||
assert device >= 0, "device should be provided."
|
||||
ptxas, cuda_version = path_to_ptxas()
|
||||
compute_capability = torch.cuda.get_device_capability(device)
|
||||
compute_capability = compute_capability[0] * 10 + compute_capability[1]
|
||||
ptx_version = ptx_get_version(cuda_version)
|
||||
ptx = make_ptx(llvm_ir, compute_capability, ptx_version)
|
||||
shem_size = _triton.get_shared_memory_size(module)
|
||||
kernel_name = ptx_get_kernel_name(ptx)
|
||||
if output == "ptx":
|
||||
return ptx, shem_size, kernel_name
|
||||
|
||||
cubin = make_cubin(ptx, ptxas, compute_capability)
|
||||
if output == "cubin":
|
||||
return cubin, ptx, shem_size, kernel_name
|
||||
|
||||
assert False
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
@@ -1306,6 +1279,23 @@ def make_fn_cache_key(fn_hash, signature, configs, constants, num_warps, num_sta
|
||||
return key
|
||||
|
||||
|
||||
def read_or_execute(cache_manager, force_compile, file_name, metadata,
|
||||
run_if_found: Callable[[str], bytes] = None,
|
||||
run_if_not_found: Callable = None):
|
||||
if not force_compile and cache_manager.has_file(file_name):
|
||||
module = run_if_found(cache_manager._make_path(file_name))
|
||||
data = module if isinstance(module, bytes) else str(module).encode("utf-8")
|
||||
md5 = hashlib.md5(data).hexdigest()
|
||||
suffix = file_name.split(".")[1]
|
||||
has_changed = metadata and md5 != metadata["md5"][suffix]
|
||||
return module, md5, has_changed, True
|
||||
module = run_if_not_found()
|
||||
data = module if isinstance(module, bytes) else str(module).encode("utf-8")
|
||||
md5 = hashlib.md5(data).hexdigest()
|
||||
cache_manager.put(data, file_name, True)
|
||||
return module, md5, True, False
|
||||
|
||||
|
||||
def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: int = 4, num_stages: int = 3, extern_libs=None, configs=None):
|
||||
if isinstance(signature, str):
|
||||
signature = {k: v.strip() for k, v in enumerate(signature.split(","))}
|
||||
@@ -1329,29 +1319,56 @@ def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: i
|
||||
so = _build(fn.__name__, src_path, tmpdir)
|
||||
with open(so, "rb") as f:
|
||||
so_cache_manager.put(f.read(), so_name, binary=True)
|
||||
|
||||
# retrieve cached shared object if it exists
|
||||
so_path = so_cache_manager._make_path(so_name)
|
||||
# create cache manager
|
||||
fn_cache_key = make_fn_cache_key(fn.cache_key, signature, configs, constants, num_warps, num_stages)
|
||||
fn_cache_manager = CacheManager(fn_cache_key)
|
||||
ptx_name = f"{name}.ptx"
|
||||
cubin_name = f"{name}.cubin"
|
||||
data_name = f"{name}.json"
|
||||
if not fn_cache_manager.has_file(cubin_name) or \
|
||||
not fn_cache_manager.has_file(data_name) or \
|
||||
not fn_cache_manager.has_file(ptx_name):
|
||||
cubin, ptx, shared, kernel_name = _compile(fn, signature, device, constants, configs[0], num_warps, num_stages, extern_libs, "cubin")
|
||||
metadata = {"name": kernel_name, "shared": shared, "num_warps": num_warps, "num_stages": num_stages}
|
||||
fn_cache_manager.put(cubin, cubin_name)
|
||||
fn_cache_manager.put(ptx, ptx_name, binary=False)
|
||||
fn_cache_manager.put(json.dumps(metadata), data_name, binary=False)
|
||||
# load metadata if any
|
||||
metadata = None
|
||||
if fn_cache_manager.has_file(f'{name}.json'):
|
||||
with open(fn_cache_manager._make_path(f"{name}.json")) as f:
|
||||
metadata = json.load(f)
|
||||
context = _triton.ir.context()
|
||||
force_compile = False
|
||||
# ast -> triton-ir (or read from cache)
|
||||
ttir, ttir_md5, force_compile, _ = read_or_execute(fn_cache_manager, force_compile, f"{name}.ttir", metadata,
|
||||
run_if_found = lambda path: _triton.ir.parse_mlir_module(path, context),
|
||||
run_if_not_found = lambda: ast_to_ttir(fn, signature, configs[0], constants))
|
||||
# triton-ir -> triton-gpu-ir (or read from cache)
|
||||
ttgir, ttgir_md5, force_compile, _ = read_or_execute(fn_cache_manager, force_compile, f"{name}.ttgir", metadata,
|
||||
run_if_found = lambda path: _triton.ir.parse_mlir_module(path, context),
|
||||
run_if_not_found = lambda: ttir_to_ttgir(ttir, num_warps, num_stages))
|
||||
# triton-gpu-ir -> llvm-ir (or read from cache)
|
||||
llir, llir_md5, force_compile, llvm_cached = read_or_execute(fn_cache_manager, force_compile, f"{name}.llir", metadata,
|
||||
run_if_found = lambda path: Path(path).read_bytes(),
|
||||
run_if_not_found = lambda: ttgir_to_llir(ttgir, extern_libs))
|
||||
if llvm_cached:
|
||||
shmem_size = metadata["shared"]
|
||||
else:
|
||||
shmem_size = _triton.get_shared_memory_size(ttgir)
|
||||
# llvm-ir -> ptx (or read from cache)
|
||||
ptx, ptx_md5, force_compile, _ = read_or_execute(fn_cache_manager, force_compile, f"{name}.ptx", metadata,
|
||||
run_if_found = lambda path: Path(path).read_text(),
|
||||
run_if_not_found = lambda: llir_to_ptx(llir))
|
||||
# ptx -> cubin (or read from cache)
|
||||
cubin, cubin_md5, force_compile, _ = read_or_execute(fn_cache_manager, force_compile, f"{name}.cubin", metadata,
|
||||
run_if_found = lambda path: Path(path).read_bytes(),
|
||||
run_if_not_found= lambda: ptx_to_cubin(ptx, device))
|
||||
# dump new metadata
|
||||
kernel_name = ptx_get_kernel_name(ptx)
|
||||
metadata = {"name": kernel_name, "shared": shmem_size, "num_warps": num_warps, "num_stages": num_stages,
|
||||
"md5": { "cubin": cubin_md5, "ptx": ptx_md5,
|
||||
"llir": llir_md5,
|
||||
"ttir": ttir_md5, "ttgir": ttgir_md5 }}
|
||||
fn_cache_manager.put(json.dumps(metadata), f"{name}.json", binary=False)
|
||||
|
||||
return CompiledKernel(name, so_cache_manager._make_path(so_name), fn_cache_manager.cache_dir)
|
||||
asm = {"ttir": ttir, "ttgir": ttgir, "llir": llir, "ptx": ptx, "cubin": cubin}
|
||||
return CompiledKernel(so_path, metadata, asm)
|
||||
|
||||
|
||||
class CompiledKernel:
|
||||
|
||||
def __init__(self, fn_name, so_path, cache_dir):
|
||||
|
||||
def __init__(self, so_path, metadata, asm):
|
||||
# initialize launcher
|
||||
import importlib.util
|
||||
spec = importlib.util.spec_from_file_location("launcher", so_path)
|
||||
@@ -1359,18 +1376,11 @@ class CompiledKernel:
|
||||
spec.loader.exec_module(mod)
|
||||
self.c_wrapper = getattr(mod, "launch")
|
||||
# initialize metadata
|
||||
with open(os.path.join(cache_dir, f"{fn_name}.json")) as f:
|
||||
metadata = json.load(f)
|
||||
self.shared = metadata["shared"]
|
||||
self.num_warps = metadata["num_warps"]
|
||||
self.num_stages = metadata["num_stages"]
|
||||
# initialize asm dict
|
||||
self.asm = dict()
|
||||
with open(os.path.join(cache_dir, f"{fn_name}.cubin"), "rb") as f:
|
||||
self.asm["cubin"] = f.read()
|
||||
with open(os.path.join(cache_dir, f"{fn_name}.ptx"), "r") as f:
|
||||
self.asm["ptx"] = f.read()
|
||||
|
||||
self.asm = asm
|
||||
device = torch.cuda.current_device()
|
||||
global cuda_utils
|
||||
if cuda_utils is None:
|
||||
|
Reference in New Issue
Block a user