[DRIVER] More robust support of unsupported CUDA version (#179)
This commit is contained in:
@@ -211,16 +211,27 @@ static bool find_and_replace(std::string& str, const std::string& begin, const s
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::map<int, int> vptx = {
|
//static std::map<int, int> vptx = {
|
||||||
{10000, 63},
|
// {10000, 63},
|
||||||
{10010, 64},
|
// {10010, 64},
|
||||||
{10020, 65},
|
// {10020, 65},
|
||||||
{11000, 70},
|
// {11000, 70},
|
||||||
{11010, 71},
|
// {11010, 71},
|
||||||
{11020, 72},
|
// {11020, 72},
|
||||||
{11030, 73},
|
// {11030, 73},
|
||||||
{11040, 73}
|
// {11040, 73}
|
||||||
};
|
//};
|
||||||
|
|
||||||
|
int vptx(int version){
|
||||||
|
if(version >= 11030) return 73;
|
||||||
|
if(version >= 11020) return 72;
|
||||||
|
if(version >= 11010) return 71;
|
||||||
|
if(version >= 11000) return 70;
|
||||||
|
if(version >= 10020) return 65;
|
||||||
|
if(version >= 10010) return 64;
|
||||||
|
if(version >= 10000) return 63;
|
||||||
|
throw std::runtime_error("Triton requires CUDA 10+");
|
||||||
|
}
|
||||||
|
|
||||||
std::string cu_module::compile_llvm_module(llvm::Module* module, driver::device* device) {
|
std::string cu_module::compile_llvm_module(llvm::Module* module, driver::device* device) {
|
||||||
// LLVM version in use may not officially support target hardware
|
// LLVM version in use may not officially support target hardware
|
||||||
@@ -237,12 +248,7 @@ std::string cu_module::compile_llvm_module(llvm::Module* module, driver::device*
|
|||||||
// driver version
|
// driver version
|
||||||
int version;
|
int version;
|
||||||
dispatch::cuDriverGetVersion(&version);
|
dispatch::cuDriverGetVersion(&version);
|
||||||
int major = version / 1000;
|
int ptx = vptx(version);
|
||||||
int minor = (version - major*1000) / 10;
|
|
||||||
if(major < 10)
|
|
||||||
throw std::runtime_error("Triton requires CUDA 10+");
|
|
||||||
// PTX version
|
|
||||||
int ptx = version > 11040 ? 73 : vptx.at(version);
|
|
||||||
int ptx_major = ptx / 10;
|
int ptx_major = ptx / 10;
|
||||||
int ptx_minor = ptx % 10;
|
int ptx_minor = ptx % 10;
|
||||||
// create
|
// create
|
||||||
|
Reference in New Issue
Block a user