[BACKEND] Extracting numWarps from tritonGPU module (#39)
This commit is contained in:
@@ -37,12 +37,6 @@ def ConvertTritonGPUToLLVM : Pass<"convert-triton-gpu-to-llvm", "mlir::ModuleOp"
|
||||
"mlir::scf::SCFDialect",
|
||||
"mlir::triton::TritonDialect",
|
||||
"mlir::triton::gpu::TritonGPUDialect"];
|
||||
|
||||
let options = [
|
||||
Option<"numWarps", "num-warps",
|
||||
"int32_t", /*default*/"4",
|
||||
"number of warps">
|
||||
];
|
||||
}
|
||||
|
||||
#endif
|
||||
|
@@ -10,10 +10,16 @@ template <typename T> class OperationPass;
|
||||
|
||||
namespace triton {
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createConvertTritonToTritonGPUPass(int numWarps = 4);
|
||||
constexpr static char AttrNumWarpsName[] = "triton_gpu.num-warps";
|
||||
|
||||
}
|
||||
// Create the pass with numWarps passed from cl::opt.
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createConvertTritonToTritonGPUPass();
|
||||
|
||||
// Create the pass with numWarps set explicitly.
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createConvertTritonToTritonGPUPass(int numWarps);
|
||||
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
#endif
|
||||
|
Reference in New Issue
Block a user