[general] major overhaul of triton-c/triton-ir/triton-jit:
- Added alloc const - Added atomics - Pruning tuning space - Added example for dot/conv/shift - Bugfixes
This commit is contained in:
71
lib/codegen/optimize_trans.cpp
Normal file
71
lib/codegen/optimize_trans.cpp
Normal file
@@ -0,0 +1,71 @@
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/codegen/optimize_trans.h"
|
||||
|
||||
namespace triton {
|
||||
namespace codegen{
|
||||
|
||||
|
||||
ir::value* optimize_trans::replace_phi(ir::value* value,
|
||||
std::vector<ir::instruction*>& to_delete,
|
||||
ir::builder& builder){
|
||||
if(auto phi = dynamic_cast<ir::phi_node*>(value)) {
|
||||
// transpose operands
|
||||
std::vector<ir::value*> incs;
|
||||
for(unsigned n = 0; n < phi->get_num_incoming(); n++)
|
||||
incs.push_back(replace_phi(phi->get_incoming_value(n), to_delete, builder));
|
||||
// create phi for transposed values
|
||||
builder.set_insert_point(phi);
|
||||
ir::phi_node* result = builder.create_phi(incs[0]->get_type(), incs.size(), phi->get_name());
|
||||
for(unsigned n = 0; n < phi->get_num_incoming(); n++)
|
||||
result->add_incoming(incs[n], phi->get_incoming_block(n));
|
||||
phi->replace_all_uses_with(result);
|
||||
to_delete.push_back(phi);
|
||||
return result;
|
||||
}
|
||||
else if(auto i = dynamic_cast<ir::instruction*>(value)){
|
||||
ir::basic_block* block = i->get_parent();
|
||||
auto it = std::find(block->begin(), block->end(), i);
|
||||
it++;
|
||||
builder.set_insert_point(it);
|
||||
ir::instruction *trans = (ir::instruction*)builder.create_trans(i);
|
||||
i->replace_all_uses_with(trans);
|
||||
trans->set_operand(0, i);
|
||||
return trans;
|
||||
}
|
||||
throw std::runtime_error("cannot transpose phi");
|
||||
}
|
||||
|
||||
|
||||
void optimize_trans::run(ir::module &mod) {
|
||||
ir::builder &builder = mod.get_builder();
|
||||
std::vector<ir::instruction*> to_delete;
|
||||
// iterate
|
||||
for(ir::function *fn: mod.get_function_list())
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction* i: block->get_inst_list()){
|
||||
// filter transposition
|
||||
if(auto trans = dynamic_cast<ir::trans_inst*>(i)) {
|
||||
auto users = trans->get_users();
|
||||
auto ops = trans->ops();
|
||||
if(users.size() > 1 || ops.size() > 1)
|
||||
continue;
|
||||
ir::value* op = *ops.begin();
|
||||
// chains of transpositions
|
||||
// TODO
|
||||
|
||||
// trans(phi) -> phi(trans(), trans()...)
|
||||
if(dynamic_cast<ir::phi_node*>(op)){
|
||||
ir::value* new_phi = replace_phi(op, to_delete, builder);
|
||||
to_delete.push_back(trans);
|
||||
trans->replace_all_uses_with(new_phi);
|
||||
}
|
||||
}
|
||||
}
|
||||
// erase dead code
|
||||
for(ir::instruction* i: to_delete)
|
||||
i->erase_from_parent();
|
||||
}
|
||||
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user