More on SCF conversion

This commit is contained in:
Yan Da
2022-05-04 21:50:32 +08:00
parent a96fe07e1c
commit 26c59e4718
2 changed files with 101 additions and 46 deletions

View File

@@ -5,6 +5,7 @@
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "../PassDetail.h"
#include <llvm-6.0/llvm/Support/ErrorHandling.h>
using namespace mlir;
using namespace mlir::triton;
@@ -12,7 +13,7 @@ using namespace mlir::triton;
namespace {
template<class Op>
class ArithBinaryPattern : public OpConversionPattern<Op> {
class ArithGenericPattern : public OpConversionPattern<Op> {
public:
using OpConversionPattern<Op>::OpConversionPattern;
@@ -48,8 +49,10 @@ public:
LogicalResult matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type retType = getTypeConverter()->convertType(op.getType());
auto value = adaptor.getValue().dyn_cast<DenseElementsAttr>();
assert(value);
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
op, retType, adaptor.getValue()
op, retType, value.reshape(retType) // This is a hack. We just want to add encoding
);
return success();
}
@@ -90,40 +93,42 @@ void populateArithmeticPatternsAndLegality(
// Rewrite rule
// patterns.add<ConvertArithmeticOp>(typeConverter, context);
patterns.add<ArithConstantPattern,
ArithBinaryPattern<arith::AddIOp>,
ArithBinaryPattern<arith::SubIOp>,
ArithBinaryPattern<arith::MulIOp>,
ArithBinaryPattern<arith::DivUIOp>,
ArithBinaryPattern<arith::DivSIOp>,
ArithBinaryPattern<arith::CeilDivUIOp>,
ArithBinaryPattern<arith::CeilDivSIOp>,
ArithBinaryPattern<arith::FloorDivSIOp>,
ArithBinaryPattern<arith::RemUIOp>,
ArithBinaryPattern<arith::RemSIOp>,
ArithBinaryPattern<arith::AndIOp>,
ArithBinaryPattern<arith::OrIOp>,
ArithBinaryPattern<arith::XOrIOp>,
ArithBinaryPattern<arith::ShLIOp>,
ArithBinaryPattern<arith::ShRUIOp>,
ArithBinaryPattern<arith::ShRSIOp>, // NegFOp
ArithGenericPattern<arith::AddIOp>,
ArithGenericPattern<arith::SubIOp>,
ArithGenericPattern<arith::MulIOp>,
ArithGenericPattern<arith::DivUIOp>,
ArithGenericPattern<arith::DivSIOp>,
ArithGenericPattern<arith::CeilDivUIOp>,
ArithGenericPattern<arith::CeilDivSIOp>,
ArithGenericPattern<arith::FloorDivSIOp>,
ArithGenericPattern<arith::RemUIOp>,
ArithGenericPattern<arith::RemSIOp>,
ArithGenericPattern<arith::AndIOp>,
ArithGenericPattern<arith::OrIOp>,
ArithGenericPattern<arith::XOrIOp>,
ArithGenericPattern<arith::ShLIOp>,
ArithGenericPattern<arith::ShRUIOp>,
ArithGenericPattern<arith::ShRSIOp>, // NegFOp
// Floating point
ArithBinaryPattern<arith::AddFOp>,
ArithBinaryPattern<arith::SubFOp>,
ArithGenericPattern<arith::AddFOp>,
ArithGenericPattern<arith::SubFOp>,
// MaxMin
ArithBinaryPattern<arith::MaxFOp>,
ArithBinaryPattern<arith::MaxSIOp>,
ArithBinaryPattern<arith::MaxUIOp>,
ArithBinaryPattern<arith::MinFOp>,
ArithBinaryPattern<arith::MinSIOp>,
ArithBinaryPattern<arith::MinUIOp>,
ArithGenericPattern<arith::MaxFOp>,
ArithGenericPattern<arith::MaxSIOp>,
ArithGenericPattern<arith::MaxUIOp>,
ArithGenericPattern<arith::MinFOp>,
ArithGenericPattern<arith::MinSIOp>,
ArithGenericPattern<arith::MinUIOp>,
// Floating point
ArithBinaryPattern<arith::MulFOp>,
ArithBinaryPattern<arith::DivFOp>,
ArithBinaryPattern<arith::RemFOp>,
ArithGenericPattern<arith::MulFOp>,
ArithGenericPattern<arith::DivFOp>,
ArithGenericPattern<arith::RemFOp>,
// Cmp
ArithCmpPattern<arith::CmpIOp, triton::gpu::CmpIOp>,
ArithCmpPattern<arith::CmpFOp, triton::gpu::CmpFOp>
ArithCmpPattern<arith::CmpFOp, triton::gpu::CmpFOp>,
// Cast Ops
ArithGenericPattern<arith::TruncIOp>,
ArithGenericPattern<arith::TruncFOp>
>(typeConverter, context);
}
@@ -212,6 +217,46 @@ void populateTritonPatterns(
>(typeConverter, context);
}
//
// SCF patterns
//
struct SCFForPattern : public OpConversionPattern<scf::ForOp> {
using OpConversionPattern<scf::ForOp>::OpConversionPattern;
LogicalResult matchAndRewrite(scf::ForOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto newFor = rewriter.replaceOpWithNewOp<scf::ForOp>(
op, adaptor.getLowerBound(), adaptor.getUpperBound(), adaptor.getStep(),
adaptor.getInitArgs()
);
// TODO: we need to copy (?) the body of ForOp
llvm_unreachable("Not implemented");
// newFor.getRegion().takeBody(adaptor.getRegion());
return success();
}
};
struct SCFYieldPattern : public OpConversionPattern<scf::YieldOp> {
using OpConversionPattern<scf::YieldOp>::OpConversionPattern;
LogicalResult matchAndRewrite(scf::YieldOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<scf::YieldOp>(
op, adaptor.getResults()
);
return success();
}
};
void populateSCFPatterns(
TritonGPUTypeConverter &typeConverter, RewritePatternSet &patterns
) {
MLIRContext *context = patterns.getContext();
patterns.add<SCFForPattern,
SCFYieldPattern
>(typeConverter, context);
}
class ConvertTritonToTritonGPU :
public ConvertTritonToTritonGPUBase<ConvertTritonToTritonGPU> {
@@ -229,7 +274,7 @@ public:
// add rules
populateArithmeticPatternsAndLegality(typeConverter, patterns, target);
populateTritonPatterns(typeConverter, patterns);
populateSCFPatterns(typeConverter, patterns);
if(failed(applyPartialConversion(mod, target,
std::move(patterns))))

View File

@@ -2,6 +2,7 @@
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include <algorithm>
#include <llvm-6.0/llvm/Support/ErrorHandling.h>
using namespace mlir;
@@ -45,6 +46,25 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
context, threadTileSize, blockTileSize, order);
return RankedTensorType::get(shape, elementType, encoding);
});
// materailizations
addArgumentMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
ValueRange inputs, Location loc) {
// llvm::errs() << "Trying to materialize target... " << inputs[0] << "\n";
llvm_unreachable("Not implemented");
return llvm::None;
});
addSourceMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
ValueRange inputs, Location loc) {
llvm_unreachable("Not implemented");
return llvm::None;
});
addTargetMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
ValueRange inputs, Location loc) {
llvm_unreachable("Not implemented");
// llvm::errs() << "Trying to materialize target... " << inputs[0] << "\n";
return llvm::None;
});
}
//
@@ -53,25 +73,15 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
TritonGPUConversionTarget::TritonGPUConversionTarget(
MLIRContext &context, TritonGPUTypeConverter &typeConverter)
: ConversionTarget(context), typeConverter(typeConverter) {
addLegalDialect<StandardOpsDialect, scf::SCFDialect>();
// Some ops from SCF are illegal
addIllegalOp<scf::ExecuteRegionOp, scf::ParallelOp,
scf::ReduceOp, scf::ReduceReturnOp>();
addDynamicallyLegalDialect<arith::ArithmeticDialect>([&](Operation *op) {
if (typeConverter.isLegal(op))
return true;
return false;
});
addDynamicallyLegalDialect<triton::TritonDialect>([&](Operation *op) {
if (typeConverter.isLegal(op))
return true;
return false;
});
addDynamicallyLegalDialect<triton::gpu::TritonGPUDialect>([&](Operation *op) {
addDynamicallyLegalDialect<arith::ArithmeticDialect,
triton::TritonDialect,
triton::gpu::TritonGPUDialect,
StandardOpsDialect,
scf::SCFDialect>([&](Operation *op) {
if (typeConverter.isLegal(op))
return true;
return false;