Deprecation of Triton-C and Replacement by decorated Python functions (#86)
This PR implements a major overhaul of the frontend for Triton, and replaces Triton-C by a pure Python API in which kernels are defined as @triton.jit decorated functions. The documentation and tutorials have also been updated to accommodate these changes. See documentations for more information on the new API
This commit is contained in:
committed by
Philippe Tillet
parent
1fdb465b71
commit
39f4730305
@@ -23,6 +23,24 @@ void recursive_deps(ir::value* v, ir::basic_block* block, std::vector<ir::instru
|
||||
recursive_deps(u, block, ret);
|
||||
}
|
||||
|
||||
ir::value* rematerialize(ir::builder& builder, ir::value* v, size_t phi_idx){
|
||||
ir::instruction* i = dynamic_cast<ir::instruction*>(v);
|
||||
if(!i)
|
||||
return v;
|
||||
if(ir::phi_node* phi = dynamic_cast<ir::phi_node*>(v))
|
||||
return phi->get_incoming_value(phi_idx);
|
||||
|
||||
std::vector<ir::value*> new_ops;
|
||||
for(ir::value* op: i->ops()){
|
||||
new_ops.push_back(rematerialize(builder, op, phi_idx));
|
||||
}
|
||||
ir::instruction* ret = i->clone();
|
||||
for(size_t k = 0; k < new_ops.size(); k++)
|
||||
ret->set_operand(k, new_ops[k]);
|
||||
builder.insert(ret);
|
||||
return ret;
|
||||
}
|
||||
|
||||
void pipeline::run(ir::module &mod) {
|
||||
// *Very* conservative heuristics for pre-fetching.
|
||||
// A load instruction can be pipelined if:
|
||||
@@ -55,21 +73,27 @@ void pipeline::run(ir::module &mod) {
|
||||
// pre-fetch first iteration
|
||||
builder.set_insert_point(header->get_inst_list().back());
|
||||
ir::value* first_ptr = ptr->get_value_for_block(header);
|
||||
ir::value* first_mask = builder.create_splat(header_br->get_cond(), ty->get_tile_shapes());
|
||||
ir::value* first_mask = builder.create_splat(header_br->get_cond(), ty->get_block_shapes());
|
||||
ir::value* false_value;
|
||||
if(auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load)){
|
||||
first_mask = builder.create_and(first_mask, masked_load->get_mask_operand());
|
||||
false_value = masked_load->get_false_value_operand();
|
||||
ir::value* remat_mask = rematerialize(builder, masked_load->get_mask_operand(), 0);
|
||||
ir::value* remat_false_value = rematerialize(builder, masked_load->get_false_value_operand(), 0);
|
||||
first_mask = builder.create_and(first_mask, remat_mask);
|
||||
false_value = remat_false_value;
|
||||
}
|
||||
else
|
||||
false_value = builder.create_splat(ir::undef_value::get(ty->get_scalar_ty()), ty->get_tile_shapes());
|
||||
false_value = builder.create_splat(ir::undef_value::get(ty->get_scalar_ty()), ty->get_block_shapes());
|
||||
ir::value* first_load = builder.create_masked_load(first_ptr, first_mask, false_value);
|
||||
// pre-fetch next iteration
|
||||
builder.set_insert_point(block->get_inst_list().back());
|
||||
ir::value* next_ptr = ptr->get_value_for_block(block);
|
||||
ir::value* next_mask = builder.create_splat(block_br->get_cond(), ty->get_tile_shapes());
|
||||
if(auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load))
|
||||
next_mask = builder.create_and(next_mask, masked_load->get_mask_operand());
|
||||
ir::value* next_mask = builder.create_splat(block_br->get_cond(), ty->get_block_shapes());
|
||||
if(auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load)){
|
||||
ir::value* remat_mask = rematerialize(builder, masked_load->get_mask_operand(), 1);
|
||||
ir::value* remat_false_value = rematerialize(builder, masked_load->get_false_value_operand(), 1);
|
||||
next_mask = builder.create_and(next_mask, remat_mask);
|
||||
false_value = remat_false_value;
|
||||
}
|
||||
ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value);
|
||||
// phi node
|
||||
builder.set_insert_point(block->get_first_non_phi());
|
||||
|
Reference in New Issue
Block a user