[GENERAL] Merged v1.0alpha into master. Added features are:
- A100 support via mma.16816 - Thread swizzling for conflict-free shared memory accesses without padding - Complete overhaul of the LLVM code generation in codegen/selection/generator.cc to remove overengineering - Added debugging capabilities in the Python binding - Compilation error for kernels that spill
This commit is contained in:
@@ -48,46 +48,6 @@ std::unique_ptr<codegen::target> host_device::make_target() const {
|
||||
// CUDA //
|
||||
/* ------------------------ */
|
||||
|
||||
// architecture
|
||||
cu_device::Architecture cu_device::nv_arch(std::pair<unsigned int, unsigned int> sm) const {
|
||||
switch(sm.first) {
|
||||
case 7:
|
||||
switch(sm.second){
|
||||
case 0: return Architecture::SM_7_0;
|
||||
}
|
||||
|
||||
case 6:
|
||||
switch(sm.second){
|
||||
case 0: return Architecture::SM_6_0;
|
||||
case 1: return Architecture::SM_6_1;
|
||||
}
|
||||
|
||||
case 5:
|
||||
switch(sm.second){
|
||||
case 0: return Architecture::SM_5_0;
|
||||
case 2: return Architecture::SM_5_2;
|
||||
default: return Architecture::UNKNOWN;
|
||||
}
|
||||
|
||||
case 3:
|
||||
switch(sm.second){
|
||||
case 0: return Architecture::SM_3_0;
|
||||
case 5: return Architecture::SM_3_5;
|
||||
case 7: return Architecture::SM_3_7;
|
||||
default: return Architecture::UNKNOWN;
|
||||
}
|
||||
|
||||
case 2:
|
||||
switch(sm.second){
|
||||
case 0: return Architecture::SM_2_0;
|
||||
case 1: return Architecture::SM_2_1;
|
||||
default: return Architecture::UNKNOWN;
|
||||
}
|
||||
|
||||
default: return Architecture::UNKNOWN;
|
||||
}
|
||||
}
|
||||
|
||||
// information query
|
||||
template<CUdevice_attribute attr>
|
||||
int cu_device::cuGetInfo() const{
|
||||
@@ -108,11 +68,6 @@ nvmlDevice_t cu_device::nvml_device() const{
|
||||
return map.at(key);
|
||||
}
|
||||
|
||||
// architecture
|
||||
cu_device::Architecture cu_device::architecture() const{
|
||||
return nv_arch(compute_capability());
|
||||
}
|
||||
|
||||
// number of address bits
|
||||
size_t cu_device::address_bits() const{
|
||||
return sizeof(size_t)*8;
|
||||
@@ -133,17 +88,17 @@ std::string cu_device::pci_bus_id() const{
|
||||
}
|
||||
|
||||
// force the device to be interpreted as a particular cc
|
||||
void cu_device::interpret_as(std::pair<size_t, size_t> cc){
|
||||
interpreted_as_ = std::make_shared<std::pair<size_t, size_t>>(cc);
|
||||
void cu_device::interpret_as(int cc){
|
||||
interpreted_as_ = std::make_shared<int>(cc);
|
||||
}
|
||||
|
||||
// compute capability
|
||||
std::pair<size_t, size_t> cu_device::compute_capability() const {
|
||||
int cu_device::compute_capability() const {
|
||||
if(interpreted_as_)
|
||||
return *interpreted_as_;
|
||||
size_t _major = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR>();
|
||||
size_t _minor = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR>();
|
||||
return std::make_pair(_major, _minor);
|
||||
size_t major = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR>();
|
||||
size_t minor = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR>();
|
||||
return major*10 + minor;
|
||||
}
|
||||
|
||||
// maximum number of threads per block
|
||||
@@ -218,7 +173,7 @@ std::string cu_device::infos() const{
|
||||
|
||||
// target
|
||||
std::unique_ptr<codegen::target> cu_device::make_target() const {
|
||||
return std::unique_ptr<codegen::nvidia_cu_target>(new codegen::nvidia_cu_target());
|
||||
return std::unique_ptr<codegen::nvidia_cu_target>(new codegen::nvidia_cu_target(compute_capability()));
|
||||
}
|
||||
|
||||
|
||||
|
@@ -93,6 +93,7 @@ namespace driver
|
||||
|
||||
bool dispatch::cuinit(){
|
||||
if(cuda_==nullptr){
|
||||
putenv((char*)"CUDA_CACHE_DISABLE=1");
|
||||
std::string libcuda = tools::getenv("TRITON_LIBCUDA");
|
||||
if(libcuda.empty())
|
||||
cuda_ = dlopen("libcuda.so", RTLD_LAZY);
|
||||
|
@@ -20,7 +20,9 @@
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*/
|
||||
#include <fstream>
|
||||
#include <unistd.h>
|
||||
#include <memory>
|
||||
#include <regex>
|
||||
#include "triton/driver/module.h"
|
||||
#include "triton/driver/context.h"
|
||||
#include "triton/driver/error.h"
|
||||
@@ -41,6 +43,19 @@
|
||||
#include "llvm/ExecutionEngine/SectionMemoryManager.h"
|
||||
#include "llvm/Transforms/Utils/Cloning.h"
|
||||
|
||||
std::string exec(const char* cmd) {
|
||||
std::array<char, 128> buffer;
|
||||
std::string result;
|
||||
std::unique_ptr<FILE, decltype(&pclose)> pipe(popen(cmd, "r"), pclose);
|
||||
if (!pipe) {
|
||||
throw std::runtime_error("popen() failed!");
|
||||
}
|
||||
while (fgets(buffer.data(), buffer.size(), pipe.get()) != nullptr) {
|
||||
result += buffer.data();
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
namespace triton
|
||||
{
|
||||
namespace driver
|
||||
@@ -63,11 +78,11 @@ void module::init_llvm() {
|
||||
}
|
||||
|
||||
module::module(CUmodule mod, bool has_ownership)
|
||||
: polymorphic_resource(mod, has_ownership) {
|
||||
: polymorphic_resource(mod, has_ownership), spilled_(0) {
|
||||
}
|
||||
|
||||
module::module(host_module_t mod, bool has_ownership)
|
||||
: polymorphic_resource(mod, has_ownership) {
|
||||
: polymorphic_resource(mod, has_ownership), spilled_(0) {
|
||||
}
|
||||
|
||||
|
||||
@@ -86,10 +101,12 @@ void module::compile_llvm_module(std::unique_ptr<llvm::Module> module, const std
|
||||
file_type_t ft) {
|
||||
init_llvm();
|
||||
// // debug
|
||||
// llvm::legacy::PassManager pm;
|
||||
llvm::legacy::PassManager pm;
|
||||
std::string tmp;
|
||||
// llvm::raw_string_ostream oss(llir_);
|
||||
// pm.add(llvm::createPrintModulePass(llvm::outs()));
|
||||
// pm.add(llvm::createVerifierPass());
|
||||
// pm.run(*module);
|
||||
pm.add(llvm::createVerifierPass());
|
||||
pm.run(*module);
|
||||
// create machine
|
||||
module->setTargetTriple(triple);
|
||||
std::string error;
|
||||
@@ -176,7 +193,7 @@ host_module::host_module(std::unique_ptr<llvm::Module> src): module(host_module_
|
||||
|
||||
// create execution engine
|
||||
for(llvm::Function& fn: src->functions())
|
||||
hst_->functions[fn.getName()] = &fn;
|
||||
hst_->functions[fn.getName().str()] = &fn;
|
||||
|
||||
// llvm::orc::JITTargetMachineBuilder JTMB = *llvm::orc::JITTargetMachineBuilder::detectHost();
|
||||
// auto DL = JTMB.getDefaultDataLayoutForTarget();
|
||||
@@ -225,7 +242,8 @@ static std::map<int, int> vptx = {
|
||||
{10010, 64},
|
||||
{10020, 65},
|
||||
{11000, 70},
|
||||
{11010, 71}
|
||||
{11010, 71},
|
||||
{11020, 72}
|
||||
};
|
||||
|
||||
std::string cu_module::compile_llvm_module(std::unique_ptr<llvm::Module> module, driver::device* device) {
|
||||
@@ -238,9 +256,7 @@ std::string cu_module::compile_llvm_module(std::unique_ptr<llvm::Module> module,
|
||||
assert(short_ptr);
|
||||
short_ptr->setValue(true);
|
||||
// compute capability
|
||||
auto _cc = ((driver::cu_device*)device)->compute_capability();
|
||||
int cc = _cc.first*10 + _cc.second;
|
||||
cc = std::min(cc, max_nvvm_cc);
|
||||
int cc = ((driver::cu_device*)device)->compute_capability();
|
||||
std::string sm = "sm_" + std::to_string(cc);
|
||||
// driver version
|
||||
int version;
|
||||
@@ -251,12 +267,11 @@ std::string cu_module::compile_llvm_module(std::unique_ptr<llvm::Module> module,
|
||||
throw std::runtime_error("Triton requires CUDA 10+");
|
||||
// PTX version
|
||||
int ptx = vptx.at(version);
|
||||
ptx = std::min(ptx, max_nvvm_ptx);
|
||||
int ptx_major = ptx / 10;
|
||||
int ptx_minor = ptx % 10;
|
||||
// create
|
||||
llvm::SmallVector<char, 0> buffer;
|
||||
module::compile_llvm_module(std::move(module), "nvptx64-nvidia-cuda", sm, "", buffer, "+ptx" + std::to_string(ptx), Assembly);
|
||||
module::compile_llvm_module(std::move(module), "nvptx64-nvidia-cuda", "sm_" + std::to_string(std::min(cc, max_nvvm_cc)), "", buffer, "+ptx" + std::to_string(std::min(ptx, max_nvvm_ptx)), Assembly);
|
||||
std::string result(buffer.begin(), buffer.end());
|
||||
find_and_replace(result, ".version", "\n", ".version " + std::to_string(ptx_major) + "." + std::to_string(ptx_minor) + "\n");
|
||||
find_and_replace(result, ".target", "\n", ".target " + sm + "\n");
|
||||
@@ -266,21 +281,69 @@ std::string cu_module::compile_llvm_module(std::unique_ptr<llvm::Module> module,
|
||||
}
|
||||
|
||||
|
||||
cu_module::cu_module(driver::device* device, std::unique_ptr<llvm::Module> ll_module): cu_module(compile_llvm_module(std::move(ll_module), device)) { }
|
||||
cu_module::cu_module(driver::device* device, std::unique_ptr<llvm::Module> ll_module): cu_module(device, compile_llvm_module(std::move(ll_module), device)) { }
|
||||
|
||||
cu_module::cu_module(std::string const & source) : module(CUmodule(), true), source_(source){
|
||||
cu_module::cu_module(driver::device* device, std::string const & source) : module(CUmodule(), true), ptx_(source){
|
||||
// JIT compile source-code
|
||||
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};
|
||||
unsigned int errbufsize = 8096;
|
||||
std::string errbuf(errbufsize, 0);
|
||||
void* optval[] = {(void*)(uintptr_t)errbufsize, (void*)errbuf.data()};
|
||||
|
||||
try{
|
||||
dispatch::cuModuleLoadDataEx(&*cu_, source_.data(), 2, opt, optval);
|
||||
}catch(exception::cuda::invalid_ptx const &){
|
||||
// // compile ptx with ptxas
|
||||
// char _fsrc[] = "/tmp/triton_k_XXXXXX";
|
||||
// char _flog[] = "/tmp/triton_l_XXXXXX";
|
||||
// int fdsrc = mkstemp(_fsrc);
|
||||
// int fdlog = mkstemp(_flog);
|
||||
// std::string fsrc = _fsrc;
|
||||
// std::string flog = _flog;
|
||||
// std::ofstream ofs(fsrc);
|
||||
// ofs << source;
|
||||
// ofs.close();
|
||||
// std::string cmd;
|
||||
// int err;
|
||||
// driver::cu_device* cu_device = (driver::cu_device*)device;
|
||||
// cmd = "ptxas -v --gpu-name=sm_" + std::to_string(cu_device->compute_capability()) + " " + fsrc + " -o " + fsrc + ".o 2> " + flog;
|
||||
// err = system(cmd.c_str());
|
||||
// dispatch::cuModuleLoad(&*cu_, (fsrc + ".o").c_str());
|
||||
// std::ifstream file(flog);
|
||||
// std::string log;
|
||||
// if(file)
|
||||
// while (!file.eof()) log.push_back(file.get());
|
||||
// unlink(_fsrc);
|
||||
// unlink(_flog);
|
||||
|
||||
// std::smatch match;
|
||||
// std::regex expr ("\\b([0-9]+) bytes spill");
|
||||
// spilled_ = 0;
|
||||
// while (std::regex_search (log,match,expr)){
|
||||
// spilled_ += std::stoi(match[1]);
|
||||
// log = match.suffix();
|
||||
// }
|
||||
// std::cout << log << std::endl;
|
||||
|
||||
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER,
|
||||
CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES, CU_JIT_INFO_LOG_BUFFER,
|
||||
CU_JIT_LOG_VERBOSE};
|
||||
unsigned int errbufsize = 8192;
|
||||
unsigned int logbufsize = 8192;
|
||||
char _err[errbufsize];
|
||||
char _log[logbufsize];
|
||||
void* optval[] = {(void*)(uintptr_t)errbufsize, (void*)_err, (void*)(uintptr_t)logbufsize, (void*)_log, (void*)1};
|
||||
dispatch::cuModuleLoadDataEx(&*cu_, ptx_.data(), 5, opt, optval);
|
||||
std::string err(_err);
|
||||
std::string log(_log);
|
||||
|
||||
// std::cout << log << std::endl;
|
||||
std::smatch match;
|
||||
std::regex expr ("\\b([0-9]+) bytes spill");
|
||||
spilled_ = 0;
|
||||
while (std::regex_search(log,match,expr)){
|
||||
spilled_ += std::stoi(match[1]);
|
||||
log = match.suffix();
|
||||
}
|
||||
}
|
||||
catch(exception::cuda::invalid_ptx const &){
|
||||
//#ifdef TRITON_LOG_PTX_ERROR
|
||||
std::cout << source << std::endl;
|
||||
std::cerr << "It appears that Triton produced invalid PTX code:" << std::endl;
|
||||
std::cerr << errbuf << std::endl;
|
||||
// exit(1);
|
||||
//#endif
|
||||
throw;
|
||||
|
Reference in New Issue
Block a user