[TRITONGPU] Added template for Triton -> TritonGPU conversion
This commit is contained in:
@@ -181,6 +181,7 @@ target_link_libraries(triton
|
||||
TritonIR
|
||||
TritonTransforms
|
||||
TritonDriver
|
||||
TritonToTritonGPU
|
||||
# optimizations
|
||||
MLIRPass
|
||||
MLIRTransforms
|
||||
|
@@ -1 +1 @@
|
||||
add_subdirectory(triton/Dialect)
|
||||
add_subdirectory(triton)
|
||||
|
2
include/triton/CMakeLists.txt
Normal file
2
include/triton/CMakeLists.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
add_subdirectory(Conversion)
|
||||
add_subdirectory(Dialect)
|
4
include/triton/Conversion/CMakeLists.txt
Normal file
4
include/triton/Conversion/CMakeLists.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS Passes.td)
|
||||
mlir_tablegen(Passes.h.inc -gen-pass-decls)
|
||||
add_public_tablegen_target(TritonConversionPassIncGen)
|
19
include/triton/Conversion/Passes.h
Normal file
19
include/triton/Conversion/Passes.h
Normal file
@@ -0,0 +1,19 @@
|
||||
#ifndef TRITON_CONVERSION_PASSES_H
|
||||
#define TRITON_CONVERSION_PASSES_H
|
||||
|
||||
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
|
||||
|
||||
namespace mlir
|
||||
{
|
||||
namespace triton
|
||||
{
|
||||
|
||||
#define GEN_PASS_REGISTRATION
|
||||
#include "triton/Conversion/Passes.h.inc"
|
||||
|
||||
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
|
||||
|
||||
#endif
|
14
include/triton/Conversion/Passes.td
Normal file
14
include/triton/Conversion/Passes.td
Normal file
@@ -0,0 +1,14 @@
|
||||
#ifndef TRITON_CONVERSION_PASSES
|
||||
#define TRITON_CONVERSION_PASSES
|
||||
|
||||
include "mlir/Pass/PassBase.td"
|
||||
|
||||
def ConvertTritonToTritonGPU: Pass<"convert-triton-to-tritongpu", "mlir::ModuleOp"> {
|
||||
let summary = "Convert Triton to TritonGPU";
|
||||
let description = [{
|
||||
|
||||
}];
|
||||
let constructor = "mlir::triton::createConvertTritonToTritonGPUPass()";
|
||||
}
|
||||
|
||||
#endif
|
@@ -0,0 +1,20 @@
|
||||
#ifndef TRITON_CONVERSION_TRITONTOTRITONGPU_TRITONTOTRITONGPUPASS_H_
|
||||
#define TRITON_CONVERSION_TRITONTOTRITONGPU_TRITONTOTRITONGPUPASS_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
namespace mlir{
|
||||
|
||||
class ModuleOp;
|
||||
template <typename T> class OperationPass;
|
||||
|
||||
namespace triton{
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createConvertTritonToTritonGPUPass();
|
||||
|
||||
}
|
||||
} // namespace mlir
|
||||
|
||||
|
||||
#endif
|
@@ -1,3 +1,4 @@
|
||||
# add_subdirectory(codegen)
|
||||
add_subdirectory(driver)
|
||||
add_subdirectory(Conversion)
|
||||
add_subdirectory(Dialect)
|
||||
|
1
lib/Conversion/CMakeLists.txt
Normal file
1
lib/Conversion/CMakeLists.txt
Normal file
@@ -0,0 +1 @@
|
||||
add_subdirectory(TritonToTritonGPU)
|
15
lib/Conversion/PassDetail.h
Normal file
15
lib/Conversion/PassDetail.h
Normal file
@@ -0,0 +1,15 @@
|
||||
#ifndef TRITON_CONVERSION_PASSDETAIL_H
|
||||
#define TRITON_CONVERSION_PASSDETAIL_H
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
namespace mlir{
|
||||
namespace triton{
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
#include "triton/Conversion/Passes.h.inc"
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
18
lib/Conversion/TritonToTritonGPU/CMakeLists.txt
Normal file
18
lib/Conversion/TritonToTritonGPU/CMakeLists.txt
Normal file
@@ -0,0 +1,18 @@
|
||||
add_mlir_conversion_library(TritonToTritonGPU
|
||||
TritonToTritonGPU.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/triton/Conversion/TritonToTritonGPU
|
||||
|
||||
DEPENDS
|
||||
TritonConversionPassIncGen
|
||||
|
||||
LINK_COMPONENTS
|
||||
Core
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
MLIRPass
|
||||
TritonIR
|
||||
TritonGPUIR
|
||||
)
|
36
lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp
Normal file
36
lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp
Normal file
@@ -0,0 +1,36 @@
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
|
||||
#include "../PassDetail.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
namespace {
|
||||
|
||||
class ConvertTritonToTritonGPU:
|
||||
public ConvertTritonToTritonGPUBase<ConvertTritonToTritonGPU> {
|
||||
|
||||
public:
|
||||
void getDependentDialects(DialectRegistry& registry) const override {
|
||||
registry.insert<arith::ArithmeticDialect>();
|
||||
registry.insert<StandardOpsDialect>();
|
||||
registry.insert<scf::SCFDialect>();
|
||||
// LLVM15
|
||||
// registry.insert<cf::ControlFlowDialect>()
|
||||
// registry.insert<func::FuncDialect>()
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
MLIRContext *context = &getContext();
|
||||
ConversionTarget target(*context);
|
||||
std::cout << "Converting" << std::endl;
|
||||
}
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
mlir::triton::createConvertTritonToTritonGPUPass() {
|
||||
return std::make_unique<::ConvertTritonToTritonGPU>();
|
||||
}
|
@@ -12,9 +12,9 @@
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
|
||||
|
||||
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/Triton/IR/Types.h"
|
||||
|
||||
#include "triton/Dialect/Triton/Transforms/Passes.h"
|
||||
|
||||
#include "llvm/IR/Module.h"
|
||||
@@ -1337,6 +1337,9 @@ void init_triton_ir(py::module &&m) {
|
||||
.def("add_triton_combine_pass", [](mlir::PassManager &self) {
|
||||
self.addPass(mlir::triton::createCombineOpsPass());
|
||||
})
|
||||
.def("add_convert_triton_to_tritongpu_pass", [](mlir::PassManager &self) {
|
||||
self.addPass(mlir::triton::createConvertTritonToTritonGPUPass());
|
||||
})
|
||||
;
|
||||
}
|
||||
|
||||
|
@@ -1308,6 +1308,15 @@ class JITFunction:
|
||||
# FIXME: now we need to return context, otherwise it will be deleted
|
||||
return generator.module, context
|
||||
|
||||
def compile_ttir_to_llir(self, mod, ctx):
|
||||
pm = _triton.ir.pass_manager(ctx)
|
||||
pm.add_inliner_pass()
|
||||
pm.add_triton_combine_pass()
|
||||
pm.add_canonicalizer_pass()
|
||||
pm.add_convert_triton_to_tritongpu_pass()
|
||||
pm.run(mod)
|
||||
return mod
|
||||
|
||||
|
||||
def __getitem__(self, grid):
|
||||
return Launcher(self._init_kernel(), grid)
|
||||
|
@@ -40,4 +40,5 @@ z = torch.empty_like(x)
|
||||
# print(add_kernel[(1,)].kernel.compile_to_ttir())
|
||||
# print(add_kernel.annotations)
|
||||
mod, ctx = add_kernel.compile_to_ttir(x, y, z, size, BLOCK_SIZE=256, grid=(1,))
|
||||
mod = add_kernel.compile_ttir_to_llir(mod, ctx)
|
||||
mod.dump()
|
||||
|
Reference in New Issue
Block a user