[CODEGEN] Various bugfixes and stability improvements in compiler backend (#240)

This commit is contained in:
Philippe Tillet
2021-08-30 11:50:35 -07:00
committed by GitHub
parent 85426dbaf7
commit 4ff3714d61
25 changed files with 568 additions and 399 deletions

View File

@@ -9,67 +9,48 @@ namespace triton {
namespace codegen{
namespace transform{
void extract_retile_chain(ir::user *root,
std::map<int, std::set<ir::user*>>& result,
int depth,
ir::instruction* rematerialize(ir::builder& bld, ir::instruction *root,
std::set<ir::value*>& seen) {
if(!seen.insert(root).second)
return;
result[depth].insert(root);
if(dynamic_cast<ir::make_range*>(root) ||
dynamic_cast<ir::splat_inst*>(root)){
return;
}
return root;
if(!root->get_type()->is_block_ty())
return root;
bld.set_insert_point(root);
ir::instruction *new_root = bld.insert(root->clone());
for(ir::value *op: root->ops()){
ir::user *u = dynamic_cast<ir::user*>(op);
if(!u)
ir::instruction *i = dynamic_cast<ir::instruction*>(op);
if(!i || i->get_id() == ir::INST_REDUCE)
continue;
extract_retile_chain(u, result, depth + 1, seen);
ir::instruction* new_op = rematerialize(bld, i, seen);
new_root->replace_uses_of_with(op, new_op);
}
return new_root;
}
void disassociate::run(ir::module &mod) {
ir::builder &bld = mod.get_builder();
std::map<ir::user*, std::map<int, std::set<ir::user*>>> clone_info;
// ir::for_each_instruction(mod, [&](ir::instruction *i){
// bld.set_insert_point(i);
// for(ir::value* op: i->ops()){
// auto reshape = dynamic_cast<ir::make_range*>(op);
// if(!reshape)
// continue;
// ir::instruction* new_op = bld.insert(reshape->clone());
// i->replace_uses_of_with(op, new_op);
// }
// });
ir::for_each_instruction(mod, [&](ir::instruction *i){
if(dynamic_cast<ir::reshape_inst*>(i)){
ir::value* op = i->get_operand(0);
if(!dynamic_cast<ir::user*>(op))
return;
if(op->get_type()->get_tile_rank() > i->get_type()->get_tile_rank())
return;
std::map<int, std::set<ir::user*>> chains;
if(dynamic_cast<ir::reshape_inst*>(i) || dynamic_cast<ir::splat_inst*>(i)){
std::set<ir::value*> seen;
extract_retile_chain(i, chains, 0, seen);
if(chains.size())
clone_info[i] = chains;
ir::instruction* new_i = rematerialize(bld, i, seen);
i->replace_all_uses_with(new_i);
}
});
for(const auto& x: clone_info){
int depth = 1;
std::map<ir::instruction*, ir::instruction*> clone_map;
while(x.second.find(depth) != x.second.end()){
// clone all users
const auto& remat = x.second.at(depth);
for(ir::user* u: remat){
ir::instruction *y = (ir::instruction*)u;
ir::instruction *cloned = y->clone();
bld.set_insert_point(y);
bld.insert(cloned);
clone_map[y] = cloned;
// replace operands of parents
if(depth > 1)
for(ir::user* ux: x.second.at(depth - 1))
clone_map.at((ir::instruction*)ux)->replace_uses_of_with(y, cloned);
else
x.first->replace_uses_of_with(y, cloned);
}
depth += 1;
}
}
}