Merge triton-mlir branch - Complete rewrite of the backend from scratch (#1004)

This PR merges the `triton-mlir` branch, in which we have been quietly
rewriting the Triton backend from scratch to increase maintainability,
stability and ultimately performance. Changes to the runtime are
minimal, and this new version aims to remain backward-compatible with
the previous commit. The legacy backend is now officially deprecated,
but can still be accessed via the `legacy-backend` tag.

Co-authored-by: Keren Zhou <kerenzhou@openai.com>
Co-authored-by: Yan Chunwei <yanchunwei@outlook.com>
Co-authored-by: goostavz <109190422+goostavz@users.noreply.github.com>
Co-authored-by: Shintaro Iwasaki <siwasaki@fb.com>
Co-authored-by: Yan Da <dyanab@connect.ust.hk>
Co-authored-by: Jun Yang <yangjunpro@gmail.com>
Co-authored-by: Ian Bearman <ianb@microsoft.com>
Co-authored-by: Jason Ansel <jansel@jansel.net>
Co-authored-by: Qingyi Liu <qingyil@nvidia.com>
Co-authored-by: ben-zhang-609 <110140741+ben-zhang-609@users.noreply.github.com>
Co-authored-by: Chenggang Zhao <lyricz@yeah.net>
Co-authored-by: ben-zhang-609 <benzh609@gmail.com>
Co-authored-by: dongdongl <dongdongl@nvidia.com>
This commit is contained in:
Philippe Tillet
2022-12-21 01:30:50 -08:00
committed by GitHub
parent 8650b4d1cb
commit 20100a7254
285 changed files with 26312 additions and 50143 deletions

View File

@@ -0,0 +1,521 @@
#include "TritonGPUToLLVM.h"
#include "DotOpHelpers.h"
#include "Utility.h"
using namespace mlir;
using namespace mlir::triton;
using ::mlir::LLVM::getElementsFromStruct;
using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
using ::mlir::LLVM::getStructFromElements;
using ::mlir::triton::gpu::getElemsPerThread;
using ::mlir::triton::gpu::SharedEncodingAttr;
struct ReturnOpConversion : public ConvertOpToLLVMPattern<::mlir::ReturnOp> {
using ConvertOpToLLVMPattern<ReturnOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
unsigned numArguments = op.getNumOperands();
// Currently, Triton kernel function always return nothing.
// TODO(Superjomn) add support for non-inline device function
if (numArguments > 0) {
return rewriter.notifyMatchFailure(
op, "Only kernel function with nothing returned is supported.");
}
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), ValueRange(),
op->getAttrs());
return success();
}
};
struct BroadcastOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::BroadcastOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::BroadcastOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(triton::BroadcastOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Following the order of indices in the legacy code, a broadcast of:
// [s(0), s(1) ... s(k-1), 1, s(k+1), s(k+2) ... s(n-1)]
// =>
// [s(0), s(1) ... s(k-1), s(k), s(k+1), s(k+2) ... s(n-1)]
//
// logically maps to a broadcast within a thread's scope:
// [cta(0)..cta(k-1), 1,cta(k+1)..cta(n-1),spt(0)..spt(k-1),
// 1,spt(k+1)..spt(n-1)]
// =>
// [cta(0)..cta(k-1),cta(k),cta(k+1)..cta(n-1),spt(0)..spt(k-1),spt(k),spt(k+1)..spt(n-1)]
//
// regardless of the order of the layout
//
Location loc = op->getLoc();
Value src = adaptor.src();
Value result = op.result();
auto srcTy = op.src().getType().cast<RankedTensorType>();
auto resultTy = result.getType().cast<RankedTensorType>();
auto srcLayout = srcTy.getEncoding();
auto resultLayout = resultTy.getEncoding();
auto srcShape = srcTy.getShape();
auto resultShape = resultTy.getShape();
unsigned rank = srcTy.getRank();
assert(rank == resultTy.getRank());
auto order = triton::gpu::getOrder(srcLayout);
auto srcOffsets = emitOffsetForLayout(srcLayout, srcShape);
auto resultOffsets = emitOffsetForLayout(resultLayout, resultShape);
SmallVector<Value> srcVals = getElementsFromStruct(loc, src, rewriter);
DenseMap<SmallVector<unsigned>, Value, SmallVectorKeyInfo> srcValues;
for (size_t i = 0; i < srcOffsets.size(); i++) {
srcValues[srcOffsets[i]] = srcVals[i];
}
SmallVector<Value> resultVals;
for (size_t i = 0; i < resultOffsets.size(); i++) {
auto offset = resultOffsets[i];
for (size_t j = 0; j < srcShape.size(); j++)
if (srcShape[j] == 1)
offset[j] = 0;
resultVals.push_back(srcValues.lookup(offset));
}
auto llvmStructTy = getTypeConverter()->convertType(resultTy);
Value resultStruct =
getStructFromElements(loc, resultVals, rewriter, llvmStructTy);
rewriter.replaceOp(op, {resultStruct});
return success();
}
};
struct PrintfOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::PrintfOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::PrintfOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(triton::PrintfOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
SmallVector<Value, 16> operands;
for (auto operand : adaptor.getOperands()) {
auto sub_operands = getElementsFromStruct(loc, operand, rewriter);
for (auto elem : sub_operands) {
operands.push_back(elem);
}
}
std::string formatStr;
llvm::raw_string_ostream os(formatStr);
os << op.prefix();
if (!operands.empty()) {
os << getFormatSubstr(operands[0]);
}
for (size_t i = 1; i < operands.size(); ++i) {
os << ", " << getFormatSubstr(operands[i]);
}
llPrintf(formatStr, operands, rewriter);
rewriter.eraseOp(op);
return success();
}
std::string getFormatSubstr(Value value) const {
Type type = value.getType();
if (type.isa<LLVM::LLVMPointerType>()) {
return "%p";
} else if (type.isBF16() || type.isF16() || type.isF32() || type.isF64()) {
return "%f";
} else if (type.isSignedInteger()) {
return "%i";
} else if (type.isUnsignedInteger() || type.isSignlessInteger()) {
return "%u";
}
assert(false && "not supported type");
return "";
}
// declare vprintf(i8*, i8*) as external function
static LLVM::LLVMFuncOp
getVprintfDeclaration(ConversionPatternRewriter &rewriter) {
auto moduleOp =
rewriter.getBlock()->getParent()->getParentOfType<ModuleOp>();
StringRef funcName("vprintf");
Operation *funcOp = moduleOp.lookupSymbol(funcName);
if (funcOp)
return cast<LLVM::LLVMFuncOp>(*funcOp);
auto *context = rewriter.getContext();
SmallVector<Type> argsType{ptr_ty(IntegerType::get(context, 8)),
ptr_ty(IntegerType::get(context, 8))};
auto funcType = LLVM::LLVMFunctionType::get(i32_ty, argsType);
ConversionPatternRewriter::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(moduleOp.getBody());
return rewriter.create<LLVM::LLVMFuncOp>(UnknownLoc::get(context), funcName,
funcType);
}
// extend integer to int32, extend float to float64
// this comes from vprintf alignment requirements.
static std::pair<Type, Value>
promoteValue(ConversionPatternRewriter &rewriter, Value value) {
auto *context = rewriter.getContext();
auto type = value.getType();
Value newOp = value;
Type newType = type;
bool bUnsigned = type.isUnsignedInteger();
if (type.isIntOrIndex() && type.getIntOrFloatBitWidth() < 32) {
if (bUnsigned) {
newType = ui32_ty;
newOp = rewriter.create<LLVM::ZExtOp>(UnknownLoc::get(context), newType,
value);
} else {
newType = i32_ty;
newOp = rewriter.create<LLVM::SExtOp>(UnknownLoc::get(context), newType,
value);
}
} else if (type.isBF16() || type.isF16() || type.isF32()) {
newType = f64_ty;
newOp = rewriter.create<LLVM::FPExtOp>(UnknownLoc::get(context), newType,
value);
}
return {newType, newOp};
}
static void llPrintf(StringRef msg, ValueRange args,
ConversionPatternRewriter &rewriter) {
static const char formatStringPrefix[] = "printfFormat_";
assert(!msg.empty() && "printf with empty string not support");
Type int8Ptr = ptr_ty(i8_ty);
auto *context = rewriter.getContext();
auto moduleOp =
rewriter.getBlock()->getParent()->getParentOfType<ModuleOp>();
auto funcOp = getVprintfDeclaration(rewriter);
Value one = rewriter.create<LLVM::ConstantOp>(
UnknownLoc::get(context), i32_ty, rewriter.getI32IntegerAttr(1));
Value zero = rewriter.create<LLVM::ConstantOp>(
UnknownLoc::get(context), i32_ty, rewriter.getI32IntegerAttr(0));
unsigned stringNumber = 0;
SmallString<16> stringConstName;
do {
stringConstName.clear();
(formatStringPrefix + Twine(stringNumber++)).toStringRef(stringConstName);
} while (moduleOp.lookupSymbol(stringConstName));
llvm::SmallString<64> formatString(msg);
formatString.push_back('\n');
formatString.push_back('\0');
size_t formatStringSize = formatString.size_in_bytes();
auto globalType = LLVM::LLVMArrayType::get(i8_ty, formatStringSize);
LLVM::GlobalOp global;
{
ConversionPatternRewriter::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(moduleOp.getBody());
global = rewriter.create<LLVM::GlobalOp>(
UnknownLoc::get(context), globalType,
/*isConstant=*/true, LLVM::Linkage::Internal, stringConstName,
rewriter.getStringAttr(formatString));
}
Value globalPtr =
rewriter.create<LLVM::AddressOfOp>(UnknownLoc::get(context), global);
Value stringStart = rewriter.create<LLVM::GEPOp>(
UnknownLoc::get(context), int8Ptr, globalPtr,
SmallVector<Value>({zero, zero}));
Value bufferPtr =
rewriter.create<LLVM::NullOp>(UnknownLoc::get(context), int8Ptr);
SmallVector<Value, 16> newArgs;
if (args.size() >= 1) {
SmallVector<Type> argTypes;
for (auto arg : args) {
Type newType;
Value newArg;
std::tie(newType, newArg) = promoteValue(rewriter, arg);
argTypes.push_back(newType);
newArgs.push_back(newArg);
}
Type structTy = LLVM::LLVMStructType::getLiteral(context, argTypes);
auto allocated = rewriter.create<LLVM::AllocaOp>(UnknownLoc::get(context),
ptr_ty(structTy), one,
/*alignment=*/0);
for (const auto &entry : llvm::enumerate(newArgs)) {
auto index = rewriter.create<LLVM::ConstantOp>(
UnknownLoc::get(context), i32_ty,
rewriter.getI32IntegerAttr(entry.index()));
auto fieldPtr = rewriter.create<LLVM::GEPOp>(
UnknownLoc::get(context), ptr_ty(argTypes[entry.index()]),
allocated, ArrayRef<Value>{zero, index});
rewriter.create<LLVM::StoreOp>(UnknownLoc::get(context), entry.value(),
fieldPtr);
}
bufferPtr = rewriter.create<LLVM::BitcastOp>(UnknownLoc::get(context),
int8Ptr, allocated);
}
SmallVector<Value> operands{stringStart, bufferPtr};
rewriter.create<LLVM::CallOp>(UnknownLoc::get(context), funcOp, operands);
}
};
struct MakeRangeOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::MakeRangeOp> {
MakeRangeOpConversion(LLVMTypeConverter &converter, PatternBenefit benefit)
: ConvertTritonGPUOpToLLVMPattern<triton::MakeRangeOp>(converter,
benefit) {}
LogicalResult
matchAndRewrite(triton::MakeRangeOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
auto rankedTy = op.result().getType().dyn_cast<RankedTensorType>();
auto shape = rankedTy.getShape();
auto layout = rankedTy.getEncoding();
auto elemTy = rankedTy.getElementType();
assert(elemTy.isInteger(32));
Value start = createIndexAttrConstant(rewriter, loc, elemTy, op.start());
auto idxs = emitIndices(loc, rewriter, layout, shape);
unsigned elems = idxs.size();
SmallVector<Value> retVals(elems);
// TODO: slice layout has more elements than expected.
// Unexpected behavior for make range, but generally OK when followed by
// expand dims + broadcast. very weird behavior otherwise potentially.
for (const auto multiDim : llvm::enumerate(idxs)) {
assert(multiDim.value().size() == 1);
retVals[multiDim.index()] = add(multiDim.value()[0], start);
}
SmallVector<Type> types(elems, elemTy);
Type structTy = LLVM::LLVMStructType::getLiteral(getContext(), types);
Value result = getStructFromElements(loc, retVals, rewriter, structTy);
rewriter.replaceOp(op, result);
return success();
}
};
struct GetProgramIdOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::GetProgramIdOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::GetProgramIdOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(triton::GetProgramIdOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
assert(op.axis() < 3);
Value blockId = rewriter.create<::mlir::gpu::BlockIdOp>(
loc, rewriter.getIndexType(), dims[op.axis()]);
auto llvmIndexTy = getTypeConverter()->getIndexType();
rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
op, TypeRange{llvmIndexTy}, ValueRange{blockId});
return success();
}
static constexpr mlir::gpu::Dimension dims[] = {mlir::gpu::Dimension::x,
mlir::gpu::Dimension::y,
mlir::gpu::Dimension::z};
};
struct GetNumProgramsOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::GetNumProgramsOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::GetNumProgramsOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(triton::GetNumProgramsOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
assert(op.axis() < 3);
Value blockId = rewriter.create<::mlir::gpu::GridDimOp>(
loc, rewriter.getIndexType(), dims[op.axis()]);
auto llvmIndexTy = getTypeConverter()->getIndexType();
rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
op, TypeRange{llvmIndexTy}, ValueRange{blockId});
return success();
}
static constexpr mlir::gpu::Dimension dims[] = {mlir::gpu::Dimension::x,
mlir::gpu::Dimension::y,
mlir::gpu::Dimension::z};
};
struct AddPtrOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::AddPtrOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::AddPtrOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(triton::AddPtrOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
auto resultTy = op.getType();
auto resultTensorTy = resultTy.dyn_cast<RankedTensorType>();
if (resultTensorTy) {
unsigned elems = getElemsPerThread(resultTy);
Type elemTy =
getTypeConverter()->convertType(resultTensorTy.getElementType());
SmallVector<Type> types(elems, elemTy);
Type structTy = LLVM::LLVMStructType::getLiteral(getContext(), types);
auto ptrs = getElementsFromStruct(loc, adaptor.ptr(), rewriter);
auto offsets = getElementsFromStruct(loc, adaptor.offset(), rewriter);
SmallVector<Value> resultVals(elems);
for (unsigned i = 0; i < elems; ++i) {
resultVals[i] = gep(elemTy, ptrs[i], offsets[i]);
}
Value view = getStructFromElements(loc, resultVals, rewriter, structTy);
rewriter.replaceOp(op, view);
} else {
assert(resultTy.isa<triton::PointerType>());
Type llResultTy = getTypeConverter()->convertType(resultTy);
Value result = gep(llResultTy, adaptor.ptr(), adaptor.offset());
rewriter.replaceOp(op, result);
}
return success();
}
};
struct AllocTensorOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::gpu::AllocTensorOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::gpu::AllocTensorOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(triton::gpu::AllocTensorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
Value smemBase = getSharedMemoryBase(loc, rewriter, op.getResult());
auto resultTy = op.getType().dyn_cast<RankedTensorType>();
auto llvmElemTy =
getTypeConverter()->convertType(resultTy.getElementType());
auto elemPtrTy = ptr_ty(llvmElemTy, 3);
smemBase = bitcast(smemBase, elemPtrTy);
auto order = resultTy.getEncoding().cast<SharedEncodingAttr>().getOrder();
// Workaround for 3D tensors
// TODO: we need to modify the pipeline pass to give a proper shared
// encoding to 3D tensors
SmallVector<unsigned> newOrder;
if (resultTy.getShape().size() == 3)
newOrder = {1 + order[0], 1 + order[1], 0};
else
newOrder = SmallVector<unsigned>(order.begin(), order.end());
auto smemObj = SharedMemoryObject(smemBase, resultTy.getShape(), newOrder,
loc, rewriter);
auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter);
rewriter.replaceOp(op, retVal);
return success();
}
};
struct ExtractSliceOpConversion
: public ConvertTritonGPUOpToLLVMPattern<tensor::ExtractSliceOp> {
using ConvertTritonGPUOpToLLVMPattern<
tensor::ExtractSliceOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(tensor::ExtractSliceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// %dst = extract_slice %src[%offsets]
Location loc = op->getLoc();
auto srcTy = op.source().getType().dyn_cast<RankedTensorType>();
auto srcLayout = srcTy.getEncoding().dyn_cast<SharedEncodingAttr>();
assert(srcLayout && "Unexpected resultLayout in ExtractSliceOpConversion");
assert(op.hasUnitStride() &&
"Only unit stride supported by ExtractSliceOpConversion");
// newBase = base + offset
// Triton supports either static and dynamic offsets
auto smemObj =
getSharedMemoryObjectFromStruct(loc, adaptor.source(), rewriter);
SmallVector<Value, 4> opOffsetVals;
SmallVector<Value, 4> offsetVals;
auto mixedOffsets = op.getMixedOffsets();
for (auto i = 0; i < mixedOffsets.size(); ++i) {
if (op.isDynamicOffset(i))
opOffsetVals.emplace_back(adaptor.offsets()[i]);
else
opOffsetVals.emplace_back(i32_val(op.getStaticOffset(i)));
offsetVals.emplace_back(add(smemObj.offsets[i], opOffsetVals[i]));
}
// Compute the offset based on the original strides of the shared memory
// object
auto offset = dot(rewriter, loc, opOffsetVals, smemObj.strides);
// newShape = rank_reduce(shape)
// Triton only supports static tensor sizes
SmallVector<Value, 4> strideVals;
for (auto i = 0; i < op.static_sizes().size(); ++i) {
if (op.getStaticSize(i) == 1) {
offsetVals.erase(offsetVals.begin() + i);
} else {
strideVals.emplace_back(smemObj.strides[i]);
}
}
auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType());
auto elemPtrTy = ptr_ty(llvmElemTy, 3);
auto resTy = op.getType().dyn_cast<RankedTensorType>();
smemObj = SharedMemoryObject(gep(elemPtrTy, smemObj.base, offset),
strideVals, offsetVals);
auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter);
rewriter.replaceOp(op, retVal);
return success();
}
};
struct AsyncWaitOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::gpu::AsyncWaitOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::gpu::AsyncWaitOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(triton::gpu::AsyncWaitOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
PTXBuilder ptxBuilder;
auto &asyncWaitOp = *ptxBuilder.create<>("cp.async.wait_group");
auto num = op->getAttrOfType<IntegerAttr>("num").getInt();
asyncWaitOp(ptxBuilder.newConstantOperand(num));
auto ctx = op.getContext();
auto loc = op.getLoc();
auto voidTy = void_ty(ctx);
ptxBuilder.launch(rewriter, loc, voidTy);
// Safe to remove the op since it doesn't have any return value.
rewriter.eraseOp(op);
return success();
}
};
void populateTritonGPUToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns, int numWarps,
AxisInfoAnalysis &axisInfoAnalysis,
const Allocation *allocation, Value smem,
PatternBenefit benefit) {
patterns.add<AddPtrOpConversion>(typeConverter, benefit);
patterns.add<AllocTensorOpConversion>(typeConverter, allocation, smem,
benefit);
patterns.add<AsyncWaitOpConversion>(typeConverter, benefit);
patterns.add<BroadcastOpConversion>(typeConverter, benefit);
patterns.add<ExtractSliceOpConversion>(typeConverter, allocation, smem,
benefit);
patterns.add<GetProgramIdOpConversion>(typeConverter, benefit);
patterns.add<GetNumProgramsOpConversion>(typeConverter, benefit);
patterns.add<MakeRangeOpConversion>(typeConverter, benefit);
patterns.add<ReturnOpConversion>(typeConverter, benefit);
patterns.add<PrintfOpConversion>(typeConverter, benefit);
}