[TRITONGPU] Added template for Triton -> TritonGPU conversion

This commit is contained in:
Phil Tillet
2022-04-30 14:31:18 -07:00
parent 2239ac1998
commit 2c6a213131
16 changed files with 146 additions and 2 deletions

View File

@@ -181,6 +181,7 @@ target_link_libraries(triton
TritonIR
TritonTransforms
TritonDriver
TritonToTritonGPU
# optimizations
MLIRPass
MLIRTransforms

View File

@@ -1 +1 @@
add_subdirectory(triton/Dialect)
add_subdirectory(triton)

View File

@@ -0,0 +1,2 @@
add_subdirectory(Conversion)
add_subdirectory(Dialect)

View File

@@ -0,0 +1,4 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls)
add_public_tablegen_target(TritonConversionPassIncGen)

View File

@@ -0,0 +1,19 @@
#ifndef TRITON_CONVERSION_PASSES_H
#define TRITON_CONVERSION_PASSES_H
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
namespace mlir
{
namespace triton
{
#define GEN_PASS_REGISTRATION
#include "triton/Conversion/Passes.h.inc"
} // namespace triton
} // namespace mlir
#endif

View File

@@ -0,0 +1,14 @@
#ifndef TRITON_CONVERSION_PASSES
#define TRITON_CONVERSION_PASSES
include "mlir/Pass/PassBase.td"
def ConvertTritonToTritonGPU: Pass<"convert-triton-to-tritongpu", "mlir::ModuleOp"> {
let summary = "Convert Triton to TritonGPU";
let description = [{
}];
let constructor = "mlir::triton::createConvertTritonToTritonGPUPass()";
}
#endif

View File

@@ -0,0 +1,20 @@
#ifndef TRITON_CONVERSION_TRITONTOTRITONGPU_TRITONTOTRITONGPUPASS_H_
#define TRITON_CONVERSION_TRITONTOTRITONGPU_TRITONTOTRITONGPUPASS_H_
#include <memory>
namespace mlir{
class ModuleOp;
template <typename T> class OperationPass;
namespace triton{
std::unique_ptr<OperationPass<ModuleOp>>
createConvertTritonToTritonGPUPass();
}
} // namespace mlir
#endif

View File

@@ -1,3 +1,4 @@
# add_subdirectory(codegen)
add_subdirectory(driver)
add_subdirectory(Conversion)
add_subdirectory(Dialect)

View File

@@ -0,0 +1 @@
add_subdirectory(TritonToTritonGPU)

View File

@@ -0,0 +1,15 @@
#ifndef TRITON_CONVERSION_PASSDETAIL_H
#define TRITON_CONVERSION_PASSDETAIL_H
#include "mlir/Pass/Pass.h"
namespace mlir{
namespace triton{
#define GEN_PASS_CLASSES
#include "triton/Conversion/Passes.h.inc"
}
}
#endif

View File

@@ -0,0 +1,18 @@
add_mlir_conversion_library(TritonToTritonGPU
TritonToTritonGPU.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/triton/Conversion/TritonToTritonGPU
DEPENDS
TritonConversionPassIncGen
LINK_COMPONENTS
Core
LINK_LIBS PUBLIC
MLIRIR
MLIRPass
TritonIR
TritonGPUIR
)

View File

@@ -0,0 +1,36 @@
#include "mlir/Transforms/DialectConversion.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
#include "../PassDetail.h"
using namespace mlir;
using namespace mlir::triton;
namespace {
class ConvertTritonToTritonGPU:
public ConvertTritonToTritonGPUBase<ConvertTritonToTritonGPU> {
public:
void getDependentDialects(DialectRegistry& registry) const override {
registry.insert<arith::ArithmeticDialect>();
registry.insert<StandardOpsDialect>();
registry.insert<scf::SCFDialect>();
// LLVM15
// registry.insert<cf::ControlFlowDialect>()
// registry.insert<func::FuncDialect>()
}
void runOnOperation() override {
MLIRContext *context = &getContext();
ConversionTarget target(*context);
std::cout << "Converting" << std::endl;
}
};
}
std::unique_ptr<OperationPass<ModuleOp>>
mlir::triton::createConvertTritonToTritonGPUPass() {
return std::make_unique<::ConvertTritonToTritonGPU>();
}

View File

@@ -12,9 +12,9 @@
#include "mlir/Transforms/Passes.h"
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Types.h"
#include "triton/Dialect/Triton/Transforms/Passes.h"
#include "llvm/IR/Module.h"
@@ -1337,6 +1337,9 @@ void init_triton_ir(py::module &&m) {
.def("add_triton_combine_pass", [](mlir::PassManager &self) {
self.addPass(mlir::triton::createCombineOpsPass());
})
.def("add_convert_triton_to_tritongpu_pass", [](mlir::PassManager &self) {
self.addPass(mlir::triton::createConvertTritonToTritonGPUPass());
})
;
}

View File

@@ -1308,6 +1308,15 @@ class JITFunction:
# FIXME: now we need to return context, otherwise it will be deleted
return generator.module, context
def compile_ttir_to_llir(self, mod, ctx):
pm = _triton.ir.pass_manager(ctx)
pm.add_inliner_pass()
pm.add_triton_combine_pass()
pm.add_canonicalizer_pass()
pm.add_convert_triton_to_tritongpu_pass()
pm.run(mod)
return mod
def __getitem__(self, grid):
return Launcher(self._init_kernel(), grid)

View File

@@ -40,4 +40,5 @@ z = torch.empty_like(x)
# print(add_kernel[(1,)].kernel.compile_to_ttir())
# print(add_kernel.annotations)
mod, ctx = add_kernel.compile_to_ttir(x, y, z, size, BLOCK_SIZE=256, grid=(1,))
mod = add_kernel.compile_ttir_to_llir(mod, ctx)
mod.dump()