Init TritonGPU to LLVM dialect conversion (#32)

* add toLLVM pass

* update num-warps setting in mlir
This commit is contained in:
Yan Chunwei
2022-08-04 10:15:45 +08:00
committed by GitHub
parent 3236642e8f
commit b988bae813
9 changed files with 328 additions and 3 deletions

View File

@@ -1,6 +1,7 @@
#ifndef TRITON_CONVERSION_PASSES_H
#define TRITON_CONVERSION_PASSES_H
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h"
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
namespace mlir {

View File

@@ -24,4 +24,25 @@ def ConvertTritonToTritonGPU: Pass<"convert-triton-to-tritongpu", "mlir::ModuleO
];
}
def ConvertTritonGPUToLLVM : Pass<"convert-triton-gpu-to-llvm", "mlir::ModuleOp"> {
let summary = "Convert TritonGPU to LLVM";
let description = [{
}];
let constructor = "mlir::triton::createConvertTritonGPUToLLVMPass()";
let dependentDialects = ["mlir::arith::ArithmeticDialect",
"mlir::StandardOpsDialect",
"mlir::scf::SCFDialect",
"mlir::triton::TritonDialect",
"mlir::triton::gpu::TritonGPUDialect"];
let options = [
Option<"numWarps", "num-warps",
"int32_t", /*default*/"4",
"number of warps">
];
}
#endif

View File

@@ -0,0 +1,29 @@
#ifndef TRITON_CONVERSION_TRITONGPUTOLLVM_TRITONGPUTOLLVMPASS_H_
#define TRITON_CONVERSION_TRITONGPUTOLLVM_TRITONGPUTOLLVMPASS_H_
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Transforms/DialectConversion.h"
#include <memory>
namespace mlir {
class ModuleOp;
template <typename T> class OperationPass;
class TritonLLVMConversionTarget : public ConversionTarget {
mlir::LLVMTypeConverter &typeConverter;
public:
explicit TritonLLVMConversionTarget(MLIRContext &ctx,
mlir::LLVMTypeConverter &typeConverter);
};
namespace triton {
std::unique_ptr<OperationPass<ModuleOp>> createConvertTritonGPUToLLVMPass();
} // namespace triton
} // namespace mlir
#endif