Deprecation of Triton-C and Replacement by decorated Python functions (#86)
This PR implements a major overhaul of the frontend for Triton, and replaces Triton-C by a pure Python API in which kernels are defined as @triton.jit decorated functions. The documentation and tutorials have also been updated to accommodate these changes. See documentations for more information on the new API
This commit is contained in:
committed by
Philippe Tillet
parent
1fdb465b71
commit
39f4730305
@@ -37,7 +37,7 @@ int membar::group_of(ir::value* v, std::vector<ir::value*> &async_write) {
|
||||
membar::val_set_t membar::intersect_with(const val_set_t& as, const val_set_t& bs) {
|
||||
val_set_t ret;
|
||||
for(ir::value* a: as){
|
||||
if(!a->get_type()->is_tile_ty())
|
||||
if(!a->get_type()->is_block_ty())
|
||||
continue;
|
||||
analysis::shared_layout* a_layout = layouts_->get(a)->to_shared();
|
||||
if(!a_layout)
|
||||
@@ -45,7 +45,7 @@ membar::val_set_t membar::intersect_with(const val_set_t& as, const val_set_t& b
|
||||
int a_start = alloc_->offset(a_layout);
|
||||
int a_end = a_start + a_layout->get_size();
|
||||
for(ir::value* b: bs){
|
||||
if(!b->get_type()->is_tile_ty())
|
||||
if(!b->get_type()->is_block_ty())
|
||||
continue;
|
||||
analysis::shared_layout* b_layout = layouts_->get(b)->to_shared();
|
||||
if(!b_layout)
|
||||
@@ -80,7 +80,7 @@ void membar::transfer(ir::basic_block *block,
|
||||
// Get shared memory reads
|
||||
std::set<ir::value*> read;
|
||||
std::copy_if(i->op_begin(), i->op_end(), std::inserter(read, read.begin()),
|
||||
[&](ir::value* i){ return i->get_type()->is_tile_ty() && layouts_->get(i)->to_shared();});
|
||||
[&](ir::value* i){ return i->get_type()->is_block_ty() && layouts_->get(i)->to_shared();});
|
||||
// RAW (async)
|
||||
val_set_t tmp;
|
||||
std::copy(async_write.begin(), async_write.end(), std::inserter(tmp, tmp.begin()));
|
||||
|
@@ -58,7 +58,8 @@ bool peephole::rewrite_trans_phi(ir::instruction* value, ir::builder& builder) {
|
||||
}
|
||||
|
||||
bool peephole::rewrite_dot(ir::instruction *value, ir::builder& builder){
|
||||
// dot(a, b, 0) + c -> dot(a, b, c)
|
||||
// dot(a, b, c) + d -> dot(a, b, c + d)
|
||||
// d + dot(a, b, c) -> dot(a, b, c + d)
|
||||
auto add = dynamic_cast<ir::binary_operator*>(value);
|
||||
if(add && add->get_op() == ir::binary_op_t::FAdd) {
|
||||
ir::value *lhs = add->get_operand(0);
|
||||
@@ -131,10 +132,10 @@ bool peephole::rewrite_unit_red(ir::instruction *value, ir::builder& builder){
|
||||
if(!x)
|
||||
return false;
|
||||
ir::value *arg = x->get_operand(0);
|
||||
auto shapes = arg->get_type()->get_tile_shapes();
|
||||
auto shapes = arg->get_type()->get_block_shapes();
|
||||
if(shapes[x->get_axis()] == 1){
|
||||
builder.set_insert_point(x);
|
||||
ir::value* new_red = builder.create_reshape(arg, x->get_type()->get_tile_shapes());
|
||||
ir::value* new_red = builder.create_reshape(arg, x->get_type()->get_block_shapes());
|
||||
x->replace_all_uses_with(new_red);
|
||||
return true;
|
||||
}
|
||||
|
@@ -23,6 +23,24 @@ void recursive_deps(ir::value* v, ir::basic_block* block, std::vector<ir::instru
|
||||
recursive_deps(u, block, ret);
|
||||
}
|
||||
|
||||
ir::value* rematerialize(ir::builder& builder, ir::value* v, size_t phi_idx){
|
||||
ir::instruction* i = dynamic_cast<ir::instruction*>(v);
|
||||
if(!i)
|
||||
return v;
|
||||
if(ir::phi_node* phi = dynamic_cast<ir::phi_node*>(v))
|
||||
return phi->get_incoming_value(phi_idx);
|
||||
|
||||
std::vector<ir::value*> new_ops;
|
||||
for(ir::value* op: i->ops()){
|
||||
new_ops.push_back(rematerialize(builder, op, phi_idx));
|
||||
}
|
||||
ir::instruction* ret = i->clone();
|
||||
for(size_t k = 0; k < new_ops.size(); k++)
|
||||
ret->set_operand(k, new_ops[k]);
|
||||
builder.insert(ret);
|
||||
return ret;
|
||||
}
|
||||
|
||||
void pipeline::run(ir::module &mod) {
|
||||
// *Very* conservative heuristics for pre-fetching.
|
||||
// A load instruction can be pipelined if:
|
||||
@@ -55,21 +73,27 @@ void pipeline::run(ir::module &mod) {
|
||||
// pre-fetch first iteration
|
||||
builder.set_insert_point(header->get_inst_list().back());
|
||||
ir::value* first_ptr = ptr->get_value_for_block(header);
|
||||
ir::value* first_mask = builder.create_splat(header_br->get_cond(), ty->get_tile_shapes());
|
||||
ir::value* first_mask = builder.create_splat(header_br->get_cond(), ty->get_block_shapes());
|
||||
ir::value* false_value;
|
||||
if(auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load)){
|
||||
first_mask = builder.create_and(first_mask, masked_load->get_mask_operand());
|
||||
false_value = masked_load->get_false_value_operand();
|
||||
ir::value* remat_mask = rematerialize(builder, masked_load->get_mask_operand(), 0);
|
||||
ir::value* remat_false_value = rematerialize(builder, masked_load->get_false_value_operand(), 0);
|
||||
first_mask = builder.create_and(first_mask, remat_mask);
|
||||
false_value = remat_false_value;
|
||||
}
|
||||
else
|
||||
false_value = builder.create_splat(ir::undef_value::get(ty->get_scalar_ty()), ty->get_tile_shapes());
|
||||
false_value = builder.create_splat(ir::undef_value::get(ty->get_scalar_ty()), ty->get_block_shapes());
|
||||
ir::value* first_load = builder.create_masked_load(first_ptr, first_mask, false_value);
|
||||
// pre-fetch next iteration
|
||||
builder.set_insert_point(block->get_inst_list().back());
|
||||
ir::value* next_ptr = ptr->get_value_for_block(block);
|
||||
ir::value* next_mask = builder.create_splat(block_br->get_cond(), ty->get_tile_shapes());
|
||||
if(auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load))
|
||||
next_mask = builder.create_and(next_mask, masked_load->get_mask_operand());
|
||||
ir::value* next_mask = builder.create_splat(block_br->get_cond(), ty->get_block_shapes());
|
||||
if(auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load)){
|
||||
ir::value* remat_mask = rematerialize(builder, masked_load->get_mask_operand(), 1);
|
||||
ir::value* remat_false_value = rematerialize(builder, masked_load->get_false_value_operand(), 1);
|
||||
next_mask = builder.create_and(next_mask, remat_mask);
|
||||
false_value = remat_false_value;
|
||||
}
|
||||
ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value);
|
||||
// phi node
|
||||
builder.set_insert_point(block->get_first_non_phi());
|
||||
|
@@ -40,7 +40,7 @@ ir::value *reassociate::reassociate_idx(ir::value *old_value,
|
||||
|
||||
// handle retiling
|
||||
if(ir::instruction* op = dynamic_cast<ir::retile_inst*>(old_value)){
|
||||
auto shapes = op->get_type()->get_tile_shapes();
|
||||
auto shapes = op->get_type()->get_block_shapes();
|
||||
ir::value *old_arg = op->get_operand(0);
|
||||
ir::value *new_arg = reassociate_idx(old_arg, builder, noncst, cst);
|
||||
// retile(x + y) = retile(x) + retile(y)
|
||||
@@ -54,19 +54,19 @@ ir::value *reassociate::reassociate_idx(ir::value *old_value,
|
||||
builder.set_insert_point(op);
|
||||
new_lhs = builder.create_reshape(old_lhs, shapes);
|
||||
new_rhs = builder.create_reshape(old_rhs, shapes);
|
||||
new_value = builder.create_add(new_lhs, new_rhs, op->get_name());
|
||||
new_value = builder.create_add(new_lhs, new_rhs);
|
||||
}
|
||||
if(dynamic_cast<ir::broadcast_inst*>(op)){
|
||||
builder.set_insert_point(op);
|
||||
new_lhs = builder.create_broadcast(old_lhs, shapes);
|
||||
new_rhs = builder.create_broadcast(old_rhs, shapes);
|
||||
new_value = builder.create_add(new_lhs, new_rhs, op->get_name());
|
||||
new_value = builder.create_add(new_lhs, new_rhs);
|
||||
}
|
||||
if(dynamic_cast<ir::splat_inst*>(op)){
|
||||
builder.set_insert_point(op);
|
||||
new_lhs = builder.create_splat(old_lhs, shapes);
|
||||
new_rhs = builder.create_splat(old_rhs, shapes);
|
||||
new_value = builder.create_add(new_lhs, new_rhs, op->get_name());
|
||||
new_value = builder.create_add(new_lhs, new_rhs);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -84,10 +84,10 @@ ir::value *reassociate::reassociate_idx(ir::value *old_value,
|
||||
ir::value *rlhs = bin_lhs->get_operand(1);
|
||||
// (cst + x) + y -> cst + (x + y)
|
||||
if(is_cst(llhs))
|
||||
new_value = builder.create_add(llhs, builder.create_add(rlhs, rhs), name);
|
||||
new_value = builder.create_add(llhs, builder.create_add(rlhs, rhs));
|
||||
// (x + cst) + y -> cst + (x + y)
|
||||
if(is_cst(rlhs))
|
||||
new_value = builder.create_add(rlhs, builder.create_add(llhs, rhs), name);
|
||||
new_value = builder.create_add(rlhs, builder.create_add(llhs, rhs));
|
||||
}
|
||||
// x + (y + z)
|
||||
if(ir::instruction* bin_rhs = is_bin_add(rhs)){
|
||||
@@ -95,10 +95,10 @@ ir::value *reassociate::reassociate_idx(ir::value *old_value,
|
||||
ir::value *rrhs = bin_rhs->get_operand(1);
|
||||
// x + (cst + y) -> cst + (x + y)
|
||||
if(is_cst(lrhs))
|
||||
new_value = builder.create_add(lrhs, builder.create_add(rrhs, lhs), name, cst);
|
||||
new_value = builder.create_add(lrhs, builder.create_add(rrhs, lhs), cst);
|
||||
// x + (y + cst) -> cst + (x + y)
|
||||
if(is_cst(rrhs))
|
||||
new_value = builder.create_add(rrhs, builder.create_add(lrhs, lhs), name, cst);
|
||||
new_value = builder.create_add(rrhs, builder.create_add(lrhs, lhs), cst);
|
||||
}
|
||||
}
|
||||
// extract constant and non-constant
|
||||
@@ -166,7 +166,7 @@ void reassociate::run(ir::module &mod) {
|
||||
ir::value* dyn = infos.at(op).dyn_ptr;
|
||||
ir::value* cst = *sta->idx_begin();
|
||||
if(dynamic_cast<ir::broadcast_inst*>(rt)) {
|
||||
auto shapes = rt->get_type()->get_tile_shapes();
|
||||
auto shapes = rt->get_type()->get_block_shapes();
|
||||
ir::value* ndyn = builder.create_broadcast(dyn, shapes);
|
||||
ir::value* broadcast = builder.create_broadcast(cst, shapes);
|
||||
ir::getelementptr_inst* nsta = (ir::getelementptr_inst*)builder.create_gep(ndyn, {broadcast});
|
||||
@@ -202,7 +202,7 @@ void reassociate::run(ir::module &mod) {
|
||||
ir::value *cst = *sta->idx_begin();
|
||||
ir::value *off = *pz->idx_begin();
|
||||
ir::value *pz_dyn = builder.create_gep(dyn, {off});
|
||||
ir::value *pz_sta = builder.create_gep(pz_dyn, {cst}, pz->get_name());
|
||||
ir::value *pz_sta = builder.create_gep(pz_dyn, {cst});
|
||||
pz->replace_all_uses_with(pz_sta);
|
||||
infos[pz_sta].dyn_ptr = pz_dyn;
|
||||
infos[pz_sta].sta_ptr = (ir::getelementptr_inst*)pz_sta;
|
||||
@@ -235,7 +235,8 @@ void reassociate::run(ir::module &mod) {
|
||||
phi_dyn->add_incoming(pa_dyn, phi->get_incoming_block(idx_a));
|
||||
builder.set_insert_point(phi->get_parent()->get_first_non_phi());
|
||||
// re-add the offset
|
||||
ir::value *phi_sta = builder.create_gep(phi_dyn, {off}, phi->get_name() + "_sta");
|
||||
ir::value *phi_sta = builder.create_gep(phi_dyn, {off});
|
||||
phi_sta->set_name( phi->get_name() + "_sta");
|
||||
phi->replace_all_uses_with(phi_sta);
|
||||
// remove offset from pz
|
||||
if(auto *x = dynamic_cast<ir::instruction*>(pz)){
|
||||
@@ -245,8 +246,8 @@ void reassociate::run(ir::module &mod) {
|
||||
builder.set_insert_point(*it);
|
||||
}
|
||||
ir::value *_0 = builder.get_int32(0);
|
||||
if(off->get_type()->is_tile_ty())
|
||||
_0 = builder.create_splat(_0, off->get_type()->get_tile_shapes());
|
||||
if(off->get_type()->is_block_ty())
|
||||
_0 = builder.create_splat(_0, off->get_type()->get_block_shapes());
|
||||
ir::value *neg_off = builder.create_sub(_0, off);
|
||||
ir::value *pz_dyn = builder.create_gep(pz, {neg_off});
|
||||
phi_dyn->add_incoming(pz_dyn, phi->get_incoming_block(idx_z));
|
||||
|
@@ -11,38 +11,38 @@ namespace codegen{
|
||||
namespace transform{
|
||||
|
||||
void reorder::run(ir::module& mod){
|
||||
ir::builder &builder = mod.get_builder();
|
||||
std::vector<std::pair<ir::instruction*, ir::value*>> to_replace;
|
||||
// ir::builder &builder = mod.get_builder();
|
||||
// std::vector<std::pair<ir::instruction*, ir::value*>> to_replace;
|
||||
|
||||
for(ir::function *fn: mod.get_function_list())
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction* i: block->get_inst_list()){
|
||||
if(auto* ld = dynamic_cast<ir::masked_load_inst*>(i)){
|
||||
ir::value* _ptr = ld->get_pointer_operand();
|
||||
ir::value* _msk = ld->get_mask_operand();
|
||||
ir::value* _val = ld->get_false_value_operand();
|
||||
auto ptr = std::find(block->begin(), block->end(), _ptr);
|
||||
auto msk = std::find(block->begin(), block->end(), _msk);
|
||||
auto val = std::find(block->begin(), block->end(), _val);
|
||||
if(ptr == block->end() || msk == block->end() || val == block->end())
|
||||
continue;
|
||||
auto it = std::find(block->begin(), block->end(), i);
|
||||
int dist_ptr = std::distance(ptr, it);
|
||||
int dist_msk = std::distance(msk, it);
|
||||
int dist_val = std::distance(val, it);
|
||||
if(dist_ptr < dist_msk && dist_ptr < dist_val)
|
||||
builder.set_insert_point(++ptr);
|
||||
if(dist_msk < dist_ptr && dist_msk < dist_val)
|
||||
builder.set_insert_point(++msk);
|
||||
if(dist_val < dist_ptr && dist_val < dist_msk)
|
||||
builder.set_insert_point(++val);
|
||||
ir::value* new_ld = builder.create_masked_load(_ptr, _msk, _val);
|
||||
to_replace.push_back(std::make_pair(ld, new_ld));
|
||||
}
|
||||
}
|
||||
// for(ir::function *fn: mod.get_function_list())
|
||||
// for(ir::basic_block *block: fn->blocks())
|
||||
// for(ir::instruction* i: block->get_inst_list()){
|
||||
// if(auto* ld = dynamic_cast<ir::masked_load_inst*>(i)){
|
||||
// ir::value* _ptr = ld->get_pointer_operand();
|
||||
// ir::value* _msk = ld->get_mask_operand();
|
||||
// ir::value* _val = ld->get_false_value_operand();
|
||||
// auto ptr = std::find(block->begin(), block->end(), _ptr);
|
||||
// auto msk = std::find(block->begin(), block->end(), _msk);
|
||||
// auto val = std::find(block->begin(), block->end(), _val);
|
||||
// if(ptr == block->end() || msk == block->end() || val == block->end())
|
||||
// continue;
|
||||
// auto it = std::find(block->begin(), block->end(), i);
|
||||
// int dist_ptr = std::distance(ptr, it);
|
||||
// int dist_msk = std::distance(msk, it);
|
||||
// int dist_val = std::distance(val, it);
|
||||
// if(dist_ptr < dist_msk && dist_ptr < dist_val)
|
||||
// builder.set_insert_point(++ptr);
|
||||
// if(dist_msk < dist_ptr && dist_msk < dist_val)
|
||||
// builder.set_insert_point(++msk);
|
||||
// if(dist_val < dist_ptr && dist_val < dist_msk)
|
||||
// builder.set_insert_point(++val);
|
||||
// ir::value* new_ld = builder.create_masked_load(_ptr, _msk, _val);
|
||||
// to_replace.push_back(std::make_pair(ld, new_ld));
|
||||
// }
|
||||
// }
|
||||
|
||||
for(auto& x: to_replace)
|
||||
x.first->replace_all_uses_with(x.second);
|
||||
// for(auto& x: to_replace)
|
||||
// x.first->replace_all_uses_with(x.second);
|
||||
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user