[TRITONGPU] Added template for Triton -> TritonGPU conversion
This commit is contained in:
@@ -181,6 +181,7 @@ target_link_libraries(triton
|
|||||||
TritonIR
|
TritonIR
|
||||||
TritonTransforms
|
TritonTransforms
|
||||||
TritonDriver
|
TritonDriver
|
||||||
|
TritonToTritonGPU
|
||||||
# optimizations
|
# optimizations
|
||||||
MLIRPass
|
MLIRPass
|
||||||
MLIRTransforms
|
MLIRTransforms
|
||||||
|
@@ -1 +1 @@
|
|||||||
add_subdirectory(triton/Dialect)
|
add_subdirectory(triton)
|
||||||
|
2
include/triton/CMakeLists.txt
Normal file
2
include/triton/CMakeLists.txt
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
add_subdirectory(Conversion)
|
||||||
|
add_subdirectory(Dialect)
|
4
include/triton/Conversion/CMakeLists.txt
Normal file
4
include/triton/Conversion/CMakeLists.txt
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
|
||||||
|
set(LLVM_TARGET_DEFINITIONS Passes.td)
|
||||||
|
mlir_tablegen(Passes.h.inc -gen-pass-decls)
|
||||||
|
add_public_tablegen_target(TritonConversionPassIncGen)
|
19
include/triton/Conversion/Passes.h
Normal file
19
include/triton/Conversion/Passes.h
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
#ifndef TRITON_CONVERSION_PASSES_H
|
||||||
|
#define TRITON_CONVERSION_PASSES_H
|
||||||
|
|
||||||
|
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
|
||||||
|
|
||||||
|
namespace mlir
|
||||||
|
{
|
||||||
|
namespace triton
|
||||||
|
{
|
||||||
|
|
||||||
|
#define GEN_PASS_REGISTRATION
|
||||||
|
#include "triton/Conversion/Passes.h.inc"
|
||||||
|
|
||||||
|
|
||||||
|
} // namespace triton
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
|
||||||
|
#endif
|
14
include/triton/Conversion/Passes.td
Normal file
14
include/triton/Conversion/Passes.td
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
#ifndef TRITON_CONVERSION_PASSES
|
||||||
|
#define TRITON_CONVERSION_PASSES
|
||||||
|
|
||||||
|
include "mlir/Pass/PassBase.td"
|
||||||
|
|
||||||
|
def ConvertTritonToTritonGPU: Pass<"convert-triton-to-tritongpu", "mlir::ModuleOp"> {
|
||||||
|
let summary = "Convert Triton to TritonGPU";
|
||||||
|
let description = [{
|
||||||
|
|
||||||
|
}];
|
||||||
|
let constructor = "mlir::triton::createConvertTritonToTritonGPUPass()";
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
@@ -0,0 +1,20 @@
|
|||||||
|
#ifndef TRITON_CONVERSION_TRITONTOTRITONGPU_TRITONTOTRITONGPUPASS_H_
|
||||||
|
#define TRITON_CONVERSION_TRITONTOTRITONGPU_TRITONTOTRITONGPUPASS_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
namespace mlir{
|
||||||
|
|
||||||
|
class ModuleOp;
|
||||||
|
template <typename T> class OperationPass;
|
||||||
|
|
||||||
|
namespace triton{
|
||||||
|
|
||||||
|
std::unique_ptr<OperationPass<ModuleOp>>
|
||||||
|
createConvertTritonToTritonGPUPass();
|
||||||
|
|
||||||
|
}
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
|
||||||
|
#endif
|
@@ -1,3 +1,4 @@
|
|||||||
# add_subdirectory(codegen)
|
# add_subdirectory(codegen)
|
||||||
add_subdirectory(driver)
|
add_subdirectory(driver)
|
||||||
|
add_subdirectory(Conversion)
|
||||||
add_subdirectory(Dialect)
|
add_subdirectory(Dialect)
|
||||||
|
1
lib/Conversion/CMakeLists.txt
Normal file
1
lib/Conversion/CMakeLists.txt
Normal file
@@ -0,0 +1 @@
|
|||||||
|
add_subdirectory(TritonToTritonGPU)
|
15
lib/Conversion/PassDetail.h
Normal file
15
lib/Conversion/PassDetail.h
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
#ifndef TRITON_CONVERSION_PASSDETAIL_H
|
||||||
|
#define TRITON_CONVERSION_PASSDETAIL_H
|
||||||
|
|
||||||
|
#include "mlir/Pass/Pass.h"
|
||||||
|
|
||||||
|
namespace mlir{
|
||||||
|
namespace triton{
|
||||||
|
|
||||||
|
#define GEN_PASS_CLASSES
|
||||||
|
#include "triton/Conversion/Passes.h.inc"
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
18
lib/Conversion/TritonToTritonGPU/CMakeLists.txt
Normal file
18
lib/Conversion/TritonToTritonGPU/CMakeLists.txt
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
add_mlir_conversion_library(TritonToTritonGPU
|
||||||
|
TritonToTritonGPU.cpp
|
||||||
|
|
||||||
|
ADDITIONAL_HEADER_DIRS
|
||||||
|
${PROJECT_SOURCE_DIR}/include/triton/Conversion/TritonToTritonGPU
|
||||||
|
|
||||||
|
DEPENDS
|
||||||
|
TritonConversionPassIncGen
|
||||||
|
|
||||||
|
LINK_COMPONENTS
|
||||||
|
Core
|
||||||
|
|
||||||
|
LINK_LIBS PUBLIC
|
||||||
|
MLIRIR
|
||||||
|
MLIRPass
|
||||||
|
TritonIR
|
||||||
|
TritonGPUIR
|
||||||
|
)
|
36
lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp
Normal file
36
lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||||
|
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
|
||||||
|
#include "../PassDetail.h"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
using namespace mlir::triton;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
class ConvertTritonToTritonGPU:
|
||||||
|
public ConvertTritonToTritonGPUBase<ConvertTritonToTritonGPU> {
|
||||||
|
|
||||||
|
public:
|
||||||
|
void getDependentDialects(DialectRegistry& registry) const override {
|
||||||
|
registry.insert<arith::ArithmeticDialect>();
|
||||||
|
registry.insert<StandardOpsDialect>();
|
||||||
|
registry.insert<scf::SCFDialect>();
|
||||||
|
// LLVM15
|
||||||
|
// registry.insert<cf::ControlFlowDialect>()
|
||||||
|
// registry.insert<func::FuncDialect>()
|
||||||
|
}
|
||||||
|
|
||||||
|
void runOnOperation() override {
|
||||||
|
MLIRContext *context = &getContext();
|
||||||
|
ConversionTarget target(*context);
|
||||||
|
std::cout << "Converting" << std::endl;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<OperationPass<ModuleOp>>
|
||||||
|
mlir::triton::createConvertTritonToTritonGPUPass() {
|
||||||
|
return std::make_unique<::ConvertTritonToTritonGPU>();
|
||||||
|
}
|
@@ -12,9 +12,9 @@
|
|||||||
#include "mlir/Transforms/Passes.h"
|
#include "mlir/Transforms/Passes.h"
|
||||||
|
|
||||||
|
|
||||||
|
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
|
||||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||||
#include "triton/Dialect/Triton/IR/Types.h"
|
#include "triton/Dialect/Triton/IR/Types.h"
|
||||||
|
|
||||||
#include "triton/Dialect/Triton/Transforms/Passes.h"
|
#include "triton/Dialect/Triton/Transforms/Passes.h"
|
||||||
|
|
||||||
#include "llvm/IR/Module.h"
|
#include "llvm/IR/Module.h"
|
||||||
@@ -1337,6 +1337,9 @@ void init_triton_ir(py::module &&m) {
|
|||||||
.def("add_triton_combine_pass", [](mlir::PassManager &self) {
|
.def("add_triton_combine_pass", [](mlir::PassManager &self) {
|
||||||
self.addPass(mlir::triton::createCombineOpsPass());
|
self.addPass(mlir::triton::createCombineOpsPass());
|
||||||
})
|
})
|
||||||
|
.def("add_convert_triton_to_tritongpu_pass", [](mlir::PassManager &self) {
|
||||||
|
self.addPass(mlir::triton::createConvertTritonToTritonGPUPass());
|
||||||
|
})
|
||||||
;
|
;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -1308,6 +1308,15 @@ class JITFunction:
|
|||||||
# FIXME: now we need to return context, otherwise it will be deleted
|
# FIXME: now we need to return context, otherwise it will be deleted
|
||||||
return generator.module, context
|
return generator.module, context
|
||||||
|
|
||||||
|
def compile_ttir_to_llir(self, mod, ctx):
|
||||||
|
pm = _triton.ir.pass_manager(ctx)
|
||||||
|
pm.add_inliner_pass()
|
||||||
|
pm.add_triton_combine_pass()
|
||||||
|
pm.add_canonicalizer_pass()
|
||||||
|
pm.add_convert_triton_to_tritongpu_pass()
|
||||||
|
pm.run(mod)
|
||||||
|
return mod
|
||||||
|
|
||||||
|
|
||||||
def __getitem__(self, grid):
|
def __getitem__(self, grid):
|
||||||
return Launcher(self._init_kernel(), grid)
|
return Launcher(self._init_kernel(), grid)
|
||||||
|
@@ -40,4 +40,5 @@ z = torch.empty_like(x)
|
|||||||
# print(add_kernel[(1,)].kernel.compile_to_ttir())
|
# print(add_kernel[(1,)].kernel.compile_to_ttir())
|
||||||
# print(add_kernel.annotations)
|
# print(add_kernel.annotations)
|
||||||
mod, ctx = add_kernel.compile_to_ttir(x, y, z, size, BLOCK_SIZE=256, grid=(1,))
|
mod, ctx = add_kernel.compile_to_ttir(x, y, z, size, BLOCK_SIZE=256, grid=(1,))
|
||||||
|
mod = add_kernel.compile_ttir_to_llir(mod, ctx)
|
||||||
mod.dump()
|
mod.dump()
|
||||||
|
Reference in New Issue
Block a user