Merge triton-mlir branch - Complete rewrite of the backend from scratch (#1004)

This PR merges the `triton-mlir` branch, in which we have been quietly
rewriting the Triton backend from scratch to increase maintainability,
stability and ultimately performance. Changes to the runtime are
minimal, and this new version aims to remain backward-compatible with
the previous commit. The legacy backend is now officially deprecated,
but can still be accessed via the `legacy-backend` tag.

Co-authored-by: Keren Zhou <kerenzhou@openai.com>
Co-authored-by: Yan Chunwei <yanchunwei@outlook.com>
Co-authored-by: goostavz <109190422+goostavz@users.noreply.github.com>
Co-authored-by: Shintaro Iwasaki <siwasaki@fb.com>
Co-authored-by: Yan Da <dyanab@connect.ust.hk>
Co-authored-by: Jun Yang <yangjunpro@gmail.com>
Co-authored-by: Ian Bearman <ianb@microsoft.com>
Co-authored-by: Jason Ansel <jansel@jansel.net>
Co-authored-by: Qingyi Liu <qingyil@nvidia.com>
Co-authored-by: ben-zhang-609 <110140741+ben-zhang-609@users.noreply.github.com>
Co-authored-by: Chenggang Zhao <lyricz@yeah.net>
Co-authored-by: ben-zhang-609 <benzh609@gmail.com>
Co-authored-by: dongdongl <dongdongl@nvidia.com>
This commit is contained in:
Philippe Tillet
2022-12-21 01:30:50 -08:00
committed by GitHub
parent 8650b4d1cb
commit 20100a7254
285 changed files with 26312 additions and 50143 deletions

View File

@@ -0,0 +1,2 @@
add_subdirectory(LLVMIR)
add_subdirectory(PTX)

View File

@@ -0,0 +1,12 @@
add_mlir_translation_library(TritonLLVMIR
LLVMIRTranslation.cpp
LINK_COMPONENTS
Core
LINK_LIBS PUBLIC
MLIRIR
MLIRLLVMIR
MLIRSupport
MLIRTargetLLVMIRExport
)

View File

@@ -0,0 +1,237 @@
#include "triton/Target/LLVMIR/LLVMIRTranslation.h"
#include "mlir/Conversion/Passes.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/ExecutionEngine/ExecutionEngine.h"
#include "mlir/ExecutionEngine/OptUtils.h"
#include "mlir/IR/Dialect.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Export.h"
#include "mlir/Target/LLVMIR/LLVMTranslationInterface.h"
#include "mlir/Transforms/Passes.h"
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.h"
#include "triton/Tools/Sys/GetEnv.hpp"
#include "llvm/IR/Constants.h"
#include "llvm/IRReader/IRReader.h"
#include "llvm/Linker/Linker.h"
#include "llvm/Support/SourceMgr.h"
namespace mlir {
namespace triton {
// Describes NVVM Metadata. It is used to record the nvvm related meta
// information from mlir module.
struct NVVMMetadata {
int maxntidx{-1};
bool is_kernel{};
// Free to extend with other information.
};
// Add the nvvm related metadata to LLVM IR.
void amendLLVMFunc(llvm::Function *func, const NVVMMetadata &metadata) {
auto *module = func->getParent();
auto &ctx = func->getContext();
if (metadata.maxntidx > 0) {
auto i32_ty = llvm::IntegerType::get(ctx, 32);
auto warps =
llvm::ConstantInt::get(i32_ty, llvm::APInt(32, metadata.maxntidx));
llvm::Metadata *md_args[] = {llvm::ValueAsMetadata::get(func),
llvm::MDString::get(ctx, "maxntidx"),
llvm::ValueAsMetadata::get(warps)};
module->getOrInsertNamedMetadata("nvvm.annotations")
->addOperand(llvm::MDNode::get(ctx, md_args));
}
if (metadata.is_kernel) {
llvm::Metadata *md_args[] = {
llvm::ValueAsMetadata::get(func), llvm::MDString::get(ctx, "kernel"),
llvm::ValueAsMetadata::get(
llvm::ConstantInt::get(llvm::Type::getInt32Ty(ctx), 1))};
module->getOrInsertNamedMetadata("nvvm.annotations")
->addOperand(llvm::MDNode::get(ctx, md_args));
}
}
void extractNVVMMetadata(mlir::ModuleOp module,
llvm::DenseMap<llvm::StringRef, NVVMMetadata> *dic) {
for (auto op : module.getOps<LLVM::LLVMFuncOp>()) {
NVVMMetadata meta;
bool hasMetadata{};
// maxntid
if (op->hasAttr("nvvm.maxntid")) {
auto attr = op->getAttr("nvvm.maxntid");
meta.maxntidx = attr.dyn_cast<IntegerAttr>().getInt();
hasMetadata = true;
}
// kernel
if (op->hasAttr("nvvm.kernel")) {
meta.is_kernel = true;
hasMetadata = true;
}
if (hasMetadata)
dic->try_emplace(op.getNameAttr().strref(), std::move(meta));
}
}
std::unique_ptr<llvm::Module>
translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module) {
auto context = module->getContext();
DialectRegistry registry;
mlir::registerLLVMDialectTranslation(registry);
mlir::registerNVVMDialectTranslation(registry);
context->appendDialectRegistry(registry);
llvm::DenseMap<llvm::StringRef, NVVMMetadata> nvvmMetadata;
extractNVVMMetadata(module, &nvvmMetadata);
auto llvmModule = mlir::translateModuleToLLVMIR(module, *llvmContext);
if (!llvmModule) {
llvm::errs() << "Failed to emit LLVM IR\n";
return nullptr;
}
auto optPipeline = mlir::makeOptimizingTransformer(
/*optLevel=*/3, /*sizeLevel=*/0,
/*targetMachine=*/nullptr);
if (auto err = optPipeline(llvmModule.get())) {
llvm::errs() << "Failed to optimize LLVM IR " << err << "\n";
return nullptr;
}
for (auto &func : llvmModule->functions()) {
auto it = nvvmMetadata.find(func.getName());
if (it != nvvmMetadata.end())
amendLLVMFunc(&func, it->second);
}
return llvmModule;
}
std::unique_ptr<llvm::Module>
translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
mlir::ModuleOp module, int computeCapability) {
mlir::PassManager pm(module->getContext());
applyPassManagerCLOptions(pm);
auto printingFlags = mlir::OpPrintingFlags();
printingFlags.elideLargeElementsAttrs(16);
pm.enableIRPrinting(
/*shouldPrintBeforePass=*/nullptr,
/*shouldPrintAfterPass=*/
[](mlir::Pass *pass, mlir::Operation *) {
return ::triton::tools::getBoolEnv("MLIR_ENABLE_DUMP");
},
/*printModuleScope=*/false,
/*printAfterOnlyOnChange=*/true,
/*printAfterOnlyOnFailure*/ false, llvm::dbgs(), printingFlags);
pm.addPass(createConvertTritonGPUToLLVMPass(computeCapability));
// Canonicalize to eliminate the remaining UnrealizedConversionCastOp
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mlir::createCSEPass()); // Simplify the IR to improve readability.
pm.addPass(mlir::createSymbolDCEPass());
pm.addPass(mlir::createCanonicalizerPass());
if (failed(pm.run(module))) {
llvm::errs() << "Pass execution failed";
return nullptr;
}
std::map<std::string, std::string> externLibs;
SmallVector<LLVM::LLVMFuncOp> funcs;
module.walk([&](LLVM::LLVMFuncOp func) {
if (func.isExternal())
funcs.push_back(func);
});
for (auto &func : funcs) {
if (func.getOperation()->hasAttr("libname")) {
auto name =
func.getOperation()->getAttr("libname").dyn_cast<StringAttr>();
auto path =
func.getOperation()->getAttr("libpath").dyn_cast<StringAttr>();
if (name) {
std::string lib_name = name.str();
externLibs[lib_name] = path.str();
}
}
}
if (module.getOperation()->hasAttr("triton_gpu.externs")) {
auto dict = module.getOperation()
->getAttr("triton_gpu.externs")
.dyn_cast<DictionaryAttr>();
for (auto &attr : dict) {
externLibs[attr.getName().strref().trim().str()] =
attr.getValue().dyn_cast<StringAttr>().strref().trim().str();
}
}
auto llvmir = translateLLVMToLLVMIR(llvmContext, module);
if (!llvmir) {
llvm::errs() << "Translate to LLVM IR failed";
return nullptr;
}
llvm::SMDiagnostic err;
for (auto &lib : externLibs) {
if (linkExternLib(*llvmir, lib.second))
return nullptr;
}
return llvmir;
}
void addExternalLibs(mlir::ModuleOp &module,
const std::vector<std::string> &names,
const std::vector<std::string> &paths) {
if (names.empty() || names.size() != paths.size())
return;
llvm::SmallVector<NamedAttribute, 2> attrs;
for (size_t i = 0; i < names.size(); ++i) {
auto name = StringAttr::get(module->getContext(), names[i]);
auto path = StringAttr::get(module->getContext(), paths[i]);
NamedAttribute attr(name, path);
attrs.push_back(attr);
}
DictionaryAttr dict = DictionaryAttr::get(module->getContext(), attrs);
module.getOperation()->setAttr("triton_gpu.externs", dict);
}
bool linkExternLib(llvm::Module &module, llvm::StringRef path) {
llvm::SMDiagnostic err;
auto &ctx = module.getContext();
auto extMod = llvm::parseIRFile(path, err, ctx);
if (!extMod) {
llvm::errs() << "Failed to load " << path;
return true;
}
extMod->setTargetTriple(module.getTargetTriple());
extMod->setDataLayout(module.getDataLayout());
if (llvm::Linker::linkModules(module, std::move(extMod),
llvm::Linker::Flags::LinkOnlyNeeded)) {
llvm::errs() << "Failed to link " << path;
return true;
}
return false;
}
} // namespace triton
} // namespace mlir

View File

@@ -0,0 +1,9 @@
add_mlir_translation_library(TritonPTX
PTXTranslation.cpp
LINK_COMPONENTS
Core
LINK_LIBS PUBLIC
TritonLLVMIR
)

View File

@@ -0,0 +1,144 @@
#include "triton/Target/PTX/PTXTranslation.h"
#include "triton/Target/LLVMIR/LLVMIRTranslation.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Verifier.h"
#include "llvm/MC/TargetRegistry.h"
#include "llvm/Support/TargetSelect.h"
#include "llvm/Target/TargetMachine.h"
#include <filesystem>
namespace triton {
static void initLLVM() {
LLVMInitializeNVPTXTargetInfo();
LLVMInitializeNVPTXTarget();
LLVMInitializeNVPTXTargetMC();
LLVMInitializeNVPTXAsmPrinter();
}
static bool findAndReplace(std::string &str, const std::string &begin,
const std::string &end, const std::string &target) {
size_t startReplace = str.find(begin);
if (startReplace == std::string::npos)
return false;
size_t endReplace = str.find(end, startReplace);
if (endReplace == std::string::npos)
return false;
str.replace(startReplace, endReplace + 1 - startReplace, target);
return true;
}
static void linkExternal(llvm::Module &module) {
bool hasExternal = false;
for (auto &func : module) {
if (func.hasExternalLinkage()) {
hasExternal = true;
break;
}
}
if (hasExternal) {
namespace fs = std::filesystem;
// [triton root dir]/python/triton/language/libdevice.10.bc
static const fs::path libdevice = fs::path(__FILE__)
.parent_path()
.parent_path()
.parent_path()
.parent_path() /
"python" / "triton" / "language" /
"libdevice.10.bc";
if (mlir::triton::linkExternLib(module, libdevice.string()))
llvm::errs() << "link failed for: " << libdevice.string();
// please check https://llvm.org/docs/NVPTXUsage.html#reflection-parameters
// this will enable fast math path in libdevice
// for example, when enable nvvm-reflect-ftz, sqrt.approx.f32 will change to
// sqrt.approx.ftz.f32
auto &ctx = module.getContext();
llvm::Type *I32 = llvm::Type::getInt32Ty(ctx);
llvm::Metadata *mdFour =
llvm::ConstantAsMetadata::get(llvm::ConstantInt::getSigned(I32, 4));
llvm::Metadata *mdName = llvm::MDString::get(ctx, "nvvm-reflect-ftz");
llvm::Metadata *mdOne =
llvm::ConstantAsMetadata::get(llvm::ConstantInt::getSigned(I32, 1));
llvm::MDNode *reflect = llvm::MDNode::get(ctx, {mdFour, mdName, mdOne});
module.addModuleFlag(reflect);
}
}
std::string translateLLVMIRToPTX(llvm::Module &module, int cc, int version) {
linkExternal(module);
// LLVM version in use may not officially support target hardware
int maxNNVMCC = 75;
// options
auto options = llvm::cl::getRegisteredOptions();
auto *shortPtr =
static_cast<llvm::cl::opt<bool> *>(options["nvptx-short-ptr"]);
assert(shortPtr);
shortPtr->setValue(true);
// compute capability
std::string sm = "sm_" + std::to_string(cc);
// max PTX version
int ptxMajor = version / 10;
int ptxMinor = version % 10;
// create
llvm::SmallVector<char, 0> buffer;
std::string triple = "nvptx64-nvidia-cuda";
std::string proc = "sm_" + std::to_string(std::min(cc, maxNNVMCC));
std::string layout = "";
std::string features = "";
// std::string features = "+ptx" + std::to_string(std::min(ptx,
// max_nvvm_ptx));
initLLVM();
// verify and store llvm
llvm::legacy::PassManager pm;
pm.add(llvm::createVerifierPass());
pm.run(module);
// module->print(llvm::outs(), nullptr);
// create machine
module.setTargetTriple(triple);
std::string error;
auto target =
llvm::TargetRegistry::lookupTarget(module.getTargetTriple(), error);
llvm::TargetOptions opt;
opt.AllowFPOpFusion = llvm::FPOpFusion::Fast;
opt.UnsafeFPMath = false;
opt.NoInfsFPMath = false;
opt.NoNaNsFPMath = true;
llvm::TargetMachine *machine = target->createTargetMachine(
module.getTargetTriple(), proc, features, opt, llvm::Reloc::PIC_,
llvm::None, llvm::CodeGenOpt::Aggressive);
// set data layout
if (layout.empty())
module.setDataLayout(machine->createDataLayout());
else
module.setDataLayout(layout);
// emit machine code
for (llvm::Function &f : module.functions())
f.addFnAttr(llvm::Attribute::AlwaysInline);
llvm::legacy::PassManager pass;
llvm::raw_svector_ostream stream(buffer);
// emit
machine->addPassesToEmitFile(pass, stream, nullptr,
llvm::CodeGenFileType::CGFT_AssemblyFile);
pass.run(module);
// post-process
std::string result(buffer.begin(), buffer.end());
findAndReplace(result, ".version", "\n",
".version " + std::to_string(ptxMajor) + "." +
std::to_string(ptxMinor) + "\n");
findAndReplace(result, ".target", "\n", ".target " + sm + "\n");
while (findAndReplace(result, "\t// begin inline asm", "\n", ""))
;
while (findAndReplace(result, "\t// end inline asm", "\n", ""))
;
return result;
}
} // namespace triton