[CODEGEN] Various bugfixes and stability improvements in compiler backend (#240)
This commit is contained in:
@@ -9,67 +9,48 @@ namespace triton {
|
||||
namespace codegen{
|
||||
namespace transform{
|
||||
|
||||
void extract_retile_chain(ir::user *root,
|
||||
std::map<int, std::set<ir::user*>>& result,
|
||||
int depth,
|
||||
ir::instruction* rematerialize(ir::builder& bld, ir::instruction *root,
|
||||
std::set<ir::value*>& seen) {
|
||||
if(!seen.insert(root).second)
|
||||
return;
|
||||
result[depth].insert(root);
|
||||
if(dynamic_cast<ir::make_range*>(root) ||
|
||||
dynamic_cast<ir::splat_inst*>(root)){
|
||||
return;
|
||||
}
|
||||
return root;
|
||||
if(!root->get_type()->is_block_ty())
|
||||
return root;
|
||||
|
||||
bld.set_insert_point(root);
|
||||
ir::instruction *new_root = bld.insert(root->clone());
|
||||
for(ir::value *op: root->ops()){
|
||||
ir::user *u = dynamic_cast<ir::user*>(op);
|
||||
if(!u)
|
||||
ir::instruction *i = dynamic_cast<ir::instruction*>(op);
|
||||
if(!i || i->get_id() == ir::INST_REDUCE)
|
||||
continue;
|
||||
extract_retile_chain(u, result, depth + 1, seen);
|
||||
ir::instruction* new_op = rematerialize(bld, i, seen);
|
||||
new_root->replace_uses_of_with(op, new_op);
|
||||
}
|
||||
return new_root;
|
||||
}
|
||||
|
||||
void disassociate::run(ir::module &mod) {
|
||||
ir::builder &bld = mod.get_builder();
|
||||
|
||||
std::map<ir::user*, std::map<int, std::set<ir::user*>>> clone_info;
|
||||
// ir::for_each_instruction(mod, [&](ir::instruction *i){
|
||||
// bld.set_insert_point(i);
|
||||
// for(ir::value* op: i->ops()){
|
||||
// auto reshape = dynamic_cast<ir::make_range*>(op);
|
||||
// if(!reshape)
|
||||
// continue;
|
||||
// ir::instruction* new_op = bld.insert(reshape->clone());
|
||||
// i->replace_uses_of_with(op, new_op);
|
||||
// }
|
||||
// });
|
||||
|
||||
|
||||
ir::for_each_instruction(mod, [&](ir::instruction *i){
|
||||
if(dynamic_cast<ir::reshape_inst*>(i)){
|
||||
ir::value* op = i->get_operand(0);
|
||||
if(!dynamic_cast<ir::user*>(op))
|
||||
return;
|
||||
if(op->get_type()->get_tile_rank() > i->get_type()->get_tile_rank())
|
||||
return;
|
||||
std::map<int, std::set<ir::user*>> chains;
|
||||
if(dynamic_cast<ir::reshape_inst*>(i) || dynamic_cast<ir::splat_inst*>(i)){
|
||||
std::set<ir::value*> seen;
|
||||
extract_retile_chain(i, chains, 0, seen);
|
||||
if(chains.size())
|
||||
clone_info[i] = chains;
|
||||
ir::instruction* new_i = rematerialize(bld, i, seen);
|
||||
i->replace_all_uses_with(new_i);
|
||||
}
|
||||
});
|
||||
|
||||
for(const auto& x: clone_info){
|
||||
int depth = 1;
|
||||
std::map<ir::instruction*, ir::instruction*> clone_map;
|
||||
while(x.second.find(depth) != x.second.end()){
|
||||
// clone all users
|
||||
const auto& remat = x.second.at(depth);
|
||||
for(ir::user* u: remat){
|
||||
ir::instruction *y = (ir::instruction*)u;
|
||||
ir::instruction *cloned = y->clone();
|
||||
bld.set_insert_point(y);
|
||||
bld.insert(cloned);
|
||||
clone_map[y] = cloned;
|
||||
// replace operands of parents
|
||||
if(depth > 1)
|
||||
for(ir::user* ux: x.second.at(depth - 1))
|
||||
clone_map.at((ir::instruction*)ux)->replace_uses_of_with(y, cloned);
|
||||
else
|
||||
x.first->replace_uses_of_with(y, cloned);
|
||||
}
|
||||
depth += 1;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user