[BACKEND] add triton-translate to translate mlir to llvmir or PTX code (#37)
This commit is contained in:
@@ -5,10 +5,10 @@ add_subdirectory(FileCheck)
|
|||||||
get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
|
get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
|
||||||
get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS)
|
get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS)
|
||||||
|
|
||||||
add_llvm_executable(triton-opt triton-opt.cpp)
|
add_llvm_executable(triton-opt triton-opt.cpp PARTIAL_SOURCES_INTENDED)
|
||||||
|
|
||||||
# TODO: what's this?
|
# TODO: what's this?
|
||||||
# llvm_update_compile_flags(triton-opt)
|
llvm_update_compile_flags(triton-opt)
|
||||||
target_link_libraries(triton-opt PRIVATE
|
target_link_libraries(triton-opt PRIVATE
|
||||||
TritonAnalysis
|
TritonAnalysis
|
||||||
TritonTransforms
|
TritonTransforms
|
||||||
@@ -24,3 +24,35 @@ target_link_libraries(triton-opt PRIVATE
|
|||||||
)
|
)
|
||||||
|
|
||||||
mlir_check_all_link_libraries(triton-opt)
|
mlir_check_all_link_libraries(triton-opt)
|
||||||
|
|
||||||
|
|
||||||
|
add_llvm_executable(triton-translate triton-translate.cpp PARTIAL_SOURCES_INTENDED)
|
||||||
|
llvm_update_compile_flags(triton-translate)
|
||||||
|
target_link_libraries(triton-translate PRIVATE
|
||||||
|
TritonAnalysis
|
||||||
|
TritonTransforms
|
||||||
|
TritonGPUTransforms
|
||||||
|
TritonLLVMIR
|
||||||
|
TritonDriver
|
||||||
|
${dialect_libs}
|
||||||
|
${conversion_libs}
|
||||||
|
# tests
|
||||||
|
TritonTestAnalysis
|
||||||
|
|
||||||
|
LLVMCore
|
||||||
|
LLVMSupport
|
||||||
|
LLVMOption
|
||||||
|
LLVMCodeGen
|
||||||
|
LLVMAsmParser
|
||||||
|
|
||||||
|
# MLIR core
|
||||||
|
MLIROptLib
|
||||||
|
MLIRIR
|
||||||
|
MLIRPass
|
||||||
|
MLIRSupport
|
||||||
|
MLIRTransforms
|
||||||
|
MLIRExecutionEngine
|
||||||
|
MLIRTransformUtils
|
||||||
|
MLIRLLVMToLLVMIRTranslation
|
||||||
|
)
|
||||||
|
mlir_check_all_link_libraries(triton-translate)
|
||||||
|
141
bin/triton-translate.cpp
Normal file
141
bin/triton-translate.cpp
Normal file
@@ -0,0 +1,141 @@
|
|||||||
|
#include "mlir/ExecutionEngine/ExecutionEngine.h"
|
||||||
|
#include "mlir/ExecutionEngine/OptUtils.h"
|
||||||
|
#include "mlir/IR/AsmState.h"
|
||||||
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
|
#include "mlir/IR/Dialect.h"
|
||||||
|
#include "mlir/Parser.h"
|
||||||
|
#include "mlir/Pass/Pass.h"
|
||||||
|
#include "mlir/Pass/PassManager.h"
|
||||||
|
#include "mlir/Support/FileUtilities.h"
|
||||||
|
#include "mlir/Support/LogicalResult.h"
|
||||||
|
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
|
||||||
|
#include "mlir/Target/LLVMIR/Export.h"
|
||||||
|
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h"
|
||||||
|
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
|
||||||
|
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||||
|
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||||
|
#include "triton/Target/LLVMIR/LLVMIRTranslation.h"
|
||||||
|
#include "triton/driver/llvm.h"
|
||||||
|
#include "llvm/IR/LLVMContext.h"
|
||||||
|
#include "llvm/Support/CommandLine.h"
|
||||||
|
#include "llvm/Support/InitLLVM.h"
|
||||||
|
#include "llvm/Support/SourceMgr.h"
|
||||||
|
#include "llvm/Support/ToolOutputFile.h"
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace triton {
|
||||||
|
|
||||||
|
OwningOpRef<ModuleOp> loadMLIRModule(llvm::StringRef inputFilename,
|
||||||
|
MLIRContext &context) {
|
||||||
|
std::string errorMessage;
|
||||||
|
auto input = openInputFile(inputFilename, &errorMessage);
|
||||||
|
if (!input) {
|
||||||
|
llvm::errs() << errorMessage << "\n";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
mlir::DialectRegistry registry;
|
||||||
|
registry
|
||||||
|
.insert<TritonDialect, gpu::TritonGPUDialect, arith::ArithmeticDialect,
|
||||||
|
StandardOpsDialect, scf::SCFDialect>();
|
||||||
|
|
||||||
|
context.appendDialectRegistry(registry);
|
||||||
|
|
||||||
|
auto processBuffer = [&](std::unique_ptr<llvm::MemoryBuffer> ownedBuffer)
|
||||||
|
-> OwningOpRef<ModuleOp> {
|
||||||
|
llvm::SourceMgr sourceMgr;
|
||||||
|
sourceMgr.AddNewSourceBuffer(std::move(ownedBuffer), SMLoc());
|
||||||
|
|
||||||
|
context.loadAllAvailableDialects();
|
||||||
|
context.allowUnregisteredDialects();
|
||||||
|
|
||||||
|
OwningOpRef<ModuleOp> module(parseSourceFile(sourceMgr, &context));
|
||||||
|
if (!module) {
|
||||||
|
llvm::errs() << "Parse MLIR file failed.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
return module;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto module = processBuffer(std::move(input));
|
||||||
|
if (!module) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
mlir::PassManager pm(module->getContext());
|
||||||
|
applyPassManagerCLOptions(pm);
|
||||||
|
|
||||||
|
pm.addPass(createConvertTritonGPUToLLVMPass());
|
||||||
|
|
||||||
|
if (failed(pm.run(module->getOperation()))) {
|
||||||
|
llvm::errs() << "Pass execution failed";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
return module;
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult tritonTranslateMain(int argc, char **argv,
|
||||||
|
llvm::StringRef toolName) {
|
||||||
|
static llvm::cl::opt<std::string> inputFilename(
|
||||||
|
llvm::cl::Positional, llvm::cl::desc("<input file>"),
|
||||||
|
llvm::cl::init("-"));
|
||||||
|
|
||||||
|
static llvm::cl::opt<std::string> outputFilename(
|
||||||
|
"o", llvm::cl::desc("Output filename"), llvm::cl::value_desc("filename"),
|
||||||
|
llvm::cl::init("-"));
|
||||||
|
|
||||||
|
static llvm::cl::opt<std::string> targetKind(
|
||||||
|
"target", llvm::cl::desc("<translation target, options: llvmir/ptx>"),
|
||||||
|
llvm::cl::value_desc("target"), llvm::cl::init("llvmir"));
|
||||||
|
|
||||||
|
static llvm::cl::opt<int> SMArch("sm", llvm::cl::desc("sm arch"),
|
||||||
|
llvm::cl::init(80));
|
||||||
|
|
||||||
|
static llvm::cl::opt<int> ptxVersion(
|
||||||
|
"ptx-version", llvm::cl::desc("PTX version"), llvm::cl::init(10000));
|
||||||
|
|
||||||
|
llvm::InitLLVM y(argc, argv);
|
||||||
|
|
||||||
|
registerAsmPrinterCLOptions();
|
||||||
|
|
||||||
|
registerMLIRContextCLOptions();
|
||||||
|
llvm::cl::ParseCommandLineOptions(argc, argv, toolName);
|
||||||
|
|
||||||
|
mlir::MLIRContext context;
|
||||||
|
auto module = loadMLIRModule(inputFilename, context);
|
||||||
|
if (!module) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string errorMessage;
|
||||||
|
auto output = openOutputFile(outputFilename, &errorMessage);
|
||||||
|
if (!output) {
|
||||||
|
llvm::errs() << errorMessage << "\n";
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::LLVMContext llvmContext;
|
||||||
|
auto llvmir = TranslateLLVMToLLVMIR(&llvmContext, *module);
|
||||||
|
if (!llvmir) {
|
||||||
|
llvm::errs() << "Translate to LLVM IR failed";
|
||||||
|
}
|
||||||
|
|
||||||
|
if (targetKind == "llvmir")
|
||||||
|
llvm::outs() << *llvmir << '\n';
|
||||||
|
else if (targetKind == "ptx")
|
||||||
|
llvm::outs() << ::triton::driver::llir_to_ptx(
|
||||||
|
llvmir.get(), SMArch.getValue(), ptxVersion.getValue());
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace triton
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
int main(int argc, char **argv) {
|
||||||
|
return failed(mlir::triton::tritonTranslateMain(
|
||||||
|
argc, argv, "Triton Translate Testing Tool."));
|
||||||
|
}
|
@@ -20,6 +20,15 @@ public:
|
|||||||
|
|
||||||
namespace triton {
|
namespace triton {
|
||||||
|
|
||||||
|
// Names for identifying different NVVM annotations. It is used as attribute
|
||||||
|
// names in MLIR modules. Refer to
|
||||||
|
// https://docs.nvidia.com/cuda/nvvm-ir-spec/index.html#supported-properties for
|
||||||
|
// the full list.
|
||||||
|
struct NVVMMetadataField {
|
||||||
|
static constexpr char MaxNTid[] = "nvvm.maxntid";
|
||||||
|
static constexpr char Kernel[] = "nvvm.kernel";
|
||||||
|
};
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<ModuleOp>> createConvertTritonGPUToLLVMPass();
|
std::unique_ptr<OperationPass<ModuleOp>> createConvertTritonGPUToLLVMPass();
|
||||||
|
|
||||||
} // namespace triton
|
} // namespace triton
|
||||||
|
24
include/triton/Target/LLVMIR/LLVMIRTranslation.h
Normal file
24
include/triton/Target/LLVMIR/LLVMIRTranslation.h
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
#ifndef TRITON_TARGET_LLVMIRTRANSLATION_H
|
||||||
|
#define TRITON_TARGET_LLVMIRTRANSLATION_H
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
namespace llvm {
|
||||||
|
class Module;
|
||||||
|
class LLVMContext;
|
||||||
|
} // namespace llvm
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
class ModuleOp;
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace triton {
|
||||||
|
|
||||||
|
// Translate mlir LLVM dialect to LLVMIR, return null if failed.
|
||||||
|
std::unique_ptr<llvm::Module>
|
||||||
|
TranslateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module);
|
||||||
|
|
||||||
|
} // namespace triton
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
#endif // TRITON_TARGET_LLVMIRTRANSLATION_H
|
@@ -3,3 +3,4 @@ add_subdirectory(driver)
|
|||||||
add_subdirectory(Analysis)
|
add_subdirectory(Analysis)
|
||||||
add_subdirectory(Conversion)
|
add_subdirectory(Conversion)
|
||||||
add_subdirectory(Dialect)
|
add_subdirectory(Dialect)
|
||||||
|
add_subdirectory(Target)
|
||||||
|
@@ -155,7 +155,7 @@ struct FuncOpConversion : public FuncOpConversionBase {
|
|||||||
auto i32 = IntegerType::get(ctx, 32);
|
auto i32 = IntegerType::get(ctx, 32);
|
||||||
// Set an attribute for maxntidx, it could be used in latter LLVM codegen
|
// Set an attribute for maxntidx, it could be used in latter LLVM codegen
|
||||||
// for `nvvm.annotation` metadata.
|
// for `nvvm.annotation` metadata.
|
||||||
newFuncOp->setAttr("nvvm.maxntidx",
|
newFuncOp->setAttr(NVVMMetadataField::MaxNTid,
|
||||||
rewriter.getIntegerAttr(i32, 32 * NumWarps));
|
rewriter.getIntegerAttr(i32, 32 * NumWarps));
|
||||||
|
|
||||||
rewriter.eraseOp(funcOp);
|
rewriter.eraseOp(funcOp);
|
||||||
|
1
lib/Target/CMakeLists.txt
Normal file
1
lib/Target/CMakeLists.txt
Normal file
@@ -0,0 +1 @@
|
|||||||
|
add_subdirectory(LLVMIR)
|
12
lib/Target/LLVMIR/CMakeLists.txt
Normal file
12
lib/Target/LLVMIR/CMakeLists.txt
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
add_mlir_translation_library(TritonLLVMIR
|
||||||
|
LLVMIRTranslation.cpp
|
||||||
|
|
||||||
|
LINK_COMPONENTS
|
||||||
|
Core
|
||||||
|
|
||||||
|
LINK_LIBS PUBLIC
|
||||||
|
MLIRIR
|
||||||
|
MLIRLLVMIR
|
||||||
|
MLIRSupport
|
||||||
|
MLIRTargetLLVMIRExport
|
||||||
|
)
|
118
lib/Target/LLVMIR/LLVMIRTranslation.cpp
Normal file
118
lib/Target/LLVMIR/LLVMIRTranslation.cpp
Normal file
@@ -0,0 +1,118 @@
|
|||||||
|
#include "triton/Target/LLVMIR/LLVMIRTranslation.h"
|
||||||
|
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||||
|
#include "mlir/ExecutionEngine/ExecutionEngine.h"
|
||||||
|
#include "mlir/ExecutionEngine/OptUtils.h"
|
||||||
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
|
#include "mlir/IR/Dialect.h"
|
||||||
|
#include "mlir/Support/LogicalResult.h"
|
||||||
|
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
|
||||||
|
#include "mlir/Target/LLVMIR/Export.h"
|
||||||
|
#include "mlir/Target/LLVMIR/LLVMTranslationInterface.h"
|
||||||
|
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h"
|
||||||
|
#include "triton/driver/llvm.h"
|
||||||
|
#include "llvm/IR/Constants.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace triton {
|
||||||
|
|
||||||
|
// Describes NVVM Metadata. It is used to record the nvvm related meta
|
||||||
|
// information from mlir module.
|
||||||
|
struct NVVMMetadata {
|
||||||
|
int maxntidx{-1};
|
||||||
|
bool is_kernel{};
|
||||||
|
// Free to extend with other information.
|
||||||
|
};
|
||||||
|
|
||||||
|
// Add the nvvm related metadata to LLVM IR.
|
||||||
|
void amendLLVMFunc(llvm::Function *func, const NVVMMetadata &metadata) {
|
||||||
|
auto *module = func->getParent();
|
||||||
|
auto &ctx = func->getContext();
|
||||||
|
|
||||||
|
if (metadata.maxntidx > 0) {
|
||||||
|
auto i32_ty = llvm::IntegerType::get(ctx, 32);
|
||||||
|
auto warps =
|
||||||
|
llvm::ConstantInt::get(i32_ty, llvm::APInt(32, metadata.maxntidx));
|
||||||
|
|
||||||
|
llvm::Metadata *md_args[] = {llvm::ValueAsMetadata::get(func),
|
||||||
|
llvm::MDString::get(ctx, "maxntidx"),
|
||||||
|
llvm::ValueAsMetadata::get(warps)};
|
||||||
|
|
||||||
|
module->getOrInsertNamedMetadata("nvvm.annotations")
|
||||||
|
->addOperand(llvm::MDNode::get(ctx, md_args));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (metadata.is_kernel) {
|
||||||
|
llvm::Metadata *md_args[] = {
|
||||||
|
llvm::ValueAsMetadata::get(func), llvm::MDString::get(ctx, "kernel"),
|
||||||
|
llvm::ValueAsMetadata::get(
|
||||||
|
llvm::ConstantInt::get(llvm::Type::getInt32Ty(ctx), 1))};
|
||||||
|
module->getOrInsertNamedMetadata("nvvm.annotations")
|
||||||
|
->addOperand(llvm::MDNode::get(ctx, md_args));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void extractNVVMMetadata(mlir::ModuleOp module,
|
||||||
|
llvm::DenseMap<llvm::StringRef, NVVMMetadata> *dic) {
|
||||||
|
for (auto op : module.getOps<LLVM::LLVMFuncOp>()) {
|
||||||
|
NVVMMetadata meta;
|
||||||
|
|
||||||
|
bool hasMetadata{};
|
||||||
|
|
||||||
|
// maxntid
|
||||||
|
if (op->hasAttr(NVVMMetadataField::MaxNTid)) {
|
||||||
|
auto attr = op->getAttr(NVVMMetadataField::MaxNTid);
|
||||||
|
meta.maxntidx = attr.dyn_cast<IntegerAttr>().getInt();
|
||||||
|
hasMetadata = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// kernel
|
||||||
|
if (op->hasAttr(NVVMMetadataField::Kernel)) {
|
||||||
|
meta.is_kernel = true;
|
||||||
|
hasMetadata = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (hasMetadata)
|
||||||
|
dic->try_emplace(op.getNameAttr().strref(), std::move(meta));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<llvm::Module>
|
||||||
|
TranslateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module) {
|
||||||
|
auto context = module->getContext();
|
||||||
|
DialectRegistry registry;
|
||||||
|
registerLLVMDialectTranslation(registry);
|
||||||
|
context->appendDialectRegistry(registry);
|
||||||
|
|
||||||
|
llvm::DenseMap<llvm::StringRef, NVVMMetadata> nvvmMetadata;
|
||||||
|
extractNVVMMetadata(module, &nvvmMetadata);
|
||||||
|
|
||||||
|
auto llvmModule = mlir::translateModuleToLLVMIR(module, *llvmContext);
|
||||||
|
if (!llvmModule) {
|
||||||
|
llvm::errs() << "Failed to emit LLVM IR\n";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize LLVM targets.
|
||||||
|
::triton::driver::init_llvm();
|
||||||
|
mlir::ExecutionEngine::setupTargetTriple(llvmModule.get());
|
||||||
|
|
||||||
|
auto optPipeline = mlir::makeOptimizingTransformer(
|
||||||
|
/*optLevel=*/3, /*sizeLevel=*/0,
|
||||||
|
/*targetMachine=*/nullptr);
|
||||||
|
|
||||||
|
if (auto err = optPipeline(llvmModule.get())) {
|
||||||
|
llvm::errs() << "Failed to optimize LLVM IR " << err << "\n";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto &func : llvmModule->functions()) {
|
||||||
|
auto it = nvvmMetadata.find(func.getName());
|
||||||
|
if (it != nvvmMetadata.end())
|
||||||
|
amendLLVMFunc(&func, it->second);
|
||||||
|
}
|
||||||
|
|
||||||
|
return llvmModule;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace triton
|
||||||
|
} // namespace mlir
|
13
test/Target/tritongpu_to_llvmir.mlir
Normal file
13
test/Target/tritongpu_to_llvmir.mlir
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
// RUN: triton-translate %s --target=llvmir | FileCheck %s
|
||||||
|
|
||||||
|
// == LLVM IR check begin ==
|
||||||
|
// CHECK-LABEL: ; ModuleID = 'LLVMDialectModule'
|
||||||
|
// CHECK: define void @test_empty_kernel
|
||||||
|
// CHECK: !nvvm.annotations
|
||||||
|
// CHECK: !{void (i64, half addrspace(1)*)* @test_empty_kernel, !"maxntidx", i32 128}
|
||||||
|
|
||||||
|
func @test_empty_kernel(%lb : index, %A : !tt.ptr<f16>) {
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
11
test/Target/tritongpu_to_ptx.mlir
Normal file
11
test/Target/tritongpu_to_ptx.mlir
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
// RUN: triton-translate %s --target=ptx --sm=80 --ptx-version=10000 | FileCheck %s
|
||||||
|
|
||||||
|
// CHECK-LABEL: // Generated by LLVM NVPTX Back-End
|
||||||
|
// CHECK: .version 6.3
|
||||||
|
// CHECK: .target sm_80
|
||||||
|
// CHECK: .address_size 64
|
||||||
|
|
||||||
|
func @test_empty_kernel(%lb : index, %A : !tt.ptr<f16>) {
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
Reference in New Issue
Block a user