[DRIVER] More robust support of unsupported CUDA version (#179)

This commit is contained in:
Philippe Tillet
2021-08-02 09:06:55 -07:00
committed by GitHub
parent b7cdf670c3
commit e8031fe61f

View File

@@ -211,16 +211,27 @@ static bool find_and_replace(std::string& str, const std::string& begin, const s
return true; return true;
} }
static std::map<int, int> vptx = { //static std::map<int, int> vptx = {
{10000, 63}, // {10000, 63},
{10010, 64}, // {10010, 64},
{10020, 65}, // {10020, 65},
{11000, 70}, // {11000, 70},
{11010, 71}, // {11010, 71},
{11020, 72}, // {11020, 72},
{11030, 73}, // {11030, 73},
{11040, 73} // {11040, 73}
}; //};
int vptx(int version){
if(version >= 11030) return 73;
if(version >= 11020) return 72;
if(version >= 11010) return 71;
if(version >= 11000) return 70;
if(version >= 10020) return 65;
if(version >= 10010) return 64;
if(version >= 10000) return 63;
throw std::runtime_error("Triton requires CUDA 10+");
}
std::string cu_module::compile_llvm_module(llvm::Module* module, driver::device* device) { std::string cu_module::compile_llvm_module(llvm::Module* module, driver::device* device) {
// LLVM version in use may not officially support target hardware // LLVM version in use may not officially support target hardware
@@ -237,12 +248,7 @@ std::string cu_module::compile_llvm_module(llvm::Module* module, driver::device*
// driver version // driver version
int version; int version;
dispatch::cuDriverGetVersion(&version); dispatch::cuDriverGetVersion(&version);
int major = version / 1000; int ptx = vptx(version);
int minor = (version - major*1000) / 10;
if(major < 10)
throw std::runtime_error("Triton requires CUDA 10+");
// PTX version
int ptx = version > 11040 ? 73 : vptx.at(version);
int ptx_major = ptx / 10; int ptx_major = ptx / 10;
int ptx_minor = ptx % 10; int ptx_minor = ptx % 10;
// create // create