Merge triton-mlir branch - Complete rewrite of the backend from scratch (#1004)

This PR merges the `triton-mlir` branch, in which we have been quietly
rewriting the Triton backend from scratch to increase maintainability,
stability and ultimately performance. Changes to the runtime are
minimal, and this new version aims to remain backward-compatible with
the previous commit. The legacy backend is now officially deprecated,
but can still be accessed via the `legacy-backend` tag.

Co-authored-by: Keren Zhou <kerenzhou@openai.com>
Co-authored-by: Yan Chunwei <yanchunwei@outlook.com>
Co-authored-by: goostavz <109190422+goostavz@users.noreply.github.com>
Co-authored-by: Shintaro Iwasaki <siwasaki@fb.com>
Co-authored-by: Yan Da <dyanab@connect.ust.hk>
Co-authored-by: Jun Yang <yangjunpro@gmail.com>
Co-authored-by: Ian Bearman <ianb@microsoft.com>
Co-authored-by: Jason Ansel <jansel@jansel.net>
Co-authored-by: Qingyi Liu <qingyil@nvidia.com>
Co-authored-by: ben-zhang-609 <110140741+ben-zhang-609@users.noreply.github.com>
Co-authored-by: Chenggang Zhao <lyricz@yeah.net>
Co-authored-by: ben-zhang-609 <benzh609@gmail.com>
Co-authored-by: dongdongl <dongdongl@nvidia.com>
This commit is contained in:
Philippe Tillet
2022-12-21 01:30:50 -08:00
committed by GitHub
parent 8650b4d1cb
commit 20100a7254
285 changed files with 26312 additions and 50143 deletions

View File

@@ -0,0 +1,229 @@
#include "ViewOpToLLVM.h"
#include "DotOpHelpers.h"
using namespace mlir;
using namespace mlir::triton;
using ::mlir::LLVM::DotOpMmaV1ConversionHelper;
using ::mlir::LLVM::DotOpMmaV2ConversionHelper;
using ::mlir::LLVM::getElementsFromStruct;
using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
using ::mlir::LLVM::getStructFromElements;
using ::mlir::triton::gpu::getElemsPerThread;
struct SplatOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::SplatOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::SplatOp>::ConvertTritonGPUOpToLLVMPattern;
// Convert SplatOp or arith::ConstantOp with SplatElementsAttr to a
// LLVM::StructType value.
//
// @elemType: the element type in operand.
// @resType: the return type of the Splat-like op.
// @constVal: a LLVM::ConstantOp or other scalar value.
static Value convertSplatLikeOp(Type elemType, Type resType, Value constVal,
TypeConverter *typeConverter,
ConversionPatternRewriter &rewriter,
Location loc) {
auto tensorTy = resType.cast<RankedTensorType>();
if (tensorTy.getEncoding().isa<BlockedEncodingAttr>() ||
tensorTy.getEncoding().isa<SliceEncodingAttr>()) {
auto srcType = typeConverter->convertType(elemType);
auto llSrc = bitcast(constVal, srcType);
size_t elemsPerThread = getElemsPerThread(tensorTy);
llvm::SmallVector<Value> elems(elemsPerThread, llSrc);
llvm::SmallVector<Type> elemTypes(elems.size(), srcType);
auto structTy =
LLVM::LLVMStructType::getLiteral(rewriter.getContext(), elemTypes);
return getStructFromElements(loc, elems, rewriter, structTy);
} else if (auto mmaLayout =
tensorTy.getEncoding().dyn_cast<MmaEncodingAttr>()) {
return convertSplatLikeOpWithMmaLayout(
mmaLayout, resType, elemType, constVal, typeConverter, rewriter, loc);
} else
assert(false && "Unsupported layout found in ConvertSplatLikeOp");
return {};
}
static Value convertSplatLikeOpWithMmaLayout(
const MmaEncodingAttr &layout, Type resType, Type elemType,
Value constVal, TypeConverter *typeConverter,
ConversionPatternRewriter &rewriter, Location loc) {
auto tensorTy = resType.cast<RankedTensorType>();
auto shape = tensorTy.getShape();
if (layout.isAmpere()) {
auto [repM, repN] = DotOpMmaV2ConversionHelper::getRepMN(tensorTy);
size_t fcSize = 4 * repM * repN;
auto structTy = LLVM::LLVMStructType::getLiteral(
rewriter.getContext(), SmallVector<Type>(fcSize, elemType));
return getStructFromElements(loc, SmallVector<Value>(fcSize, constVal),
rewriter, structTy);
}
if (layout.isVolta()) {
DotOpMmaV1ConversionHelper helper(layout);
int repM = helper.getRepM(shape[0]);
int repN = helper.getRepN(shape[1]);
// According to mma layout of v1, each thread process 8 elements.
int elems = 8 * repM * repN;
auto structTy = LLVM::LLVMStructType::getLiteral(
rewriter.getContext(), SmallVector<Type>(elems, elemType));
return getStructFromElements(loc, SmallVector<Value>(elems, constVal),
rewriter, structTy);
}
assert(false && "Unsupported mma layout found");
return {};
}
LogicalResult matchAndRewrite(triton::SplatOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto loc = op->getLoc();
auto src = adaptor.src();
auto llStruct = convertSplatLikeOp(src.getType(), op.getType(), src,
getTypeConverter(), rewriter, loc);
rewriter.replaceOp(op, {llStruct});
return success();
}
};
// This pattern helps to convert arith::ConstantOp(with SplatElementsAttr),
// the logic is the same as triton::SplatOp, so the underlying implementation
// is reused.
struct ArithConstantSplatOpConversion
: public ConvertTritonGPUOpToLLVMPattern<arith::ConstantOp> {
using ConvertTritonGPUOpToLLVMPattern<
arith::ConstantOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto value = op.getValue();
if (!value.dyn_cast<SplatElementsAttr>())
return failure();
auto loc = op->getLoc();
LLVM::ConstantOp arithConstantOp;
auto values = op.getValue().dyn_cast<SplatElementsAttr>();
auto elemType = values.getElementType();
Attribute val;
if (elemType.isBF16() || type::isFloat(elemType)) {
val = values.getValues<FloatAttr>()[0];
} else if (type::isInt(elemType)) {
val = values.getValues<IntegerAttr>()[0];
} else {
llvm::errs() << "ArithConstantSplatOpConversion get unsupported type: "
<< value.getType() << "\n";
return failure();
}
auto constOp = rewriter.create<LLVM::ConstantOp>(loc, elemType, val);
auto llStruct = SplatOpConversion::convertSplatLikeOp(
elemType, op.getType(), constOp, getTypeConverter(), rewriter, loc);
rewriter.replaceOp(op, llStruct);
return success();
}
};
struct CatOpConversion : public ConvertTritonGPUOpToLLVMPattern<CatOp> {
using OpAdaptor = typename CatOp::Adaptor;
explicit CatOpConversion(LLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
: ConvertTritonGPUOpToLLVMPattern<CatOp>(typeConverter, benefit) {}
LogicalResult
matchAndRewrite(CatOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
auto resultTy = op.getType().template cast<RankedTensorType>();
unsigned elems = getElemsPerThread(resultTy);
Type elemTy =
this->getTypeConverter()->convertType(resultTy.getElementType());
SmallVector<Type> types(elems, elemTy);
// unpack input values
auto lhsVals = getElementsFromStruct(loc, adaptor.lhs(), rewriter);
auto rhsVals = getElementsFromStruct(loc, adaptor.rhs(), rewriter);
// concatenate (and potentially reorder) values
SmallVector<Value> retVals;
for (Value v : lhsVals)
retVals.push_back(v);
for (Value v : rhsVals)
retVals.push_back(v);
// pack and replace
Type structTy = LLVM::LLVMStructType::getLiteral(this->getContext(), types);
Value ret = getStructFromElements(loc, retVals, rewriter, structTy);
rewriter.replaceOp(op, ret);
return success();
}
};
template <typename SourceOp>
struct ViewLikeOpConversion : public ConvertTritonGPUOpToLLVMPattern<SourceOp> {
using OpAdaptor = typename SourceOp::Adaptor;
explicit ViewLikeOpConversion(LLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
: ConvertTritonGPUOpToLLVMPattern<SourceOp>(typeConverter, benefit) {}
LogicalResult
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// We cannot directly run `rewriter.replaceOp(op, adaptor.src())`
// due to MLIR's restrictions
Location loc = op->getLoc();
auto resultTy = op.getType().template cast<RankedTensorType>();
unsigned elems = getElemsPerThread(resultTy);
Type elemTy =
this->getTypeConverter()->convertType(resultTy.getElementType());
SmallVector<Type> types(elems, elemTy);
Type structTy = LLVM::LLVMStructType::getLiteral(this->getContext(), types);
auto vals = getElementsFromStruct(loc, adaptor.src(), rewriter);
Value view = getStructFromElements(loc, vals, rewriter, structTy);
rewriter.replaceOp(op, view);
return success();
}
};
struct TransOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::TransOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::TransOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(triton::TransOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
auto srcSmemObj =
getSharedMemoryObjectFromStruct(loc, adaptor.src(), rewriter);
SmallVector<Value> dstStrides = {srcSmemObj.strides[1],
srcSmemObj.strides[0]};
SmallVector<Value> dstOffsets = {srcSmemObj.offsets[1],
srcSmemObj.offsets[0]};
auto dstSmemObj =
SharedMemoryObject(srcSmemObj.base, dstStrides, dstOffsets);
auto retVal = getStructFromSharedMemoryObject(loc, dstSmemObj, rewriter);
rewriter.replaceOp(op, retVal);
return success();
}
};
void populateViewOpToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns, int numWarps,
AxisInfoAnalysis &axisInfoAnalysis,
const Allocation *allocation, Value smem,
PatternBenefit benefit) {
patterns.add<ViewLikeOpConversion<triton::ViewOp>>(typeConverter, benefit);
patterns.add<ViewLikeOpConversion<triton::ExpandDimsOp>>(typeConverter,
benefit);
patterns.add<SplatOpConversion>(typeConverter, benefit);
patterns.add<ArithConstantSplatOpConversion>(typeConverter, benefit);
patterns.add<CatOpConversion>(typeConverter, benefit);
patterns.add<TransOpConversion>(typeConverter, benefit);
}