[BACKEND] Support optional mask in TritonGPUToLLVM (#80)

Co-authored-by: gzhu <gzhu@nvidia.com>
This commit is contained in:
goostavz
2022-08-25 08:51:37 +08:00
committed by GitHub
parent 84aa7d025a
commit bedbf221c0
5 changed files with 69 additions and 17 deletions

View File

@@ -26,6 +26,7 @@
#include "llvm/Support/raw_ostream.h"
#include <Python.h>
#include <cctype>
#include <optional>
#include <pybind11/buffer_info.h>
#include <pybind11/functional.h>
@@ -99,6 +100,14 @@ long pow2_divisor(long N) {
return 1;
}
bool getBoolEnv(const std::string &env) {
const char *s = std::getenv(env.c_str());
std::string str(s ? s : "");
std::transform(str.begin(), str.end(), str.begin(),
[](unsigned char c) { return std::tolower(c); });
return (str == "on" || str == "true" || str == "1");
}
// Returns something like "int16", whether dtype is a torch.dtype or
// triton.language.dtype.
std::string dtype_cache_key_part(const py::object &dtype) {
@@ -1589,6 +1598,21 @@ void init_triton_ir(py::module &&m) {
py::class_<mlir::PassManager>(m, "pass_manager")
.def(py::init<mlir::MLIRContext *>())
.def("enable_debug",
[](mlir::PassManager &self) {
auto printingFlags = mlir::OpPrintingFlags();
printingFlags.elideLargeElementsAttrs(16);
self.enableIRPrinting(
/*shouldPrintBeforePass=*/nullptr,
/*shouldPrintAfterPass=*/
[](mlir::Pass *pass, mlir::Operation *) {
return getBoolEnv("MLIR_ENABLE_DUMP");
},
/*printModuleScope=*/false,
/*printAfterOnlyOnChange=*/true,
/*printAfterOnlyOnFailure*/ false, llvm::dbgs(),
printingFlags);
})
.def("run",
[](mlir::PassManager &self, mlir::ModuleOp &mod) -> bool {
return mlir::succeeded(self.run(mod.getOperation()));

View File

@@ -27,7 +27,13 @@ def test_vecadd_no_scf():
z_ptrs = z_ptr + offset
tl.store(z_ptrs, z)
ptx, shem_size, kernel_name = triton.compile(kernel, "*fp32,i32,*fp32,i32,*fp32,i32", constants={"BLOCK_SIZE_N": 256}, num_warps=NUM_WARPS, device=0, output="ptx")
# TODO: add this to CI, to make sure the the compilation flow is at lease OK
# before we have GPU machines for CI.
# ptx, shem_size, kernel_name = triton.compile(kernel,
# "*fp32,i32,*fp32,i32,*fp32,i32",
# constants={"BLOCK_SIZE_N": 256},
# num_warps=NUM_WARPS,
# device=0, output="ptx")
torch.zeros([10], device=torch.device('cuda'))
device = torch.cuda.current_device()

View File

@@ -772,6 +772,7 @@ def make_triton_ir(fn, signature, constants=dict(), attributes=dict()):
def optimize_triton_ir(mod):
pm = _triton.ir.pass_manager(mod.context)
pm.enable_debug()
pm.add_inliner_pass()
pm.add_canonicalizer_pass()
pm.run(mod)
@@ -780,6 +781,7 @@ def optimize_triton_ir(mod):
def make_tritongpu_ir(mod, num_warps):
pm = _triton.ir.pass_manager(mod.context)
pm.enable_debug()
pm.add_inliner_pass()
pm.add_triton_combine_pass()
pm.add_canonicalizer_pass()
@@ -791,6 +793,7 @@ def make_tritongpu_ir(mod, num_warps):
def optimize_tritongpu_ir(mod, num_stages):
pm = _triton.ir.pass_manager(mod.context)
pm.enable_debug()
pm.add_tritongpu_pipeline_pass(num_stages)
pm.add_canonicalizer_pass()
pm.add_cse_pass()