Revert "Remove TypeConverter from TritonToTritonGPU conversion"

This reverts commit 64d0b87ef0.
This commit is contained in:
Yan Da
2022-06-18 14:57:41 +08:00
parent 64d0b87ef0
commit 53cf93ce6a
3 changed files with 27 additions and 26 deletions

View File

@@ -20,8 +20,9 @@ private:
}; };
class TritonGPUConversionTarget : public ConversionTarget { class TritonGPUConversionTarget : public ConversionTarget {
TritonGPUTypeConverter &typeConverter;
public: public:
explicit TritonGPUConversionTarget(MLIRContext &ctx); explicit TritonGPUConversionTarget(MLIRContext &ctx, TritonGPUTypeConverter &typeConverter);
}; };
} // namespace mlir } // namespace mlir

View File

@@ -73,7 +73,8 @@ public:
}; };
void populateArithmeticPatternsAndLegality( void populateArithmeticPatternsAndLegality(
RewritePatternSet &patterns, TritonGPUConversionTarget &target){ TritonGPUTypeConverter& typeConverter, RewritePatternSet &patterns,
TritonGPUConversionTarget &target){
// -------------- // --------------
// Add legality and rewrite pattern rules for operations // Add legality and rewrite pattern rules for operations
// from the Arithmetic dialect. The basic premise is that // from the Arithmetic dialect. The basic premise is that
@@ -127,7 +128,7 @@ void populateArithmeticPatternsAndLegality(
// Cast Ops // Cast Ops
ArithGenericPattern<arith::TruncIOp>, ArithGenericPattern<arith::TruncIOp>,
ArithGenericPattern<arith::TruncFOp> ArithGenericPattern<arith::TruncFOp>
>(context); >(typeConverter, context);
} }
// //
@@ -234,7 +235,9 @@ struct TritonReducePattern : public OpConversionPattern<triton::ReduceOp> {
} }
}; };
void populateTritonPatterns(RewritePatternSet &patterns) { void populateTritonPatterns(
TritonGPUTypeConverter& typeConverter, RewritePatternSet &patterns
) {
MLIRContext *context = patterns.getContext(); MLIRContext *context = patterns.getContext();
patterns.add<TritonGenericPattern<triton::ReshapeOp>, patterns.add<TritonGenericPattern<triton::ReshapeOp>,
TritonGenericPattern<triton::BroadcastOp>, TritonGenericPattern<triton::BroadcastOp>,
@@ -244,7 +247,7 @@ void populateTritonPatterns(RewritePatternSet &patterns) {
TritonDotPattern, TritonDotPattern,
TritonLoadPattern, TritonLoadPattern,
TritonStorePattern TritonStorePattern
>(context); >(typeConverter, context);
} }
// //
@@ -307,9 +310,12 @@ struct SCFYieldPattern : public OpConversionPattern<scf::YieldOp> {
} }
}; };
void populateSCFPatterns(RewritePatternSet &patterns) { void populateSCFPatterns(
TritonGPUTypeConverter &typeConverter, RewritePatternSet &patterns
) {
MLIRContext *context = patterns.getContext(); MLIRContext *context = patterns.getContext();
patterns.add<SCFYieldPattern, SCFForPattern>(context); patterns.add<SCFYieldPattern, SCFForPattern
>(typeConverter, context);
} }
@@ -324,16 +330,17 @@ public:
MLIRContext *context = &getContext(); MLIRContext *context = &getContext();
ModuleOp mod = getOperation(); ModuleOp mod = getOperation();
int numThreads = numWarps * 32; int numThreads = numWarps * 32;
// type converter
TritonGPUConversionTarget target(*context); TritonGPUTypeConverter typeConverter(context, numThreads);
TritonGPUConversionTarget target(*context, typeConverter);
// rewrite patterns // rewrite patterns
RewritePatternSet patterns(context); RewritePatternSet patterns(context);
// add rules // add rules
populateArithmeticPatternsAndLegality(patterns, target); populateArithmeticPatternsAndLegality(typeConverter, patterns, target);
populateTritonPatterns(patterns); populateTritonPatterns(typeConverter, patterns);
// TODO: can we use // TODO: can we use
// mlir::scf::populateSCFStructurealTypeConversionsAndLegality(...) here? // mlir::scf::populateSCFStructurealTypeConversionsAndLegality(...) here?
populateSCFPatterns(patterns); populateSCFPatterns(typeConverter, patterns);
if(failed(applyPartialConversion(mod, target, if(failed(applyPartialConversion(mod, target,
std::move(patterns)))) std::move(patterns))))

View File

@@ -90,8 +90,9 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
// //
// TritonGPUConversion // TritonGPUConversion
// //
TritonGPUConversionTarget::TritonGPUConversionTarget(MLIRContext &context) TritonGPUConversionTarget::TritonGPUConversionTarget(
: ConversionTarget(context) { MLIRContext &context, TritonGPUTypeConverter &typeConverter)
: ConversionTarget(context), typeConverter(typeConverter) {
// TODO: we should also verify ops of TritonGPUDialect // TODO: we should also verify ops of TritonGPUDialect
addLegalDialect<triton::gpu::TritonGPUDialect>(); addLegalDialect<triton::gpu::TritonGPUDialect>();
@@ -103,18 +104,7 @@ TritonGPUConversionTarget::TritonGPUConversionTarget(MLIRContext &context)
triton::TritonDialect, triton::TritonDialect,
StandardOpsDialect, StandardOpsDialect,
scf::SCFDialect>([&](Operation *op) { scf::SCFDialect>([&](Operation *op) {
auto isLegal = [](Value v) -> bool { if (typeConverter.isLegal(op))
Type type = v.getType();
if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
if (tensorType.getEncoding())
return true;
return false;
}
return true;
};
if (llvm::all_of(op->getOperands(), isLegal) &&
llvm::all_of(op->getResults(), isLegal))
return true; return true;
return false; return false;
}); });
@@ -127,6 +117,9 @@ TritonGPUConversionTarget::TritonGPUConversionTarget(MLIRContext &context)
if (aEncoding && aEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>() && if (aEncoding && aEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>() &&
bEncoding && bEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>()) bEncoding && bEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>())
return true; return true;
// // TODO: we should delete this
// if (this->typeConverter.isLegal(dotOp))
// return true;
return false; return false;
}); });