This PR merges the `triton-mlir` branch, in which we have been quietly rewriting the Triton backend from scratch to increase maintainability, stability and ultimately performance. Changes to the runtime are minimal, and this new version aims to remain backward-compatible with the previous commit. The legacy backend is now officially deprecated, but can still be accessed via the `legacy-backend` tag. Co-authored-by: Keren Zhou <kerenzhou@openai.com> Co-authored-by: Yan Chunwei <yanchunwei@outlook.com> Co-authored-by: goostavz <109190422+goostavz@users.noreply.github.com> Co-authored-by: Shintaro Iwasaki <siwasaki@fb.com> Co-authored-by: Yan Da <dyanab@connect.ust.hk> Co-authored-by: Jun Yang <yangjunpro@gmail.com> Co-authored-by: Ian Bearman <ianb@microsoft.com> Co-authored-by: Jason Ansel <jansel@jansel.net> Co-authored-by: Qingyi Liu <qingyil@nvidia.com> Co-authored-by: ben-zhang-609 <110140741+ben-zhang-609@users.noreply.github.com> Co-authored-by: Chenggang Zhao <lyricz@yeah.net> Co-authored-by: ben-zhang-609 <benzh609@gmail.com> Co-authored-by: dongdongl <dongdongl@nvidia.com>
648 lines
26 KiB
C++
648 lines
26 KiB
C++
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h"
|
|
|
|
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
|
#include "mlir/Dialect/GPU/GPUDialect.h"
|
|
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
|
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
|
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
#include "mlir/Transforms/DialectConversion.h"
|
|
#include "triton/Dialect/Triton/IR/Dialect.h"
|
|
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
|
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
|
|
#include "llvm/ADT/APSInt.h"
|
|
#include <numeric>
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::triton;
|
|
|
|
#define GEN_PASS_CLASSES
|
|
#include "triton/Conversion/Passes.h.inc"
|
|
|
|
namespace {
|
|
|
|
template <class Op> class GenericOpPattern : public OpConversionPattern<Op> {
|
|
public:
|
|
using OpConversionPattern<Op>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(Op op, typename Op::Adaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
Type retType = this->getTypeConverter()->convertType(op.getType());
|
|
rewriter.replaceOpWithNewOp<Op>(op, retType, adaptor.getOperands());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
template <class SrcOp, class DstOp>
|
|
class ArithCmpPattern : public OpConversionPattern<SrcOp> {
|
|
public:
|
|
using OpConversionPattern<SrcOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(SrcOp op, typename SrcOp::Adaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
Type retType = this->getTypeConverter()->convertType(op.getType());
|
|
rewriter.replaceOpWithNewOp<DstOp>(op, retType, adaptor.getPredicate(),
|
|
adaptor.getLhs(), adaptor.getRhs());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
class ArithConstantPattern : public OpConversionPattern<arith::ConstantOp> {
|
|
public:
|
|
using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
Type retType = getTypeConverter()->convertType(op.getType());
|
|
auto value = adaptor.getValue().dyn_cast<DenseElementsAttr>();
|
|
assert(value);
|
|
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
|
|
op, retType,
|
|
value.reshape(retType) // This is a hack. We just want to add encoding
|
|
);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
class ConvertArithmeticOp : public ConversionPattern {
|
|
public:
|
|
ConvertArithmeticOp(TritonGPUTypeConverter &typeConverter,
|
|
MLIRContext *context)
|
|
: ConversionPattern(typeConverter, MatchAnyOpTypeTag(), /*benefit=*/1,
|
|
context) {}
|
|
|
|
LogicalResult
|
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
Dialect *dialect = op->getDialect();
|
|
if (dialect->getTypeID() != mlir::TypeID::get<arith::ArithmeticDialect>())
|
|
return failure();
|
|
return success();
|
|
}
|
|
};
|
|
|
|
void populateArithmeticPatternsAndLegality(
|
|
TritonGPUTypeConverter &typeConverter, RewritePatternSet &patterns,
|
|
TritonGPUConversionTarget &target) {
|
|
// --------------
|
|
// Add legality and rewrite pattern rules for operations
|
|
// from the Arithmetic dialect. The basic premise is that
|
|
// arithmetic operations require both inputs to have the same
|
|
// non-null encoding
|
|
// --------------
|
|
MLIRContext *context = patterns.getContext();
|
|
// TODO: there's probably a better way to avoid adding all ops one-by-one
|
|
patterns.add<
|
|
ArithConstantPattern, GenericOpPattern<arith::AddIOp>,
|
|
GenericOpPattern<arith::SubIOp>, GenericOpPattern<arith::MulIOp>,
|
|
GenericOpPattern<arith::DivUIOp>, GenericOpPattern<arith::DivSIOp>,
|
|
GenericOpPattern<arith::CeilDivUIOp>,
|
|
GenericOpPattern<arith::CeilDivSIOp>,
|
|
GenericOpPattern<arith::FloorDivSIOp>, GenericOpPattern<arith::RemUIOp>,
|
|
GenericOpPattern<arith::RemSIOp>, GenericOpPattern<arith::AndIOp>,
|
|
GenericOpPattern<arith::OrIOp>, GenericOpPattern<arith::XOrIOp>,
|
|
GenericOpPattern<arith::ShLIOp>, GenericOpPattern<arith::ShRUIOp>,
|
|
GenericOpPattern<arith::ShRSIOp>, // NegFOp
|
|
// Floating point
|
|
GenericOpPattern<arith::AddFOp>, GenericOpPattern<arith::SubFOp>,
|
|
// MaxMin
|
|
GenericOpPattern<arith::MaxFOp>, GenericOpPattern<arith::MaxSIOp>,
|
|
GenericOpPattern<arith::MaxUIOp>, GenericOpPattern<arith::MinFOp>,
|
|
GenericOpPattern<arith::MinSIOp>, GenericOpPattern<arith::MinUIOp>,
|
|
// Floating point
|
|
GenericOpPattern<arith::MulFOp>, GenericOpPattern<arith::DivFOp>,
|
|
GenericOpPattern<arith::RemFOp>,
|
|
// Cmp
|
|
ArithCmpPattern<arith::CmpIOp, triton::gpu::CmpIOp>,
|
|
ArithCmpPattern<arith::CmpFOp, triton::gpu::CmpFOp>,
|
|
// Cast Ops
|
|
GenericOpPattern<arith::TruncIOp>, GenericOpPattern<arith::TruncFOp>,
|
|
GenericOpPattern<arith::ExtUIOp>, GenericOpPattern<arith::ExtSIOp>,
|
|
GenericOpPattern<arith::ExtFOp>, GenericOpPattern<arith::SIToFPOp>,
|
|
GenericOpPattern<arith::FPToSIOp>, GenericOpPattern<arith::FPToUIOp>,
|
|
GenericOpPattern<arith::UIToFPOp>>(typeConverter, context);
|
|
}
|
|
|
|
// this shouldn't exist if mlir's SelectOp checked encodings properly
|
|
class StdSelectPattern : public OpConversionPattern<SelectOp> {
|
|
public:
|
|
using OpConversionPattern<SelectOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(SelectOp op, typename SelectOp::Adaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
Type retType = this->getTypeConverter()->convertType(op.getType());
|
|
rewriter.replaceOpWithNewOp<triton::gpu::SelectOp>(
|
|
op, retType, adaptor.getCondition(), adaptor.getTrueValue(),
|
|
adaptor.getFalseValue());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
void populateStdPatternsAndLegality(TritonGPUTypeConverter &typeConverter,
|
|
RewritePatternSet &patterns,
|
|
TritonGPUConversionTarget &target) {
|
|
MLIRContext *context = patterns.getContext();
|
|
// Rewrite rule
|
|
patterns.add<StdSelectPattern>(typeConverter, context);
|
|
target.addLegalOp<ReturnOp>(); // this is ok because all functions are inlined
|
|
// by the frontend
|
|
}
|
|
|
|
void populateMathPatternsAndLegality(TritonGPUTypeConverter &typeConverter,
|
|
RewritePatternSet &patterns,
|
|
TritonGPUConversionTarget &target) {
|
|
MLIRContext *context = patterns.getContext();
|
|
// Rewrite rule
|
|
patterns.add<GenericOpPattern<math::ExpOp>, GenericOpPattern<math::CosOp>,
|
|
GenericOpPattern<math::SinOp>, GenericOpPattern<math::LogOp>,
|
|
GenericOpPattern<math::SqrtOp>>(typeConverter, context);
|
|
}
|
|
|
|
//
|
|
// Triton patterns
|
|
//
|
|
// TODO: Do we need to put them in anonymous namespace?
|
|
struct TritonMakeRangePattern
|
|
: public OpConversionPattern<triton::MakeRangeOp> {
|
|
using OpConversionPattern<triton::MakeRangeOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(triton::MakeRangeOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
Type retType = getTypeConverter()->convertType(op.getType());
|
|
rewriter.replaceOpWithNewOp<triton::MakeRangeOp>(
|
|
op, retType, adaptor.start(), adaptor.end());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct TritonExpandDimsPattern
|
|
: public OpConversionPattern<triton::ExpandDimsOp> {
|
|
using OpConversionPattern<triton::ExpandDimsOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(triton::ExpandDimsOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
// Type retType = op.getType());
|
|
RankedTensorType argType = adaptor.src().getType().cast<RankedTensorType>();
|
|
Attribute _argEncoding = argType.getEncoding();
|
|
if (!_argEncoding)
|
|
return failure();
|
|
auto argEncoding = _argEncoding.cast<triton::gpu::BlockedEncodingAttr>();
|
|
// return shape
|
|
auto retShape = argType.getShape().vec();
|
|
retShape.insert(retShape.begin() + op.axis(), 1);
|
|
// return encoding
|
|
auto retSizePerThread = argEncoding.getSizePerThread().vec();
|
|
retSizePerThread.insert(retSizePerThread.begin() + op.axis(), 1);
|
|
auto retThreadsPerWarp = argEncoding.getThreadsPerWarp().vec();
|
|
retThreadsPerWarp.insert(retThreadsPerWarp.begin() + op.axis(), 1);
|
|
auto retWarpsPerCTA = argEncoding.getWarpsPerCTA().vec();
|
|
retWarpsPerCTA.insert(retWarpsPerCTA.begin() + op.axis(), 1);
|
|
SmallVector<unsigned, 4> retOrder(retShape.size());
|
|
std::iota(retOrder.begin(), retOrder.end(), 0);
|
|
triton::gpu::BlockedEncodingAttr retEncoding =
|
|
triton::gpu::BlockedEncodingAttr::get(getContext(), retSizePerThread,
|
|
retThreadsPerWarp, retWarpsPerCTA,
|
|
retOrder);
|
|
// convert operand to slice of return type
|
|
Attribute newArgEncoding = triton::gpu::SliceEncodingAttr::get(
|
|
getContext(), op.axis(), retEncoding);
|
|
RankedTensorType newArgType = RankedTensorType::get(
|
|
argType.getShape(), argType.getElementType(), newArgEncoding);
|
|
// construct new op
|
|
auto newSrc = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
|
op.getLoc(), newArgType, adaptor.src());
|
|
rewriter.replaceOpWithNewOp<triton::ExpandDimsOp>(op, newSrc,
|
|
adaptor.axis());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
|
|
using OpConversionPattern<triton::DotOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(triton::DotOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
RankedTensorType origType = op.getType().cast<RankedTensorType>();
|
|
auto origShape = origType.getShape();
|
|
auto typeConverter = getTypeConverter<TritonGPUTypeConverter>();
|
|
int numWarps = typeConverter->getNumWarps();
|
|
|
|
SmallVector<unsigned> retSizePerThread = {1, 1};
|
|
if (origShape[0] * origShape[1] / (numWarps * 32) >= 4)
|
|
retSizePerThread = {2, 2};
|
|
if (origShape[0] * origShape[1] / (numWarps * 32) >= 16)
|
|
retSizePerThread = {4, 4};
|
|
SmallVector<unsigned> retOrder = {1, 0};
|
|
Attribute dEncoding = triton::gpu::BlockedEncodingAttr::get(
|
|
getContext(), origShape, retSizePerThread, retOrder, numWarps);
|
|
RankedTensorType retType =
|
|
RankedTensorType::get(origShape, origType.getElementType(), dEncoding);
|
|
// a & b must be of smem layout
|
|
auto aType = adaptor.a().getType().cast<RankedTensorType>();
|
|
auto bType = adaptor.b().getType().cast<RankedTensorType>();
|
|
Attribute aEncoding = aType.getEncoding();
|
|
Attribute bEncoding = bType.getEncoding();
|
|
if (!aEncoding || !bEncoding)
|
|
return failure();
|
|
Value a = adaptor.a();
|
|
Value b = adaptor.b();
|
|
Value c = adaptor.c();
|
|
if (!aEncoding.isa<triton::gpu::DotOperandEncodingAttr>()) {
|
|
Attribute encoding =
|
|
triton::gpu::DotOperandEncodingAttr::get(getContext(), 0, dEncoding);
|
|
auto dstType = RankedTensorType::get(aType.getShape(),
|
|
aType.getElementType(), encoding);
|
|
a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), dstType, a);
|
|
}
|
|
if (!bEncoding.isa<triton::gpu::DotOperandEncodingAttr>()) {
|
|
Attribute encoding =
|
|
triton::gpu::DotOperandEncodingAttr::get(getContext(), 1, dEncoding);
|
|
auto dstType = RankedTensorType::get(bType.getShape(),
|
|
bType.getElementType(), encoding);
|
|
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), dstType, b);
|
|
}
|
|
c = rewriter.create<triton::gpu::ConvertLayoutOp>(c.getLoc(), retType, c);
|
|
|
|
rewriter.replaceOpWithNewOp<triton::DotOp>(op, retType, a, b, c,
|
|
adaptor.allowTF32());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct TritonCatPattern : public OpConversionPattern<triton::CatOp> {
|
|
|
|
using OpConversionPattern<triton::CatOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(triton::CatOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
// For now, this behaves like generic, but this will evolve when
|
|
// we add support for `can_reorder=False`
|
|
Type retType = this->getTypeConverter()->convertType(op.getType());
|
|
rewriter.replaceOpWithNewOp<triton::CatOp>(op, retType,
|
|
adaptor.getOperands());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct TritonTransPattern : public OpConversionPattern<triton::TransOp> {
|
|
|
|
using OpConversionPattern<triton::TransOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(triton::TransOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
Value src = adaptor.src();
|
|
auto srcType = src.getType().cast<RankedTensorType>();
|
|
Attribute srcEncoding = srcType.getEncoding();
|
|
if (!srcEncoding)
|
|
return failure();
|
|
if (!srcEncoding.isa<triton::gpu::SharedEncodingAttr>()) {
|
|
// TODO: end-to-end correctness is broken if
|
|
// the input is blocked and the output is shared
|
|
// with different order. Maybe a backend issue in BlockedToShared?
|
|
SmallVector<unsigned> order = {1, 0};
|
|
if (auto srcBlockedEncoding =
|
|
srcEncoding.dyn_cast<triton::gpu::BlockedEncodingAttr>())
|
|
llvm::copy(srcBlockedEncoding.getOrder(), order.begin());
|
|
srcEncoding =
|
|
triton::gpu::SharedEncodingAttr::get(getContext(), 1, 1, 1, order);
|
|
srcType = RankedTensorType::get(srcType.getShape(),
|
|
srcType.getElementType(), srcEncoding);
|
|
src = rewriter.create<triton::gpu::ConvertLayoutOp>(src.getLoc(), srcType,
|
|
src);
|
|
}
|
|
auto srcSharedEncoding =
|
|
srcEncoding.cast<triton::gpu::SharedEncodingAttr>();
|
|
SmallVector<unsigned> retOrder(srcSharedEncoding.getOrder().begin(),
|
|
srcSharedEncoding.getOrder().end());
|
|
SmallVector<int64_t> retShapes(srcType.getShape().begin(),
|
|
srcType.getShape().end());
|
|
std::reverse(retOrder.begin(), retOrder.end());
|
|
std::reverse(retShapes.begin(), retShapes.end());
|
|
auto retEncoding =
|
|
triton::gpu::SharedEncodingAttr::get(getContext(), 1, 1, 1, retOrder);
|
|
auto retType =
|
|
RankedTensorType::get(retShapes, srcType.getElementType(), retEncoding);
|
|
|
|
rewriter.replaceOpWithNewOp<triton::TransOp>(op, retType, src);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct TritonLoadPattern : public OpConversionPattern<triton::LoadOp> {
|
|
using OpConversionPattern<triton::LoadOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
rewriter.replaceOpWithNewOp<triton::LoadOp>(
|
|
op, typeConverter->convertType(op.getType()), adaptor.ptr(),
|
|
adaptor.mask(), adaptor.other(), adaptor.cache(), adaptor.evict(),
|
|
adaptor.isVolatile());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct TritonStorePattern : public OpConversionPattern<triton::StoreOp> {
|
|
using OpConversionPattern<triton::StoreOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(triton::StoreOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
rewriter.replaceOpWithNewOp<triton::StoreOp>(
|
|
op, adaptor.ptr(), adaptor.value(), adaptor.mask());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct TritonAtomicCASPattern
|
|
: public OpConversionPattern<triton::AtomicCASOp> {
|
|
using OpConversionPattern<triton::AtomicCASOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(triton::AtomicCASOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
rewriter.replaceOpWithNewOp<triton::AtomicCASOp>(
|
|
op, typeConverter->convertType(op.getType()), adaptor.ptr(),
|
|
adaptor.cmp(), adaptor.val());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct TritonAtomicRMWPattern
|
|
: public OpConversionPattern<triton::AtomicRMWOp> {
|
|
using OpConversionPattern<triton::AtomicRMWOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(triton::AtomicRMWOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
rewriter.replaceOpWithNewOp<triton::AtomicRMWOp>(
|
|
op, typeConverter->convertType(op.getType()), adaptor.atomic_rmw_op(),
|
|
adaptor.ptr(), adaptor.val(), adaptor.mask());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct TritonExtElemwisePattern
|
|
: public OpConversionPattern<triton::ExtElemwiseOp> {
|
|
using OpConversionPattern<triton::ExtElemwiseOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(triton::ExtElemwiseOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
rewriter.replaceOpWithNewOp<triton::ExtElemwiseOp>(
|
|
op, typeConverter->convertType(op.getType()), adaptor.args(),
|
|
adaptor.libname(), adaptor.libpath(), adaptor.symbol());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
template <class Op>
|
|
struct TritonGenericPattern : public OpConversionPattern<Op> {
|
|
using OpConversionPattern<Op>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(Op op, typename Op::Adaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
Type retType = this->getTypeConverter()->convertType(op.getType());
|
|
rewriter.replaceOpWithNewOp<Op>(op, retType, adaptor.getOperands());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct TritonBroadcastPattern
|
|
: public OpConversionPattern<triton::BroadcastOp> {
|
|
using OpConversionPattern<triton::BroadcastOp>::OpConversionPattern;
|
|
|
|
// This creates a tensor with the new shape but the argument's layout
|
|
LogicalResult
|
|
matchAndRewrite(BroadcastOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto srcType = adaptor.src().getType().cast<RankedTensorType>();
|
|
auto srcEncoding = srcType.getEncoding();
|
|
if (!srcEncoding)
|
|
return failure();
|
|
auto opType = op.getType().cast<RankedTensorType>();
|
|
Type retType = RankedTensorType::get(opType.getShape(),
|
|
opType.getElementType(), srcEncoding);
|
|
// Type retType = this->getTypeConverter()->convertType(op.getType());
|
|
rewriter.replaceOpWithNewOp<triton::BroadcastOp>(op, retType,
|
|
adaptor.getOperands());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct TritonReducePattern : public OpConversionPattern<triton::ReduceOp> {
|
|
using OpConversionPattern<triton::ReduceOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
rewriter.replaceOpWithNewOp<triton::ReduceOp>(
|
|
op, adaptor.redOp(), adaptor.operand(), adaptor.axis());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct TritonPrintfPattern : public OpConversionPattern<triton::PrintfOp> {
|
|
using OpConversionPattern<PrintfOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(PrintfOp op, typename PrintfOp::Adaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
rewriter.replaceOpWithNewOp<triton::PrintfOp>(op, op.prefixAttr(),
|
|
adaptor.getOperands());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
|
|
RewritePatternSet &patterns) {
|
|
MLIRContext *context = patterns.getContext();
|
|
patterns.add< // TODO: view should have custom pattern that views the layout
|
|
TritonGenericPattern<triton::ViewOp>,
|
|
TritonGenericPattern<triton::BitcastOp>,
|
|
TritonGenericPattern<triton::FpToFpOp>,
|
|
TritonGenericPattern<triton::IntToPtrOp>,
|
|
TritonGenericPattern<triton::PtrToIntOp>,
|
|
TritonGenericPattern<triton::SplatOp>, TritonBroadcastPattern,
|
|
TritonGenericPattern<triton::AddPtrOp>, TritonCatPattern,
|
|
TritonReducePattern, TritonTransPattern, TritonExpandDimsPattern,
|
|
TritonMakeRangePattern, TritonDotPattern, TritonLoadPattern,
|
|
TritonStorePattern, TritonExtElemwisePattern, TritonPrintfPattern,
|
|
TritonAtomicRMWPattern>(typeConverter, context);
|
|
}
|
|
|
|
//
|
|
// SCF patterns
|
|
//
|
|
// This is borrowed from ConvertForOpTypes in
|
|
// SCF/Transforms/StructuralTypeConversions.cpp
|
|
struct SCFForPattern : public OpConversionPattern<scf::ForOp> {
|
|
using OpConversionPattern<scf::ForOp>::OpConversionPattern;
|
|
// Ref: ConvertForOpTypes
|
|
LogicalResult
|
|
matchAndRewrite(scf::ForOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto newOp =
|
|
cast<scf::ForOp>(rewriter.cloneWithoutRegions(*op.getOperation()));
|
|
rewriter.inlineRegionBefore(op.getLoopBody(), newOp.getLoopBody(),
|
|
newOp.getLoopBody().end());
|
|
|
|
// Now, update all the types.
|
|
|
|
// Convert the types of block arguments within the given region. This
|
|
// replaces each block with a new block containing the updated signature.
|
|
// The entry block may have a special conversion if `entryConversion` is
|
|
// provided. On success, the new entry block to the region is returned for
|
|
// convenience. Otherwise, failure is returned.
|
|
if (failed(rewriter.convertRegionTypes(&newOp.getLoopBody(),
|
|
*getTypeConverter()))) {
|
|
return rewriter.notifyMatchFailure(op, "could not convert body types");
|
|
}
|
|
// Change the clone to use the updated operands. We could have cloned with
|
|
// a BlockAndValueMapping, but this seems a bit more direct.
|
|
newOp->setOperands(adaptor.getOperands());
|
|
// Update the result types to the new converted types.
|
|
SmallVector<Type> newResultTypes;
|
|
for (Type type : op.getResultTypes()) {
|
|
Type newType = typeConverter->convertType(type);
|
|
if (!newType)
|
|
return rewriter.notifyMatchFailure(op, "not a 1:1 type conversion");
|
|
newResultTypes.push_back(newType);
|
|
}
|
|
for (auto t : llvm::zip(newOp.getResults(), newResultTypes))
|
|
std::get<0>(t).setType(std::get<1>(t));
|
|
|
|
rewriter.replaceOp(op, newOp.getResults());
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct SCFYieldPattern : public OpConversionPattern<scf::YieldOp> {
|
|
using OpConversionPattern<scf::YieldOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(scf::YieldOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
// rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock());
|
|
// rewriter.create<scf::YieldOp>(op.getLoc(), adaptor.getOperands());
|
|
// op.erase();
|
|
rewriter.replaceOpWithNewOp<scf::YieldOp>(op, adaptor.getOperands());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
// This is borrowed from ConvertFIfOpTypes in
|
|
// SCF/Transforms/StructuralTypeConversions.cpp
|
|
class SCFIfPattern : public OpConversionPattern<scf::IfOp> {
|
|
public:
|
|
using OpConversionPattern<scf::IfOp>::OpConversionPattern;
|
|
LogicalResult
|
|
matchAndRewrite(scf::IfOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
// TODO: Generalize this to any type conversion, not just 1:1.
|
|
//
|
|
// We need to implement something more sophisticated here that tracks which
|
|
// types convert to which other types and does the appropriate
|
|
// materialization logic.
|
|
// For example, it's possible that one result type converts to 0 types and
|
|
// another to 2 types, so newResultTypes would at least be the right size to
|
|
// not crash in the llvm::zip call below, but then we would set the the
|
|
// wrong type on the SSA values! These edge cases are also why we cannot
|
|
// safely use the TypeConverter::convertTypes helper here.
|
|
SmallVector<Type> newResultTypes;
|
|
for (auto type : op.getResultTypes()) {
|
|
Type newType = typeConverter->convertType(type);
|
|
if (!newType)
|
|
return rewriter.notifyMatchFailure(op, "not a 1:1 type conversion");
|
|
newResultTypes.push_back(newType);
|
|
}
|
|
|
|
// See comments in the ForOp pattern for why we clone without regions and
|
|
// then inline.
|
|
scf::IfOp newOp =
|
|
cast<scf::IfOp>(rewriter.cloneWithoutRegions(*op.getOperation()));
|
|
rewriter.inlineRegionBefore(op.getThenRegion(), newOp.getThenRegion(),
|
|
newOp.getThenRegion().end());
|
|
rewriter.inlineRegionBefore(op.getElseRegion(), newOp.getElseRegion(),
|
|
newOp.getElseRegion().end());
|
|
|
|
// Update the operands and types.
|
|
newOp->setOperands(adaptor.getOperands());
|
|
for (auto t : llvm::zip(newOp.getResults(), newResultTypes))
|
|
std::get<0>(t).setType(std::get<1>(t));
|
|
rewriter.replaceOp(op, newOp.getResults());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
void populateSCFPatterns(TritonGPUTypeConverter &typeConverter,
|
|
RewritePatternSet &patterns) {
|
|
MLIRContext *context = patterns.getContext();
|
|
patterns.add<SCFYieldPattern, SCFForPattern, SCFIfPattern>(typeConverter,
|
|
context);
|
|
}
|
|
|
|
class ConvertTritonToTritonGPU
|
|
: public ConvertTritonToTritonGPUBase<ConvertTritonToTritonGPU> {
|
|
public:
|
|
ConvertTritonToTritonGPU() = default;
|
|
// constructor with some parameters set explicitly.
|
|
ConvertTritonToTritonGPU(int numWarps) { this->numWarps = numWarps; }
|
|
|
|
void runOnOperation() override {
|
|
MLIRContext *context = &getContext();
|
|
ModuleOp mod = getOperation();
|
|
// type converter
|
|
TritonGPUTypeConverter typeConverter(context, numWarps);
|
|
TritonGPUConversionTarget target(*context, typeConverter);
|
|
// rewrite patterns
|
|
RewritePatternSet patterns(context);
|
|
// add rules
|
|
populateStdPatternsAndLegality(typeConverter, patterns, target);
|
|
populateArithmeticPatternsAndLegality(typeConverter, patterns, target);
|
|
populateMathPatternsAndLegality(typeConverter, patterns, target);
|
|
populateTritonPatterns(typeConverter, patterns);
|
|
// TODO: can we use
|
|
// mlir::scf::populateSCFStructurealTypeConversionsAndLegality(...) here?
|
|
populateSCFPatterns(typeConverter, patterns);
|
|
|
|
if (failed(applyPartialConversion(mod, target, std::move(patterns))))
|
|
return signalPassFailure();
|
|
|
|
auto inti = llvm::APSInt(32, false);
|
|
auto i32_ty = IntegerType::get(mod->getContext(), 32);
|
|
|
|
mod->setAttr(
|
|
AttrNumWarpsName,
|
|
IntegerAttr::get(i32_ty, llvm::APInt(32, numWarps.getValue())));
|
|
|
|
// update layouts
|
|
// broadcast src => multicast, dst => broadcasted
|
|
// if (failed(target.refineLayouts(mod, numWarps)))
|
|
// return signalPassFailure();
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
std::unique_ptr<OperationPass<ModuleOp>>
|
|
mlir::triton::createConvertTritonToTritonGPUPass(int numWarps) {
|
|
return std::make_unique<::ConvertTritonToTritonGPU>(numWarps);
|
|
}
|
|
|
|
std::unique_ptr<OperationPass<ModuleOp>>
|
|
mlir::triton::createConvertTritonToTritonGPUPass() {
|
|
return std::make_unique<::ConvertTritonToTritonGPU>();
|
|
}
|