[DRIVER] Added (slow) support for CUDA11 and Ampere
This commit is contained in:
@@ -223,31 +223,49 @@ static bool find_and_replace(std::string& str, const std::string& begin, const s
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static std::map<int, int> vptx = {
|
||||||
|
{10000, 63},
|
||||||
|
{10010, 64},
|
||||||
|
{10020, 65},
|
||||||
|
{11000, 70},
|
||||||
|
{11010, 71}
|
||||||
|
};
|
||||||
|
|
||||||
std::string cu_module::compile_llvm_module(std::unique_ptr<llvm::Module> module, driver::device* device) {
|
std::string cu_module::compile_llvm_module(std::unique_ptr<llvm::Module> module, driver::device* device) {
|
||||||
// options
|
// LLVM version in use may not officially support target hardware
|
||||||
auto options = llvm::cl::getRegisteredOptions();
|
int max_nvvm_cc = 75;
|
||||||
// for(auto& opt: options)
|
int max_nvvm_ptx = 64;
|
||||||
// std::cout << opt.getKey().str() << std::endl;
|
// options
|
||||||
auto* short_ptr = static_cast<llvm::cl::opt<bool>*>(options["nvptx-short-ptr"]);
|
auto options = llvm::cl::getRegisteredOptions();
|
||||||
assert(short_ptr);
|
auto* short_ptr = static_cast<llvm::cl::opt<bool>*>(options["nvptx-short-ptr"]);
|
||||||
short_ptr->setValue(true);
|
assert(short_ptr);
|
||||||
// compute capability
|
short_ptr->setValue(true);
|
||||||
auto cc = ((driver::cu_device*)device)->compute_capability();
|
// compute capability
|
||||||
std::string sm = "sm_" + std::to_string(cc.first) + std::to_string(cc.second);
|
auto _cc = ((driver::cu_device*)device)->compute_capability();
|
||||||
// create
|
int cc = _cc.first*10 + _cc.second;
|
||||||
llvm::SmallVector<char, 0> buffer;
|
cc = std::min(cc, max_nvvm_cc);
|
||||||
module::compile_llvm_module(std::move(module), "nvptx64-nvidia-cuda", sm, "", buffer, "ptx63", Assembly);
|
std::string sm = "sm_" + std::to_string(cc);
|
||||||
std::string result(buffer.begin(), buffer.end());
|
// driver version
|
||||||
int version;
|
int version;
|
||||||
dispatch::cuDriverGetVersion(&version);
|
dispatch::cuDriverGetVersion(&version);
|
||||||
int major = version / 1000;
|
int major = version / 1000;
|
||||||
// int minor = (version - major*1000) / 10;
|
int minor = (version - major*1000) / 10;
|
||||||
if(major < 10)
|
if(major < 10)
|
||||||
throw std::runtime_error("Triton requires CUDA 10+");
|
throw std::runtime_error("Triton requires CUDA 10+");
|
||||||
find_and_replace(result, ".version", "\n", ".version 6.4\n");
|
// PTX version
|
||||||
while(find_and_replace(result, "\t// begin inline asm", "\n", ""));
|
int ptx = vptx.at(version);
|
||||||
while(find_and_replace(result, "\t// end inline asm", "\n", ""));
|
ptx = std::min(ptx, max_nvvm_ptx);
|
||||||
return result;
|
int ptx_major = ptx / 10;
|
||||||
|
int ptx_minor = ptx % 10;
|
||||||
|
// create
|
||||||
|
llvm::SmallVector<char, 0> buffer;
|
||||||
|
module::compile_llvm_module(std::move(module), "nvptx64-nvidia-cuda", sm, "", buffer, "+ptx" + std::to_string(ptx), Assembly);
|
||||||
|
std::string result(buffer.begin(), buffer.end());
|
||||||
|
find_and_replace(result, ".version", "\n", ".version " + std::to_string(ptx_major) + "." + std::to_string(ptx_minor) + "\n");
|
||||||
|
find_and_replace(result, ".target", "\n", ".target " + sm + "\n");
|
||||||
|
while(find_and_replace(result, "\t// begin inline asm", "\n", ""));
|
||||||
|
while(find_and_replace(result, "\t// end inline asm", "\n", ""));
|
||||||
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user