More on SCF conversion
This commit is contained in:
@@ -5,6 +5,7 @@
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "../PassDetail.h"
|
||||
#include <llvm-6.0/llvm/Support/ErrorHandling.h>
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
@@ -12,7 +13,7 @@ using namespace mlir::triton;
|
||||
namespace {
|
||||
|
||||
template<class Op>
|
||||
class ArithBinaryPattern : public OpConversionPattern<Op> {
|
||||
class ArithGenericPattern : public OpConversionPattern<Op> {
|
||||
public:
|
||||
using OpConversionPattern<Op>::OpConversionPattern;
|
||||
|
||||
@@ -48,8 +49,10 @@ public:
|
||||
LogicalResult matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Type retType = getTypeConverter()->convertType(op.getType());
|
||||
auto value = adaptor.getValue().dyn_cast<DenseElementsAttr>();
|
||||
assert(value);
|
||||
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
|
||||
op, retType, adaptor.getValue()
|
||||
op, retType, value.reshape(retType) // This is a hack. We just want to add encoding
|
||||
);
|
||||
return success();
|
||||
}
|
||||
@@ -90,40 +93,42 @@ void populateArithmeticPatternsAndLegality(
|
||||
// Rewrite rule
|
||||
// patterns.add<ConvertArithmeticOp>(typeConverter, context);
|
||||
patterns.add<ArithConstantPattern,
|
||||
ArithBinaryPattern<arith::AddIOp>,
|
||||
ArithBinaryPattern<arith::SubIOp>,
|
||||
ArithBinaryPattern<arith::MulIOp>,
|
||||
ArithBinaryPattern<arith::DivUIOp>,
|
||||
ArithBinaryPattern<arith::DivSIOp>,
|
||||
ArithBinaryPattern<arith::CeilDivUIOp>,
|
||||
ArithBinaryPattern<arith::CeilDivSIOp>,
|
||||
ArithBinaryPattern<arith::FloorDivSIOp>,
|
||||
ArithBinaryPattern<arith::RemUIOp>,
|
||||
ArithBinaryPattern<arith::RemSIOp>,
|
||||
ArithBinaryPattern<arith::AndIOp>,
|
||||
ArithBinaryPattern<arith::OrIOp>,
|
||||
ArithBinaryPattern<arith::XOrIOp>,
|
||||
ArithBinaryPattern<arith::ShLIOp>,
|
||||
ArithBinaryPattern<arith::ShRUIOp>,
|
||||
ArithBinaryPattern<arith::ShRSIOp>, // NegFOp
|
||||
ArithGenericPattern<arith::AddIOp>,
|
||||
ArithGenericPattern<arith::SubIOp>,
|
||||
ArithGenericPattern<arith::MulIOp>,
|
||||
ArithGenericPattern<arith::DivUIOp>,
|
||||
ArithGenericPattern<arith::DivSIOp>,
|
||||
ArithGenericPattern<arith::CeilDivUIOp>,
|
||||
ArithGenericPattern<arith::CeilDivSIOp>,
|
||||
ArithGenericPattern<arith::FloorDivSIOp>,
|
||||
ArithGenericPattern<arith::RemUIOp>,
|
||||
ArithGenericPattern<arith::RemSIOp>,
|
||||
ArithGenericPattern<arith::AndIOp>,
|
||||
ArithGenericPattern<arith::OrIOp>,
|
||||
ArithGenericPattern<arith::XOrIOp>,
|
||||
ArithGenericPattern<arith::ShLIOp>,
|
||||
ArithGenericPattern<arith::ShRUIOp>,
|
||||
ArithGenericPattern<arith::ShRSIOp>, // NegFOp
|
||||
// Floating point
|
||||
ArithBinaryPattern<arith::AddFOp>,
|
||||
ArithBinaryPattern<arith::SubFOp>,
|
||||
ArithGenericPattern<arith::AddFOp>,
|
||||
ArithGenericPattern<arith::SubFOp>,
|
||||
// MaxMin
|
||||
ArithBinaryPattern<arith::MaxFOp>,
|
||||
ArithBinaryPattern<arith::MaxSIOp>,
|
||||
ArithBinaryPattern<arith::MaxUIOp>,
|
||||
ArithBinaryPattern<arith::MinFOp>,
|
||||
ArithBinaryPattern<arith::MinSIOp>,
|
||||
ArithBinaryPattern<arith::MinUIOp>,
|
||||
ArithGenericPattern<arith::MaxFOp>,
|
||||
ArithGenericPattern<arith::MaxSIOp>,
|
||||
ArithGenericPattern<arith::MaxUIOp>,
|
||||
ArithGenericPattern<arith::MinFOp>,
|
||||
ArithGenericPattern<arith::MinSIOp>,
|
||||
ArithGenericPattern<arith::MinUIOp>,
|
||||
// Floating point
|
||||
ArithBinaryPattern<arith::MulFOp>,
|
||||
ArithBinaryPattern<arith::DivFOp>,
|
||||
ArithBinaryPattern<arith::RemFOp>,
|
||||
ArithGenericPattern<arith::MulFOp>,
|
||||
ArithGenericPattern<arith::DivFOp>,
|
||||
ArithGenericPattern<arith::RemFOp>,
|
||||
// Cmp
|
||||
ArithCmpPattern<arith::CmpIOp, triton::gpu::CmpIOp>,
|
||||
ArithCmpPattern<arith::CmpFOp, triton::gpu::CmpFOp>
|
||||
ArithCmpPattern<arith::CmpFOp, triton::gpu::CmpFOp>,
|
||||
// Cast Ops
|
||||
ArithGenericPattern<arith::TruncIOp>,
|
||||
ArithGenericPattern<arith::TruncFOp>
|
||||
>(typeConverter, context);
|
||||
}
|
||||
|
||||
@@ -212,6 +217,46 @@ void populateTritonPatterns(
|
||||
>(typeConverter, context);
|
||||
}
|
||||
|
||||
//
|
||||
// SCF patterns
|
||||
//
|
||||
struct SCFForPattern : public OpConversionPattern<scf::ForOp> {
|
||||
using OpConversionPattern<scf::ForOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(scf::ForOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto newFor = rewriter.replaceOpWithNewOp<scf::ForOp>(
|
||||
op, adaptor.getLowerBound(), adaptor.getUpperBound(), adaptor.getStep(),
|
||||
adaptor.getInitArgs()
|
||||
);
|
||||
// TODO: we need to copy (?) the body of ForOp
|
||||
llvm_unreachable("Not implemented");
|
||||
// newFor.getRegion().takeBody(adaptor.getRegion());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct SCFYieldPattern : public OpConversionPattern<scf::YieldOp> {
|
||||
using OpConversionPattern<scf::YieldOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(scf::YieldOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
rewriter.replaceOpWithNewOp<scf::YieldOp>(
|
||||
op, adaptor.getResults()
|
||||
);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void populateSCFPatterns(
|
||||
TritonGPUTypeConverter &typeConverter, RewritePatternSet &patterns
|
||||
) {
|
||||
MLIRContext *context = patterns.getContext();
|
||||
patterns.add<SCFForPattern,
|
||||
SCFYieldPattern
|
||||
>(typeConverter, context);
|
||||
}
|
||||
|
||||
|
||||
class ConvertTritonToTritonGPU :
|
||||
public ConvertTritonToTritonGPUBase<ConvertTritonToTritonGPU> {
|
||||
@@ -229,7 +274,7 @@ public:
|
||||
// add rules
|
||||
populateArithmeticPatternsAndLegality(typeConverter, patterns, target);
|
||||
populateTritonPatterns(typeConverter, patterns);
|
||||
|
||||
populateSCFPatterns(typeConverter, patterns);
|
||||
|
||||
if(failed(applyPartialConversion(mod, target,
|
||||
std::move(patterns))))
|
||||
|
@@ -2,6 +2,7 @@
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include <algorithm>
|
||||
#include <llvm-6.0/llvm/Support/ErrorHandling.h>
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
@@ -45,6 +46,25 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
|
||||
context, threadTileSize, blockTileSize, order);
|
||||
return RankedTensorType::get(shape, elementType, encoding);
|
||||
});
|
||||
|
||||
// materailizations
|
||||
addArgumentMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
|
||||
ValueRange inputs, Location loc) {
|
||||
// llvm::errs() << "Trying to materialize target... " << inputs[0] << "\n";
|
||||
llvm_unreachable("Not implemented");
|
||||
return llvm::None;
|
||||
});
|
||||
addSourceMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
|
||||
ValueRange inputs, Location loc) {
|
||||
llvm_unreachable("Not implemented");
|
||||
return llvm::None;
|
||||
});
|
||||
addTargetMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
|
||||
ValueRange inputs, Location loc) {
|
||||
llvm_unreachable("Not implemented");
|
||||
// llvm::errs() << "Trying to materialize target... " << inputs[0] << "\n";
|
||||
return llvm::None;
|
||||
});
|
||||
}
|
||||
|
||||
//
|
||||
@@ -53,25 +73,15 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
|
||||
TritonGPUConversionTarget::TritonGPUConversionTarget(
|
||||
MLIRContext &context, TritonGPUTypeConverter &typeConverter)
|
||||
: ConversionTarget(context), typeConverter(typeConverter) {
|
||||
addLegalDialect<StandardOpsDialect, scf::SCFDialect>();
|
||||
|
||||
// Some ops from SCF are illegal
|
||||
addIllegalOp<scf::ExecuteRegionOp, scf::ParallelOp,
|
||||
scf::ReduceOp, scf::ReduceReturnOp>();
|
||||
|
||||
addDynamicallyLegalDialect<arith::ArithmeticDialect>([&](Operation *op) {
|
||||
if (typeConverter.isLegal(op))
|
||||
return true;
|
||||
return false;
|
||||
});
|
||||
|
||||
addDynamicallyLegalDialect<triton::TritonDialect>([&](Operation *op) {
|
||||
if (typeConverter.isLegal(op))
|
||||
return true;
|
||||
return false;
|
||||
});
|
||||
|
||||
addDynamicallyLegalDialect<triton::gpu::TritonGPUDialect>([&](Operation *op) {
|
||||
addDynamicallyLegalDialect<arith::ArithmeticDialect,
|
||||
triton::TritonDialect,
|
||||
triton::gpu::TritonGPUDialect,
|
||||
StandardOpsDialect,
|
||||
scf::SCFDialect>([&](Operation *op) {
|
||||
if (typeConverter.isLegal(op))
|
||||
return true;
|
||||
return false;
|
||||
|
Reference in New Issue
Block a user