diff --git a/include/triton/codegen/analysis/axes.h b/include/triton/codegen/analysis/axes.h new file mode 100644 index 000000000..f625c4193 --- /dev/null +++ b/include/triton/codegen/analysis/axes.h @@ -0,0 +1,49 @@ +#ifndef _TRITON_CODEGEN_ANALYSIS_AXES_H_ +#define _TRITON_CODEGEN_ANALYSIS_AXES_H_ + +#include +#include +#include +#include + +namespace triton{ + +namespace ir{ + class value; + class module; + class instruction; +} + +namespace codegen{ +namespace analysis{ + +class axes { + typedef std::pair node_t; + typedef std::map > graph_t; + +private: + void add_constraint(node_t x, node_t y); + void init_c_phi(ir::instruction *i); + void init_c_graph(ir::instruction *v); + void connected_components(node_t x, std::set &nodes, graph_t &graph, unsigned group_id); + +public: + axes(); + void run(ir::module &mod); + unsigned get(ir::value *value, unsigned ax); + bool has(ir::value *value, unsigned ax); + +private: + // constraints graph + graph_t dependencies_; + std::set nodes_; + // parameter groups + std::map> groups_; +}; + +} +} + +} + +#endif diff --git a/include/triton/codegen/analysis/grid.h b/include/triton/codegen/analysis/grid.h deleted file mode 100644 index 465011b83..000000000 --- a/include/triton/codegen/analysis/grid.h +++ /dev/null @@ -1,89 +0,0 @@ -#ifndef TDL_INCLUDE_IR_CODEGEN_TUNE_H -#define TDL_INCLUDE_IR_CODEGEN_TUNE_H - -#include -#include -#include -#include - -namespace triton{ - -namespace ir{ - class value; - class module; - class instruction; - class function; - class metaparameter; - class constant_int; -} - -namespace codegen{ - -namespace transform{ -class coalesce; -} - -namespace analysis{ - -class grids { - typedef std::pair node_t; - typedef std::map > graph_t; - typedef std::shared_ptr param_ptr_t; - typedef std::map> param_map_t; - -public: - enum fragment_t{ - STRIDED_SCAN, - HMMA_FRAGMENT_C - }; - -private: - void add_constraint(node_t x, node_t y); - void init_c_phi(ir::instruction *i); - void init_c_graph(ir::instruction *v); - fragment_t get_fragmentation_type(node_t x, graph_t &graph); - void connected_components(node_t x, const std::vector& params, const std::vector& maps, std::set &nodes, graph_t &graph, unsigned group_id); - void create_grids(std::vector &grids, - std::map &references, - ir::function *fn); - - -public: - grids(size_t num_warps, transform::coalesce* coalesce); - void run(ir::module &mod); - const std::vector get() const { return grids_; } - fragment_t fragment_of(ir::value *value, unsigned ax); - unsigned group_of(ir::value *value, unsigned ax); - int mts(ir::value *value, unsigned ax); - int nts(ir::value *value, unsigned ax); - int fpw(ir::value *value, unsigned ax); - int wpt(ir::value *value, unsigned ax); - void copy(ir::value *dst, ir::value *src); - -private: - - transform::coalesce* coalesce_; - // number of warps - size_t num_warps_; - // grids - std::vector grids_; - // grid parameters - param_map_t fpw_; - param_map_t wpt_; - param_map_t mts_; - param_map_t nts_; - // constraints graph - graph_t dependencies_; - std::set nodes_; - // fragments - std::map fragments_; - // parameter groups - std::map> groups_; -}; - - -} -} -} - -#endif diff --git a/include/triton/codegen/analysis/layout.h b/include/triton/codegen/analysis/layout.h new file mode 100644 index 000000000..3bc1f2f6a --- /dev/null +++ b/include/triton/codegen/analysis/layout.h @@ -0,0 +1,57 @@ +#ifndef _TRITON_CODEGEN_ANALYSIS_GRID_H_ +#define _TRITON_CODEGEN_ANALYSIS_GRID_H_ + +#include +#include +#include +#include + +namespace triton{ + +namespace ir{ + class value; + class module; + class instruction; +} + +namespace codegen{ +namespace analysis{ + +class axes; + +class layout { + typedef ir::value* node_t; + typedef std::map > graph_t; + +private: + // connected components + void connected_components(node_t x, std::set &nodes, graph_t &graph, unsigned id); + // list the axes of the given value + std::set axes_of(ir::value *value); + +public: + // constructor + layout(analysis::axes *axes); + // run the passes + void run(ir::module &mod); + // get the layout ID of the given value + unsigned id(ir::value *value) const; + // get the values associates with the given ID + const std::vector& values(unsigned id) const; + // get number of groups + size_t get_num_groups() const; + +private: + analysis::axes* axes_; + graph_t dependencies_; + std::set nodes_; + std::map groups_; + std::map> values_; +}; + +} +} + +} + +#endif diff --git a/include/triton/codegen/analysis/memalloc.h b/include/triton/codegen/analysis/memalloc.h index 91cf89123..f50d00b22 100644 --- a/include/triton/codegen/analysis/memalloc.h +++ b/include/triton/codegen/analysis/memalloc.h @@ -15,15 +15,15 @@ namespace ir{ namespace codegen{ namespace analysis{ -class grids; +class tiles; class liveness; class meminfo; class memalloc { public: - memalloc(liveness *live, meminfo *buffer_info, grids *params) - : liveness_(live), buffer_info_(buffer_info), params_(params){ } + memalloc(liveness *live, meminfo *buffer_info, tiles *params) + : liveness_(live), buffer_info_(buffer_info), tiles_(params){ } // utilities unsigned num_bytes(ir::value *x); unsigned is_ld_padded(ir::value* x); @@ -40,7 +40,7 @@ private: // dependences liveness *liveness_; meminfo *buffer_info_; - grids *params_; + tiles *tiles_; }; } diff --git a/include/triton/codegen/analysis/tiles.h b/include/triton/codegen/analysis/tiles.h new file mode 100644 index 000000000..a9387cb5c --- /dev/null +++ b/include/triton/codegen/analysis/tiles.h @@ -0,0 +1,68 @@ +#ifndef _TRITON_CODEGEN_ANALYSIS_TILES_H_ +#define _TRITON_CODEGEN_ANALYSIS_TILES_H_ + +#include +#include +#include +#include + +namespace triton{ + +namespace ir{ + class value; + class module; + class instruction; + class function; + class metaparameter; + class constant_int; +} + +namespace codegen{ + +namespace transform{ +class coalesce; +} + +namespace analysis{ + +class axes; +class layout; + +class tiles { + typedef std::map> param_map_t; +private: + void init_hmma_tile(ir::value *i); + void init_scanline_tile(ir::value *i); + +public: + tiles(size_t num_warps, transform::coalesce* coalesce, analysis::axes* axes, analysis::layout* layout); + void run(ir::module &mod); + bool hmma(ir::value *value); + int mts(ir::value *value, unsigned ax); + int nts(ir::value *value, unsigned ax); + int fpw(ir::value *value, unsigned ax); + int wpt(ir::value *value, unsigned ax); + const std::map& largest(); + +private: + // dependencies + analysis::layout* layout_; + analysis::axes* axes_; + transform::coalesce* coalesce_; + // number of warps + size_t num_warps_; + // tile properties + std::map hmma_; + std::map largest_; + std::map fpw_; + std::map wpt_; + std::map mts_; + std::map nts_; +}; + + +} +} +} + +#endif diff --git a/include/triton/codegen/selection.h b/include/triton/codegen/selection.h index b4d2e3344..ba92843a4 100644 --- a/include/triton/codegen/selection.h +++ b/include/triton/codegen/selection.h @@ -43,10 +43,12 @@ namespace triton{ namespace codegen{ namespace analysis{ -class grids; +class tiles; class align; class memalloc; class meminfo; +class axes; +class layout; } namespace transform{ @@ -199,8 +201,12 @@ private: public: - selection(analysis::memalloc *alloc, analysis::grids *params, analysis::meminfo *buffer_info, analysis::align *alignment, transform::coalesce* reorder, target *tgt, unsigned num_warps) - : alloc_(alloc), params_(params), buffer_info_(buffer_info), alignment_(alignment), reorder_(reorder), tgt_(tgt), num_warps_(num_warps){ } + selection(analysis::memalloc *alloc, analysis::tiles *tiles, analysis::meminfo *buffer_info, + analysis::align *alignment, analysis::axes *axes, analysis::layout *layouts, + transform::coalesce* reorder, target *tgt, unsigned num_warps) + : alloc_(alloc), tiles_(tiles), buffer_info_(buffer_info), + alignment_(alignment), a_axes_(axes), layouts_(layouts), + reorder_(reorder), tgt_(tgt), num_warps_(num_warps){ } void run(ir::module &src, Module &dst); @@ -208,7 +214,9 @@ private: vmap_t vmap_; tmap_t tmap_; analysis::memalloc *alloc_; - analysis::grids *params_; + analysis::tiles *tiles_; + analysis::axes *a_axes_; + analysis::layout *layouts_; analysis::meminfo *buffer_info_; analysis::align *alignment_; transform::coalesce *reorder_; diff --git a/include/triton/codegen/transform/reassociate.h b/include/triton/codegen/transform/reassociate.h index 318884755..d7e33c9a2 100644 --- a/include/triton/codegen/transform/reassociate.h +++ b/include/triton/codegen/transform/reassociate.h @@ -19,7 +19,7 @@ class getelementptr_inst; namespace codegen{ namespace analysis{ -class grids; +class tiles; class align; } @@ -37,11 +37,10 @@ private: ir::value *reassociate_ptr(ir::getelementptr_inst* pz, ir::builder &builder, std::map &offsets); public: - reassociate(analysis::align* align, analysis::grids *params); + reassociate(analysis::align* align); void run(ir::module& module); private: - analysis::grids* params_; analysis::align* align_; }; diff --git a/include/triton/codegen/transform/vectorize.h b/include/triton/codegen/transform/vectorize.h index bf08eb46f..0a6571b61 100644 --- a/include/triton/codegen/transform/vectorize.h +++ b/include/triton/codegen/transform/vectorize.h @@ -10,18 +10,18 @@ namespace ir { namespace codegen{ namespace analysis{ - class grids; + class tiles; } namespace transform{ class vectorize { public: - vectorize(analysis::grids *params): params_(params){} + vectorize(analysis::tiles *params): params_(params){} void run(ir::module &mod); private: - analysis::grids *params_; + analysis::tiles *params_; }; } diff --git a/include/triton/ir/type.h b/include/triton/ir/type.h index aee2ecc42..4b67d9e94 100644 --- a/include/triton/ir/type.h +++ b/include/triton/ir/type.h @@ -63,6 +63,7 @@ public: unsigned get_primitive_size_in_bits() const; type *get_scalar_ty() const; const tile_shapes_t& get_tile_shapes() const; + const size_t get_tile_rank() const; unsigned get_tile_num_elements() const; type *get_tile_element_ty() const; unsigned get_pointer_address_space() const; diff --git a/include/triton/runtime/function.h b/include/triton/runtime/function.h index 96ec35ef7..9d04cad78 100644 --- a/include/triton/runtime/function.h +++ b/include/triton/runtime/function.h @@ -11,7 +11,7 @@ // codegen #include "triton/codegen/selection.h" #include "triton/codegen/target.h" -#include "triton/codegen/analysis/grid.h" +#include "triton/codegen/analysis/tiles.h" #include "triton/codegen/analysis/memalloc.h" #include "triton/codegen/analysis/liveness.h" #include "triton/codegen/analysis/meminfo.h" @@ -45,7 +45,7 @@ class translation_unit; namespace codegen{ namespace analysis{ -class grids; +class tiles; } } diff --git a/lib/codegen/analysis/axes.cc b/lib/codegen/analysis/axes.cc new file mode 100644 index 000000000..99fc59234 --- /dev/null +++ b/lib/codegen/analysis/axes.cc @@ -0,0 +1,166 @@ +#include "triton/codegen/analysis/axes.h" +#include "triton/ir/instructions.h" +#include "triton/ir/type.h" +#include "triton/ir/module.h" +#include "triton/ir/function.h" +#include "triton/ir/context_impl.h" +#include "triton/ir/constant.h" +#include "triton/driver/device.h" + + + +namespace triton{ +namespace codegen{ +namespace analysis{ + +axes::axes() {} + +void axes::add_constraint(node_t x, node_t y) { + size_t shape_x = 1; + size_t shape_y = 1; + if(x.first->get_type()->is_tile_ty()) + shape_x = x.first->get_type()->get_tile_shapes()[x.second]; + if(y.first->get_type()->is_tile_ty()) + shape_y = y.first->get_type()->get_tile_shapes()[y.second]; + if(shape_x == 1 && shape_y == 1) + return; + dependencies_[x].insert(y); + dependencies_[y].insert(x); + nodes_.insert(x); + nodes_.insert(y); +} + +void axes::init_c_graph(ir::instruction *v) { + // Reference shape + ir::type::tile_shapes_t shapes; + if(auto *store = dynamic_cast(v)) + shapes = store->get_pointer_operand()->get_type()->get_tile_shapes(); + else if(auto *atom = dynamic_cast(v)) + shapes = atom->get_operand(0)->get_type()->get_tile_shapes(); + else if(dynamic_cast(v)) + return; + else if(dynamic_cast(v)) + return; + else if(auto *reduce = dynamic_cast(v)) { + unsigned axis = reduce->get_axis(); + ir::value *arg = reduce->get_operand(0); + auto in_shapes = arg->get_type()->get_tile_shapes(); + unsigned current = 0; + for(unsigned i = 0; i < in_shapes.size(); i++){ + if(i == axis) + continue; + add_constraint({reduce, current++}, {arg, i}); + } + return; + } + else + shapes = v->get_type()->get_tile_shapes(); + // Reshape + if(dynamic_cast(v)) { + ir::value *op = v->get_operand(0); + auto op_shapes = op->get_type()->get_tile_shapes(); + unsigned current = 0; + bool is_skewed = false; + for(unsigned i = 0; i < shapes.size(); i ++){ + if(shapes[i] == 1){ + add_constraint({v, i}, {v, i}); + } + else if(!is_skewed && + shapes[i] == op_shapes[current]) + add_constraint({v, i}, {op, current++}); + else{ + is_skewed = true; + add_constraint({v, i}, {v, i}); + } + } + } + // Splat + else if(dynamic_cast(v)){ + return; + } + // Trans + else if(auto *x = dynamic_cast(v)){ + ir::value *op = v->get_operand(0); + auto perm = x->get_perm(); + for(unsigned i = 0; i < perm.size(); i++) + add_constraint({v, perm[i]->get_value()}, {op, i}); + } + // Broadcast + else if(dynamic_cast(v)){ + ir::value *op = v->get_operand(0); + ir::type *op_ty = op->get_type(); + const auto& op_shapes = op_ty->get_tile_shapes(); + for(unsigned i = 0; i < shapes.size(); i ++){ + if(op_shapes[i] == shapes[i] && v != op) + add_constraint({v, i}, {op, i}); + } + } + // Matrix multiplication + else if(dynamic_cast(v)){ + ir::value *A = v->get_operand(0); + ir::value *B = v->get_operand(1); + ir::value *D = v->get_operand(2); + for(unsigned i = 0; i < shapes.size(); i++) + add_constraint({v, i}, {D, i}); + for(unsigned i = 2; i < shapes.size(); i++){ + add_constraint({v, i}, {A, i}); + add_constraint({v, i}, {B, i}); + } + } + // Element-wise + else if(dynamic_cast(v)) { + for(unsigned i = 0; i < shapes.size(); i ++){ + std::vector ops = v->ops(); + for(ir::value* op: ops) + add_constraint({v, i}, {op, i}); + } + } +} + +void axes::connected_components(node_t x, std::set &nodes, graph_t &graph, unsigned group_id) { + groups_[x.first].insert({x.second, group_id}); + if(nodes.find(x) != nodes.end()){ + nodes.erase(x); + for(const node_t &y: graph[x]) + connected_components(y, nodes, graph, group_id); + } +} + +unsigned axes::get(ir::value *value, unsigned ax) { + unsigned result = groups_.at(value).at(ax); + return result; +} + +bool axes::has(ir::value *value, unsigned ax) { + auto it = groups_.find(value); + if(it == groups_.end()) + return false; + auto iit = it->second.find(ax); + if(iit == it->second.end()) + return false; + return true; +} + + +void axes::run(ir::module &mod) { + nodes_.clear(); + dependencies_.clear(); + groups_.clear(); + // Create graph + for(ir::function *fn: mod.get_function_list()){ + // Build constraints graph + for(ir::basic_block *block: fn->blocks()) + for(ir::instruction *i : block->get_inst_list()) + if(i->has_tile_result_or_op()) + init_c_graph(i); + } + // Axes + unsigned group_id = 0; + while(!nodes_.empty()) + connected_components(*nodes_.begin(), nodes_, dependencies_, group_id++); +} + +} +} + +} diff --git a/lib/codegen/analysis/grid.cc b/lib/codegen/analysis/grid.cc deleted file mode 100644 index 1e6de0de4..000000000 --- a/lib/codegen/analysis/grid.cc +++ /dev/null @@ -1,367 +0,0 @@ -#include -#include -#include "triton/codegen/transform/coalesce.h" -#include "triton/codegen/analysis/grid.h" -#include "triton/ir/instructions.h" -#include "triton/ir/type.h" -#include "triton/ir/module.h" -#include "triton/ir/function.h" -#include "triton/ir/context_impl.h" -#include "triton/ir/constant.h" -#include "triton/driver/device.h" - - - -namespace triton{ -namespace codegen{ -namespace analysis{ - -grids::grids(size_t num_warps, transform::coalesce *reorder): num_warps_(num_warps), coalesce_(reorder) -{ } - -bool is_hmma(ir::value *v){ - bool result = false; - if(auto *x = dynamic_cast(v)){ - ir::value *a = x->get_operand(0); - ir::type *a_ty = a->get_type(); - ir::value *b = x->get_operand(1); - ir::type *b_ty = b->get_type(); - // inputs have to be FP16 - result = a_ty->get_scalar_ty()->is_half_ty() && b_ty->get_scalar_ty()->is_half_ty(); - // reduction has to be multiple of 4: TODO - } - return result; -} - -void grids::add_constraint(node_t x, node_t y) { - dependencies_[x].insert(y); - dependencies_[y].insert(x); - nodes_.insert(x); - nodes_.insert(y); -} - -void grids::init_c_phi(ir::instruction *v) { - // Phi Nodes: all the incoming value share the result layout - if(auto *phi = dynamic_cast(v)) - for(ir::value *op: phi->ops()) - for(unsigned k = 0; k < phi->get_type()->get_tile_shapes().size(); k++) - if(dependencies_.find({op, k}) != dependencies_.end() - || dependencies_.find({phi, k}) != dependencies_.end()){ - add_constraint({phi, k}, {op, k}); - } -} - -void grids::init_c_graph(ir::instruction *v) { - // Reference shape - ir::type::tile_shapes_t shapes; - if(auto *store = dynamic_cast(v)) - shapes = store->get_pointer_operand()->get_type()->get_tile_shapes(); - else if(auto *atom = dynamic_cast(v)) - shapes = atom->get_operand(0)->get_type()->get_tile_shapes(); - else if(dynamic_cast(v)) - return; - else if(dynamic_cast(v)) - return; - else if(auto *reduce = dynamic_cast(v)) { - unsigned axis = reduce->get_axis(); - ir::value *arg = reduce->get_operand(0); - auto in_shapes = arg->get_type()->get_tile_shapes(); - unsigned current = 0; - for(unsigned i = 0; i < in_shapes.size(); i++){ - if(i == axis) - continue; - add_constraint({reduce, current++}, {arg, i}); - } - return; - } - else - shapes = v->get_type()->get_tile_shapes(); - // Reshape - if(dynamic_cast(v)) { - ir::value *op = v->get_operand(0); - auto op_shapes = op->get_type()->get_tile_shapes(); - unsigned current = 0; - bool is_skewed = false; - for(unsigned i = 0; i < shapes.size(); i ++){ - if(shapes[i] == 1){ - add_constraint({v, i}, {v, i}); - } - else if(!is_skewed && - shapes[i] == op_shapes[current]) - add_constraint({v, i}, {op, current++}); - else{ - is_skewed = true; - add_constraint({v, i}, {v, i}); - } - } - } - // Splat - else if(dynamic_cast(v)){ - return; - } - // Trans - else if(auto *x = dynamic_cast(v)){ - ir::value *op = v->get_operand(0); - auto perm = x->get_perm(); - for(unsigned i = 0; i < perm.size(); i++) - add_constraint({v, perm[i]->get_value()}, {op, i}); - } - // Broadcast - else if(dynamic_cast(v)){ - ir::value *op = v->get_operand(0); - ir::type *op_ty = op->get_type(); - const auto& op_shapes = op_ty->get_tile_shapes(); - for(unsigned i = 0; i < shapes.size(); i ++){ - if(op_shapes[i] == shapes[i] && v != op) - add_constraint({v, i}, {op, i}); - } - } - // Matrix multiplication - else if(dynamic_cast(v)){ - ir::value *A = v->get_operand(0); - ir::value *B = v->get_operand(1); - ir::value *D = v->get_operand(2); - for(unsigned i = 0; i < shapes.size(); i++) - add_constraint({v, i}, {D, i}); - for(unsigned i = 2; i < shapes.size(); i++){ - add_constraint({v, i}, {A, i}); - add_constraint({v, i}, {B, i}); - } - } - // Element-wise - else if(dynamic_cast(v)) { - for(unsigned i = 0; i < shapes.size(); i ++){ - std::vector ops = v->ops(); - for(ir::value* op: ops) - add_constraint({v, i}, {op, i}); - } - } -} - -grids::fragment_t grids::get_fragmentation_type(node_t x, graph_t &graph){ - std::list work; - std::set seen; - work.push_back(x); - while(!work.empty()){ - node_t current = work.back(); - if(is_hmma(current.first)) - return HMMA_FRAGMENT_C; - work.pop_back(); - seen.insert(current); - for(node_t y: graph[current]){ - if(seen.find(y) == seen.end()) - work.push_back(y); - } - } - return STRIDED_SCAN; -} - -void grids::connected_components(node_t x, const std::vector& ptr_vec, const std::vector& maps, - std::set &nodes, graph_t &graph, unsigned group_id) { - groups_[x.first].insert({x.second, group_id}); - if(nodes.find(x) != nodes.end()){ - nodes.erase(x); - for(unsigned i = 0; i < ptr_vec.size(); i++) - (*maps[i])[x.first][x.second] = ptr_vec[i]; - for(const node_t &y: graph[x]) - connected_components(y, ptr_vec, maps, nodes, graph, group_id); - } -} - -unsigned grids::group_of(ir::value *value, unsigned ax) { - unsigned result = groups_.at(value).at(ax); - return result; -} - -grids::fragment_t grids::fragment_of(ir::value *value, unsigned ax) { - return fragments_.at({value, ax}); -} - - -//TODO: This shouldn't exist! -void grids::copy(ir::value *dst, ir::value *src) { - mts_[dst] = mts_[src]; - nts_[dst] = nts_[src]; - fpw_[dst] = fpw_[src]; - wpt_[dst] = wpt_[src]; - groups_[dst] = groups_[src]; - fragments_[{dst, 0}] = fragments_[{src, 0}]; -} - - -void grids::run(ir::module &mod) { - // Create tiling parameters - for(ir::function *fn: mod.get_function_list()){ - // Build constraints graph - for(ir::basic_block *block: fn->blocks()) - for(ir::instruction *i : block->get_inst_list()) - if(i->has_tile_result_or_op()) - init_c_graph(i); - // Build phi constraints - for(ir::basic_block *block: fn->blocks()) - for(ir::instruction *i : block->get_inst_list()) - if(i->has_tile_result_or_op()) - init_c_phi(i); - // Layout parameters - unsigned group_id = 0; - for(auto x: nodes_) - fragments_[x] = get_fragmentation_type(x, dependencies_); - while(!nodes_.empty()) { - node_t node = *nodes_.begin(); - if(fragments_[node] == STRIDED_SCAN) { - param_ptr_t nts(new int(-1)); - param_ptr_t mts(new int(-1)); - connected_components(node, {nts, mts}, {&nts_, &mts_}, nodes_, dependencies_, group_id++); - } - else { - param_ptr_t fpw(new int(-1)); - param_ptr_t wpt(new int(-1)); - connected_components(node, {fpw, wpt}, {&fpw_, &wpt_}, nodes_, dependencies_, group_id++); - } - } - } - - for(ir::function *fn: mod.get_function_list()){ - std::map references; - create_grids(grids_, references, fn); - } - - - unsigned num_threads = num_warps_*32; - auto clamp = [&](unsigned x, unsigned lo, unsigned hi) { return std::min(std::max(x, lo), hi); }; - - for(ir::value *i: grids_){ - if(!i->get_type()->is_tile_ty()) - continue; - auto order = coalesce_->get_order(i); - auto shapes = i->get_type()->get_tile_shapes(); - unsigned size = i->get_type()->get_tile_num_elements(); - /* HMMA parameters*/ - if(fragments_.at({i, 0}) == HMMA_FRAGMENT_C){ - unsigned shape_0 = shapes[order[0]]; - unsigned shape_1 = shapes[order[1]]; - /* fragments per warp */ - // try to make things as square as possible to maximize data re-use - std::vector fpw = {1, 1, 1}; - std::vector fpw_nm1; - unsigned num_fragments = std::min((shape_0/8)*(shape_1/8), 4); - do { - fpw_nm1 = fpw; - if(fpw[0]*fpw[1] < num_fragments) - fpw[0] = clamp(fpw[0]*2, 1, shape_0 / 8); - if(fpw[0]*fpw[1] < num_fragments) - fpw[1] = clamp(fpw[1]*2, 1, shape_1 / 8); - }while(fpw_nm1 != fpw); - // store parameters - for(unsigned d = 0; d < shapes.size(); d++) - *fpw_[i][d] = fpw[d]; - /* warps per tile */ - // try to make things as square as possible to maximize data re-use - std::vector wpt = {1, 1, 1}; - std::vector wpt_nm1; - do{ - wpt_nm1 = wpt; - if(wpt[0] * wpt[1] * wpt[2] < num_warps_) - wpt[0] = clamp(wpt[0]*2, 1, shape_0 / (fpw[0]*8)); - if(wpt[0] * wpt[1] * wpt[2] < num_warps_) - wpt[1] = clamp(wpt[1]*2, 1, shape_1 / (fpw[1]*8)); - }while(wpt_nm1 != wpt); - // store parameters - for(unsigned d = 0; d < shapes.size(); d++) - *wpt_[i][d] = wpt[d]; - /* sanity check */ - unsigned effective_num_warps = 1; - for(size_t d = 0; d < shapes.size(); d++) - effective_num_warps *= *wpt_[i][d]; - if(num_warps_ != effective_num_warps) - throw std::runtime_error("cannot create a kernel with this amount of warps"); - } - - /* Scan-line */ - else{ - unsigned ld = order[0]; - unsigned current = num_threads; - *nts_[i][ld] = clamp(size / num_threads, 1, 4); - *mts_[i][ld] = clamp(current, 1, shapes[ld] / *nts_[i][ld]); - current = current / *mts_[i][ld]; - for(size_t d = 1; d < shapes.size(); d++){ - ld = order[d]; - *nts_[i][ld] = 1; - *mts_[i][ld] = clamp(current, 1, shapes[ld]); - current = current / *mts_[i][ld]; - } - /* sanity check */ - unsigned effective_num_threads = 1; - for(size_t d = 0; d < shapes.size(); d++) - effective_num_threads *= *mts_[i][d]; - if(num_threads != effective_num_threads) - throw std::runtime_error("cannot create a kernel with this amount of warps"); - } - } - -} - - -void grids::create_grids(std::vector &grids, - std::map &references, - ir::function *fn) { - // get number of dimensions greater than 1 - auto get_tile_gt1_dim = [&](ir::value *v){ - unsigned result = 0; - for(auto shape: v->get_type()->get_tile_shapes()) { - result += (shape > 1)? shape : 0; - } - return result; - }; - // bind references - std::set seen; - std::function bind_references = [&](ir::value *v) - { - // skip - if(!v->get_type()->is_tile_ty() || !seen.insert(v).second) - return; - // recurse - if(auto *user = dynamic_cast(v)) - for(ir::value *op: user->ops()) - bind_references(op); - // bind - const auto& shapes = v->get_type()->get_tile_shapes(); - for(size_t d = 0; d < shapes.size(); d++){ - if(shapes[d] == 1) - continue; - unsigned x = group_of(v, d); - ir::value *&r = references[x]; - if(!r || get_tile_gt1_dim(v) > get_tile_gt1_dim(r)) - r = v; - } - }; - - for(ir::basic_block *block: fn->blocks()) - for(ir::instruction *i: block->get_inst_list()) - bind_references(i); - - // create grid - for(auto &ref: references) - if(std::find(grids.begin(), grids.end(), ref.second) == grids.end()) - grids.push_back(ref.second); -} - -int grids::mts(ir::value *value, unsigned ax) { - return *mts_.at(value).at(ax); -} - -int grids::nts(ir::value *value, unsigned ax) { - return *nts_.at(value).at(ax); -} - -int grids::fpw(ir::value *value, unsigned ax) { - return *fpw_.at(value).at(ax); -} - -int grids::wpt(ir::value *value, unsigned ax) { - return *wpt_.at(value).at(ax); -} - -} -} -} diff --git a/lib/codegen/analysis/layout.cc b/lib/codegen/analysis/layout.cc new file mode 100644 index 000000000..a6eade0b2 --- /dev/null +++ b/lib/codegen/analysis/layout.cc @@ -0,0 +1,96 @@ +#include +#include +#include "triton/codegen/analysis/axes.h" +#include "triton/codegen/analysis/layout.h" +#include "triton/ir/function.h" +#include "triton/ir/module.h" + +namespace triton{ +namespace codegen{ +namespace analysis{ + + +// axes +std::set layout::axes_of(ir::value *value) { + auto ty = value->get_type(); + // rank of value + size_t rank = 0; + if(ty->is_tile_ty()) + rank = ty->get_tile_rank(); + // create result + std::set result; + for(size_t d = 0; d < rank; d++){ + if(axes_->has(value, d)) + result.insert(axes_->get(value, d)); + } + return result; +} + +// connected components +void layout::connected_components(node_t x, std::set &nodes, graph_t &graph, unsigned group_id) { + groups_[x] = group_id; + values_[group_id].push_back(x); + if(nodes.find(x) != nodes.end()){ + nodes.erase(x); + for(const node_t &y: graph[x]) + connected_components(y, nodes, graph, group_id); + } +} + +// constructor +layout::layout(analysis::axes *axes) + : axes_(axes) { } + +// get group id +unsigned layout::id(ir::value *value) const +{ return groups_.at(value); } + +// get values +const std::vector& layout::values(unsigned id) const +{ return values_.at(id); } + +// get number of groups +size_t layout::get_num_groups() const +{ return values_.size(); } + +// run +void layout::run(ir::module &mod) { + nodes_.clear(); + dependencies_.clear(); + groups_.clear(); + values_.clear(); + // Create graph + for(ir::function *fn: mod.get_function_list()) + for(ir::basic_block *block: fn->blocks()) + for(ir::instruction *i : block->get_inst_list()) { + // skip scalars + if(!i->get_type()->is_tile_ty()) + continue; + // add an edge between i and the operands that share an axis + std::set i_axes = axes_of(i); + nodes_.insert(i); + for(ir::value* op: i->ops()){ + if(!op->get_type()->is_tile_ty()) + continue; + nodes_.insert(op); + std::set op_axes = axes_of(op); + std::set common; + std::set_intersection(i_axes.begin(), i_axes.end(), + op_axes.begin(), op_axes.end(), + std::inserter(common, common.begin())); + if(!common.empty() || !op->get_type()->is_tile_ty()){ + dependencies_[i].insert(op); + dependencies_[op].insert(i); + } + } + } + // Grids + unsigned group_id = 0; + while(!nodes_.empty()){ + connected_components(*nodes_.begin(), nodes_, dependencies_, group_id++); + } +} + +} +} +} diff --git a/lib/codegen/analysis/memalloc.cc b/lib/codegen/analysis/memalloc.cc index 866f9c7a4..7f80824e3 100644 --- a/lib/codegen/analysis/memalloc.cc +++ b/lib/codegen/analysis/memalloc.cc @@ -2,7 +2,7 @@ #include "triton/codegen/analysis/memalloc.h" #include "triton/codegen/analysis/liveness.h" #include "triton/codegen/analysis/meminfo.h" -#include "triton/codegen/analysis/grid.h" +#include "triton/codegen/analysis/tiles.h" #include "triton/ir/basic_block.h" #include "triton/ir/type.h" #include "triton/ir/value.h" @@ -20,7 +20,7 @@ unsigned memalloc::is_ld_padded(ir::value *x) { } for(ir::user* user: x->get_users()) if(auto dot = dynamic_cast(user)){ - bool is_hmma = params_->fragment_of(user, 0) == grids::HMMA_FRAGMENT_C; + bool is_hmma = tiles_->hmma(user); bool is_op_0 = x == dot->get_operand(0); bool is_op_1 = x == dot->get_operand(1); if(is_hmma && is_op_0){ @@ -56,10 +56,10 @@ unsigned memalloc::num_bytes(ir::value *x) { for(auto x: shapes) num_elements *= x; size_t depth; - if(params_->fragment_of(x, 0) == grids::HMMA_FRAGMENT_C) - depth = params_->wpt(op, axis); + if(tiles_->hmma(x)) + depth = tiles_->wpt(op, axis); else - depth = params_->mts(op, axis); + depth = tiles_->mts(op, axis); return num_elements * num_bytes * depth; } unsigned num_bytes = x->get_type()->get_primitive_size_in_bits() / 8; diff --git a/lib/codegen/analysis/tiles.cc b/lib/codegen/analysis/tiles.cc new file mode 100644 index 000000000..7b8505ff7 --- /dev/null +++ b/lib/codegen/analysis/tiles.cc @@ -0,0 +1,176 @@ +#include +#include +#include "triton/codegen/analysis/axes.h" +#include "triton/codegen/analysis/tiles.h" +#include "triton/codegen/analysis/layout.h" +#include "triton/codegen/transform/coalesce.h" +#include "triton/ir/instructions.h" +#include "triton/ir/type.h" +#include "triton/ir/module.h" +#include "triton/ir/function.h" +#include "triton/ir/context_impl.h" +#include "triton/ir/constant.h" +#include "triton/driver/device.h" + + + +namespace triton{ +namespace codegen{ +namespace analysis{ + +tiles::tiles(size_t num_warps, transform::coalesce *reorder, analysis::axes *axes, analysis::layout *layout): + num_warps_(num_warps), coalesce_(reorder), axes_(axes), layout_(layout) +{ } + +bool is_hmma(ir::value *v){ + bool result = false; + if(auto *x = dynamic_cast(v)){ + ir::value *a = x->get_operand(0); + ir::type *a_ty = a->get_type(); + ir::value *b = x->get_operand(1); + ir::type *b_ty = b->get_type(); + result = a_ty->get_scalar_ty()->is_half_ty() && + b_ty->get_scalar_ty()->is_half_ty(); + } + return result; +} + + + +bool tiles::hmma(ir::value *value) { + return hmma_.at(layout_->id(value)); +} + +int tiles::mts(ir::value *value, unsigned ax) { + return mts_.at(axes_->get(value, ax)); +} + +int tiles::nts(ir::value *value, unsigned ax) { + return nts_.at(axes_->get(value, ax)); +} + +int tiles::fpw(ir::value *value, unsigned ax) { + return fpw_.at(axes_->get(value, ax)); +} + +int tiles::wpt(ir::value *value, unsigned ax) { + return wpt_.at(axes_->get(value, ax)); +} + +const std::map& tiles::largest() { + return largest_; +} + + +unsigned clamp(unsigned x, unsigned lo, unsigned hi) { + return std::min(std::max(x, lo), hi); +} + + +void tiles::init_hmma_tile(ir::value *i) { + auto order = coalesce_->get_order(i); + auto shapes = i->get_type()->get_tile_shapes(); + unsigned shape_0 = shapes[order[0]]; + unsigned shape_1 = shapes[order[1]]; + /* fragments per warp */ + // try to make things as square as possible to maximize data re-use + std::vector fpw = {1, 1, 1}; + std::vector fpw_nm1; + unsigned num_fragments = std::min((shape_0/8)*(shape_1/8), 4); + do { + fpw_nm1 = fpw; + if(fpw[0]*fpw[1] < num_fragments) + fpw[0] = clamp(fpw[0]*2, 1, shape_0 / 8); + if(fpw[0]*fpw[1] < num_fragments) + fpw[1] = clamp(fpw[1]*2, 1, shape_1 / 8); + }while(fpw_nm1 != fpw); + // store parameters + for(unsigned d = 0; d < shapes.size(); d++) + fpw_[axes_->get(i, d)] = fpw[d]; + /* warps per tile */ + // try to make things as square as possible to maximize data re-use + std::vector wpt = {1, 1, 1}; + std::vector wpt_nm1; + do{ + wpt_nm1 = wpt; + if(wpt[0] * wpt[1] * wpt[2] < num_warps_) + wpt[0] = clamp(wpt[0]*2, 1, shape_0 / (fpw[0]*8)); + if(wpt[0] * wpt[1] * wpt[2] < num_warps_) + wpt[1] = clamp(wpt[1]*2, 1, shape_1 / (fpw[1]*8)); + }while(wpt_nm1 != wpt); + // store parameters + for(unsigned d = 0; d < shapes.size(); d++) + wpt_[axes_->get(i, d)] = wpt[d]; + /* sanity check */ + unsigned effective_num_warps = 1; + for(size_t d = 0; d < shapes.size(); d++) + effective_num_warps *= wpt_[axes_->get(i, d)]; + if(num_warps_ != effective_num_warps) + throw std::runtime_error("cannot create a kernel with this amount of warps"); +} + +void tiles::init_scanline_tile(ir::value *i) { + auto order = coalesce_->get_order(i); + auto shapes = i->get_type()->get_tile_shapes(); + unsigned size = i->get_type()->get_tile_num_elements(); + unsigned ld = order[0]; + unsigned num_threads = num_warps_*32; + unsigned current = num_threads; + nts_[axes_->get(i, ld)] = clamp(size / num_threads, 1, 4); + mts_[axes_->get(i, ld)] = clamp(current, 1, shapes[ld] / nts_[axes_->get(i, ld)]); + current = current / mts_[axes_->get(i, ld)]; + for(size_t d = 1; d < shapes.size(); d++){ + ld = order[d]; + nts_[axes_->get(i, ld)] = 1; + mts_[axes_->get(i, ld)] = clamp(current, 1, shapes[ld]); + current = current / mts_[axes_->get(i, ld)]; + } + /* sanity check */ + unsigned effective_num_threads = 1; + for(size_t d = 0; d < shapes.size(); d++) + effective_num_threads *= mts_[axes_->get(i, d)]; + if(num_threads != effective_num_threads) + throw std::runtime_error("cannot create a kernel with this amount of warps"); +} + +void tiles::run(ir::module &) { + hmma_.clear(); + largest_.clear(); + size_t num_groups = layout_->get_num_groups(); + // find out which groups require hmma layout + for(size_t i = 0; i < num_groups; i++) { + const auto& values = layout_->values(i); + hmma_[i] = std::any_of(values.begin(), values.end(), &is_hmma); + } + // find out which value is the largest in each group +// std::vector axes; + for(size_t i = 0; i < num_groups; i++) { + const auto& values = layout_->values(i); + auto rank = [](ir::value* v) { + ir::type *ty = v->get_type(); + size_t ret = 0; + if(ty->is_tile_ty()) + for(int s: ty->get_tile_shapes()) + ret += s > 1; + return ret; + }; + auto cmp = [&rank](ir::value* x, ir::value *y) { return rank(x) < rank(y); }; + largest_[i] = *std::max_element(values.begin(), values.end(), cmp); + } + + // tiling parameters + for(auto x: largest_){ + ir::value *i = x.second; + if(!i->get_type()->is_tile_ty()) + continue; + /* HMMA parameters*/ + if(hmma_[x.first]) + init_hmma_tile(i); + else + init_scanline_tile(i); + } +} + +} +} +} diff --git a/lib/codegen/selection.cc b/lib/codegen/selection.cc index e78228070..62b68e78a 100644 --- a/lib/codegen/selection.cc +++ b/lib/codegen/selection.cc @@ -1,6 +1,8 @@ #include "triton/codegen/selection.h" #include "triton/codegen/target.h" -#include "triton/codegen/analysis/grid.h" +#include "triton/codegen/analysis/layout.h" +#include "triton/codegen/analysis/axes.h" +#include "triton/codegen/analysis/tiles.h" #include "triton/codegen/analysis/memalloc.h" #include "triton/codegen/analysis/align.h" #include "triton/codegen/transform/coalesce.h" @@ -584,8 +586,8 @@ void selection::init_strided_scan_axes(ir::value *v, IRBuilder<> &builder, Value std::vector warp_size(dim); std::vector n_warps(dim); for(unsigned i = 0; i < shapes.size(); i++){ - contiguous[i] = params_->nts(v, i); - block_size[i] = params_->mts(v, i); + contiguous[i] = tiles_->nts(v, i); + block_size[i] = tiles_->mts(v, i); } to_warps(block_size, order, n_warps, warp_size); std::vector thread_id_in_warp = delinearize(u_thread_id, order, warp_size, builder); @@ -604,7 +606,7 @@ void selection::init_strided_scan_axes(ir::value *v, IRBuilder<> &builder, Value unsigned offset = n / contiguous[k] * per_block + n % contiguous[k]; idx_list[n] = builder.CreateAdd(scaled_thread_id, builder.getInt32(offset), "idx_" + str_k + "_" + std::to_string(n)); } - axes_[params_->group_of(v, k)] = distributed_axis{contiguous[k], idx_list, thread_id}; + axes_[a_axes_->get(v, k)] = distributed_axis{contiguous[k], idx_list, thread_id}; } } @@ -622,13 +624,13 @@ void selection::init_hmma_axes(ir::value *v, IRBuilder<> &builder, Value *u_thre Value *_16 = builder.getInt32(16); // fragments per warp - unsigned fpw_0 = params_->fpw(v, 0); - unsigned fpw_1 = params_->fpw(v, 1); - unsigned fpw_2 = is_batched ? params_->fpw(v, 2) : 1; + unsigned fpw_0 = tiles_->fpw(v, 0); + unsigned fpw_1 = tiles_->fpw(v, 1); + unsigned fpw_2 = is_batched ? tiles_->fpw(v, 2) : 1; // warps per tile - unsigned wpt_0 = params_->wpt(v, 0); - unsigned wpt_1 = params_->wpt(v, 1); - unsigned wpt_2 = is_batched ? params_->wpt(v, 2) : 1; + unsigned wpt_0 = tiles_->wpt(v, 0); + unsigned wpt_1 = tiles_->wpt(v, 1); + unsigned wpt_2 = is_batched ? tiles_->wpt(v, 2) : 1; // hmma warp tile size unsigned hmma_wts_0 = fpw_0 * 8; unsigned hmma_wts_1 = fpw_1 * 8; @@ -709,18 +711,18 @@ void selection::init_hmma_axes(ir::value *v, IRBuilder<> &builder, Value *u_thre /* axes */ - axes_[params_->group_of(v, 0)] = distributed_axis{1, idx_i, warp_id_0}; - axes_[params_->group_of(v, 1)] = distributed_axis{1, idx_j, warp_id_1}; + axes_[a_axes_->get(v, 0)] = distributed_axis{1, idx_i, warp_id_0}; + axes_[a_axes_->get(v, 1)] = distributed_axis{1, idx_j, warp_id_1}; if(is_batched) - axes_[params_->group_of(v, 2)] = distributed_axis{1, idx_z, warp_id_2}; + axes_[a_axes_->get(v, 2)] = distributed_axis{1, idx_z, warp_id_2}; } void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) { - if(params_->fragment_of(v, 0) == analysis::grids::STRIDED_SCAN) - init_strided_scan_axes(v, builder, u_thread_id, u_warp_id); - else + if(tiles_->hmma(v)) init_hmma_axes(v, builder, u_thread_id, u_warp_id); + else + init_strided_scan_axes(v, builder, u_thread_id, u_warp_id); } /* ------------------- @@ -780,7 +782,7 @@ void selection::create_distributed_tile(ir::value *v, IRBuilder<> &builder) { std::vector axes(shapes.size()); for(size_t d = 0; d < shapes.size(); d++){ if(shapes[d] > 1){ - unsigned x = params_->group_of(v, d); + unsigned x = a_axes_->get(v, d); axes[d] = axes_.at(x); } else{ @@ -831,8 +833,8 @@ void selection::init_grids(ir::function *fn, IRBuilder<> &builder, Value *sh_mem Value *u_thread_warp_id = builder.CreateURem(u_thread_id, warp_size); Value *u_warp_id = builder.CreateUDiv(u_thread_id, warp_size); // create grid - for(ir::value* i: params_->get()) - init_axes(i, builder, u_thread_warp_id, u_warp_id); + for(auto x: tiles_->largest()) + init_axes(x.second, builder, u_thread_warp_id, u_warp_id); // create tile std::set seen; for(ir::basic_block *block: fn->blocks()) @@ -915,7 +917,7 @@ void selection::lower_reduce(ir::reduce_inst *x, LLVMContext &ctx, Function *fn, Value *base_ptr = builder.CreateBitCast(sh_mem_ptr_, PointerType::get(res_ty, addr_space)); for(auto& x: partial) { // current element being computed - Value *lane = axes_.at(params_->group_of(op, axis)).thread_id; + Value *lane = axes_.at(a_axes_->get(op, axis)).thread_id; Value *&result = x.second; indices_t write_idx = x.first; write_idx.insert(write_idx.begin() + axis, lane); @@ -928,7 +930,7 @@ void selection::lower_reduce(ir::reduce_inst *x, LLVMContext &ctx, Function *fn, tgt_->add_barrier(module, builder); builder.CreateStore(result, write_ptr); // build result - unsigned depth = params_->wpt(op, axis); + unsigned depth = tiles_->wpt(op, axis); for(unsigned i = depth/2; i > 0; i >>= 1){ // current indices indices_t current(write_idx.size(), builder.getInt32(0)); @@ -1095,12 +1097,12 @@ void selection::lower_hmma_dot(ir::dot_inst *dot, LLVMContext &ctx, Function *fn "{$10, $11}, " "{$0, $1, $2, $3, $4, $5, $6, $7};", "=f,=f,=f,=f,=f,=f,=f,=f,r,r,r,r,0,1,2,3,4,5,6,7", false); - unsigned fpw_0 = params_->fpw(dot, 0); - unsigned fpw_1 = params_->fpw(dot, 1); + unsigned fpw_0 = tiles_->fpw(dot, 0); + unsigned fpw_1 = tiles_->fpw(dot, 1); unsigned wts_0 = fpw_0 * 8; unsigned wts_1 = fpw_1 * 8; - unsigned wpt_0 = params_->wpt(dot, 0); - unsigned wpt_1 = params_->wpt(dot, 1); + unsigned wpt_0 = tiles_->wpt(dot, 0); + unsigned wpt_1 = tiles_->wpt(dot, 1); unsigned stride_rep_i = wpt_0 * wts_0; unsigned stride_rep_j = wpt_1 * wts_1; unsigned num_rep_i = shapes[0] / stride_rep_i; @@ -1241,10 +1243,11 @@ void selection::lower_dot(ir::dot_inst *dot, LLVMContext &ctx, Function *fn, IRB if(NK != 1) { shared_tile *TA = (shared_tile*)tmap_.at(A); shared_tile *TB = (shared_tile*)tmap_.at(B); - if(params_->fragment_of(dot, 0) == analysis::grids::STRIDED_SCAN) - lower_scanline_dot(dot, ctx, fn, builder, TC, TA, TB, TD, NK, c_ty, f_mul_add); - else + if(tiles_->hmma(dot)) lower_hmma_dot(dot, ctx, fn, builder, TC, TA, TB, TD, NK); + else + lower_scanline_dot(dot, ctx, fn, builder, TC, TA, TB, TD, NK, c_ty, f_mul_add); + } else { distributed_tile *TA = (distributed_tile*)tmap_.at(A); diff --git a/lib/codegen/transform/reassociate.cc b/lib/codegen/transform/reassociate.cc index f059aba88..8ca89cda2 100644 --- a/lib/codegen/transform/reassociate.cc +++ b/lib/codegen/transform/reassociate.cc @@ -2,7 +2,6 @@ #include #include "triton/codegen/transform/reassociate.h" #include "triton/codegen/analysis/align.h" -#include "triton/codegen/analysis/grid.h" #include "triton/ir/module.h" #include "triton/ir/function.h" #include "triton/ir/basic_block.h" @@ -90,11 +89,6 @@ ir::value *reassociate::reassociate_idx(ir::value *old_value, new_rhs = builder.create_splat(old_rhs, shapes); new_value = builder.create_add(new_lhs, new_rhs, op->get_name()); } - if(new_value != old_value){ - params_->copy(new_value, old_value); - params_->copy(new_lhs, old_value); - params_->copy(new_rhs, old_value); - } } } @@ -127,11 +121,6 @@ ir::value *reassociate::reassociate_idx(ir::value *old_value, if(is_cst(rrhs)) new_value = builder.create_add(rrhs, builder.create_add(lrhs, lhs), name, cst); } - if(new_value != old_value){ - params_->copy(new_value, old_value); - params_->copy(((ir::instruction*)new_value)->get_operand(0), old_value); - params_->copy(((ir::instruction*)new_value)->get_operand(1), old_value); - } } // extract constant and non-constant @@ -156,8 +145,7 @@ ir::value *reassociate::reassociate_idx(ir::value *old_value, return new_value; } -reassociate::reassociate(analysis::align *align, analysis::grids* params) - : params_(params), align_(align) +reassociate::reassociate(analysis::align *align): align_(align) { } @@ -183,9 +171,6 @@ void reassociate::run(ir::module &mod) { ir::value* static_range = ir::make_range_sta::get(old_range); ir::value* new_range = builder.create_add(dyn_range, static_range); old_range->replace_all_uses_with(new_range); - params_->copy(dyn_range, old_range); - params_->copy(static_range, old_range); - params_->copy(new_range, old_range); } } @@ -214,9 +199,6 @@ void reassociate::run(ir::module &mod) { ir::value* ndyn = builder.create_broadcast(dyn, shapes); ir::value* broadcast = builder.create_broadcast(cst, shapes); ir::getelementptr_inst* nsta = (ir::getelementptr_inst*)builder.create_gep(ndyn, {broadcast}); - params_->copy(ndyn, rt); - params_->copy(nsta, rt); - params_->copy(broadcast, rt); infos[rt] = cst_info{ndyn, nsta}; } } @@ -236,8 +218,6 @@ void reassociate::run(ir::module &mod) { builder.set_insert_point(pz); ir::value *dyn_ptr = builder.create_gep(py, {dyn}); ir::value *sta_ptr = builder.create_gep(dyn_ptr, {sta}); - params_->copy(dyn_ptr, pz); - params_->copy(sta_ptr, pz); pz->replace_all_uses_with(sta_ptr); infos[sta_ptr].dyn_ptr = dyn_ptr; infos[sta_ptr].sta_ptr = (ir::getelementptr_inst*)sta_ptr; @@ -252,8 +232,6 @@ void reassociate::run(ir::module &mod) { ir::value *off = *pz->idx_begin(); ir::value *pz_dyn = builder.create_gep(dyn, {off}); ir::value *pz_sta = builder.create_gep(pz_dyn, {cst}, pz->get_name()); - params_->copy(pz_dyn, pz); - params_->copy(pz_sta, pz); pz->replace_all_uses_with(pz_sta); infos[pz_sta].dyn_ptr = pz_dyn; infos[pz_sta].sta_ptr = (ir::getelementptr_inst*)pz_sta; @@ -298,12 +276,6 @@ void reassociate::run(ir::module &mod) { ir::value *neg_off = builder.create_neg(off); ir::value *pz_dyn = builder.create_gep(pz, {neg_off}); phi_dyn->add_incoming(pz_dyn, phi->get_incoming_block(idx_z)); - // copy parameters - params_->copy(pz_dyn, pz); - params_->copy(((ir::instruction*)neg_off)->get_operand(0), off); - params_->copy(neg_off, off); - params_->copy(phi_dyn, phi); - params_->copy(phi_sta, phi); infos[phi_sta].dyn_ptr = phi_dyn; infos[phi_sta].sta_ptr = (ir::getelementptr_inst*)phi_sta; replaced.insert(phi); diff --git a/lib/codegen/transform/vectorize.cc b/lib/codegen/transform/vectorize.cc index 4d1b88541..ef120f903 100644 --- a/lib/codegen/transform/vectorize.cc +++ b/lib/codegen/transform/vectorize.cc @@ -1,5 +1,5 @@ #include "triton/codegen/transform/vectorize.h" -#include "triton/codegen/analysis/grid.h" +#include "triton/codegen/analysis/tiles.h" #include "triton/ir/module.h" #include "triton/ir/function.h" #include "triton/ir/basic_block.h" @@ -23,7 +23,6 @@ void vectorize::run(ir::module &mod) { ir::instruction *rx = (ir::instruction*)builder.create_vectorize(x); x->replace_all_uses_with(rx); rx->set_operand(0, x); - params_->copy(rx, x); } if(dynamic_cast(i)){ ir::value *x = i->get_operand(0); @@ -33,7 +32,6 @@ void vectorize::run(ir::module &mod) { ir::instruction *rx = (ir::instruction*)builder.create_vectorize(x); x->replace_all_uses_with(rx); rx->set_operand(0, x); - params_->copy(rx, x); } } } diff --git a/lib/ir/type.cc b/lib/ir/type.cc index aa3d9aa46..198553b52 100644 --- a/lib/ir/type.cc +++ b/lib/ir/type.cc @@ -73,6 +73,10 @@ const type::tile_shapes_t &type::get_tile_shapes() const { return ((tile_type*)this)->get_shapes(); } +const size_t type::get_tile_rank() const { + return get_tile_shapes().size(); +} + unsigned type::get_tile_num_elements() const { const tile_shapes_t& shapes = get_tile_shapes(); unsigned result = 1; diff --git a/lib/runtime/function.cc b/lib/runtime/function.cc index 9c3f99869..d29e9d7ce 100644 --- a/lib/runtime/function.cc +++ b/lib/runtime/function.cc @@ -3,6 +3,9 @@ #include #include #include +#include "triton/codegen/analysis/axes.h" +#include "triton/codegen/analysis/layout.h" +#include "triton/codegen/analysis/tiles.h" #include "triton/codegen/selection.h" #include "triton/runtime/function.h" #include "triton/codegen/transform/coalesce.h" @@ -192,49 +195,54 @@ function::caller function::autotune(driver::stream* stream, const grid_fn_ty& gr std::unique_ptr function::make_bin(ir::module &module, driver::context *context, const options_t& opt) { std::unique_ptr target = context->device()->make_target(); - + // generate llvm code + llvm::LLVMContext ctx; + std::unique_ptr llvm(new llvm::Module(module.get_name(), ctx)); // create passes codegen::analysis::meminfo shmem_info; - codegen::analysis::liveness shmem_liveness(&shmem_info); codegen::analysis::align alignment_info; + codegen::analysis::liveness shmem_liveness(&shmem_info); codegen::transform::coalesce coalesce(&alignment_info, &shmem_info); - codegen::analysis::grids grids(opt.num_warps, &coalesce); - codegen::analysis::memalloc shmem_allocation(&shmem_liveness, &shmem_info, &grids); + codegen::analysis::axes axes; + codegen::analysis::layout layouts(&axes); + codegen::analysis::tiles tiles(opt.num_warps, &coalesce, &axes, &layouts); + codegen::analysis::memalloc shmem_allocation(&shmem_liveness, &shmem_info, &tiles); codegen::transform::membar shmem_barriers(&shmem_allocation, &shmem_info); - codegen::transform::vectorize vectorize(&grids); + codegen::transform::vectorize vectorize(&tiles); codegen::transform::dce dce; codegen::transform::peephole peephole; - codegen::transform::reassociate reassociate(&alignment_info, &grids); - codegen::selection selection(&shmem_allocation, &grids, &shmem_info, &alignment_info, &coalesce, target.get(), opt.num_warps); + codegen::transform::reassociate reassociate(&alignment_info); + codegen::selection selection(&shmem_allocation, &tiles, &shmem_info, &alignment_info, &axes, &layouts, &coalesce, target.get(), opt.num_warps); // run passes peephole.run(module); dce.run(module); alignment_info.run(module); - if(target->is_gpu()) - shmem_info.run(module); + shmem_info.run(module); coalesce.run(module); dce.run(module); - grids.run(module); + axes.run(module); + layouts.run(module); + tiles.run(module); alignment_info.run(module); reassociate.run(module); dce.run(module); peephole.run(module); - if(target->is_gpu()){ - shmem_info.run(module); - shmem_liveness.run(module); - shmem_allocation.run(); - if(shmem_allocation.allocated_size() > context->device()->max_shared_memory()) - return std::unique_ptr(); - shmem_barriers.run(module); - } + shmem_info.run(module); + shmem_liveness.run(module); + shmem_allocation.run(); + if(shmem_allocation.allocated_size() > context->device()->max_shared_memory()) + return std::unique_ptr(); + shmem_barriers.run(module); dce.run(module); vectorize.run(module); dce.run(module); alignment_info.run(module); + coalesce.run(module); + dce.run(module); // ir::print(module, std::cout); - // generate llvm code - llvm::LLVMContext ctx; - std::unique_ptr llvm(new llvm::Module(module.get_name(), ctx)); + axes.run(module); + layouts.run(module); + tiles.run(module); selection.run(module, *llvm); // return binary std::unique_ptr res(driver::module::create(context, std::move(llvm))); diff --git a/tests/bench/dot.cc b/tests/bench/dot.cc index 7f2366ecc..9a59f3ea7 100644 --- a/tests/bench/dot.cc +++ b/tests/bench/dot.cc @@ -45,10 +45,10 @@ std::vector do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, i opt.defines.push_back({"TYPE", {ty}}); opt.defines.push_back({"AT", {AT?"1":"0"}}); opt.defines.push_back({"BT", {BT?"1":"0"}}); - opt.defines.push_back({"TM", {"64", "128"}}); - opt.defines.push_back({"TN", {"64", "128"}}); + opt.defines.push_back({"TM", {"128"}}); + opt.defines.push_back({"TN", {"128"}}); opt.defines.push_back({"TK", {"8"}}); - opt.num_warps = {2, 4, 8}; + opt.num_warps = {4}; // create function rt::function function(src::dot, opt); // benchmark available libraries