From 26c984946286177a0ff4f3dd1a2d5757a6255eb5 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Mon, 5 Aug 2019 21:19:13 -0700 Subject: [PATCH] [ir][instructions] added permutations option for trans --- include/triton/codegen/optimize_trans.h | 3 ++- include/triton/ir/builder.h | 2 +- include/triton/ir/instructions.h | 9 +++++++-- lib/codegen/optimize_dot.cpp | 8 +++++++- lib/codegen/optimize_trans.cpp | 9 +++++---- lib/codegen/selection.cpp | 8 +++++--- lib/codegen/tune.cpp | 8 ++++---- lib/dnn/dot.cpp | 5 ++++- lib/ir/builder.cpp | 4 ++-- lib/ir/instructions.cpp | 24 +++++++++++++++++++++--- 10 files changed, 58 insertions(+), 22 deletions(-) diff --git a/include/triton/codegen/optimize_trans.h b/include/triton/codegen/optimize_trans.h index c6ec73b4d..8af45205d 100644 --- a/include/triton/codegen/optimize_trans.h +++ b/include/triton/codegen/optimize_trans.h @@ -13,13 +13,14 @@ namespace ir { class instruction; class trans_inst; class builder; + class constant_int; } namespace codegen{ class optimize_trans { private: - ir::value *replace_phi(ir::value* value, ir::builder &builder); + ir::value *replace_phi(ir::value* value, ir::builder &builder, const std::vector &perm); public: optimize_trans() {} diff --git a/include/triton/ir/builder.h b/include/triton/ir/builder.h index 079a79e40..d3f5e7be4 100644 --- a/include/triton/ir/builder.h +++ b/include/triton/ir/builder.h @@ -132,7 +132,7 @@ public: value *create_atomic_exch(value *ptr, value *val, const std::string &name = ""); value *create_atomic_add(value *ptr, value *val, const std::string &name = ""); value *create_dot(value *A, value *B, value *C, const std::string &name = ""); - value *create_trans(value *A, const std::string &name = ""); + value *create_trans(value *A, const std::vector &perm = {}, const std::string &name = ""); value *create_sqrt(value *A, const std::string &name = ""); value *create_reduce(value *A, unsigned axis, const std::string &name = ""); value *create_select(value *pred, value *if_value, value *else_value, const std::string &name = ""); diff --git a/include/triton/ir/instructions.h b/include/triton/ir/instructions.h index 9886a8a0e..c0b176ebf 100644 --- a/include/triton/ir/instructions.h +++ b/include/triton/ir/instructions.h @@ -585,13 +585,18 @@ private: class trans_inst: public builtin_inst { public: ir::type* get_res_ty(ir::type* in); + std::vector get_default_perm(ir::type* ty); private: - trans_inst(value *arg, const std::string& name, instruction* next); + trans_inst(value *arg, const std::vector& perm, const std::string& name, instruction* next); std::string repr_impl() const { return "trans"; } public: - static instruction* create(value *arg, const std::string &name = "", instruction *next = nullptr); + static instruction* create(value *arg, const std::vector& perm = {}, const std::string &name = "", instruction *next = nullptr); + const std::vector get_perm() const; + +private: + std::vector perm_; }; class sqrt_inst: public builtin_inst { diff --git a/lib/codegen/optimize_dot.cpp b/lib/codegen/optimize_dot.cpp index 8688e918e..e3ebfbcdb 100644 --- a/lib/codegen/optimize_dot.cpp +++ b/lib/codegen/optimize_dot.cpp @@ -68,7 +68,13 @@ void optimize_dot::run(ir::module &mod) { } // dot(op(a), b) if(!trans_b){ - ir::value* BB = builder.create_trans(B); + size_t size = B->get_type()->get_tile_shapes().size(); + std::vector perm(size); + ir::type *int32_ty = ir::type::get_int32_ty(B->get_type()->get_context()); + for(size_t i = 0; i < size; i++) + perm[i] = ir::constant_int::get(int32_ty, i); + std::swap(perm[0], perm[1]); + ir::value* BB = builder.create_trans(B, perm); ir::instruction *NT = builder.insert(ir::dot_inst::create_nt(A, BB, D)); dot->replace_all_uses_with(NT); to_delete.push_back(dot); diff --git a/lib/codegen/optimize_trans.cpp b/lib/codegen/optimize_trans.cpp index 0fb96ac96..16c4605d7 100644 --- a/lib/codegen/optimize_trans.cpp +++ b/lib/codegen/optimize_trans.cpp @@ -7,12 +7,13 @@ namespace codegen{ ir::value* optimize_trans::replace_phi(ir::value* value, - ir::builder& builder){ + ir::builder& builder, + const std::vector &perm){ if(auto phi = dynamic_cast(value)) { // transpose operands std::vector incs; for(unsigned n = 0; n < phi->get_num_incoming(); n++) - incs.push_back(replace_phi(phi->get_incoming_value(n), builder)); + incs.push_back(replace_phi(phi->get_incoming_value(n), builder, perm)); // create phi for transposed values builder.set_insert_point(phi); ir::phi_node* result = builder.create_phi(incs[0]->get_type(), incs.size(), phi->get_name()); @@ -26,7 +27,7 @@ ir::value* optimize_trans::replace_phi(ir::value* value, auto it = std::find(block->begin(), block->end(), i); it++; builder.set_insert_point(it); - ir::instruction *trans = (ir::instruction*)builder.create_trans(i); + ir::instruction *trans = (ir::instruction*)builder.create_trans(i, perm); i->replace_all_uses_with(trans); trans->set_operand(0, i); return trans; @@ -53,7 +54,7 @@ void optimize_trans::run(ir::module &mod) { // trans(phi) -> phi(trans(), trans()...) if(dynamic_cast(op)){ - ir::value* new_phi = replace_phi(op, builder); + ir::value* new_phi = replace_phi(op, builder, trans->get_perm()); trans->replace_all_uses_with(new_phi); } } diff --git a/lib/codegen/selection.cpp b/lib/codegen/selection.cpp index e419f5a8d..3b973de71 100644 --- a/lib/codegen/selection.cpp +++ b/lib/codegen/selection.cpp @@ -974,11 +974,13 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> & }); } // trans - else if(dynamic_cast(ins)) { + else if(auto* x = dynamic_cast(ins)) { distributed_tile* in = (distributed_tile*)tmap_.at(ins->get_operand(0)); + auto perm = x->get_perm(); in->for_each([&](indices_t idx){ - indices_t out_idx = idx; - std::rotate(out_idx.begin(), out_idx.begin() + 1, out_idx.end()); + indices_t out_idx(idx.size()); + for(size_t i = 0; i < idx.size(); i++) + out_idx[i] = idx[perm[i]->get_value()]; ti->set_value(out_idx, in->get_value(idx)); }); } diff --git a/lib/codegen/tune.cpp b/lib/codegen/tune.cpp index 35445a72d..09017e978 100644 --- a/lib/codegen/tune.cpp +++ b/lib/codegen/tune.cpp @@ -98,11 +98,11 @@ void tune::init_c_graph(ir::instruction *v) { } // Trans - else if(dynamic_cast(v)){ + else if(auto *x = dynamic_cast(v)){ ir::value *op = v->get_operand(0); - size_t n_shapes = shapes.size(); - for(unsigned i = 0; i < n_shapes; i++) - add_constraint({v, (i + 1) % n_shapes}, {op, i}); + auto perm = x->get_perm(); + for(unsigned i = 0; i < perm.size(); i++) + add_constraint({v, perm[i]->get_value()}, {op, i}); } // Broadcast else if(dynamic_cast(v)){ diff --git a/lib/dnn/dot.cpp b/lib/dnn/dot.cpp index 7cc7563dc..a43ea1ca4 100644 --- a/lib/dnn/dot.cpp +++ b/lib/dnn/dot.cpp @@ -75,7 +75,7 @@ void dot::triton_c_src(std::ostream &os) const { std::string AS0 = "TM", AS1 = "TK"; std::string BS0 = "TK", BS1 = "TN"; std::string XAS0 = "TM", XAS1 = "TK/4", XAS2 = "4"; - std::string XBS0 = "TN", XBS1 = "TK/4", XBS2 = "4"; + std::string XBS0 = "TK/4", XBS1 = "TN", XBS2 = "4"; std::string bca0 = "[newaxis, :]", bca1 = "[:, newaxis]"; std::string bcb0 = "[:, newaxis]", bcb1 = "[newaxis, :]"; std::string lda0 = "*lda", lda1 = ""; @@ -84,11 +84,13 @@ void dot::triton_c_src(std::ostream &os) const { std::string useb = BT_ ? "trans(xb)" : "xb"; if(AT_){ std::swap(AS0, AS1); + std::swap(XAS0, XAS1); std::swap(bca0, bca1); std::swap(lda0, lda1); } if(BT_){ std::swap(BS0, BS1); + std::swap(XBS0, XBS1); std::swap(bcb0, bcb1); std::swap(ldb0, ldb1); } @@ -144,6 +146,7 @@ void matmul(restrict read_only align(16) )" + a_ty_ + R"( *A, } )"; + std::cout << res << std::endl; os << res; } diff --git a/lib/ir/builder.cpp b/lib/ir/builder.cpp index d47fbbaa5..1f6aa7c54 100644 --- a/lib/ir/builder.cpp +++ b/lib/ir/builder.cpp @@ -324,8 +324,8 @@ value *builder::create_dot(value *A, value *B, value *C, const std::string &name return insert(dot_inst::create_nn(A, B, C, name)); } -value *builder::create_trans(value *A, const std::string &name) { - return insert(trans_inst::create(A, name)); +value *builder::create_trans(value *A, const std::vector& perm, const std::string &name) { + return insert(trans_inst::create(A, perm, name)); } value *builder::create_sqrt(value *A, const std::string &name) { diff --git a/lib/ir/instructions.cpp b/lib/ir/instructions.cpp index e6e85ff85..58a81cd3b 100644 --- a/lib/ir/instructions.cpp +++ b/lib/ir/instructions.cpp @@ -572,13 +572,31 @@ ir::type* trans_inst::get_res_ty(ir::type* ty) { return tile_type::get(ty->get_scalar_ty(), shapes); } -trans_inst::trans_inst(value *arg, const std::string &name, instruction *next) +std::vector trans_inst::get_default_perm(ir::type* ty) { + auto size = ty->get_tile_shapes().size(); + ir::type* int32_ty = type::get_int32_ty(ty->get_context()); + std::vector result; + for(size_t i = 0; i < size; i++) + result.push_back(ir::constant_int::get(int32_ty, i + 1 % size)); + return result; +} + +trans_inst::trans_inst(value *arg, const std::vector& perm, const std::string &name, instruction *next) : builtin_inst(get_res_ty(arg->get_type()), 1, 1, name, next) { + perm_ = perm; + if(perm_.empty()) + perm_ = get_default_perm(arg->get_type()); + auto size = arg->get_type()->get_tile_shapes().size(); + assert(perm_.size() == size); set_operand(0, arg); } -instruction* trans_inst::create(value *arg, const std::string &name, instruction *next) { - return new trans_inst(arg, name, next); +instruction* trans_inst::create(value *arg, const std::vector &perm, const std::string &name, instruction *next) { + return new trans_inst(arg, perm, name, next); +} + +const std::vector trans_inst::get_perm() const { + return perm_; } //===----------------------------------------------------------------------===//