Deprecation of Triton-C and Replacement by decorated Python functions (#86)
This PR implements a major overhaul of the frontend for Triton, and replaces Triton-C by a pure Python API in which kernels are defined as @triton.jit decorated functions. The documentation and tutorials have also been updated to accommodate these changes. See documentations for more information on the new API
This commit is contained in:
committed by
Philippe Tillet
parent
1fdb465b71
commit
39f4730305
@@ -212,7 +212,7 @@ static std::map<int, int> vptx = {
|
||||
{11020, 72}
|
||||
};
|
||||
|
||||
std::string cu_module::compile_llvm_module(std::unique_ptr<llvm::Module> module, driver::device* device) {
|
||||
std::string cu_module::compile_llvm_module(llvm::Module* module, driver::device* device) {
|
||||
// LLVM version in use may not officially support target hardware
|
||||
int max_nvvm_cc = 75;
|
||||
int max_nvvm_ptx = 64;
|
||||
@@ -316,6 +316,7 @@ void cu_module::init_from_ptx(const std::string& ptx) {
|
||||
// log = match.suffix();
|
||||
// }
|
||||
// std::cout << log << std::endl;
|
||||
// std::cout << ptx_ << std::endl;
|
||||
|
||||
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER,
|
||||
CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES, CU_JIT_INFO_LOG_BUFFER,
|
||||
@@ -351,8 +352,9 @@ cu_module::cu_module(driver::device* device, std::unique_ptr<llvm::Module> ll_mo
|
||||
oss << *ll_module;
|
||||
oss.flush();
|
||||
std::string cache_path = tools::getenv("TRITON_DEBUG_CACHE_PATH");
|
||||
if(cache_path.empty())
|
||||
ptx_ = compile_llvm_module(std::move(ll_module), device);
|
||||
if(cache_path.empty()){
|
||||
ptx_ = compile_llvm_module(ll_module.get(), device);
|
||||
}
|
||||
else{
|
||||
tools::mkdir(cache_path);
|
||||
// update cache path to PTX file
|
||||
@@ -370,7 +372,7 @@ cu_module::cu_module(driver::device* device, std::unique_ptr<llvm::Module> ll_mo
|
||||
ptx_ = _ptx.str();
|
||||
// compile and write-back if read empty
|
||||
if(ptx_.empty()){
|
||||
ptx_ = compile_llvm_module(std::move(ll_module), device);
|
||||
ptx_ = compile_llvm_module(ll_module.get(), device);
|
||||
std::ofstream ofs(cache_path);
|
||||
ofs << ptx_;
|
||||
}
|
||||
|
Reference in New Issue
Block a user