Merge triton-mlir
branch - Complete rewrite of the backend from scratch (#1004)
This PR merges the `triton-mlir` branch, in which we have been quietly rewriting the Triton backend from scratch to increase maintainability, stability and ultimately performance. Changes to the runtime are minimal, and this new version aims to remain backward-compatible with the previous commit. The legacy backend is now officially deprecated, but can still be accessed via the `legacy-backend` tag. Co-authored-by: Keren Zhou <kerenzhou@openai.com> Co-authored-by: Yan Chunwei <yanchunwei@outlook.com> Co-authored-by: goostavz <109190422+goostavz@users.noreply.github.com> Co-authored-by: Shintaro Iwasaki <siwasaki@fb.com> Co-authored-by: Yan Da <dyanab@connect.ust.hk> Co-authored-by: Jun Yang <yangjunpro@gmail.com> Co-authored-by: Ian Bearman <ianb@microsoft.com> Co-authored-by: Jason Ansel <jansel@jansel.net> Co-authored-by: Qingyi Liu <qingyil@nvidia.com> Co-authored-by: ben-zhang-609 <110140741+ben-zhang-609@users.noreply.github.com> Co-authored-by: Chenggang Zhao <lyricz@yeah.net> Co-authored-by: ben-zhang-609 <benzh609@gmail.com> Co-authored-by: dongdongl <dongdongl@nvidia.com>
This commit is contained in:
2
lib/Target/CMakeLists.txt
Normal file
2
lib/Target/CMakeLists.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
add_subdirectory(LLVMIR)
|
||||
add_subdirectory(PTX)
|
12
lib/Target/LLVMIR/CMakeLists.txt
Normal file
12
lib/Target/LLVMIR/CMakeLists.txt
Normal file
@@ -0,0 +1,12 @@
|
||||
add_mlir_translation_library(TritonLLVMIR
|
||||
LLVMIRTranslation.cpp
|
||||
|
||||
LINK_COMPONENTS
|
||||
Core
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
MLIRLLVMIR
|
||||
MLIRSupport
|
||||
MLIRTargetLLVMIRExport
|
||||
)
|
237
lib/Target/LLVMIR/LLVMIRTranslation.cpp
Normal file
237
lib/Target/LLVMIR/LLVMIRTranslation.cpp
Normal file
@@ -0,0 +1,237 @@
|
||||
#include "triton/Target/LLVMIR/LLVMIRTranslation.h"
|
||||
|
||||
#include "mlir/Conversion/Passes.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/ExecutionEngine/ExecutionEngine.h"
|
||||
#include "mlir/ExecutionEngine/OptUtils.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
|
||||
#include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h"
|
||||
#include "mlir/Target/LLVMIR/Export.h"
|
||||
#include "mlir/Target/LLVMIR/LLVMTranslationInterface.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.h"
|
||||
#include "triton/Tools/Sys/GetEnv.hpp"
|
||||
#include "llvm/IR/Constants.h"
|
||||
#include "llvm/IRReader/IRReader.h"
|
||||
#include "llvm/Linker/Linker.h"
|
||||
#include "llvm/Support/SourceMgr.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace triton {
|
||||
|
||||
// Describes NVVM Metadata. It is used to record the nvvm related meta
|
||||
// information from mlir module.
|
||||
struct NVVMMetadata {
|
||||
int maxntidx{-1};
|
||||
bool is_kernel{};
|
||||
// Free to extend with other information.
|
||||
};
|
||||
|
||||
// Add the nvvm related metadata to LLVM IR.
|
||||
void amendLLVMFunc(llvm::Function *func, const NVVMMetadata &metadata) {
|
||||
auto *module = func->getParent();
|
||||
auto &ctx = func->getContext();
|
||||
|
||||
if (metadata.maxntidx > 0) {
|
||||
auto i32_ty = llvm::IntegerType::get(ctx, 32);
|
||||
auto warps =
|
||||
llvm::ConstantInt::get(i32_ty, llvm::APInt(32, metadata.maxntidx));
|
||||
|
||||
llvm::Metadata *md_args[] = {llvm::ValueAsMetadata::get(func),
|
||||
llvm::MDString::get(ctx, "maxntidx"),
|
||||
llvm::ValueAsMetadata::get(warps)};
|
||||
|
||||
module->getOrInsertNamedMetadata("nvvm.annotations")
|
||||
->addOperand(llvm::MDNode::get(ctx, md_args));
|
||||
}
|
||||
|
||||
if (metadata.is_kernel) {
|
||||
llvm::Metadata *md_args[] = {
|
||||
llvm::ValueAsMetadata::get(func), llvm::MDString::get(ctx, "kernel"),
|
||||
llvm::ValueAsMetadata::get(
|
||||
llvm::ConstantInt::get(llvm::Type::getInt32Ty(ctx), 1))};
|
||||
module->getOrInsertNamedMetadata("nvvm.annotations")
|
||||
->addOperand(llvm::MDNode::get(ctx, md_args));
|
||||
}
|
||||
}
|
||||
|
||||
void extractNVVMMetadata(mlir::ModuleOp module,
|
||||
llvm::DenseMap<llvm::StringRef, NVVMMetadata> *dic) {
|
||||
for (auto op : module.getOps<LLVM::LLVMFuncOp>()) {
|
||||
NVVMMetadata meta;
|
||||
|
||||
bool hasMetadata{};
|
||||
|
||||
// maxntid
|
||||
if (op->hasAttr("nvvm.maxntid")) {
|
||||
auto attr = op->getAttr("nvvm.maxntid");
|
||||
meta.maxntidx = attr.dyn_cast<IntegerAttr>().getInt();
|
||||
hasMetadata = true;
|
||||
}
|
||||
|
||||
// kernel
|
||||
if (op->hasAttr("nvvm.kernel")) {
|
||||
meta.is_kernel = true;
|
||||
hasMetadata = true;
|
||||
}
|
||||
|
||||
if (hasMetadata)
|
||||
dic->try_emplace(op.getNameAttr().strref(), std::move(meta));
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<llvm::Module>
|
||||
translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module) {
|
||||
auto context = module->getContext();
|
||||
DialectRegistry registry;
|
||||
mlir::registerLLVMDialectTranslation(registry);
|
||||
mlir::registerNVVMDialectTranslation(registry);
|
||||
context->appendDialectRegistry(registry);
|
||||
|
||||
llvm::DenseMap<llvm::StringRef, NVVMMetadata> nvvmMetadata;
|
||||
extractNVVMMetadata(module, &nvvmMetadata);
|
||||
|
||||
auto llvmModule = mlir::translateModuleToLLVMIR(module, *llvmContext);
|
||||
if (!llvmModule) {
|
||||
llvm::errs() << "Failed to emit LLVM IR\n";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto optPipeline = mlir::makeOptimizingTransformer(
|
||||
/*optLevel=*/3, /*sizeLevel=*/0,
|
||||
/*targetMachine=*/nullptr);
|
||||
|
||||
if (auto err = optPipeline(llvmModule.get())) {
|
||||
llvm::errs() << "Failed to optimize LLVM IR " << err << "\n";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
for (auto &func : llvmModule->functions()) {
|
||||
auto it = nvvmMetadata.find(func.getName());
|
||||
if (it != nvvmMetadata.end())
|
||||
amendLLVMFunc(&func, it->second);
|
||||
}
|
||||
|
||||
return llvmModule;
|
||||
}
|
||||
|
||||
std::unique_ptr<llvm::Module>
|
||||
translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
|
||||
mlir::ModuleOp module, int computeCapability) {
|
||||
mlir::PassManager pm(module->getContext());
|
||||
applyPassManagerCLOptions(pm);
|
||||
auto printingFlags = mlir::OpPrintingFlags();
|
||||
printingFlags.elideLargeElementsAttrs(16);
|
||||
pm.enableIRPrinting(
|
||||
/*shouldPrintBeforePass=*/nullptr,
|
||||
/*shouldPrintAfterPass=*/
|
||||
[](mlir::Pass *pass, mlir::Operation *) {
|
||||
return ::triton::tools::getBoolEnv("MLIR_ENABLE_DUMP");
|
||||
},
|
||||
/*printModuleScope=*/false,
|
||||
/*printAfterOnlyOnChange=*/true,
|
||||
/*printAfterOnlyOnFailure*/ false, llvm::dbgs(), printingFlags);
|
||||
|
||||
pm.addPass(createConvertTritonGPUToLLVMPass(computeCapability));
|
||||
// Canonicalize to eliminate the remaining UnrealizedConversionCastOp
|
||||
pm.addPass(mlir::createCanonicalizerPass());
|
||||
pm.addPass(mlir::createCSEPass()); // Simplify the IR to improve readability.
|
||||
pm.addPass(mlir::createSymbolDCEPass());
|
||||
pm.addPass(mlir::createCanonicalizerPass());
|
||||
|
||||
if (failed(pm.run(module))) {
|
||||
llvm::errs() << "Pass execution failed";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::map<std::string, std::string> externLibs;
|
||||
SmallVector<LLVM::LLVMFuncOp> funcs;
|
||||
module.walk([&](LLVM::LLVMFuncOp func) {
|
||||
if (func.isExternal())
|
||||
funcs.push_back(func);
|
||||
});
|
||||
|
||||
for (auto &func : funcs) {
|
||||
if (func.getOperation()->hasAttr("libname")) {
|
||||
auto name =
|
||||
func.getOperation()->getAttr("libname").dyn_cast<StringAttr>();
|
||||
auto path =
|
||||
func.getOperation()->getAttr("libpath").dyn_cast<StringAttr>();
|
||||
if (name) {
|
||||
std::string lib_name = name.str();
|
||||
externLibs[lib_name] = path.str();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (module.getOperation()->hasAttr("triton_gpu.externs")) {
|
||||
auto dict = module.getOperation()
|
||||
->getAttr("triton_gpu.externs")
|
||||
.dyn_cast<DictionaryAttr>();
|
||||
for (auto &attr : dict) {
|
||||
externLibs[attr.getName().strref().trim().str()] =
|
||||
attr.getValue().dyn_cast<StringAttr>().strref().trim().str();
|
||||
}
|
||||
}
|
||||
|
||||
auto llvmir = translateLLVMToLLVMIR(llvmContext, module);
|
||||
if (!llvmir) {
|
||||
llvm::errs() << "Translate to LLVM IR failed";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
llvm::SMDiagnostic err;
|
||||
for (auto &lib : externLibs) {
|
||||
if (linkExternLib(*llvmir, lib.second))
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return llvmir;
|
||||
}
|
||||
|
||||
void addExternalLibs(mlir::ModuleOp &module,
|
||||
const std::vector<std::string> &names,
|
||||
const std::vector<std::string> &paths) {
|
||||
if (names.empty() || names.size() != paths.size())
|
||||
return;
|
||||
|
||||
llvm::SmallVector<NamedAttribute, 2> attrs;
|
||||
|
||||
for (size_t i = 0; i < names.size(); ++i) {
|
||||
auto name = StringAttr::get(module->getContext(), names[i]);
|
||||
auto path = StringAttr::get(module->getContext(), paths[i]);
|
||||
NamedAttribute attr(name, path);
|
||||
attrs.push_back(attr);
|
||||
}
|
||||
|
||||
DictionaryAttr dict = DictionaryAttr::get(module->getContext(), attrs);
|
||||
module.getOperation()->setAttr("triton_gpu.externs", dict);
|
||||
}
|
||||
|
||||
bool linkExternLib(llvm::Module &module, llvm::StringRef path) {
|
||||
llvm::SMDiagnostic err;
|
||||
auto &ctx = module.getContext();
|
||||
|
||||
auto extMod = llvm::parseIRFile(path, err, ctx);
|
||||
if (!extMod) {
|
||||
llvm::errs() << "Failed to load " << path;
|
||||
return true;
|
||||
}
|
||||
|
||||
extMod->setTargetTriple(module.getTargetTriple());
|
||||
extMod->setDataLayout(module.getDataLayout());
|
||||
|
||||
if (llvm::Linker::linkModules(module, std::move(extMod),
|
||||
llvm::Linker::Flags::LinkOnlyNeeded)) {
|
||||
llvm::errs() << "Failed to link " << path;
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
9
lib/Target/PTX/CMakeLists.txt
Normal file
9
lib/Target/PTX/CMakeLists.txt
Normal file
@@ -0,0 +1,9 @@
|
||||
add_mlir_translation_library(TritonPTX
|
||||
PTXTranslation.cpp
|
||||
|
||||
LINK_COMPONENTS
|
||||
Core
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
TritonLLVMIR
|
||||
)
|
144
lib/Target/PTX/PTXTranslation.cpp
Normal file
144
lib/Target/PTX/PTXTranslation.cpp
Normal file
@@ -0,0 +1,144 @@
|
||||
#include "triton/Target/PTX/PTXTranslation.h"
|
||||
#include "triton/Target/LLVMIR/LLVMIRTranslation.h"
|
||||
|
||||
#include "llvm/IR/IRBuilder.h"
|
||||
#include "llvm/IR/LegacyPassManager.h"
|
||||
#include "llvm/IR/Module.h"
|
||||
#include "llvm/IR/Verifier.h"
|
||||
#include "llvm/MC/TargetRegistry.h"
|
||||
#include "llvm/Support/TargetSelect.h"
|
||||
#include "llvm/Target/TargetMachine.h"
|
||||
#include <filesystem>
|
||||
|
||||
namespace triton {
|
||||
|
||||
static void initLLVM() {
|
||||
LLVMInitializeNVPTXTargetInfo();
|
||||
LLVMInitializeNVPTXTarget();
|
||||
LLVMInitializeNVPTXTargetMC();
|
||||
LLVMInitializeNVPTXAsmPrinter();
|
||||
}
|
||||
|
||||
static bool findAndReplace(std::string &str, const std::string &begin,
|
||||
const std::string &end, const std::string &target) {
|
||||
size_t startReplace = str.find(begin);
|
||||
if (startReplace == std::string::npos)
|
||||
return false;
|
||||
size_t endReplace = str.find(end, startReplace);
|
||||
if (endReplace == std::string::npos)
|
||||
return false;
|
||||
str.replace(startReplace, endReplace + 1 - startReplace, target);
|
||||
return true;
|
||||
}
|
||||
|
||||
static void linkExternal(llvm::Module &module) {
|
||||
bool hasExternal = false;
|
||||
for (auto &func : module) {
|
||||
if (func.hasExternalLinkage()) {
|
||||
hasExternal = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (hasExternal) {
|
||||
namespace fs = std::filesystem;
|
||||
// [triton root dir]/python/triton/language/libdevice.10.bc
|
||||
static const fs::path libdevice = fs::path(__FILE__)
|
||||
.parent_path()
|
||||
.parent_path()
|
||||
.parent_path()
|
||||
.parent_path() /
|
||||
"python" / "triton" / "language" /
|
||||
"libdevice.10.bc";
|
||||
if (mlir::triton::linkExternLib(module, libdevice.string()))
|
||||
llvm::errs() << "link failed for: " << libdevice.string();
|
||||
|
||||
// please check https://llvm.org/docs/NVPTXUsage.html#reflection-parameters
|
||||
// this will enable fast math path in libdevice
|
||||
// for example, when enable nvvm-reflect-ftz, sqrt.approx.f32 will change to
|
||||
// sqrt.approx.ftz.f32
|
||||
auto &ctx = module.getContext();
|
||||
llvm::Type *I32 = llvm::Type::getInt32Ty(ctx);
|
||||
llvm::Metadata *mdFour =
|
||||
llvm::ConstantAsMetadata::get(llvm::ConstantInt::getSigned(I32, 4));
|
||||
llvm::Metadata *mdName = llvm::MDString::get(ctx, "nvvm-reflect-ftz");
|
||||
llvm::Metadata *mdOne =
|
||||
llvm::ConstantAsMetadata::get(llvm::ConstantInt::getSigned(I32, 1));
|
||||
llvm::MDNode *reflect = llvm::MDNode::get(ctx, {mdFour, mdName, mdOne});
|
||||
module.addModuleFlag(reflect);
|
||||
}
|
||||
}
|
||||
|
||||
std::string translateLLVMIRToPTX(llvm::Module &module, int cc, int version) {
|
||||
linkExternal(module);
|
||||
|
||||
// LLVM version in use may not officially support target hardware
|
||||
int maxNNVMCC = 75;
|
||||
// options
|
||||
auto options = llvm::cl::getRegisteredOptions();
|
||||
auto *shortPtr =
|
||||
static_cast<llvm::cl::opt<bool> *>(options["nvptx-short-ptr"]);
|
||||
assert(shortPtr);
|
||||
shortPtr->setValue(true);
|
||||
// compute capability
|
||||
std::string sm = "sm_" + std::to_string(cc);
|
||||
// max PTX version
|
||||
int ptxMajor = version / 10;
|
||||
int ptxMinor = version % 10;
|
||||
// create
|
||||
llvm::SmallVector<char, 0> buffer;
|
||||
std::string triple = "nvptx64-nvidia-cuda";
|
||||
std::string proc = "sm_" + std::to_string(std::min(cc, maxNNVMCC));
|
||||
std::string layout = "";
|
||||
std::string features = "";
|
||||
// std::string features = "+ptx" + std::to_string(std::min(ptx,
|
||||
// max_nvvm_ptx));
|
||||
initLLVM();
|
||||
// verify and store llvm
|
||||
llvm::legacy::PassManager pm;
|
||||
pm.add(llvm::createVerifierPass());
|
||||
pm.run(module);
|
||||
// module->print(llvm::outs(), nullptr);
|
||||
|
||||
// create machine
|
||||
module.setTargetTriple(triple);
|
||||
std::string error;
|
||||
auto target =
|
||||
llvm::TargetRegistry::lookupTarget(module.getTargetTriple(), error);
|
||||
llvm::TargetOptions opt;
|
||||
opt.AllowFPOpFusion = llvm::FPOpFusion::Fast;
|
||||
opt.UnsafeFPMath = false;
|
||||
opt.NoInfsFPMath = false;
|
||||
opt.NoNaNsFPMath = true;
|
||||
llvm::TargetMachine *machine = target->createTargetMachine(
|
||||
module.getTargetTriple(), proc, features, opt, llvm::Reloc::PIC_,
|
||||
llvm::None, llvm::CodeGenOpt::Aggressive);
|
||||
// set data layout
|
||||
if (layout.empty())
|
||||
module.setDataLayout(machine->createDataLayout());
|
||||
else
|
||||
module.setDataLayout(layout);
|
||||
// emit machine code
|
||||
for (llvm::Function &f : module.functions())
|
||||
f.addFnAttr(llvm::Attribute::AlwaysInline);
|
||||
llvm::legacy::PassManager pass;
|
||||
llvm::raw_svector_ostream stream(buffer);
|
||||
// emit
|
||||
machine->addPassesToEmitFile(pass, stream, nullptr,
|
||||
llvm::CodeGenFileType::CGFT_AssemblyFile);
|
||||
pass.run(module);
|
||||
|
||||
// post-process
|
||||
std::string result(buffer.begin(), buffer.end());
|
||||
findAndReplace(result, ".version", "\n",
|
||||
".version " + std::to_string(ptxMajor) + "." +
|
||||
std::to_string(ptxMinor) + "\n");
|
||||
findAndReplace(result, ".target", "\n", ".target " + sm + "\n");
|
||||
while (findAndReplace(result, "\t// begin inline asm", "\n", ""))
|
||||
;
|
||||
while (findAndReplace(result, "\t// end inline asm", "\n", ""))
|
||||
;
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace triton
|
Reference in New Issue
Block a user