#include #include #include #include "triton/ir/function.h" #include "triton/ir/cfg.h" #include "triton/ir/basic_block.h" #include "triton/ir/instructions.h" #include "triton/ir/module.h" #include "triton/codegen/analysis/meminfo.h" #include "triton/codegen/analysis/align.h" #include "triton/codegen/transform/coalesce.h" namespace triton { namespace codegen{ namespace transform{ coalesce::coalesce(analysis::align* align, analysis::meminfo *mem) : align_(align), mem_(mem) { } std::vector coalesce::get_order(ir::value* v) { return order_.at(v); } void coalesce::run(ir::module &mod) { std::set io; std::function set_order = [&](ir::value *v) -> void { if(order_.find(v) != order_.end()) return; ir::type *tile_ty = v->get_type(); if(auto *x = dynamic_cast(v)) tile_ty = x->get_operand(0)->get_type(); if(!tile_ty->is_tile_ty()) return; std::vector order(tile_ty->get_tile_shapes().size()); std::iota(order.begin(), order.end(), 0); order_[v] = order; if(ir::user* u = dynamic_cast(v)) for(ir::value* op: u->ops()) set_order(op); }; // initialize work-list for(ir::function *fn: mod.get_function_list()) for(ir::basic_block *block: ir::cfg::reverse_post_order(fn)) for(ir::instruction *i: block->get_inst_list()){ if(auto *x = dynamic_cast(i)) { ir::type* ptr_ty = x->get_pointer_operand()->get_type(); if(ptr_ty->is_tile_ty()) io.insert(x); } set_order(i); } ir::builder &builder = mod.get_builder(); std::map replaced; for(ir::io_inst *i: io) { ir::value *ptr = i->get_pointer_operand(); auto max_contiguous = align_->get_max_contiguous_vec(ptr); std::vector order(max_contiguous.size()); std::iota(order.begin(), order.end(), 0); std::sort(order.begin(), order.end(), [&](unsigned a, unsigned b) { return max_contiguous[a] > max_contiguous[b]; } ); std::list> work_list; if(order != order_[i]) work_list.push_back({i, nullptr}); // rematerialize recursively while(!work_list.empty()) { auto pair = work_list.back(); ir::instruction* cloned = pair.first; ir::instruction* original = pair.second; order_[cloned] = order; work_list.pop_back(); for(ir::value *op: cloned->ops()) { ir::instruction* i_op = dynamic_cast(op); if(replaced.find(i_op) != replaced.end()){ cloned->replace_uses_of_with(i_op, replaced.at(i_op)); continue; } if(!i_op) continue; ir::type *ty = i_op->get_type(); if(!ty->is_tile_ty()) continue; auto& inst_list = i_op->get_parent()->get_inst_list(); auto it = std::find(inst_list.begin(), inst_list.end(), i_op); it++; builder.set_insert_point(it); // found a load; write to shared memory and stop recursion ir::instruction *n_op = nullptr; if(mem_->is_shared(i_op)){ i_op->add_use(cloned); continue; } if(auto* ld = dynamic_cast(i_op)) n_op = ir::copy_to_shared_inst::create(ld); // not a load; rematerialize and add to worklist else { n_op = i_op->clone(); work_list.push_back({n_op, i_op}); } n_op = builder.insert(n_op); replaced.insert({i_op, n_op}); order_[n_op] = order; align_->copy(n_op, i_op); mem_->copy(n_op, i_op); if(original) n_op->erase_use(original); cloned->replace_uses_of_with(i_op, n_op); } } } } } } }