Revert "Remove TypeConverter from TritonToTritonGPU conversion"

This reverts commit 64d0b87ef0.
This commit is contained in:
Yan Da
2022-06-18 14:57:41 +08:00
parent 64d0b87ef0
commit 53cf93ce6a
3 changed files with 27 additions and 26 deletions

View File

@@ -90,8 +90,9 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
//
// TritonGPUConversion
//
TritonGPUConversionTarget::TritonGPUConversionTarget(MLIRContext &context)
: ConversionTarget(context) {
TritonGPUConversionTarget::TritonGPUConversionTarget(
MLIRContext &context, TritonGPUTypeConverter &typeConverter)
: ConversionTarget(context), typeConverter(typeConverter) {
// TODO: we should also verify ops of TritonGPUDialect
addLegalDialect<triton::gpu::TritonGPUDialect>();
@@ -103,18 +104,7 @@ TritonGPUConversionTarget::TritonGPUConversionTarget(MLIRContext &context)
triton::TritonDialect,
StandardOpsDialect,
scf::SCFDialect>([&](Operation *op) {
auto isLegal = [](Value v) -> bool {
Type type = v.getType();
if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
if (tensorType.getEncoding())
return true;
return false;
}
return true;
};
if (llvm::all_of(op->getOperands(), isLegal) &&
llvm::all_of(op->getResults(), isLegal))
if (typeConverter.isLegal(op))
return true;
return false;
});
@@ -127,6 +117,9 @@ TritonGPUConversionTarget::TritonGPUConversionTarget(MLIRContext &context)
if (aEncoding && aEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>() &&
bEncoding && bEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>())
return true;
// // TODO: we should delete this
// if (this->typeConverter.isLegal(dotOp))
// return true;
return false;
});