add basic template for legalizing arithmetic op
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
|
||||
#include "../PassDetail.h"
|
||||
|
||||
@@ -8,6 +9,44 @@ using namespace mlir::triton;
|
||||
|
||||
namespace {
|
||||
|
||||
class ConvertArithmeticOp: public ConversionPattern {
|
||||
public:
|
||||
ConvertArithmeticOp(TypeConverter &typeConverter, MLIRContext *context)
|
||||
: ConversionPattern(typeConverter, MatchAnyOpTypeTag(), /*benefit=*/1,
|
||||
context) {}
|
||||
|
||||
LogicalResult matchAndRewrite(Operation* op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter& rewriter) const override {
|
||||
Dialect* dialect = op->getDialect();
|
||||
if(dialect->getTypeID() != mlir::TypeID::get<arith::ArithmeticDialect>())
|
||||
return failure();
|
||||
// Arithmetic op to legalize here. Create layout conversion if necessary
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void populateArithmeticPatternsAndLegality(
|
||||
TypeConverter& typeConverter, RewritePatternSet &patterns,
|
||||
ConversionTarget &target){
|
||||
// --------------
|
||||
// Add legality and rewrite pattern rules for operations
|
||||
// from the Arithmetic dialect. The basic premise is that
|
||||
// arithmetic operations require both inputs to have the same
|
||||
// non-null encoding
|
||||
// --------------
|
||||
MLIRContext *context = patterns.getContext();
|
||||
// Legality rule
|
||||
target.addDynamicallyLegalDialect<arith::ArithmeticDialect>(
|
||||
// TODO: check above rule here
|
||||
[](Operation *op){
|
||||
return false;
|
||||
}
|
||||
);
|
||||
// Rewrite rule
|
||||
patterns.add<ConvertArithmeticOp>(typeConverter, context);
|
||||
}
|
||||
|
||||
|
||||
class ConvertTritonToTritonGPU:
|
||||
public ConvertTritonToTritonGPUBase<ConvertTritonToTritonGPU> {
|
||||
|
||||
@@ -24,7 +63,18 @@ public:
|
||||
void runOnOperation() override {
|
||||
MLIRContext *context = &getContext();
|
||||
ConversionTarget target(*context);
|
||||
std::cout << "Converting" << std::endl;
|
||||
// type converter
|
||||
TypeConverter typeConverter;
|
||||
// rewrite patterns
|
||||
RewritePatternSet patterns(context);
|
||||
// add rules
|
||||
populateArithmeticPatternsAndLegality(typeConverter, patterns, target);
|
||||
|
||||
|
||||
if(failed(applyPartialConversion(getOperation(), target,
|
||||
std::move(patterns))))
|
||||
return signalPassFailure();
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
|
Reference in New Issue
Block a user