[DRIVER] Added (slow) support for CUDA11 and Ampere

This commit is contained in:
Philippe Tillet
2020-11-15 03:27:35 -05:00
parent baa858aa74
commit f42b04d925

View File

@@ -223,31 +223,49 @@ static bool find_and_replace(std::string& str, const std::string& begin, const s
return true; return true;
} }
static std::map<int, int> vptx = {
{10000, 63},
{10010, 64},
{10020, 65},
{11000, 70},
{11010, 71}
};
std::string cu_module::compile_llvm_module(std::unique_ptr<llvm::Module> module, driver::device* device) { std::string cu_module::compile_llvm_module(std::unique_ptr<llvm::Module> module, driver::device* device) {
// options // LLVM version in use may not officially support target hardware
auto options = llvm::cl::getRegisteredOptions(); int max_nvvm_cc = 75;
// for(auto& opt: options) int max_nvvm_ptx = 64;
// std::cout << opt.getKey().str() << std::endl; // options
auto* short_ptr = static_cast<llvm::cl::opt<bool>*>(options["nvptx-short-ptr"]); auto options = llvm::cl::getRegisteredOptions();
assert(short_ptr); auto* short_ptr = static_cast<llvm::cl::opt<bool>*>(options["nvptx-short-ptr"]);
short_ptr->setValue(true); assert(short_ptr);
// compute capability short_ptr->setValue(true);
auto cc = ((driver::cu_device*)device)->compute_capability(); // compute capability
std::string sm = "sm_" + std::to_string(cc.first) + std::to_string(cc.second); auto _cc = ((driver::cu_device*)device)->compute_capability();
// create int cc = _cc.first*10 + _cc.second;
llvm::SmallVector<char, 0> buffer; cc = std::min(cc, max_nvvm_cc);
module::compile_llvm_module(std::move(module), "nvptx64-nvidia-cuda", sm, "", buffer, "ptx63", Assembly); std::string sm = "sm_" + std::to_string(cc);
std::string result(buffer.begin(), buffer.end()); // driver version
int version; int version;
dispatch::cuDriverGetVersion(&version); dispatch::cuDriverGetVersion(&version);
int major = version / 1000; int major = version / 1000;
// int minor = (version - major*1000) / 10; int minor = (version - major*1000) / 10;
if(major < 10) if(major < 10)
throw std::runtime_error("Triton requires CUDA 10+"); throw std::runtime_error("Triton requires CUDA 10+");
find_and_replace(result, ".version", "\n", ".version 6.4\n"); // PTX version
while(find_and_replace(result, "\t// begin inline asm", "\n", "")); int ptx = vptx.at(version);
while(find_and_replace(result, "\t// end inline asm", "\n", "")); ptx = std::min(ptx, max_nvvm_ptx);
return result; int ptx_major = ptx / 10;
int ptx_minor = ptx % 10;
// create
llvm::SmallVector<char, 0> buffer;
module::compile_llvm_module(std::move(module), "nvptx64-nvidia-cuda", sm, "", buffer, "+ptx" + std::to_string(ptx), Assembly);
std::string result(buffer.begin(), buffer.end());
find_and_replace(result, ".version", "\n", ".version " + std::to_string(ptx_major) + "." + std::to_string(ptx_minor) + "\n");
find_and_replace(result, ".target", "\n", ".target " + sm + "\n");
while(find_and_replace(result, "\t// begin inline asm", "\n", ""));
while(find_and_replace(result, "\t// end inline asm", "\n", ""));
return result;
} }