From 07786dc932c4abf38c900947550148c466aafd9e Mon Sep 17 00:00:00 2001 From: ben-zhang-609 <110140741+ben-zhang-609@users.noreply.github.com> Date: Wed, 23 Nov 2022 03:08:23 +0800 Subject: [PATCH] [Triton-MLIR] Add compute capability (#902) add compute capability from python frontend to backend. Co-authored-by: Keren Zhou --- .../Dialect/TritonGPU/Transforms/Passes.h | 2 +- .../Dialect/TritonGPU/Transforms/Passes.td | 8 +- lib/Dialect/TritonGPU/Transforms/Combine.cpp | 56 ++++++-- python/src/triton.cc | 7 +- python/triton/compiler.py | 122 +++++++++--------- python/triton/tools/aot.py | 9 +- test/Target/tritongpu_to_llvmir.mlir | 2 +- 7 files changed, 123 insertions(+), 83 deletions(-) diff --git a/include/triton/Dialect/TritonGPU/Transforms/Passes.h b/include/triton/Dialect/TritonGPU/Transforms/Passes.h index e570d60d5..7e02fb2b9 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Passes.h +++ b/include/triton/Dialect/TritonGPU/Transforms/Passes.h @@ -13,7 +13,7 @@ std::unique_ptr createTritonGPUCanonicalizeLoopsPass(); std::unique_ptr createTritonGPUCoalescePass(); -std::unique_ptr createTritonGPUCombineOpsPass(); +std::unique_ptr createTritonGPUCombineOpsPass(int computeCapability = 80); std::unique_ptr createTritonGPUVerifier(); diff --git a/include/triton/Dialect/TritonGPU/Transforms/Passes.td b/include/triton/Dialect/TritonGPU/Transforms/Passes.td index caa85a950..f22a76c55 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Passes.td +++ b/include/triton/Dialect/TritonGPU/Transforms/Passes.td @@ -41,7 +41,7 @@ def TritonGPUCoalesce: Pass<"tritongpu-coalesce", "mlir::ModuleOp"> { let summary = "coalesce"; let description = [{ - TODO + TODO }]; let constructor = "mlir::createTritonGPUCoalescePass()"; @@ -63,6 +63,12 @@ def TritonGPUCombineOps : Pass<"tritongpu-combine", "mlir::ModuleOp"> { let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", "mlir::triton::TritonDialect"]; + + let options = [ + Option<"computeCapability", "compute-capability", + "int32_t", /*default*/"80", + "device compute capability"> + ]; } def TritonGPUCanonicalizeLoops: Pass<"tritongpu-canonicalize-loops", "mlir::ModuleOp"> { diff --git a/lib/Dialect/TritonGPU/Transforms/Combine.cpp b/lib/Dialect/TritonGPU/Transforms/Combine.cpp index 6133f9381..fb6124d52 100644 --- a/lib/Dialect/TritonGPU/Transforms/Combine.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Combine.cpp @@ -15,6 +15,7 @@ #include "triton/Analysis/Utility.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h" #include @@ -560,11 +561,37 @@ public: // ----------------------------------------------------------------------------- // // ----------------------------------------------------------------------------- +static int computeCapabilityToMMAVersion(int computeCapability) { + if (computeCapability < 80) { + return 1; + } else if (computeCapability < 90) { + return 2; + } else { + assert(false && "computeCapability > 90 not supported"); + return 0; + } +} + +static SmallVector +mmaVersionToShapePerWarp(int version, const ArrayRef &shape, + int numWarps) { + if (version == 1) { + return {16, 16}; + } else if (version == 2) { + return {16, 8}; + } else { + assert(false && "version not supported"); + return {0, 0}; + } +} class BlockedToMMA : public mlir::RewritePattern { + int computeCapability; + public: - BlockedToMMA(mlir::MLIRContext *context) - : mlir::RewritePattern(triton::DotOp::getOperationName(), 2, context) {} + BlockedToMMA(mlir::MLIRContext *context, int computeCapability) + : mlir::RewritePattern(triton::DotOp::getOperationName(), 2, context), + computeCapability(computeCapability) {} static SmallVector getWarpsPerTile(const ArrayRef &shape, int version, int numWarps) { @@ -572,7 +599,8 @@ public: // TODO: Handle one warp per row for fused matmuls // TODO: unsigned -> int64_t to keep things uniform SmallVector ret = {1, 1}; - SmallVector shapePerWarp = {16, 8}; + SmallVector shapePerWarp = + mmaVersionToShapePerWarp(version, shape, numWarps); bool changed = false; // TODO (@daadaada): double-check. // original logic in @@ -615,11 +643,12 @@ public: auto retShape = oldRetType.getShape(); auto mod = op->getParentOfType(); int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod); - auto newRetType = - RankedTensorType::get(retShape, oldRetType.getElementType(), - triton::gpu::MmaEncodingAttr::get( - oldRetType.getContext(), 2, - getWarpsPerTile(retShape, 2, numWarps))); + int version = computeCapabilityToMMAVersion(computeCapability); + auto newRetType = RankedTensorType::get( + retShape, oldRetType.getElementType(), + triton::gpu::MmaEncodingAttr::get( + oldRetType.getContext(), version, + getWarpsPerTile(retShape, version, numWarps))); // convert accumulator auto oldAcc = dotOp.getOperand(2); auto newAcc = rewriter.create( @@ -656,6 +685,10 @@ public: class TritonGPUCombineOpsPass : public TritonGPUCombineOpsBase { public: + TritonGPUCombineOpsPass() = default; + TritonGPUCombineOpsPass(int computeCapability) { + this->computeCapability = computeCapability; + } void runOnOperation() override { MLIRContext *context = &getContext(); ModuleOp m = getOperation(); @@ -667,7 +700,7 @@ public: patterns.add(context); patterns.add(context); patterns.add(context); - patterns.add(context); + patterns.add(context, computeCapability); if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) { signalPassFailure(); @@ -675,6 +708,7 @@ public: } }; -std::unique_ptr mlir::createTritonGPUCombineOpsPass() { - return std::make_unique(); +std::unique_ptr +mlir::createTritonGPUCombineOpsPass(int computeCapability) { + return std::make_unique(computeCapability); } diff --git a/python/src/triton.cc b/python/src/triton.cc index f7c29dd4a..f133b6c14 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -1279,8 +1279,9 @@ void init_triton_ir(py::module &&m) { self.addPass(mlir::createTritonGPUPrefetchPass()); }) .def("add_triton_gpu_combine_pass", - [](mlir::PassManager &self) { - self.addPass(mlir::createTritonGPUCombineOpsPass()); + [](mlir::PassManager &self, int computeCapability) { + self.addPass( + mlir::createTritonGPUCombineOpsPass(computeCapability)); }) .def("add_triton_gpu_to_llvm", [](mlir::PassManager &self) { @@ -1301,7 +1302,7 @@ void init_triton_translation(py::module &m) { m.def( "translate_triton_gpu_to_llvmir", - [](mlir::ModuleOp op) { + [](mlir::ModuleOp op, int computeCapability) { llvm::LLVMContext llvmContext; auto llvmModule = ::mlir::triton::translateTritonGPUToLLVMIR(&llvmContext, op); diff --git a/python/triton/compiler.py b/python/triton/compiler.py index d74b1f4fd..82e8168aa 100644 --- a/python/triton/compiler.py +++ b/python/triton/compiler.py @@ -15,10 +15,10 @@ import sysconfig import tempfile import warnings from collections import namedtuple +from pathlib import Path from sysconfig import get_paths from typing import Any, Callable, Dict, Tuple, Union -from pathlib import Path import setuptools import torch from filelock import FileLock @@ -828,7 +828,7 @@ def kernel_suffix(signature, specialization): def build_triton_ir(fn, signature, specialization, constants): # canonicalize signature if isinstance(signature, str): - signature = {k: v.strip() for k, v in enumerate(signature.split(","))} + signature = {k: v.strip() for k, v in enumerate(signature.split(","))} context = _triton.ir.context() context.load_triton() # create kernel prototype @@ -876,23 +876,23 @@ def ast_to_ttir(fn, signature, specialization, constants): return optimize_triton_ir(mod) -def ttir_to_ttgir(mod, num_warps, num_stages): +def ttir_to_ttgir(mod, num_warps, num_stages, compute_capability): pm = _triton.ir.pass_manager(mod.context) pm.add_convert_triton_to_tritongpu_pass(num_warps) pm.enable_debug() # Convert blocked layout to mma layout for dot ops so that pipeline # can get shared memory swizzled correctly. pm.add_coalesce_pass() - pm.add_triton_gpu_combine_pass() + pm.add_triton_gpu_combine_pass(compute_capability) pm.add_tritongpu_pipeline_pass(num_stages) # Prefetch must be done after pipeline pass because pipeline pass # extracts slices from the original tensor. pm.add_tritongpu_prefetch_pass() pm.add_canonicalizer_pass() pm.add_cse_pass() - pm.add_triton_gpu_combine_pass() + pm.add_triton_gpu_combine_pass(compute_capability) pm.add_licm_pass() - pm.add_triton_gpu_combine_pass() + pm.add_triton_gpu_combine_pass(compute_capability) pm.add_cse_pass() pm.run(mod) return mod @@ -905,13 +905,13 @@ def add_external_libs(mod, libs): _triton.add_external_libs(mod, list(libs.keys()), list(libs.values())) -def ttgir_to_llir(mod, extern_libs): +def ttgir_to_llir(mod, extern_libs, compute_capability): if extern_libs: add_external_libs(mod, extern_libs) - return _triton.translate_triton_gpu_to_llvmir(mod) + return _triton.translate_triton_gpu_to_llvmir(mod, compute_capability) -def llir_to_ptx(mod: Any, compute_capability: int = None, ptx_version: int = None) -> Tuple[str, int]: +def llir_to_ptx(mod: Any, compute_capability: int, ptx_version: int = None) -> Tuple[str, int]: ''' Translate TritonGPU module to PTX code. :param mod: a TritonGPU dialect module @@ -919,26 +919,20 @@ def llir_to_ptx(mod: Any, compute_capability: int = None, ptx_version: int = Non - PTX code - shared memory allocation size ''' - if compute_capability is None: - device = torch.cuda.current_device() - compute_capability = torch.cuda.get_device_capability(device) - compute_capability = compute_capability[0] * 10 + compute_capability[1] if ptx_version is None: _, cuda_version = path_to_ptxas() ptx_version = ptx_get_version(cuda_version) return _triton.translate_llvmir_to_ptx(mod, compute_capability, ptx_version) -def ptx_to_cubin(ptx: str, device: int): +def ptx_to_cubin(ptx: str, compute_capability: int): ''' Compile TritonGPU module to cubin. :param ptx: ptx code - :param device: CUDA device + :param compute_capability: compute capability :return: str ''' ptxas, _ = path_to_ptxas() - compute_capability = torch.cuda.get_device_capability(device) - compute_capability = compute_capability[0] * 10 + compute_capability[1] return _triton.compile_ptx_to_cubin(ptx, ptxas, compute_capability) @@ -1190,7 +1184,7 @@ class CacheManager: return binary = isinstance(data, bytes) if not binary: - data = str(data) + data = str(data) assert self.lock_path is not None filepath = self._make_path(filename) with FileLock(self.lock_path): @@ -1292,18 +1286,20 @@ def read_or_execute(cache_manager, force_compile, file_name, metadata, run_if_not_found: Callable = None): suffix = file_name.split(".")[1] if not force_compile and cache_manager.has_file(file_name): - module = run_if_found(cache_manager._make_path(file_name)) - data = module if isinstance(module, bytes) else str(module).encode("utf-8") - md5 = hashlib.md5(data).hexdigest() - has_changed = metadata and md5 != metadata["md5"][suffix] - return module, md5, has_changed, True + module = run_if_found(cache_manager._make_path(file_name)) + data = module if isinstance(module, bytes) else str(module).encode("utf-8") + md5 = hashlib.md5(data).hexdigest() + has_changed = metadata and md5 != metadata["md5"][suffix] + return module, md5, has_changed, True module = run_if_not_found() data = module if isinstance(module, bytes) else str(module).encode("utf-8") md5 = hashlib.md5(data).hexdigest() cache_manager.put(data, file_name, True if isinstance(data, bytes) else data) return module, md5, True, False -# +# + + def make_stub(name, signature, constants): # name of files that are cached so_cache_key = make_so_cache_key(signature, constants) @@ -1325,9 +1321,10 @@ def make_stub(name, signature, constants): def convert_type_repr(x): match = re.search(r'!tt\.ptr<(.*)>', x) if match is not None: - return '*' + convert_type_repr(match.group(1)) + return '*' + convert_type_repr(match.group(1)) return x + def make_hash(fn, **kwargs): if isinstance(fn, triton.runtime.JITFunction): configs = kwargs["configs"] @@ -1344,14 +1341,13 @@ def make_hash(fn, **kwargs): return hashlib.md5(Path(fn).read_text().encode("utf-8")).hexdigest() - # def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: int = 4, num_stages: int = 3, extern_libs=None, configs=None): def compile(fn, **kwargs): # we get the kernel, i.e. the first function generated in the module - # if fn is not a JITFunction, then it + # if fn is not a JITFunction, then it # has to be a path to a file context = _triton.ir.context() - asm, md5 = dict(), dict() + asm = dict() constants = kwargs.get("constants", dict()) if isinstance(fn, triton.runtime.JITFunction): configs = kwargs.get("configs", None) @@ -1374,64 +1370,66 @@ def compile(fn, **kwargs): param_tys = [convert_type_repr(str(ty)) for ty in function.type.param_types()] signature = {k: v for k, v in enumerate(param_tys)} first_stage = 2 - + # cache manager so_path = make_stub(name, signature, constants) # create cache manager fn_cache_manager = CacheManager(make_hash(fn, **kwargs)) # determine name and extension type of provided function if isinstance(fn, triton.runtime.JITFunction): - name, ext = fn.__name__, "ast" + name, ext = fn.__name__, "ast" else: - name, ext = os.path.basename(fn).split(".") + name, ext = os.path.basename(fn).split(".") # initialize compilation params num_warps = kwargs.get("num_warps", 4) num_stages = kwargs.get("num_stages", 3) extern_libs = kwargs.get("extern_libs", dict()) device = kwargs.get("device", torch.cuda.current_device()) + compute_capability = torch.cuda.get_device_capability(device) + compute_capability = compute_capability[0] * 10 + compute_capability[1] # load metadata if any metadata = None if fn_cache_manager.has_file(f'{name}.json'): - with open(fn_cache_manager._make_path(f"{name}.json")) as f: + with open(fn_cache_manager._make_path(f"{name}.json")) as f: metadata = json.load(f) else: - metadata = {"num_warps": num_warps, "num_stages": num_stages, "ctime": dict()} + metadata = {"num_warps": num_warps, "num_stages": num_stages, "ctime": dict()} # build compilation stages stages = { - "ast" : (lambda path: fn, None), - "ttir": (lambda path: _triton.ir.parse_mlir_module(path, context), - lambda src: ast_to_ttir(src, signature, configs[0], constants)), - "ttgir": (lambda path: _triton.ir.parse_mlir_module(path, context), - lambda src: ttir_to_ttgir(src, num_warps, num_stages)), - "llir": (lambda path: Path(path).read_bytes(), - lambda src: ttgir_to_llir(src, extern_libs)), - "ptx": (lambda path: Path(path).read_text(), - llir_to_ptx), - "cubin": (lambda path: Path(path).read_bytes(), - lambda src: ptx_to_cubin(src, device)) + "ast": (lambda path: fn, None), + "ttir": (lambda path: _triton.ir.parse_mlir_module(path, context), + lambda src: ast_to_ttir(src, signature, configs[0], constants)), + "ttgir": (lambda path: _triton.ir.parse_mlir_module(path, context), + lambda src: ttir_to_ttgir(src, num_warps, num_stages, compute_capability)), + "llir": (lambda path: Path(path).read_bytes(), + lambda src: ttgir_to_llir(src, extern_libs, compute_capability)), + "ptx": (lambda path: Path(path).read_text(), + lambda src: llir_to_ptx(src, compute_capability)), + "cubin": (lambda path: Path(path).read_bytes(), + lambda src: ptx_to_cubin(src, compute_capability)) } first_stage = list(stages.keys()).index(ext) asm = dict() module = fn # run compilation pipeline and populate metadata for ir, (parse, compile) in list(stages.items())[first_stage:]: - path = fn_cache_manager._make_path(f"{name}.{ir}") - if ir == ext: - next_module = parse(fn) - elif os.path.exists(path) and\ - os.path.getctime(path) == metadata["ctime"][ir]: - next_module = parse(path) - else: - next_module = compile(module) - fn_cache_manager.put(next_module, f"{name}.{ir}") - if os.path.exists(path): - metadata["ctime"][ir] = os.path.getctime(path) - asm[ir] = next_module if ir == "cubin" else str(next_module) - if ir == "llir" and "shared" not in metadata: - metadata["shared"] = _triton.get_shared_memory_size(module) - if ir == "ptx": - metadata["name"] = ptx_get_kernel_name(next_module) - module = next_module + path = fn_cache_manager._make_path(f"{name}.{ir}") + if ir == ext: + next_module = parse(fn) + elif os.path.exists(path) and\ + os.path.getctime(path) == metadata["ctime"][ir]: + next_module = parse(path) + else: + next_module = compile(module) + fn_cache_manager.put(next_module, f"{name}.{ir}") + if os.path.exists(path): + metadata["ctime"][ir] = os.path.getctime(path) + asm[ir] = next_module if ir == "cubin" else str(next_module) + if ir == "llir" and "shared" not in metadata: + metadata["shared"] = _triton.get_shared_memory_size(module) + if ir == "ptx": + metadata["name"] = ptx_get_kernel_name(next_module) + module = next_module # write-back metadata fn_cache_manager.put(json.dumps(metadata), f"{name}.json", binary=False) # return handle to compiled kernel @@ -1494,7 +1492,7 @@ class CudaUtils(object): #include #include \"cuda.h\" - #define PY_SSIZE_T_CLEAN + #define PY_SSIZE_T_CLEAN #include static inline void gpuAssert(CUresult code, const char *file, int line) diff --git a/python/triton/tools/aot.py b/python/triton/tools/aot.py index 28d9e3500..7b5a59fe0 100644 --- a/python/triton/tools/aot.py +++ b/python/triton/tools/aot.py @@ -37,20 +37,21 @@ if __name__ == '__main__': print(module.str()) exit(0) + if not args.sm: + raise argparse.ArgumentError(None, "Must specify --sm for PTX compilation") + # triton-ir -> triton-gpu-ir - module = triton.compiler.ttir_to_ttgir(module, num_warps=4, num_stages=3) + module = triton.compiler.ttir_to_ttgir(module, num_warps=4, num_stages=3, compute_capability=args.sm) if args.target == 'triton-gpu-ir': print(module.str()) exit(0) # triton-gpu-ir -> llvm-ir - module = triton.compiler.ttgir_to_llir(module, extern_libs=None) + module = triton.compiler.ttgir_to_llir(module, extern_libs=None, compute_capability=args.sm) if args.target == 'llvm-ir': print(module) exit(0) - if not args.sm: - raise argparse.ArgumentError(None, "Must specify --sm for PTX compilation") if not args.ptx_version: raise argparse.ArgumentError(None, "Must specify --ptx-version for PTX compilation") diff --git a/test/Target/tritongpu_to_llvmir.mlir b/test/Target/tritongpu_to_llvmir.mlir index b852a1ad6..cafff3ca6 100644 --- a/test/Target/tritongpu_to_llvmir.mlir +++ b/test/Target/tritongpu_to_llvmir.mlir @@ -1,4 +1,4 @@ -// RUN: %PYTHON -m triton.tools.aot %s --target=llvm-ir | FileCheck %s +// RUN: %PYTHON -m triton.tools.aot %s --target=llvm-ir --sm=80 | FileCheck %s // == LLVM IR check begin == // CHECK-LABEL: ; ModuleID = 'LLVMDialectModule'