From f42b04d92535e8ada507de331f5d55ad4e276b21 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Sun, 15 Nov 2020 03:27:35 -0500 Subject: [PATCH] [DRIVER] Added (slow) support for CUDA11 and Ampere --- lib/driver/module.cc | 66 ++++++++++++++++++++++++++++---------------- 1 file changed, 42 insertions(+), 24 deletions(-) diff --git a/lib/driver/module.cc b/lib/driver/module.cc index c3eeb5f27..1f9b97a92 100755 --- a/lib/driver/module.cc +++ b/lib/driver/module.cc @@ -223,31 +223,49 @@ static bool find_and_replace(std::string& str, const std::string& begin, const s return true; } +static std::map vptx = { + {10000, 63}, + {10010, 64}, + {10020, 65}, + {11000, 70}, + {11010, 71} +}; + std::string cu_module::compile_llvm_module(std::unique_ptr module, driver::device* device) { - // options - auto options = llvm::cl::getRegisteredOptions(); -// for(auto& opt: options) -// std::cout << opt.getKey().str() << std::endl; - auto* short_ptr = static_cast*>(options["nvptx-short-ptr"]); - assert(short_ptr); - short_ptr->setValue(true); - // compute capability - auto cc = ((driver::cu_device*)device)->compute_capability(); - std::string sm = "sm_" + std::to_string(cc.first) + std::to_string(cc.second); - // create - llvm::SmallVector buffer; - module::compile_llvm_module(std::move(module), "nvptx64-nvidia-cuda", sm, "", buffer, "ptx63", Assembly); - std::string result(buffer.begin(), buffer.end()); - int version; - dispatch::cuDriverGetVersion(&version); - int major = version / 1000; -// int minor = (version - major*1000) / 10; - if(major < 10) - throw std::runtime_error("Triton requires CUDA 10+"); - find_and_replace(result, ".version", "\n", ".version 6.4\n"); - while(find_and_replace(result, "\t// begin inline asm", "\n", "")); - while(find_and_replace(result, "\t// end inline asm", "\n", "")); - return result; + // LLVM version in use may not officially support target hardware + int max_nvvm_cc = 75; + int max_nvvm_ptx = 64; + // options + auto options = llvm::cl::getRegisteredOptions(); + auto* short_ptr = static_cast*>(options["nvptx-short-ptr"]); + assert(short_ptr); + short_ptr->setValue(true); + // compute capability + auto _cc = ((driver::cu_device*)device)->compute_capability(); + int cc = _cc.first*10 + _cc.second; + cc = std::min(cc, max_nvvm_cc); + std::string sm = "sm_" + std::to_string(cc); + // driver version + int version; + dispatch::cuDriverGetVersion(&version); + int major = version / 1000; + int minor = (version - major*1000) / 10; + if(major < 10) + throw std::runtime_error("Triton requires CUDA 10+"); + // PTX version + int ptx = vptx.at(version); + ptx = std::min(ptx, max_nvvm_ptx); + int ptx_major = ptx / 10; + int ptx_minor = ptx % 10; + // create + llvm::SmallVector buffer; + module::compile_llvm_module(std::move(module), "nvptx64-nvidia-cuda", sm, "", buffer, "+ptx" + std::to_string(ptx), Assembly); + std::string result(buffer.begin(), buffer.end()); + find_and_replace(result, ".version", "\n", ".version " + std::to_string(ptx_major) + "." + std::to_string(ptx_minor) + "\n"); + find_and_replace(result, ".target", "\n", ".target " + sm + "\n"); + while(find_and_replace(result, "\t// begin inline asm", "\n", "")); + while(find_and_replace(result, "\t// end inline asm", "\n", "")); + return result; }