Revert "Remove TypeConverter from TritonToTritonGPU conversion"

This reverts commit 64d0b87ef0.
This commit is contained in:
Yan Da
2022-06-18 14:57:41 +08:00
parent 64d0b87ef0
commit 53cf93ce6a
3 changed files with 27 additions and 26 deletions

View File

@@ -73,7 +73,8 @@ public:
};
void populateArithmeticPatternsAndLegality(
RewritePatternSet &patterns, TritonGPUConversionTarget &target){
TritonGPUTypeConverter& typeConverter, RewritePatternSet &patterns,
TritonGPUConversionTarget &target){
// --------------
// Add legality and rewrite pattern rules for operations
// from the Arithmetic dialect. The basic premise is that
@@ -127,7 +128,7 @@ void populateArithmeticPatternsAndLegality(
// Cast Ops
ArithGenericPattern<arith::TruncIOp>,
ArithGenericPattern<arith::TruncFOp>
>(context);
>(typeConverter, context);
}
//
@@ -234,7 +235,9 @@ struct TritonReducePattern : public OpConversionPattern<triton::ReduceOp> {
}
};
void populateTritonPatterns(RewritePatternSet &patterns) {
void populateTritonPatterns(
TritonGPUTypeConverter& typeConverter, RewritePatternSet &patterns
) {
MLIRContext *context = patterns.getContext();
patterns.add<TritonGenericPattern<triton::ReshapeOp>,
TritonGenericPattern<triton::BroadcastOp>,
@@ -244,7 +247,7 @@ void populateTritonPatterns(RewritePatternSet &patterns) {
TritonDotPattern,
TritonLoadPattern,
TritonStorePattern
>(context);
>(typeConverter, context);
}
//
@@ -307,9 +310,12 @@ struct SCFYieldPattern : public OpConversionPattern<scf::YieldOp> {
}
};
void populateSCFPatterns(RewritePatternSet &patterns) {
void populateSCFPatterns(
TritonGPUTypeConverter &typeConverter, RewritePatternSet &patterns
) {
MLIRContext *context = patterns.getContext();
patterns.add<SCFYieldPattern, SCFForPattern>(context);
patterns.add<SCFYieldPattern, SCFForPattern
>(typeConverter, context);
}
@@ -324,16 +330,17 @@ public:
MLIRContext *context = &getContext();
ModuleOp mod = getOperation();
int numThreads = numWarps * 32;
TritonGPUConversionTarget target(*context);
// type converter
TritonGPUTypeConverter typeConverter(context, numThreads);
TritonGPUConversionTarget target(*context, typeConverter);
// rewrite patterns
RewritePatternSet patterns(context);
// add rules
populateArithmeticPatternsAndLegality(patterns, target);
populateTritonPatterns(patterns);
populateArithmeticPatternsAndLegality(typeConverter, patterns, target);
populateTritonPatterns(typeConverter, patterns);
// TODO: can we use
// mlir::scf::populateSCFStructurealTypeConversionsAndLegality(...) here?
populateSCFPatterns(patterns);
populateSCFPatterns(typeConverter, patterns);
if(failed(applyPartialConversion(mod, target,
std::move(patterns))))