From 2c6a21313156b6038c6fd6d3d5a11ef5e2439011 Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Sat, 30 Apr 2022 14:31:18 -0700 Subject: [PATCH] [TRITONGPU] Added template for Triton -> TritonGPU conversion --- CMakeLists.txt | 1 + include/CMakeLists.txt | 2 +- include/triton/CMakeLists.txt | 2 ++ include/triton/Conversion/CMakeLists.txt | 4 +++ include/triton/Conversion/Passes.h | 19 ++++++++++ include/triton/Conversion/Passes.td | 14 ++++++++ .../TritonToTritonGPU/TritonToTritonGPU.h | 20 +++++++++++ lib/CMakeLists.txt | 1 + lib/Conversion/CMakeLists.txt | 1 + lib/Conversion/PassDetail.h | 15 ++++++++ .../TritonToTritonGPU/CMakeLists.txt | 18 ++++++++++ .../TritonToTritonGPU/TritonToTritonGPU.cpp | 36 +++++++++++++++++++ python/src/triton.cc | 5 ++- python/triton/code_gen.py | 9 +++++ rewrite-test/jit/vecadd.py | 1 + rewrite-test/{scf_tests.py => test_scf.py} | 0 16 files changed, 146 insertions(+), 2 deletions(-) create mode 100644 include/triton/CMakeLists.txt create mode 100644 include/triton/Conversion/CMakeLists.txt create mode 100644 include/triton/Conversion/Passes.h create mode 100644 include/triton/Conversion/Passes.td create mode 100644 include/triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h create mode 100644 lib/Conversion/CMakeLists.txt create mode 100644 lib/Conversion/PassDetail.h create mode 100644 lib/Conversion/TritonToTritonGPU/CMakeLists.txt create mode 100644 lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp rename rewrite-test/{scf_tests.py => test_scf.py} (100%) diff --git a/CMakeLists.txt b/CMakeLists.txt index efb47f7df..7055418c6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -181,6 +181,7 @@ target_link_libraries(triton TritonIR TritonTransforms TritonDriver + TritonToTritonGPU # optimizations MLIRPass MLIRTransforms diff --git a/include/CMakeLists.txt b/include/CMakeLists.txt index 9da937000..109c292fe 100644 --- a/include/CMakeLists.txt +++ b/include/CMakeLists.txt @@ -1 +1 @@ -add_subdirectory(triton/Dialect) +add_subdirectory(triton) diff --git a/include/triton/CMakeLists.txt b/include/triton/CMakeLists.txt new file mode 100644 index 000000000..b5f579c1a --- /dev/null +++ b/include/triton/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(Conversion) +add_subdirectory(Dialect) \ No newline at end of file diff --git a/include/triton/Conversion/CMakeLists.txt b/include/triton/Conversion/CMakeLists.txt new file mode 100644 index 000000000..e25b0da63 --- /dev/null +++ b/include/triton/Conversion/CMakeLists.txt @@ -0,0 +1,4 @@ + +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls) +add_public_tablegen_target(TritonConversionPassIncGen) \ No newline at end of file diff --git a/include/triton/Conversion/Passes.h b/include/triton/Conversion/Passes.h new file mode 100644 index 000000000..125551f5c --- /dev/null +++ b/include/triton/Conversion/Passes.h @@ -0,0 +1,19 @@ +#ifndef TRITON_CONVERSION_PASSES_H +#define TRITON_CONVERSION_PASSES_H + +#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h" + +namespace mlir +{ +namespace triton +{ + +#define GEN_PASS_REGISTRATION +#include "triton/Conversion/Passes.h.inc" + + +} // namespace triton +} // namespace mlir + + +#endif \ No newline at end of file diff --git a/include/triton/Conversion/Passes.td b/include/triton/Conversion/Passes.td new file mode 100644 index 000000000..2e10e0a09 --- /dev/null +++ b/include/triton/Conversion/Passes.td @@ -0,0 +1,14 @@ +#ifndef TRITON_CONVERSION_PASSES +#define TRITON_CONVERSION_PASSES + +include "mlir/Pass/PassBase.td" + +def ConvertTritonToTritonGPU: Pass<"convert-triton-to-tritongpu", "mlir::ModuleOp"> { + let summary = "Convert Triton to TritonGPU"; + let description = [{ + + }]; + let constructor = "mlir::triton::createConvertTritonToTritonGPUPass()"; +} + +#endif diff --git a/include/triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h b/include/triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h new file mode 100644 index 000000000..9a68f826d --- /dev/null +++ b/include/triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h @@ -0,0 +1,20 @@ +#ifndef TRITON_CONVERSION_TRITONTOTRITONGPU_TRITONTOTRITONGPUPASS_H_ +#define TRITON_CONVERSION_TRITONTOTRITONGPU_TRITONTOTRITONGPUPASS_H_ + +#include + +namespace mlir{ + +class ModuleOp; +template class OperationPass; + +namespace triton{ + +std::unique_ptr> +createConvertTritonToTritonGPUPass(); + +} +} // namespace mlir + + +#endif \ No newline at end of file diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index 44cd839e6..480882592 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -1,3 +1,4 @@ # add_subdirectory(codegen) add_subdirectory(driver) +add_subdirectory(Conversion) add_subdirectory(Dialect) diff --git a/lib/Conversion/CMakeLists.txt b/lib/Conversion/CMakeLists.txt new file mode 100644 index 000000000..5cbcea5da --- /dev/null +++ b/lib/Conversion/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(TritonToTritonGPU) diff --git a/lib/Conversion/PassDetail.h b/lib/Conversion/PassDetail.h new file mode 100644 index 000000000..e772f41b6 --- /dev/null +++ b/lib/Conversion/PassDetail.h @@ -0,0 +1,15 @@ +#ifndef TRITON_CONVERSION_PASSDETAIL_H +#define TRITON_CONVERSION_PASSDETAIL_H + +#include "mlir/Pass/Pass.h" + +namespace mlir{ +namespace triton{ + +#define GEN_PASS_CLASSES +#include "triton/Conversion/Passes.h.inc" + +} +} + +#endif diff --git a/lib/Conversion/TritonToTritonGPU/CMakeLists.txt b/lib/Conversion/TritonToTritonGPU/CMakeLists.txt new file mode 100644 index 000000000..382b2c977 --- /dev/null +++ b/lib/Conversion/TritonToTritonGPU/CMakeLists.txt @@ -0,0 +1,18 @@ +add_mlir_conversion_library(TritonToTritonGPU + TritonToTritonGPU.cpp + + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/include/triton/Conversion/TritonToTritonGPU + + DEPENDS + TritonConversionPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRIR + MLIRPass + TritonIR + TritonGPUIR +) \ No newline at end of file diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp new file mode 100644 index 000000000..a0e93f48c --- /dev/null +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp @@ -0,0 +1,36 @@ +#include "mlir/Transforms/DialectConversion.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h" +#include "../PassDetail.h" + +using namespace mlir; +using namespace mlir::triton; + +namespace { + +class ConvertTritonToTritonGPU: + public ConvertTritonToTritonGPUBase { + +public: + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + registry.insert(); + registry.insert(); + // LLVM15 + // registry.insert() + // registry.insert() + } + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ConversionTarget target(*context); + std::cout << "Converting" << std::endl; + } +}; + +} + +std::unique_ptr> +mlir::triton::createConvertTritonToTritonGPUPass() { + return std::make_unique<::ConvertTritonToTritonGPU>(); +} \ No newline at end of file diff --git a/python/src/triton.cc b/python/src/triton.cc index 97964394f..82142dfb5 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -12,9 +12,9 @@ #include "mlir/Transforms/Passes.h" +#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Types.h" - #include "triton/Dialect/Triton/Transforms/Passes.h" #include "llvm/IR/Module.h" @@ -1337,6 +1337,9 @@ void init_triton_ir(py::module &&m) { .def("add_triton_combine_pass", [](mlir::PassManager &self) { self.addPass(mlir::triton::createCombineOpsPass()); }) + .def("add_convert_triton_to_tritongpu_pass", [](mlir::PassManager &self) { + self.addPass(mlir::triton::createConvertTritonToTritonGPUPass()); + }) ; } diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index c231666bf..45399c5ba 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -1307,6 +1307,15 @@ class JITFunction: raise CompilationError(self.src, node) from e # FIXME: now we need to return context, otherwise it will be deleted return generator.module, context + + def compile_ttir_to_llir(self, mod, ctx): + pm = _triton.ir.pass_manager(ctx) + pm.add_inliner_pass() + pm.add_triton_combine_pass() + pm.add_canonicalizer_pass() + pm.add_convert_triton_to_tritongpu_pass() + pm.run(mod) + return mod def __getitem__(self, grid): diff --git a/rewrite-test/jit/vecadd.py b/rewrite-test/jit/vecadd.py index 34a7fc4f1..c659bf742 100644 --- a/rewrite-test/jit/vecadd.py +++ b/rewrite-test/jit/vecadd.py @@ -40,4 +40,5 @@ z = torch.empty_like(x) # print(add_kernel[(1,)].kernel.compile_to_ttir()) # print(add_kernel.annotations) mod, ctx = add_kernel.compile_to_ttir(x, y, z, size, BLOCK_SIZE=256, grid=(1,)) +mod = add_kernel.compile_ttir_to_llir(mod, ctx) mod.dump() diff --git a/rewrite-test/scf_tests.py b/rewrite-test/test_scf.py similarity index 100% rename from rewrite-test/scf_tests.py rename to rewrite-test/test_scf.py