diff --git a/include/triton/codegen/transform/disassociate.h b/include/triton/codegen/transform/disassociate.h new file mode 100644 index 000000000..f2363f3fe --- /dev/null +++ b/include/triton/codegen/transform/disassociate.h @@ -0,0 +1,22 @@ +#ifndef _TRITON_SELECTION_TRANSFORM_DISASSOCIATE_H_ +#define _TRITON_SELECTION_TRANSFORM_DISASSOCIATE_H_ + + +namespace triton { +namespace ir { + class module; +} + +namespace codegen{ +namespace transform{ + +class disassociate { +public: + void run(ir::module &mod); +}; + +} +} +} + +#endif diff --git a/lib/codegen/transform/disassociate.cc b/lib/codegen/transform/disassociate.cc new file mode 100644 index 000000000..70134b186 --- /dev/null +++ b/lib/codegen/transform/disassociate.cc @@ -0,0 +1,83 @@ +#include "triton/codegen/transform/disassociate.h" +#include "triton/ir/utils.h" +#include "triton/ir/instructions.h" +#include "triton/ir/builder.h" +#include "triton/ir/module.h" +#include + +namespace triton { +namespace codegen{ +namespace transform{ + +void extract_retile_chain(ir::user *root, + const std::vector& current, + std::vector>& result, + std::set& seen) { + if(!seen.insert(root).second) + return; + if(dynamic_cast(root) || dynamic_cast(root)){ + std::vector next = current; + next.push_back(root); + result.push_back(next); + return; + } + for(ir::value *op: root->ops()){ + ir::user *u = dynamic_cast(op); + if(!u) + continue; + std::vector next = current; + next.push_back(u); + extract_retile_chain(u, next, result, seen); + } +} + +void disassociate::run(ir::module &mod) { + ir::builder &bld = mod.get_builder(); + + std::map>> clone_info; + ir::for_each_instruction(mod, [&](ir::instruction *i){ + if(dynamic_cast(i)){ + std::vector> chains; + std::set seen; + if(!dynamic_cast(i->get_operand(0))) + return; + extract_retile_chain(i, {}, chains, seen); + if(chains.size()) + clone_info[i] = chains; + } + }); + + + for(auto x: clone_info){ + for(auto chain: x.second){ + for(int i = 0; i < chain.size(); i++) { + ir::instruction *y = (ir::instruction*)chain[i]; + ir::instruction *cloned = y->clone(); + bld.set_insert_point(y); + bld.insert(cloned); + if(i > 0) + chain[i-1]->replace_uses_of_with(y, cloned); + else + x.first->replace_uses_of_with(y, cloned); + } + + +// ir::instruction *y = (ir::instruction*)parent; +// for(ir::user *u: chain){ +// ir::instruction *cloned = y->clone(); +// bld.set_insert_point(y); +// bld.insert(cloned); +// std::cout << typeid(*u).name() << std::endl; +// u->replace_uses_of_with(y, cloned); +// y = (ir::instruction*)u; +// } + } + } + + +} + + +} +} +} diff --git a/lib/runtime/function.cc b/lib/runtime/function.cc index e017982f8..f9eab4ee4 100644 --- a/lib/runtime/function.cc +++ b/lib/runtime/function.cc @@ -13,6 +13,7 @@ #include "triton/codegen/transform/membar.h" #include "triton/codegen/transform/reassociate.h" #include "triton/codegen/transform/cts.h" +#include "triton/codegen/transform/disassociate.h" #include "triton/codegen/selection/generator.h" #include "triton/runtime/function.h" #include "triton/lang/cpp.h" @@ -208,6 +209,7 @@ std::unique_ptr function::make_bin(ir::module &module, driver::c // create passes codegen::analysis::align align; codegen::analysis::axes axes; + codegen::transform::disassociate disassociate; codegen::analysis::layout layouts(&axes, &align, opt.num_warps); codegen::analysis::liveness liveness(&layouts); codegen::analysis::allocation allocation(&liveness); @@ -219,7 +221,8 @@ std::unique_ptr function::make_bin(ir::module &module, driver::c codegen::transform::cts cts; codegen::generator isel(&axes, &layouts, &align, &allocation, target.get(), opt.num_warps); // run passes -// ir::print(module, std::cout); + disassociate.run(module); + dce.run(module); peephole.run(module); dce.run(module); align.run(module); diff --git a/tests/common/src/dot.h b/tests/common/src/dot.h index dc71d86bb..05ed68a7b 100644 --- a/tests/common/src/dot.h +++ b/tests/common/src/dot.h @@ -10,14 +10,13 @@ void dot(TYPE * A, TYPE * B, TYPE * C, // prologue int ridx = get_program_id(0); int ridy = get_program_id(1); - int rxa[TM] = ridx * TM + 0 ... TM; - int ryb[TN] = ridy * TN + 0 ... TN; - int rka[TK] = 0 ... TK; - int rkb[TK] = 0 ... TK; + int rm[TM] = ridx * TM + 0 ... TM; + int rn[TN] = ridy * TN + 0 ... TN; + int rk[TK] = 0 ... TK; float c[TM, TN] = 0; // pointers to operands - TYPE* pa[SHAPE_A] = A + rka[BROADCAST_AK] * STRIDE_AK + rxa[BROADCAST_AM] * STRIDE_AM; - TYPE* pb[SHAPE_B] = B + rkb[BROADCAST_BK] * STRIDE_BK + ryb[BROADCAST_BN] * STRIDE_BN; + TYPE* pa[SHAPE_A] = A + rk[BROADCAST_AK] * STRIDE_AK + rm[BROADCAST_AM] * STRIDE_AM; + TYPE* pb[SHAPE_B] = B + rk[BROADCAST_BK] * STRIDE_BK + rn[BROADCAST_BN] * STRIDE_BN; // prefetches operands TYPE a[SHAPE_A] = *pa; TYPE b[SHAPE_B] = *pb; @@ -32,9 +31,7 @@ void dot(TYPE * A, TYPE * B, TYPE * C, b = checkb ? *pb : 0; } // epilogue - int rxc[TM] = ridx * TM + 0 ... TM; - int ryc[TN] = ridy * TN + 0 ... TN; - TYPE* pc[TM, TN] = C + rxc[:, newaxis] + ryc[newaxis, :] * ldc; + TYPE* pc[TM, TN] = C + rm[:, newaxis] + rn[newaxis, :] * ldc; *pc = c; } )";