From b6dbe959f053929e275ba9b1cd5b904dd10b6430 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Fri, 4 Nov 2022 10:57:29 -0700 Subject: [PATCH] [RUNTIME] Re-vamped cache so users can manually patch IR / ptx / cubin files (#845) Also deprecates a couple of tests --- python/src/triton.cc | 6 +- python/tests/test_math_ops.py | 33 ------- python/tests/test_type.py | 80 ----------------- python/triton/compiler.py | 164 ++++++++++++++++++---------------- python/triton/tools/aot.py | 7 +- 5 files changed, 93 insertions(+), 197 deletions(-) delete mode 100644 python/tests/test_math_ops.py delete mode 100644 python/tests/test_type.py diff --git a/python/src/triton.cc b/python/src/triton.cc index 4cbc3f1e7..b3b745c43 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -261,7 +261,7 @@ void init_triton_ir(py::module &&m) { }, ret::reference) .def("dump", [](mlir::OpState &self) { self->dump(); }) - .def("str", + .def("__str__", [](mlir::OpState &self) -> std::string { std::string str; llvm::raw_string_ostream os(str); @@ -1280,8 +1280,8 @@ void init_triton_translation(py::module &m) { using ret = py::return_value_policy; m.def("get_shared_memory_size", [](mlir::ModuleOp module) { - return module->getAttrOfType("triton_gpu.shared") - .getInt(); + auto shared = module->getAttrOfType("triton_gpu.shared"); + return shared.getInt(); }); m.def( diff --git a/python/tests/test_math_ops.py b/python/tests/test_math_ops.py deleted file mode 100644 index 6b3463490..000000000 --- a/python/tests/test_math_ops.py +++ /dev/null @@ -1,33 +0,0 @@ - -import triton -import triton.language as tl - - -@triton.jit -def math_kernel(x1_ptr, x2_ptr, x3_ptr, x4_ptr, n, BLOCK_SIZE: tl.constexpr): - offsets = tl.arange(0, BLOCK_SIZE) - x1 = tl.load(x1_ptr + offsets, mask=offsets < n) - x2 = tl.load(x2_ptr + offsets, mask=offsets < n) - x3 = tl.load(x3_ptr + offsets, mask=offsets < n) - x4 = tl.load(x4_ptr + offsets, mask=offsets < n) - - y1 = tl.sin(x1) - y2 = tl.libdevice.sin(x2) - y3 = tl.libdevice.div_rn(x3, x3) - y4 = tl.libdevice.fma_rd(x4, x4, x4) - - tl.store(x1_ptr + offsets, y1, mask=offsets < n) - tl.store(x2_ptr + offsets, y2, mask=offsets < n) - tl.store(x3_ptr + offsets, y3, mask=offsets < n) - tl.store(x4_ptr + offsets, y4, mask=offsets < n) - - -def test_empty_kernel_cubin_compile(): - kernel = triton.compiler._compile(math_kernel, - "*fp32,*fp32,*fp32,*fp32,i32", - device=0, - constants={"BLOCK_SIZE": 256}, - output="ttgir") # "cubin" - assert kernel - # TODO: Check if the values are correct. - # TODO: Cover all the math operators diff --git a/python/tests/test_type.py b/python/tests/test_type.py deleted file mode 100644 index 8580b967a..000000000 --- a/python/tests/test_type.py +++ /dev/null @@ -1,80 +0,0 @@ -import triton -import triton.language as tl - - -# TODO: function with no arguments don't work -@triton.jit -def binop_type_check(X): - # 0d-tensor is not allowed. - # zero_0d = tl.zeros([], dtype=tl.float32) - zero_1d = tl.zeros([2], dtype=tl.float32) - zero_2d_21 = tl.zeros([2, 1], dtype=tl.float32) - zero_2d_22 = tl.zeros([2, 2], dtype=tl.float32) - - # scalar + scalar -> scalar - a0 = 0.0 + 0.0 - # # scalar + 0D -> 0D - # a1 = 0.0 + zero_0d - # a2 = zero_0d + 0.0 - # scalar + 1D -> 1D - a3 = 0.0 + zero_1d - a4 = zero_1d + 0.0 - # scalar + 2D -> 2D - a5 = 0.0 + zero_2d_22 - a6 = zero_2d_22 + 0.0 - - # # 0D + 0D -> 0D - # b1 = zero_0d + zero_0d - # # 0D + 1D -> 1D - # b2 = zero_0d + zero_1d - # b3 = zero_1d + zero_0d - # # 0D + 2D -> 2D - # b4 = zero_0d + zero_2d_22 - # b5 = zero_2d_22 + zero_0d - - # 1D + 1D -> 1D - c1 = zero_1d + zero_1d - # 1D + 2D -> 2D - c2 = zero_1d + zero_2d_21 - c3 = zero_1d + zero_2d_22 - c4 = zero_2d_21 + zero_1d - c5 = zero_2d_22 + zero_1d - - # 2D + 2D -> 2D - d1 = zero_2d_21 + zero_2d_21 - d2 = zero_2d_22 + zero_2d_22 - d3 = zero_2d_21 + zero_2d_22 - d4 = zero_2d_22 + zero_2d_21 - - # return a0, a1, a2, a3, a4, a5, a6, b1, b2, b3, b4, b5, c1, c2, c3, c4, c5, d1, d2, d3, d4 - return a0, a3, a4, a5, a6, c1, c2, c3, c4, c5, d1, d2, d3, d4 - - -def test_binop_type_check(): - kernel = triton.compiler._compile(binop_type_check, - signature="*fp32", - device=0, - output="ttir") - assert (kernel) - # TODO: Check types of the results - - -@triton.jit -def reduce_type_check(ptr): - v_32 = tl.load(ptr + tl.arange(0, 32)) - v_scalar = tl.min(v_32, axis=0) - tl.store(ptr, v_scalar) - v_64x128 = tl.load(ptr + tl.arange(0, 64)[:, None] + tl.arange(0, 128)[None, :]) - v_64 = tl.max(v_64x128, axis=1) - tl.store(ptr + tl.arange(0, 64), v_64) - v_128 = tl.max(v_64x128, axis=0) - tl.store(ptr + tl.arange(0, 128), v_128) - - -def test_reduce_type_check(): - kernel = triton.compiler._compile(reduce_type_check, - signature="*fp32", - device=0, - output="ttir") - assert (kernel) - # TODO: Check types of the results diff --git a/python/triton/compiler.py b/python/triton/compiler.py index 2558cd322..38aeb70a4 100644 --- a/python/triton/compiler.py +++ b/python/triton/compiler.py @@ -16,8 +16,9 @@ import tempfile import warnings from collections import namedtuple from sysconfig import get_paths -from typing import Any, Dict, Tuple, Union +from typing import Any, Callable, Dict, Tuple, Union +from pathlib import Path import setuptools import torch from filelock import FileLock @@ -822,7 +823,10 @@ def kernel_suffix(signature, specialization): # ------------------------------------------------------------------------------ -def make_triton_ir(fn, signature, specialization, constants): +def build_triton_ir(fn, signature, specialization, constants): + # canonicalize signature + if isinstance(signature, str): + signature = {k: v.strip() for k, v in enumerate(signature.split(","))} context = _triton.ir.context() context.load_triton() # create kernel prototype @@ -852,7 +856,6 @@ def make_triton_ir(fn, signature, specialization, constants): ret.context = context return ret, generator - def optimize_triton_ir(mod): pm = _triton.ir.pass_manager(mod.context) pm.enable_debug() @@ -864,16 +867,13 @@ def optimize_triton_ir(mod): pm.run(mod) return mod +def ast_to_ttir(fn, signature, specialization, constants): + mod, _ = build_triton_ir(fn, signature, specialization, constants) + return optimize_triton_ir(mod) -def make_tritongpu_ir(mod, num_warps): +def ttir_to_ttgir(mod, num_warps, num_stages): pm = _triton.ir.pass_manager(mod.context) pm.add_convert_triton_to_tritongpu_pass(num_warps) - pm.run(mod) - return mod - - -def optimize_tritongpu_ir(mod, num_stages): - pm = _triton.ir.pass_manager(mod.context) pm.enable_debug() # Get error in backend due to wrong conversion in expanding async-related instruction. # TODO[Superjomn]: Open it when fixed. @@ -897,11 +897,13 @@ def add_external_libs(mod, libs): _triton.add_external_libs(mod, list(libs.keys()), list(libs.values())) -def make_llvm_ir(mod): +def ttgir_to_llir(mod, extern_libs): + if extern_libs: + add_external_libs(mod, extern_libs) return _triton.translate_triton_gpu_to_llvmir(mod) -def make_ptx(mod: Any, compute_capability: int, ptx_version: int) -> Tuple[str, int]: +def llir_to_ptx(mod: Any, compute_capability: int = None, ptx_version: int = None) -> Tuple[str, int]: ''' Translate TritonGPU module to PTX code. :param mod: a TritonGPU dialect module @@ -909,16 +911,27 @@ def make_ptx(mod: Any, compute_capability: int, ptx_version: int) -> Tuple[str, - PTX code - shared memory alloaction size ''' + if compute_capability is None: + device = torch.cuda.current_device() + compute_capability = torch.cuda.get_device_capability(device) + compute_capability = compute_capability[0] * 10 + compute_capability[1] + if ptx_version is None: + _, cuda_version = path_to_ptxas() + ptx_version = ptx_get_version(cuda_version) return _triton.translate_llvmir_to_ptx(mod, compute_capability, ptx_version) -def make_cubin(ptx: str, ptxas: str, compute_capability: int): + +def ptx_to_cubin(ptx: str, device: int): ''' Compile TritonGPU module to cubin. :param ptx: ptx code :param device: CUDA device :return: str ''' + ptxas, _ = path_to_ptxas() + compute_capability = torch.cuda.get_device_capability(device) + compute_capability = compute_capability[0] * 10 + compute_capability[1] return _triton.compile_ptx_to_cubin(ptx, ptxas, compute_capability) @@ -978,46 +991,6 @@ def path_to_ptxas(): instance_descriptor = namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1"], defaults=[set(), set()]) -def _compile(fn, signature: str, device: int = -1, constants=dict(), specialization=instance_descriptor(), num_warps: int = 4, num_stages: int = 3, extern_libs=None, output: str = "ttgir") -> Tuple[str, int, str]: - if isinstance(signature, str): - signature = {k: v.strip() for k, v in enumerate(signature.split(","))} - valid_outputs = ("ttir", "ttgir", "ptx", "cubin") - assert output in valid_outputs, "output should be one of [%s], but get \"%s\"" % (','.join(valid_outputs), output) - - # triton-ir - module, _ = make_triton_ir(fn, signature, specialization, constants) - module = optimize_triton_ir(module) - if output == "ttir": - return module.str() - - # tritongpu-ir - module = make_tritongpu_ir(module, num_warps) - module = optimize_tritongpu_ir(module, num_stages) - if output == "ttgir": - return module.str() - - if extern_libs: - add_external_libs(module, extern_libs) - - # llvm-ir - llvm_ir = make_llvm_ir(module) - - assert device >= 0, "device should be provided." - ptxas, cuda_version = path_to_ptxas() - compute_capability = torch.cuda.get_device_capability(device) - compute_capability = compute_capability[0] * 10 + compute_capability[1] - ptx_version = ptx_get_version(cuda_version) - ptx = make_ptx(llvm_ir, compute_capability, ptx_version) - shem_size = _triton.get_shared_memory_size(module) - kernel_name = ptx_get_kernel_name(ptx) - if output == "ptx": - return ptx, shem_size, kernel_name - - cubin = make_cubin(ptx, ptxas, compute_capability) - if output == "cubin": - return cubin, ptx, shem_size, kernel_name - - assert False # ------------------------------------------------------------------------------ @@ -1306,6 +1279,23 @@ def make_fn_cache_key(fn_hash, signature, configs, constants, num_warps, num_sta return key +def read_or_execute(cache_manager, force_compile, file_name, metadata, + run_if_found: Callable[[str], bytes] = None, + run_if_not_found: Callable = None): + if not force_compile and cache_manager.has_file(file_name): + module = run_if_found(cache_manager._make_path(file_name)) + data = module if isinstance(module, bytes) else str(module).encode("utf-8") + md5 = hashlib.md5(data).hexdigest() + suffix = file_name.split(".")[1] + has_changed = metadata and md5 != metadata["md5"][suffix] + return module, md5, has_changed, True + module = run_if_not_found() + data = module if isinstance(module, bytes) else str(module).encode("utf-8") + md5 = hashlib.md5(data).hexdigest() + cache_manager.put(data, file_name, True) + return module, md5, True, False + + def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: int = 4, num_stages: int = 3, extern_libs=None, configs=None): if isinstance(signature, str): signature = {k: v.strip() for k, v in enumerate(signature.split(","))} @@ -1329,29 +1319,56 @@ def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: i so = _build(fn.__name__, src_path, tmpdir) with open(so, "rb") as f: so_cache_manager.put(f.read(), so_name, binary=True) - - # retrieve cached shared object if it exists + so_path = so_cache_manager._make_path(so_name) + # create cache manager fn_cache_key = make_fn_cache_key(fn.cache_key, signature, configs, constants, num_warps, num_stages) fn_cache_manager = CacheManager(fn_cache_key) - ptx_name = f"{name}.ptx" - cubin_name = f"{name}.cubin" - data_name = f"{name}.json" - if not fn_cache_manager.has_file(cubin_name) or \ - not fn_cache_manager.has_file(data_name) or \ - not fn_cache_manager.has_file(ptx_name): - cubin, ptx, shared, kernel_name = _compile(fn, signature, device, constants, configs[0], num_warps, num_stages, extern_libs, "cubin") - metadata = {"name": kernel_name, "shared": shared, "num_warps": num_warps, "num_stages": num_stages} - fn_cache_manager.put(cubin, cubin_name) - fn_cache_manager.put(ptx, ptx_name, binary=False) - fn_cache_manager.put(json.dumps(metadata), data_name, binary=False) + # load metadata if any + metadata = None + if fn_cache_manager.has_file(f'{name}.json'): + with open(fn_cache_manager._make_path(f"{name}.json")) as f: + metadata = json.load(f) + context = _triton.ir.context() + force_compile = False + # ast -> triton-ir (or read from cache) + ttir, ttir_md5, force_compile, _ = read_or_execute(fn_cache_manager, force_compile, f"{name}.ttir", metadata, + run_if_found = lambda path: _triton.ir.parse_mlir_module(path, context), + run_if_not_found = lambda: ast_to_ttir(fn, signature, configs[0], constants)) + # triton-ir -> triton-gpu-ir (or read from cache) + ttgir, ttgir_md5, force_compile, _ = read_or_execute(fn_cache_manager, force_compile, f"{name}.ttgir", metadata, + run_if_found = lambda path: _triton.ir.parse_mlir_module(path, context), + run_if_not_found = lambda: ttir_to_ttgir(ttir, num_warps, num_stages)) + # triton-gpu-ir -> llvm-ir (or read from cache) + llir, llir_md5, force_compile, llvm_cached = read_or_execute(fn_cache_manager, force_compile, f"{name}.llir", metadata, + run_if_found = lambda path: Path(path).read_bytes(), + run_if_not_found = lambda: ttgir_to_llir(ttgir, extern_libs)) + if llvm_cached: + shmem_size = metadata["shared"] + else: + shmem_size = _triton.get_shared_memory_size(ttgir) + # llvm-ir -> ptx (or read from cache) + ptx, ptx_md5, force_compile, _ = read_or_execute(fn_cache_manager, force_compile, f"{name}.ptx", metadata, + run_if_found = lambda path: Path(path).read_text(), + run_if_not_found = lambda: llir_to_ptx(llir)) + # ptx -> cubin (or read from cache) + cubin, cubin_md5, force_compile, _ = read_or_execute(fn_cache_manager, force_compile, f"{name}.cubin", metadata, + run_if_found = lambda path: Path(path).read_bytes(), + run_if_not_found= lambda: ptx_to_cubin(ptx, device)) + # dump new metadata + kernel_name = ptx_get_kernel_name(ptx) + metadata = {"name": kernel_name, "shared": shmem_size, "num_warps": num_warps, "num_stages": num_stages, + "md5": { "cubin": cubin_md5, "ptx": ptx_md5, + "llir": llir_md5, + "ttir": ttir_md5, "ttgir": ttgir_md5 }} + fn_cache_manager.put(json.dumps(metadata), f"{name}.json", binary=False) - return CompiledKernel(name, so_cache_manager._make_path(so_name), fn_cache_manager.cache_dir) + asm = {"ttir": ttir, "ttgir": ttgir, "llir": llir, "ptx": ptx, "cubin": cubin} + return CompiledKernel(so_path, metadata, asm) class CompiledKernel: - def __init__(self, fn_name, so_path, cache_dir): - + def __init__(self, so_path, metadata, asm): # initialize launcher import importlib.util spec = importlib.util.spec_from_file_location("launcher", so_path) @@ -1359,18 +1376,11 @@ class CompiledKernel: spec.loader.exec_module(mod) self.c_wrapper = getattr(mod, "launch") # initialize metadata - with open(os.path.join(cache_dir, f"{fn_name}.json")) as f: - metadata = json.load(f) self.shared = metadata["shared"] self.num_warps = metadata["num_warps"] self.num_stages = metadata["num_stages"] # initialize asm dict - self.asm = dict() - with open(os.path.join(cache_dir, f"{fn_name}.cubin"), "rb") as f: - self.asm["cubin"] = f.read() - with open(os.path.join(cache_dir, f"{fn_name}.ptx"), "r") as f: - self.asm["ptx"] = f.read() - + self.asm = asm device = torch.cuda.current_device() global cuda_utils if cuda_utils is None: diff --git a/python/triton/tools/aot.py b/python/triton/tools/aot.py index 72df49d4c..28d9e3500 100644 --- a/python/triton/tools/aot.py +++ b/python/triton/tools/aot.py @@ -38,14 +38,13 @@ if __name__ == '__main__': exit(0) # triton-ir -> triton-gpu-ir - module = triton.compiler.make_tritongpu_ir(module, num_warps=4) - module = triton.compiler.optimize_tritongpu_ir(module, num_stages=3) + module = triton.compiler.ttir_to_ttgir(module, num_warps=4, num_stages=3) if args.target == 'triton-gpu-ir': print(module.str()) exit(0) # triton-gpu-ir -> llvm-ir - module = triton.compiler.make_llvm_ir(module) + module = triton.compiler.ttgir_to_llir(module, extern_libs=None) if args.target == 'llvm-ir': print(module) exit(0) @@ -56,6 +55,6 @@ if __name__ == '__main__': raise argparse.ArgumentError(None, "Must specify --ptx-version for PTX compilation") # llvm-ir -> ptx - module = triton.compiler.make_ptx(module, compute_capability=args.sm, ptx_version=args.ptx_version) + module = triton.compiler.llir_to_ptx(module, compute_capability=args.sm, ptx_version=args.ptx_version) assert args.target == 'ptx' print(module)