[BACKEND] Extracting numWarps from tritonGPU module (#39)

This commit is contained in:
Yan Chunwei
2022-08-09 00:40:20 +08:00
committed by GitHub
parent 920723cf3d
commit 83ef74f248
8 changed files with 66 additions and 17 deletions

View File

@@ -37,12 +37,6 @@ def ConvertTritonGPUToLLVM : Pass<"convert-triton-gpu-to-llvm", "mlir::ModuleOp"
"mlir::scf::SCFDialect",
"mlir::triton::TritonDialect",
"mlir::triton::gpu::TritonGPUDialect"];
let options = [
Option<"numWarps", "num-warps",
"int32_t", /*default*/"4",
"number of warps">
];
}
#endif

View File

@@ -10,10 +10,16 @@ template <typename T> class OperationPass;
namespace triton {
std::unique_ptr<OperationPass<ModuleOp>>
createConvertTritonToTritonGPUPass(int numWarps = 4);
constexpr static char AttrNumWarpsName[] = "triton_gpu.num-warps";
}
// Create the pass with numWarps passed from cl::opt.
std::unique_ptr<OperationPass<ModuleOp>> createConvertTritonToTritonGPUPass();
// Create the pass with numWarps set explicitly.
std::unique_ptr<OperationPass<ModuleOp>>
createConvertTritonToTritonGPUPass(int numWarps);
} // namespace triton
} // namespace mlir
#endif
#endif