From 560e29229becc505c4848087e96e1bc9515c9b70 Mon Sep 17 00:00:00 2001 From: Yan Da Date: Tue, 7 Jun 2022 19:33:51 +0800 Subject: [PATCH] register conversion in triton-opt --- bin/triton-opt.cpp | 3 +++ include/triton/Conversion/Passes.td | 6 ++++++ .../triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h | 2 +- include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td | 3 +++ 4 files changed, 13 insertions(+), 1 deletion(-) diff --git a/bin/triton-opt.cpp b/bin/triton-opt.cpp index 2a41739d7..9975ed63d 100644 --- a/bin/triton-opt.cpp +++ b/bin/triton-opt.cpp @@ -4,6 +4,8 @@ #include "triton/Dialect/Triton/Transforms/Passes.h" #include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Conversion/Passes.h" + #include "mlir/IR/Dialect.h" #include "mlir/InitAllPasses.h" #include "mlir/Support/MlirOptMain.h" @@ -13,6 +15,7 @@ int main(int argc, char **argv) { mlir::registerAllPasses(); mlir::registerTritonPasses(); mlir::registerTritonGPUPasses(); + mlir::triton::registerConvertTritonToTritonGPUPass(); // TODO: register Triton & TritonGPU passes mlir::DialectRegistry registry; diff --git a/include/triton/Conversion/Passes.td b/include/triton/Conversion/Passes.td index ca3c378f7..72c330ba0 100644 --- a/include/triton/Conversion/Passes.td +++ b/include/triton/Conversion/Passes.td @@ -16,6 +16,12 @@ def ConvertTritonToTritonGPU: Pass<"convert-triton-to-tritongpu", "mlir::ModuleO "mlir::scf::SCFDialect", "mlir::triton::TritonDialect", "mlir::triton::gpu::TritonGPUDialect"]; + + let options = [ + Option<"numWarps", "num-warps", + "int32_t", /*default*/"4", + "number of warps"> + ]; } #endif diff --git a/include/triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h b/include/triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h index 9a68f826d..b21b6a1f1 100644 --- a/include/triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h +++ b/include/triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h @@ -11,7 +11,7 @@ template class OperationPass; namespace triton{ std::unique_ptr> -createConvertTritonToTritonGPUPass(); +createConvertTritonToTritonGPUPass(int numWarps = 4); } } // namespace mlir diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td index 4a5309ae2..bc2ea5ed3 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td @@ -53,6 +53,9 @@ def TTG_CopyAsyncOp : TTG_Op<"copy_async", let results = (outs TT_Type:$result); let assemblyFormat = "$ptr`,` $mask`,` $other attr-dict `:` type($ptr) `->` type($result)"; + + // result needs to be of shared layout + let verifier = [{ return ::verify(*this); }]; } // Port Arith_CmpIOp & Arith_CmpFOp to TritonGPU.