[DRIVER] More robust support of unsupported CUDA version (#179)

This commit is contained in:
Philippe Tillet
2021-08-02 09:06:55 -07:00
committed by GitHub
parent b7cdf670c3
commit e8031fe61f

View File

@@ -211,16 +211,27 @@ static bool find_and_replace(std::string& str, const std::string& begin, const s
return true;
}
static std::map<int, int> vptx = {
{10000, 63},
{10010, 64},
{10020, 65},
{11000, 70},
{11010, 71},
{11020, 72},
{11030, 73},
{11040, 73}
};
//static std::map<int, int> vptx = {
// {10000, 63},
// {10010, 64},
// {10020, 65},
// {11000, 70},
// {11010, 71},
// {11020, 72},
// {11030, 73},
// {11040, 73}
//};
int vptx(int version){
if(version >= 11030) return 73;
if(version >= 11020) return 72;
if(version >= 11010) return 71;
if(version >= 11000) return 70;
if(version >= 10020) return 65;
if(version >= 10010) return 64;
if(version >= 10000) return 63;
throw std::runtime_error("Triton requires CUDA 10+");
}
std::string cu_module::compile_llvm_module(llvm::Module* module, driver::device* device) {
// LLVM version in use may not officially support target hardware
@@ -237,12 +248,7 @@ std::string cu_module::compile_llvm_module(llvm::Module* module, driver::device*
// driver version
int version;
dispatch::cuDriverGetVersion(&version);
int major = version / 1000;
int minor = (version - major*1000) / 10;
if(major < 10)
throw std::runtime_error("Triton requires CUDA 10+");
// PTX version
int ptx = version > 11040 ? 73 : vptx.at(version);
int ptx = vptx(version);
int ptx_major = ptx / 10;
int ptx_minor = ptx % 10;
// create