[DRIVER] More robust support of unsupported CUDA version (#179)
This commit is contained in:
@@ -211,16 +211,27 @@ static bool find_and_replace(std::string& str, const std::string& begin, const s
|
||||
return true;
|
||||
}
|
||||
|
||||
static std::map<int, int> vptx = {
|
||||
{10000, 63},
|
||||
{10010, 64},
|
||||
{10020, 65},
|
||||
{11000, 70},
|
||||
{11010, 71},
|
||||
{11020, 72},
|
||||
{11030, 73},
|
||||
{11040, 73}
|
||||
};
|
||||
//static std::map<int, int> vptx = {
|
||||
// {10000, 63},
|
||||
// {10010, 64},
|
||||
// {10020, 65},
|
||||
// {11000, 70},
|
||||
// {11010, 71},
|
||||
// {11020, 72},
|
||||
// {11030, 73},
|
||||
// {11040, 73}
|
||||
//};
|
||||
|
||||
int vptx(int version){
|
||||
if(version >= 11030) return 73;
|
||||
if(version >= 11020) return 72;
|
||||
if(version >= 11010) return 71;
|
||||
if(version >= 11000) return 70;
|
||||
if(version >= 10020) return 65;
|
||||
if(version >= 10010) return 64;
|
||||
if(version >= 10000) return 63;
|
||||
throw std::runtime_error("Triton requires CUDA 10+");
|
||||
}
|
||||
|
||||
std::string cu_module::compile_llvm_module(llvm::Module* module, driver::device* device) {
|
||||
// LLVM version in use may not officially support target hardware
|
||||
@@ -237,12 +248,7 @@ std::string cu_module::compile_llvm_module(llvm::Module* module, driver::device*
|
||||
// driver version
|
||||
int version;
|
||||
dispatch::cuDriverGetVersion(&version);
|
||||
int major = version / 1000;
|
||||
int minor = (version - major*1000) / 10;
|
||||
if(major < 10)
|
||||
throw std::runtime_error("Triton requires CUDA 10+");
|
||||
// PTX version
|
||||
int ptx = version > 11040 ? 73 : vptx.at(version);
|
||||
int ptx = vptx(version);
|
||||
int ptx_major = ptx / 10;
|
||||
int ptx_minor = ptx % 10;
|
||||
// create
|
||||
|
Reference in New Issue
Block a user