[RUNTIME] Re-vamped cache so users can manually patch IR / ptx / cubin files (#845)

Also deprecates a couple of tests
This commit is contained in:
Philippe Tillet
2022-11-04 10:57:29 -07:00
committed by GitHub
parent 4218e68d74
commit b6dbe959f0
5 changed files with 93 additions and 197 deletions

View File

@@ -261,7 +261,7 @@ void init_triton_ir(py::module &&m) {
},
ret::reference)
.def("dump", [](mlir::OpState &self) { self->dump(); })
.def("str",
.def("__str__",
[](mlir::OpState &self) -> std::string {
std::string str;
llvm::raw_string_ostream os(str);
@@ -1280,8 +1280,8 @@ void init_triton_translation(py::module &m) {
using ret = py::return_value_policy;
m.def("get_shared_memory_size", [](mlir::ModuleOp module) {
return module->getAttrOfType<mlir::IntegerAttr>("triton_gpu.shared")
.getInt();
auto shared = module->getAttrOfType<mlir::IntegerAttr>("triton_gpu.shared");
return shared.getInt();
});
m.def(

View File

@@ -1,33 +0,0 @@
import triton
import triton.language as tl
@triton.jit
def math_kernel(x1_ptr, x2_ptr, x3_ptr, x4_ptr, n, BLOCK_SIZE: tl.constexpr):
offsets = tl.arange(0, BLOCK_SIZE)
x1 = tl.load(x1_ptr + offsets, mask=offsets < n)
x2 = tl.load(x2_ptr + offsets, mask=offsets < n)
x3 = tl.load(x3_ptr + offsets, mask=offsets < n)
x4 = tl.load(x4_ptr + offsets, mask=offsets < n)
y1 = tl.sin(x1)
y2 = tl.libdevice.sin(x2)
y3 = tl.libdevice.div_rn(x3, x3)
y4 = tl.libdevice.fma_rd(x4, x4, x4)
tl.store(x1_ptr + offsets, y1, mask=offsets < n)
tl.store(x2_ptr + offsets, y2, mask=offsets < n)
tl.store(x3_ptr + offsets, y3, mask=offsets < n)
tl.store(x4_ptr + offsets, y4, mask=offsets < n)
def test_empty_kernel_cubin_compile():
kernel = triton.compiler._compile(math_kernel,
"*fp32,*fp32,*fp32,*fp32,i32",
device=0,
constants={"BLOCK_SIZE": 256},
output="ttgir") # "cubin"
assert kernel
# TODO: Check if the values are correct.
# TODO: Cover all the math operators

View File

@@ -1,80 +0,0 @@
import triton
import triton.language as tl
# TODO: function with no arguments don't work
@triton.jit
def binop_type_check(X):
# 0d-tensor is not allowed.
# zero_0d = tl.zeros([], dtype=tl.float32)
zero_1d = tl.zeros([2], dtype=tl.float32)
zero_2d_21 = tl.zeros([2, 1], dtype=tl.float32)
zero_2d_22 = tl.zeros([2, 2], dtype=tl.float32)
# scalar + scalar -> scalar
a0 = 0.0 + 0.0
# # scalar + 0D -> 0D
# a1 = 0.0 + zero_0d
# a2 = zero_0d + 0.0
# scalar + 1D -> 1D
a3 = 0.0 + zero_1d
a4 = zero_1d + 0.0
# scalar + 2D -> 2D
a5 = 0.0 + zero_2d_22
a6 = zero_2d_22 + 0.0
# # 0D + 0D -> 0D
# b1 = zero_0d + zero_0d
# # 0D + 1D -> 1D
# b2 = zero_0d + zero_1d
# b3 = zero_1d + zero_0d
# # 0D + 2D -> 2D
# b4 = zero_0d + zero_2d_22
# b5 = zero_2d_22 + zero_0d
# 1D + 1D -> 1D
c1 = zero_1d + zero_1d
# 1D + 2D -> 2D
c2 = zero_1d + zero_2d_21
c3 = zero_1d + zero_2d_22
c4 = zero_2d_21 + zero_1d
c5 = zero_2d_22 + zero_1d
# 2D + 2D -> 2D
d1 = zero_2d_21 + zero_2d_21
d2 = zero_2d_22 + zero_2d_22
d3 = zero_2d_21 + zero_2d_22
d4 = zero_2d_22 + zero_2d_21
# return a0, a1, a2, a3, a4, a5, a6, b1, b2, b3, b4, b5, c1, c2, c3, c4, c5, d1, d2, d3, d4
return a0, a3, a4, a5, a6, c1, c2, c3, c4, c5, d1, d2, d3, d4
def test_binop_type_check():
kernel = triton.compiler._compile(binop_type_check,
signature="*fp32",
device=0,
output="ttir")
assert (kernel)
# TODO: Check types of the results
@triton.jit
def reduce_type_check(ptr):
v_32 = tl.load(ptr + tl.arange(0, 32))
v_scalar = tl.min(v_32, axis=0)
tl.store(ptr, v_scalar)
v_64x128 = tl.load(ptr + tl.arange(0, 64)[:, None] + tl.arange(0, 128)[None, :])
v_64 = tl.max(v_64x128, axis=1)
tl.store(ptr + tl.arange(0, 64), v_64)
v_128 = tl.max(v_64x128, axis=0)
tl.store(ptr + tl.arange(0, 128), v_128)
def test_reduce_type_check():
kernel = triton.compiler._compile(reduce_type_check,
signature="*fp32",
device=0,
output="ttir")
assert (kernel)
# TODO: Check types of the results

View File

@@ -16,8 +16,9 @@ import tempfile
import warnings
from collections import namedtuple
from sysconfig import get_paths
from typing import Any, Dict, Tuple, Union
from typing import Any, Callable, Dict, Tuple, Union
from pathlib import Path
import setuptools
import torch
from filelock import FileLock
@@ -822,7 +823,10 @@ def kernel_suffix(signature, specialization):
# ------------------------------------------------------------------------------
def make_triton_ir(fn, signature, specialization, constants):
def build_triton_ir(fn, signature, specialization, constants):
# canonicalize signature
if isinstance(signature, str):
signature = {k: v.strip() for k, v in enumerate(signature.split(","))}
context = _triton.ir.context()
context.load_triton()
# create kernel prototype
@@ -852,7 +856,6 @@ def make_triton_ir(fn, signature, specialization, constants):
ret.context = context
return ret, generator
def optimize_triton_ir(mod):
pm = _triton.ir.pass_manager(mod.context)
pm.enable_debug()
@@ -864,16 +867,13 @@ def optimize_triton_ir(mod):
pm.run(mod)
return mod
def ast_to_ttir(fn, signature, specialization, constants):
mod, _ = build_triton_ir(fn, signature, specialization, constants)
return optimize_triton_ir(mod)
def make_tritongpu_ir(mod, num_warps):
def ttir_to_ttgir(mod, num_warps, num_stages):
pm = _triton.ir.pass_manager(mod.context)
pm.add_convert_triton_to_tritongpu_pass(num_warps)
pm.run(mod)
return mod
def optimize_tritongpu_ir(mod, num_stages):
pm = _triton.ir.pass_manager(mod.context)
pm.enable_debug()
# Get error in backend due to wrong conversion in expanding async-related instruction.
# TODO[Superjomn]: Open it when fixed.
@@ -897,11 +897,13 @@ def add_external_libs(mod, libs):
_triton.add_external_libs(mod, list(libs.keys()), list(libs.values()))
def make_llvm_ir(mod):
def ttgir_to_llir(mod, extern_libs):
if extern_libs:
add_external_libs(mod, extern_libs)
return _triton.translate_triton_gpu_to_llvmir(mod)
def make_ptx(mod: Any, compute_capability: int, ptx_version: int) -> Tuple[str, int]:
def llir_to_ptx(mod: Any, compute_capability: int = None, ptx_version: int = None) -> Tuple[str, int]:
'''
Translate TritonGPU module to PTX code.
:param mod: a TritonGPU dialect module
@@ -909,16 +911,27 @@ def make_ptx(mod: Any, compute_capability: int, ptx_version: int) -> Tuple[str,
- PTX code
- shared memory alloaction size
'''
if compute_capability is None:
device = torch.cuda.current_device()
compute_capability = torch.cuda.get_device_capability(device)
compute_capability = compute_capability[0] * 10 + compute_capability[1]
if ptx_version is None:
_, cuda_version = path_to_ptxas()
ptx_version = ptx_get_version(cuda_version)
return _triton.translate_llvmir_to_ptx(mod, compute_capability, ptx_version)
def make_cubin(ptx: str, ptxas: str, compute_capability: int):
def ptx_to_cubin(ptx: str, device: int):
'''
Compile TritonGPU module to cubin.
:param ptx: ptx code
:param device: CUDA device
:return: str
'''
ptxas, _ = path_to_ptxas()
compute_capability = torch.cuda.get_device_capability(device)
compute_capability = compute_capability[0] * 10 + compute_capability[1]
return _triton.compile_ptx_to_cubin(ptx, ptxas, compute_capability)
@@ -978,46 +991,6 @@ def path_to_ptxas():
instance_descriptor = namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1"], defaults=[set(), set()])
def _compile(fn, signature: str, device: int = -1, constants=dict(), specialization=instance_descriptor(), num_warps: int = 4, num_stages: int = 3, extern_libs=None, output: str = "ttgir") -> Tuple[str, int, str]:
if isinstance(signature, str):
signature = {k: v.strip() for k, v in enumerate(signature.split(","))}
valid_outputs = ("ttir", "ttgir", "ptx", "cubin")
assert output in valid_outputs, "output should be one of [%s], but get \"%s\"" % (','.join(valid_outputs), output)
# triton-ir
module, _ = make_triton_ir(fn, signature, specialization, constants)
module = optimize_triton_ir(module)
if output == "ttir":
return module.str()
# tritongpu-ir
module = make_tritongpu_ir(module, num_warps)
module = optimize_tritongpu_ir(module, num_stages)
if output == "ttgir":
return module.str()
if extern_libs:
add_external_libs(module, extern_libs)
# llvm-ir
llvm_ir = make_llvm_ir(module)
assert device >= 0, "device should be provided."
ptxas, cuda_version = path_to_ptxas()
compute_capability = torch.cuda.get_device_capability(device)
compute_capability = compute_capability[0] * 10 + compute_capability[1]
ptx_version = ptx_get_version(cuda_version)
ptx = make_ptx(llvm_ir, compute_capability, ptx_version)
shem_size = _triton.get_shared_memory_size(module)
kernel_name = ptx_get_kernel_name(ptx)
if output == "ptx":
return ptx, shem_size, kernel_name
cubin = make_cubin(ptx, ptxas, compute_capability)
if output == "cubin":
return cubin, ptx, shem_size, kernel_name
assert False
# ------------------------------------------------------------------------------
@@ -1306,6 +1279,23 @@ def make_fn_cache_key(fn_hash, signature, configs, constants, num_warps, num_sta
return key
def read_or_execute(cache_manager, force_compile, file_name, metadata,
run_if_found: Callable[[str], bytes] = None,
run_if_not_found: Callable = None):
if not force_compile and cache_manager.has_file(file_name):
module = run_if_found(cache_manager._make_path(file_name))
data = module if isinstance(module, bytes) else str(module).encode("utf-8")
md5 = hashlib.md5(data).hexdigest()
suffix = file_name.split(".")[1]
has_changed = metadata and md5 != metadata["md5"][suffix]
return module, md5, has_changed, True
module = run_if_not_found()
data = module if isinstance(module, bytes) else str(module).encode("utf-8")
md5 = hashlib.md5(data).hexdigest()
cache_manager.put(data, file_name, True)
return module, md5, True, False
def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: int = 4, num_stages: int = 3, extern_libs=None, configs=None):
if isinstance(signature, str):
signature = {k: v.strip() for k, v in enumerate(signature.split(","))}
@@ -1329,29 +1319,56 @@ def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: i
so = _build(fn.__name__, src_path, tmpdir)
with open(so, "rb") as f:
so_cache_manager.put(f.read(), so_name, binary=True)
# retrieve cached shared object if it exists
so_path = so_cache_manager._make_path(so_name)
# create cache manager
fn_cache_key = make_fn_cache_key(fn.cache_key, signature, configs, constants, num_warps, num_stages)
fn_cache_manager = CacheManager(fn_cache_key)
ptx_name = f"{name}.ptx"
cubin_name = f"{name}.cubin"
data_name = f"{name}.json"
if not fn_cache_manager.has_file(cubin_name) or \
not fn_cache_manager.has_file(data_name) or \
not fn_cache_manager.has_file(ptx_name):
cubin, ptx, shared, kernel_name = _compile(fn, signature, device, constants, configs[0], num_warps, num_stages, extern_libs, "cubin")
metadata = {"name": kernel_name, "shared": shared, "num_warps": num_warps, "num_stages": num_stages}
fn_cache_manager.put(cubin, cubin_name)
fn_cache_manager.put(ptx, ptx_name, binary=False)
fn_cache_manager.put(json.dumps(metadata), data_name, binary=False)
# load metadata if any
metadata = None
if fn_cache_manager.has_file(f'{name}.json'):
with open(fn_cache_manager._make_path(f"{name}.json")) as f:
metadata = json.load(f)
context = _triton.ir.context()
force_compile = False
# ast -> triton-ir (or read from cache)
ttir, ttir_md5, force_compile, _ = read_or_execute(fn_cache_manager, force_compile, f"{name}.ttir", metadata,
run_if_found = lambda path: _triton.ir.parse_mlir_module(path, context),
run_if_not_found = lambda: ast_to_ttir(fn, signature, configs[0], constants))
# triton-ir -> triton-gpu-ir (or read from cache)
ttgir, ttgir_md5, force_compile, _ = read_or_execute(fn_cache_manager, force_compile, f"{name}.ttgir", metadata,
run_if_found = lambda path: _triton.ir.parse_mlir_module(path, context),
run_if_not_found = lambda: ttir_to_ttgir(ttir, num_warps, num_stages))
# triton-gpu-ir -> llvm-ir (or read from cache)
llir, llir_md5, force_compile, llvm_cached = read_or_execute(fn_cache_manager, force_compile, f"{name}.llir", metadata,
run_if_found = lambda path: Path(path).read_bytes(),
run_if_not_found = lambda: ttgir_to_llir(ttgir, extern_libs))
if llvm_cached:
shmem_size = metadata["shared"]
else:
shmem_size = _triton.get_shared_memory_size(ttgir)
# llvm-ir -> ptx (or read from cache)
ptx, ptx_md5, force_compile, _ = read_or_execute(fn_cache_manager, force_compile, f"{name}.ptx", metadata,
run_if_found = lambda path: Path(path).read_text(),
run_if_not_found = lambda: llir_to_ptx(llir))
# ptx -> cubin (or read from cache)
cubin, cubin_md5, force_compile, _ = read_or_execute(fn_cache_manager, force_compile, f"{name}.cubin", metadata,
run_if_found = lambda path: Path(path).read_bytes(),
run_if_not_found= lambda: ptx_to_cubin(ptx, device))
# dump new metadata
kernel_name = ptx_get_kernel_name(ptx)
metadata = {"name": kernel_name, "shared": shmem_size, "num_warps": num_warps, "num_stages": num_stages,
"md5": { "cubin": cubin_md5, "ptx": ptx_md5,
"llir": llir_md5,
"ttir": ttir_md5, "ttgir": ttgir_md5 }}
fn_cache_manager.put(json.dumps(metadata), f"{name}.json", binary=False)
return CompiledKernel(name, so_cache_manager._make_path(so_name), fn_cache_manager.cache_dir)
asm = {"ttir": ttir, "ttgir": ttgir, "llir": llir, "ptx": ptx, "cubin": cubin}
return CompiledKernel(so_path, metadata, asm)
class CompiledKernel:
def __init__(self, fn_name, so_path, cache_dir):
def __init__(self, so_path, metadata, asm):
# initialize launcher
import importlib.util
spec = importlib.util.spec_from_file_location("launcher", so_path)
@@ -1359,18 +1376,11 @@ class CompiledKernel:
spec.loader.exec_module(mod)
self.c_wrapper = getattr(mod, "launch")
# initialize metadata
with open(os.path.join(cache_dir, f"{fn_name}.json")) as f:
metadata = json.load(f)
self.shared = metadata["shared"]
self.num_warps = metadata["num_warps"]
self.num_stages = metadata["num_stages"]
# initialize asm dict
self.asm = dict()
with open(os.path.join(cache_dir, f"{fn_name}.cubin"), "rb") as f:
self.asm["cubin"] = f.read()
with open(os.path.join(cache_dir, f"{fn_name}.ptx"), "r") as f:
self.asm["ptx"] = f.read()
self.asm = asm
device = torch.cuda.current_device()
global cuda_utils
if cuda_utils is None:

View File

@@ -38,14 +38,13 @@ if __name__ == '__main__':
exit(0)
# triton-ir -> triton-gpu-ir
module = triton.compiler.make_tritongpu_ir(module, num_warps=4)
module = triton.compiler.optimize_tritongpu_ir(module, num_stages=3)
module = triton.compiler.ttir_to_ttgir(module, num_warps=4, num_stages=3)
if args.target == 'triton-gpu-ir':
print(module.str())
exit(0)
# triton-gpu-ir -> llvm-ir
module = triton.compiler.make_llvm_ir(module)
module = triton.compiler.ttgir_to_llir(module, extern_libs=None)
if args.target == 'llvm-ir':
print(module)
exit(0)
@@ -56,6 +55,6 @@ if __name__ == '__main__':
raise argparse.ArgumentError(None, "Must specify --ptx-version for PTX compilation")
# llvm-ir -> ptx
module = triton.compiler.make_ptx(module, compute_capability=args.sm, ptx_version=args.ptx_version)
module = triton.compiler.llir_to_ptx(module, compute_capability=args.sm, ptx_version=args.ptx_version)
assert args.target == 'ptx'
print(module)