[codegen][coalesce] more bugfixes
This commit is contained in:
@@ -28,16 +28,17 @@ void coalesce::run(ir::module &mod) {
|
||||
std::function<void(ir::value*)> set_order = [&](ir::value *v) -> void {
|
||||
if(order_.find(v) != order_.end())
|
||||
return;
|
||||
order_[v] = {};
|
||||
ir::type *tile_ty = v->get_type();
|
||||
if(auto *x = dynamic_cast<ir::store_inst*>(v))
|
||||
tile_ty = x->get_operand(0)->get_type();
|
||||
if(!tile_ty->is_tile_ty())
|
||||
return;
|
||||
std::vector<unsigned> order(tile_ty->get_tile_shapes().size());
|
||||
std::iota(order.begin(), order.end(), 0);
|
||||
order_[v] = order;
|
||||
if(ir::user* u = dynamic_cast<ir::user*>(v))
|
||||
for(ir::value* op: u->ops())
|
||||
set_order(op);
|
||||
ir::type* ty = v->get_type();
|
||||
if(!ty->is_tile_ty())
|
||||
return;
|
||||
std::vector<unsigned> order(ty->get_tile_shapes().size());
|
||||
std::iota(order.begin(), order.end(), 0);
|
||||
order_[v] = order;
|
||||
};
|
||||
|
||||
// initialize work-list
|
||||
@@ -52,56 +53,58 @@ void coalesce::run(ir::module &mod) {
|
||||
set_order(i);
|
||||
}
|
||||
|
||||
// ir::builder &builder = mod.get_builder();
|
||||
// std::set<ir::value*> seen;
|
||||
// for(ir::io_inst *i: io) {
|
||||
// ir::value *ptr = i->get_pointer_operand();
|
||||
// auto max_contiguous = align_->get_max_contiguous_vec(ptr);
|
||||
// std::vector<unsigned> order(max_contiguous.size());
|
||||
// std::iota(order.begin(), order.end(), 0);
|
||||
// std::sort(order.begin(), order.end(), [&](unsigned a, unsigned b) { return max_contiguous[a] > max_contiguous[b]; } );
|
||||
// std::list<ir::instruction*> work_list;
|
||||
// if(order != order_[i])
|
||||
// work_list.push_back(i);
|
||||
// // rematerialize recursively
|
||||
// while(!work_list.empty()) {
|
||||
// ir::instruction* current = work_list.back();
|
||||
// order_[current] = order;
|
||||
// work_list.pop_back();
|
||||
// for(ir::value *op: current->ops()) {
|
||||
// ir::instruction* i_op = dynamic_cast<ir::instruction*>(op);
|
||||
// if(!seen.insert(op).second)
|
||||
// continue;
|
||||
// if(!i_op)
|
||||
// continue;
|
||||
// ir::type *ty = i_op->get_type();
|
||||
// if(!ty->is_tile_ty())
|
||||
// continue;
|
||||
// auto& inst_list = i_op->get_parent()->get_inst_list();
|
||||
// auto it = std::find(inst_list.begin(), inst_list.end(), i_op);
|
||||
// it++;
|
||||
// builder.set_insert_point(it);
|
||||
// // found a load; write to shared memory and stop recursion
|
||||
// ir::instruction *n_op = nullptr;
|
||||
// if(mem_->is_shared(i_op)){
|
||||
// continue;
|
||||
// }
|
||||
// if(auto* ld = dynamic_cast<ir::load_inst*>(i_op)) {
|
||||
// n_op = ir::copy_to_shared_inst::create(ld);
|
||||
// }
|
||||
// // not a load; rematerialize and recurse
|
||||
// else {
|
||||
// n_op = i_op->clone();
|
||||
// work_list.push_back(n_op);
|
||||
// }
|
||||
// n_op = builder.insert(n_op);
|
||||
// order_[n_op] = order;
|
||||
// align_->copy(n_op, i_op);
|
||||
// current->replace_uses_of_with(i_op, n_op);
|
||||
// }
|
||||
// }
|
||||
ir::builder &builder = mod.get_builder();
|
||||
std::map<ir::value*, ir::value*> replaced;
|
||||
for(ir::io_inst *i: io) {
|
||||
ir::value *ptr = i->get_pointer_operand();
|
||||
auto max_contiguous = align_->get_max_contiguous_vec(ptr);
|
||||
std::vector<unsigned> order(max_contiguous.size());
|
||||
std::iota(order.begin(), order.end(), 0);
|
||||
std::sort(order.begin(), order.end(), [&](unsigned a, unsigned b) { return max_contiguous[a] > max_contiguous[b]; } );
|
||||
std::list<ir::instruction*> work_list;
|
||||
if(order != order_[i])
|
||||
work_list.push_back(i);
|
||||
// rematerialize recursively
|
||||
while(!work_list.empty()) {
|
||||
ir::instruction* current = work_list.back();
|
||||
order_[current] = order;
|
||||
work_list.pop_back();
|
||||
for(ir::value *op: current->ops()) {
|
||||
ir::instruction* i_op = dynamic_cast<ir::instruction*>(op);
|
||||
if(replaced.find(i_op) != replaced.end()){
|
||||
current->replace_uses_of_with(i_op, replaced.at(i_op));
|
||||
continue;
|
||||
}
|
||||
if(!i_op)
|
||||
continue;
|
||||
ir::type *ty = i_op->get_type();
|
||||
if(!ty->is_tile_ty())
|
||||
continue;
|
||||
auto& inst_list = i_op->get_parent()->get_inst_list();
|
||||
auto it = std::find(inst_list.begin(), inst_list.end(), i_op);
|
||||
it++;
|
||||
builder.set_insert_point(it);
|
||||
// found a load; write to shared memory and stop recursion
|
||||
ir::instruction *n_op = nullptr;
|
||||
if(mem_->is_shared(i_op))
|
||||
continue;
|
||||
if(auto* ld = dynamic_cast<ir::load_inst*>(i_op))
|
||||
n_op = ir::copy_to_shared_inst::create(ld);
|
||||
// not a load; rematerialize and recurse
|
||||
else {
|
||||
n_op = i_op->clone();
|
||||
work_list.push_back(n_op);
|
||||
}
|
||||
n_op = builder.insert(n_op);
|
||||
replaced.insert({i_op, n_op});
|
||||
order_[n_op] = order;
|
||||
align_->copy(n_op, i_op);
|
||||
// mem_->copy(n_op, i_op);
|
||||
current->replace_uses_of_with(i_op, n_op);
|
||||
}
|
||||
}
|
||||
|
||||
// }
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user