[TRITONGPU] Added template for Triton -> TritonGPU conversion

This commit is contained in:
Phil Tillet
2022-04-30 14:31:18 -07:00
parent 2239ac1998
commit 2c6a213131
16 changed files with 146 additions and 2 deletions

View File

@@ -12,9 +12,9 @@
#include "mlir/Transforms/Passes.h"
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Types.h"
#include "triton/Dialect/Triton/Transforms/Passes.h"
#include "llvm/IR/Module.h"
@@ -1337,6 +1337,9 @@ void init_triton_ir(py::module &&m) {
.def("add_triton_combine_pass", [](mlir::PassManager &self) {
self.addPass(mlir::triton::createCombineOpsPass());
})
.def("add_convert_triton_to_tritongpu_pass", [](mlir::PassManager &self) {
self.addPass(mlir::triton::createConvertTritonToTritonGPUPass());
})
;
}

View File

@@ -1307,6 +1307,15 @@ class JITFunction:
raise CompilationError(self.src, node) from e
# FIXME: now we need to return context, otherwise it will be deleted
return generator.module, context
def compile_ttir_to_llir(self, mod, ctx):
pm = _triton.ir.pass_manager(ctx)
pm.add_inliner_pass()
pm.add_triton_combine_pass()
pm.add_canonicalizer_pass()
pm.add_convert_triton_to_tritongpu_pass()
pm.run(mod)
return mod
def __getitem__(self, grid):