Revert "Remove TypeConverter from TritonToTritonGPU conversion"
This reverts commit 64d0b87ef0
.
This commit is contained in:
@@ -20,8 +20,9 @@ private:
|
||||
};
|
||||
|
||||
class TritonGPUConversionTarget : public ConversionTarget {
|
||||
TritonGPUTypeConverter &typeConverter;
|
||||
public:
|
||||
explicit TritonGPUConversionTarget(MLIRContext &ctx);
|
||||
explicit TritonGPUConversionTarget(MLIRContext &ctx, TritonGPUTypeConverter &typeConverter);
|
||||
};
|
||||
|
||||
} // namespace mlir
|
||||
|
@@ -73,7 +73,8 @@ public:
|
||||
};
|
||||
|
||||
void populateArithmeticPatternsAndLegality(
|
||||
RewritePatternSet &patterns, TritonGPUConversionTarget &target){
|
||||
TritonGPUTypeConverter& typeConverter, RewritePatternSet &patterns,
|
||||
TritonGPUConversionTarget &target){
|
||||
// --------------
|
||||
// Add legality and rewrite pattern rules for operations
|
||||
// from the Arithmetic dialect. The basic premise is that
|
||||
@@ -127,7 +128,7 @@ void populateArithmeticPatternsAndLegality(
|
||||
// Cast Ops
|
||||
ArithGenericPattern<arith::TruncIOp>,
|
||||
ArithGenericPattern<arith::TruncFOp>
|
||||
>(context);
|
||||
>(typeConverter, context);
|
||||
}
|
||||
|
||||
//
|
||||
@@ -234,7 +235,9 @@ struct TritonReducePattern : public OpConversionPattern<triton::ReduceOp> {
|
||||
}
|
||||
};
|
||||
|
||||
void populateTritonPatterns(RewritePatternSet &patterns) {
|
||||
void populateTritonPatterns(
|
||||
TritonGPUTypeConverter& typeConverter, RewritePatternSet &patterns
|
||||
) {
|
||||
MLIRContext *context = patterns.getContext();
|
||||
patterns.add<TritonGenericPattern<triton::ReshapeOp>,
|
||||
TritonGenericPattern<triton::BroadcastOp>,
|
||||
@@ -244,7 +247,7 @@ void populateTritonPatterns(RewritePatternSet &patterns) {
|
||||
TritonDotPattern,
|
||||
TritonLoadPattern,
|
||||
TritonStorePattern
|
||||
>(context);
|
||||
>(typeConverter, context);
|
||||
}
|
||||
|
||||
//
|
||||
@@ -307,9 +310,12 @@ struct SCFYieldPattern : public OpConversionPattern<scf::YieldOp> {
|
||||
}
|
||||
};
|
||||
|
||||
void populateSCFPatterns(RewritePatternSet &patterns) {
|
||||
void populateSCFPatterns(
|
||||
TritonGPUTypeConverter &typeConverter, RewritePatternSet &patterns
|
||||
) {
|
||||
MLIRContext *context = patterns.getContext();
|
||||
patterns.add<SCFYieldPattern, SCFForPattern>(context);
|
||||
patterns.add<SCFYieldPattern, SCFForPattern
|
||||
>(typeConverter, context);
|
||||
}
|
||||
|
||||
|
||||
@@ -324,16 +330,17 @@ public:
|
||||
MLIRContext *context = &getContext();
|
||||
ModuleOp mod = getOperation();
|
||||
int numThreads = numWarps * 32;
|
||||
|
||||
TritonGPUConversionTarget target(*context);
|
||||
// type converter
|
||||
TritonGPUTypeConverter typeConverter(context, numThreads);
|
||||
TritonGPUConversionTarget target(*context, typeConverter);
|
||||
// rewrite patterns
|
||||
RewritePatternSet patterns(context);
|
||||
// add rules
|
||||
populateArithmeticPatternsAndLegality(patterns, target);
|
||||
populateTritonPatterns(patterns);
|
||||
populateArithmeticPatternsAndLegality(typeConverter, patterns, target);
|
||||
populateTritonPatterns(typeConverter, patterns);
|
||||
// TODO: can we use
|
||||
// mlir::scf::populateSCFStructurealTypeConversionsAndLegality(...) here?
|
||||
populateSCFPatterns(patterns);
|
||||
populateSCFPatterns(typeConverter, patterns);
|
||||
|
||||
if(failed(applyPartialConversion(mod, target,
|
||||
std::move(patterns))))
|
||||
|
@@ -90,8 +90,9 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
|
||||
//
|
||||
// TritonGPUConversion
|
||||
//
|
||||
TritonGPUConversionTarget::TritonGPUConversionTarget(MLIRContext &context)
|
||||
: ConversionTarget(context) {
|
||||
TritonGPUConversionTarget::TritonGPUConversionTarget(
|
||||
MLIRContext &context, TritonGPUTypeConverter &typeConverter)
|
||||
: ConversionTarget(context), typeConverter(typeConverter) {
|
||||
// TODO: we should also verify ops of TritonGPUDialect
|
||||
addLegalDialect<triton::gpu::TritonGPUDialect>();
|
||||
|
||||
@@ -103,18 +104,7 @@ TritonGPUConversionTarget::TritonGPUConversionTarget(MLIRContext &context)
|
||||
triton::TritonDialect,
|
||||
StandardOpsDialect,
|
||||
scf::SCFDialect>([&](Operation *op) {
|
||||
auto isLegal = [](Value v) -> bool {
|
||||
Type type = v.getType();
|
||||
if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
|
||||
if (tensorType.getEncoding())
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
||||
if (llvm::all_of(op->getOperands(), isLegal) &&
|
||||
llvm::all_of(op->getResults(), isLegal))
|
||||
if (typeConverter.isLegal(op))
|
||||
return true;
|
||||
return false;
|
||||
});
|
||||
@@ -127,6 +117,9 @@ TritonGPUConversionTarget::TritonGPUConversionTarget(MLIRContext &context)
|
||||
if (aEncoding && aEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>() &&
|
||||
bEncoding && bEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>())
|
||||
return true;
|
||||
// // TODO: we should delete this
|
||||
// if (this->typeConverter.isLegal(dotOp))
|
||||
// return true;
|
||||
return false;
|
||||
});
|
||||
|
||||
|
Reference in New Issue
Block a user