More progress on TritonGPUTypeConverter & TritonGPUConversionTarget

This commit is contained in:
Yan Da
2022-05-01 22:06:54 +08:00
parent 4ece9fd1f3
commit 1428185c9c
12 changed files with 182 additions and 22 deletions

View File

@@ -15,4 +15,5 @@ add_mlir_conversion_library(TritonToTritonGPU
MLIRPass
TritonIR
TritonGPUIR
)
TritonGPUConversion
)

View File

@@ -1,7 +1,8 @@
#include "mlir/Transforms/DialectConversion.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "../PassDetail.h"
using namespace mlir;
@@ -39,7 +40,7 @@ void populateArithmeticPatternsAndLegality(
target.addDynamicallyLegalDialect<arith::ArithmeticDialect>(
// TODO: check above rule here
[](Operation *op){
return false;
return true;
}
);
// Rewrite rule
@@ -47,26 +48,27 @@ void populateArithmeticPatternsAndLegality(
}
class ConvertTritonToTritonGPU:
class ConvertTritonToTritonGPU :
public ConvertTritonToTritonGPUBase<ConvertTritonToTritonGPU> {
public:
void runOnOperation() override {
MLIRContext *context = &getContext();
ConversionTarget target(*context);
// type converter
TypeConverter typeConverter;
// rewrite patterns
RewritePatternSet patterns(context);
// add rules
populateArithmeticPatternsAndLegality(typeConverter, patterns, target);
void runOnOperation() override {
MLIRContext *context = &getContext();
TritonGPUConversionTarget target(*context);
ModuleOp mod = getOperation();
// int numThreads = mod.getAttr();
// type converter
TritonGPUTypeConverter typeConverter(context, /*numThreads*/4*32);
// rewrite patterns
RewritePatternSet patterns(context);
// add rules
populateArithmeticPatternsAndLegality(typeConverter, patterns, target);
if(failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
return signalPassFailure();
}
if(failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
return signalPassFailure();
}
};
}