[BACKEND] Support optional mask in TritonGPUToLLVM (#80)
Co-authored-by: gzhu <gzhu@nvidia.com>
This commit is contained in:
@@ -26,6 +26,7 @@
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
#include <Python.h>
|
||||
#include <cctype>
|
||||
#include <optional>
|
||||
#include <pybind11/buffer_info.h>
|
||||
#include <pybind11/functional.h>
|
||||
@@ -99,6 +100,14 @@ long pow2_divisor(long N) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
bool getBoolEnv(const std::string &env) {
|
||||
const char *s = std::getenv(env.c_str());
|
||||
std::string str(s ? s : "");
|
||||
std::transform(str.begin(), str.end(), str.begin(),
|
||||
[](unsigned char c) { return std::tolower(c); });
|
||||
return (str == "on" || str == "true" || str == "1");
|
||||
}
|
||||
|
||||
// Returns something like "int16", whether dtype is a torch.dtype or
|
||||
// triton.language.dtype.
|
||||
std::string dtype_cache_key_part(const py::object &dtype) {
|
||||
@@ -1589,6 +1598,21 @@ void init_triton_ir(py::module &&m) {
|
||||
|
||||
py::class_<mlir::PassManager>(m, "pass_manager")
|
||||
.def(py::init<mlir::MLIRContext *>())
|
||||
.def("enable_debug",
|
||||
[](mlir::PassManager &self) {
|
||||
auto printingFlags = mlir::OpPrintingFlags();
|
||||
printingFlags.elideLargeElementsAttrs(16);
|
||||
self.enableIRPrinting(
|
||||
/*shouldPrintBeforePass=*/nullptr,
|
||||
/*shouldPrintAfterPass=*/
|
||||
[](mlir::Pass *pass, mlir::Operation *) {
|
||||
return getBoolEnv("MLIR_ENABLE_DUMP");
|
||||
},
|
||||
/*printModuleScope=*/false,
|
||||
/*printAfterOnlyOnChange=*/true,
|
||||
/*printAfterOnlyOnFailure*/ false, llvm::dbgs(),
|
||||
printingFlags);
|
||||
})
|
||||
.def("run",
|
||||
[](mlir::PassManager &self, mlir::ModuleOp &mod) -> bool {
|
||||
return mlir::succeeded(self.run(mod.getOperation()));
|
||||
|
Reference in New Issue
Block a user