[OPTIMIZER] Made layout simplification pass efficient for fused attention kernels (#790)
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
|
||||
#include "../PassDetail.h"
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
@@ -88,15 +89,7 @@ void populateArithmeticPatternsAndLegality(
|
||||
// non-null encoding
|
||||
// --------------
|
||||
MLIRContext *context = patterns.getContext();
|
||||
// // Legality rule
|
||||
// target.addDynamicallyLegalDialect<arith::ArithmeticDialect>(
|
||||
// // TODO: check above rule here
|
||||
// [](Operation *op){
|
||||
// return true;
|
||||
// }
|
||||
// );
|
||||
// Rewrite rule
|
||||
// patterns.add<ConvertArithmeticOp>(typeConverter, context);
|
||||
// TODO: there's probably a better way to avoid adding all ops one-by-one
|
||||
patterns.add<
|
||||
ArithConstantPattern, GenericOpPattern<arith::AddIOp>,
|
||||
GenericOpPattern<arith::SubIOp>, GenericOpPattern<arith::MulIOp>,
|
||||
@@ -121,8 +114,35 @@ void populateArithmeticPatternsAndLegality(
|
||||
ArithCmpPattern<arith::CmpIOp, triton::gpu::CmpIOp>,
|
||||
ArithCmpPattern<arith::CmpFOp, triton::gpu::CmpFOp>,
|
||||
// Cast Ops
|
||||
GenericOpPattern<arith::TruncIOp>, GenericOpPattern<arith::TruncFOp>>(
|
||||
typeConverter, context);
|
||||
GenericOpPattern<arith::TruncIOp>, GenericOpPattern<arith::TruncFOp>,
|
||||
GenericOpPattern<arith::SIToFPOp>>(typeConverter, context);
|
||||
}
|
||||
|
||||
// this shouldn't exist if mlir's SelectOp checked encodings properly
|
||||
class StdSelectPattern : public OpConversionPattern<SelectOp> {
|
||||
public:
|
||||
using OpConversionPattern<SelectOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(SelectOp op, typename SelectOp::Adaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Type retType = this->getTypeConverter()->convertType(op.getType());
|
||||
triton::gpu::SelectOp res =
|
||||
rewriter.replaceOpWithNewOp<triton::gpu::SelectOp>(
|
||||
op, retType, adaptor.getCondition(), adaptor.getTrueValue(),
|
||||
adaptor.getFalseValue());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void populateStdPatternsAndLegality(TritonGPUTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns,
|
||||
TritonGPUConversionTarget &target) {
|
||||
MLIRContext *context = patterns.getContext();
|
||||
// Rewrite rule
|
||||
patterns.add<StdSelectPattern>(typeConverter, context);
|
||||
target.addLegalOp<ReturnOp>(); // this is ok because all functions are inlined
|
||||
// by the frontend
|
||||
}
|
||||
|
||||
void populateMathPatternsAndLegality(TritonGPUTypeConverter &typeConverter,
|
||||
@@ -231,7 +251,8 @@ struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
|
||||
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), dstType, b);
|
||||
}
|
||||
auto newDot = rewriter.replaceOpWithNewOp<triton::DotOp>(
|
||||
op, retType, a, b, adaptor.c(), adaptor.allowTF32());
|
||||
op, retType, a, b, adaptor.c(), adaptor.allowTF32(), adaptor.transA(),
|
||||
adaptor.transB());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -418,6 +439,7 @@ public:
|
||||
// rewrite patterns
|
||||
RewritePatternSet patterns(context);
|
||||
// add rules
|
||||
populateStdPatternsAndLegality(typeConverter, patterns, target);
|
||||
populateArithmeticPatternsAndLegality(typeConverter, patterns, target);
|
||||
populateMathPatternsAndLegality(typeConverter, patterns, target);
|
||||
populateTritonPatterns(typeConverter, patterns);
|
||||
|
Reference in New Issue
Block a user