diff --git a/include/triton/codegen/reassociate.h b/include/triton/codegen/reassociate.h index 9be8ed6bd..f6d30ea72 100644 --- a/include/triton/codegen/reassociate.h +++ b/include/triton/codegen/reassociate.h @@ -14,6 +14,7 @@ class module; class value; class builder; class instruction; +class getelementptr_inst; } namespace codegen{ @@ -21,9 +22,15 @@ namespace codegen{ class tune; class reassociate { + struct cst_info { + ir::value* sta; + ir::value* dyn; + }; + private: ir::instruction* is_bin_add(ir::value *x); - ir::value *reorder_op(ir::value *value, ir::builder &builder, std::vector& to_delete, ir::value *&noncst, ir::value *&cst); + ir::value *reassociate_idx(ir::value *value, ir::builder &builder, std::vector& to_delete, ir::value *&noncst, ir::value *&cst); + ir::value *reassociate_ptr(ir::getelementptr_inst* pz, ir::builder &builder, std::map &offsets); public: reassociate(tune *params); diff --git a/include/triton/ir/instructions.h b/include/triton/ir/instructions.h index 1bce5bd47..29b2678a3 100644 --- a/include/triton/ir/instructions.h +++ b/include/triton/ir/instructions.h @@ -384,6 +384,7 @@ public: type *get_source_elt_ty() { return source_elt_ty; } op_iterator idx_begin() { return op_begin() + 1; } op_iterator idx_end() { return op_end(); } + value *get_pointer_operand() { return *op_begin(); } // factory methods static getelementptr_inst* create(value *ptr, const std::vector &idx, diff --git a/include/triton/runtime/jit.h b/include/triton/runtime/jit.h index 8a1157940..a1dcfc578 100644 --- a/include/triton/runtime/jit.h +++ b/include/triton/runtime/jit.h @@ -79,6 +79,7 @@ public: alignment_info.run(module); reassociate.run(module); ir::print(module, std::cout); + //exit(EXIT_FAILURE); if(target_->is_gpu()){ shmem_info.run(module); shmem_liveness.run(module); diff --git a/lib/codegen/reassociate.cpp b/lib/codegen/reassociate.cpp index 2ca8828d7..a06141d0e 100644 --- a/lib/codegen/reassociate.cpp +++ b/lib/codegen/reassociate.cpp @@ -49,11 +49,34 @@ inline bool is_cst(ir::value *x) { } -inline ir::value *reassociate::reorder_op(ir::value *old_value, - ir::builder &builder, - std::vector& to_delete, - ir::value *&noncst, - ir::value *&cst){ +// reassociate pointer +// pz = py + a = (px + (cst + b)) + a -> (px + b) + (cst + a) +ir::value *reassociate::reassociate_ptr(ir::getelementptr_inst* pz, + ir::builder &builder, + std::map &info) { + ir::value *a = *pz->idx_begin(); + ir::value *vpy = pz->get_pointer_operand(); + if(info.find(vpy) == info.end()) + return nullptr; + ir::getelementptr_inst *py = (ir::getelementptr_inst*)vpy; + ir::value *px = py->get_pointer_operand(); + ir::value *cst = info.at(py).sta; + ir::value *b = info.at(py).dyn; + ir::value *new_py = builder.create_gep(px, {b}); + ir::value *new_a = builder.create_add(cst, a); + ir::value *new_pz = builder.create_gep(new_py, {new_a}); + params_->copy(new_pz, pz); + params_->copy(new_py, vpy); + params_->copy(new_a, a); + pz->replace_all_uses_with(new_pz); + return pz; +} + +ir::value *reassociate::reassociate_idx(ir::value *old_value, + ir::builder &builder, + std::vector& to_delete, + ir::value *&noncst, + ir::value *&cst){ // value doesn't change by default ir::value* new_value = old_value; cst = nullptr; @@ -63,7 +86,7 @@ inline ir::value *reassociate::reorder_op(ir::value *old_value, if(ir::instruction* op = dynamic_cast(old_value)){ auto shapes = op->get_type()->get_tile_shapes(); ir::value *old_arg = op->get_operand(0); - ir::value *new_arg = reorder_op(old_arg, builder, to_delete, noncst, cst); + ir::value *new_arg = reassociate_idx(old_arg, builder, to_delete, noncst, cst); // retile(x + y) = retile(x) + retile(y) if(ir::instruction* bin_add = is_bin_add(new_arg)) if(cst){ @@ -102,8 +125,8 @@ inline ir::value *reassociate::reorder_op(ir::value *old_value, if(ir::instruction* op = is_bin_add(old_value)){ builder.set_insert_point(op); std::string name = op->get_name(); - ir::value *lhs = reorder_op(op->get_operand (0), builder, to_delete, noncst, cst); - ir::value *rhs = reorder_op(op->get_operand(1), builder, to_delete, noncst, cst); + ir::value *lhs = reassociate_idx(op->get_operand (0), builder, to_delete, noncst, cst); + ir::value *rhs = reassociate_idx(op->get_operand(1), builder, to_delete, noncst, cst); builder.set_insert_point(op); // (x + y) + z if(ir::instruction* bin_lhs = is_bin_add(lhs)){ @@ -167,6 +190,8 @@ reassociate::reassociate(tune* params) : params_(params) { } + +/* run */ void reassociate::run(ir::module &mod) { ir::builder &builder = mod.get_builder(); std::vector to_delete; @@ -196,25 +221,63 @@ void reassociate::run(ir::module &mod) { } // reassociate + std::map infos; + std::map> re_ordered; + for(ir::function *fn: mod.get_function_list()){ std::vector rpo = ir::cfg::reverse_post_order(fn); - bool done = false; - do{ - // iterate through blocks - for(ir::basic_block *block: rpo){ - // iterate through instruction - for(ir::instruction *i: block->get_inst_list()){ - if(auto *gep = dynamic_cast(i)){ - std::vector idxs(gep->idx_begin(), gep->idx_end()); - ir::value *cst = nullptr; - ir::value *noncst = idxs[0]; - reorder_op(noncst, builder, to_delete, noncst, cst); -// std::cout << gep->get_name() << " " << noncst << " " << cst << std::endl; - } - } - done = true; + // iterate through blocks + for(ir::basic_block *block: rpo){ + // iterate through instruction + for(ir::instruction *i: block->get_inst_list()){ + // getelementptr instruction + if(ir::getelementptr_inst *pz = dynamic_cast(i)){ + + // pz = py + offset + // tries to achieve pz = py + (cst + a) + // by modifying py and/or offset + ir::value* py = pz->get_pointer_operand(); + ir::value* offset = *pz->idx_begin(); + + // reassociate index + ir::value *sta = nullptr; + ir::value *dyn = offset; + reassociate_idx(pz, builder, to_delete, dyn, sta); + if(sta){ + infos[pz] = {sta, dyn}; + re_ordered[block].insert(pz); } - }while(!done); + +// // reassociate pointer +// reassociate_ptr(pz, builder, offsets); + +// // reassociate phi-node +// if(ir::phi_node* phi = dynamic_cast(py)){ +// // only optimize the case where py = phi pa, pz +// std::vector ops = phi->ops(); +// if(!(ops.size() == 2 && (ops[0] == pz || ops[1] == pz))) +// continue; +// size_t idx_z = (ops[0] == pz) ? 0 : 1; +// size_t idx_a = (idx_z + 1) % 2; +// ir::value *vpa = phi->get_incoming_value(idx_a); +// ir::value *block_a = phi->get_incoming_block(idx_a); +// ir::value *block_z = phi->get_incoming_value(idx_z); +// auto it = infos.find(vpa); +// if(it == infos.end()) +// continue; +// ir::value *b = it->a; +// // pa = px + (cst + b) +// ir::getelementptr_inst *pa = (ir::getelementptr_inst*)vpa; +// ir::getelementptr_inst *px = pa->get_pointer_operand(); +// // new_pa = px + b +// ir::getelementptr_inst *new_pa = builder.create_gep(px, {b}); +// // new_pz = py + (offset + a) +// ir::getelementptr_inst *new_offset = builder.create_add(it->cst, dyn); +// ir::getelementptr_inst *new_pz = builder.create_gep(pz->get_pointer_operand(), {new_offset}); +// } + } + } + } } // erase dead code for(ir::instruction* i: to_delete)