@@ -31,27 +31,28 @@ static bool findAndReplace(std::string &str, const std::string &begin,
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::string translateLLVMIRToPTX(llvm::Module &module, int cc, int version) {
|
std::string translateLLVMIRToPTX(llvm::Module &module, int cc, int version) {
|
||||||
// LLVM version in use may not officially support target hardware
|
// LLVM version in use may not officially support target hardware.
|
||||||
int maxNNVMCC = 75;
|
// Supported versions for LLVM 14 are here:
|
||||||
|
// https://github.com/llvm/llvm-project/blob/f28c006a5895fc0e329fe15fead81e37457cb1d1/clang/include/clang/Basic/BuiltinsNVPTX.def
|
||||||
|
int maxPTX = std::min(75, version);
|
||||||
|
int maxCC = std::min(86, cc);
|
||||||
// options
|
// options
|
||||||
auto options = llvm::cl::getRegisteredOptions();
|
auto options = llvm::cl::getRegisteredOptions();
|
||||||
auto *shortPtr =
|
auto *shortPtr =
|
||||||
static_cast<llvm::cl::opt<bool> *>(options["nvptx-short-ptr"]);
|
static_cast<llvm::cl::opt<bool> *>(options["nvptx-short-ptr"]);
|
||||||
assert(shortPtr);
|
assert(shortPtr);
|
||||||
shortPtr->setValue(true);
|
shortPtr->setValue(true);
|
||||||
// compute capability
|
std::string sm = "sm_" + std::to_string(maxCC);
|
||||||
std::string sm = "sm_" + std::to_string(cc);
|
|
||||||
// max PTX version
|
// max PTX version
|
||||||
int ptxMajor = version / 10;
|
int ptxMajor = maxPTX / 10;
|
||||||
int ptxMinor = version % 10;
|
int ptxMinor = maxPTX % 10;
|
||||||
// create
|
// create
|
||||||
llvm::SmallVector<char, 0> buffer;
|
llvm::SmallVector<char, 0> buffer;
|
||||||
std::string triple = "nvptx64-nvidia-cuda";
|
std::string triple = "nvptx64-nvidia-cuda";
|
||||||
std::string proc = "sm_" + std::to_string(std::min(cc, maxNNVMCC));
|
std::string proc = "sm_" + std::to_string(maxCC);
|
||||||
std::string layout = "";
|
std::string layout = "";
|
||||||
std::string features = "";
|
std::string features = "";
|
||||||
// std::string features = "+ptx" + std::to_string(std::min(ptx,
|
// std::string features = "+ptx" + std::to_string(maxPTX);
|
||||||
// max_nvvm_ptx));
|
|
||||||
initLLVM();
|
initLLVM();
|
||||||
// verify and store llvm
|
// verify and store llvm
|
||||||
llvm::legacy::PassManager pm;
|
llvm::legacy::PassManager pm;
|
||||||
|
@@ -963,23 +963,12 @@ def ptx_get_version(cuda_version) -> int:
|
|||||||
'''
|
'''
|
||||||
assert isinstance(cuda_version, str)
|
assert isinstance(cuda_version, str)
|
||||||
major, minor = map(int, cuda_version.split('.'))
|
major, minor = map(int, cuda_version.split('.'))
|
||||||
version = major * 1000 + minor * 10
|
if major == 12:
|
||||||
if version >= 11040:
|
return 80 + minor
|
||||||
return 74
|
if major == 11:
|
||||||
if version >= 11030:
|
return 70 + minor
|
||||||
return 73
|
if major == 10:
|
||||||
if version >= 11020:
|
return 63 + minor
|
||||||
return 72
|
|
||||||
if version >= 11010:
|
|
||||||
return 71
|
|
||||||
if version >= 11000:
|
|
||||||
return 70
|
|
||||||
if version >= 10020:
|
|
||||||
return 65
|
|
||||||
if version >= 10010:
|
|
||||||
return 64
|
|
||||||
if version >= 10000:
|
|
||||||
return 63
|
|
||||||
raise RuntimeError("Triton only support CUDA 10.0 or higher")
|
raise RuntimeError("Triton only support CUDA 10.0 or higher")
|
||||||
|
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user