More on SCF conversion

This commit is contained in:
Yan Da
2022-05-04 21:50:32 +08:00
parent a96fe07e1c
commit 26c59e4718
2 changed files with 101 additions and 46 deletions

View File

@@ -5,6 +5,7 @@
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "../PassDetail.h" #include "../PassDetail.h"
#include <llvm-6.0/llvm/Support/ErrorHandling.h>
using namespace mlir; using namespace mlir;
using namespace mlir::triton; using namespace mlir::triton;
@@ -12,7 +13,7 @@ using namespace mlir::triton;
namespace { namespace {
template<class Op> template<class Op>
class ArithBinaryPattern : public OpConversionPattern<Op> { class ArithGenericPattern : public OpConversionPattern<Op> {
public: public:
using OpConversionPattern<Op>::OpConversionPattern; using OpConversionPattern<Op>::OpConversionPattern;
@@ -48,8 +49,10 @@ public:
LogicalResult matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, LogicalResult matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
Type retType = getTypeConverter()->convertType(op.getType()); Type retType = getTypeConverter()->convertType(op.getType());
auto value = adaptor.getValue().dyn_cast<DenseElementsAttr>();
assert(value);
rewriter.replaceOpWithNewOp<arith::ConstantOp>( rewriter.replaceOpWithNewOp<arith::ConstantOp>(
op, retType, adaptor.getValue() op, retType, value.reshape(retType) // This is a hack. We just want to add encoding
); );
return success(); return success();
} }
@@ -90,40 +93,42 @@ void populateArithmeticPatternsAndLegality(
// Rewrite rule // Rewrite rule
// patterns.add<ConvertArithmeticOp>(typeConverter, context); // patterns.add<ConvertArithmeticOp>(typeConverter, context);
patterns.add<ArithConstantPattern, patterns.add<ArithConstantPattern,
ArithBinaryPattern<arith::AddIOp>, ArithGenericPattern<arith::AddIOp>,
ArithBinaryPattern<arith::SubIOp>, ArithGenericPattern<arith::SubIOp>,
ArithBinaryPattern<arith::MulIOp>, ArithGenericPattern<arith::MulIOp>,
ArithBinaryPattern<arith::DivUIOp>, ArithGenericPattern<arith::DivUIOp>,
ArithBinaryPattern<arith::DivSIOp>, ArithGenericPattern<arith::DivSIOp>,
ArithBinaryPattern<arith::CeilDivUIOp>, ArithGenericPattern<arith::CeilDivUIOp>,
ArithBinaryPattern<arith::CeilDivSIOp>, ArithGenericPattern<arith::CeilDivSIOp>,
ArithBinaryPattern<arith::FloorDivSIOp>, ArithGenericPattern<arith::FloorDivSIOp>,
ArithBinaryPattern<arith::RemUIOp>, ArithGenericPattern<arith::RemUIOp>,
ArithBinaryPattern<arith::RemSIOp>, ArithGenericPattern<arith::RemSIOp>,
ArithBinaryPattern<arith::AndIOp>, ArithGenericPattern<arith::AndIOp>,
ArithBinaryPattern<arith::OrIOp>, ArithGenericPattern<arith::OrIOp>,
ArithBinaryPattern<arith::XOrIOp>, ArithGenericPattern<arith::XOrIOp>,
ArithBinaryPattern<arith::ShLIOp>, ArithGenericPattern<arith::ShLIOp>,
ArithBinaryPattern<arith::ShRUIOp>, ArithGenericPattern<arith::ShRUIOp>,
ArithBinaryPattern<arith::ShRSIOp>, // NegFOp ArithGenericPattern<arith::ShRSIOp>, // NegFOp
// Floating point // Floating point
ArithBinaryPattern<arith::AddFOp>, ArithGenericPattern<arith::AddFOp>,
ArithBinaryPattern<arith::SubFOp>, ArithGenericPattern<arith::SubFOp>,
// MaxMin // MaxMin
ArithBinaryPattern<arith::MaxFOp>, ArithGenericPattern<arith::MaxFOp>,
ArithBinaryPattern<arith::MaxSIOp>, ArithGenericPattern<arith::MaxSIOp>,
ArithBinaryPattern<arith::MaxUIOp>, ArithGenericPattern<arith::MaxUIOp>,
ArithBinaryPattern<arith::MinFOp>, ArithGenericPattern<arith::MinFOp>,
ArithBinaryPattern<arith::MinSIOp>, ArithGenericPattern<arith::MinSIOp>,
ArithBinaryPattern<arith::MinUIOp>, ArithGenericPattern<arith::MinUIOp>,
// Floating point // Floating point
ArithBinaryPattern<arith::MulFOp>, ArithGenericPattern<arith::MulFOp>,
ArithBinaryPattern<arith::DivFOp>, ArithGenericPattern<arith::DivFOp>,
ArithBinaryPattern<arith::RemFOp>, ArithGenericPattern<arith::RemFOp>,
// Cmp // Cmp
ArithCmpPattern<arith::CmpIOp, triton::gpu::CmpIOp>, ArithCmpPattern<arith::CmpIOp, triton::gpu::CmpIOp>,
ArithCmpPattern<arith::CmpFOp, triton::gpu::CmpFOp> ArithCmpPattern<arith::CmpFOp, triton::gpu::CmpFOp>,
// Cast Ops // Cast Ops
ArithGenericPattern<arith::TruncIOp>,
ArithGenericPattern<arith::TruncFOp>
>(typeConverter, context); >(typeConverter, context);
} }
@@ -212,6 +217,46 @@ void populateTritonPatterns(
>(typeConverter, context); >(typeConverter, context);
} }
//
// SCF patterns
//
struct SCFForPattern : public OpConversionPattern<scf::ForOp> {
using OpConversionPattern<scf::ForOp>::OpConversionPattern;
LogicalResult matchAndRewrite(scf::ForOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto newFor = rewriter.replaceOpWithNewOp<scf::ForOp>(
op, adaptor.getLowerBound(), adaptor.getUpperBound(), adaptor.getStep(),
adaptor.getInitArgs()
);
// TODO: we need to copy (?) the body of ForOp
llvm_unreachable("Not implemented");
// newFor.getRegion().takeBody(adaptor.getRegion());
return success();
}
};
struct SCFYieldPattern : public OpConversionPattern<scf::YieldOp> {
using OpConversionPattern<scf::YieldOp>::OpConversionPattern;
LogicalResult matchAndRewrite(scf::YieldOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<scf::YieldOp>(
op, adaptor.getResults()
);
return success();
}
};
void populateSCFPatterns(
TritonGPUTypeConverter &typeConverter, RewritePatternSet &patterns
) {
MLIRContext *context = patterns.getContext();
patterns.add<SCFForPattern,
SCFYieldPattern
>(typeConverter, context);
}
class ConvertTritonToTritonGPU : class ConvertTritonToTritonGPU :
public ConvertTritonToTritonGPUBase<ConvertTritonToTritonGPU> { public ConvertTritonToTritonGPUBase<ConvertTritonToTritonGPU> {
@@ -229,7 +274,7 @@ public:
// add rules // add rules
populateArithmeticPatternsAndLegality(typeConverter, patterns, target); populateArithmeticPatternsAndLegality(typeConverter, patterns, target);
populateTritonPatterns(typeConverter, patterns); populateTritonPatterns(typeConverter, patterns);
populateSCFPatterns(typeConverter, patterns);
if(failed(applyPartialConversion(mod, target, if(failed(applyPartialConversion(mod, target,
std::move(patterns)))) std::move(patterns))))

View File

@@ -2,6 +2,7 @@
#include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include <algorithm> #include <algorithm>
#include <llvm-6.0/llvm/Support/ErrorHandling.h>
using namespace mlir; using namespace mlir;
@@ -45,6 +46,25 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
context, threadTileSize, blockTileSize, order); context, threadTileSize, blockTileSize, order);
return RankedTensorType::get(shape, elementType, encoding); return RankedTensorType::get(shape, elementType, encoding);
}); });
// materailizations
addArgumentMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
ValueRange inputs, Location loc) {
// llvm::errs() << "Trying to materialize target... " << inputs[0] << "\n";
llvm_unreachable("Not implemented");
return llvm::None;
});
addSourceMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
ValueRange inputs, Location loc) {
llvm_unreachable("Not implemented");
return llvm::None;
});
addTargetMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
ValueRange inputs, Location loc) {
llvm_unreachable("Not implemented");
// llvm::errs() << "Trying to materialize target... " << inputs[0] << "\n";
return llvm::None;
});
} }
// //
@@ -53,25 +73,15 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
TritonGPUConversionTarget::TritonGPUConversionTarget( TritonGPUConversionTarget::TritonGPUConversionTarget(
MLIRContext &context, TritonGPUTypeConverter &typeConverter) MLIRContext &context, TritonGPUTypeConverter &typeConverter)
: ConversionTarget(context), typeConverter(typeConverter) { : ConversionTarget(context), typeConverter(typeConverter) {
addLegalDialect<StandardOpsDialect, scf::SCFDialect>();
// Some ops from SCF are illegal // Some ops from SCF are illegal
addIllegalOp<scf::ExecuteRegionOp, scf::ParallelOp, addIllegalOp<scf::ExecuteRegionOp, scf::ParallelOp,
scf::ReduceOp, scf::ReduceReturnOp>(); scf::ReduceOp, scf::ReduceReturnOp>();
addDynamicallyLegalDialect<arith::ArithmeticDialect>([&](Operation *op) { addDynamicallyLegalDialect<arith::ArithmeticDialect,
if (typeConverter.isLegal(op)) triton::TritonDialect,
return true; triton::gpu::TritonGPUDialect,
return false; StandardOpsDialect,
}); scf::SCFDialect>([&](Operation *op) {
addDynamicallyLegalDialect<triton::TritonDialect>([&](Operation *op) {
if (typeConverter.isLegal(op))
return true;
return false;
});
addDynamicallyLegalDialect<triton::gpu::TritonGPUDialect>([&](Operation *op) {
if (typeConverter.isLegal(op)) if (typeConverter.isLegal(op))
return true; return true;
return false; return false;