This PR merges the `triton-mlir` branch, in which we have been quietly rewriting the Triton backend from scratch to increase maintainability, stability and ultimately performance. Changes to the runtime are minimal, and this new version aims to remain backward-compatible with the previous commit. The legacy backend is now officially deprecated, but can still be accessed via the `legacy-backend` tag. Co-authored-by: Keren Zhou <kerenzhou@openai.com> Co-authored-by: Yan Chunwei <yanchunwei@outlook.com> Co-authored-by: goostavz <109190422+goostavz@users.noreply.github.com> Co-authored-by: Shintaro Iwasaki <siwasaki@fb.com> Co-authored-by: Yan Da <dyanab@connect.ust.hk> Co-authored-by: Jun Yang <yangjunpro@gmail.com> Co-authored-by: Ian Bearman <ianb@microsoft.com> Co-authored-by: Jason Ansel <jansel@jansel.net> Co-authored-by: Qingyi Liu <qingyil@nvidia.com> Co-authored-by: ben-zhang-609 <110140741+ben-zhang-609@users.noreply.github.com> Co-authored-by: Chenggang Zhao <lyricz@yeah.net> Co-authored-by: ben-zhang-609 <benzh609@gmail.com> Co-authored-by: dongdongl <dongdongl@nvidia.com>
152 lines
5.2 KiB
C++
152 lines
5.2 KiB
C++
#include "triton/Analysis/Utility.h"
|
|
#include "mlir/IR/Dialect.h"
|
|
#include "triton/Dialect/Triton/IR/Dialect.h"
|
|
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
|
|
|
namespace mlir {
|
|
|
|
bool ReduceOpHelper::isFastReduction() {
|
|
auto srcLayout = srcTy.getEncoding();
|
|
auto axis = op.axis();
|
|
return axis == triton::gpu::getOrder(srcLayout)[0];
|
|
}
|
|
|
|
unsigned ReduceOpHelper::getInterWarpSize() {
|
|
auto srcLayout = srcTy.getEncoding();
|
|
auto srcShape = srcTy.getShape();
|
|
auto axis = op.axis();
|
|
auto srcReduceDimSize = static_cast<unsigned>(srcShape[axis]);
|
|
unsigned sizeIntraWarps = getIntraWarpSize();
|
|
return std::min(srcReduceDimSize / sizeIntraWarps,
|
|
triton::gpu::getWarpsPerCTA(srcLayout)[axis]);
|
|
}
|
|
|
|
unsigned ReduceOpHelper::getIntraWarpSize() {
|
|
auto srcLayout = srcTy.getEncoding();
|
|
auto srcShape = srcTy.getShape();
|
|
auto axis = op.axis();
|
|
auto srcReduceDimSize = static_cast<unsigned>(srcShape[axis]);
|
|
return std::min(srcReduceDimSize,
|
|
triton::gpu::getThreadsPerWarp(srcLayout)[axis]);
|
|
}
|
|
|
|
unsigned ReduceOpHelper::getThreadsReductionAxis() {
|
|
auto srcLayout = srcTy.getEncoding();
|
|
auto axis = op.axis();
|
|
return triton::gpu::getThreadsPerWarp(srcLayout)[axis] *
|
|
triton::gpu::getWarpsPerCTA(srcLayout)[axis];
|
|
}
|
|
|
|
SmallVector<unsigned> ReduceOpHelper::getScratchConfigBasic() {
|
|
auto axis = op.axis();
|
|
auto smemShape = convertType<unsigned>(getSrcShape());
|
|
smemShape[axis] = std::min(smemShape[axis], getThreadsReductionAxis());
|
|
return smemShape;
|
|
}
|
|
|
|
SmallVector<SmallVector<unsigned>> ReduceOpHelper::getScratchConfigsFast() {
|
|
auto axis = op.axis();
|
|
SmallVector<SmallVector<unsigned>> smemShapes(3);
|
|
|
|
/// shared memory block0
|
|
smemShapes[0] = convertType<unsigned>(getSrcShape());
|
|
smemShapes[0][axis] = getInterWarpSize();
|
|
|
|
/// FIXME(Qingyi): This size is actually larger than required.
|
|
/// shared memory block1:
|
|
auto mod = op.getOperation()->getParentOfType<ModuleOp>();
|
|
unsigned numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
|
|
smemShapes[1].push_back(numWarps * 32);
|
|
|
|
return smemShapes;
|
|
}
|
|
|
|
unsigned ReduceOpHelper::getScratchSizeInBytes() {
|
|
unsigned elems = 0;
|
|
if (isFastReduction()) {
|
|
auto smemShapes = getScratchConfigsFast();
|
|
for (const auto &smemShape : smemShapes)
|
|
elems = std::max(elems, product<unsigned>(smemShape));
|
|
} else {
|
|
auto smemShape = getScratchConfigBasic();
|
|
elems = product<unsigned>(smemShape);
|
|
}
|
|
|
|
auto tensorType = op.operand().getType().cast<RankedTensorType>();
|
|
unsigned bytes = elems * tensorType.getElementTypeBitWidth() / 8;
|
|
|
|
if (triton::ReduceOp::withIndex(op.redOp()))
|
|
bytes += elems * sizeof(int32_t);
|
|
|
|
return bytes;
|
|
}
|
|
|
|
bool isSharedEncoding(Value value) {
|
|
auto type = value.getType();
|
|
if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
|
|
auto encoding = tensorType.getEncoding();
|
|
return encoding && encoding.isa<triton::gpu::SharedEncodingAttr>();
|
|
}
|
|
return false;
|
|
}
|
|
|
|
bool maybeSharedAllocationOp(Operation *op) {
|
|
// TODO(Keren): This function can be replaced by adding
|
|
// MemoryEffectOpInterface. We can then use the MemoryEffectOpInterface to
|
|
// query the memory effects of the op.
|
|
auto *dialect = op->getDialect();
|
|
return dialect &&
|
|
(dialect->getTypeID() ==
|
|
mlir::TypeID::get<triton::gpu::TritonGPUDialect>() ||
|
|
dialect->getTypeID() == mlir::TypeID::get<triton::TritonDialect>() ||
|
|
dialect->getTypeID() ==
|
|
mlir::TypeID::get<arith::ArithmeticDialect>() ||
|
|
dialect->getTypeID() == mlir::TypeID::get<tensor::TensorDialect>());
|
|
}
|
|
|
|
bool maybeAliasOp(Operation *op) {
|
|
return isa<tensor::ExtractSliceOp>(op) || isa<triton::TransOp>(op) ||
|
|
isa<triton::gpu::InsertSliceAsyncOp>(op) ||
|
|
isa<tensor::InsertSliceOp>(op);
|
|
}
|
|
|
|
bool supportMMA(triton::DotOp op, int version) {
|
|
// Refer to mma section for the data type supported by Volta and Hopper
|
|
// Tensor Core in
|
|
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-884-f16
|
|
auto aElemTy = op.a().getType().cast<RankedTensorType>().getElementType();
|
|
auto bElemTy = op.b().getType().cast<RankedTensorType>().getElementType();
|
|
if (aElemTy.isF32() && bElemTy.isF32()) {
|
|
return op.allowTF32() && version >= 2;
|
|
}
|
|
return supportMMA(op.a(), version) && supportMMA(op.b(), version);
|
|
}
|
|
|
|
bool supportMMA(Value value, int version) {
|
|
// Tell whether a DotOp support HMMA by the operand type(either $a or $b).
|
|
// We cannot get both the operand types(in TypeConverter), here we assume the
|
|
// types of both the operands are identical here.
|
|
assert((version == 1 || version == 2) &&
|
|
"Unexpected MMA layout version found");
|
|
auto elemTy = value.getType().cast<RankedTensorType>().getElementType();
|
|
return elemTy.isF16() || elemTy.isBF16() ||
|
|
(elemTy.isF32() && version >= 2) ||
|
|
(elemTy.isInteger(8) && version >= 2);
|
|
}
|
|
|
|
Type getElementType(Value value) {
|
|
auto type = value.getType();
|
|
if (auto tensorType = type.dyn_cast<RankedTensorType>())
|
|
return tensorType.getElementType();
|
|
return type;
|
|
}
|
|
|
|
std::string getValueOperandName(Value value, AsmState &state) {
|
|
std::string opName;
|
|
llvm::raw_string_ostream ss(opName);
|
|
value.printAsOperand(ss, state);
|
|
return opName;
|
|
}
|
|
|
|
} // namespace mlir
|