[FRONTEND] Update PTX/SM support for LLVM14 (PR #1038 redux) (#1039)

=
This commit is contained in:
Connor Baker
2023-01-09 13:31:55 -05:00
committed by GitHub
parent 733301ff31
commit c20215dad1
2 changed files with 16 additions and 26 deletions

View File

@@ -31,27 +31,28 @@ static bool findAndReplace(std::string &str, const std::string &begin,
} }
std::string translateLLVMIRToPTX(llvm::Module &module, int cc, int version) { std::string translateLLVMIRToPTX(llvm::Module &module, int cc, int version) {
// LLVM version in use may not officially support target hardware // LLVM version in use may not officially support target hardware.
int maxNNVMCC = 75; // Supported versions for LLVM 14 are here:
// https://github.com/llvm/llvm-project/blob/f28c006a5895fc0e329fe15fead81e37457cb1d1/clang/include/clang/Basic/BuiltinsNVPTX.def
int maxPTX = std::min(75, version);
int maxCC = std::min(86, cc);
// options // options
auto options = llvm::cl::getRegisteredOptions(); auto options = llvm::cl::getRegisteredOptions();
auto *shortPtr = auto *shortPtr =
static_cast<llvm::cl::opt<bool> *>(options["nvptx-short-ptr"]); static_cast<llvm::cl::opt<bool> *>(options["nvptx-short-ptr"]);
assert(shortPtr); assert(shortPtr);
shortPtr->setValue(true); shortPtr->setValue(true);
// compute capability std::string sm = "sm_" + std::to_string(maxCC);
std::string sm = "sm_" + std::to_string(cc);
// max PTX version // max PTX version
int ptxMajor = version / 10; int ptxMajor = maxPTX / 10;
int ptxMinor = version % 10; int ptxMinor = maxPTX % 10;
// create // create
llvm::SmallVector<char, 0> buffer; llvm::SmallVector<char, 0> buffer;
std::string triple = "nvptx64-nvidia-cuda"; std::string triple = "nvptx64-nvidia-cuda";
std::string proc = "sm_" + std::to_string(std::min(cc, maxNNVMCC)); std::string proc = "sm_" + std::to_string(maxCC);
std::string layout = ""; std::string layout = "";
std::string features = ""; std::string features = "";
// std::string features = "+ptx" + std::to_string(std::min(ptx, // std::string features = "+ptx" + std::to_string(maxPTX);
// max_nvvm_ptx));
initLLVM(); initLLVM();
// verify and store llvm // verify and store llvm
llvm::legacy::PassManager pm; llvm::legacy::PassManager pm;

View File

@@ -963,23 +963,12 @@ def ptx_get_version(cuda_version) -> int:
''' '''
assert isinstance(cuda_version, str) assert isinstance(cuda_version, str)
major, minor = map(int, cuda_version.split('.')) major, minor = map(int, cuda_version.split('.'))
version = major * 1000 + minor * 10 if major == 12:
if version >= 11040: return 80 + minor
return 74 if major == 11:
if version >= 11030: return 70 + minor
return 73 if major == 10:
if version >= 11020: return 63 + minor
return 72
if version >= 11010:
return 71
if version >= 11000:
return 70
if version >= 10020:
return 65
if version >= 10010:
return 64
if version >= 10000:
return 63
raise RuntimeError("Triton only support CUDA 10.0 or higher") raise RuntimeError("Triton only support CUDA 10.0 or higher")