diff --git a/lib/driver/module.cc b/lib/driver/module.cc index ff8f37665..83ef7e832 100755 --- a/lib/driver/module.cc +++ b/lib/driver/module.cc @@ -211,16 +211,27 @@ static bool find_and_replace(std::string& str, const std::string& begin, const s return true; } -static std::map vptx = { - {10000, 63}, - {10010, 64}, - {10020, 65}, - {11000, 70}, - {11010, 71}, - {11020, 72}, - {11030, 73}, - {11040, 73} -}; +//static std::map vptx = { +// {10000, 63}, +// {10010, 64}, +// {10020, 65}, +// {11000, 70}, +// {11010, 71}, +// {11020, 72}, +// {11030, 73}, +// {11040, 73} +//}; + +int vptx(int version){ + if(version >= 11030) return 73; + if(version >= 11020) return 72; + if(version >= 11010) return 71; + if(version >= 11000) return 70; + if(version >= 10020) return 65; + if(version >= 10010) return 64; + if(version >= 10000) return 63; + throw std::runtime_error("Triton requires CUDA 10+"); +} std::string cu_module::compile_llvm_module(llvm::Module* module, driver::device* device) { // LLVM version in use may not officially support target hardware @@ -237,12 +248,7 @@ std::string cu_module::compile_llvm_module(llvm::Module* module, driver::device* // driver version int version; dispatch::cuDriverGetVersion(&version); - int major = version / 1000; - int minor = (version - major*1000) / 10; - if(major < 10) - throw std::runtime_error("Triton requires CUDA 10+"); - // PTX version - int ptx = version > 11040 ? 73 : vptx.at(version); + int ptx = vptx(version); int ptx_major = ptx / 10; int ptx_minor = ptx % 10; // create