More on SCF conversion
This commit is contained in:
@@ -5,6 +5,7 @@
|
|||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||||
#include "../PassDetail.h"
|
#include "../PassDetail.h"
|
||||||
|
#include <llvm-6.0/llvm/Support/ErrorHandling.h>
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::triton;
|
using namespace mlir::triton;
|
||||||
@@ -12,7 +13,7 @@ using namespace mlir::triton;
|
|||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
template<class Op>
|
template<class Op>
|
||||||
class ArithBinaryPattern : public OpConversionPattern<Op> {
|
class ArithGenericPattern : public OpConversionPattern<Op> {
|
||||||
public:
|
public:
|
||||||
using OpConversionPattern<Op>::OpConversionPattern;
|
using OpConversionPattern<Op>::OpConversionPattern;
|
||||||
|
|
||||||
@@ -48,8 +49,10 @@ public:
|
|||||||
LogicalResult matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
|
LogicalResult matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
Type retType = getTypeConverter()->convertType(op.getType());
|
Type retType = getTypeConverter()->convertType(op.getType());
|
||||||
|
auto value = adaptor.getValue().dyn_cast<DenseElementsAttr>();
|
||||||
|
assert(value);
|
||||||
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
|
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
|
||||||
op, retType, adaptor.getValue()
|
op, retType, value.reshape(retType) // This is a hack. We just want to add encoding
|
||||||
);
|
);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
@@ -90,40 +93,42 @@ void populateArithmeticPatternsAndLegality(
|
|||||||
// Rewrite rule
|
// Rewrite rule
|
||||||
// patterns.add<ConvertArithmeticOp>(typeConverter, context);
|
// patterns.add<ConvertArithmeticOp>(typeConverter, context);
|
||||||
patterns.add<ArithConstantPattern,
|
patterns.add<ArithConstantPattern,
|
||||||
ArithBinaryPattern<arith::AddIOp>,
|
ArithGenericPattern<arith::AddIOp>,
|
||||||
ArithBinaryPattern<arith::SubIOp>,
|
ArithGenericPattern<arith::SubIOp>,
|
||||||
ArithBinaryPattern<arith::MulIOp>,
|
ArithGenericPattern<arith::MulIOp>,
|
||||||
ArithBinaryPattern<arith::DivUIOp>,
|
ArithGenericPattern<arith::DivUIOp>,
|
||||||
ArithBinaryPattern<arith::DivSIOp>,
|
ArithGenericPattern<arith::DivSIOp>,
|
||||||
ArithBinaryPattern<arith::CeilDivUIOp>,
|
ArithGenericPattern<arith::CeilDivUIOp>,
|
||||||
ArithBinaryPattern<arith::CeilDivSIOp>,
|
ArithGenericPattern<arith::CeilDivSIOp>,
|
||||||
ArithBinaryPattern<arith::FloorDivSIOp>,
|
ArithGenericPattern<arith::FloorDivSIOp>,
|
||||||
ArithBinaryPattern<arith::RemUIOp>,
|
ArithGenericPattern<arith::RemUIOp>,
|
||||||
ArithBinaryPattern<arith::RemSIOp>,
|
ArithGenericPattern<arith::RemSIOp>,
|
||||||
ArithBinaryPattern<arith::AndIOp>,
|
ArithGenericPattern<arith::AndIOp>,
|
||||||
ArithBinaryPattern<arith::OrIOp>,
|
ArithGenericPattern<arith::OrIOp>,
|
||||||
ArithBinaryPattern<arith::XOrIOp>,
|
ArithGenericPattern<arith::XOrIOp>,
|
||||||
ArithBinaryPattern<arith::ShLIOp>,
|
ArithGenericPattern<arith::ShLIOp>,
|
||||||
ArithBinaryPattern<arith::ShRUIOp>,
|
ArithGenericPattern<arith::ShRUIOp>,
|
||||||
ArithBinaryPattern<arith::ShRSIOp>, // NegFOp
|
ArithGenericPattern<arith::ShRSIOp>, // NegFOp
|
||||||
// Floating point
|
// Floating point
|
||||||
ArithBinaryPattern<arith::AddFOp>,
|
ArithGenericPattern<arith::AddFOp>,
|
||||||
ArithBinaryPattern<arith::SubFOp>,
|
ArithGenericPattern<arith::SubFOp>,
|
||||||
// MaxMin
|
// MaxMin
|
||||||
ArithBinaryPattern<arith::MaxFOp>,
|
ArithGenericPattern<arith::MaxFOp>,
|
||||||
ArithBinaryPattern<arith::MaxSIOp>,
|
ArithGenericPattern<arith::MaxSIOp>,
|
||||||
ArithBinaryPattern<arith::MaxUIOp>,
|
ArithGenericPattern<arith::MaxUIOp>,
|
||||||
ArithBinaryPattern<arith::MinFOp>,
|
ArithGenericPattern<arith::MinFOp>,
|
||||||
ArithBinaryPattern<arith::MinSIOp>,
|
ArithGenericPattern<arith::MinSIOp>,
|
||||||
ArithBinaryPattern<arith::MinUIOp>,
|
ArithGenericPattern<arith::MinUIOp>,
|
||||||
// Floating point
|
// Floating point
|
||||||
ArithBinaryPattern<arith::MulFOp>,
|
ArithGenericPattern<arith::MulFOp>,
|
||||||
ArithBinaryPattern<arith::DivFOp>,
|
ArithGenericPattern<arith::DivFOp>,
|
||||||
ArithBinaryPattern<arith::RemFOp>,
|
ArithGenericPattern<arith::RemFOp>,
|
||||||
// Cmp
|
// Cmp
|
||||||
ArithCmpPattern<arith::CmpIOp, triton::gpu::CmpIOp>,
|
ArithCmpPattern<arith::CmpIOp, triton::gpu::CmpIOp>,
|
||||||
ArithCmpPattern<arith::CmpFOp, triton::gpu::CmpFOp>
|
ArithCmpPattern<arith::CmpFOp, triton::gpu::CmpFOp>,
|
||||||
// Cast Ops
|
// Cast Ops
|
||||||
|
ArithGenericPattern<arith::TruncIOp>,
|
||||||
|
ArithGenericPattern<arith::TruncFOp>
|
||||||
>(typeConverter, context);
|
>(typeConverter, context);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -212,6 +217,46 @@ void populateTritonPatterns(
|
|||||||
>(typeConverter, context);
|
>(typeConverter, context);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// SCF patterns
|
||||||
|
//
|
||||||
|
struct SCFForPattern : public OpConversionPattern<scf::ForOp> {
|
||||||
|
using OpConversionPattern<scf::ForOp>::OpConversionPattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(scf::ForOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
auto newFor = rewriter.replaceOpWithNewOp<scf::ForOp>(
|
||||||
|
op, adaptor.getLowerBound(), adaptor.getUpperBound(), adaptor.getStep(),
|
||||||
|
adaptor.getInitArgs()
|
||||||
|
);
|
||||||
|
// TODO: we need to copy (?) the body of ForOp
|
||||||
|
llvm_unreachable("Not implemented");
|
||||||
|
// newFor.getRegion().takeBody(adaptor.getRegion());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct SCFYieldPattern : public OpConversionPattern<scf::YieldOp> {
|
||||||
|
using OpConversionPattern<scf::YieldOp>::OpConversionPattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(scf::YieldOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
rewriter.replaceOpWithNewOp<scf::YieldOp>(
|
||||||
|
op, adaptor.getResults()
|
||||||
|
);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
void populateSCFPatterns(
|
||||||
|
TritonGPUTypeConverter &typeConverter, RewritePatternSet &patterns
|
||||||
|
) {
|
||||||
|
MLIRContext *context = patterns.getContext();
|
||||||
|
patterns.add<SCFForPattern,
|
||||||
|
SCFYieldPattern
|
||||||
|
>(typeConverter, context);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class ConvertTritonToTritonGPU :
|
class ConvertTritonToTritonGPU :
|
||||||
public ConvertTritonToTritonGPUBase<ConvertTritonToTritonGPU> {
|
public ConvertTritonToTritonGPUBase<ConvertTritonToTritonGPU> {
|
||||||
@@ -229,7 +274,7 @@ public:
|
|||||||
// add rules
|
// add rules
|
||||||
populateArithmeticPatternsAndLegality(typeConverter, patterns, target);
|
populateArithmeticPatternsAndLegality(typeConverter, patterns, target);
|
||||||
populateTritonPatterns(typeConverter, patterns);
|
populateTritonPatterns(typeConverter, patterns);
|
||||||
|
populateSCFPatterns(typeConverter, patterns);
|
||||||
|
|
||||||
if(failed(applyPartialConversion(mod, target,
|
if(failed(applyPartialConversion(mod, target,
|
||||||
std::move(patterns))))
|
std::move(patterns))))
|
||||||
|
@@ -2,6 +2,7 @@
|
|||||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
#include <llvm-6.0/llvm/Support/ErrorHandling.h>
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
@@ -45,6 +46,25 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
|
|||||||
context, threadTileSize, blockTileSize, order);
|
context, threadTileSize, blockTileSize, order);
|
||||||
return RankedTensorType::get(shape, elementType, encoding);
|
return RankedTensorType::get(shape, elementType, encoding);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// materailizations
|
||||||
|
addArgumentMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
|
||||||
|
ValueRange inputs, Location loc) {
|
||||||
|
// llvm::errs() << "Trying to materialize target... " << inputs[0] << "\n";
|
||||||
|
llvm_unreachable("Not implemented");
|
||||||
|
return llvm::None;
|
||||||
|
});
|
||||||
|
addSourceMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
|
||||||
|
ValueRange inputs, Location loc) {
|
||||||
|
llvm_unreachable("Not implemented");
|
||||||
|
return llvm::None;
|
||||||
|
});
|
||||||
|
addTargetMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
|
||||||
|
ValueRange inputs, Location loc) {
|
||||||
|
llvm_unreachable("Not implemented");
|
||||||
|
// llvm::errs() << "Trying to materialize target... " << inputs[0] << "\n";
|
||||||
|
return llvm::None;
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
@@ -53,25 +73,15 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
|
|||||||
TritonGPUConversionTarget::TritonGPUConversionTarget(
|
TritonGPUConversionTarget::TritonGPUConversionTarget(
|
||||||
MLIRContext &context, TritonGPUTypeConverter &typeConverter)
|
MLIRContext &context, TritonGPUTypeConverter &typeConverter)
|
||||||
: ConversionTarget(context), typeConverter(typeConverter) {
|
: ConversionTarget(context), typeConverter(typeConverter) {
|
||||||
addLegalDialect<StandardOpsDialect, scf::SCFDialect>();
|
|
||||||
|
|
||||||
// Some ops from SCF are illegal
|
// Some ops from SCF are illegal
|
||||||
addIllegalOp<scf::ExecuteRegionOp, scf::ParallelOp,
|
addIllegalOp<scf::ExecuteRegionOp, scf::ParallelOp,
|
||||||
scf::ReduceOp, scf::ReduceReturnOp>();
|
scf::ReduceOp, scf::ReduceReturnOp>();
|
||||||
|
|
||||||
addDynamicallyLegalDialect<arith::ArithmeticDialect>([&](Operation *op) {
|
addDynamicallyLegalDialect<arith::ArithmeticDialect,
|
||||||
if (typeConverter.isLegal(op))
|
triton::TritonDialect,
|
||||||
return true;
|
triton::gpu::TritonGPUDialect,
|
||||||
return false;
|
StandardOpsDialect,
|
||||||
});
|
scf::SCFDialect>([&](Operation *op) {
|
||||||
|
|
||||||
addDynamicallyLegalDialect<triton::TritonDialect>([&](Operation *op) {
|
|
||||||
if (typeConverter.isLegal(op))
|
|
||||||
return true;
|
|
||||||
return false;
|
|
||||||
});
|
|
||||||
|
|
||||||
addDynamicallyLegalDialect<triton::gpu::TritonGPUDialect>([&](Operation *op) {
|
|
||||||
if (typeConverter.isLegal(op))
|
if (typeConverter.isLegal(op))
|
||||||
return true;
|
return true;
|
||||||
return false;
|
return false;
|
||||||
|
Reference in New Issue
Block a user