More on TritonGPU conversion
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
|
||||
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
@@ -12,7 +13,7 @@ namespace {
|
||||
|
||||
class ConvertArithmeticOp: public ConversionPattern {
|
||||
public:
|
||||
ConvertArithmeticOp(TypeConverter &typeConverter, MLIRContext *context)
|
||||
ConvertArithmeticOp(TritonGPUTypeConverter &typeConverter, MLIRContext *context)
|
||||
: ConversionPattern(typeConverter, MatchAnyOpTypeTag(), /*benefit=*/1,
|
||||
context) {}
|
||||
|
||||
@@ -21,14 +22,13 @@ public:
|
||||
Dialect* dialect = op->getDialect();
|
||||
if(dialect->getTypeID() != mlir::TypeID::get<arith::ArithmeticDialect>())
|
||||
return failure();
|
||||
// Arithmetic op to legalize here. Create layout conversion if necessary
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void populateArithmeticPatternsAndLegality(
|
||||
TypeConverter& typeConverter, RewritePatternSet &patterns,
|
||||
ConversionTarget &target){
|
||||
TritonGPUTypeConverter& typeConverter, RewritePatternSet &patterns,
|
||||
TritonGPUConversionTarget &target){
|
||||
// --------------
|
||||
// Add legality and rewrite pattern rules for operations
|
||||
// from the Arithmetic dialect. The basic premise is that
|
||||
@@ -47,6 +47,75 @@ void populateArithmeticPatternsAndLegality(
|
||||
patterns.add<ConvertArithmeticOp>(typeConverter, context);
|
||||
}
|
||||
|
||||
//
|
||||
// Triton patterns
|
||||
//
|
||||
// TODO: Do we need to put them in anonymous namespace?
|
||||
struct TritonMakeRangePattern : public OpConversionPattern<triton::MakeRangeOp> {
|
||||
using OpConversionPattern<triton::MakeRangeOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(triton::MakeRangeOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Type retType = getTypeConverter()->convertType(op.getType());
|
||||
|
||||
rewriter.replaceOpWithNewOp<triton::MakeRangeOp>(
|
||||
op.getOperation(), retType, op.start(), op.end()
|
||||
);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct TritonBroadcastPattern : public OpConversionPattern<triton::BroadcastOp> {
|
||||
using OpConversionPattern<triton::BroadcastOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(triton::BroadcastOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Type retType = getTypeConverter()->convertType(op.getType());
|
||||
rewriter.replaceOpWithNewOp<triton::BroadcastOp>(
|
||||
op.getOperation(), retType, op.src()
|
||||
);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct TritonGEPPattern : public OpConversionPattern<triton::GEPOp> {
|
||||
using OpConversionPattern<triton::GEPOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(triton::GEPOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Type retType = getTypeConverter()->convertType(op.getType());
|
||||
rewriter.replaceOpWithNewOp<triton::GEPOp>(
|
||||
op.getOperation(), retType, op.ptr(), op.offset()
|
||||
);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct TritonLoadPattern : public OpConversionPattern<triton::LoadOp> {
|
||||
using OpConversionPattern<triton::LoadOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Type retType = getTypeConverter()->convertType(op.getType());
|
||||
rewriter.replaceOpWithNewOp<triton::LoadOp>(
|
||||
op.getOperation(), retType,
|
||||
op.ptr(), op.mask(), op.other(), op.cache(), op.evict(), op.isVolatile()
|
||||
);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void populateTritonPatterns(
|
||||
TritonGPUTypeConverter& typeConverter, RewritePatternSet &patterns
|
||||
) {
|
||||
MLIRContext *context = patterns.getContext();
|
||||
patterns.add<TritonMakeRangePattern,
|
||||
TritonBroadcastPattern,
|
||||
TritonGEPPattern,
|
||||
TritonLoadPattern
|
||||
>(typeConverter, context);
|
||||
}
|
||||
|
||||
|
||||
class ConvertTritonToTritonGPU :
|
||||
public ConvertTritonToTritonGPUBase<ConvertTritonToTritonGPU> {
|
||||
@@ -54,18 +123,19 @@ class ConvertTritonToTritonGPU :
|
||||
public:
|
||||
void runOnOperation() override {
|
||||
MLIRContext *context = &getContext();
|
||||
TritonGPUConversionTarget target(*context);
|
||||
ModuleOp mod = getOperation();
|
||||
// int numThreads = mod.getAttr();
|
||||
// type converter
|
||||
TritonGPUTypeConverter typeConverter(context, /*numThreads*/4*32);
|
||||
TritonGPUTypeConverter typeConverter(context, /*numThreads*/128);
|
||||
TritonGPUConversionTarget target(*context, typeConverter);
|
||||
// rewrite patterns
|
||||
RewritePatternSet patterns(context);
|
||||
// add rules
|
||||
populateArithmeticPatternsAndLegality(typeConverter, patterns, target);
|
||||
populateTritonPatterns(typeConverter, patterns);
|
||||
|
||||
|
||||
if(failed(applyPartialConversion(getOperation(), target,
|
||||
if(failed(applyPartialConversion(mod, target,
|
||||
std::move(patterns))))
|
||||
return signalPassFailure();
|
||||
}
|
||||
|
Reference in New Issue
Block a user