Deprecation of Triton-C and Replacement by decorated Python functions (#86)

This PR implements a major overhaul of the frontend for Triton, and replaces Triton-C by a pure Python API in which kernels are defined as @triton.jit decorated functions. The documentation and tutorials have also been updated to accommodate these changes.

See documentations for more information on the new API
This commit is contained in:
Philippe Tillet
2021-04-20 22:29:40 -04:00
committed by Philippe Tillet
parent 1fdb465b71
commit 39f4730305
91 changed files with 4500 additions and 13008 deletions

View File

@@ -11,38 +11,38 @@ namespace codegen{
namespace transform{
void reorder::run(ir::module& mod){
ir::builder &builder = mod.get_builder();
std::vector<std::pair<ir::instruction*, ir::value*>> to_replace;
// ir::builder &builder = mod.get_builder();
// std::vector<std::pair<ir::instruction*, ir::value*>> to_replace;
for(ir::function *fn: mod.get_function_list())
for(ir::basic_block *block: fn->blocks())
for(ir::instruction* i: block->get_inst_list()){
if(auto* ld = dynamic_cast<ir::masked_load_inst*>(i)){
ir::value* _ptr = ld->get_pointer_operand();
ir::value* _msk = ld->get_mask_operand();
ir::value* _val = ld->get_false_value_operand();
auto ptr = std::find(block->begin(), block->end(), _ptr);
auto msk = std::find(block->begin(), block->end(), _msk);
auto val = std::find(block->begin(), block->end(), _val);
if(ptr == block->end() || msk == block->end() || val == block->end())
continue;
auto it = std::find(block->begin(), block->end(), i);
int dist_ptr = std::distance(ptr, it);
int dist_msk = std::distance(msk, it);
int dist_val = std::distance(val, it);
if(dist_ptr < dist_msk && dist_ptr < dist_val)
builder.set_insert_point(++ptr);
if(dist_msk < dist_ptr && dist_msk < dist_val)
builder.set_insert_point(++msk);
if(dist_val < dist_ptr && dist_val < dist_msk)
builder.set_insert_point(++val);
ir::value* new_ld = builder.create_masked_load(_ptr, _msk, _val);
to_replace.push_back(std::make_pair(ld, new_ld));
}
}
// for(ir::function *fn: mod.get_function_list())
// for(ir::basic_block *block: fn->blocks())
// for(ir::instruction* i: block->get_inst_list()){
// if(auto* ld = dynamic_cast<ir::masked_load_inst*>(i)){
// ir::value* _ptr = ld->get_pointer_operand();
// ir::value* _msk = ld->get_mask_operand();
// ir::value* _val = ld->get_false_value_operand();
// auto ptr = std::find(block->begin(), block->end(), _ptr);
// auto msk = std::find(block->begin(), block->end(), _msk);
// auto val = std::find(block->begin(), block->end(), _val);
// if(ptr == block->end() || msk == block->end() || val == block->end())
// continue;
// auto it = std::find(block->begin(), block->end(), i);
// int dist_ptr = std::distance(ptr, it);
// int dist_msk = std::distance(msk, it);
// int dist_val = std::distance(val, it);
// if(dist_ptr < dist_msk && dist_ptr < dist_val)
// builder.set_insert_point(++ptr);
// if(dist_msk < dist_ptr && dist_msk < dist_val)
// builder.set_insert_point(++msk);
// if(dist_val < dist_ptr && dist_val < dist_msk)
// builder.set_insert_point(++val);
// ir::value* new_ld = builder.create_masked_load(_ptr, _msk, _val);
// to_replace.push_back(std::make_pair(ld, new_ld));
// }
// }
for(auto& x: to_replace)
x.first->replace_all_uses_with(x.second);
// for(auto& x: to_replace)
// x.first->replace_all_uses_with(x.second);
}