Remove TypeConverter from TritonToTritonGPU conversion
This commit is contained in:
@@ -20,9 +20,8 @@ private:
|
|||||||
};
|
};
|
||||||
|
|
||||||
class TritonGPUConversionTarget : public ConversionTarget {
|
class TritonGPUConversionTarget : public ConversionTarget {
|
||||||
TritonGPUTypeConverter &typeConverter;
|
|
||||||
public:
|
public:
|
||||||
explicit TritonGPUConversionTarget(MLIRContext &ctx, TritonGPUTypeConverter &typeConverter);
|
explicit TritonGPUConversionTarget(MLIRContext &ctx);
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
@@ -73,8 +73,7 @@ public:
|
|||||||
};
|
};
|
||||||
|
|
||||||
void populateArithmeticPatternsAndLegality(
|
void populateArithmeticPatternsAndLegality(
|
||||||
TritonGPUTypeConverter& typeConverter, RewritePatternSet &patterns,
|
RewritePatternSet &patterns, TritonGPUConversionTarget &target){
|
||||||
TritonGPUConversionTarget &target){
|
|
||||||
// --------------
|
// --------------
|
||||||
// Add legality and rewrite pattern rules for operations
|
// Add legality and rewrite pattern rules for operations
|
||||||
// from the Arithmetic dialect. The basic premise is that
|
// from the Arithmetic dialect. The basic premise is that
|
||||||
@@ -128,7 +127,7 @@ void populateArithmeticPatternsAndLegality(
|
|||||||
// Cast Ops
|
// Cast Ops
|
||||||
ArithGenericPattern<arith::TruncIOp>,
|
ArithGenericPattern<arith::TruncIOp>,
|
||||||
ArithGenericPattern<arith::TruncFOp>
|
ArithGenericPattern<arith::TruncFOp>
|
||||||
>(typeConverter, context);
|
>(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
@@ -235,9 +234,7 @@ struct TritonReducePattern : public OpConversionPattern<triton::ReduceOp> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
void populateTritonPatterns(
|
void populateTritonPatterns(RewritePatternSet &patterns) {
|
||||||
TritonGPUTypeConverter& typeConverter, RewritePatternSet &patterns
|
|
||||||
) {
|
|
||||||
MLIRContext *context = patterns.getContext();
|
MLIRContext *context = patterns.getContext();
|
||||||
patterns.add<TritonGenericPattern<triton::ReshapeOp>,
|
patterns.add<TritonGenericPattern<triton::ReshapeOp>,
|
||||||
TritonGenericPattern<triton::BroadcastOp>,
|
TritonGenericPattern<triton::BroadcastOp>,
|
||||||
@@ -247,7 +244,7 @@ void populateTritonPatterns(
|
|||||||
TritonDotPattern,
|
TritonDotPattern,
|
||||||
TritonLoadPattern,
|
TritonLoadPattern,
|
||||||
TritonStorePattern
|
TritonStorePattern
|
||||||
>(typeConverter, context);
|
>(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
@@ -310,12 +307,9 @@ struct SCFYieldPattern : public OpConversionPattern<scf::YieldOp> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
void populateSCFPatterns(
|
void populateSCFPatterns(RewritePatternSet &patterns) {
|
||||||
TritonGPUTypeConverter &typeConverter, RewritePatternSet &patterns
|
|
||||||
) {
|
|
||||||
MLIRContext *context = patterns.getContext();
|
MLIRContext *context = patterns.getContext();
|
||||||
patterns.add<SCFYieldPattern, SCFForPattern
|
patterns.add<SCFYieldPattern, SCFForPattern>(context);
|
||||||
>(typeConverter, context);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -330,17 +324,16 @@ public:
|
|||||||
MLIRContext *context = &getContext();
|
MLIRContext *context = &getContext();
|
||||||
ModuleOp mod = getOperation();
|
ModuleOp mod = getOperation();
|
||||||
int numThreads = numWarps * 32;
|
int numThreads = numWarps * 32;
|
||||||
// type converter
|
|
||||||
TritonGPUTypeConverter typeConverter(context, numThreads);
|
TritonGPUConversionTarget target(*context);
|
||||||
TritonGPUConversionTarget target(*context, typeConverter);
|
|
||||||
// rewrite patterns
|
// rewrite patterns
|
||||||
RewritePatternSet patterns(context);
|
RewritePatternSet patterns(context);
|
||||||
// add rules
|
// add rules
|
||||||
populateArithmeticPatternsAndLegality(typeConverter, patterns, target);
|
populateArithmeticPatternsAndLegality(patterns, target);
|
||||||
populateTritonPatterns(typeConverter, patterns);
|
populateTritonPatterns(patterns);
|
||||||
// TODO: can we use
|
// TODO: can we use
|
||||||
// mlir::scf::populateSCFStructurealTypeConversionsAndLegality(...) here?
|
// mlir::scf::populateSCFStructurealTypeConversionsAndLegality(...) here?
|
||||||
populateSCFPatterns(typeConverter, patterns);
|
populateSCFPatterns(patterns);
|
||||||
|
|
||||||
if(failed(applyPartialConversion(mod, target,
|
if(failed(applyPartialConversion(mod, target,
|
||||||
std::move(patterns))))
|
std::move(patterns))))
|
||||||
|
@@ -90,9 +90,8 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
|
|||||||
//
|
//
|
||||||
// TritonGPUConversion
|
// TritonGPUConversion
|
||||||
//
|
//
|
||||||
TritonGPUConversionTarget::TritonGPUConversionTarget(
|
TritonGPUConversionTarget::TritonGPUConversionTarget(MLIRContext &context)
|
||||||
MLIRContext &context, TritonGPUTypeConverter &typeConverter)
|
: ConversionTarget(context) {
|
||||||
: ConversionTarget(context), typeConverter(typeConverter) {
|
|
||||||
// TODO: we should also verify ops of TritonGPUDialect
|
// TODO: we should also verify ops of TritonGPUDialect
|
||||||
addLegalDialect<triton::gpu::TritonGPUDialect>();
|
addLegalDialect<triton::gpu::TritonGPUDialect>();
|
||||||
|
|
||||||
@@ -104,7 +103,18 @@ TritonGPUConversionTarget::TritonGPUConversionTarget(
|
|||||||
triton::TritonDialect,
|
triton::TritonDialect,
|
||||||
StandardOpsDialect,
|
StandardOpsDialect,
|
||||||
scf::SCFDialect>([&](Operation *op) {
|
scf::SCFDialect>([&](Operation *op) {
|
||||||
if (typeConverter.isLegal(op))
|
auto isLegal = [](Value v) -> bool {
|
||||||
|
Type type = v.getType();
|
||||||
|
if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
|
||||||
|
if (tensorType.getEncoding())
|
||||||
|
return true;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
};
|
||||||
|
|
||||||
|
if (llvm::all_of(op->getOperands(), isLegal) &&
|
||||||
|
llvm::all_of(op->getResults(), isLegal))
|
||||||
return true;
|
return true;
|
||||||
return false;
|
return false;
|
||||||
});
|
});
|
||||||
@@ -117,9 +127,6 @@ TritonGPUConversionTarget::TritonGPUConversionTarget(
|
|||||||
if (aEncoding && aEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>() &&
|
if (aEncoding && aEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>() &&
|
||||||
bEncoding && bEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>())
|
bEncoding && bEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>())
|
||||||
return true;
|
return true;
|
||||||
// // TODO: we should delete this
|
|
||||||
// if (this->typeConverter.isLegal(dotOp))
|
|
||||||
// return true;
|
|
||||||
return false;
|
return false;
|
||||||
});
|
});
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user