[Triton-MLIR] Add compute capability (#902)
add compute capability from python frontend to backend. Co-authored-by: Keren Zhou <kerenzhou@openai.com>
This commit is contained in:
@@ -13,7 +13,7 @@ std::unique_ptr<Pass> createTritonGPUCanonicalizeLoopsPass();
|
||||
|
||||
std::unique_ptr<Pass> createTritonGPUCoalescePass();
|
||||
|
||||
std::unique_ptr<Pass> createTritonGPUCombineOpsPass();
|
||||
std::unique_ptr<Pass> createTritonGPUCombineOpsPass(int computeCapability = 80);
|
||||
|
||||
std::unique_ptr<Pass> createTritonGPUVerifier();
|
||||
|
||||
|
@@ -41,7 +41,7 @@ def TritonGPUCoalesce: Pass<"tritongpu-coalesce", "mlir::ModuleOp"> {
|
||||
let summary = "coalesce";
|
||||
|
||||
let description = [{
|
||||
TODO
|
||||
TODO
|
||||
}];
|
||||
|
||||
let constructor = "mlir::createTritonGPUCoalescePass()";
|
||||
@@ -63,6 +63,12 @@ def TritonGPUCombineOps : Pass<"tritongpu-combine", "mlir::ModuleOp"> {
|
||||
|
||||
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
|
||||
"mlir::triton::TritonDialect"];
|
||||
|
||||
let options = [
|
||||
Option<"computeCapability", "compute-capability",
|
||||
"int32_t", /*default*/"80",
|
||||
"device compute capability">
|
||||
];
|
||||
}
|
||||
|
||||
def TritonGPUCanonicalizeLoops: Pass<"tritongpu-canonicalize-loops", "mlir::ModuleOp"> {
|
||||
|
@@ -15,6 +15,7 @@
|
||||
#include "triton/Analysis/Utility.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
|
||||
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
@@ -560,11 +561,37 @@ public:
|
||||
// -----------------------------------------------------------------------------
|
||||
//
|
||||
// -----------------------------------------------------------------------------
|
||||
static int computeCapabilityToMMAVersion(int computeCapability) {
|
||||
if (computeCapability < 80) {
|
||||
return 1;
|
||||
} else if (computeCapability < 90) {
|
||||
return 2;
|
||||
} else {
|
||||
assert(false && "computeCapability > 90 not supported");
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
static SmallVector<int64_t, 2>
|
||||
mmaVersionToShapePerWarp(int version, const ArrayRef<int64_t> &shape,
|
||||
int numWarps) {
|
||||
if (version == 1) {
|
||||
return {16, 16};
|
||||
} else if (version == 2) {
|
||||
return {16, 8};
|
||||
} else {
|
||||
assert(false && "version not supported");
|
||||
return {0, 0};
|
||||
}
|
||||
}
|
||||
|
||||
class BlockedToMMA : public mlir::RewritePattern {
|
||||
int computeCapability;
|
||||
|
||||
public:
|
||||
BlockedToMMA(mlir::MLIRContext *context)
|
||||
: mlir::RewritePattern(triton::DotOp::getOperationName(), 2, context) {}
|
||||
BlockedToMMA(mlir::MLIRContext *context, int computeCapability)
|
||||
: mlir::RewritePattern(triton::DotOp::getOperationName(), 2, context),
|
||||
computeCapability(computeCapability) {}
|
||||
|
||||
static SmallVector<unsigned, 2>
|
||||
getWarpsPerTile(const ArrayRef<int64_t> &shape, int version, int numWarps) {
|
||||
@@ -572,7 +599,8 @@ public:
|
||||
// TODO: Handle one warp per row for fused matmuls
|
||||
// TODO: unsigned -> int64_t to keep things uniform
|
||||
SmallVector<unsigned, 2> ret = {1, 1};
|
||||
SmallVector<int64_t, 2> shapePerWarp = {16, 8};
|
||||
SmallVector<int64_t, 2> shapePerWarp =
|
||||
mmaVersionToShapePerWarp(version, shape, numWarps);
|
||||
bool changed = false;
|
||||
// TODO (@daadaada): double-check.
|
||||
// original logic in
|
||||
@@ -615,11 +643,12 @@ public:
|
||||
auto retShape = oldRetType.getShape();
|
||||
auto mod = op->getParentOfType<mlir::ModuleOp>();
|
||||
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
|
||||
auto newRetType =
|
||||
RankedTensorType::get(retShape, oldRetType.getElementType(),
|
||||
triton::gpu::MmaEncodingAttr::get(
|
||||
oldRetType.getContext(), 2,
|
||||
getWarpsPerTile(retShape, 2, numWarps)));
|
||||
int version = computeCapabilityToMMAVersion(computeCapability);
|
||||
auto newRetType = RankedTensorType::get(
|
||||
retShape, oldRetType.getElementType(),
|
||||
triton::gpu::MmaEncodingAttr::get(
|
||||
oldRetType.getContext(), version,
|
||||
getWarpsPerTile(retShape, version, numWarps)));
|
||||
// convert accumulator
|
||||
auto oldAcc = dotOp.getOperand(2);
|
||||
auto newAcc = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
@@ -656,6 +685,10 @@ public:
|
||||
class TritonGPUCombineOpsPass
|
||||
: public TritonGPUCombineOpsBase<TritonGPUCombineOpsPass> {
|
||||
public:
|
||||
TritonGPUCombineOpsPass() = default;
|
||||
TritonGPUCombineOpsPass(int computeCapability) {
|
||||
this->computeCapability = computeCapability;
|
||||
}
|
||||
void runOnOperation() override {
|
||||
MLIRContext *context = &getContext();
|
||||
ModuleOp m = getOperation();
|
||||
@@ -667,7 +700,7 @@ public:
|
||||
patterns.add<RematerializeBackward>(context);
|
||||
patterns.add<RematerializeForward>(context);
|
||||
patterns.add<MoveConvertOutOfLoop>(context);
|
||||
patterns.add<BlockedToMMA>(context);
|
||||
patterns.add<BlockedToMMA>(context, computeCapability);
|
||||
|
||||
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) {
|
||||
signalPassFailure();
|
||||
@@ -675,6 +708,7 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
std::unique_ptr<Pass> mlir::createTritonGPUCombineOpsPass() {
|
||||
return std::make_unique<TritonGPUCombineOpsPass>();
|
||||
std::unique_ptr<Pass>
|
||||
mlir::createTritonGPUCombineOpsPass(int computeCapability) {
|
||||
return std::make_unique<TritonGPUCombineOpsPass>(computeCapability);
|
||||
}
|
||||
|
@@ -1279,8 +1279,9 @@ void init_triton_ir(py::module &&m) {
|
||||
self.addPass(mlir::createTritonGPUPrefetchPass());
|
||||
})
|
||||
.def("add_triton_gpu_combine_pass",
|
||||
[](mlir::PassManager &self) {
|
||||
self.addPass(mlir::createTritonGPUCombineOpsPass());
|
||||
[](mlir::PassManager &self, int computeCapability) {
|
||||
self.addPass(
|
||||
mlir::createTritonGPUCombineOpsPass(computeCapability));
|
||||
})
|
||||
.def("add_triton_gpu_to_llvm",
|
||||
[](mlir::PassManager &self) {
|
||||
@@ -1301,7 +1302,7 @@ void init_triton_translation(py::module &m) {
|
||||
|
||||
m.def(
|
||||
"translate_triton_gpu_to_llvmir",
|
||||
[](mlir::ModuleOp op) {
|
||||
[](mlir::ModuleOp op, int computeCapability) {
|
||||
llvm::LLVMContext llvmContext;
|
||||
auto llvmModule =
|
||||
::mlir::triton::translateTritonGPUToLLVMIR(&llvmContext, op);
|
||||
|
@@ -15,10 +15,10 @@ import sysconfig
|
||||
import tempfile
|
||||
import warnings
|
||||
from collections import namedtuple
|
||||
from pathlib import Path
|
||||
from sysconfig import get_paths
|
||||
from typing import Any, Callable, Dict, Tuple, Union
|
||||
|
||||
from pathlib import Path
|
||||
import setuptools
|
||||
import torch
|
||||
from filelock import FileLock
|
||||
@@ -828,7 +828,7 @@ def kernel_suffix(signature, specialization):
|
||||
def build_triton_ir(fn, signature, specialization, constants):
|
||||
# canonicalize signature
|
||||
if isinstance(signature, str):
|
||||
signature = {k: v.strip() for k, v in enumerate(signature.split(","))}
|
||||
signature = {k: v.strip() for k, v in enumerate(signature.split(","))}
|
||||
context = _triton.ir.context()
|
||||
context.load_triton()
|
||||
# create kernel prototype
|
||||
@@ -876,23 +876,23 @@ def ast_to_ttir(fn, signature, specialization, constants):
|
||||
return optimize_triton_ir(mod)
|
||||
|
||||
|
||||
def ttir_to_ttgir(mod, num_warps, num_stages):
|
||||
def ttir_to_ttgir(mod, num_warps, num_stages, compute_capability):
|
||||
pm = _triton.ir.pass_manager(mod.context)
|
||||
pm.add_convert_triton_to_tritongpu_pass(num_warps)
|
||||
pm.enable_debug()
|
||||
# Convert blocked layout to mma layout for dot ops so that pipeline
|
||||
# can get shared memory swizzled correctly.
|
||||
pm.add_coalesce_pass()
|
||||
pm.add_triton_gpu_combine_pass()
|
||||
pm.add_triton_gpu_combine_pass(compute_capability)
|
||||
pm.add_tritongpu_pipeline_pass(num_stages)
|
||||
# Prefetch must be done after pipeline pass because pipeline pass
|
||||
# extracts slices from the original tensor.
|
||||
pm.add_tritongpu_prefetch_pass()
|
||||
pm.add_canonicalizer_pass()
|
||||
pm.add_cse_pass()
|
||||
pm.add_triton_gpu_combine_pass()
|
||||
pm.add_triton_gpu_combine_pass(compute_capability)
|
||||
pm.add_licm_pass()
|
||||
pm.add_triton_gpu_combine_pass()
|
||||
pm.add_triton_gpu_combine_pass(compute_capability)
|
||||
pm.add_cse_pass()
|
||||
pm.run(mod)
|
||||
return mod
|
||||
@@ -905,13 +905,13 @@ def add_external_libs(mod, libs):
|
||||
_triton.add_external_libs(mod, list(libs.keys()), list(libs.values()))
|
||||
|
||||
|
||||
def ttgir_to_llir(mod, extern_libs):
|
||||
def ttgir_to_llir(mod, extern_libs, compute_capability):
|
||||
if extern_libs:
|
||||
add_external_libs(mod, extern_libs)
|
||||
return _triton.translate_triton_gpu_to_llvmir(mod)
|
||||
return _triton.translate_triton_gpu_to_llvmir(mod, compute_capability)
|
||||
|
||||
|
||||
def llir_to_ptx(mod: Any, compute_capability: int = None, ptx_version: int = None) -> Tuple[str, int]:
|
||||
def llir_to_ptx(mod: Any, compute_capability: int, ptx_version: int = None) -> Tuple[str, int]:
|
||||
'''
|
||||
Translate TritonGPU module to PTX code.
|
||||
:param mod: a TritonGPU dialect module
|
||||
@@ -919,26 +919,20 @@ def llir_to_ptx(mod: Any, compute_capability: int = None, ptx_version: int = Non
|
||||
- PTX code
|
||||
- shared memory allocation size
|
||||
'''
|
||||
if compute_capability is None:
|
||||
device = torch.cuda.current_device()
|
||||
compute_capability = torch.cuda.get_device_capability(device)
|
||||
compute_capability = compute_capability[0] * 10 + compute_capability[1]
|
||||
if ptx_version is None:
|
||||
_, cuda_version = path_to_ptxas()
|
||||
ptx_version = ptx_get_version(cuda_version)
|
||||
return _triton.translate_llvmir_to_ptx(mod, compute_capability, ptx_version)
|
||||
|
||||
|
||||
def ptx_to_cubin(ptx: str, device: int):
|
||||
def ptx_to_cubin(ptx: str, compute_capability: int):
|
||||
'''
|
||||
Compile TritonGPU module to cubin.
|
||||
:param ptx: ptx code
|
||||
:param device: CUDA device
|
||||
:param compute_capability: compute capability
|
||||
:return: str
|
||||
'''
|
||||
ptxas, _ = path_to_ptxas()
|
||||
compute_capability = torch.cuda.get_device_capability(device)
|
||||
compute_capability = compute_capability[0] * 10 + compute_capability[1]
|
||||
return _triton.compile_ptx_to_cubin(ptx, ptxas, compute_capability)
|
||||
|
||||
|
||||
@@ -1190,7 +1184,7 @@ class CacheManager:
|
||||
return
|
||||
binary = isinstance(data, bytes)
|
||||
if not binary:
|
||||
data = str(data)
|
||||
data = str(data)
|
||||
assert self.lock_path is not None
|
||||
filepath = self._make_path(filename)
|
||||
with FileLock(self.lock_path):
|
||||
@@ -1292,18 +1286,20 @@ def read_or_execute(cache_manager, force_compile, file_name, metadata,
|
||||
run_if_not_found: Callable = None):
|
||||
suffix = file_name.split(".")[1]
|
||||
if not force_compile and cache_manager.has_file(file_name):
|
||||
module = run_if_found(cache_manager._make_path(file_name))
|
||||
data = module if isinstance(module, bytes) else str(module).encode("utf-8")
|
||||
md5 = hashlib.md5(data).hexdigest()
|
||||
has_changed = metadata and md5 != metadata["md5"][suffix]
|
||||
return module, md5, has_changed, True
|
||||
module = run_if_found(cache_manager._make_path(file_name))
|
||||
data = module if isinstance(module, bytes) else str(module).encode("utf-8")
|
||||
md5 = hashlib.md5(data).hexdigest()
|
||||
has_changed = metadata and md5 != metadata["md5"][suffix]
|
||||
return module, md5, has_changed, True
|
||||
module = run_if_not_found()
|
||||
data = module if isinstance(module, bytes) else str(module).encode("utf-8")
|
||||
md5 = hashlib.md5(data).hexdigest()
|
||||
cache_manager.put(data, file_name, True if isinstance(data, bytes) else data)
|
||||
return module, md5, True, False
|
||||
|
||||
#
|
||||
#
|
||||
|
||||
|
||||
def make_stub(name, signature, constants):
|
||||
# name of files that are cached
|
||||
so_cache_key = make_so_cache_key(signature, constants)
|
||||
@@ -1325,9 +1321,10 @@ def make_stub(name, signature, constants):
|
||||
def convert_type_repr(x):
|
||||
match = re.search(r'!tt\.ptr<(.*)>', x)
|
||||
if match is not None:
|
||||
return '*' + convert_type_repr(match.group(1))
|
||||
return '*' + convert_type_repr(match.group(1))
|
||||
return x
|
||||
|
||||
|
||||
def make_hash(fn, **kwargs):
|
||||
if isinstance(fn, triton.runtime.JITFunction):
|
||||
configs = kwargs["configs"]
|
||||
@@ -1344,14 +1341,13 @@ def make_hash(fn, **kwargs):
|
||||
return hashlib.md5(Path(fn).read_text().encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
|
||||
# def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: int = 4, num_stages: int = 3, extern_libs=None, configs=None):
|
||||
def compile(fn, **kwargs):
|
||||
# we get the kernel, i.e. the first function generated in the module
|
||||
# if fn is not a JITFunction, then it
|
||||
# if fn is not a JITFunction, then it
|
||||
# has to be a path to a file
|
||||
context = _triton.ir.context()
|
||||
asm, md5 = dict(), dict()
|
||||
asm = dict()
|
||||
constants = kwargs.get("constants", dict())
|
||||
if isinstance(fn, triton.runtime.JITFunction):
|
||||
configs = kwargs.get("configs", None)
|
||||
@@ -1374,64 +1370,66 @@ def compile(fn, **kwargs):
|
||||
param_tys = [convert_type_repr(str(ty)) for ty in function.type.param_types()]
|
||||
signature = {k: v for k, v in enumerate(param_tys)}
|
||||
first_stage = 2
|
||||
|
||||
|
||||
# cache manager
|
||||
so_path = make_stub(name, signature, constants)
|
||||
# create cache manager
|
||||
fn_cache_manager = CacheManager(make_hash(fn, **kwargs))
|
||||
# determine name and extension type of provided function
|
||||
if isinstance(fn, triton.runtime.JITFunction):
|
||||
name, ext = fn.__name__, "ast"
|
||||
name, ext = fn.__name__, "ast"
|
||||
else:
|
||||
name, ext = os.path.basename(fn).split(".")
|
||||
name, ext = os.path.basename(fn).split(".")
|
||||
# initialize compilation params
|
||||
num_warps = kwargs.get("num_warps", 4)
|
||||
num_stages = kwargs.get("num_stages", 3)
|
||||
extern_libs = kwargs.get("extern_libs", dict())
|
||||
device = kwargs.get("device", torch.cuda.current_device())
|
||||
compute_capability = torch.cuda.get_device_capability(device)
|
||||
compute_capability = compute_capability[0] * 10 + compute_capability[1]
|
||||
# load metadata if any
|
||||
metadata = None
|
||||
if fn_cache_manager.has_file(f'{name}.json'):
|
||||
with open(fn_cache_manager._make_path(f"{name}.json")) as f:
|
||||
with open(fn_cache_manager._make_path(f"{name}.json")) as f:
|
||||
metadata = json.load(f)
|
||||
else:
|
||||
metadata = {"num_warps": num_warps, "num_stages": num_stages, "ctime": dict()}
|
||||
metadata = {"num_warps": num_warps, "num_stages": num_stages, "ctime": dict()}
|
||||
# build compilation stages
|
||||
stages = {
|
||||
"ast" : (lambda path: fn, None),
|
||||
"ttir": (lambda path: _triton.ir.parse_mlir_module(path, context),
|
||||
lambda src: ast_to_ttir(src, signature, configs[0], constants)),
|
||||
"ttgir": (lambda path: _triton.ir.parse_mlir_module(path, context),
|
||||
lambda src: ttir_to_ttgir(src, num_warps, num_stages)),
|
||||
"llir": (lambda path: Path(path).read_bytes(),
|
||||
lambda src: ttgir_to_llir(src, extern_libs)),
|
||||
"ptx": (lambda path: Path(path).read_text(),
|
||||
llir_to_ptx),
|
||||
"cubin": (lambda path: Path(path).read_bytes(),
|
||||
lambda src: ptx_to_cubin(src, device))
|
||||
"ast": (lambda path: fn, None),
|
||||
"ttir": (lambda path: _triton.ir.parse_mlir_module(path, context),
|
||||
lambda src: ast_to_ttir(src, signature, configs[0], constants)),
|
||||
"ttgir": (lambda path: _triton.ir.parse_mlir_module(path, context),
|
||||
lambda src: ttir_to_ttgir(src, num_warps, num_stages, compute_capability)),
|
||||
"llir": (lambda path: Path(path).read_bytes(),
|
||||
lambda src: ttgir_to_llir(src, extern_libs, compute_capability)),
|
||||
"ptx": (lambda path: Path(path).read_text(),
|
||||
lambda src: llir_to_ptx(src, compute_capability)),
|
||||
"cubin": (lambda path: Path(path).read_bytes(),
|
||||
lambda src: ptx_to_cubin(src, compute_capability))
|
||||
}
|
||||
first_stage = list(stages.keys()).index(ext)
|
||||
asm = dict()
|
||||
module = fn
|
||||
# run compilation pipeline and populate metadata
|
||||
for ir, (parse, compile) in list(stages.items())[first_stage:]:
|
||||
path = fn_cache_manager._make_path(f"{name}.{ir}")
|
||||
if ir == ext:
|
||||
next_module = parse(fn)
|
||||
elif os.path.exists(path) and\
|
||||
os.path.getctime(path) == metadata["ctime"][ir]:
|
||||
next_module = parse(path)
|
||||
else:
|
||||
next_module = compile(module)
|
||||
fn_cache_manager.put(next_module, f"{name}.{ir}")
|
||||
if os.path.exists(path):
|
||||
metadata["ctime"][ir] = os.path.getctime(path)
|
||||
asm[ir] = next_module if ir == "cubin" else str(next_module)
|
||||
if ir == "llir" and "shared" not in metadata:
|
||||
metadata["shared"] = _triton.get_shared_memory_size(module)
|
||||
if ir == "ptx":
|
||||
metadata["name"] = ptx_get_kernel_name(next_module)
|
||||
module = next_module
|
||||
path = fn_cache_manager._make_path(f"{name}.{ir}")
|
||||
if ir == ext:
|
||||
next_module = parse(fn)
|
||||
elif os.path.exists(path) and\
|
||||
os.path.getctime(path) == metadata["ctime"][ir]:
|
||||
next_module = parse(path)
|
||||
else:
|
||||
next_module = compile(module)
|
||||
fn_cache_manager.put(next_module, f"{name}.{ir}")
|
||||
if os.path.exists(path):
|
||||
metadata["ctime"][ir] = os.path.getctime(path)
|
||||
asm[ir] = next_module if ir == "cubin" else str(next_module)
|
||||
if ir == "llir" and "shared" not in metadata:
|
||||
metadata["shared"] = _triton.get_shared_memory_size(module)
|
||||
if ir == "ptx":
|
||||
metadata["name"] = ptx_get_kernel_name(next_module)
|
||||
module = next_module
|
||||
# write-back metadata
|
||||
fn_cache_manager.put(json.dumps(metadata), f"{name}.json", binary=False)
|
||||
# return handle to compiled kernel
|
||||
@@ -1494,7 +1492,7 @@ class CudaUtils(object):
|
||||
#include <cuda.h>
|
||||
|
||||
#include \"cuda.h\"
|
||||
#define PY_SSIZE_T_CLEAN
|
||||
#define PY_SSIZE_T_CLEAN
|
||||
#include <Python.h>
|
||||
|
||||
static inline void gpuAssert(CUresult code, const char *file, int line)
|
||||
|
@@ -37,20 +37,21 @@ if __name__ == '__main__':
|
||||
print(module.str())
|
||||
exit(0)
|
||||
|
||||
if not args.sm:
|
||||
raise argparse.ArgumentError(None, "Must specify --sm for PTX compilation")
|
||||
|
||||
# triton-ir -> triton-gpu-ir
|
||||
module = triton.compiler.ttir_to_ttgir(module, num_warps=4, num_stages=3)
|
||||
module = triton.compiler.ttir_to_ttgir(module, num_warps=4, num_stages=3, compute_capability=args.sm)
|
||||
if args.target == 'triton-gpu-ir':
|
||||
print(module.str())
|
||||
exit(0)
|
||||
|
||||
# triton-gpu-ir -> llvm-ir
|
||||
module = triton.compiler.ttgir_to_llir(module, extern_libs=None)
|
||||
module = triton.compiler.ttgir_to_llir(module, extern_libs=None, compute_capability=args.sm)
|
||||
if args.target == 'llvm-ir':
|
||||
print(module)
|
||||
exit(0)
|
||||
|
||||
if not args.sm:
|
||||
raise argparse.ArgumentError(None, "Must specify --sm for PTX compilation")
|
||||
if not args.ptx_version:
|
||||
raise argparse.ArgumentError(None, "Must specify --ptx-version for PTX compilation")
|
||||
|
||||
|
@@ -1,4 +1,4 @@
|
||||
// RUN: %PYTHON -m triton.tools.aot %s --target=llvm-ir | FileCheck %s
|
||||
// RUN: %PYTHON -m triton.tools.aot %s --target=llvm-ir --sm=80 | FileCheck %s
|
||||
|
||||
// == LLVM IR check begin ==
|
||||
// CHECK-LABEL: ; ModuleID = 'LLVMDialectModule'
|
||||
|
Reference in New Issue
Block a user