[Triton-MLIR] Fix warnings and variable names (#898)
We have been seeing the following error message for a while: > NO target: Unable to find target for this triple (no targets are registered) Seems that it's not necessary to setup the target triple at that point, so we can just take it out to get rid of the error message. Variable names have been changed to the camel style.
This commit is contained in:
@@ -99,9 +99,6 @@ translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module) {
|
|||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize LLVM targets.
|
|
||||||
mlir::ExecutionEngine::setupTargetTriple(llvmModule.get());
|
|
||||||
|
|
||||||
auto optPipeline = mlir::makeOptimizingTransformer(
|
auto optPipeline = mlir::makeOptimizingTransformer(
|
||||||
/*optLevel=*/3, /*sizeLevel=*/0,
|
/*optLevel=*/3, /*sizeLevel=*/0,
|
||||||
/*targetMachine=*/nullptr);
|
/*targetMachine=*/nullptr);
|
||||||
|
@@ -12,29 +12,28 @@
|
|||||||
|
|
||||||
namespace triton {
|
namespace triton {
|
||||||
|
|
||||||
static void init_llvm() {
|
static void initLLVM() {
|
||||||
LLVMInitializeNVPTXTargetInfo();
|
LLVMInitializeNVPTXTargetInfo();
|
||||||
LLVMInitializeNVPTXTarget();
|
LLVMInitializeNVPTXTarget();
|
||||||
LLVMInitializeNVPTXTargetMC();
|
LLVMInitializeNVPTXTargetMC();
|
||||||
LLVMInitializeNVPTXAsmPrinter();
|
LLVMInitializeNVPTXAsmPrinter();
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool find_and_replace(std::string &str, const std::string &begin,
|
static bool findAndReplace(std::string &str, const std::string &begin,
|
||||||
const std::string &end,
|
const std::string &end, const std::string &target) {
|
||||||
const std::string &target) {
|
size_t startReplace = str.find(begin);
|
||||||
size_t start_replace = str.find(begin);
|
if (startReplace == std::string::npos)
|
||||||
if (start_replace == std::string::npos)
|
|
||||||
return false;
|
return false;
|
||||||
size_t end_replace = str.find(end, start_replace);
|
size_t endReplace = str.find(end, startReplace);
|
||||||
if (end_replace == std::string::npos)
|
if (endReplace == std::string::npos)
|
||||||
return false;
|
return false;
|
||||||
str.replace(start_replace, end_replace + 1 - start_replace, target);
|
str.replace(startReplace, endReplace + 1 - startReplace, target);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::string llir_to_ptx(llvm::Module *module, int capability, int ptx) {
|
static void linkExternal(llvm::Module &module) {
|
||||||
bool hasExternal = false;
|
bool hasExternal = false;
|
||||||
for (auto &func : *module) {
|
for (auto &func : module) {
|
||||||
if (func.hasExternalLinkage()) {
|
if (func.hasExternalLinkage()) {
|
||||||
hasExternal = true;
|
hasExternal = true;
|
||||||
break;
|
break;
|
||||||
@@ -51,16 +50,14 @@ static std::string llir_to_ptx(llvm::Module *module, int capability, int ptx) {
|
|||||||
.parent_path() /
|
.parent_path() /
|
||||||
"python" / "triton" / "language" /
|
"python" / "triton" / "language" /
|
||||||
"libdevice.10.bc";
|
"libdevice.10.bc";
|
||||||
if (mlir::triton::linkExternLib(*module, libdevice.string()))
|
if (mlir::triton::linkExternLib(module, libdevice.string()))
|
||||||
llvm::errs() << "link failed for: " << libdevice.string();
|
llvm::errs() << "link failed for: " << libdevice.string();
|
||||||
}
|
|
||||||
|
|
||||||
// please check https://llvm.org/docs/NVPTXUsage.html#reflection-parameters
|
// please check https://llvm.org/docs/NVPTXUsage.html#reflection-parameters
|
||||||
// this will enable fast math path in libdevice
|
// this will enable fast math path in libdevice
|
||||||
// for example, when enable nvvm-reflect-ftz, sqrt.approx.f32 will change to
|
// for example, when enable nvvm-reflect-ftz, sqrt.approx.f32 will change to
|
||||||
// sqrt.approx.ftz.f32
|
// sqrt.approx.ftz.f32
|
||||||
{
|
auto &ctx = module.getContext();
|
||||||
auto &ctx = module->getContext();
|
|
||||||
llvm::Type *I32 = llvm::Type::getInt32Ty(ctx);
|
llvm::Type *I32 = llvm::Type::getInt32Ty(ctx);
|
||||||
llvm::Metadata *mdFour =
|
llvm::Metadata *mdFour =
|
||||||
llvm::ConstantAsMetadata::get(llvm::ConstantInt::getSigned(I32, 4));
|
llvm::ConstantAsMetadata::get(llvm::ConstantInt::getSigned(I32, 4));
|
||||||
@@ -68,81 +65,80 @@ static std::string llir_to_ptx(llvm::Module *module, int capability, int ptx) {
|
|||||||
llvm::Metadata *mdOne =
|
llvm::Metadata *mdOne =
|
||||||
llvm::ConstantAsMetadata::get(llvm::ConstantInt::getSigned(I32, 1));
|
llvm::ConstantAsMetadata::get(llvm::ConstantInt::getSigned(I32, 1));
|
||||||
llvm::MDNode *reflect = llvm::MDNode::get(ctx, {mdFour, mdName, mdOne});
|
llvm::MDNode *reflect = llvm::MDNode::get(ctx, {mdFour, mdName, mdOne});
|
||||||
module->addModuleFlag(reflect);
|
module.addModuleFlag(reflect);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string translateLLVMIRToPTX(llvm::Module &module, int cc, int version) {
|
||||||
|
linkExternal(module);
|
||||||
|
|
||||||
// LLVM version in use may not officially support target hardware
|
// LLVM version in use may not officially support target hardware
|
||||||
int max_nvvm_cc = 75;
|
int maxNNVMCC = 75;
|
||||||
// int max_nvvm_ptx = 74;
|
|
||||||
// options
|
// options
|
||||||
auto options = llvm::cl::getRegisteredOptions();
|
auto options = llvm::cl::getRegisteredOptions();
|
||||||
auto *short_ptr =
|
auto *shortPtr =
|
||||||
static_cast<llvm::cl::opt<bool> *>(options["nvptx-short-ptr"]);
|
static_cast<llvm::cl::opt<bool> *>(options["nvptx-short-ptr"]);
|
||||||
assert(short_ptr);
|
assert(shortPtr);
|
||||||
short_ptr->setValue(true);
|
shortPtr->setValue(true);
|
||||||
// compute capability
|
// compute capability
|
||||||
std::string sm = "sm_" + std::to_string(capability);
|
std::string sm = "sm_" + std::to_string(cc);
|
||||||
// max PTX version
|
// max PTX version
|
||||||
int ptx_major = ptx / 10;
|
int ptxMajor = version / 10;
|
||||||
int ptx_minor = ptx % 10;
|
int ptxMinor = version % 10;
|
||||||
// create
|
// create
|
||||||
llvm::SmallVector<char, 0> buffer;
|
llvm::SmallVector<char, 0> buffer;
|
||||||
std::string triple = "nvptx64-nvidia-cuda";
|
std::string triple = "nvptx64-nvidia-cuda";
|
||||||
std::string proc = "sm_" + std::to_string(std::min(capability, max_nvvm_cc));
|
std::string proc = "sm_" + std::to_string(std::min(cc, maxNNVMCC));
|
||||||
std::string layout = "";
|
std::string layout = "";
|
||||||
std::string features = "";
|
std::string features = "";
|
||||||
// std::string features = "+ptx" + std::to_string(std::min(ptx,
|
// std::string features = "+ptx" + std::to_string(std::min(ptx,
|
||||||
// max_nvvm_ptx));
|
// max_nvvm_ptx));
|
||||||
init_llvm();
|
initLLVM();
|
||||||
// verify and store llvm
|
// verify and store llvm
|
||||||
llvm::legacy::PassManager pm;
|
llvm::legacy::PassManager pm;
|
||||||
pm.add(llvm::createVerifierPass());
|
pm.add(llvm::createVerifierPass());
|
||||||
pm.run(*module);
|
pm.run(module);
|
||||||
// module->print(llvm::outs(), nullptr);
|
// module->print(llvm::outs(), nullptr);
|
||||||
|
|
||||||
// create machine
|
// create machine
|
||||||
module->setTargetTriple(triple);
|
module.setTargetTriple(triple);
|
||||||
std::string error;
|
std::string error;
|
||||||
auto target =
|
auto target =
|
||||||
llvm::TargetRegistry::lookupTarget(module->getTargetTriple(), error);
|
llvm::TargetRegistry::lookupTarget(module.getTargetTriple(), error);
|
||||||
llvm::TargetOptions opt;
|
llvm::TargetOptions opt;
|
||||||
opt.AllowFPOpFusion = llvm::FPOpFusion::Fast;
|
opt.AllowFPOpFusion = llvm::FPOpFusion::Fast;
|
||||||
opt.UnsafeFPMath = false;
|
opt.UnsafeFPMath = false;
|
||||||
opt.NoInfsFPMath = false;
|
opt.NoInfsFPMath = false;
|
||||||
opt.NoNaNsFPMath = true;
|
opt.NoNaNsFPMath = true;
|
||||||
llvm::TargetMachine *machine = target->createTargetMachine(
|
llvm::TargetMachine *machine = target->createTargetMachine(
|
||||||
module->getTargetTriple(), proc, features, opt, llvm::Reloc::PIC_,
|
module.getTargetTriple(), proc, features, opt, llvm::Reloc::PIC_,
|
||||||
llvm::None, llvm::CodeGenOpt::Aggressive);
|
llvm::None, llvm::CodeGenOpt::Aggressive);
|
||||||
// set data layout
|
// set data layout
|
||||||
if (layout.empty())
|
if (layout.empty())
|
||||||
module->setDataLayout(machine->createDataLayout());
|
module.setDataLayout(machine->createDataLayout());
|
||||||
else
|
else
|
||||||
module->setDataLayout(layout);
|
module.setDataLayout(layout);
|
||||||
// emit machine code
|
// emit machine code
|
||||||
for (llvm::Function &f : module->functions())
|
for (llvm::Function &f : module.functions())
|
||||||
f.addFnAttr(llvm::Attribute::AlwaysInline);
|
f.addFnAttr(llvm::Attribute::AlwaysInline);
|
||||||
llvm::legacy::PassManager pass;
|
llvm::legacy::PassManager pass;
|
||||||
llvm::raw_svector_ostream stream(buffer);
|
llvm::raw_svector_ostream stream(buffer);
|
||||||
// emit
|
// emit
|
||||||
machine->addPassesToEmitFile(pass, stream, nullptr,
|
machine->addPassesToEmitFile(pass, stream, nullptr,
|
||||||
llvm::CodeGenFileType::CGFT_AssemblyFile);
|
llvm::CodeGenFileType::CGFT_AssemblyFile);
|
||||||
pass.run(*module);
|
pass.run(module);
|
||||||
|
|
||||||
// post-process
|
// post-process
|
||||||
std::string result(buffer.begin(), buffer.end());
|
std::string result(buffer.begin(), buffer.end());
|
||||||
find_and_replace(result, ".version", "\n",
|
findAndReplace(result, ".version", "\n",
|
||||||
".version " + std::to_string(ptx_major) + "." +
|
".version " + std::to_string(ptxMajor) + "." +
|
||||||
std::to_string(ptx_minor) + "\n");
|
std::to_string(ptxMinor) + "\n");
|
||||||
find_and_replace(result, ".target", "\n", ".target " + sm + "\n");
|
findAndReplace(result, ".target", "\n", ".target " + sm + "\n");
|
||||||
while (find_and_replace(result, "\t// begin inline asm", "\n", ""))
|
while (findAndReplace(result, "\t// begin inline asm", "\n", ""))
|
||||||
;
|
;
|
||||||
while (find_and_replace(result, "\t// end inline asm", "\n", ""))
|
while (findAndReplace(result, "\t// end inline asm", "\n", ""))
|
||||||
;
|
;
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string translateLLVMIRToPTX(llvm::Module &module, int cc, int version) {
|
|
||||||
auto ptxCode = llir_to_ptx(&module, cc, version);
|
|
||||||
return ptxCode;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace triton
|
} // namespace triton
|
||||||
|
Reference in New Issue
Block a user