[OPTIMIZER] Made layout simplification pass efficient for fused attention kernels (#790)

This commit is contained in:
Philippe Tillet
2022-10-21 16:52:15 -07:00
committed by GitHub
parent c4726333bf
commit bb0f9235d1
26 changed files with 683 additions and 229 deletions

View File

@@ -1,6 +1,7 @@
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
#include "../PassDetail.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Transforms/DialectConversion.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
@@ -88,15 +89,7 @@ void populateArithmeticPatternsAndLegality(
// non-null encoding
// --------------
MLIRContext *context = patterns.getContext();
// // Legality rule
// target.addDynamicallyLegalDialect<arith::ArithmeticDialect>(
// // TODO: check above rule here
// [](Operation *op){
// return true;
// }
// );
// Rewrite rule
// patterns.add<ConvertArithmeticOp>(typeConverter, context);
// TODO: there's probably a better way to avoid adding all ops one-by-one
patterns.add<
ArithConstantPattern, GenericOpPattern<arith::AddIOp>,
GenericOpPattern<arith::SubIOp>, GenericOpPattern<arith::MulIOp>,
@@ -121,8 +114,35 @@ void populateArithmeticPatternsAndLegality(
ArithCmpPattern<arith::CmpIOp, triton::gpu::CmpIOp>,
ArithCmpPattern<arith::CmpFOp, triton::gpu::CmpFOp>,
// Cast Ops
GenericOpPattern<arith::TruncIOp>, GenericOpPattern<arith::TruncFOp>>(
typeConverter, context);
GenericOpPattern<arith::TruncIOp>, GenericOpPattern<arith::TruncFOp>,
GenericOpPattern<arith::SIToFPOp>>(typeConverter, context);
}
// this shouldn't exist if mlir's SelectOp checked encodings properly
class StdSelectPattern : public OpConversionPattern<SelectOp> {
public:
using OpConversionPattern<SelectOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(SelectOp op, typename SelectOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type retType = this->getTypeConverter()->convertType(op.getType());
triton::gpu::SelectOp res =
rewriter.replaceOpWithNewOp<triton::gpu::SelectOp>(
op, retType, adaptor.getCondition(), adaptor.getTrueValue(),
adaptor.getFalseValue());
return success();
}
};
void populateStdPatternsAndLegality(TritonGPUTypeConverter &typeConverter,
RewritePatternSet &patterns,
TritonGPUConversionTarget &target) {
MLIRContext *context = patterns.getContext();
// Rewrite rule
patterns.add<StdSelectPattern>(typeConverter, context);
target.addLegalOp<ReturnOp>(); // this is ok because all functions are inlined
// by the frontend
}
void populateMathPatternsAndLegality(TritonGPUTypeConverter &typeConverter,
@@ -231,7 +251,8 @@ struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), dstType, b);
}
auto newDot = rewriter.replaceOpWithNewOp<triton::DotOp>(
op, retType, a, b, adaptor.c(), adaptor.allowTF32());
op, retType, a, b, adaptor.c(), adaptor.allowTF32(), adaptor.transA(),
adaptor.transB());
return success();
}
};
@@ -418,6 +439,7 @@ public:
// rewrite patterns
RewritePatternSet patterns(context);
// add rules
populateStdPatternsAndLegality(typeConverter, patterns, target);
populateArithmeticPatternsAndLegality(typeConverter, patterns, target);
populateMathPatternsAndLegality(typeConverter, patterns, target);
populateTritonPatterns(typeConverter, patterns);