[Triton-MLIR] Fix warnings and variable names (#898)

We have been seeing the following error message for a while:

> NO target: Unable to find target for this triple (no targets are
registered)

Seems that it's not necessary to setup the target triple at that point,
so we can just take it out to get rid of the error message.

Variable names have been changed to the camel style.
This commit is contained in:
Keren Zhou
2022-11-20 22:25:27 -08:00
committed by GitHub
parent 85cccfb81f
commit 04b852e031
2 changed files with 45 additions and 52 deletions

View File

@@ -99,9 +99,6 @@ translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module) {
return nullptr;
}
// Initialize LLVM targets.
mlir::ExecutionEngine::setupTargetTriple(llvmModule.get());
auto optPipeline = mlir::makeOptimizingTransformer(
/*optLevel=*/3, /*sizeLevel=*/0,
/*targetMachine=*/nullptr);

View File

@@ -12,29 +12,28 @@
namespace triton {
static void init_llvm() {
static void initLLVM() {
LLVMInitializeNVPTXTargetInfo();
LLVMInitializeNVPTXTarget();
LLVMInitializeNVPTXTargetMC();
LLVMInitializeNVPTXAsmPrinter();
}
static bool find_and_replace(std::string &str, const std::string &begin,
const std::string &end,
const std::string &target) {
size_t start_replace = str.find(begin);
if (start_replace == std::string::npos)
static bool findAndReplace(std::string &str, const std::string &begin,
const std::string &end, const std::string &target) {
size_t startReplace = str.find(begin);
if (startReplace == std::string::npos)
return false;
size_t end_replace = str.find(end, start_replace);
if (end_replace == std::string::npos)
size_t endReplace = str.find(end, startReplace);
if (endReplace == std::string::npos)
return false;
str.replace(start_replace, end_replace + 1 - start_replace, target);
str.replace(startReplace, endReplace + 1 - startReplace, target);
return true;
}
static std::string llir_to_ptx(llvm::Module *module, int capability, int ptx) {
static void linkExternal(llvm::Module &module) {
bool hasExternal = false;
for (auto &func : *module) {
for (auto &func : module) {
if (func.hasExternalLinkage()) {
hasExternal = true;
break;
@@ -51,16 +50,14 @@ static std::string llir_to_ptx(llvm::Module *module, int capability, int ptx) {
.parent_path() /
"python" / "triton" / "language" /
"libdevice.10.bc";
if (mlir::triton::linkExternLib(*module, libdevice.string()))
if (mlir::triton::linkExternLib(module, libdevice.string()))
llvm::errs() << "link failed for: " << libdevice.string();
}
// please check https://llvm.org/docs/NVPTXUsage.html#reflection-parameters
// this will enable fast math path in libdevice
// for example, when enable nvvm-reflect-ftz, sqrt.approx.f32 will change to
// sqrt.approx.ftz.f32
{
auto &ctx = module->getContext();
// please check https://llvm.org/docs/NVPTXUsage.html#reflection-parameters
// this will enable fast math path in libdevice
// for example, when enable nvvm-reflect-ftz, sqrt.approx.f32 will change to
// sqrt.approx.ftz.f32
auto &ctx = module.getContext();
llvm::Type *I32 = llvm::Type::getInt32Ty(ctx);
llvm::Metadata *mdFour =
llvm::ConstantAsMetadata::get(llvm::ConstantInt::getSigned(I32, 4));
@@ -68,81 +65,80 @@ static std::string llir_to_ptx(llvm::Module *module, int capability, int ptx) {
llvm::Metadata *mdOne =
llvm::ConstantAsMetadata::get(llvm::ConstantInt::getSigned(I32, 1));
llvm::MDNode *reflect = llvm::MDNode::get(ctx, {mdFour, mdName, mdOne});
module->addModuleFlag(reflect);
module.addModuleFlag(reflect);
}
}
std::string translateLLVMIRToPTX(llvm::Module &module, int cc, int version) {
linkExternal(module);
// LLVM version in use may not officially support target hardware
int max_nvvm_cc = 75;
// int max_nvvm_ptx = 74;
int maxNNVMCC = 75;
// options
auto options = llvm::cl::getRegisteredOptions();
auto *short_ptr =
auto *shortPtr =
static_cast<llvm::cl::opt<bool> *>(options["nvptx-short-ptr"]);
assert(short_ptr);
short_ptr->setValue(true);
assert(shortPtr);
shortPtr->setValue(true);
// compute capability
std::string sm = "sm_" + std::to_string(capability);
std::string sm = "sm_" + std::to_string(cc);
// max PTX version
int ptx_major = ptx / 10;
int ptx_minor = ptx % 10;
int ptxMajor = version / 10;
int ptxMinor = version % 10;
// create
llvm::SmallVector<char, 0> buffer;
std::string triple = "nvptx64-nvidia-cuda";
std::string proc = "sm_" + std::to_string(std::min(capability, max_nvvm_cc));
std::string proc = "sm_" + std::to_string(std::min(cc, maxNNVMCC));
std::string layout = "";
std::string features = "";
// std::string features = "+ptx" + std::to_string(std::min(ptx,
// max_nvvm_ptx));
init_llvm();
initLLVM();
// verify and store llvm
llvm::legacy::PassManager pm;
pm.add(llvm::createVerifierPass());
pm.run(*module);
pm.run(module);
// module->print(llvm::outs(), nullptr);
// create machine
module->setTargetTriple(triple);
module.setTargetTriple(triple);
std::string error;
auto target =
llvm::TargetRegistry::lookupTarget(module->getTargetTriple(), error);
llvm::TargetRegistry::lookupTarget(module.getTargetTriple(), error);
llvm::TargetOptions opt;
opt.AllowFPOpFusion = llvm::FPOpFusion::Fast;
opt.UnsafeFPMath = false;
opt.NoInfsFPMath = false;
opt.NoNaNsFPMath = true;
llvm::TargetMachine *machine = target->createTargetMachine(
module->getTargetTriple(), proc, features, opt, llvm::Reloc::PIC_,
module.getTargetTriple(), proc, features, opt, llvm::Reloc::PIC_,
llvm::None, llvm::CodeGenOpt::Aggressive);
// set data layout
if (layout.empty())
module->setDataLayout(machine->createDataLayout());
module.setDataLayout(machine->createDataLayout());
else
module->setDataLayout(layout);
module.setDataLayout(layout);
// emit machine code
for (llvm::Function &f : module->functions())
for (llvm::Function &f : module.functions())
f.addFnAttr(llvm::Attribute::AlwaysInline);
llvm::legacy::PassManager pass;
llvm::raw_svector_ostream stream(buffer);
// emit
machine->addPassesToEmitFile(pass, stream, nullptr,
llvm::CodeGenFileType::CGFT_AssemblyFile);
pass.run(*module);
pass.run(module);
// post-process
std::string result(buffer.begin(), buffer.end());
find_and_replace(result, ".version", "\n",
".version " + std::to_string(ptx_major) + "." +
std::to_string(ptx_minor) + "\n");
find_and_replace(result, ".target", "\n", ".target " + sm + "\n");
while (find_and_replace(result, "\t// begin inline asm", "\n", ""))
findAndReplace(result, ".version", "\n",
".version " + std::to_string(ptxMajor) + "." +
std::to_string(ptxMinor) + "\n");
findAndReplace(result, ".target", "\n", ".target " + sm + "\n");
while (findAndReplace(result, "\t// begin inline asm", "\n", ""))
;
while (find_and_replace(result, "\t// end inline asm", "\n", ""))
while (findAndReplace(result, "\t// end inline asm", "\n", ""))
;
return result;
}
std::string translateLLVMIRToPTX(llvm::Module &module, int cc, int version) {
auto ptxCode = llir_to_ptx(&module, cc, version);
return ptxCode;
}
} // namespace triton