register conversion in triton-opt
This commit is contained in:
@@ -4,6 +4,8 @@
|
||||
#include "triton/Dialect/Triton/Transforms/Passes.h"
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
|
||||
|
||||
#include "triton/Conversion/Passes.h"
|
||||
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/InitAllPasses.h"
|
||||
#include "mlir/Support/MlirOptMain.h"
|
||||
@@ -13,6 +15,7 @@ int main(int argc, char **argv) {
|
||||
mlir::registerAllPasses();
|
||||
mlir::registerTritonPasses();
|
||||
mlir::registerTritonGPUPasses();
|
||||
mlir::triton::registerConvertTritonToTritonGPUPass();
|
||||
|
||||
// TODO: register Triton & TritonGPU passes
|
||||
mlir::DialectRegistry registry;
|
||||
|
@@ -16,6 +16,12 @@ def ConvertTritonToTritonGPU: Pass<"convert-triton-to-tritongpu", "mlir::ModuleO
|
||||
"mlir::scf::SCFDialect",
|
||||
"mlir::triton::TritonDialect",
|
||||
"mlir::triton::gpu::TritonGPUDialect"];
|
||||
|
||||
let options = [
|
||||
Option<"numWarps", "num-warps",
|
||||
"int32_t", /*default*/"4",
|
||||
"number of warps">
|
||||
];
|
||||
}
|
||||
|
||||
#endif
|
||||
|
@@ -11,7 +11,7 @@ template <typename T> class OperationPass;
|
||||
namespace triton{
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createConvertTritonToTritonGPUPass();
|
||||
createConvertTritonToTritonGPUPass(int numWarps = 4);
|
||||
|
||||
}
|
||||
} // namespace mlir
|
||||
|
@@ -53,6 +53,9 @@ def TTG_CopyAsyncOp : TTG_Op<"copy_async",
|
||||
let results = (outs TT_Type:$result);
|
||||
|
||||
let assemblyFormat = "$ptr`,` $mask`,` $other attr-dict `:` type($ptr) `->` type($result)";
|
||||
|
||||
// result needs to be of shared layout
|
||||
let verifier = [{ return ::verify(*this); }];
|
||||
}
|
||||
|
||||
// Port Arith_CmpIOp & Arith_CmpFOp to TritonGPU.
|
||||
|
Reference in New Issue
Block a user