[Triton-MLIR] Add compute capability (#902)

add compute capability from python frontend to backend.

Co-authored-by: Keren Zhou <kerenzhou@openai.com>
This commit is contained in:
ben-zhang-609
2022-11-23 03:08:23 +08:00
committed by GitHub
parent 2afebcd79b
commit 07786dc932
7 changed files with 123 additions and 83 deletions

View File

@@ -13,7 +13,7 @@ std::unique_ptr<Pass> createTritonGPUCanonicalizeLoopsPass();
std::unique_ptr<Pass> createTritonGPUCoalescePass();
std::unique_ptr<Pass> createTritonGPUCombineOpsPass();
std::unique_ptr<Pass> createTritonGPUCombineOpsPass(int computeCapability = 80);
std::unique_ptr<Pass> createTritonGPUVerifier();

View File

@@ -41,7 +41,7 @@ def TritonGPUCoalesce: Pass<"tritongpu-coalesce", "mlir::ModuleOp"> {
let summary = "coalesce";
let description = [{
TODO
TODO
}];
let constructor = "mlir::createTritonGPUCoalescePass()";
@@ -63,6 +63,12 @@ def TritonGPUCombineOps : Pass<"tritongpu-combine", "mlir::ModuleOp"> {
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::TritonDialect"];
let options = [
Option<"computeCapability", "compute-capability",
"int32_t", /*default*/"80",
"device compute capability">
];
}
def TritonGPUCanonicalizeLoops: Pass<"tritongpu-canonicalize-loops", "mlir::ModuleOp"> {

View File

@@ -15,6 +15,7 @@
#include "triton/Analysis/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
#include <memory>
@@ -560,11 +561,37 @@ public:
// -----------------------------------------------------------------------------
//
// -----------------------------------------------------------------------------
static int computeCapabilityToMMAVersion(int computeCapability) {
if (computeCapability < 80) {
return 1;
} else if (computeCapability < 90) {
return 2;
} else {
assert(false && "computeCapability > 90 not supported");
return 0;
}
}
static SmallVector<int64_t, 2>
mmaVersionToShapePerWarp(int version, const ArrayRef<int64_t> &shape,
int numWarps) {
if (version == 1) {
return {16, 16};
} else if (version == 2) {
return {16, 8};
} else {
assert(false && "version not supported");
return {0, 0};
}
}
class BlockedToMMA : public mlir::RewritePattern {
int computeCapability;
public:
BlockedToMMA(mlir::MLIRContext *context)
: mlir::RewritePattern(triton::DotOp::getOperationName(), 2, context) {}
BlockedToMMA(mlir::MLIRContext *context, int computeCapability)
: mlir::RewritePattern(triton::DotOp::getOperationName(), 2, context),
computeCapability(computeCapability) {}
static SmallVector<unsigned, 2>
getWarpsPerTile(const ArrayRef<int64_t> &shape, int version, int numWarps) {
@@ -572,7 +599,8 @@ public:
// TODO: Handle one warp per row for fused matmuls
// TODO: unsigned -> int64_t to keep things uniform
SmallVector<unsigned, 2> ret = {1, 1};
SmallVector<int64_t, 2> shapePerWarp = {16, 8};
SmallVector<int64_t, 2> shapePerWarp =
mmaVersionToShapePerWarp(version, shape, numWarps);
bool changed = false;
// TODO (@daadaada): double-check.
// original logic in
@@ -615,11 +643,12 @@ public:
auto retShape = oldRetType.getShape();
auto mod = op->getParentOfType<mlir::ModuleOp>();
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
auto newRetType =
RankedTensorType::get(retShape, oldRetType.getElementType(),
triton::gpu::MmaEncodingAttr::get(
oldRetType.getContext(), 2,
getWarpsPerTile(retShape, 2, numWarps)));
int version = computeCapabilityToMMAVersion(computeCapability);
auto newRetType = RankedTensorType::get(
retShape, oldRetType.getElementType(),
triton::gpu::MmaEncodingAttr::get(
oldRetType.getContext(), version,
getWarpsPerTile(retShape, version, numWarps)));
// convert accumulator
auto oldAcc = dotOp.getOperand(2);
auto newAcc = rewriter.create<triton::gpu::ConvertLayoutOp>(
@@ -656,6 +685,10 @@ public:
class TritonGPUCombineOpsPass
: public TritonGPUCombineOpsBase<TritonGPUCombineOpsPass> {
public:
TritonGPUCombineOpsPass() = default;
TritonGPUCombineOpsPass(int computeCapability) {
this->computeCapability = computeCapability;
}
void runOnOperation() override {
MLIRContext *context = &getContext();
ModuleOp m = getOperation();
@@ -667,7 +700,7 @@ public:
patterns.add<RematerializeBackward>(context);
patterns.add<RematerializeForward>(context);
patterns.add<MoveConvertOutOfLoop>(context);
patterns.add<BlockedToMMA>(context);
patterns.add<BlockedToMMA>(context, computeCapability);
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) {
signalPassFailure();
@@ -675,6 +708,7 @@ public:
}
};
std::unique_ptr<Pass> mlir::createTritonGPUCombineOpsPass() {
return std::make_unique<TritonGPUCombineOpsPass>();
std::unique_ptr<Pass>
mlir::createTritonGPUCombineOpsPass(int computeCapability) {
return std::make_unique<TritonGPUCombineOpsPass>(computeCapability);
}

View File

@@ -1279,8 +1279,9 @@ void init_triton_ir(py::module &&m) {
self.addPass(mlir::createTritonGPUPrefetchPass());
})
.def("add_triton_gpu_combine_pass",
[](mlir::PassManager &self) {
self.addPass(mlir::createTritonGPUCombineOpsPass());
[](mlir::PassManager &self, int computeCapability) {
self.addPass(
mlir::createTritonGPUCombineOpsPass(computeCapability));
})
.def("add_triton_gpu_to_llvm",
[](mlir::PassManager &self) {
@@ -1301,7 +1302,7 @@ void init_triton_translation(py::module &m) {
m.def(
"translate_triton_gpu_to_llvmir",
[](mlir::ModuleOp op) {
[](mlir::ModuleOp op, int computeCapability) {
llvm::LLVMContext llvmContext;
auto llvmModule =
::mlir::triton::translateTritonGPUToLLVMIR(&llvmContext, op);

View File

@@ -15,10 +15,10 @@ import sysconfig
import tempfile
import warnings
from collections import namedtuple
from pathlib import Path
from sysconfig import get_paths
from typing import Any, Callable, Dict, Tuple, Union
from pathlib import Path
import setuptools
import torch
from filelock import FileLock
@@ -828,7 +828,7 @@ def kernel_suffix(signature, specialization):
def build_triton_ir(fn, signature, specialization, constants):
# canonicalize signature
if isinstance(signature, str):
signature = {k: v.strip() for k, v in enumerate(signature.split(","))}
signature = {k: v.strip() for k, v in enumerate(signature.split(","))}
context = _triton.ir.context()
context.load_triton()
# create kernel prototype
@@ -876,23 +876,23 @@ def ast_to_ttir(fn, signature, specialization, constants):
return optimize_triton_ir(mod)
def ttir_to_ttgir(mod, num_warps, num_stages):
def ttir_to_ttgir(mod, num_warps, num_stages, compute_capability):
pm = _triton.ir.pass_manager(mod.context)
pm.add_convert_triton_to_tritongpu_pass(num_warps)
pm.enable_debug()
# Convert blocked layout to mma layout for dot ops so that pipeline
# can get shared memory swizzled correctly.
pm.add_coalesce_pass()
pm.add_triton_gpu_combine_pass()
pm.add_triton_gpu_combine_pass(compute_capability)
pm.add_tritongpu_pipeline_pass(num_stages)
# Prefetch must be done after pipeline pass because pipeline pass
# extracts slices from the original tensor.
pm.add_tritongpu_prefetch_pass()
pm.add_canonicalizer_pass()
pm.add_cse_pass()
pm.add_triton_gpu_combine_pass()
pm.add_triton_gpu_combine_pass(compute_capability)
pm.add_licm_pass()
pm.add_triton_gpu_combine_pass()
pm.add_triton_gpu_combine_pass(compute_capability)
pm.add_cse_pass()
pm.run(mod)
return mod
@@ -905,13 +905,13 @@ def add_external_libs(mod, libs):
_triton.add_external_libs(mod, list(libs.keys()), list(libs.values()))
def ttgir_to_llir(mod, extern_libs):
def ttgir_to_llir(mod, extern_libs, compute_capability):
if extern_libs:
add_external_libs(mod, extern_libs)
return _triton.translate_triton_gpu_to_llvmir(mod)
return _triton.translate_triton_gpu_to_llvmir(mod, compute_capability)
def llir_to_ptx(mod: Any, compute_capability: int = None, ptx_version: int = None) -> Tuple[str, int]:
def llir_to_ptx(mod: Any, compute_capability: int, ptx_version: int = None) -> Tuple[str, int]:
'''
Translate TritonGPU module to PTX code.
:param mod: a TritonGPU dialect module
@@ -919,26 +919,20 @@ def llir_to_ptx(mod: Any, compute_capability: int = None, ptx_version: int = Non
- PTX code
- shared memory allocation size
'''
if compute_capability is None:
device = torch.cuda.current_device()
compute_capability = torch.cuda.get_device_capability(device)
compute_capability = compute_capability[0] * 10 + compute_capability[1]
if ptx_version is None:
_, cuda_version = path_to_ptxas()
ptx_version = ptx_get_version(cuda_version)
return _triton.translate_llvmir_to_ptx(mod, compute_capability, ptx_version)
def ptx_to_cubin(ptx: str, device: int):
def ptx_to_cubin(ptx: str, compute_capability: int):
'''
Compile TritonGPU module to cubin.
:param ptx: ptx code
:param device: CUDA device
:param compute_capability: compute capability
:return: str
'''
ptxas, _ = path_to_ptxas()
compute_capability = torch.cuda.get_device_capability(device)
compute_capability = compute_capability[0] * 10 + compute_capability[1]
return _triton.compile_ptx_to_cubin(ptx, ptxas, compute_capability)
@@ -1190,7 +1184,7 @@ class CacheManager:
return
binary = isinstance(data, bytes)
if not binary:
data = str(data)
data = str(data)
assert self.lock_path is not None
filepath = self._make_path(filename)
with FileLock(self.lock_path):
@@ -1292,18 +1286,20 @@ def read_or_execute(cache_manager, force_compile, file_name, metadata,
run_if_not_found: Callable = None):
suffix = file_name.split(".")[1]
if not force_compile and cache_manager.has_file(file_name):
module = run_if_found(cache_manager._make_path(file_name))
data = module if isinstance(module, bytes) else str(module).encode("utf-8")
md5 = hashlib.md5(data).hexdigest()
has_changed = metadata and md5 != metadata["md5"][suffix]
return module, md5, has_changed, True
module = run_if_found(cache_manager._make_path(file_name))
data = module if isinstance(module, bytes) else str(module).encode("utf-8")
md5 = hashlib.md5(data).hexdigest()
has_changed = metadata and md5 != metadata["md5"][suffix]
return module, md5, has_changed, True
module = run_if_not_found()
data = module if isinstance(module, bytes) else str(module).encode("utf-8")
md5 = hashlib.md5(data).hexdigest()
cache_manager.put(data, file_name, True if isinstance(data, bytes) else data)
return module, md5, True, False
#
#
def make_stub(name, signature, constants):
# name of files that are cached
so_cache_key = make_so_cache_key(signature, constants)
@@ -1325,9 +1321,10 @@ def make_stub(name, signature, constants):
def convert_type_repr(x):
match = re.search(r'!tt\.ptr<(.*)>', x)
if match is not None:
return '*' + convert_type_repr(match.group(1))
return '*' + convert_type_repr(match.group(1))
return x
def make_hash(fn, **kwargs):
if isinstance(fn, triton.runtime.JITFunction):
configs = kwargs["configs"]
@@ -1344,14 +1341,13 @@ def make_hash(fn, **kwargs):
return hashlib.md5(Path(fn).read_text().encode("utf-8")).hexdigest()
# def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: int = 4, num_stages: int = 3, extern_libs=None, configs=None):
def compile(fn, **kwargs):
# we get the kernel, i.e. the first function generated in the module
# if fn is not a JITFunction, then it
# if fn is not a JITFunction, then it
# has to be a path to a file
context = _triton.ir.context()
asm, md5 = dict(), dict()
asm = dict()
constants = kwargs.get("constants", dict())
if isinstance(fn, triton.runtime.JITFunction):
configs = kwargs.get("configs", None)
@@ -1374,64 +1370,66 @@ def compile(fn, **kwargs):
param_tys = [convert_type_repr(str(ty)) for ty in function.type.param_types()]
signature = {k: v for k, v in enumerate(param_tys)}
first_stage = 2
# cache manager
so_path = make_stub(name, signature, constants)
# create cache manager
fn_cache_manager = CacheManager(make_hash(fn, **kwargs))
# determine name and extension type of provided function
if isinstance(fn, triton.runtime.JITFunction):
name, ext = fn.__name__, "ast"
name, ext = fn.__name__, "ast"
else:
name, ext = os.path.basename(fn).split(".")
name, ext = os.path.basename(fn).split(".")
# initialize compilation params
num_warps = kwargs.get("num_warps", 4)
num_stages = kwargs.get("num_stages", 3)
extern_libs = kwargs.get("extern_libs", dict())
device = kwargs.get("device", torch.cuda.current_device())
compute_capability = torch.cuda.get_device_capability(device)
compute_capability = compute_capability[0] * 10 + compute_capability[1]
# load metadata if any
metadata = None
if fn_cache_manager.has_file(f'{name}.json'):
with open(fn_cache_manager._make_path(f"{name}.json")) as f:
with open(fn_cache_manager._make_path(f"{name}.json")) as f:
metadata = json.load(f)
else:
metadata = {"num_warps": num_warps, "num_stages": num_stages, "ctime": dict()}
metadata = {"num_warps": num_warps, "num_stages": num_stages, "ctime": dict()}
# build compilation stages
stages = {
"ast" : (lambda path: fn, None),
"ttir": (lambda path: _triton.ir.parse_mlir_module(path, context),
lambda src: ast_to_ttir(src, signature, configs[0], constants)),
"ttgir": (lambda path: _triton.ir.parse_mlir_module(path, context),
lambda src: ttir_to_ttgir(src, num_warps, num_stages)),
"llir": (lambda path: Path(path).read_bytes(),
lambda src: ttgir_to_llir(src, extern_libs)),
"ptx": (lambda path: Path(path).read_text(),
llir_to_ptx),
"cubin": (lambda path: Path(path).read_bytes(),
lambda src: ptx_to_cubin(src, device))
"ast": (lambda path: fn, None),
"ttir": (lambda path: _triton.ir.parse_mlir_module(path, context),
lambda src: ast_to_ttir(src, signature, configs[0], constants)),
"ttgir": (lambda path: _triton.ir.parse_mlir_module(path, context),
lambda src: ttir_to_ttgir(src, num_warps, num_stages, compute_capability)),
"llir": (lambda path: Path(path).read_bytes(),
lambda src: ttgir_to_llir(src, extern_libs, compute_capability)),
"ptx": (lambda path: Path(path).read_text(),
lambda src: llir_to_ptx(src, compute_capability)),
"cubin": (lambda path: Path(path).read_bytes(),
lambda src: ptx_to_cubin(src, compute_capability))
}
first_stage = list(stages.keys()).index(ext)
asm = dict()
module = fn
# run compilation pipeline and populate metadata
for ir, (parse, compile) in list(stages.items())[first_stage:]:
path = fn_cache_manager._make_path(f"{name}.{ir}")
if ir == ext:
next_module = parse(fn)
elif os.path.exists(path) and\
os.path.getctime(path) == metadata["ctime"][ir]:
next_module = parse(path)
else:
next_module = compile(module)
fn_cache_manager.put(next_module, f"{name}.{ir}")
if os.path.exists(path):
metadata["ctime"][ir] = os.path.getctime(path)
asm[ir] = next_module if ir == "cubin" else str(next_module)
if ir == "llir" and "shared" not in metadata:
metadata["shared"] = _triton.get_shared_memory_size(module)
if ir == "ptx":
metadata["name"] = ptx_get_kernel_name(next_module)
module = next_module
path = fn_cache_manager._make_path(f"{name}.{ir}")
if ir == ext:
next_module = parse(fn)
elif os.path.exists(path) and\
os.path.getctime(path) == metadata["ctime"][ir]:
next_module = parse(path)
else:
next_module = compile(module)
fn_cache_manager.put(next_module, f"{name}.{ir}")
if os.path.exists(path):
metadata["ctime"][ir] = os.path.getctime(path)
asm[ir] = next_module if ir == "cubin" else str(next_module)
if ir == "llir" and "shared" not in metadata:
metadata["shared"] = _triton.get_shared_memory_size(module)
if ir == "ptx":
metadata["name"] = ptx_get_kernel_name(next_module)
module = next_module
# write-back metadata
fn_cache_manager.put(json.dumps(metadata), f"{name}.json", binary=False)
# return handle to compiled kernel
@@ -1494,7 +1492,7 @@ class CudaUtils(object):
#include <cuda.h>
#include \"cuda.h\"
#define PY_SSIZE_T_CLEAN
#define PY_SSIZE_T_CLEAN
#include <Python.h>
static inline void gpuAssert(CUresult code, const char *file, int line)

View File

@@ -37,20 +37,21 @@ if __name__ == '__main__':
print(module.str())
exit(0)
if not args.sm:
raise argparse.ArgumentError(None, "Must specify --sm for PTX compilation")
# triton-ir -> triton-gpu-ir
module = triton.compiler.ttir_to_ttgir(module, num_warps=4, num_stages=3)
module = triton.compiler.ttir_to_ttgir(module, num_warps=4, num_stages=3, compute_capability=args.sm)
if args.target == 'triton-gpu-ir':
print(module.str())
exit(0)
# triton-gpu-ir -> llvm-ir
module = triton.compiler.ttgir_to_llir(module, extern_libs=None)
module = triton.compiler.ttgir_to_llir(module, extern_libs=None, compute_capability=args.sm)
if args.target == 'llvm-ir':
print(module)
exit(0)
if not args.sm:
raise argparse.ArgumentError(None, "Must specify --sm for PTX compilation")
if not args.ptx_version:
raise argparse.ArgumentError(None, "Must specify --ptx-version for PTX compilation")

View File

@@ -1,4 +1,4 @@
// RUN: %PYTHON -m triton.tools.aot %s --target=llvm-ir | FileCheck %s
// RUN: %PYTHON -m triton.tools.aot %s --target=llvm-ir --sm=80 | FileCheck %s
// == LLVM IR check begin ==
// CHECK-LABEL: ; ModuleID = 'LLVMDialectModule'