register conversion in triton-opt

This commit is contained in:
Yan Da
2022-06-07 19:33:51 +08:00
parent 0e11435448
commit 560e29229b
4 changed files with 13 additions and 1 deletions

View File

@@ -4,6 +4,8 @@
#include "triton/Dialect/Triton/Transforms/Passes.h" #include "triton/Dialect/Triton/Transforms/Passes.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h" #include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include "triton/Conversion/Passes.h"
#include "mlir/IR/Dialect.h" #include "mlir/IR/Dialect.h"
#include "mlir/InitAllPasses.h" #include "mlir/InitAllPasses.h"
#include "mlir/Support/MlirOptMain.h" #include "mlir/Support/MlirOptMain.h"
@@ -13,6 +15,7 @@ int main(int argc, char **argv) {
mlir::registerAllPasses(); mlir::registerAllPasses();
mlir::registerTritonPasses(); mlir::registerTritonPasses();
mlir::registerTritonGPUPasses(); mlir::registerTritonGPUPasses();
mlir::triton::registerConvertTritonToTritonGPUPass();
// TODO: register Triton & TritonGPU passes // TODO: register Triton & TritonGPU passes
mlir::DialectRegistry registry; mlir::DialectRegistry registry;

View File

@@ -16,6 +16,12 @@ def ConvertTritonToTritonGPU: Pass<"convert-triton-to-tritongpu", "mlir::ModuleO
"mlir::scf::SCFDialect", "mlir::scf::SCFDialect",
"mlir::triton::TritonDialect", "mlir::triton::TritonDialect",
"mlir::triton::gpu::TritonGPUDialect"]; "mlir::triton::gpu::TritonGPUDialect"];
let options = [
Option<"numWarps", "num-warps",
"int32_t", /*default*/"4",
"number of warps">
];
} }
#endif #endif

View File

@@ -11,7 +11,7 @@ template <typename T> class OperationPass;
namespace triton{ namespace triton{
std::unique_ptr<OperationPass<ModuleOp>> std::unique_ptr<OperationPass<ModuleOp>>
createConvertTritonToTritonGPUPass(); createConvertTritonToTritonGPUPass(int numWarps = 4);
} }
} // namespace mlir } // namespace mlir

View File

@@ -53,6 +53,9 @@ def TTG_CopyAsyncOp : TTG_Op<"copy_async",
let results = (outs TT_Type:$result); let results = (outs TT_Type:$result);
let assemblyFormat = "$ptr`,` $mask`,` $other attr-dict `:` type($ptr) `->` type($result)"; let assemblyFormat = "$ptr`,` $mask`,` $other attr-dict `:` type($ptr) `->` type($result)";
// result needs to be of shared layout
let verifier = [{ return ::verify(*this); }];
} }
// Port Arith_CmpIOp & Arith_CmpFOp to TritonGPU. // Port Arith_CmpIOp & Arith_CmpFOp to TritonGPU.