[Triton-MLIR] Add compute capability (#902)

add compute capability from python frontend to backend.

Co-authored-by: Keren Zhou <kerenzhou@openai.com>
This commit is contained in:
ben-zhang-609
2022-11-23 03:08:23 +08:00
committed by GitHub
parent 2afebcd79b
commit 07786dc932
7 changed files with 123 additions and 83 deletions

View File

@@ -13,7 +13,7 @@ std::unique_ptr<Pass> createTritonGPUCanonicalizeLoopsPass();
std::unique_ptr<Pass> createTritonGPUCoalescePass(); std::unique_ptr<Pass> createTritonGPUCoalescePass();
std::unique_ptr<Pass> createTritonGPUCombineOpsPass(); std::unique_ptr<Pass> createTritonGPUCombineOpsPass(int computeCapability = 80);
std::unique_ptr<Pass> createTritonGPUVerifier(); std::unique_ptr<Pass> createTritonGPUVerifier();

View File

@@ -63,6 +63,12 @@ def TritonGPUCombineOps : Pass<"tritongpu-combine", "mlir::ModuleOp"> {
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::TritonDialect"]; "mlir::triton::TritonDialect"];
let options = [
Option<"computeCapability", "compute-capability",
"int32_t", /*default*/"80",
"device compute capability">
];
} }
def TritonGPUCanonicalizeLoops: Pass<"tritongpu-canonicalize-loops", "mlir::ModuleOp"> { def TritonGPUCanonicalizeLoops: Pass<"tritongpu-canonicalize-loops", "mlir::ModuleOp"> {

View File

@@ -15,6 +15,7 @@
#include "triton/Analysis/Utility.h" #include "triton/Analysis/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h" #include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
#include <memory> #include <memory>
@@ -560,11 +561,37 @@ public:
// ----------------------------------------------------------------------------- // -----------------------------------------------------------------------------
// //
// ----------------------------------------------------------------------------- // -----------------------------------------------------------------------------
static int computeCapabilityToMMAVersion(int computeCapability) {
if (computeCapability < 80) {
return 1;
} else if (computeCapability < 90) {
return 2;
} else {
assert(false && "computeCapability > 90 not supported");
return 0;
}
}
static SmallVector<int64_t, 2>
mmaVersionToShapePerWarp(int version, const ArrayRef<int64_t> &shape,
int numWarps) {
if (version == 1) {
return {16, 16};
} else if (version == 2) {
return {16, 8};
} else {
assert(false && "version not supported");
return {0, 0};
}
}
class BlockedToMMA : public mlir::RewritePattern { class BlockedToMMA : public mlir::RewritePattern {
int computeCapability;
public: public:
BlockedToMMA(mlir::MLIRContext *context) BlockedToMMA(mlir::MLIRContext *context, int computeCapability)
: mlir::RewritePattern(triton::DotOp::getOperationName(), 2, context) {} : mlir::RewritePattern(triton::DotOp::getOperationName(), 2, context),
computeCapability(computeCapability) {}
static SmallVector<unsigned, 2> static SmallVector<unsigned, 2>
getWarpsPerTile(const ArrayRef<int64_t> &shape, int version, int numWarps) { getWarpsPerTile(const ArrayRef<int64_t> &shape, int version, int numWarps) {
@@ -572,7 +599,8 @@ public:
// TODO: Handle one warp per row for fused matmuls // TODO: Handle one warp per row for fused matmuls
// TODO: unsigned -> int64_t to keep things uniform // TODO: unsigned -> int64_t to keep things uniform
SmallVector<unsigned, 2> ret = {1, 1}; SmallVector<unsigned, 2> ret = {1, 1};
SmallVector<int64_t, 2> shapePerWarp = {16, 8}; SmallVector<int64_t, 2> shapePerWarp =
mmaVersionToShapePerWarp(version, shape, numWarps);
bool changed = false; bool changed = false;
// TODO (@daadaada): double-check. // TODO (@daadaada): double-check.
// original logic in // original logic in
@@ -615,11 +643,12 @@ public:
auto retShape = oldRetType.getShape(); auto retShape = oldRetType.getShape();
auto mod = op->getParentOfType<mlir::ModuleOp>(); auto mod = op->getParentOfType<mlir::ModuleOp>();
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod); int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
auto newRetType = int version = computeCapabilityToMMAVersion(computeCapability);
RankedTensorType::get(retShape, oldRetType.getElementType(), auto newRetType = RankedTensorType::get(
retShape, oldRetType.getElementType(),
triton::gpu::MmaEncodingAttr::get( triton::gpu::MmaEncodingAttr::get(
oldRetType.getContext(), 2, oldRetType.getContext(), version,
getWarpsPerTile(retShape, 2, numWarps))); getWarpsPerTile(retShape, version, numWarps)));
// convert accumulator // convert accumulator
auto oldAcc = dotOp.getOperand(2); auto oldAcc = dotOp.getOperand(2);
auto newAcc = rewriter.create<triton::gpu::ConvertLayoutOp>( auto newAcc = rewriter.create<triton::gpu::ConvertLayoutOp>(
@@ -656,6 +685,10 @@ public:
class TritonGPUCombineOpsPass class TritonGPUCombineOpsPass
: public TritonGPUCombineOpsBase<TritonGPUCombineOpsPass> { : public TritonGPUCombineOpsBase<TritonGPUCombineOpsPass> {
public: public:
TritonGPUCombineOpsPass() = default;
TritonGPUCombineOpsPass(int computeCapability) {
this->computeCapability = computeCapability;
}
void runOnOperation() override { void runOnOperation() override {
MLIRContext *context = &getContext(); MLIRContext *context = &getContext();
ModuleOp m = getOperation(); ModuleOp m = getOperation();
@@ -667,7 +700,7 @@ public:
patterns.add<RematerializeBackward>(context); patterns.add<RematerializeBackward>(context);
patterns.add<RematerializeForward>(context); patterns.add<RematerializeForward>(context);
patterns.add<MoveConvertOutOfLoop>(context); patterns.add<MoveConvertOutOfLoop>(context);
patterns.add<BlockedToMMA>(context); patterns.add<BlockedToMMA>(context, computeCapability);
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) { if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) {
signalPassFailure(); signalPassFailure();
@@ -675,6 +708,7 @@ public:
} }
}; };
std::unique_ptr<Pass> mlir::createTritonGPUCombineOpsPass() { std::unique_ptr<Pass>
return std::make_unique<TritonGPUCombineOpsPass>(); mlir::createTritonGPUCombineOpsPass(int computeCapability) {
return std::make_unique<TritonGPUCombineOpsPass>(computeCapability);
} }

View File

@@ -1279,8 +1279,9 @@ void init_triton_ir(py::module &&m) {
self.addPass(mlir::createTritonGPUPrefetchPass()); self.addPass(mlir::createTritonGPUPrefetchPass());
}) })
.def("add_triton_gpu_combine_pass", .def("add_triton_gpu_combine_pass",
[](mlir::PassManager &self) { [](mlir::PassManager &self, int computeCapability) {
self.addPass(mlir::createTritonGPUCombineOpsPass()); self.addPass(
mlir::createTritonGPUCombineOpsPass(computeCapability));
}) })
.def("add_triton_gpu_to_llvm", .def("add_triton_gpu_to_llvm",
[](mlir::PassManager &self) { [](mlir::PassManager &self) {
@@ -1301,7 +1302,7 @@ void init_triton_translation(py::module &m) {
m.def( m.def(
"translate_triton_gpu_to_llvmir", "translate_triton_gpu_to_llvmir",
[](mlir::ModuleOp op) { [](mlir::ModuleOp op, int computeCapability) {
llvm::LLVMContext llvmContext; llvm::LLVMContext llvmContext;
auto llvmModule = auto llvmModule =
::mlir::triton::translateTritonGPUToLLVMIR(&llvmContext, op); ::mlir::triton::translateTritonGPUToLLVMIR(&llvmContext, op);

View File

@@ -15,10 +15,10 @@ import sysconfig
import tempfile import tempfile
import warnings import warnings
from collections import namedtuple from collections import namedtuple
from pathlib import Path
from sysconfig import get_paths from sysconfig import get_paths
from typing import Any, Callable, Dict, Tuple, Union from typing import Any, Callable, Dict, Tuple, Union
from pathlib import Path
import setuptools import setuptools
import torch import torch
from filelock import FileLock from filelock import FileLock
@@ -876,23 +876,23 @@ def ast_to_ttir(fn, signature, specialization, constants):
return optimize_triton_ir(mod) return optimize_triton_ir(mod)
def ttir_to_ttgir(mod, num_warps, num_stages): def ttir_to_ttgir(mod, num_warps, num_stages, compute_capability):
pm = _triton.ir.pass_manager(mod.context) pm = _triton.ir.pass_manager(mod.context)
pm.add_convert_triton_to_tritongpu_pass(num_warps) pm.add_convert_triton_to_tritongpu_pass(num_warps)
pm.enable_debug() pm.enable_debug()
# Convert blocked layout to mma layout for dot ops so that pipeline # Convert blocked layout to mma layout for dot ops so that pipeline
# can get shared memory swizzled correctly. # can get shared memory swizzled correctly.
pm.add_coalesce_pass() pm.add_coalesce_pass()
pm.add_triton_gpu_combine_pass() pm.add_triton_gpu_combine_pass(compute_capability)
pm.add_tritongpu_pipeline_pass(num_stages) pm.add_tritongpu_pipeline_pass(num_stages)
# Prefetch must be done after pipeline pass because pipeline pass # Prefetch must be done after pipeline pass because pipeline pass
# extracts slices from the original tensor. # extracts slices from the original tensor.
pm.add_tritongpu_prefetch_pass() pm.add_tritongpu_prefetch_pass()
pm.add_canonicalizer_pass() pm.add_canonicalizer_pass()
pm.add_cse_pass() pm.add_cse_pass()
pm.add_triton_gpu_combine_pass() pm.add_triton_gpu_combine_pass(compute_capability)
pm.add_licm_pass() pm.add_licm_pass()
pm.add_triton_gpu_combine_pass() pm.add_triton_gpu_combine_pass(compute_capability)
pm.add_cse_pass() pm.add_cse_pass()
pm.run(mod) pm.run(mod)
return mod return mod
@@ -905,13 +905,13 @@ def add_external_libs(mod, libs):
_triton.add_external_libs(mod, list(libs.keys()), list(libs.values())) _triton.add_external_libs(mod, list(libs.keys()), list(libs.values()))
def ttgir_to_llir(mod, extern_libs): def ttgir_to_llir(mod, extern_libs, compute_capability):
if extern_libs: if extern_libs:
add_external_libs(mod, extern_libs) add_external_libs(mod, extern_libs)
return _triton.translate_triton_gpu_to_llvmir(mod) return _triton.translate_triton_gpu_to_llvmir(mod, compute_capability)
def llir_to_ptx(mod: Any, compute_capability: int = None, ptx_version: int = None) -> Tuple[str, int]: def llir_to_ptx(mod: Any, compute_capability: int, ptx_version: int = None) -> Tuple[str, int]:
''' '''
Translate TritonGPU module to PTX code. Translate TritonGPU module to PTX code.
:param mod: a TritonGPU dialect module :param mod: a TritonGPU dialect module
@@ -919,26 +919,20 @@ def llir_to_ptx(mod: Any, compute_capability: int = None, ptx_version: int = Non
- PTX code - PTX code
- shared memory allocation size - shared memory allocation size
''' '''
if compute_capability is None:
device = torch.cuda.current_device()
compute_capability = torch.cuda.get_device_capability(device)
compute_capability = compute_capability[0] * 10 + compute_capability[1]
if ptx_version is None: if ptx_version is None:
_, cuda_version = path_to_ptxas() _, cuda_version = path_to_ptxas()
ptx_version = ptx_get_version(cuda_version) ptx_version = ptx_get_version(cuda_version)
return _triton.translate_llvmir_to_ptx(mod, compute_capability, ptx_version) return _triton.translate_llvmir_to_ptx(mod, compute_capability, ptx_version)
def ptx_to_cubin(ptx: str, device: int): def ptx_to_cubin(ptx: str, compute_capability: int):
''' '''
Compile TritonGPU module to cubin. Compile TritonGPU module to cubin.
:param ptx: ptx code :param ptx: ptx code
:param device: CUDA device :param compute_capability: compute capability
:return: str :return: str
''' '''
ptxas, _ = path_to_ptxas() ptxas, _ = path_to_ptxas()
compute_capability = torch.cuda.get_device_capability(device)
compute_capability = compute_capability[0] * 10 + compute_capability[1]
return _triton.compile_ptx_to_cubin(ptx, ptxas, compute_capability) return _triton.compile_ptx_to_cubin(ptx, ptxas, compute_capability)
@@ -1304,6 +1298,8 @@ def read_or_execute(cache_manager, force_compile, file_name, metadata,
return module, md5, True, False return module, md5, True, False
# #
def make_stub(name, signature, constants): def make_stub(name, signature, constants):
# name of files that are cached # name of files that are cached
so_cache_key = make_so_cache_key(signature, constants) so_cache_key = make_so_cache_key(signature, constants)
@@ -1328,6 +1324,7 @@ def convert_type_repr(x):
return '*' + convert_type_repr(match.group(1)) return '*' + convert_type_repr(match.group(1))
return x return x
def make_hash(fn, **kwargs): def make_hash(fn, **kwargs):
if isinstance(fn, triton.runtime.JITFunction): if isinstance(fn, triton.runtime.JITFunction):
configs = kwargs["configs"] configs = kwargs["configs"]
@@ -1344,14 +1341,13 @@ def make_hash(fn, **kwargs):
return hashlib.md5(Path(fn).read_text().encode("utf-8")).hexdigest() return hashlib.md5(Path(fn).read_text().encode("utf-8")).hexdigest()
# def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: int = 4, num_stages: int = 3, extern_libs=None, configs=None): # def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: int = 4, num_stages: int = 3, extern_libs=None, configs=None):
def compile(fn, **kwargs): def compile(fn, **kwargs):
# we get the kernel, i.e. the first function generated in the module # we get the kernel, i.e. the first function generated in the module
# if fn is not a JITFunction, then it # if fn is not a JITFunction, then it
# has to be a path to a file # has to be a path to a file
context = _triton.ir.context() context = _triton.ir.context()
asm, md5 = dict(), dict() asm = dict()
constants = kwargs.get("constants", dict()) constants = kwargs.get("constants", dict())
if isinstance(fn, triton.runtime.JITFunction): if isinstance(fn, triton.runtime.JITFunction):
configs = kwargs.get("configs", None) configs = kwargs.get("configs", None)
@@ -1389,6 +1385,8 @@ def compile(fn, **kwargs):
num_stages = kwargs.get("num_stages", 3) num_stages = kwargs.get("num_stages", 3)
extern_libs = kwargs.get("extern_libs", dict()) extern_libs = kwargs.get("extern_libs", dict())
device = kwargs.get("device", torch.cuda.current_device()) device = kwargs.get("device", torch.cuda.current_device())
compute_capability = torch.cuda.get_device_capability(device)
compute_capability = compute_capability[0] * 10 + compute_capability[1]
# load metadata if any # load metadata if any
metadata = None metadata = None
if fn_cache_manager.has_file(f'{name}.json'): if fn_cache_manager.has_file(f'{name}.json'):
@@ -1398,17 +1396,17 @@ def compile(fn, **kwargs):
metadata = {"num_warps": num_warps, "num_stages": num_stages, "ctime": dict()} metadata = {"num_warps": num_warps, "num_stages": num_stages, "ctime": dict()}
# build compilation stages # build compilation stages
stages = { stages = {
"ast" : (lambda path: fn, None), "ast": (lambda path: fn, None),
"ttir": (lambda path: _triton.ir.parse_mlir_module(path, context), "ttir": (lambda path: _triton.ir.parse_mlir_module(path, context),
lambda src: ast_to_ttir(src, signature, configs[0], constants)), lambda src: ast_to_ttir(src, signature, configs[0], constants)),
"ttgir": (lambda path: _triton.ir.parse_mlir_module(path, context), "ttgir": (lambda path: _triton.ir.parse_mlir_module(path, context),
lambda src: ttir_to_ttgir(src, num_warps, num_stages)), lambda src: ttir_to_ttgir(src, num_warps, num_stages, compute_capability)),
"llir": (lambda path: Path(path).read_bytes(), "llir": (lambda path: Path(path).read_bytes(),
lambda src: ttgir_to_llir(src, extern_libs)), lambda src: ttgir_to_llir(src, extern_libs, compute_capability)),
"ptx": (lambda path: Path(path).read_text(), "ptx": (lambda path: Path(path).read_text(),
llir_to_ptx), lambda src: llir_to_ptx(src, compute_capability)),
"cubin": (lambda path: Path(path).read_bytes(), "cubin": (lambda path: Path(path).read_bytes(),
lambda src: ptx_to_cubin(src, device)) lambda src: ptx_to_cubin(src, compute_capability))
} }
first_stage = list(stages.keys()).index(ext) first_stage = list(stages.keys()).index(ext)
asm = dict() asm = dict()

View File

@@ -37,20 +37,21 @@ if __name__ == '__main__':
print(module.str()) print(module.str())
exit(0) exit(0)
if not args.sm:
raise argparse.ArgumentError(None, "Must specify --sm for PTX compilation")
# triton-ir -> triton-gpu-ir # triton-ir -> triton-gpu-ir
module = triton.compiler.ttir_to_ttgir(module, num_warps=4, num_stages=3) module = triton.compiler.ttir_to_ttgir(module, num_warps=4, num_stages=3, compute_capability=args.sm)
if args.target == 'triton-gpu-ir': if args.target == 'triton-gpu-ir':
print(module.str()) print(module.str())
exit(0) exit(0)
# triton-gpu-ir -> llvm-ir # triton-gpu-ir -> llvm-ir
module = triton.compiler.ttgir_to_llir(module, extern_libs=None) module = triton.compiler.ttgir_to_llir(module, extern_libs=None, compute_capability=args.sm)
if args.target == 'llvm-ir': if args.target == 'llvm-ir':
print(module) print(module)
exit(0) exit(0)
if not args.sm:
raise argparse.ArgumentError(None, "Must specify --sm for PTX compilation")
if not args.ptx_version: if not args.ptx_version:
raise argparse.ArgumentError(None, "Must specify --ptx-version for PTX compilation") raise argparse.ArgumentError(None, "Must specify --ptx-version for PTX compilation")

View File

@@ -1,4 +1,4 @@
// RUN: %PYTHON -m triton.tools.aot %s --target=llvm-ir | FileCheck %s // RUN: %PYTHON -m triton.tools.aot %s --target=llvm-ir --sm=80 | FileCheck %s
// == LLVM IR check begin == // == LLVM IR check begin ==
// CHECK-LABEL: ; ModuleID = 'LLVMDialectModule' // CHECK-LABEL: ; ModuleID = 'LLVMDialectModule'