More progress on TritonGPUTypeConverter & TritonGPUConversionTarget
This commit is contained in:
@@ -15,4 +15,5 @@ add_mlir_conversion_library(TritonToTritonGPU
|
||||
MLIRPass
|
||||
TritonIR
|
||||
TritonGPUIR
|
||||
)
|
||||
TritonGPUConversion
|
||||
)
|
||||
|
@@ -1,7 +1,8 @@
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
|
||||
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "../PassDetail.h"
|
||||
|
||||
using namespace mlir;
|
||||
@@ -39,7 +40,7 @@ void populateArithmeticPatternsAndLegality(
|
||||
target.addDynamicallyLegalDialect<arith::ArithmeticDialect>(
|
||||
// TODO: check above rule here
|
||||
[](Operation *op){
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
);
|
||||
// Rewrite rule
|
||||
@@ -47,26 +48,27 @@ void populateArithmeticPatternsAndLegality(
|
||||
}
|
||||
|
||||
|
||||
class ConvertTritonToTritonGPU:
|
||||
class ConvertTritonToTritonGPU :
|
||||
public ConvertTritonToTritonGPUBase<ConvertTritonToTritonGPU> {
|
||||
|
||||
public:
|
||||
void runOnOperation() override {
|
||||
MLIRContext *context = &getContext();
|
||||
ConversionTarget target(*context);
|
||||
// type converter
|
||||
TypeConverter typeConverter;
|
||||
// rewrite patterns
|
||||
RewritePatternSet patterns(context);
|
||||
// add rules
|
||||
populateArithmeticPatternsAndLegality(typeConverter, patterns, target);
|
||||
void runOnOperation() override {
|
||||
MLIRContext *context = &getContext();
|
||||
TritonGPUConversionTarget target(*context);
|
||||
ModuleOp mod = getOperation();
|
||||
// int numThreads = mod.getAttr();
|
||||
// type converter
|
||||
TritonGPUTypeConverter typeConverter(context, /*numThreads*/4*32);
|
||||
// rewrite patterns
|
||||
RewritePatternSet patterns(context);
|
||||
// add rules
|
||||
populateArithmeticPatternsAndLegality(typeConverter, patterns, target);
|
||||
|
||||
|
||||
if(failed(applyPartialConversion(getOperation(), target,
|
||||
std::move(patterns))))
|
||||
return signalPassFailure();
|
||||
|
||||
}
|
||||
if(failed(applyPartialConversion(getOperation(), target,
|
||||
std::move(patterns))))
|
||||
return signalPassFailure();
|
||||
}
|
||||
};
|
||||
|
||||
}
|
||||
|
Reference in New Issue
Block a user