add basic template for legalizing arithmetic op

This commit is contained in:
Phil Tillet
2022-04-30 20:42:25 -07:00
parent 2c6a213131
commit d9017f8593

View File

@@ -1,5 +1,6 @@
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
#include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Dialect.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h" #include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
#include "../PassDetail.h" #include "../PassDetail.h"
@@ -8,6 +9,44 @@ using namespace mlir::triton;
namespace { namespace {
class ConvertArithmeticOp: public ConversionPattern {
public:
ConvertArithmeticOp(TypeConverter &typeConverter, MLIRContext *context)
: ConversionPattern(typeConverter, MatchAnyOpTypeTag(), /*benefit=*/1,
context) {}
LogicalResult matchAndRewrite(Operation* op, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const override {
Dialect* dialect = op->getDialect();
if(dialect->getTypeID() != mlir::TypeID::get<arith::ArithmeticDialect>())
return failure();
// Arithmetic op to legalize here. Create layout conversion if necessary
return success();
}
};
void populateArithmeticPatternsAndLegality(
TypeConverter& typeConverter, RewritePatternSet &patterns,
ConversionTarget &target){
// --------------
// Add legality and rewrite pattern rules for operations
// from the Arithmetic dialect. The basic premise is that
// arithmetic operations require both inputs to have the same
// non-null encoding
// --------------
MLIRContext *context = patterns.getContext();
// Legality rule
target.addDynamicallyLegalDialect<arith::ArithmeticDialect>(
// TODO: check above rule here
[](Operation *op){
return false;
}
);
// Rewrite rule
patterns.add<ConvertArithmeticOp>(typeConverter, context);
}
class ConvertTritonToTritonGPU: class ConvertTritonToTritonGPU:
public ConvertTritonToTritonGPUBase<ConvertTritonToTritonGPU> { public ConvertTritonToTritonGPUBase<ConvertTritonToTritonGPU> {
@@ -24,7 +63,18 @@ public:
void runOnOperation() override { void runOnOperation() override {
MLIRContext *context = &getContext(); MLIRContext *context = &getContext();
ConversionTarget target(*context); ConversionTarget target(*context);
std::cout << "Converting" << std::endl; // type converter
TypeConverter typeConverter;
// rewrite patterns
RewritePatternSet patterns(context);
// add rules
populateArithmeticPatternsAndLegality(typeConverter, patterns, target);
if(failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
return signalPassFailure();
} }
}; };