#include #include #include "triton/ir/utils.h" #include "triton/ir/instructions.h" #include "triton/ir/function.h" #include "triton/ir/module.h" #include "triton/codegen/transform/coalesce.h" #include "triton/codegen/analysis/align.h" #include "triton/codegen/analysis/layout.h" namespace triton { namespace codegen{ namespace transform{ coalesce::coalesce(analysis::align* align, analysis::layouts *layouts) : align_(align), layout_(layouts) { } // Find all values that are used as pointer operands in LD/ST void coalesce::extract_io_use(ir::value *v, std::set& result) { for(ir::user* u: v->get_users()){ auto i = dynamic_cast(u); if(i && i->get_pointer_operand() == v) result.insert(i); } } void coalesce::extract_ld(ir::io_inst* i, std::map>& result) { ir::value *ptr = i->get_pointer_operand(); auto contiguous = align_->contiguous(ptr); auto it = std::max_element(contiguous.begin(), contiguous.end()); int axis = std::distance(contiguous.begin(), it); result[axis].push_back(i); } ir::value* coalesce::rematerialize(ir::value *x, ir::builder &builder, std::map& seen) { if(seen.find(x) != seen.end()) return seen.at(x); auto i = dynamic_cast(x); // not an instruction -- forward value if(!i) return x; // already in shared memory -- forward value if(dynamic_cast(x)){ return x; } // set insert point auto& inst_list = i->get_parent()->get_inst_list(); auto pos = ++std::find(inst_list.begin(), inst_list.end(), i); builder.set_insert_point(pos); if(dynamic_cast(x)){ ir::value *ret = builder.insert(ir::copy_to_shared_inst::create(x)); return ret; } // default -- recursive clone ir::instruction *cloned = builder.insert(i->clone()); seen[i] = cloned; // rematerialize operands for(ir::value *op: cloned->ops()) cloned->replace_uses_of_with(op, rematerialize(op, builder, seen)); return cloned; } void coalesce::run(ir::module &mod) { size_t num_groups = layout_->num_layouts(); for(size_t id = 0; id < num_groups; id++) { if(!layout_->get(id)->to_mma()) continue; // extract memory stores const auto& values = layout_->values_of(id); ir::value* dot = nullptr; for(ir::value *v: values) if(auto x = dynamic_cast(v)) dot = x; ir::builder& builder = mod.get_builder(); std::vector worklist = {dot}; std::set seen; while(!worklist.empty()) { ir::value *current = worklist.back(); seen.insert(current); worklist.pop_back(); // stop if trunc if(auto x = dynamic_cast(current)){ builder.set_insert_point_after(x); ir::recoalesce_inst* rc = ir::recoalesce_inst::create(x); builder.insert(rc); x->replace_all_uses_with(rc); rc->replace_uses_of_with(rc, x); break; } // recurse for(ir::user *u: current->get_users()) if(seen.find(u) == seen.end()) worklist.push_back(u); } } // find values to rematerialize std::vector remat; for(size_t id = 0; id < num_groups; id++) { const auto& values = layout_->values_of(id); // extract pointers used in ld/st operations std::set io; for(ir::value *v: values) extract_io_use(v, io); // extract leading axes std::map> axes; for(ir::io_inst *i: io){ if(i->get_pointer_operand()->get_type()->get_tile_rank() == layout_->get(id)->get_rank()){ extract_ld(i, axes); } } // update list of values to rematerialize if(axes.empty()) continue; for(auto it = ++axes.rbegin(); it != axes.rend(); it++){ if(it->second.size() == 1) continue; remat.insert(remat.begin(), it->second.begin(), it->second.end()); } } // rematerialize values for(ir::io_inst *r: remat) { ir::builder& builder = mod.get_builder(); // rematerialize operands std::map seen; for(ir::value *op: r->ops()) r->replace_uses_of_with(op, rematerialize(op, mod.get_builder(), seen)); // copy to shared if load auto& inst_list = r->get_parent()->get_inst_list(); auto pos = ++std::find(inst_list.begin(), inst_list.end(), r); builder.set_insert_point(pos); if(dynamic_cast(r)){ ir::instruction *cts = builder.insert(ir::copy_to_shared_inst::create(r)); r->replace_all_uses_with(cts); cts->replace_uses_of_with(cts, r); } } } } } }