[RUNTIME] Re-vamped cache so users can manually patch IR / ptx / cubin files (#845)

Also deprecates a couple of tests
This commit is contained in:
Philippe Tillet
2022-11-04 10:57:29 -07:00
committed by GitHub
parent 4218e68d74
commit b6dbe959f0
5 changed files with 93 additions and 197 deletions

View File

@@ -261,7 +261,7 @@ void init_triton_ir(py::module &&m) {
}, },
ret::reference) ret::reference)
.def("dump", [](mlir::OpState &self) { self->dump(); }) .def("dump", [](mlir::OpState &self) { self->dump(); })
.def("str", .def("__str__",
[](mlir::OpState &self) -> std::string { [](mlir::OpState &self) -> std::string {
std::string str; std::string str;
llvm::raw_string_ostream os(str); llvm::raw_string_ostream os(str);
@@ -1280,8 +1280,8 @@ void init_triton_translation(py::module &m) {
using ret = py::return_value_policy; using ret = py::return_value_policy;
m.def("get_shared_memory_size", [](mlir::ModuleOp module) { m.def("get_shared_memory_size", [](mlir::ModuleOp module) {
return module->getAttrOfType<mlir::IntegerAttr>("triton_gpu.shared") auto shared = module->getAttrOfType<mlir::IntegerAttr>("triton_gpu.shared");
.getInt(); return shared.getInt();
}); });
m.def( m.def(

View File

@@ -1,33 +0,0 @@
import triton
import triton.language as tl
@triton.jit
def math_kernel(x1_ptr, x2_ptr, x3_ptr, x4_ptr, n, BLOCK_SIZE: tl.constexpr):
offsets = tl.arange(0, BLOCK_SIZE)
x1 = tl.load(x1_ptr + offsets, mask=offsets < n)
x2 = tl.load(x2_ptr + offsets, mask=offsets < n)
x3 = tl.load(x3_ptr + offsets, mask=offsets < n)
x4 = tl.load(x4_ptr + offsets, mask=offsets < n)
y1 = tl.sin(x1)
y2 = tl.libdevice.sin(x2)
y3 = tl.libdevice.div_rn(x3, x3)
y4 = tl.libdevice.fma_rd(x4, x4, x4)
tl.store(x1_ptr + offsets, y1, mask=offsets < n)
tl.store(x2_ptr + offsets, y2, mask=offsets < n)
tl.store(x3_ptr + offsets, y3, mask=offsets < n)
tl.store(x4_ptr + offsets, y4, mask=offsets < n)
def test_empty_kernel_cubin_compile():
kernel = triton.compiler._compile(math_kernel,
"*fp32,*fp32,*fp32,*fp32,i32",
device=0,
constants={"BLOCK_SIZE": 256},
output="ttgir") # "cubin"
assert kernel
# TODO: Check if the values are correct.
# TODO: Cover all the math operators

View File

@@ -1,80 +0,0 @@
import triton
import triton.language as tl
# TODO: function with no arguments don't work
@triton.jit
def binop_type_check(X):
# 0d-tensor is not allowed.
# zero_0d = tl.zeros([], dtype=tl.float32)
zero_1d = tl.zeros([2], dtype=tl.float32)
zero_2d_21 = tl.zeros([2, 1], dtype=tl.float32)
zero_2d_22 = tl.zeros([2, 2], dtype=tl.float32)
# scalar + scalar -> scalar
a0 = 0.0 + 0.0
# # scalar + 0D -> 0D
# a1 = 0.0 + zero_0d
# a2 = zero_0d + 0.0
# scalar + 1D -> 1D
a3 = 0.0 + zero_1d
a4 = zero_1d + 0.0
# scalar + 2D -> 2D
a5 = 0.0 + zero_2d_22
a6 = zero_2d_22 + 0.0
# # 0D + 0D -> 0D
# b1 = zero_0d + zero_0d
# # 0D + 1D -> 1D
# b2 = zero_0d + zero_1d
# b3 = zero_1d + zero_0d
# # 0D + 2D -> 2D
# b4 = zero_0d + zero_2d_22
# b5 = zero_2d_22 + zero_0d
# 1D + 1D -> 1D
c1 = zero_1d + zero_1d
# 1D + 2D -> 2D
c2 = zero_1d + zero_2d_21
c3 = zero_1d + zero_2d_22
c4 = zero_2d_21 + zero_1d
c5 = zero_2d_22 + zero_1d
# 2D + 2D -> 2D
d1 = zero_2d_21 + zero_2d_21
d2 = zero_2d_22 + zero_2d_22
d3 = zero_2d_21 + zero_2d_22
d4 = zero_2d_22 + zero_2d_21
# return a0, a1, a2, a3, a4, a5, a6, b1, b2, b3, b4, b5, c1, c2, c3, c4, c5, d1, d2, d3, d4
return a0, a3, a4, a5, a6, c1, c2, c3, c4, c5, d1, d2, d3, d4
def test_binop_type_check():
kernel = triton.compiler._compile(binop_type_check,
signature="*fp32",
device=0,
output="ttir")
assert (kernel)
# TODO: Check types of the results
@triton.jit
def reduce_type_check(ptr):
v_32 = tl.load(ptr + tl.arange(0, 32))
v_scalar = tl.min(v_32, axis=0)
tl.store(ptr, v_scalar)
v_64x128 = tl.load(ptr + tl.arange(0, 64)[:, None] + tl.arange(0, 128)[None, :])
v_64 = tl.max(v_64x128, axis=1)
tl.store(ptr + tl.arange(0, 64), v_64)
v_128 = tl.max(v_64x128, axis=0)
tl.store(ptr + tl.arange(0, 128), v_128)
def test_reduce_type_check():
kernel = triton.compiler._compile(reduce_type_check,
signature="*fp32",
device=0,
output="ttir")
assert (kernel)
# TODO: Check types of the results

View File

@@ -16,8 +16,9 @@ import tempfile
import warnings import warnings
from collections import namedtuple from collections import namedtuple
from sysconfig import get_paths from sysconfig import get_paths
from typing import Any, Dict, Tuple, Union from typing import Any, Callable, Dict, Tuple, Union
from pathlib import Path
import setuptools import setuptools
import torch import torch
from filelock import FileLock from filelock import FileLock
@@ -822,7 +823,10 @@ def kernel_suffix(signature, specialization):
# ------------------------------------------------------------------------------ # ------------------------------------------------------------------------------
def make_triton_ir(fn, signature, specialization, constants): def build_triton_ir(fn, signature, specialization, constants):
# canonicalize signature
if isinstance(signature, str):
signature = {k: v.strip() for k, v in enumerate(signature.split(","))}
context = _triton.ir.context() context = _triton.ir.context()
context.load_triton() context.load_triton()
# create kernel prototype # create kernel prototype
@@ -852,7 +856,6 @@ def make_triton_ir(fn, signature, specialization, constants):
ret.context = context ret.context = context
return ret, generator return ret, generator
def optimize_triton_ir(mod): def optimize_triton_ir(mod):
pm = _triton.ir.pass_manager(mod.context) pm = _triton.ir.pass_manager(mod.context)
pm.enable_debug() pm.enable_debug()
@@ -864,16 +867,13 @@ def optimize_triton_ir(mod):
pm.run(mod) pm.run(mod)
return mod return mod
def ast_to_ttir(fn, signature, specialization, constants):
mod, _ = build_triton_ir(fn, signature, specialization, constants)
return optimize_triton_ir(mod)
def make_tritongpu_ir(mod, num_warps): def ttir_to_ttgir(mod, num_warps, num_stages):
pm = _triton.ir.pass_manager(mod.context) pm = _triton.ir.pass_manager(mod.context)
pm.add_convert_triton_to_tritongpu_pass(num_warps) pm.add_convert_triton_to_tritongpu_pass(num_warps)
pm.run(mod)
return mod
def optimize_tritongpu_ir(mod, num_stages):
pm = _triton.ir.pass_manager(mod.context)
pm.enable_debug() pm.enable_debug()
# Get error in backend due to wrong conversion in expanding async-related instruction. # Get error in backend due to wrong conversion in expanding async-related instruction.
# TODO[Superjomn]: Open it when fixed. # TODO[Superjomn]: Open it when fixed.
@@ -897,11 +897,13 @@ def add_external_libs(mod, libs):
_triton.add_external_libs(mod, list(libs.keys()), list(libs.values())) _triton.add_external_libs(mod, list(libs.keys()), list(libs.values()))
def make_llvm_ir(mod): def ttgir_to_llir(mod, extern_libs):
if extern_libs:
add_external_libs(mod, extern_libs)
return _triton.translate_triton_gpu_to_llvmir(mod) return _triton.translate_triton_gpu_to_llvmir(mod)
def make_ptx(mod: Any, compute_capability: int, ptx_version: int) -> Tuple[str, int]: def llir_to_ptx(mod: Any, compute_capability: int = None, ptx_version: int = None) -> Tuple[str, int]:
''' '''
Translate TritonGPU module to PTX code. Translate TritonGPU module to PTX code.
:param mod: a TritonGPU dialect module :param mod: a TritonGPU dialect module
@@ -909,16 +911,27 @@ def make_ptx(mod: Any, compute_capability: int, ptx_version: int) -> Tuple[str,
- PTX code - PTX code
- shared memory alloaction size - shared memory alloaction size
''' '''
if compute_capability is None:
device = torch.cuda.current_device()
compute_capability = torch.cuda.get_device_capability(device)
compute_capability = compute_capability[0] * 10 + compute_capability[1]
if ptx_version is None:
_, cuda_version = path_to_ptxas()
ptx_version = ptx_get_version(cuda_version)
return _triton.translate_llvmir_to_ptx(mod, compute_capability, ptx_version) return _triton.translate_llvmir_to_ptx(mod, compute_capability, ptx_version)
def make_cubin(ptx: str, ptxas: str, compute_capability: int):
def ptx_to_cubin(ptx: str, device: int):
''' '''
Compile TritonGPU module to cubin. Compile TritonGPU module to cubin.
:param ptx: ptx code :param ptx: ptx code
:param device: CUDA device :param device: CUDA device
:return: str :return: str
''' '''
ptxas, _ = path_to_ptxas()
compute_capability = torch.cuda.get_device_capability(device)
compute_capability = compute_capability[0] * 10 + compute_capability[1]
return _triton.compile_ptx_to_cubin(ptx, ptxas, compute_capability) return _triton.compile_ptx_to_cubin(ptx, ptxas, compute_capability)
@@ -978,46 +991,6 @@ def path_to_ptxas():
instance_descriptor = namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1"], defaults=[set(), set()]) instance_descriptor = namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1"], defaults=[set(), set()])
def _compile(fn, signature: str, device: int = -1, constants=dict(), specialization=instance_descriptor(), num_warps: int = 4, num_stages: int = 3, extern_libs=None, output: str = "ttgir") -> Tuple[str, int, str]:
if isinstance(signature, str):
signature = {k: v.strip() for k, v in enumerate(signature.split(","))}
valid_outputs = ("ttir", "ttgir", "ptx", "cubin")
assert output in valid_outputs, "output should be one of [%s], but get \"%s\"" % (','.join(valid_outputs), output)
# triton-ir
module, _ = make_triton_ir(fn, signature, specialization, constants)
module = optimize_triton_ir(module)
if output == "ttir":
return module.str()
# tritongpu-ir
module = make_tritongpu_ir(module, num_warps)
module = optimize_tritongpu_ir(module, num_stages)
if output == "ttgir":
return module.str()
if extern_libs:
add_external_libs(module, extern_libs)
# llvm-ir
llvm_ir = make_llvm_ir(module)
assert device >= 0, "device should be provided."
ptxas, cuda_version = path_to_ptxas()
compute_capability = torch.cuda.get_device_capability(device)
compute_capability = compute_capability[0] * 10 + compute_capability[1]
ptx_version = ptx_get_version(cuda_version)
ptx = make_ptx(llvm_ir, compute_capability, ptx_version)
shem_size = _triton.get_shared_memory_size(module)
kernel_name = ptx_get_kernel_name(ptx)
if output == "ptx":
return ptx, shem_size, kernel_name
cubin = make_cubin(ptx, ptxas, compute_capability)
if output == "cubin":
return cubin, ptx, shem_size, kernel_name
assert False
# ------------------------------------------------------------------------------ # ------------------------------------------------------------------------------
@@ -1306,6 +1279,23 @@ def make_fn_cache_key(fn_hash, signature, configs, constants, num_warps, num_sta
return key return key
def read_or_execute(cache_manager, force_compile, file_name, metadata,
run_if_found: Callable[[str], bytes] = None,
run_if_not_found: Callable = None):
if not force_compile and cache_manager.has_file(file_name):
module = run_if_found(cache_manager._make_path(file_name))
data = module if isinstance(module, bytes) else str(module).encode("utf-8")
md5 = hashlib.md5(data).hexdigest()
suffix = file_name.split(".")[1]
has_changed = metadata and md5 != metadata["md5"][suffix]
return module, md5, has_changed, True
module = run_if_not_found()
data = module if isinstance(module, bytes) else str(module).encode("utf-8")
md5 = hashlib.md5(data).hexdigest()
cache_manager.put(data, file_name, True)
return module, md5, True, False
def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: int = 4, num_stages: int = 3, extern_libs=None, configs=None): def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: int = 4, num_stages: int = 3, extern_libs=None, configs=None):
if isinstance(signature, str): if isinstance(signature, str):
signature = {k: v.strip() for k, v in enumerate(signature.split(","))} signature = {k: v.strip() for k, v in enumerate(signature.split(","))}
@@ -1329,29 +1319,56 @@ def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: i
so = _build(fn.__name__, src_path, tmpdir) so = _build(fn.__name__, src_path, tmpdir)
with open(so, "rb") as f: with open(so, "rb") as f:
so_cache_manager.put(f.read(), so_name, binary=True) so_cache_manager.put(f.read(), so_name, binary=True)
so_path = so_cache_manager._make_path(so_name)
# retrieve cached shared object if it exists # create cache manager
fn_cache_key = make_fn_cache_key(fn.cache_key, signature, configs, constants, num_warps, num_stages) fn_cache_key = make_fn_cache_key(fn.cache_key, signature, configs, constants, num_warps, num_stages)
fn_cache_manager = CacheManager(fn_cache_key) fn_cache_manager = CacheManager(fn_cache_key)
ptx_name = f"{name}.ptx" # load metadata if any
cubin_name = f"{name}.cubin" metadata = None
data_name = f"{name}.json" if fn_cache_manager.has_file(f'{name}.json'):
if not fn_cache_manager.has_file(cubin_name) or \ with open(fn_cache_manager._make_path(f"{name}.json")) as f:
not fn_cache_manager.has_file(data_name) or \ metadata = json.load(f)
not fn_cache_manager.has_file(ptx_name): context = _triton.ir.context()
cubin, ptx, shared, kernel_name = _compile(fn, signature, device, constants, configs[0], num_warps, num_stages, extern_libs, "cubin") force_compile = False
metadata = {"name": kernel_name, "shared": shared, "num_warps": num_warps, "num_stages": num_stages} # ast -> triton-ir (or read from cache)
fn_cache_manager.put(cubin, cubin_name) ttir, ttir_md5, force_compile, _ = read_or_execute(fn_cache_manager, force_compile, f"{name}.ttir", metadata,
fn_cache_manager.put(ptx, ptx_name, binary=False) run_if_found = lambda path: _triton.ir.parse_mlir_module(path, context),
fn_cache_manager.put(json.dumps(metadata), data_name, binary=False) run_if_not_found = lambda: ast_to_ttir(fn, signature, configs[0], constants))
# triton-ir -> triton-gpu-ir (or read from cache)
ttgir, ttgir_md5, force_compile, _ = read_or_execute(fn_cache_manager, force_compile, f"{name}.ttgir", metadata,
run_if_found = lambda path: _triton.ir.parse_mlir_module(path, context),
run_if_not_found = lambda: ttir_to_ttgir(ttir, num_warps, num_stages))
# triton-gpu-ir -> llvm-ir (or read from cache)
llir, llir_md5, force_compile, llvm_cached = read_or_execute(fn_cache_manager, force_compile, f"{name}.llir", metadata,
run_if_found = lambda path: Path(path).read_bytes(),
run_if_not_found = lambda: ttgir_to_llir(ttgir, extern_libs))
if llvm_cached:
shmem_size = metadata["shared"]
else:
shmem_size = _triton.get_shared_memory_size(ttgir)
# llvm-ir -> ptx (or read from cache)
ptx, ptx_md5, force_compile, _ = read_or_execute(fn_cache_manager, force_compile, f"{name}.ptx", metadata,
run_if_found = lambda path: Path(path).read_text(),
run_if_not_found = lambda: llir_to_ptx(llir))
# ptx -> cubin (or read from cache)
cubin, cubin_md5, force_compile, _ = read_or_execute(fn_cache_manager, force_compile, f"{name}.cubin", metadata,
run_if_found = lambda path: Path(path).read_bytes(),
run_if_not_found= lambda: ptx_to_cubin(ptx, device))
# dump new metadata
kernel_name = ptx_get_kernel_name(ptx)
metadata = {"name": kernel_name, "shared": shmem_size, "num_warps": num_warps, "num_stages": num_stages,
"md5": { "cubin": cubin_md5, "ptx": ptx_md5,
"llir": llir_md5,
"ttir": ttir_md5, "ttgir": ttgir_md5 }}
fn_cache_manager.put(json.dumps(metadata), f"{name}.json", binary=False)
return CompiledKernel(name, so_cache_manager._make_path(so_name), fn_cache_manager.cache_dir) asm = {"ttir": ttir, "ttgir": ttgir, "llir": llir, "ptx": ptx, "cubin": cubin}
return CompiledKernel(so_path, metadata, asm)
class CompiledKernel: class CompiledKernel:
def __init__(self, fn_name, so_path, cache_dir): def __init__(self, so_path, metadata, asm):
# initialize launcher # initialize launcher
import importlib.util import importlib.util
spec = importlib.util.spec_from_file_location("launcher", so_path) spec = importlib.util.spec_from_file_location("launcher", so_path)
@@ -1359,18 +1376,11 @@ class CompiledKernel:
spec.loader.exec_module(mod) spec.loader.exec_module(mod)
self.c_wrapper = getattr(mod, "launch") self.c_wrapper = getattr(mod, "launch")
# initialize metadata # initialize metadata
with open(os.path.join(cache_dir, f"{fn_name}.json")) as f:
metadata = json.load(f)
self.shared = metadata["shared"] self.shared = metadata["shared"]
self.num_warps = metadata["num_warps"] self.num_warps = metadata["num_warps"]
self.num_stages = metadata["num_stages"] self.num_stages = metadata["num_stages"]
# initialize asm dict # initialize asm dict
self.asm = dict() self.asm = asm
with open(os.path.join(cache_dir, f"{fn_name}.cubin"), "rb") as f:
self.asm["cubin"] = f.read()
with open(os.path.join(cache_dir, f"{fn_name}.ptx"), "r") as f:
self.asm["ptx"] = f.read()
device = torch.cuda.current_device() device = torch.cuda.current_device()
global cuda_utils global cuda_utils
if cuda_utils is None: if cuda_utils is None:

View File

@@ -38,14 +38,13 @@ if __name__ == '__main__':
exit(0) exit(0)
# triton-ir -> triton-gpu-ir # triton-ir -> triton-gpu-ir
module = triton.compiler.make_tritongpu_ir(module, num_warps=4) module = triton.compiler.ttir_to_ttgir(module, num_warps=4, num_stages=3)
module = triton.compiler.optimize_tritongpu_ir(module, num_stages=3)
if args.target == 'triton-gpu-ir': if args.target == 'triton-gpu-ir':
print(module.str()) print(module.str())
exit(0) exit(0)
# triton-gpu-ir -> llvm-ir # triton-gpu-ir -> llvm-ir
module = triton.compiler.make_llvm_ir(module) module = triton.compiler.ttgir_to_llir(module, extern_libs=None)
if args.target == 'llvm-ir': if args.target == 'llvm-ir':
print(module) print(module)
exit(0) exit(0)
@@ -56,6 +55,6 @@ if __name__ == '__main__':
raise argparse.ArgumentError(None, "Must specify --ptx-version for PTX compilation") raise argparse.ArgumentError(None, "Must specify --ptx-version for PTX compilation")
# llvm-ir -> ptx # llvm-ir -> ptx
module = triton.compiler.make_ptx(module, compute_capability=args.sm, ptx_version=args.ptx_version) module = triton.compiler.llir_to_ptx(module, compute_capability=args.sm, ptx_version=args.ptx_version)
assert args.target == 'ptx' assert args.target == 'ptx'
print(module) print(module)