register conversion in triton-opt
This commit is contained in:
@@ -4,6 +4,8 @@
|
|||||||
#include "triton/Dialect/Triton/Transforms/Passes.h"
|
#include "triton/Dialect/Triton/Transforms/Passes.h"
|
||||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
|
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
|
||||||
|
|
||||||
|
#include "triton/Conversion/Passes.h"
|
||||||
|
|
||||||
#include "mlir/IR/Dialect.h"
|
#include "mlir/IR/Dialect.h"
|
||||||
#include "mlir/InitAllPasses.h"
|
#include "mlir/InitAllPasses.h"
|
||||||
#include "mlir/Support/MlirOptMain.h"
|
#include "mlir/Support/MlirOptMain.h"
|
||||||
@@ -13,6 +15,7 @@ int main(int argc, char **argv) {
|
|||||||
mlir::registerAllPasses();
|
mlir::registerAllPasses();
|
||||||
mlir::registerTritonPasses();
|
mlir::registerTritonPasses();
|
||||||
mlir::registerTritonGPUPasses();
|
mlir::registerTritonGPUPasses();
|
||||||
|
mlir::triton::registerConvertTritonToTritonGPUPass();
|
||||||
|
|
||||||
// TODO: register Triton & TritonGPU passes
|
// TODO: register Triton & TritonGPU passes
|
||||||
mlir::DialectRegistry registry;
|
mlir::DialectRegistry registry;
|
||||||
|
@@ -16,6 +16,12 @@ def ConvertTritonToTritonGPU: Pass<"convert-triton-to-tritongpu", "mlir::ModuleO
|
|||||||
"mlir::scf::SCFDialect",
|
"mlir::scf::SCFDialect",
|
||||||
"mlir::triton::TritonDialect",
|
"mlir::triton::TritonDialect",
|
||||||
"mlir::triton::gpu::TritonGPUDialect"];
|
"mlir::triton::gpu::TritonGPUDialect"];
|
||||||
|
|
||||||
|
let options = [
|
||||||
|
Option<"numWarps", "num-warps",
|
||||||
|
"int32_t", /*default*/"4",
|
||||||
|
"number of warps">
|
||||||
|
];
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
@@ -11,7 +11,7 @@ template <typename T> class OperationPass;
|
|||||||
namespace triton{
|
namespace triton{
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<ModuleOp>>
|
std::unique_ptr<OperationPass<ModuleOp>>
|
||||||
createConvertTritonToTritonGPUPass();
|
createConvertTritonToTritonGPUPass(int numWarps = 4);
|
||||||
|
|
||||||
}
|
}
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
@@ -53,6 +53,9 @@ def TTG_CopyAsyncOp : TTG_Op<"copy_async",
|
|||||||
let results = (outs TT_Type:$result);
|
let results = (outs TT_Type:$result);
|
||||||
|
|
||||||
let assemblyFormat = "$ptr`,` $mask`,` $other attr-dict `:` type($ptr) `->` type($result)";
|
let assemblyFormat = "$ptr`,` $mask`,` $other attr-dict `:` type($ptr) `->` type($result)";
|
||||||
|
|
||||||
|
// result needs to be of shared layout
|
||||||
|
let verifier = [{ return ::verify(*this); }];
|
||||||
}
|
}
|
||||||
|
|
||||||
// Port Arith_CmpIOp & Arith_CmpFOp to TritonGPU.
|
// Port Arith_CmpIOp & Arith_CmpFOp to TritonGPU.
|
||||||
|
Reference in New Issue
Block a user