diff --git a/lib/Target/PTX/PTXTranslation.cpp b/lib/Target/PTX/PTXTranslation.cpp index 32c99df95..2642ffd07 100644 --- a/lib/Target/PTX/PTXTranslation.cpp +++ b/lib/Target/PTX/PTXTranslation.cpp @@ -31,27 +31,28 @@ static bool findAndReplace(std::string &str, const std::string &begin, } std::string translateLLVMIRToPTX(llvm::Module &module, int cc, int version) { - // LLVM version in use may not officially support target hardware - int maxNNVMCC = 75; + // LLVM version in use may not officially support target hardware. + // Supported versions for LLVM 14 are here: + // https://github.com/llvm/llvm-project/blob/f28c006a5895fc0e329fe15fead81e37457cb1d1/clang/include/clang/Basic/BuiltinsNVPTX.def + int maxPTX = std::min(75, version); + int maxCC = std::min(86, cc); // options auto options = llvm::cl::getRegisteredOptions(); auto *shortPtr = static_cast *>(options["nvptx-short-ptr"]); assert(shortPtr); shortPtr->setValue(true); - // compute capability - std::string sm = "sm_" + std::to_string(cc); + std::string sm = "sm_" + std::to_string(maxCC); // max PTX version - int ptxMajor = version / 10; - int ptxMinor = version % 10; + int ptxMajor = maxPTX / 10; + int ptxMinor = maxPTX % 10; // create llvm::SmallVector buffer; std::string triple = "nvptx64-nvidia-cuda"; - std::string proc = "sm_" + std::to_string(std::min(cc, maxNNVMCC)); + std::string proc = "sm_" + std::to_string(maxCC); std::string layout = ""; std::string features = ""; - // std::string features = "+ptx" + std::to_string(std::min(ptx, - // max_nvvm_ptx)); + // std::string features = "+ptx" + std::to_string(maxPTX); initLLVM(); // verify and store llvm llvm::legacy::PassManager pm; diff --git a/python/triton/compiler.py b/python/triton/compiler.py index aa2760e64..57833215d 100644 --- a/python/triton/compiler.py +++ b/python/triton/compiler.py @@ -963,23 +963,12 @@ def ptx_get_version(cuda_version) -> int: ''' assert isinstance(cuda_version, str) major, minor = map(int, cuda_version.split('.')) - version = major * 1000 + minor * 10 - if version >= 11040: - return 74 - if version >= 11030: - return 73 - if version >= 11020: - return 72 - if version >= 11010: - return 71 - if version >= 11000: - return 70 - if version >= 10020: - return 65 - if version >= 10010: - return 64 - if version >= 10000: - return 63 + if major == 12: + return 80 + minor + if major == 11: + return 70 + minor + if major == 10: + return 63 + minor raise RuntimeError("Triton only support CUDA 10.0 or higher")