More on TritonGPU conversion

This commit is contained in:
Yan Da
2022-05-02 21:51:00 +08:00
parent 1428185c9c
commit 75d32e2442
7 changed files with 114 additions and 18 deletions

View File

@@ -11,7 +11,12 @@ using namespace mlir;
TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
int numThreads)
: context(context), numThreads(numThreads) {
addConversion([&](RankedTensorType tensorType) -> RankedTensorType {
// TODO: how does MLIR pick the right conversion?
addConversion([](Type type) { return type; });
addConversion([this](RankedTensorType tensorType) -> RankedTensorType {
MLIRContext *context = this->context;
int numThreads = this->numThreads;
llvm::ArrayRef<int64_t> shape = tensorType.getShape();
Type elementType = tensorType.getElementType();
int64_t rank = tensorType.getRank();
@@ -45,15 +50,28 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
//
// TritonGPUConversion
//
TritonGPUConversionTarget::TritonGPUConversionTarget(MLIRContext &context)
: ConversionTarget(context) {
TritonGPUConversionTarget::TritonGPUConversionTarget(
MLIRContext &context, TritonGPUTypeConverter &typeConverter)
: ConversionTarget(context), typeConverter(typeConverter) {
addLegalDialect<triton::TritonDialect,
arith::ArithmeticDialect,
StandardOpsDialect,
scf::SCFDialect>();
// Some ops from SCF are illegal
addIllegalOp<scf::ExecuteRegionOp, scf::ParallelOp,
scf::ReduceOp, scf::ReduceReturnOp>();
addDynamicallyLegalDialect<arith::ArithmeticDialect>([&](Operation *op) {
if (typeConverter.isLegal(op))
return true;
return false;
});
addDynamicallyLegalDialect<triton::TritonDialect>([&](Operation *op) {
if (typeConverter.isLegal(op))
return true;
return false;
});
// // We have requirements for the data layouts
// addDynamicallyLegalOp<triton::DotOp>([](triton::DotOp dotOp) -> bool {