Remove TypeConverter from TritonToTritonGPU conversion

This commit is contained in:
Yan Da
2022-06-18 14:34:59 +08:00
parent 9feb256b71
commit 64d0b87ef0
3 changed files with 26 additions and 27 deletions

View File

@@ -20,9 +20,8 @@ private:
};
class TritonGPUConversionTarget : public ConversionTarget {
TritonGPUTypeConverter &typeConverter;
public:
explicit TritonGPUConversionTarget(MLIRContext &ctx, TritonGPUTypeConverter &typeConverter);
explicit TritonGPUConversionTarget(MLIRContext &ctx);
};
} // namespace mlir

View File

@@ -73,8 +73,7 @@ public:
};
void populateArithmeticPatternsAndLegality(
TritonGPUTypeConverter& typeConverter, RewritePatternSet &patterns,
TritonGPUConversionTarget &target){
RewritePatternSet &patterns, TritonGPUConversionTarget &target){
// --------------
// Add legality and rewrite pattern rules for operations
// from the Arithmetic dialect. The basic premise is that
@@ -128,7 +127,7 @@ void populateArithmeticPatternsAndLegality(
// Cast Ops
ArithGenericPattern<arith::TruncIOp>,
ArithGenericPattern<arith::TruncFOp>
>(typeConverter, context);
>(context);
}
//
@@ -235,9 +234,7 @@ struct TritonReducePattern : public OpConversionPattern<triton::ReduceOp> {
}
};
void populateTritonPatterns(
TritonGPUTypeConverter& typeConverter, RewritePatternSet &patterns
) {
void populateTritonPatterns(RewritePatternSet &patterns) {
MLIRContext *context = patterns.getContext();
patterns.add<TritonGenericPattern<triton::ReshapeOp>,
TritonGenericPattern<triton::BroadcastOp>,
@@ -247,7 +244,7 @@ void populateTritonPatterns(
TritonDotPattern,
TritonLoadPattern,
TritonStorePattern
>(typeConverter, context);
>(context);
}
//
@@ -310,12 +307,9 @@ struct SCFYieldPattern : public OpConversionPattern<scf::YieldOp> {
}
};
void populateSCFPatterns(
TritonGPUTypeConverter &typeConverter, RewritePatternSet &patterns
) {
void populateSCFPatterns(RewritePatternSet &patterns) {
MLIRContext *context = patterns.getContext();
patterns.add<SCFYieldPattern, SCFForPattern
>(typeConverter, context);
patterns.add<SCFYieldPattern, SCFForPattern>(context);
}
@@ -330,17 +324,16 @@ public:
MLIRContext *context = &getContext();
ModuleOp mod = getOperation();
int numThreads = numWarps * 32;
// type converter
TritonGPUTypeConverter typeConverter(context, numThreads);
TritonGPUConversionTarget target(*context, typeConverter);
TritonGPUConversionTarget target(*context);
// rewrite patterns
RewritePatternSet patterns(context);
// add rules
populateArithmeticPatternsAndLegality(typeConverter, patterns, target);
populateTritonPatterns(typeConverter, patterns);
populateArithmeticPatternsAndLegality(patterns, target);
populateTritonPatterns(patterns);
// TODO: can we use
// mlir::scf::populateSCFStructurealTypeConversionsAndLegality(...) here?
populateSCFPatterns(typeConverter, patterns);
populateSCFPatterns(patterns);
if(failed(applyPartialConversion(mod, target,
std::move(patterns))))

View File

@@ -90,9 +90,8 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
//
// TritonGPUConversion
//
TritonGPUConversionTarget::TritonGPUConversionTarget(
MLIRContext &context, TritonGPUTypeConverter &typeConverter)
: ConversionTarget(context), typeConverter(typeConverter) {
TritonGPUConversionTarget::TritonGPUConversionTarget(MLIRContext &context)
: ConversionTarget(context) {
// TODO: we should also verify ops of TritonGPUDialect
addLegalDialect<triton::gpu::TritonGPUDialect>();
@@ -104,7 +103,18 @@ TritonGPUConversionTarget::TritonGPUConversionTarget(
triton::TritonDialect,
StandardOpsDialect,
scf::SCFDialect>([&](Operation *op) {
if (typeConverter.isLegal(op))
auto isLegal = [](Value v) -> bool {
Type type = v.getType();
if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
if (tensorType.getEncoding())
return true;
return false;
}
return true;
};
if (llvm::all_of(op->getOperands(), isLegal) &&
llvm::all_of(op->getResults(), isLegal))
return true;
return false;
});
@@ -117,9 +127,6 @@ TritonGPUConversionTarget::TritonGPUConversionTarget(
if (aEncoding && aEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>() &&
bEncoding && bEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>())
return true;
// // TODO: we should delete this
// if (this->typeConverter.isLegal(dotOp))
// return true;
return false;
});