More on SCF conversion

This commit is contained in:
Yan Da
2022-05-04 21:50:32 +08:00
parent a96fe07e1c
commit 26c59e4718
2 changed files with 101 additions and 46 deletions

View File

@@ -2,6 +2,7 @@
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include <algorithm>
#include <llvm-6.0/llvm/Support/ErrorHandling.h>
using namespace mlir;
@@ -45,6 +46,25 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
context, threadTileSize, blockTileSize, order);
return RankedTensorType::get(shape, elementType, encoding);
});
// materailizations
addArgumentMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
ValueRange inputs, Location loc) {
// llvm::errs() << "Trying to materialize target... " << inputs[0] << "\n";
llvm_unreachable("Not implemented");
return llvm::None;
});
addSourceMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
ValueRange inputs, Location loc) {
llvm_unreachable("Not implemented");
return llvm::None;
});
addTargetMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
ValueRange inputs, Location loc) {
llvm_unreachable("Not implemented");
// llvm::errs() << "Trying to materialize target... " << inputs[0] << "\n";
return llvm::None;
});
}
//
@@ -53,25 +73,15 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
TritonGPUConversionTarget::TritonGPUConversionTarget(
MLIRContext &context, TritonGPUTypeConverter &typeConverter)
: ConversionTarget(context), typeConverter(typeConverter) {
addLegalDialect<StandardOpsDialect, scf::SCFDialect>();
// Some ops from SCF are illegal
addIllegalOp<scf::ExecuteRegionOp, scf::ParallelOp,
scf::ReduceOp, scf::ReduceReturnOp>();
addDynamicallyLegalDialect<arith::ArithmeticDialect>([&](Operation *op) {
if (typeConverter.isLegal(op))
return true;
return false;
});
addDynamicallyLegalDialect<triton::TritonDialect>([&](Operation *op) {
if (typeConverter.isLegal(op))
return true;
return false;
});
addDynamicallyLegalDialect<triton::gpu::TritonGPUDialect>([&](Operation *op) {
addDynamicallyLegalDialect<arith::ArithmeticDialect,
triton::TritonDialect,
triton::gpu::TritonGPUDialect,
StandardOpsDialect,
scf::SCFDialect>([&](Operation *op) {
if (typeConverter.isLegal(op))
return true;
return false;