[DRIVER] Added (slow) support for CUDA11 and Ampere

This commit is contained in:
Philippe Tillet
2020-11-15 03:27:35 -05:00
parent baa858aa74
commit f42b04d925

View File

@@ -223,31 +223,49 @@ static bool find_and_replace(std::string& str, const std::string& begin, const s
return true;
}
static std::map<int, int> vptx = {
{10000, 63},
{10010, 64},
{10020, 65},
{11000, 70},
{11010, 71}
};
std::string cu_module::compile_llvm_module(std::unique_ptr<llvm::Module> module, driver::device* device) {
// options
auto options = llvm::cl::getRegisteredOptions();
// for(auto& opt: options)
// std::cout << opt.getKey().str() << std::endl;
auto* short_ptr = static_cast<llvm::cl::opt<bool>*>(options["nvptx-short-ptr"]);
assert(short_ptr);
short_ptr->setValue(true);
// compute capability
auto cc = ((driver::cu_device*)device)->compute_capability();
std::string sm = "sm_" + std::to_string(cc.first) + std::to_string(cc.second);
// create
llvm::SmallVector<char, 0> buffer;
module::compile_llvm_module(std::move(module), "nvptx64-nvidia-cuda", sm, "", buffer, "ptx63", Assembly);
std::string result(buffer.begin(), buffer.end());
int version;
dispatch::cuDriverGetVersion(&version);
int major = version / 1000;
// int minor = (version - major*1000) / 10;
if(major < 10)
throw std::runtime_error("Triton requires CUDA 10+");
find_and_replace(result, ".version", "\n", ".version 6.4\n");
while(find_and_replace(result, "\t// begin inline asm", "\n", ""));
while(find_and_replace(result, "\t// end inline asm", "\n", ""));
return result;
// LLVM version in use may not officially support target hardware
int max_nvvm_cc = 75;
int max_nvvm_ptx = 64;
// options
auto options = llvm::cl::getRegisteredOptions();
auto* short_ptr = static_cast<llvm::cl::opt<bool>*>(options["nvptx-short-ptr"]);
assert(short_ptr);
short_ptr->setValue(true);
// compute capability
auto _cc = ((driver::cu_device*)device)->compute_capability();
int cc = _cc.first*10 + _cc.second;
cc = std::min(cc, max_nvvm_cc);
std::string sm = "sm_" + std::to_string(cc);
// driver version
int version;
dispatch::cuDriverGetVersion(&version);
int major = version / 1000;
int minor = (version - major*1000) / 10;
if(major < 10)
throw std::runtime_error("Triton requires CUDA 10+");
// PTX version
int ptx = vptx.at(version);
ptx = std::min(ptx, max_nvvm_ptx);
int ptx_major = ptx / 10;
int ptx_minor = ptx % 10;
// create
llvm::SmallVector<char, 0> buffer;
module::compile_llvm_module(std::move(module), "nvptx64-nvidia-cuda", sm, "", buffer, "+ptx" + std::to_string(ptx), Assembly);
std::string result(buffer.begin(), buffer.end());
find_and_replace(result, ".version", "\n", ".version " + std::to_string(ptx_major) + "." + std::to_string(ptx_minor) + "\n");
find_and_replace(result, ".target", "\n", ".target " + sm + "\n");
while(find_and_replace(result, "\t// begin inline asm", "\n", ""));
while(find_and_replace(result, "\t// end inline asm", "\n", ""));
return result;
}