2022-04-30 14:31:18 -07:00
|
|
|
#include "triton/Dialect/Triton/IR/Dialect.h"
|
2022-05-02 21:51:00 +08:00
|
|
|
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
2022-04-30 14:31:18 -07:00
|
|
|
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
|
2022-05-01 22:06:54 +08:00
|
|
|
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
|
|
|
|
#include "mlir/Transforms/DialectConversion.h"
|
|
|
|
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
2022-04-30 14:31:18 -07:00
|
|
|
#include "../PassDetail.h"
|
|
|
|
|
|
|
|
using namespace mlir;
|
|
|
|
using namespace mlir::triton;
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
|
2022-04-30 20:42:25 -07:00
|
|
|
class ConvertArithmeticOp: public ConversionPattern {
|
|
|
|
public:
|
2022-05-02 21:51:00 +08:00
|
|
|
ConvertArithmeticOp(TritonGPUTypeConverter &typeConverter, MLIRContext *context)
|
2022-04-30 20:42:25 -07:00
|
|
|
: ConversionPattern(typeConverter, MatchAnyOpTypeTag(), /*benefit=*/1,
|
|
|
|
context) {}
|
|
|
|
|
|
|
|
LogicalResult matchAndRewrite(Operation* op, ArrayRef<Value> operands,
|
|
|
|
ConversionPatternRewriter& rewriter) const override {
|
|
|
|
Dialect* dialect = op->getDialect();
|
|
|
|
if(dialect->getTypeID() != mlir::TypeID::get<arith::ArithmeticDialect>())
|
|
|
|
return failure();
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
void populateArithmeticPatternsAndLegality(
|
2022-05-02 21:51:00 +08:00
|
|
|
TritonGPUTypeConverter& typeConverter, RewritePatternSet &patterns,
|
|
|
|
TritonGPUConversionTarget &target){
|
2022-04-30 20:42:25 -07:00
|
|
|
// --------------
|
|
|
|
// Add legality and rewrite pattern rules for operations
|
|
|
|
// from the Arithmetic dialect. The basic premise is that
|
|
|
|
// arithmetic operations require both inputs to have the same
|
|
|
|
// non-null encoding
|
|
|
|
// --------------
|
|
|
|
MLIRContext *context = patterns.getContext();
|
|
|
|
// Legality rule
|
|
|
|
target.addDynamicallyLegalDialect<arith::ArithmeticDialect>(
|
|
|
|
// TODO: check above rule here
|
|
|
|
[](Operation *op){
|
2022-05-01 22:06:54 +08:00
|
|
|
return true;
|
2022-04-30 20:42:25 -07:00
|
|
|
}
|
|
|
|
);
|
|
|
|
// Rewrite rule
|
|
|
|
patterns.add<ConvertArithmeticOp>(typeConverter, context);
|
|
|
|
}
|
|
|
|
|
2022-05-02 21:51:00 +08:00
|
|
|
//
|
|
|
|
// Triton patterns
|
|
|
|
//
|
|
|
|
// TODO: Do we need to put them in anonymous namespace?
|
|
|
|
struct TritonMakeRangePattern : public OpConversionPattern<triton::MakeRangeOp> {
|
|
|
|
using OpConversionPattern<triton::MakeRangeOp>::OpConversionPattern;
|
|
|
|
|
|
|
|
LogicalResult matchAndRewrite(triton::MakeRangeOp op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
Type retType = getTypeConverter()->convertType(op.getType());
|
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<triton::MakeRangeOp>(
|
|
|
|
op.getOperation(), retType, op.start(), op.end()
|
|
|
|
);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
struct TritonBroadcastPattern : public OpConversionPattern<triton::BroadcastOp> {
|
|
|
|
using OpConversionPattern<triton::BroadcastOp>::OpConversionPattern;
|
|
|
|
|
|
|
|
LogicalResult matchAndRewrite(triton::BroadcastOp op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
Type retType = getTypeConverter()->convertType(op.getType());
|
|
|
|
rewriter.replaceOpWithNewOp<triton::BroadcastOp>(
|
|
|
|
op.getOperation(), retType, op.src()
|
|
|
|
);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
struct TritonGEPPattern : public OpConversionPattern<triton::GEPOp> {
|
|
|
|
using OpConversionPattern<triton::GEPOp>::OpConversionPattern;
|
|
|
|
|
|
|
|
LogicalResult matchAndRewrite(triton::GEPOp op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
Type retType = getTypeConverter()->convertType(op.getType());
|
|
|
|
rewriter.replaceOpWithNewOp<triton::GEPOp>(
|
|
|
|
op.getOperation(), retType, op.ptr(), op.offset()
|
|
|
|
);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
struct TritonLoadPattern : public OpConversionPattern<triton::LoadOp> {
|
|
|
|
using OpConversionPattern<triton::LoadOp>::OpConversionPattern;
|
|
|
|
|
|
|
|
LogicalResult matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
Type retType = getTypeConverter()->convertType(op.getType());
|
|
|
|
rewriter.replaceOpWithNewOp<triton::LoadOp>(
|
|
|
|
op.getOperation(), retType,
|
|
|
|
op.ptr(), op.mask(), op.other(), op.cache(), op.evict(), op.isVolatile()
|
|
|
|
);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
void populateTritonPatterns(
|
|
|
|
TritonGPUTypeConverter& typeConverter, RewritePatternSet &patterns
|
|
|
|
) {
|
|
|
|
MLIRContext *context = patterns.getContext();
|
|
|
|
patterns.add<TritonMakeRangePattern,
|
|
|
|
TritonBroadcastPattern,
|
|
|
|
TritonGEPPattern,
|
|
|
|
TritonLoadPattern
|
|
|
|
>(typeConverter, context);
|
|
|
|
}
|
|
|
|
|
2022-04-30 20:42:25 -07:00
|
|
|
|
2022-05-01 22:06:54 +08:00
|
|
|
class ConvertTritonToTritonGPU :
|
2022-04-30 14:31:18 -07:00
|
|
|
public ConvertTritonToTritonGPUBase<ConvertTritonToTritonGPU> {
|
|
|
|
|
|
|
|
public:
|
2022-05-01 22:06:54 +08:00
|
|
|
void runOnOperation() override {
|
|
|
|
MLIRContext *context = &getContext();
|
|
|
|
ModuleOp mod = getOperation();
|
|
|
|
// int numThreads = mod.getAttr();
|
|
|
|
// type converter
|
2022-05-02 21:51:00 +08:00
|
|
|
TritonGPUTypeConverter typeConverter(context, /*numThreads*/128);
|
|
|
|
TritonGPUConversionTarget target(*context, typeConverter);
|
2022-05-01 22:06:54 +08:00
|
|
|
// rewrite patterns
|
|
|
|
RewritePatternSet patterns(context);
|
|
|
|
// add rules
|
|
|
|
populateArithmeticPatternsAndLegality(typeConverter, patterns, target);
|
2022-05-02 21:51:00 +08:00
|
|
|
populateTritonPatterns(typeConverter, patterns);
|
2022-04-30 20:42:25 -07:00
|
|
|
|
|
|
|
|
2022-05-02 21:51:00 +08:00
|
|
|
if(failed(applyPartialConversion(mod, target,
|
2022-05-01 22:06:54 +08:00
|
|
|
std::move(patterns))))
|
|
|
|
return signalPassFailure();
|
|
|
|
}
|
2022-04-30 14:31:18 -07:00
|
|
|
};
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
std::unique_ptr<OperationPass<ModuleOp>>
|
|
|
|
mlir::triton::createConvertTritonToTritonGPUPass() {
|
|
|
|
return std::make_unique<::ConvertTritonToTritonGPU>();
|
2022-05-01 13:06:51 +08:00
|
|
|
}
|