diff --git a/include/triton/codegen/analysis/align.h b/include/triton/codegen/analysis/align.h index bbc5fe440..647db3984 100644 --- a/include/triton/codegen/analysis/align.h +++ b/include/triton/codegen/analysis/align.h @@ -55,6 +55,8 @@ private: std::vector populate_starting_multiple_gep(ir::getelementptr_inst* x); std::vector populate_starting_multiple_default(ir::value* v); std::vector populate_starting_multiple(ir::value *v); + // populate all maps + void populate(ir::value *v); public: void run(ir::module &mod); diff --git a/include/triton/codegen/analysis/axes.h b/include/triton/codegen/analysis/axes.h index f625c4193..d22fa5fa8 100644 --- a/include/triton/codegen/analysis/axes.h +++ b/include/triton/codegen/analysis/axes.h @@ -23,15 +23,24 @@ class axes { private: void add_constraint(node_t x, node_t y); - void init_c_phi(ir::instruction *i); - void init_c_graph(ir::instruction *v); + // update graph + void update_graph_store(ir::instruction *i); + void update_graph_reduce(ir::instruction *i); + void update_graph_reshape(ir::instruction *i); + void update_graph_splat(ir::instruction *i); + void update_graph_trans(ir::instruction *i); + void update_graph_broadcast(ir::instruction *i); + void update_graph_dot(ir::instruction *i); + void update_graph_elementwise(ir::instruction *i); + void update_graph(ir::instruction *i); + // connected components void connected_components(node_t x, std::set &nodes, graph_t &graph, unsigned group_id); public: axes(); void run(ir::module &mod); - unsigned get(ir::value *value, unsigned ax); - bool has(ir::value *value, unsigned ax); + unsigned get_id(ir::value *value, unsigned ax); + bool has_id(ir::value *value, unsigned ax); private: // constraints graph diff --git a/include/triton/codegen/analysis/layout.h b/include/triton/codegen/analysis/layout.h index 7bc14b08f..2e7fbd830 100644 --- a/include/triton/codegen/analysis/layout.h +++ b/include/triton/codegen/analysis/layout.h @@ -24,8 +24,9 @@ class layout { typedef std::map > graph_t; private: - // create edge + // graph creation void connect(ir::value *x, ir::value *y); + void make_graph(ir::instruction *i); // connected components void connected_components(node_t x, std::set &nodes, graph_t &graph, unsigned id); // list the axes of the given value diff --git a/include/triton/codegen/transform/vectorize.h b/include/triton/codegen/transform/vectorize.h deleted file mode 100644 index 0a6571b61..000000000 --- a/include/triton/codegen/transform/vectorize.h +++ /dev/null @@ -1,31 +0,0 @@ -#ifndef TDL_INCLUDE_CODEGEN_VECTORIZE_H -#define TDL_INCLUDE_CODEGEN_VECTORIZE_H - -namespace triton { - -namespace ir { - class module; -} - -namespace codegen{ - -namespace analysis{ - class tiles; -} - -namespace transform{ - -class vectorize { -public: - vectorize(analysis::tiles *params): params_(params){} - void run(ir::module &mod); - -private: - analysis::tiles *params_; -}; - -} -} -} - -#endif diff --git a/include/triton/ir/builder.h b/include/triton/ir/builder.h index 0b6c859b1..d5707265a 100644 --- a/include/triton/ir/builder.h +++ b/include/triton/ir/builder.h @@ -140,7 +140,6 @@ public: value *create_select(value *pred, value *if_value, value *else_value, const std::string &name = ""); // Intrinsics value *create_copy_to_shared(value *arg, const std::string &name = ""); - value *create_vectorize(value *arg, const std::string &name = ""); value *create_barrier(const std::string &name = ""); private: diff --git a/include/triton/ir/enums.h b/include/triton/ir/enums.h index 609fb2d46..19cf82086 100644 --- a/include/triton/ir/enums.h +++ b/include/triton/ir/enums.h @@ -77,6 +77,67 @@ enum cmp_pred_t { LAST_ICMP_PREDICATE }; +enum value_id_t: unsigned { + /* ------------ * + INSTRUCTIONS + * ------------ */ + INST_BEGIN, + // phi + INST_PHI, + // arithmetic + INST_BINOP, + INST_GETELEMENTPTR, + INST_SELECT, + INST_SQRT, + // cmp + INST_ICMP, + INST_FCMP, + // cast + INST_CAST_TRUNC, + INST_CAST_ZEXT, + INST_CAST_SEXT, + INST_CAST_FP_TRUNC, + INST_CAST_FP_EXT, + INST_CAST_UI_TO_FP, + INST_CAST_SI_TO_FP, + INST_CAST_FP_TO_UI, + INST_CAST_FP_TO_SI, + INST_CAST_PTR_TO_INT, + INST_CAST_INT_TO_PTR, + INST_CAST_BIT_CAST, + INST_CAST_ADDR_SPACE_CAST, + // terminators + INST_RETURN, + INST_COND_BRANCH, + INST_UNCOND_BRANCH, + // io + INST_UNMASKED_LOAD, + INST_MASKED_LOAD, + INST_UNMASKED_STORE, + INST_MASKED_STORE, + // retile + INST_RESHAPE, + INST_SPLAT, + INST_BROADCAST, + INST_DOWNCAST, + // builtin + INST_GET_PROGRAM_ID, + INST_GET_NUM_PROGRAMS, + // atomics + INST_ATOMIC_CAS, + INST_ATOMIC_EXCH, + INST_ATOMIC_ADD, + // array arithmetic + INST_TRANS, + INST_REDUCE, + INST_DOT, + // intrinsics + INST_COPY_TO_SHARED, + INST_BARRIER, + INST_MAKE_RANGE_DYN, + INST_MAKE_RANGE_STA, + INST_MAKE_RANGE +}; diff --git a/include/triton/ir/instructions.h b/include/triton/ir/instructions.h index dd85fd3a0..bafc1c2c3 100644 --- a/include/triton/ir/instructions.h +++ b/include/triton/ir/instructions.h @@ -40,7 +40,8 @@ private: protected: // constructors - instruction(type *ty, unsigned num_ops, unsigned num_results = 1, const std::string &name = "", instruction *next = nullptr); + instruction(type *ty, value_id_t ity, unsigned num_ops, + const std::string &name = "", instruction *next = nullptr); public: // parent @@ -59,32 +60,21 @@ public: // cloning ir::instruction* clone() { ir::instruction* res = clone_impl(); -// for(auto it = op_begin(); it != op_end(); it++){ -// (*it)->add_use(res); -// } - res->set_name("testcloned"); + for(auto it = op_begin(); it != op_end(); it++) + (*it)->add_use(res); res->parent_ = nullptr; return res; } + // instruction id + value_id_t get_id() const { return id_; } private: basic_block *parent_; std::map metadatas_; + value_id_t id_; }; -// result reference -class result_reference: public value { -public: - result_reference(instruction *ref, unsigned arg_id, const std::string &name = ""); - instruction *get_ref(); - unsigned get_arg_id(); - -private: - instruction *ref_; - unsigned arg_id_; -}; - //===----------------------------------------------------------------------===// // phi_node classes //===----------------------------------------------------------------------===// @@ -173,11 +163,13 @@ public: class cmp_inst: public instruction{ public: typedef cmp_pred_t pred_t; + private: std::string repr_impl() const; protected: - cmp_inst(type *ty, cmp_pred_t pred, value *lhs, value *rhs, const std::string &name, instruction *next); + cmp_inst(type *ty, value_id_t id, cmp_pred_t pred, + value *lhs, value *rhs, const std::string &name, instruction *next); static bool is_fp_predicate(cmp_pred_t pred); static bool is_int_predicate(cmp_pred_t pred); static type* make_cmp_result_type(type *ty); @@ -190,7 +182,8 @@ private: }; class icmp_inst: public cmp_inst { - using cmp_inst::cmp_inst; + icmp_inst(type *ty, cmp_pred_t pred, + value *lhs, value *rhs, const std::string &name, instruction *next); public: static icmp_inst* create(cmp_pred_t pred, value *lhs, value *rhs, @@ -199,7 +192,8 @@ public: }; class fcmp_inst: public cmp_inst { - using cmp_inst::cmp_inst; + fcmp_inst(type *ty, cmp_pred_t pred, + value *lhs, value *rhs, const std::string &name, instruction *next); public: static fcmp_inst* create(cmp_pred_t pred, value *lhs, value *rhs, @@ -213,7 +207,7 @@ public: class unary_inst: public instruction { protected: - unary_inst(type *Ty, value *v, const std::string &name, instruction *next); + unary_inst(type *ty, value_id_t id, value *v, const std::string &name, instruction *next); }; @@ -226,8 +220,8 @@ private: std::string repr_impl() const; protected: - cast_inst(type *ty, value *v, const std::string &name, instruction *next, cast_op_t op) - : unary_inst(ty, v, name, next), op_(op) { } + cast_inst(type *ty, value_id_t id, value *v, const std::string &name, instruction *next, cast_op_t op) + : unary_inst(ty, id, v, name, next), op_(op) { } private: static bool is_valid(cast_op_t op, value *arg, type *ty); @@ -246,27 +240,27 @@ private: cast_op_t op_; }; -#define TRITON_IR_DECLARE_CAST_INST_SIMPL(name, op) \ +#define TRITON_IR_DECLARE_CAST_INST_SIMPL(name, id, op) \ class name : public cast_inst { \ _TRITON_DEFINE_CLONE(name); \ friend class cast_inst; \ name(type *ty, value *v, const std::string &name, instruction *next) \ - : cast_inst(ty, v, name, next, op){ } \ + : cast_inst(ty, id, v, name, next, op){ } \ }; -TRITON_IR_DECLARE_CAST_INST_SIMPL(trunc_inst, cast_op_t::Trunc) -TRITON_IR_DECLARE_CAST_INST_SIMPL(z_ext_inst, cast_op_t::ZExt) -TRITON_IR_DECLARE_CAST_INST_SIMPL(s_ext_inst, cast_op_t::SExt) -TRITON_IR_DECLARE_CAST_INST_SIMPL(fp_trunc_inst, cast_op_t::FPTrunc) -TRITON_IR_DECLARE_CAST_INST_SIMPL(fp_ext_inst, cast_op_t::FPExt) -TRITON_IR_DECLARE_CAST_INST_SIMPL(ui_to_fp_inst, cast_op_t::UIToFP) -TRITON_IR_DECLARE_CAST_INST_SIMPL(si_to_fp_inst, cast_op_t::SIToFP) -TRITON_IR_DECLARE_CAST_INST_SIMPL(fp_to_ui_inst, cast_op_t::FPToUI) -TRITON_IR_DECLARE_CAST_INST_SIMPL(fp_to_si_inst, cast_op_t::FPToSI) -TRITON_IR_DECLARE_CAST_INST_SIMPL(ptr_to_int_inst, cast_op_t::PtrToInt) -TRITON_IR_DECLARE_CAST_INST_SIMPL(int_to_ptr_inst, cast_op_t::IntToPtr) -TRITON_IR_DECLARE_CAST_INST_SIMPL(bit_cast_inst, cast_op_t::BitCast) -TRITON_IR_DECLARE_CAST_INST_SIMPL(addr_space_cast_inst, cast_op_t::AddrSpaceCast) +TRITON_IR_DECLARE_CAST_INST_SIMPL(trunc_inst, INST_CAST_TRUNC, cast_op_t::Trunc) +TRITON_IR_DECLARE_CAST_INST_SIMPL(z_ext_inst, INST_CAST_ZEXT, cast_op_t::ZExt) +TRITON_IR_DECLARE_CAST_INST_SIMPL(s_ext_inst, INST_CAST_SEXT, cast_op_t::SExt) +TRITON_IR_DECLARE_CAST_INST_SIMPL(fp_trunc_inst, INST_CAST_FP_TRUNC, cast_op_t::FPTrunc) +TRITON_IR_DECLARE_CAST_INST_SIMPL(fp_ext_inst, INST_CAST_FP_EXT, cast_op_t::FPExt) +TRITON_IR_DECLARE_CAST_INST_SIMPL(ui_to_fp_inst, INST_CAST_UI_TO_FP, cast_op_t::UIToFP) +TRITON_IR_DECLARE_CAST_INST_SIMPL(si_to_fp_inst, INST_CAST_SI_TO_FP, cast_op_t::SIToFP) +TRITON_IR_DECLARE_CAST_INST_SIMPL(fp_to_ui_inst, INST_CAST_FP_TO_UI, cast_op_t::FPToUI) +TRITON_IR_DECLARE_CAST_INST_SIMPL(fp_to_si_inst, INST_CAST_FP_TO_SI, cast_op_t::FPToSI) +TRITON_IR_DECLARE_CAST_INST_SIMPL(ptr_to_int_inst, INST_CAST_PTR_TO_INT, cast_op_t::PtrToInt) +TRITON_IR_DECLARE_CAST_INST_SIMPL(int_to_ptr_inst, INST_CAST_INT_TO_PTR, cast_op_t::IntToPtr) +TRITON_IR_DECLARE_CAST_INST_SIMPL(bit_cast_inst, INST_CAST_BIT_CAST, cast_op_t::BitCast) +TRITON_IR_DECLARE_CAST_INST_SIMPL(addr_space_cast_inst, INST_CAST_ADDR_SPACE_CAST, cast_op_t::AddrSpaceCast) //===----------------------------------------------------------------------===// // terminator_inst classes @@ -372,33 +366,38 @@ private: class io_inst: public instruction { protected: - io_inst(type *ty, unsigned num_ops, unsigned num_results = 1, const std::string &name = "", instruction *next = nullptr); + io_inst(type *ty, value_id_t id, unsigned num_ops, + const std::string &name = "", instruction *next = nullptr); public: // accessors value *get_pointer_operand() { return get_operand(0); } - -// value *get_mask() const; -// value *get_false_value() const; }; +// load class load_inst: public io_inst { protected: - load_inst(value *ptr, unsigned num_extra_ops, const std::string &name, instruction *next); + load_inst(value *ptr, value_id_t id, unsigned num_ops, + const std::string &name = "", instruction *next = nullptr); private: - std::string repr_impl() const { return "load"; } static type *get_pointee_type(type *ty); - -public: - - // factory method - static load_inst* create(value *ptr, - const std::string &name = "", - instruction *next = nullptr); - _TRITON_DEFINE_CLONE(load_inst) }; +// unmasked load +class unmasked_load_inst: public load_inst { +private: + std::string repr_impl() const { return "unmasked_load"; } + unmasked_load_inst(value *ptr, const std::string &name, instruction *next); + +public: + static unmasked_load_inst* create(value *ptr, + const std::string &name = "", + instruction *next = nullptr); + _TRITON_DEFINE_CLONE(unmasked_load_inst) +}; + +// masked load class masked_load_inst: public load_inst { private: std::string repr_impl() const { return "masked_load"; } @@ -416,22 +415,28 @@ public: _TRITON_DEFINE_CLONE(masked_load_inst) }; -class store_inst: public io_inst{ +// store +class store_inst: public io_inst { protected: - store_inst(value *ptr, value *v, unsigned num_extra_ops, - const std::string &name, instruction *next); - -private: - std::string repr_impl() const { return "store"; } + store_inst(value *ptr, value_id_t id, unsigned num_ops, + const std::string &name = "", instruction *next = nullptr); public: - // accessors value *get_value_operand() { return get_operand(1); } +}; + +// unmasked_store +class unmasked_store_inst: public store_inst{ +private: + std::string repr_impl() const { return "unmasked_store"; } + unmasked_store_inst(value *ptr, value *v, const std::string &name, instruction *next); + +public: // factory method - static store_inst* create(value* ptr, value *v, - const std::string &name = "", - instruction *next = nullptr); - _TRITON_DEFINE_CLONE(store_inst) + static unmasked_store_inst* create(value* ptr, value *v, + const std::string &name = "", + instruction *next = nullptr); + _TRITON_DEFINE_CLONE(unmasked_store_inst) }; class masked_store_inst: public store_inst{ @@ -458,7 +463,7 @@ public: class retile_inst: public unary_inst { protected: - retile_inst(value *arg, const type::tile_shapes_t &shapes, const std::string &name, instruction *next); + retile_inst(value *arg, value_id_t id, const type::tile_shapes_t &shapes, const std::string &name, instruction *next); }; // reshape @@ -690,16 +695,6 @@ public: instruction *next = nullptr); }; -class vectorize_inst: public unary_inst{ -private: - using unary_inst::unary_inst; - std::string repr_impl() const { return "vectorize"; } - _TRITON_DEFINE_CLONE(vectorize_inst) - -public: - static vectorize_inst* create(value *arg, const std::string &name = "", instruction *next = nullptr); -}; - // On NVIDIA, implementation is such that // constant_range = nv_dynamic_program_idx + nv_static_program_idx // so as to enable re-association on nv_static_program_idx which is constant diff --git a/include/triton/ir/cfg.h b/include/triton/ir/utils.h similarity index 50% rename from include/triton/ir/cfg.h rename to include/triton/ir/utils.h index a61ff6dee..3b9e2f5f3 100644 --- a/include/triton/ir/cfg.h +++ b/include/triton/ir/utils.h @@ -4,18 +4,25 @@ #define _TRITON_IR_CFG_H_ #include +#include namespace triton{ namespace ir{ +class module; class function; class basic_block; +class instruction; +class value; class cfg { public: static std::vector reverse_post_order(function* fn); }; +void for_each_instruction(ir::module& mod, const std::function &fn); +void for_each_value(ir::module& mod, const std::function &fn); + } } diff --git a/include/triton/runtime/function.h b/include/triton/runtime/function.h index 9d04cad78..0eaa9a33d 100644 --- a/include/triton/runtime/function.h +++ b/include/triton/runtime/function.h @@ -20,7 +20,6 @@ #include "triton/codegen/transform/peephole.h" #include "triton/codegen/transform/membar.h" #include "triton/codegen/transform/reassociate.h" -#include "triton/codegen/transform/vectorize.h" #include "triton/lang/parser.h" #include "triton/runtime/arg.h" diff --git a/lib/codegen/analysis/align.cc b/lib/codegen/analysis/align.cc index fda2f6e32..f84e8d692 100644 --- a/lib/codegen/analysis/align.cc +++ b/lib/codegen/analysis/align.cc @@ -1,4 +1,5 @@ #include "triton/codegen/analysis/align.h" +#include "triton/ir/utils.h" #include "triton/ir/module.h" #include "triton/ir/function.h" #include "triton/ir/basic_block.h" @@ -14,15 +15,19 @@ namespace analysis{ inline int gcd(int a, int b) { - if (a == 0) - return b; - if (b == 0) - return a; - if (a == b) - return a; - if (a > b) - return gcd(a-b, b); - return gcd(a, b-a); + if (a == 0) + return b; + if (b == 0) + return a; + if (a == b) + return a; + if (a > b) + return gcd(a - b, b); + return gcd(a, b - a); +} + +inline int lcm(int a, int b) { + return (a * b) / gcd(a, b); } template @@ -64,8 +69,8 @@ std::vector align::populate_is_constant_phi(ir::phi_node* x) { std::vector align::populate_is_constant_splat(ir::splat_inst* x) { auto shapes = get_shapes(x); - std::vector result; ir::value* op = x->get_operand(0); + std::vector result; auto op_cst = populate_is_constant(op); for(auto d: shapes) result.push_back(cst_info{d, op_cst[0].value}); @@ -478,28 +483,15 @@ std::vector align::contiguous(ir::value* v) const { return max_contiguous_.at(v); } + +void align::populate(ir::value *v) { + populate_is_constant(v); + populate_starting_multiple(v); + populate_max_contiguous(v); +} + void align::run(ir::module &mod) { - - // populate constant - for(ir::function *fn: mod.get_function_list()) - for(ir::basic_block *block: fn->blocks()) - for(ir::instruction *i: block->get_inst_list()){ - populate_is_constant(i); - } - - // populate starting multiple - for(ir::function *fn: mod.get_function_list()) - for(ir::basic_block *block: fn->blocks()) - for(ir::instruction *i: block->get_inst_list()){ - populate_starting_multiple(i); - } - - // populate maximum contiguous - for(ir::function *fn: mod.get_function_list()) - for(ir::basic_block *block: fn->blocks()) - for(ir::instruction *i: block->get_inst_list()){ - populate_max_contiguous(i); - } + ir::for_each_value(mod, [this](ir::value* v) { populate(v); } ); } diff --git a/lib/codegen/analysis/axes.cc b/lib/codegen/analysis/axes.cc index 99fc59234..790c8a36b 100644 --- a/lib/codegen/analysis/axes.cc +++ b/lib/codegen/analysis/axes.cc @@ -1,5 +1,6 @@ #include "triton/codegen/analysis/axes.h" #include "triton/ir/instructions.h" +#include "triton/ir/utils.h" #include "triton/ir/type.h" #include "triton/ir/module.h" #include "triton/ir/function.h" @@ -30,91 +31,113 @@ void axes::add_constraint(node_t x, node_t y) { nodes_.insert(y); } -void axes::init_c_graph(ir::instruction *v) { - // Reference shape - ir::type::tile_shapes_t shapes; - if(auto *store = dynamic_cast(v)) - shapes = store->get_pointer_operand()->get_type()->get_tile_shapes(); - else if(auto *atom = dynamic_cast(v)) - shapes = atom->get_operand(0)->get_type()->get_tile_shapes(); - else if(dynamic_cast(v)) + +void axes::update_graph_reduce(ir::instruction *i) { + auto* red = static_cast(i); + unsigned axis = red->get_axis(); + ir::value *arg = red->get_operand(0); + auto in_shapes = arg->get_type()->get_tile_shapes(); + unsigned current = 0; + for(unsigned d = 0; d < in_shapes.size(); d++){ + if(d == axis) + continue; + add_constraint({i, current++}, {arg, d}); + } +} + +void axes::update_graph_reshape(ir::instruction *i) { + auto* reshape = static_cast(i); + // operands + ir::value *op = reshape->get_operand(0); + // shapes + auto op_shapes = op->get_type()->get_tile_shapes(); + auto res_shapes = reshape->get_type()->get_tile_shapes(); + // construct edges + unsigned current = 0; + bool is_skewed = false; + for(unsigned d = 0; d < res_shapes.size(); d ++){ + bool same_shape = res_shapes[d] == op_shapes[current]; + // either add edge between axis or just add a node in the graph + if(!is_skewed && same_shape) + add_constraint({i, d}, {op, current++}); + else + add_constraint({i, d}, {i, d}); + // reshaping is skewed + if(res_shapes[d] > 1 && !same_shape) + is_skewed = true; + } +} + +void axes::update_graph_splat(ir::instruction *) { + // argument is scalar so don't make any edge + return; +} + +void axes::update_graph_trans(ir::instruction *i) { + auto *trans = static_cast(i); + ir::value *op = trans->get_operand(0); + auto perm = trans->get_perm(); + // add edge between axis perm[d] and axis d + for(unsigned d = 0; d < perm.size(); d++) + add_constraint({i, perm[d]->get_value()}, {op, d}); +} + +void axes::update_graph_broadcast(ir::instruction *i) { + auto *broadcast = static_cast(i); + auto shapes = broadcast->get_type()->get_tile_shapes(); + ir::value *op = broadcast->get_operand(0); + ir::type *op_ty = op->get_type(); + const auto& op_shapes = op_ty->get_tile_shapes(); + // add edge between non-broadcast axes + for(unsigned d = 0; d < shapes.size(); d ++) + if(op_shapes[d] == shapes[d]) + add_constraint({i, d}, {op, d}); +} + +void axes::update_graph_dot(ir::instruction *i) { + auto *dot = static_cast(i); + auto shapes = dot->get_type()->get_tile_shapes(); + ir::value *A = dot->get_operand(0); + ir::value *B = dot->get_operand(1); + ir::value *D = dot->get_operand(2); + // add edges between result and accumulator + for(unsigned d = 0; d < shapes.size(); d++) + add_constraint({dot, d}, {D, d}); + // add edge for batch dimension + for(unsigned d = 2; d < shapes.size(); d++){ + add_constraint({dot, d}, {A, d}); + add_constraint({dot, d}, {B, d}); + } +} + +void axes::update_graph_elementwise(ir::instruction *i) { + if(i->get_num_operands() == 0) return; - else if(dynamic_cast(v)) - return; - else if(auto *reduce = dynamic_cast(v)) { - unsigned axis = reduce->get_axis(); - ir::value *arg = reduce->get_operand(0); - auto in_shapes = arg->get_type()->get_tile_shapes(); - unsigned current = 0; - for(unsigned i = 0; i < in_shapes.size(); i++){ - if(i == axis) - continue; - add_constraint({reduce, current++}, {arg, i}); - } + ir::value *op = i->get_operand(0); + if(!op->get_type()->is_tile_ty()) return; + auto rank = op->get_type()->get_tile_rank(); + for(unsigned d = 0; d < rank; d++) + for(ir::value* opx: i->ops()) + for(ir::value* opy: i->ops()){ + if(!i->get_type()->is_void_ty()) + add_constraint({i, d}, {opx, d}); + add_constraint({opx, d}, {opy, d}); } - else - shapes = v->get_type()->get_tile_shapes(); - // Reshape - if(dynamic_cast(v)) { - ir::value *op = v->get_operand(0); - auto op_shapes = op->get_type()->get_tile_shapes(); - unsigned current = 0; - bool is_skewed = false; - for(unsigned i = 0; i < shapes.size(); i ++){ - if(shapes[i] == 1){ - add_constraint({v, i}, {v, i}); - } - else if(!is_skewed && - shapes[i] == op_shapes[current]) - add_constraint({v, i}, {op, current++}); - else{ - is_skewed = true; - add_constraint({v, i}, {v, i}); - } - } - } - // Splat - else if(dynamic_cast(v)){ - return; - } - // Trans - else if(auto *x = dynamic_cast(v)){ - ir::value *op = v->get_operand(0); - auto perm = x->get_perm(); - for(unsigned i = 0; i < perm.size(); i++) - add_constraint({v, perm[i]->get_value()}, {op, i}); - } - // Broadcast - else if(dynamic_cast(v)){ - ir::value *op = v->get_operand(0); - ir::type *op_ty = op->get_type(); - const auto& op_shapes = op_ty->get_tile_shapes(); - for(unsigned i = 0; i < shapes.size(); i ++){ - if(op_shapes[i] == shapes[i] && v != op) - add_constraint({v, i}, {op, i}); - } - } - // Matrix multiplication - else if(dynamic_cast(v)){ - ir::value *A = v->get_operand(0); - ir::value *B = v->get_operand(1); - ir::value *D = v->get_operand(2); - for(unsigned i = 0; i < shapes.size(); i++) - add_constraint({v, i}, {D, i}); - for(unsigned i = 2; i < shapes.size(); i++){ - add_constraint({v, i}, {A, i}); - add_constraint({v, i}, {B, i}); - } - } - // Element-wise - else if(dynamic_cast(v)) { - for(unsigned i = 0; i < shapes.size(); i ++){ - std::vector ops = v->ops(); - for(ir::value* op: ops) - add_constraint({v, i}, {op, i}); - } +} + + +void axes::update_graph(ir::instruction *i) { + switch (i->get_id()) { + case ir::INST_REDUCE: return update_graph_reduce(i); + case ir::INST_RESHAPE: return update_graph_reshape(i); + case ir::INST_SPLAT: return update_graph_splat(i); + case ir::INST_TRANS: return update_graph_trans(i); + case ir::INST_BROADCAST: return update_graph_broadcast(i); + case ir::INST_DOT: return update_graph_dot(i); + default: return update_graph_elementwise(i); } + return; } void axes::connected_components(node_t x, std::set &nodes, graph_t &graph, unsigned group_id) { @@ -126,12 +149,12 @@ void axes::connected_components(node_t x, std::set &nodes, graph_t &grap } } -unsigned axes::get(ir::value *value, unsigned ax) { +unsigned axes::get_id(ir::value *value, unsigned ax) { unsigned result = groups_.at(value).at(ax); return result; } -bool axes::has(ir::value *value, unsigned ax) { +bool axes::has_id(ir::value *value, unsigned ax) { auto it = groups_.find(value); if(it == groups_.end()) return false; @@ -146,15 +169,9 @@ void axes::run(ir::module &mod) { nodes_.clear(); dependencies_.clear(); groups_.clear(); - // Create graph - for(ir::function *fn: mod.get_function_list()){ - // Build constraints graph - for(ir::basic_block *block: fn->blocks()) - for(ir::instruction *i : block->get_inst_list()) - if(i->has_tile_result_or_op()) - init_c_graph(i); - } - // Axes + // make graph + ir::for_each_instruction(mod, [this](ir::instruction *x) { update_graph(x); }); + // connected components unsigned group_id = 0; while(!nodes_.empty()) connected_components(*nodes_.begin(), nodes_, dependencies_, group_id++); diff --git a/lib/codegen/analysis/layout.cc b/lib/codegen/analysis/layout.cc index 0f376b4fc..77b25e0bb 100644 --- a/lib/codegen/analysis/layout.cc +++ b/lib/codegen/analysis/layout.cc @@ -4,6 +4,7 @@ #include "triton/codegen/analysis/layout.h" #include "triton/ir/function.h" #include "triton/ir/module.h" +#include "triton/ir/utils.h" namespace triton{ namespace codegen{ @@ -20,8 +21,8 @@ std::set layout::axes_of(ir::value *value) { // create result std::set result; for(size_t d = 0; d < rank; d++){ - if(axes_->has(value, d)) - result.insert(axes_->get(value, d)); + if(axes_->has_id(value, d)) + result.insert(axes_->get_id(value, d)); } return result; } @@ -74,24 +75,23 @@ void layout::connect(ir::value *x, ir::value *y) { } } +void layout::make_graph(ir::instruction *i) { + for(ir::value* opx: i->ops()) + for(ir::value* opy: i->ops()){ + connect(i, opx); + connect(opx, opy); + } +} + // run void layout::run(ir::module &mod) { nodes_.clear(); dependencies_.clear(); groups_.clear(); values_.clear(); - // Create graph - for(ir::function *fn: mod.get_function_list()) - for(ir::basic_block *block: fn->blocks()) - for(ir::instruction *i : block->get_inst_list()) { - for(ir::value* opx: i->ops()) - for(ir::value* opy: i->ops()){ - connect(i, opx); - connect(opx, opy); - } - - } - // Grids + // make graph + ir::for_each_instruction(mod, [this](ir::instruction* i) { make_graph(i); }); + // connected components unsigned group_id = 0; while(!nodes_.empty()){ connected_components(*nodes_.begin(), nodes_, dependencies_, group_id++); diff --git a/lib/codegen/analysis/meminfo.cc b/lib/codegen/analysis/meminfo.cc index 314c272c0..be55d6ac7 100644 --- a/lib/codegen/analysis/meminfo.cc +++ b/lib/codegen/analysis/meminfo.cc @@ -82,9 +82,9 @@ void add_copy(ir::value *x, ir::builder &builder) { } void meminfo::run(ir::module &mod) { -// shared_.clear(); -// refs_.clear(); -// double_.clear(); + shared_.clear(); + refs_.clear(); + double_.clear(); // Add shared copies for(ir::function *fn: mod.get_function_list()){ diff --git a/lib/codegen/analysis/tiles.cc b/lib/codegen/analysis/tiles.cc index d1b26a6f9..7d4d81376 100644 --- a/lib/codegen/analysis/tiles.cc +++ b/lib/codegen/analysis/tiles.cc @@ -43,19 +43,19 @@ bool tiles::hmma(ir::value *value) { } int tiles::mts(ir::value *value, unsigned ax) { - return mts_.at(axes_->get(value, ax)); + return mts_.at(axes_->get_id(value, ax)); } int tiles::nts(ir::value *value, unsigned ax) { - return nts_.at(axes_->get(value, ax)); + return nts_.at(axes_->get_id(value, ax)); } int tiles::fpw(ir::value *value, unsigned ax) { - return fpw_.at(axes_->get(value, ax)); + return fpw_.at(axes_->get_id(value, ax)); } int tiles::wpt(ir::value *value, unsigned ax) { - return wpt_.at(axes_->get(value, ax)); + return wpt_.at(axes_->get_id(value, ax)); } std::vector tiles::order(ir::value *v) { @@ -92,7 +92,7 @@ void tiles::init_hmma_tile(ir::value *i) { }while(fpw_nm1 != fpw); // store parameters for(unsigned d = 0; d < shapes.size(); d++) - fpw_[axes_->get(i, d)] = fpw[d]; + fpw_[axes_->get_id(i, d)] = fpw[d]; /* warps per tile */ // try to make things as square as possible to maximize data re-use std::vector wpt = {1, 1, 1}; @@ -106,11 +106,11 @@ void tiles::init_hmma_tile(ir::value *i) { }while(wpt_nm1 != wpt); // store parameters for(unsigned d = 0; d < shapes.size(); d++) - wpt_[axes_->get(i, d)] = wpt[d]; + wpt_[axes_->get_id(i, d)] = wpt[d]; /* sanity check */ unsigned effective_num_warps = 1; for(size_t d = 0; d < shapes.size(); d++) - effective_num_warps *= wpt_[axes_->get(i, d)]; + effective_num_warps *= wpt_[axes_->get_id(i, d)]; if(num_warps_ != effective_num_warps) throw std::runtime_error("cannot create a kernel with this amount of warps"); } @@ -122,19 +122,19 @@ void tiles::init_scanline_tile(ir::value *i) { unsigned ld = ord[0]; unsigned num_threads = num_warps_*32; unsigned current = num_threads; - nts_[axes_->get(i, ld)] = clamp(size / num_threads, 1, 4); - mts_[axes_->get(i, ld)] = clamp(current, 1, shapes[ld] / nts_[axes_->get(i, ld)]); - current = current / mts_[axes_->get(i, ld)]; + nts_[axes_->get_id(i, ld)] = clamp(size / num_threads, 1, 4); + mts_[axes_->get_id(i, ld)] = clamp(current, 1, shapes[ld] / nts_[axes_->get_id(i, ld)]); + current = current / mts_[axes_->get_id(i, ld)]; for(size_t d = 1; d < shapes.size(); d++){ ld = ord[d]; - nts_[axes_->get(i, ld)] = 1; - mts_[axes_->get(i, ld)] = clamp(current, 1, shapes[ld]); - current = current / mts_[axes_->get(i, ld)]; + nts_[axes_->get_id(i, ld)] = 1; + mts_[axes_->get_id(i, ld)] = clamp(current, 1, shapes[ld]); + current = current / mts_[axes_->get_id(i, ld)]; } /* sanity check */ unsigned effective_num_threads = 1; for(size_t d = 0; d < shapes.size(); d++) - effective_num_threads *= mts_[axes_->get(i, d)]; + effective_num_threads *= mts_[axes_->get_id(i, d)]; if(num_threads != effective_num_threads) throw std::runtime_error("cannot create a kernel with this amount of warps"); } diff --git a/lib/codegen/selection.cc b/lib/codegen/selection.cc index c6592a59c..d89a4e1c5 100644 --- a/lib/codegen/selection.cc +++ b/lib/codegen/selection.cc @@ -615,7 +615,7 @@ void selection::init_strided_scan_axes(ir::value *v, IRBuilder<> &builder, Value unsigned offset = n / contiguous[k] * per_block + n % contiguous[k]; idx_list[n] = builder.CreateAdd(scaled_thread_id, builder.getInt32(offset), "idx_" + str_k + "_" + std::to_string(n)); } - axes_[a_axes_->get(v, k)] = distributed_axis{contiguous[k], idx_list, thread_id}; + axes_[a_axes_->get_id(v, k)] = distributed_axis{contiguous[k], idx_list, thread_id}; } } @@ -720,10 +720,10 @@ void selection::init_hmma_axes(ir::value *v, IRBuilder<> &builder, Value *u_thre /* axes */ - axes_[a_axes_->get(v, 0)] = distributed_axis{1, idx_i, warp_id_0}; - axes_[a_axes_->get(v, 1)] = distributed_axis{1, idx_j, warp_id_1}; + axes_[a_axes_->get_id(v, 0)] = distributed_axis{1, idx_i, warp_id_0}; + axes_[a_axes_->get_id(v, 1)] = distributed_axis{1, idx_j, warp_id_1}; if(is_batched) - axes_[a_axes_->get(v, 2)] = distributed_axis{1, idx_z, warp_id_2}; + axes_[a_axes_->get_id(v, 2)] = distributed_axis{1, idx_z, warp_id_2}; } @@ -791,7 +791,7 @@ void selection::create_distributed_tile(ir::value *v, IRBuilder<> &builder) { std::vector axes(shapes.size()); for(size_t d = 0; d < shapes.size(); d++){ if(shapes[d] > 1){ - unsigned x = a_axes_->get(v, d); + unsigned x = a_axes_->get_id(v, d); axes[d] = axes_.at(x); } else{ @@ -942,7 +942,7 @@ void selection::lower_reduce(ir::reduce_inst *x, LLVMContext &ctx, Function *fn, Value *base_ptr = builder.CreateBitCast(sh_mem_ptr_, PointerType::get(res_ty, addr_space)); for(auto& x: partial) { // current element being computed - Value *lane = axes_.at(a_axes_->get(op, axis)).thread_id; + Value *lane = axes_.at(a_axes_->get_id(op, axis)).thread_id; Value *&result = x.second; indices_t write_idx = x.first; write_idx.insert(write_idx.begin() + axis, lane); diff --git a/lib/codegen/target.cc b/lib/codegen/target.cc index 4116bcca7..f63b4b899 100644 --- a/lib/codegen/target.cc +++ b/lib/codegen/target.cc @@ -103,39 +103,9 @@ Value* nvidia_cu_target::get_block_id(Module *module, IRBuilder<>& builder, unsi Intrinsic::nvvm_read_ptx_sreg_ctaid_y, Intrinsic::nvvm_read_ptx_sreg_ctaid_z }; -// bool z_order = true; -// if(z_order && ax < 2){ -// static std::array n_cta_ids = { -// Intrinsic::nvvm_read_ptx_sreg_nctaid_x, -// Intrinsic::nvvm_read_ptx_sreg_nctaid_y, -// Intrinsic::nvvm_read_ptx_sreg_nctaid_z -// }; -// Value* cta_id_0 = builder.CreateIntrinsic(cta_ids[0], {}, {}); -// Value* cta_id_1 = builder.CreateIntrinsic(cta_ids[1], {}, {}); -// Value* n_cta_id_0 = builder.CreateIntrinsic(n_cta_ids[0], {}, {}); -// Value* n_cta_id_1 = builder.CreateIntrinsic(n_cta_ids[1], {}, {}); -// // global block ID -// Value* bid = builder.CreateAdd(cta_id_0, builder.CreateMul(cta_id_1, n_cta_id_0)); -// // helper for minimum -// auto Min = [&](Value *x, Value *y){ -// return builder.CreateSelect(builder.CreateICmpSGE(x, y), y, x); -// }; -// // super-tile size -// Value* sts = Min(builder.getInt32(16), n_cta_id_1); -// // number of CTAs per super-block -// Value *nscta = builder.CreateMul(n_cta_id_0, sts); -// Value *bid0 = builder.CreateURem(builder.CreateUDiv(bid, sts), n_cta_id_0); -// Value *bid1 = builder.CreateAdd(builder.CreateMul(builder.CreateUDiv(bid, nscta), sts),builder.CreateURem(bid, sts)); -// if(ax == 0) -// return bid0; -// else -// return bid1; -// } -// else{ - Value* get_cta_id = Intrinsic::getDeclaration(module, cta_ids[ax]); - Value* cta_id = builder.CreateCall(get_cta_id, {}); - return cta_id; -// } + Value* get_cta_id = Intrinsic::getDeclaration(module, cta_ids[ax]); + Value* cta_id = builder.CreateCall(get_cta_id, {}); + return cta_id; } Value* nvidia_cu_target::get_local_id(Module *module, IRBuilder<>& builder, unsigned ax) { diff --git a/lib/codegen/transform/coalesce.cc b/lib/codegen/transform/coalesce.cc index 873f7a9f5..455f2fb5d 100644 --- a/lib/codegen/transform/coalesce.cc +++ b/lib/codegen/transform/coalesce.cc @@ -2,7 +2,7 @@ #include #include #include "triton/ir/function.h" -#include "triton/ir/cfg.h" +#include "triton/ir/utils.h" #include "triton/ir/basic_block.h" #include "triton/ir/instructions.h" #include "triton/ir/module.h" diff --git a/lib/codegen/transform/dce.cc b/lib/codegen/transform/dce.cc index a1b5880c5..18406b4ab 100644 --- a/lib/codegen/transform/dce.cc +++ b/lib/codegen/transform/dce.cc @@ -1,7 +1,7 @@ #include "triton/ir/function.h" #include "triton/ir/basic_block.h" #include "triton/ir/module.h" -#include "triton/ir/cfg.h" +#include "triton/ir/utils.h" #include "triton/codegen/transform/dce.h" #include diff --git a/lib/codegen/transform/membar.cc b/lib/codegen/transform/membar.cc index fc6891ea8..b8b029d9a 100644 --- a/lib/codegen/transform/membar.cc +++ b/lib/codegen/transform/membar.cc @@ -9,7 +9,7 @@ #include "triton/ir/function.h" #include "triton/ir/basic_block.h" #include "triton/ir/instructions.h" -#include "triton/ir/cfg.h" +#include "triton/ir/utils.h" namespace triton { diff --git a/lib/codegen/transform/reassociate.cc b/lib/codegen/transform/reassociate.cc index 8ca89cda2..38e8c79ed 100644 --- a/lib/codegen/transform/reassociate.cc +++ b/lib/codegen/transform/reassociate.cc @@ -6,7 +6,7 @@ #include "triton/ir/function.h" #include "triton/ir/basic_block.h" #include "triton/ir/instructions.h" -#include "triton/ir/cfg.h" +#include "triton/ir/utils.h" namespace triton { namespace codegen{ diff --git a/lib/codegen/transform/vectorize.cc b/lib/codegen/transform/vectorize.cc deleted file mode 100644 index ef120f903..000000000 --- a/lib/codegen/transform/vectorize.cc +++ /dev/null @@ -1,41 +0,0 @@ -#include "triton/codegen/transform/vectorize.h" -#include "triton/codegen/analysis/tiles.h" -#include "triton/ir/module.h" -#include "triton/ir/function.h" -#include "triton/ir/basic_block.h" -#include "triton/ir/instructions.h" - -namespace triton { - -namespace codegen{ -namespace transform{ - -void vectorize::run(ir::module &mod) { - ir::builder &builder = mod.get_builder(); - for(ir::function *fn: mod.get_function_list()) - for(ir::basic_block *block: fn->blocks()) - for(ir::instruction *i: block->get_inst_list()){ - if(auto *trans = dynamic_cast(i)){ - ir::value *x = i->get_operand(0); - if(trans->get_perm()[0]->get_value() != 0) - continue; - builder.set_insert_point(i); - ir::instruction *rx = (ir::instruction*)builder.create_vectorize(x); - x->replace_all_uses_with(rx); - rx->set_operand(0, x); - } - if(dynamic_cast(i)){ - ir::value *x = i->get_operand(0); - if(params_->nts(x, 0) == 1) - continue; - builder.set_insert_point(i); - ir::instruction *rx = (ir::instruction*)builder.create_vectorize(x); - x->replace_all_uses_with(rx); - rx->set_operand(0, x); - } - } -} - -} -} -} diff --git a/lib/driver/module.cc b/lib/driver/module.cc index 0bf85c84f..66c775ac6 100755 --- a/lib/driver/module.cc +++ b/lib/driver/module.cc @@ -241,6 +241,7 @@ std::string cu_module::compile_llvm_module(std::unique_ptr module, cu_module::cu_module(driver::context * context, std::unique_ptr ll_module): cu_module(context, compile_llvm_module(std::move(ll_module), context->device())) { } cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){ +// std::cout << source << std::endl; cu_context::context_switcher ctx_switch(*context); // JIT compile source-code CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER}; diff --git a/lib/ir/builder.cc b/lib/ir/builder.cc index 458365a60..00450b547 100644 --- a/lib/ir/builder.cc +++ b/lib/ir/builder.cc @@ -252,11 +252,11 @@ DEFINE_FCMP_INSTR(ONE, cmp_pred_t::FCMP_ONE) //===----------------------------------------------------------------------===// value *builder::create_load(value *ptr, const std::string &name){ - return insert(load_inst::create(ptr, name)); + return insert(unmasked_load_inst::create(ptr, name)); } value *builder::create_store(value *ptr, value *val, const std::string &name){ - return insert(store_inst::create(ptr, val, name)); + return insert(unmasked_store_inst::create(ptr, val, name)); } value *builder::create_masked_load(value *ptr, value *mask, value *false_value, const std::string &name){ @@ -340,10 +340,6 @@ value *builder::create_copy_to_shared(value *arg, const std::string &name) { return insert(copy_to_shared_inst::create(arg, name)); } -value *builder::create_vectorize(value *arg, const std::string &name) { - return insert(vectorize_inst::create(arg, name)); -} - value *builder::create_barrier(const std::string &name) { return insert(barrier_inst::create(ctx_, name)); } diff --git a/lib/ir/cfg.cc b/lib/ir/cfg.cc deleted file mode 100644 index 5b19849d4..000000000 --- a/lib/ir/cfg.cc +++ /dev/null @@ -1,31 +0,0 @@ -#include -#include "triton/ir/cfg.h" -#include "triton/ir/basic_block.h" -#include "triton/ir/function.h" - -namespace triton{ -namespace ir{ - -std::vector cfg::reverse_post_order(function* fn) { - std::stack stack; - std::set visited; - std::vector result; - // initialize stack - for(ir::basic_block* block: fn->blocks()) - if(block->get_predecessors().empty()) - stack.push(block); - // DFS - while(!stack.empty()) { - basic_block* current = stack.top(); - stack.pop(); - result.push_back(current); - visited.insert(current); - for(basic_block* succ: current->get_successors()) - if(visited.find(succ) == visited.end()) - stack.push(succ); - } - return std::move(result); -} - -} -} diff --git a/lib/ir/instructions.cc b/lib/ir/instructions.cc index acecc08b5..e89367536 100644 --- a/lib/ir/instructions.cc +++ b/lib/ir/instructions.cc @@ -12,8 +12,9 @@ namespace ir{ // instruction classes //===----------------------------------------------------------------------===// -instruction::instruction(type *ty, unsigned num_ops, unsigned num_results, const std::string &name, instruction *next) - : user(ty, num_ops, name) { +instruction::instruction(type *ty, value_id_t ity, unsigned num_ops, + const std::string &name, instruction *next) + : user(ty, num_ops, name), id_(ity) { if(next){ basic_block *block = next->get_parent(); assert(block && "Next instruction is not in a basic block!"); @@ -35,17 +36,12 @@ bool instruction::has_tile_result_or_op() { return result; } - -// result reference -result_reference::result_reference(instruction *ref, unsigned arg_id, const std::string &name) - : value(ref->get_type(), name), arg_id_(arg_id){ } - //===----------------------------------------------------------------------===// // phi_node classes //===----------------------------------------------------------------------===// phi_node::phi_node(type *ty, unsigned num_reserved, std::string const &name, instruction *next) - : instruction(ty, 0, 1, name, next) { + : instruction(ty, INST_PHI, 0, name, next) { blocks_.reserve(num_reserved); } @@ -131,7 +127,7 @@ bool binary_operator::is_int_add_sub() const { binary_operator::binary_operator(binary_op_t op, value *lhs, value *rhs, type *ty, const std::string &name, instruction *next) - : instruction(ty, 2, 1, name, next), op_(op){ + : instruction(ty, INST_BINOP, 2, name, next), op_(op){ set_operand(0, lhs); set_operand(1, rhs); } @@ -164,6 +160,8 @@ binary_operator *binary_operator::create_not(value *arg, const std::string &name // cmp_inst classes //===----------------------------------------------------------------------===// + + // cmp_inst std::string cmp_inst::repr_impl() const { switch (pred_) { @@ -197,8 +195,8 @@ std::string cmp_inst::repr_impl() const { } } -cmp_inst::cmp_inst(type *ty, cmp_pred_t pred, value *lhs, value *rhs, const std::string &name, instruction *next) - : instruction(ty, 2, 1, name, next), pred_(pred) { +cmp_inst::cmp_inst(type *ty, value_id_t id, cmp_pred_t pred, value *lhs, value *rhs, const std::string &name, instruction *next) + : instruction(ty, id, 2, name, next), pred_(pred) { set_operand(0, lhs); set_operand(1, rhs); } @@ -219,7 +217,12 @@ bool cmp_inst::is_int_predicate(cmp_pred_t pred) { return pred >= FIRST_ICMP_PREDICATE && pred <= LAST_ICMP_PREDICATE; } + // icmp_inst +icmp_inst::icmp_inst(type *ty, cmp_pred_t pred, + value *lhs, value *rhs, const std::string &name, instruction *next) + : cmp_inst(ty, INST_ICMP, pred, lhs, rhs, name, next){ } + icmp_inst* icmp_inst::create(cmp_pred_t pred, value *lhs, value *rhs, const std::string &name, instruction *next){ assert(is_int_predicate(pred)); type *res_ty = make_cmp_result_type(lhs->get_type()); @@ -227,6 +230,10 @@ icmp_inst* icmp_inst::create(cmp_pred_t pred, value *lhs, value *rhs, const std: } // fcmp_inst +fcmp_inst::fcmp_inst(type *ty, cmp_pred_t pred, + value *lhs, value *rhs, const std::string &name, instruction *next) + : cmp_inst(ty, INST_FCMP, pred, lhs, rhs, name, next){ } + fcmp_inst* fcmp_inst::create(cmp_pred_t pred, value *lhs, value *rhs, const std::string &name, instruction *next){ assert(is_fp_predicate(pred)); type *res_ty = make_cmp_result_type(lhs->get_type()); @@ -237,8 +244,8 @@ fcmp_inst* fcmp_inst::create(cmp_pred_t pred, value *lhs, value *rhs, const std: // unary_inst classes //===----------------------------------------------------------------------===// -unary_inst::unary_inst(type *ty, value *v, const std::string &name, instruction *next) - : instruction(ty, 1, 1, name, next) { +unary_inst::unary_inst(type *ty, value_id_t id, value *v, const std::string &name, instruction *next) + : instruction(ty, id, 1, name, next) { set_operand(0, v); } @@ -309,7 +316,7 @@ cast_inst *cast_inst::create_integer_cast(value *arg, type *ty, bool is_signed, // return_inst return_inst::return_inst(context &ctx, value *ret_val, instruction *next) - : terminator_inst(type::get_void_ty(ctx), ret_val!=nullptr, 0, "", next){ + : terminator_inst(type::get_void_ty(ctx), INST_RETURN, ret_val!=nullptr, "", next){ if(ret_val) set_operand(0, ret_val); } @@ -332,13 +339,13 @@ branch_inst* branch_inst::create(value *cond, basic_block *if_dst, basic_block * // uncond_branch_inst uncond_branch_inst::uncond_branch_inst(basic_block *dst, instruction *next) - : branch_inst(type::get_void_ty(dst->get_context()), 1, 0, "", next){ + : branch_inst(type::get_void_ty(dst->get_context()), INST_UNCOND_BRANCH, 1, "", next){ set_operand(0, dst); } // cond_branch_inst cond_branch_inst::cond_branch_inst(basic_block *if_dst, basic_block *else_dst, value *cond, instruction *next) - : branch_inst(type::get_void_ty(if_dst->get_context()), 3, 0, "", next){ + : branch_inst(type::get_void_ty(if_dst->get_context()), INST_COND_BRANCH, 3, "", next){ assert(cond->get_type()->is_integer_ty(1) && "May only branch on boolean predicates!"); set_operand(0, if_dst); set_operand(1, else_dst); @@ -351,7 +358,7 @@ cond_branch_inst::cond_branch_inst(basic_block *if_dst, basic_block *else_dst, v //===----------------------------------------------------------------------===// getelementptr_inst::getelementptr_inst(type *pointee_ty, value *ptr, const std::vector &idx, const std::string &name, instruction *next) - : instruction(get_return_type(pointee_ty, ptr, idx), 1 + idx.size(), 1, name, next), + : instruction(get_return_type(pointee_ty, ptr, idx), INST_GETELEMENTPTR, 1 + idx.size(), name, next), source_elt_ty(pointee_ty), res_elt_ty(get_indexed_type(pointee_ty, idx)){ // sanity check @@ -414,8 +421,13 @@ getelementptr_inst *getelementptr_inst::create(value *ptr, const std::vectorget_type()), id, num_ops, name, next) { } // load @@ -427,19 +439,21 @@ type *load_inst::get_pointee_type(type *ty) { return pointee_ty; } -load_inst::load_inst(value *ptr, unsigned num_extra_ops, const std::string &name, instruction *next) - : io_inst(get_pointee_type(ptr->get_type()), 1 + num_extra_ops, 1, name, next) { +// unmasked_load +unmasked_load_inst::unmasked_load_inst(value *ptr, const std::string &name, instruction *next) + : load_inst(ptr, INST_UNMASKED_LOAD, 1, name, next) { set_operand(0, ptr); } -load_inst* load_inst::create(value *ptr, const std::string &name, instruction *next) { - return new load_inst(ptr, 0, name, next); +unmasked_load_inst* unmasked_load_inst::create(value *ptr, const std::string &name, instruction *next) { + return new unmasked_load_inst(ptr, name, next); } // masked load masked_load_inst::masked_load_inst(value *ptr, value *mask, value *false_value, const std::string &name, instruction *next) - : load_inst(ptr, 2, name, next) { + : load_inst(ptr, INST_MASKED_LOAD, 3, name, next) { + set_operand(0, ptr); set_operand(1, mask); set_operand(2, false_value); } @@ -450,23 +464,29 @@ masked_load_inst* masked_load_inst::create(value *ptr, value *mask, value *false } -// store -store_inst::store_inst(value *ptr, value *val, unsigned num_extra_ops, - const std::string &name, instruction *next) - : io_inst(type::get_void_ty(ptr->get_type()->get_context()), 2 + num_extra_ops, 1, name, next) { +store_inst::store_inst(value *ptr, value_id_t id, unsigned num_ops, const std::string &name, instruction *next) + : io_inst(type::get_void_ty(ptr->get_type()->get_context()), id, num_ops, name, next) +{ } + +// unmasked_store +unmasked_store_inst::unmasked_store_inst(value *ptr, value *val, + const std::string &name, instruction *next) + : store_inst(ptr, INST_UNMASKED_STORE, 2, name, next) { set_operand(0, ptr); set_operand(1, val); } -store_inst* store_inst::create(value *ptr, value *val, - const std::string &name, instruction *next) { - return new store_inst(ptr, val, 0, name, next); +unmasked_store_inst* unmasked_store_inst::create(value *ptr, value *val, + const std::string &name, instruction *next) { + return new unmasked_store_inst(ptr, val, name, next); } // masked store masked_store_inst::masked_store_inst(value *ptr, value *val, value *mask, const std::string &name, instruction *next) - : store_inst(ptr, val, 1, name, next) { + : store_inst(ptr, INST_MASKED_STORE, 3, name, next) { + set_operand(0, ptr); + set_operand(1, val); set_operand(2, mask); } @@ -477,15 +497,16 @@ masked_store_inst* masked_store_inst::create(value *ptr, value *val, value *mask // retile_inst classes //===----------------------------------------------------------------------===// -retile_inst::retile_inst(value *arg, const type::tile_shapes_t &shapes, +retile_inst::retile_inst(value *arg, value_id_t id, const type::tile_shapes_t &shapes, const std::string &name, instruction *next) - : unary_inst(tile_type::get(arg->get_type()->get_scalar_ty(), shapes), arg, name, next) { } + : unary_inst(tile_type::get(arg->get_type()->get_scalar_ty(), shapes), id, arg, name, next) { } + // reshape instruction* reshape_inst::create(value *arg, const type::tile_shapes_t &shapes, const std::string &name, instruction *next) { - return new reshape_inst(arg, shapes, name, next); + return new reshape_inst(arg, INST_RESHAPE, shapes, name, next); } @@ -493,20 +514,20 @@ instruction* reshape_inst::create(value *arg, const type::tile_shapes_t &shapes, instruction* splat_inst::create(value *arg, const type::tile_shapes_t &shapes, const std::string &name, instruction *next) { - return new splat_inst(arg, shapes, name, next); + return new splat_inst(arg, INST_SPLAT, shapes, name, next); } // broadcast instruction* broadcast_inst::create(value *arg, const type::tile_shapes_t &shapes, const std::string &name, instruction *next) { - return new broadcast_inst(arg, shapes, name, next); + return new broadcast_inst(arg, INST_BROADCAST, shapes, name, next); } // downcast instruction* downcast_inst::create(value *arg, const std::string &name, instruction *next) { - return new downcast_inst(arg->get_type()->get_scalar_ty(), arg, name, next); + return new downcast_inst(arg->get_type()->get_scalar_ty(), INST_DOWNCAST, arg, name, next); } //===----------------------------------------------------------------------===// @@ -515,7 +536,7 @@ instruction* downcast_inst::create(value *arg, const std::string &name, instruct dot_inst::dot_inst(value *A, value *B, value *C, TransT AT, TransT BT, const std::string &name, instruction *next) - : builtin_inst(C->get_type(), 3, 1, name, next), AT_(AT), BT_(BT) { + : builtin_inst(C->get_type(), INST_DOT, 3, name, next), AT_(AT), BT_(BT) { set_operand(0, A); set_operand(1, B); set_operand(2, C); @@ -578,7 +599,7 @@ std::vector trans_inst::init_perm(ir::type* ty, const std::vector } trans_inst::trans_inst(value *arg, const std::vector& perm, const std::string &name, instruction *next) - : builtin_inst(get_res_ty(arg->get_type(), perm), 1, 1, name, next) { + : builtin_inst(get_res_ty(arg->get_type(), perm), INST_TRANS, 1, name, next) { // sanity check perm_ = init_perm(arg->get_type(), perm); //auto size = arg->get_type()->get_tile_shapes().size(); @@ -599,7 +620,7 @@ const std::vector trans_inst::get_perm() const { //===----------------------------------------------------------------------===// sqrt_inst::sqrt_inst(value *arg, const std::string &name, instruction *next) - : builtin_inst(arg->get_type(), 1, 1, name, next){ + : builtin_inst(arg->get_type(), INST_SQRT, 1, name, next){ set_operand(0, arg); } @@ -621,7 +642,7 @@ type* reduce_inst::get_res_type(value *arg, unsigned axis) { } reduce_inst::reduce_inst(value *arg, unsigned axis, const std::string &name, instruction *next) - : builtin_inst(get_res_type(arg, axis), 1, 1, name, next), + : builtin_inst(get_res_type(arg, axis), INST_REDUCE, 1, name, next), axis_(axis){ set_operand(0, arg); } @@ -636,7 +657,7 @@ instruction* reduce_inst::create(value *arg, unsigned axis, const std::string &n //===----------------------------------------------------------------------===// select_inst::select_inst(value *pred, value *if_value, value *else_value, const std::string &name, instruction *next) - : builtin_inst(if_value->get_type(), 3, 1, name, next){ + : builtin_inst(if_value->get_type(), INST_SELECT, 3, name, next){ set_operand(0, pred); set_operand(1, if_value); set_operand(2, else_value); @@ -652,7 +673,7 @@ instruction* select_inst::create(value *pred, value *if_value, value *else_value // get_program_id get_program_id_inst::get_program_id_inst(type *ty, unsigned axis, const std::string &name, instruction *next) - : builtin_inst(ty, 0, 1, name, next), axis_(axis){ + : builtin_inst(ty, INST_GET_PROGRAM_ID, 0, name, next), axis_(axis){ } @@ -662,7 +683,7 @@ instruction* get_program_id_inst::create(context &ctx, unsigned axis, const std: // get_num_program get_num_program_inst::get_num_program_inst(type *ty, unsigned axis, const std::string &name, instruction *next) - : builtin_inst(ty, 0, 1, name, next), axis_(axis){ + : builtin_inst(ty, INST_GET_NUM_PROGRAMS, 0, name, next), axis_(axis){ } @@ -674,7 +695,7 @@ instruction* get_num_program_inst::create(context &ctx, unsigned axis, const std // atomic cas atomic_cas_inst::atomic_cas_inst(value *ptr, value *cmp, value *val, const std::string &name, instruction *next) - : builtin_inst(ptr->get_type()->get_pointer_element_ty(), 3, 1, name, next) { + : builtin_inst(ptr->get_type()->get_pointer_element_ty(), INST_ATOMIC_CAS, 3, name, next) { set_operand(0, ptr); set_operand(1, cmp); set_operand(2, val); @@ -687,7 +708,7 @@ instruction* atomic_cas_inst::create(value *ptr, value *cmp, value *val, const s // atomic exch atomic_exch_inst::atomic_exch_inst(value *ptr, value *val, const std::string &name, instruction *next) - : builtin_inst(ptr->get_type()->get_pointer_element_ty(), 2, 1, name, next) { + : builtin_inst(ptr->get_type()->get_pointer_element_ty(), INST_ATOMIC_EXCH, 2, name, next) { set_operand(0, ptr); set_operand(1, val); } @@ -699,7 +720,7 @@ instruction* atomic_exch_inst::create(value *ptr, value *val, const std::string // atomic add atomic_add_inst::atomic_add_inst(value *ptr, value *val, const std::string &name, instruction *next) - : builtin_inst(ptr->get_type()->get_pointer_element_ty(), 2, 1, name, next) { + : builtin_inst(ptr->get_type()->get_pointer_element_ty(), INST_ATOMIC_ADD, 2, name, next) { set_operand(0, ptr); set_operand(1, val); } @@ -714,18 +735,13 @@ instruction* atomic_add_inst::create(value *ptr, value *val, const std::string & // copy to shared copy_to_shared_inst* copy_to_shared_inst::create(value *arg, const std::string &name, instruction *next) { - return new copy_to_shared_inst(arg->get_type(), arg, name, next); -} - -// vectorize -vectorize_inst* vectorize_inst::create(value *arg, const std::string &name, instruction *next) { - return new vectorize_inst(arg->get_type(), arg, name, next); + return new copy_to_shared_inst(arg->get_type(), INST_COPY_TO_SHARED, arg, name, next); } // barrier barrier_inst::barrier_inst(context &ctx, const std::string &name, instruction *next) - : instruction(type::get_void_ty(ctx), 0, 0, name, next) { } + : instruction(type::get_void_ty(ctx), INST_BARRIER, 0, name, next) { } barrier_inst* barrier_inst::create(context &ctx, const std::string &name, instruction *next) { return new barrier_inst(ctx, name, next); @@ -734,7 +750,7 @@ barrier_inst* barrier_inst::create(context &ctx, const std::string &name, instru // nv_dynamic_program_idx make_range_dyn::make_range_dyn(type *ty, const std::string &name, instruction *next) - : instruction(ty, 0, 1, name, next) { } + : instruction(ty, INST_MAKE_RANGE_DYN, 0, name, next) { } make_range_dyn* make_range_dyn::create(type *ty, const std::string &name, instruction *next) { return new make_range_dyn(ty, name, next); @@ -757,7 +773,7 @@ make_range_sta* make_range_sta::get(make_range* range) { // make_range make_range::make_range(type *ty, constant_int *first, constant_int *last) - : instruction(ty, 0), first_(first), last_(last){ } + : instruction(ty, INST_MAKE_RANGE, 0), first_(first), last_(last){ } make_range *make_range::create(constant_int *first, constant_int *last) { assert(first->get_type()->is_integer_ty()); diff --git a/lib/ir/utils.cc b/lib/ir/utils.cc new file mode 100644 index 000000000..7baf5df14 --- /dev/null +++ b/lib/ir/utils.cc @@ -0,0 +1,54 @@ +#include +#include +#include "triton/ir/utils.h" +#include "triton/ir/basic_block.h" +#include "triton/ir/function.h" +#include "triton/ir/module.h" + +namespace triton{ +namespace ir{ + +std::vector cfg::reverse_post_order(function* fn) { + std::stack stack; + std::set visited; + std::vector result; + // initialize stack + for(ir::basic_block* block: fn->blocks()) + if(block->get_predecessors().empty()) + stack.push(block); + // DFS + while(!stack.empty()) { + basic_block* current = stack.top(); + stack.pop(); + result.push_back(current); + visited.insert(current); + for(basic_block* succ: current->get_successors()) + if(visited.find(succ) == visited.end()) + stack.push(succ); + } + return std::move(result); +} + +void for_each_instruction(module &mod, const std::function &do_work) { + for(ir::function *fn: mod.get_function_list()) + for(ir::basic_block *block: cfg::reverse_post_order(fn)) + for(ir::instruction *i: block->get_inst_list()) + do_work(i); +} + +void for_each_value(module &mod, const std::function &do_work) { + std::set seen; + for(ir::function *fn: mod.get_function_list()) + for(ir::basic_block *block: cfg::reverse_post_order(fn)) + for(ir::instruction *i: block->get_inst_list()){ + for(ir::value *op: i->ops()){ + if(seen.insert(op).second) + do_work(op); + } + if(seen.insert(i).second) + do_work(i); + } +} + +} +}