[RUNTIME] Re-vamped cache so users can manually patch IR / ptx / cubin files (#845)
Also deprecates a couple of tests
This commit is contained in:
@@ -261,7 +261,7 @@ void init_triton_ir(py::module &&m) {
|
||||
},
|
||||
ret::reference)
|
||||
.def("dump", [](mlir::OpState &self) { self->dump(); })
|
||||
.def("str",
|
||||
.def("__str__",
|
||||
[](mlir::OpState &self) -> std::string {
|
||||
std::string str;
|
||||
llvm::raw_string_ostream os(str);
|
||||
@@ -1280,8 +1280,8 @@ void init_triton_translation(py::module &m) {
|
||||
using ret = py::return_value_policy;
|
||||
|
||||
m.def("get_shared_memory_size", [](mlir::ModuleOp module) {
|
||||
return module->getAttrOfType<mlir::IntegerAttr>("triton_gpu.shared")
|
||||
.getInt();
|
||||
auto shared = module->getAttrOfType<mlir::IntegerAttr>("triton_gpu.shared");
|
||||
return shared.getInt();
|
||||
});
|
||||
|
||||
m.def(
|
||||
|
@@ -1,33 +0,0 @@
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def math_kernel(x1_ptr, x2_ptr, x3_ptr, x4_ptr, n, BLOCK_SIZE: tl.constexpr):
|
||||
offsets = tl.arange(0, BLOCK_SIZE)
|
||||
x1 = tl.load(x1_ptr + offsets, mask=offsets < n)
|
||||
x2 = tl.load(x2_ptr + offsets, mask=offsets < n)
|
||||
x3 = tl.load(x3_ptr + offsets, mask=offsets < n)
|
||||
x4 = tl.load(x4_ptr + offsets, mask=offsets < n)
|
||||
|
||||
y1 = tl.sin(x1)
|
||||
y2 = tl.libdevice.sin(x2)
|
||||
y3 = tl.libdevice.div_rn(x3, x3)
|
||||
y4 = tl.libdevice.fma_rd(x4, x4, x4)
|
||||
|
||||
tl.store(x1_ptr + offsets, y1, mask=offsets < n)
|
||||
tl.store(x2_ptr + offsets, y2, mask=offsets < n)
|
||||
tl.store(x3_ptr + offsets, y3, mask=offsets < n)
|
||||
tl.store(x4_ptr + offsets, y4, mask=offsets < n)
|
||||
|
||||
|
||||
def test_empty_kernel_cubin_compile():
|
||||
kernel = triton.compiler._compile(math_kernel,
|
||||
"*fp32,*fp32,*fp32,*fp32,i32",
|
||||
device=0,
|
||||
constants={"BLOCK_SIZE": 256},
|
||||
output="ttgir") # "cubin"
|
||||
assert kernel
|
||||
# TODO: Check if the values are correct.
|
||||
# TODO: Cover all the math operators
|
@@ -1,80 +0,0 @@
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
# TODO: function with no arguments don't work
|
||||
@triton.jit
|
||||
def binop_type_check(X):
|
||||
# 0d-tensor is not allowed.
|
||||
# zero_0d = tl.zeros([], dtype=tl.float32)
|
||||
zero_1d = tl.zeros([2], dtype=tl.float32)
|
||||
zero_2d_21 = tl.zeros([2, 1], dtype=tl.float32)
|
||||
zero_2d_22 = tl.zeros([2, 2], dtype=tl.float32)
|
||||
|
||||
# scalar + scalar -> scalar
|
||||
a0 = 0.0 + 0.0
|
||||
# # scalar + 0D -> 0D
|
||||
# a1 = 0.0 + zero_0d
|
||||
# a2 = zero_0d + 0.0
|
||||
# scalar + 1D -> 1D
|
||||
a3 = 0.0 + zero_1d
|
||||
a4 = zero_1d + 0.0
|
||||
# scalar + 2D -> 2D
|
||||
a5 = 0.0 + zero_2d_22
|
||||
a6 = zero_2d_22 + 0.0
|
||||
|
||||
# # 0D + 0D -> 0D
|
||||
# b1 = zero_0d + zero_0d
|
||||
# # 0D + 1D -> 1D
|
||||
# b2 = zero_0d + zero_1d
|
||||
# b3 = zero_1d + zero_0d
|
||||
# # 0D + 2D -> 2D
|
||||
# b4 = zero_0d + zero_2d_22
|
||||
# b5 = zero_2d_22 + zero_0d
|
||||
|
||||
# 1D + 1D -> 1D
|
||||
c1 = zero_1d + zero_1d
|
||||
# 1D + 2D -> 2D
|
||||
c2 = zero_1d + zero_2d_21
|
||||
c3 = zero_1d + zero_2d_22
|
||||
c4 = zero_2d_21 + zero_1d
|
||||
c5 = zero_2d_22 + zero_1d
|
||||
|
||||
# 2D + 2D -> 2D
|
||||
d1 = zero_2d_21 + zero_2d_21
|
||||
d2 = zero_2d_22 + zero_2d_22
|
||||
d3 = zero_2d_21 + zero_2d_22
|
||||
d4 = zero_2d_22 + zero_2d_21
|
||||
|
||||
# return a0, a1, a2, a3, a4, a5, a6, b1, b2, b3, b4, b5, c1, c2, c3, c4, c5, d1, d2, d3, d4
|
||||
return a0, a3, a4, a5, a6, c1, c2, c3, c4, c5, d1, d2, d3, d4
|
||||
|
||||
|
||||
def test_binop_type_check():
|
||||
kernel = triton.compiler._compile(binop_type_check,
|
||||
signature="*fp32",
|
||||
device=0,
|
||||
output="ttir")
|
||||
assert (kernel)
|
||||
# TODO: Check types of the results
|
||||
|
||||
|
||||
@triton.jit
|
||||
def reduce_type_check(ptr):
|
||||
v_32 = tl.load(ptr + tl.arange(0, 32))
|
||||
v_scalar = tl.min(v_32, axis=0)
|
||||
tl.store(ptr, v_scalar)
|
||||
v_64x128 = tl.load(ptr + tl.arange(0, 64)[:, None] + tl.arange(0, 128)[None, :])
|
||||
v_64 = tl.max(v_64x128, axis=1)
|
||||
tl.store(ptr + tl.arange(0, 64), v_64)
|
||||
v_128 = tl.max(v_64x128, axis=0)
|
||||
tl.store(ptr + tl.arange(0, 128), v_128)
|
||||
|
||||
|
||||
def test_reduce_type_check():
|
||||
kernel = triton.compiler._compile(reduce_type_check,
|
||||
signature="*fp32",
|
||||
device=0,
|
||||
output="ttir")
|
||||
assert (kernel)
|
||||
# TODO: Check types of the results
|
@@ -16,8 +16,9 @@ import tempfile
|
||||
import warnings
|
||||
from collections import namedtuple
|
||||
from sysconfig import get_paths
|
||||
from typing import Any, Dict, Tuple, Union
|
||||
from typing import Any, Callable, Dict, Tuple, Union
|
||||
|
||||
from pathlib import Path
|
||||
import setuptools
|
||||
import torch
|
||||
from filelock import FileLock
|
||||
@@ -822,7 +823,10 @@ def kernel_suffix(signature, specialization):
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
|
||||
def make_triton_ir(fn, signature, specialization, constants):
|
||||
def build_triton_ir(fn, signature, specialization, constants):
|
||||
# canonicalize signature
|
||||
if isinstance(signature, str):
|
||||
signature = {k: v.strip() for k, v in enumerate(signature.split(","))}
|
||||
context = _triton.ir.context()
|
||||
context.load_triton()
|
||||
# create kernel prototype
|
||||
@@ -852,7 +856,6 @@ def make_triton_ir(fn, signature, specialization, constants):
|
||||
ret.context = context
|
||||
return ret, generator
|
||||
|
||||
|
||||
def optimize_triton_ir(mod):
|
||||
pm = _triton.ir.pass_manager(mod.context)
|
||||
pm.enable_debug()
|
||||
@@ -864,16 +867,13 @@ def optimize_triton_ir(mod):
|
||||
pm.run(mod)
|
||||
return mod
|
||||
|
||||
def ast_to_ttir(fn, signature, specialization, constants):
|
||||
mod, _ = build_triton_ir(fn, signature, specialization, constants)
|
||||
return optimize_triton_ir(mod)
|
||||
|
||||
def make_tritongpu_ir(mod, num_warps):
|
||||
def ttir_to_ttgir(mod, num_warps, num_stages):
|
||||
pm = _triton.ir.pass_manager(mod.context)
|
||||
pm.add_convert_triton_to_tritongpu_pass(num_warps)
|
||||
pm.run(mod)
|
||||
return mod
|
||||
|
||||
|
||||
def optimize_tritongpu_ir(mod, num_stages):
|
||||
pm = _triton.ir.pass_manager(mod.context)
|
||||
pm.enable_debug()
|
||||
# Get error in backend due to wrong conversion in expanding async-related instruction.
|
||||
# TODO[Superjomn]: Open it when fixed.
|
||||
@@ -897,11 +897,13 @@ def add_external_libs(mod, libs):
|
||||
_triton.add_external_libs(mod, list(libs.keys()), list(libs.values()))
|
||||
|
||||
|
||||
def make_llvm_ir(mod):
|
||||
def ttgir_to_llir(mod, extern_libs):
|
||||
if extern_libs:
|
||||
add_external_libs(mod, extern_libs)
|
||||
return _triton.translate_triton_gpu_to_llvmir(mod)
|
||||
|
||||
|
||||
def make_ptx(mod: Any, compute_capability: int, ptx_version: int) -> Tuple[str, int]:
|
||||
def llir_to_ptx(mod: Any, compute_capability: int = None, ptx_version: int = None) -> Tuple[str, int]:
|
||||
'''
|
||||
Translate TritonGPU module to PTX code.
|
||||
:param mod: a TritonGPU dialect module
|
||||
@@ -909,16 +911,27 @@ def make_ptx(mod: Any, compute_capability: int, ptx_version: int) -> Tuple[str,
|
||||
- PTX code
|
||||
- shared memory alloaction size
|
||||
'''
|
||||
if compute_capability is None:
|
||||
device = torch.cuda.current_device()
|
||||
compute_capability = torch.cuda.get_device_capability(device)
|
||||
compute_capability = compute_capability[0] * 10 + compute_capability[1]
|
||||
if ptx_version is None:
|
||||
_, cuda_version = path_to_ptxas()
|
||||
ptx_version = ptx_get_version(cuda_version)
|
||||
return _triton.translate_llvmir_to_ptx(mod, compute_capability, ptx_version)
|
||||
|
||||
|
||||
def make_cubin(ptx: str, ptxas: str, compute_capability: int):
|
||||
|
||||
def ptx_to_cubin(ptx: str, device: int):
|
||||
'''
|
||||
Compile TritonGPU module to cubin.
|
||||
:param ptx: ptx code
|
||||
:param device: CUDA device
|
||||
:return: str
|
||||
'''
|
||||
ptxas, _ = path_to_ptxas()
|
||||
compute_capability = torch.cuda.get_device_capability(device)
|
||||
compute_capability = compute_capability[0] * 10 + compute_capability[1]
|
||||
return _triton.compile_ptx_to_cubin(ptx, ptxas, compute_capability)
|
||||
|
||||
|
||||
@@ -978,46 +991,6 @@ def path_to_ptxas():
|
||||
instance_descriptor = namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1"], defaults=[set(), set()])
|
||||
|
||||
|
||||
def _compile(fn, signature: str, device: int = -1, constants=dict(), specialization=instance_descriptor(), num_warps: int = 4, num_stages: int = 3, extern_libs=None, output: str = "ttgir") -> Tuple[str, int, str]:
|
||||
if isinstance(signature, str):
|
||||
signature = {k: v.strip() for k, v in enumerate(signature.split(","))}
|
||||
valid_outputs = ("ttir", "ttgir", "ptx", "cubin")
|
||||
assert output in valid_outputs, "output should be one of [%s], but get \"%s\"" % (','.join(valid_outputs), output)
|
||||
|
||||
# triton-ir
|
||||
module, _ = make_triton_ir(fn, signature, specialization, constants)
|
||||
module = optimize_triton_ir(module)
|
||||
if output == "ttir":
|
||||
return module.str()
|
||||
|
||||
# tritongpu-ir
|
||||
module = make_tritongpu_ir(module, num_warps)
|
||||
module = optimize_tritongpu_ir(module, num_stages)
|
||||
if output == "ttgir":
|
||||
return module.str()
|
||||
|
||||
if extern_libs:
|
||||
add_external_libs(module, extern_libs)
|
||||
|
||||
# llvm-ir
|
||||
llvm_ir = make_llvm_ir(module)
|
||||
|
||||
assert device >= 0, "device should be provided."
|
||||
ptxas, cuda_version = path_to_ptxas()
|
||||
compute_capability = torch.cuda.get_device_capability(device)
|
||||
compute_capability = compute_capability[0] * 10 + compute_capability[1]
|
||||
ptx_version = ptx_get_version(cuda_version)
|
||||
ptx = make_ptx(llvm_ir, compute_capability, ptx_version)
|
||||
shem_size = _triton.get_shared_memory_size(module)
|
||||
kernel_name = ptx_get_kernel_name(ptx)
|
||||
if output == "ptx":
|
||||
return ptx, shem_size, kernel_name
|
||||
|
||||
cubin = make_cubin(ptx, ptxas, compute_capability)
|
||||
if output == "cubin":
|
||||
return cubin, ptx, shem_size, kernel_name
|
||||
|
||||
assert False
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
@@ -1306,6 +1279,23 @@ def make_fn_cache_key(fn_hash, signature, configs, constants, num_warps, num_sta
|
||||
return key
|
||||
|
||||
|
||||
def read_or_execute(cache_manager, force_compile, file_name, metadata,
|
||||
run_if_found: Callable[[str], bytes] = None,
|
||||
run_if_not_found: Callable = None):
|
||||
if not force_compile and cache_manager.has_file(file_name):
|
||||
module = run_if_found(cache_manager._make_path(file_name))
|
||||
data = module if isinstance(module, bytes) else str(module).encode("utf-8")
|
||||
md5 = hashlib.md5(data).hexdigest()
|
||||
suffix = file_name.split(".")[1]
|
||||
has_changed = metadata and md5 != metadata["md5"][suffix]
|
||||
return module, md5, has_changed, True
|
||||
module = run_if_not_found()
|
||||
data = module if isinstance(module, bytes) else str(module).encode("utf-8")
|
||||
md5 = hashlib.md5(data).hexdigest()
|
||||
cache_manager.put(data, file_name, True)
|
||||
return module, md5, True, False
|
||||
|
||||
|
||||
def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: int = 4, num_stages: int = 3, extern_libs=None, configs=None):
|
||||
if isinstance(signature, str):
|
||||
signature = {k: v.strip() for k, v in enumerate(signature.split(","))}
|
||||
@@ -1329,29 +1319,56 @@ def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: i
|
||||
so = _build(fn.__name__, src_path, tmpdir)
|
||||
with open(so, "rb") as f:
|
||||
so_cache_manager.put(f.read(), so_name, binary=True)
|
||||
|
||||
# retrieve cached shared object if it exists
|
||||
so_path = so_cache_manager._make_path(so_name)
|
||||
# create cache manager
|
||||
fn_cache_key = make_fn_cache_key(fn.cache_key, signature, configs, constants, num_warps, num_stages)
|
||||
fn_cache_manager = CacheManager(fn_cache_key)
|
||||
ptx_name = f"{name}.ptx"
|
||||
cubin_name = f"{name}.cubin"
|
||||
data_name = f"{name}.json"
|
||||
if not fn_cache_manager.has_file(cubin_name) or \
|
||||
not fn_cache_manager.has_file(data_name) or \
|
||||
not fn_cache_manager.has_file(ptx_name):
|
||||
cubin, ptx, shared, kernel_name = _compile(fn, signature, device, constants, configs[0], num_warps, num_stages, extern_libs, "cubin")
|
||||
metadata = {"name": kernel_name, "shared": shared, "num_warps": num_warps, "num_stages": num_stages}
|
||||
fn_cache_manager.put(cubin, cubin_name)
|
||||
fn_cache_manager.put(ptx, ptx_name, binary=False)
|
||||
fn_cache_manager.put(json.dumps(metadata), data_name, binary=False)
|
||||
# load metadata if any
|
||||
metadata = None
|
||||
if fn_cache_manager.has_file(f'{name}.json'):
|
||||
with open(fn_cache_manager._make_path(f"{name}.json")) as f:
|
||||
metadata = json.load(f)
|
||||
context = _triton.ir.context()
|
||||
force_compile = False
|
||||
# ast -> triton-ir (or read from cache)
|
||||
ttir, ttir_md5, force_compile, _ = read_or_execute(fn_cache_manager, force_compile, f"{name}.ttir", metadata,
|
||||
run_if_found = lambda path: _triton.ir.parse_mlir_module(path, context),
|
||||
run_if_not_found = lambda: ast_to_ttir(fn, signature, configs[0], constants))
|
||||
# triton-ir -> triton-gpu-ir (or read from cache)
|
||||
ttgir, ttgir_md5, force_compile, _ = read_or_execute(fn_cache_manager, force_compile, f"{name}.ttgir", metadata,
|
||||
run_if_found = lambda path: _triton.ir.parse_mlir_module(path, context),
|
||||
run_if_not_found = lambda: ttir_to_ttgir(ttir, num_warps, num_stages))
|
||||
# triton-gpu-ir -> llvm-ir (or read from cache)
|
||||
llir, llir_md5, force_compile, llvm_cached = read_or_execute(fn_cache_manager, force_compile, f"{name}.llir", metadata,
|
||||
run_if_found = lambda path: Path(path).read_bytes(),
|
||||
run_if_not_found = lambda: ttgir_to_llir(ttgir, extern_libs))
|
||||
if llvm_cached:
|
||||
shmem_size = metadata["shared"]
|
||||
else:
|
||||
shmem_size = _triton.get_shared_memory_size(ttgir)
|
||||
# llvm-ir -> ptx (or read from cache)
|
||||
ptx, ptx_md5, force_compile, _ = read_or_execute(fn_cache_manager, force_compile, f"{name}.ptx", metadata,
|
||||
run_if_found = lambda path: Path(path).read_text(),
|
||||
run_if_not_found = lambda: llir_to_ptx(llir))
|
||||
# ptx -> cubin (or read from cache)
|
||||
cubin, cubin_md5, force_compile, _ = read_or_execute(fn_cache_manager, force_compile, f"{name}.cubin", metadata,
|
||||
run_if_found = lambda path: Path(path).read_bytes(),
|
||||
run_if_not_found= lambda: ptx_to_cubin(ptx, device))
|
||||
# dump new metadata
|
||||
kernel_name = ptx_get_kernel_name(ptx)
|
||||
metadata = {"name": kernel_name, "shared": shmem_size, "num_warps": num_warps, "num_stages": num_stages,
|
||||
"md5": { "cubin": cubin_md5, "ptx": ptx_md5,
|
||||
"llir": llir_md5,
|
||||
"ttir": ttir_md5, "ttgir": ttgir_md5 }}
|
||||
fn_cache_manager.put(json.dumps(metadata), f"{name}.json", binary=False)
|
||||
|
||||
return CompiledKernel(name, so_cache_manager._make_path(so_name), fn_cache_manager.cache_dir)
|
||||
asm = {"ttir": ttir, "ttgir": ttgir, "llir": llir, "ptx": ptx, "cubin": cubin}
|
||||
return CompiledKernel(so_path, metadata, asm)
|
||||
|
||||
|
||||
class CompiledKernel:
|
||||
|
||||
def __init__(self, fn_name, so_path, cache_dir):
|
||||
|
||||
def __init__(self, so_path, metadata, asm):
|
||||
# initialize launcher
|
||||
import importlib.util
|
||||
spec = importlib.util.spec_from_file_location("launcher", so_path)
|
||||
@@ -1359,18 +1376,11 @@ class CompiledKernel:
|
||||
spec.loader.exec_module(mod)
|
||||
self.c_wrapper = getattr(mod, "launch")
|
||||
# initialize metadata
|
||||
with open(os.path.join(cache_dir, f"{fn_name}.json")) as f:
|
||||
metadata = json.load(f)
|
||||
self.shared = metadata["shared"]
|
||||
self.num_warps = metadata["num_warps"]
|
||||
self.num_stages = metadata["num_stages"]
|
||||
# initialize asm dict
|
||||
self.asm = dict()
|
||||
with open(os.path.join(cache_dir, f"{fn_name}.cubin"), "rb") as f:
|
||||
self.asm["cubin"] = f.read()
|
||||
with open(os.path.join(cache_dir, f"{fn_name}.ptx"), "r") as f:
|
||||
self.asm["ptx"] = f.read()
|
||||
|
||||
self.asm = asm
|
||||
device = torch.cuda.current_device()
|
||||
global cuda_utils
|
||||
if cuda_utils is None:
|
||||
|
@@ -38,14 +38,13 @@ if __name__ == '__main__':
|
||||
exit(0)
|
||||
|
||||
# triton-ir -> triton-gpu-ir
|
||||
module = triton.compiler.make_tritongpu_ir(module, num_warps=4)
|
||||
module = triton.compiler.optimize_tritongpu_ir(module, num_stages=3)
|
||||
module = triton.compiler.ttir_to_ttgir(module, num_warps=4, num_stages=3)
|
||||
if args.target == 'triton-gpu-ir':
|
||||
print(module.str())
|
||||
exit(0)
|
||||
|
||||
# triton-gpu-ir -> llvm-ir
|
||||
module = triton.compiler.make_llvm_ir(module)
|
||||
module = triton.compiler.ttgir_to_llir(module, extern_libs=None)
|
||||
if args.target == 'llvm-ir':
|
||||
print(module)
|
||||
exit(0)
|
||||
@@ -56,6 +55,6 @@ if __name__ == '__main__':
|
||||
raise argparse.ArgumentError(None, "Must specify --ptx-version for PTX compilation")
|
||||
|
||||
# llvm-ir -> ptx
|
||||
module = triton.compiler.make_ptx(module, compute_capability=args.sm, ptx_version=args.ptx_version)
|
||||
module = triton.compiler.llir_to_ptx(module, compute_capability=args.sm, ptx_version=args.ptx_version)
|
||||
assert args.target == 'ptx'
|
||||
print(module)
|
||||
|
Reference in New Issue
Block a user