[RUNTIME] Re-vamped cache so users can manually patch IR / ptx / cubin files (#845)

Also deprecates a couple of tests
This commit is contained in:
Philippe Tillet
2022-11-04 10:57:29 -07:00
committed by GitHub
parent 4218e68d74
commit b6dbe959f0
5 changed files with 93 additions and 197 deletions

View File

@@ -16,8 +16,9 @@ import tempfile
import warnings
from collections import namedtuple
from sysconfig import get_paths
from typing import Any, Dict, Tuple, Union
from typing import Any, Callable, Dict, Tuple, Union
from pathlib import Path
import setuptools
import torch
from filelock import FileLock
@@ -822,7 +823,10 @@ def kernel_suffix(signature, specialization):
# ------------------------------------------------------------------------------
def make_triton_ir(fn, signature, specialization, constants):
def build_triton_ir(fn, signature, specialization, constants):
# canonicalize signature
if isinstance(signature, str):
signature = {k: v.strip() for k, v in enumerate(signature.split(","))}
context = _triton.ir.context()
context.load_triton()
# create kernel prototype
@@ -852,7 +856,6 @@ def make_triton_ir(fn, signature, specialization, constants):
ret.context = context
return ret, generator
def optimize_triton_ir(mod):
pm = _triton.ir.pass_manager(mod.context)
pm.enable_debug()
@@ -864,16 +867,13 @@ def optimize_triton_ir(mod):
pm.run(mod)
return mod
def ast_to_ttir(fn, signature, specialization, constants):
mod, _ = build_triton_ir(fn, signature, specialization, constants)
return optimize_triton_ir(mod)
def make_tritongpu_ir(mod, num_warps):
def ttir_to_ttgir(mod, num_warps, num_stages):
pm = _triton.ir.pass_manager(mod.context)
pm.add_convert_triton_to_tritongpu_pass(num_warps)
pm.run(mod)
return mod
def optimize_tritongpu_ir(mod, num_stages):
pm = _triton.ir.pass_manager(mod.context)
pm.enable_debug()
# Get error in backend due to wrong conversion in expanding async-related instruction.
# TODO[Superjomn]: Open it when fixed.
@@ -897,11 +897,13 @@ def add_external_libs(mod, libs):
_triton.add_external_libs(mod, list(libs.keys()), list(libs.values()))
def make_llvm_ir(mod):
def ttgir_to_llir(mod, extern_libs):
if extern_libs:
add_external_libs(mod, extern_libs)
return _triton.translate_triton_gpu_to_llvmir(mod)
def make_ptx(mod: Any, compute_capability: int, ptx_version: int) -> Tuple[str, int]:
def llir_to_ptx(mod: Any, compute_capability: int = None, ptx_version: int = None) -> Tuple[str, int]:
'''
Translate TritonGPU module to PTX code.
:param mod: a TritonGPU dialect module
@@ -909,16 +911,27 @@ def make_ptx(mod: Any, compute_capability: int, ptx_version: int) -> Tuple[str,
- PTX code
- shared memory alloaction size
'''
if compute_capability is None:
device = torch.cuda.current_device()
compute_capability = torch.cuda.get_device_capability(device)
compute_capability = compute_capability[0] * 10 + compute_capability[1]
if ptx_version is None:
_, cuda_version = path_to_ptxas()
ptx_version = ptx_get_version(cuda_version)
return _triton.translate_llvmir_to_ptx(mod, compute_capability, ptx_version)
def make_cubin(ptx: str, ptxas: str, compute_capability: int):
def ptx_to_cubin(ptx: str, device: int):
'''
Compile TritonGPU module to cubin.
:param ptx: ptx code
:param device: CUDA device
:return: str
'''
ptxas, _ = path_to_ptxas()
compute_capability = torch.cuda.get_device_capability(device)
compute_capability = compute_capability[0] * 10 + compute_capability[1]
return _triton.compile_ptx_to_cubin(ptx, ptxas, compute_capability)
@@ -978,46 +991,6 @@ def path_to_ptxas():
instance_descriptor = namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1"], defaults=[set(), set()])
def _compile(fn, signature: str, device: int = -1, constants=dict(), specialization=instance_descriptor(), num_warps: int = 4, num_stages: int = 3, extern_libs=None, output: str = "ttgir") -> Tuple[str, int, str]:
if isinstance(signature, str):
signature = {k: v.strip() for k, v in enumerate(signature.split(","))}
valid_outputs = ("ttir", "ttgir", "ptx", "cubin")
assert output in valid_outputs, "output should be one of [%s], but get \"%s\"" % (','.join(valid_outputs), output)
# triton-ir
module, _ = make_triton_ir(fn, signature, specialization, constants)
module = optimize_triton_ir(module)
if output == "ttir":
return module.str()
# tritongpu-ir
module = make_tritongpu_ir(module, num_warps)
module = optimize_tritongpu_ir(module, num_stages)
if output == "ttgir":
return module.str()
if extern_libs:
add_external_libs(module, extern_libs)
# llvm-ir
llvm_ir = make_llvm_ir(module)
assert device >= 0, "device should be provided."
ptxas, cuda_version = path_to_ptxas()
compute_capability = torch.cuda.get_device_capability(device)
compute_capability = compute_capability[0] * 10 + compute_capability[1]
ptx_version = ptx_get_version(cuda_version)
ptx = make_ptx(llvm_ir, compute_capability, ptx_version)
shem_size = _triton.get_shared_memory_size(module)
kernel_name = ptx_get_kernel_name(ptx)
if output == "ptx":
return ptx, shem_size, kernel_name
cubin = make_cubin(ptx, ptxas, compute_capability)
if output == "cubin":
return cubin, ptx, shem_size, kernel_name
assert False
# ------------------------------------------------------------------------------
@@ -1306,6 +1279,23 @@ def make_fn_cache_key(fn_hash, signature, configs, constants, num_warps, num_sta
return key
def read_or_execute(cache_manager, force_compile, file_name, metadata,
run_if_found: Callable[[str], bytes] = None,
run_if_not_found: Callable = None):
if not force_compile and cache_manager.has_file(file_name):
module = run_if_found(cache_manager._make_path(file_name))
data = module if isinstance(module, bytes) else str(module).encode("utf-8")
md5 = hashlib.md5(data).hexdigest()
suffix = file_name.split(".")[1]
has_changed = metadata and md5 != metadata["md5"][suffix]
return module, md5, has_changed, True
module = run_if_not_found()
data = module if isinstance(module, bytes) else str(module).encode("utf-8")
md5 = hashlib.md5(data).hexdigest()
cache_manager.put(data, file_name, True)
return module, md5, True, False
def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: int = 4, num_stages: int = 3, extern_libs=None, configs=None):
if isinstance(signature, str):
signature = {k: v.strip() for k, v in enumerate(signature.split(","))}
@@ -1329,29 +1319,56 @@ def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: i
so = _build(fn.__name__, src_path, tmpdir)
with open(so, "rb") as f:
so_cache_manager.put(f.read(), so_name, binary=True)
# retrieve cached shared object if it exists
so_path = so_cache_manager._make_path(so_name)
# create cache manager
fn_cache_key = make_fn_cache_key(fn.cache_key, signature, configs, constants, num_warps, num_stages)
fn_cache_manager = CacheManager(fn_cache_key)
ptx_name = f"{name}.ptx"
cubin_name = f"{name}.cubin"
data_name = f"{name}.json"
if not fn_cache_manager.has_file(cubin_name) or \
not fn_cache_manager.has_file(data_name) or \
not fn_cache_manager.has_file(ptx_name):
cubin, ptx, shared, kernel_name = _compile(fn, signature, device, constants, configs[0], num_warps, num_stages, extern_libs, "cubin")
metadata = {"name": kernel_name, "shared": shared, "num_warps": num_warps, "num_stages": num_stages}
fn_cache_manager.put(cubin, cubin_name)
fn_cache_manager.put(ptx, ptx_name, binary=False)
fn_cache_manager.put(json.dumps(metadata), data_name, binary=False)
# load metadata if any
metadata = None
if fn_cache_manager.has_file(f'{name}.json'):
with open(fn_cache_manager._make_path(f"{name}.json")) as f:
metadata = json.load(f)
context = _triton.ir.context()
force_compile = False
# ast -> triton-ir (or read from cache)
ttir, ttir_md5, force_compile, _ = read_or_execute(fn_cache_manager, force_compile, f"{name}.ttir", metadata,
run_if_found = lambda path: _triton.ir.parse_mlir_module(path, context),
run_if_not_found = lambda: ast_to_ttir(fn, signature, configs[0], constants))
# triton-ir -> triton-gpu-ir (or read from cache)
ttgir, ttgir_md5, force_compile, _ = read_or_execute(fn_cache_manager, force_compile, f"{name}.ttgir", metadata,
run_if_found = lambda path: _triton.ir.parse_mlir_module(path, context),
run_if_not_found = lambda: ttir_to_ttgir(ttir, num_warps, num_stages))
# triton-gpu-ir -> llvm-ir (or read from cache)
llir, llir_md5, force_compile, llvm_cached = read_or_execute(fn_cache_manager, force_compile, f"{name}.llir", metadata,
run_if_found = lambda path: Path(path).read_bytes(),
run_if_not_found = lambda: ttgir_to_llir(ttgir, extern_libs))
if llvm_cached:
shmem_size = metadata["shared"]
else:
shmem_size = _triton.get_shared_memory_size(ttgir)
# llvm-ir -> ptx (or read from cache)
ptx, ptx_md5, force_compile, _ = read_or_execute(fn_cache_manager, force_compile, f"{name}.ptx", metadata,
run_if_found = lambda path: Path(path).read_text(),
run_if_not_found = lambda: llir_to_ptx(llir))
# ptx -> cubin (or read from cache)
cubin, cubin_md5, force_compile, _ = read_or_execute(fn_cache_manager, force_compile, f"{name}.cubin", metadata,
run_if_found = lambda path: Path(path).read_bytes(),
run_if_not_found= lambda: ptx_to_cubin(ptx, device))
# dump new metadata
kernel_name = ptx_get_kernel_name(ptx)
metadata = {"name": kernel_name, "shared": shmem_size, "num_warps": num_warps, "num_stages": num_stages,
"md5": { "cubin": cubin_md5, "ptx": ptx_md5,
"llir": llir_md5,
"ttir": ttir_md5, "ttgir": ttgir_md5 }}
fn_cache_manager.put(json.dumps(metadata), f"{name}.json", binary=False)
return CompiledKernel(name, so_cache_manager._make_path(so_name), fn_cache_manager.cache_dir)
asm = {"ttir": ttir, "ttgir": ttgir, "llir": llir, "ptx": ptx, "cubin": cubin}
return CompiledKernel(so_path, metadata, asm)
class CompiledKernel:
def __init__(self, fn_name, so_path, cache_dir):
def __init__(self, so_path, metadata, asm):
# initialize launcher
import importlib.util
spec = importlib.util.spec_from_file_location("launcher", so_path)
@@ -1359,18 +1376,11 @@ class CompiledKernel:
spec.loader.exec_module(mod)
self.c_wrapper = getattr(mod, "launch")
# initialize metadata
with open(os.path.join(cache_dir, f"{fn_name}.json")) as f:
metadata = json.load(f)
self.shared = metadata["shared"]
self.num_warps = metadata["num_warps"]
self.num_stages = metadata["num_stages"]
# initialize asm dict
self.asm = dict()
with open(os.path.join(cache_dir, f"{fn_name}.cubin"), "rb") as f:
self.asm["cubin"] = f.read()
with open(os.path.join(cache_dir, f"{fn_name}.ptx"), "r") as f:
self.asm["ptx"] = f.read()
self.asm = asm
device = torch.cuda.current_device()
global cuda_utils
if cuda_utils is None: