diff --git a/include/triton/codegen/analysis/grid.h b/include/triton/codegen/analysis/grid.h index 467ba9fff..50a8c578a 100644 --- a/include/triton/codegen/analysis/grid.h +++ b/include/triton/codegen/analysis/grid.h @@ -50,16 +50,15 @@ private: public: grids(size_t num_warps, transform::coalesce* reorder); - unsigned get_param_group(ir::value *value, unsigned ax); fragment_t get_fragment(ir::value *value, unsigned ax); void copy(ir::value *dst, ir::value *src); void run(ir::module &mod); - unsigned get_num_threads(); + unsigned get_param_group(ir::value *value, unsigned ax); const std::vector get_grids() const { return grids_; } - int get_mts(ir::value *value, unsigned ax); - int get_nts(ir::value *value, unsigned ax); - int get_fpw(ir::value *value, unsigned ax); - int get_wpt(ir::value *value, unsigned ax); + int mts(ir::value *value, unsigned ax); + int nts(ir::value *value, unsigned ax); + int fpw(ir::value *value, unsigned ax); + int wpt(ir::value *value, unsigned ax); private: diff --git a/include/triton/codegen/selection.h b/include/triton/codegen/selection.h index 3efe0a256..b4d2e3344 100644 --- a/include/triton/codegen/selection.h +++ b/include/triton/codegen/selection.h @@ -157,7 +157,11 @@ private: void create_grids(std::vector &grids, std::map &references, ir::function *fn); + void create_shared_tile(ir::value *v, Builder &builder, Value *sh_mem_ptr); + void create_distributed_tile(ir::value *v, Builder &builder); void create_tile(ir::value *v, Builder &builder, std::set &seen, Value *sh_mem_ptr); + void init_strided_scan_axes(ir::value *i, Builder &builder, Value *u_thread_id, Value *u_warp_id); + void init_hmma_axes(ir::value *i, Builder &builder, Value *u_thread_id, Value *u_warp_id); void init_axes(ir::value *i, Builder &builder, Value *u_thread_id, Value *u_warp_id); void init_grids(ir::function *fn, Builder &builder, Value *sh_mem_ptr); @@ -195,8 +199,8 @@ private: public: - selection(analysis::memalloc *alloc, analysis::grids *params, analysis::meminfo *buffer_info, analysis::align *alignment, transform::coalesce* reorder, target *tgt) - : alloc_(alloc), params_(params), buffer_info_(buffer_info), alignment_(alignment), reorder_(reorder), tgt_(tgt){ } + selection(analysis::memalloc *alloc, analysis::grids *params, analysis::meminfo *buffer_info, analysis::align *alignment, transform::coalesce* reorder, target *tgt, unsigned num_warps) + : alloc_(alloc), params_(params), buffer_info_(buffer_info), alignment_(alignment), reorder_(reorder), tgt_(tgt), num_warps_(num_warps){ } void run(ir::module &src, Module &dst); @@ -215,6 +219,7 @@ private: Value *offset_b_j_, *offset_b_k_; unsigned num_packs_0_, num_packs_1_; unsigned pack_size_0_, pack_size_1_; + unsigned num_warps_; }; } diff --git a/lib/codegen/analysis/grid.cc b/lib/codegen/analysis/grid.cc index 2d9b494c4..cf5a718cd 100644 --- a/lib/codegen/analysis/grid.cc +++ b/lib/codegen/analysis/grid.cc @@ -156,8 +156,8 @@ grids::fragment_t grids::get_fragmentation_type(node_t x, graph_t &graph){ return STRIDED_SCAN; } -void grids::connected_components(node_t x, const std::vector& ptr_vec, const std::vector& maps, std::set &nodes, graph_t &graph, unsigned group_id) -{ +void grids::connected_components(node_t x, const std::vector& ptr_vec, const std::vector& maps, + std::set &nodes, graph_t &graph, unsigned group_id) { groups_[x.first].insert({x.second, group_id}); if(nodes.find(x) != nodes.end()){ nodes.erase(x); @@ -190,22 +190,18 @@ void grids::copy(ir::value *dst, ir::value *src) { void grids::run(ir::module &mod) { - ir::context &ctx = mod.get_context(); - // Create metaparameters + // Create tiling parameters for(ir::function *fn: mod.get_function_list()){ - // Build constraints graph for(ir::basic_block *block: fn->blocks()) for(ir::instruction *i : block->get_inst_list()) if(i->has_tile_result_or_op()) init_c_graph(i); - // Build phi constraints for(ir::basic_block *block: fn->blocks()) for(ir::instruction *i : block->get_inst_list()) if(i->has_tile_result_or_op()) init_c_phi(i); - // Layout parameters unsigned group_id = 0; for(auto x: nodes_) @@ -231,7 +227,7 @@ void grids::run(ir::module &mod) { } - unsigned num_threads = get_num_threads(); + unsigned num_threads = num_warps_*32; auto clamp = [&](unsigned x, unsigned lo, unsigned hi) { return std::min(std::max(x, lo), hi); }; for(ir::value *i: grids_){ @@ -242,10 +238,8 @@ void grids::run(ir::module &mod) { unsigned size = i->get_type()->get_tile_num_elements(); /* HMMA parameters*/ if(fragments_.at({i, 0}) == HMMA_FRAGMENT_C){ - unsigned shape_0 = shapes[order[0]]; unsigned shape_1 = shapes[order[1]]; - /* fragments per warp */ // try to make things as square as possible to maximize data re-use std::vector fpw = {1, 1, 1}; @@ -261,7 +255,6 @@ void grids::run(ir::module &mod) { // store parameters for(unsigned d = 0; d < shapes.size(); d++) *fpw_[i][d] = fpw[d]; - /* warps per tile */ // try to make things as square as possible to maximize data re-use std::vector wpt = {1, 1, 1}; @@ -276,15 +269,12 @@ void grids::run(ir::module &mod) { // store parameters for(unsigned d = 0; d < shapes.size(); d++) *wpt_[i][d] = wpt[d]; - /* sanity check */ unsigned effective_num_warps = 1; for(size_t d = 0; d < shapes.size(); d++) effective_num_warps *= *wpt_[i][d]; - if(num_warps_ != effective_num_warps) throw std::runtime_error("cannot create a kernel with this amount of warps"); - } /* Scan-line */ @@ -356,24 +346,19 @@ void grids::create_grids(std::vector &grids, grids.push_back(ref.second); } - -unsigned grids::get_num_threads() { - return num_warps_*32; -} - -int grids::get_mts(ir::value *value, unsigned ax) { +int grids::mts(ir::value *value, unsigned ax) { return *mts_.at(value).at(ax); } -int grids::get_nts(ir::value *value, unsigned ax) { +int grids::nts(ir::value *value, unsigned ax) { return *nts_.at(value).at(ax); } -int grids::get_fpw(ir::value *value, unsigned ax) { +int grids::fpw(ir::value *value, unsigned ax) { return *fpw_.at(value).at(ax); } -int grids::get_wpt(ir::value *value, unsigned ax) { +int grids::wpt(ir::value *value, unsigned ax) { return *wpt_.at(value).at(ax); } diff --git a/lib/codegen/analysis/memalloc.cc b/lib/codegen/analysis/memalloc.cc index be81b68e2..631b8f663 100644 --- a/lib/codegen/analysis/memalloc.cc +++ b/lib/codegen/analysis/memalloc.cc @@ -57,9 +57,9 @@ unsigned memalloc::get_num_bytes(ir::value *x) { num_elements *= x; size_t depth; if(params_->get_fragment(x, 0) == grids::HMMA_FRAGMENT_C) - depth = params_->get_wpt(op, axis); + depth = params_->wpt(op, axis); else - depth = params_->get_mts(op, axis); + depth = params_->mts(op, axis); return num_elements * num_bytes * depth; } unsigned num_bytes = x->get_type()->get_primitive_size_in_bits() / 8; diff --git a/lib/codegen/selection.cc b/lib/codegen/selection.cc index 8b6588386..4cff99890 100644 --- a/lib/codegen/selection.cc +++ b/lib/codegen/selection.cc @@ -571,145 +571,154 @@ inline void to_warps(const std::vector &bs, const std::vector &builder, Value *u_thread_id, Value *u_warp_id) { +void selection::init_strided_scan_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) { auto order = reorder_->get_order(v); const auto& shapes = v->get_type()->get_tile_shapes(); size_t dim = shapes.size(); - if(params_->get_fragment(v, 0) == analysis::grids::STRIDED_SCAN){ - std::vector contiguous(dim); - std::vector block_size(dim); - std::vector warp_size(dim); - std::vector n_warps(dim); - for(unsigned i = 0; i < shapes.size(); i++){ - contiguous[i] = params_->get_nts(v, i); - block_size[i] = params_->get_mts(v, i); - } - to_warps(block_size, order, n_warps, warp_size); - std::vector thread_id_in_warp = delinearize(u_thread_id, order, warp_size, builder); - std::vector warp_id = delinearize(u_warp_id, order, n_warps, builder); - // Create axes - for(unsigned k = 0; k < dim; k++) { - std::string str_k = std::to_string(k); - Value *warp_size_k = builder.getInt32(warp_size[k]); - Value *contiguous_k = builder.getInt32(contiguous[k]); - Value *thread_id = builder.CreateAdd(thread_id_in_warp[k], builder.CreateMul(warp_id[k], warp_size_k)); - Value *scaled_thread_id = builder.CreateMul(thread_id, contiguous_k); - unsigned per_block = contiguous[k] * warp_size[k] * n_warps[k]; - unsigned per_thread = contiguous[k] * shapes[k] / per_block; - std::vector idx_list(per_thread); - for(unsigned n = 0 ; n < per_thread; n++){ - unsigned offset = n / contiguous[k] * per_block + n % contiguous[k]; - idx_list[n] = builder.CreateAdd(scaled_thread_id, builder.getInt32(offset), "idx_" + str_k + "_" + std::to_string(n)); - } - axes_[params_->get_param_group(v, k)] = distributed_axis{contiguous[k], idx_list, thread_id}; - } + std::vector contiguous(dim); + std::vector block_size(dim); + std::vector warp_size(dim); + std::vector n_warps(dim); + for(unsigned i = 0; i < shapes.size(); i++){ + contiguous[i] = params_->nts(v, i); + block_size[i] = params_->mts(v, i); } - else { - if(shapes.size() > 3) - throw std::runtime_error("unsupported"); - bool is_batched = shapes.size() >= 3; - - Value *_1 = builder.getInt32(1); - Value *_2 = builder.getInt32(2); - Value *_3 = builder.getInt32(3); - Value *_4 = builder.getInt32(4); - Value *_16 = builder.getInt32(16); - - // fragments per warp - unsigned fpw_0 = params_->get_fpw(v, 0); - unsigned fpw_1 = params_->get_fpw(v, 1); - unsigned fpw_2 = is_batched ? params_->get_fpw(v, 2) : 1; - // warps per tile - unsigned wpt_0 = params_->get_wpt(v, 0); - unsigned wpt_1 = params_->get_wpt(v, 1); - unsigned wpt_2 = is_batched ? params_->get_wpt(v, 2) : 1; - // hmma warp tile size - unsigned hmma_wts_0 = fpw_0 * 8; - unsigned hmma_wts_1 = fpw_1 * 8; - unsigned hmma_wts_2 = is_batched ? fpw_2 : 1; - // hmma block tile size - unsigned hmma_bts_0 = hmma_wts_0 * wpt_0; - unsigned hmma_bts_1 = hmma_wts_1 * wpt_1; - unsigned hmma_bts_2 = is_batched ? hmma_wts_2 * wpt_2 : 1; - // number of repetition - unsigned num_rep_0 = shapes[0] / hmma_bts_0; - unsigned num_rep_1 = shapes[1] / hmma_bts_1; - unsigned num_rep_2 = is_batched ? shapes[2] / hmma_bts_2 : 1; - // size of each pack (interleaving) - pack_size_0_ = std::min(num_rep_0, 1); - pack_size_1_ = std::min(num_rep_1, 1); - // number of packs (interleaving) - num_packs_0_ = num_rep_0 / pack_size_0_; - num_packs_1_ = num_rep_1 / pack_size_1_; - - /* intra warp offset */ - // offset of quad in pair - Value *in_pair_off_a = builder.CreateMul(builder.CreateUDiv(builder.CreateAnd(u_thread_id, _16), builder.getInt32(4)), - builder.getInt32(fpw_0 * pack_size_0_)); - Value *in_pair_off_b = builder.CreateMul(builder.CreateUDiv(builder.CreateAnd(u_thread_id, _16), builder.getInt32(4)), - builder.getInt32(fpw_1 * pack_size_1_)); - - // Quad pair id - Value *pair_a_id = builder.CreateUDiv(builder.CreateURem(u_thread_id, _16), _4); - Value *pair_b_id = builder.CreateUDiv(builder.CreateURem(u_thread_id, _16), _4); - pair_a_id = builder.CreateURem(pair_a_id, builder.getInt32(fpw_0)); - pair_b_id = builder.CreateUDiv(pair_b_id, builder.getInt32(fpw_0)); - pair_b_id = builder.CreateURem(pair_b_id, builder.getInt32(fpw_1)); - // Quad pair offset - Value *pair_a_off = builder.CreateMul(pair_a_id, builder.getInt32(4 * pack_size_0_)); - Value *pair_b_off = builder.CreateMul(pair_b_id, builder.getInt32(4 * pack_size_1_)); - - /* inter warp offset */ - Value *warp_id_0 = builder.CreateURem(u_warp_id, builder.getInt32(wpt_0)); - Value *warp_id_12 = builder.CreateUDiv(u_warp_id, builder.getInt32(wpt_0)); - Value *warp_id_1 = builder.CreateURem(warp_id_12, builder.getInt32(wpt_1)); - Value *warp_id_2 = builder.CreateUDiv(warp_id_12, builder.getInt32(wpt_1)); - Value *warp_offset_i = builder.CreateMul(warp_id_0, builder.getInt32(hmma_wts_0 * pack_size_0_)); - Value *warp_offset_j = builder.CreateMul(warp_id_1, builder.getInt32(hmma_wts_1 * pack_size_1_)); - - /* offsets */ - // a offset - offset_a_i_ = builder.CreateAdd(warp_offset_i, builder.CreateAdd(pair_a_off, in_pair_off_a)); - offset_a_k_ = builder.CreateAnd(u_thread_id, _3); - // b offsets - offset_b_j_ = builder.CreateAdd(warp_offset_j, builder.CreateAdd(pair_b_off, in_pair_off_b)); - offset_b_k_ = builder.CreateAnd(u_thread_id, _3); - - // c offsets - Value *offset_c_i = builder.CreateAdd(builder.CreateAnd(u_thread_id, _1), offset_a_i_); - Value *offset_c_j = builder.CreateAdd(builder.CreateAnd(u_thread_id, _2), - builder.CreateAdd(warp_offset_j, pair_b_off)); - - /* indices */ - // i indices - std::vector idx_i; - for(unsigned pack = 0; pack < num_packs_0_; pack++) - for(unsigned ii = 0; ii < pack_size_0_; ii++) - for(unsigned i = 0; i < 2; i++){ - idx_i.push_back(builder.CreateAdd(offset_c_i, builder.getInt32(pack*hmma_bts_0*pack_size_0_ + ii*4 + i*2))); + to_warps(block_size, order, n_warps, warp_size); + std::vector thread_id_in_warp = delinearize(u_thread_id, order, warp_size, builder); + std::vector warp_id = delinearize(u_warp_id, order, n_warps, builder); + // Create axes + for(unsigned k = 0; k < dim; k++) { + std::string str_k = std::to_string(k); + Value *warp_size_k = builder.getInt32(warp_size[k]); + Value *contiguous_k = builder.getInt32(contiguous[k]); + Value *thread_id = builder.CreateAdd(thread_id_in_warp[k], builder.CreateMul(warp_id[k], warp_size_k)); + Value *scaled_thread_id = builder.CreateMul(thread_id, contiguous_k); + unsigned per_block = contiguous[k] * warp_size[k] * n_warps[k]; + unsigned per_thread = contiguous[k] * shapes[k] / per_block; + std::vector idx_list(per_thread); + for(unsigned n = 0 ; n < per_thread; n++){ + unsigned offset = n / contiguous[k] * per_block + n % contiguous[k]; + idx_list[n] = builder.CreateAdd(scaled_thread_id, builder.getInt32(offset), "idx_" + str_k + "_" + std::to_string(n)); } - // j indices - std::vector idx_j; - for(unsigned pack = 0; pack < num_packs_1_; pack++) - for(unsigned jj = 0; jj < pack_size_1_; jj++) - for(unsigned j = 0; j < 2; j++){ - idx_j.push_back(builder.CreateAdd(offset_c_j, builder.getInt32(pack*hmma_bts_1*pack_size_1_ + jj*4 + j*4*fpw_1*pack_size_1_))); - idx_j.push_back(builder.CreateAdd(offset_c_j, builder.getInt32(pack*hmma_bts_1*pack_size_1_ + jj*4 + j*4*fpw_1*pack_size_1_ + 1))); - } - // z indices - std::vector idx_z; - for(unsigned pack = 0; pack < num_rep_2; pack++) - idx_z.push_back(builder.CreateAdd(warp_id_2, builder.getInt32(pack*hmma_bts_2))); - - - /* axes */ - axes_[params_->get_param_group(v, 0)] = distributed_axis{1, idx_i, warp_id_0}; - axes_[params_->get_param_group(v, 1)] = distributed_axis{1, idx_j, warp_id_1}; - if(is_batched) - axes_[params_->get_param_group(v, 2)] = distributed_axis{1, idx_z, warp_id_2}; + axes_[params_->get_param_group(v, k)] = distributed_axis{contiguous[k], idx_list, thread_id}; } } +void selection::init_hmma_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) { +// auto order = reorder_->get_order(v); + const auto& shapes = v->get_type()->get_tile_shapes(); + if(shapes.size() > 3) + throw std::runtime_error("unsupported"); + bool is_batched = shapes.size() >= 3; + + Value *_1 = builder.getInt32(1); + Value *_2 = builder.getInt32(2); + Value *_3 = builder.getInt32(3); + Value *_4 = builder.getInt32(4); + Value *_16 = builder.getInt32(16); + + // fragments per warp + unsigned fpw_0 = params_->fpw(v, 0); + unsigned fpw_1 = params_->fpw(v, 1); + unsigned fpw_2 = is_batched ? params_->fpw(v, 2) : 1; + // warps per tile + unsigned wpt_0 = params_->wpt(v, 0); + unsigned wpt_1 = params_->wpt(v, 1); + unsigned wpt_2 = is_batched ? params_->wpt(v, 2) : 1; + // hmma warp tile size + unsigned hmma_wts_0 = fpw_0 * 8; + unsigned hmma_wts_1 = fpw_1 * 8; + unsigned hmma_wts_2 = is_batched ? fpw_2 : 1; + // hmma block tile size + unsigned hmma_bts_0 = hmma_wts_0 * wpt_0; + unsigned hmma_bts_1 = hmma_wts_1 * wpt_1; + unsigned hmma_bts_2 = is_batched ? hmma_wts_2 * wpt_2 : 1; + // number of repetition + unsigned num_rep_0 = shapes[0] / hmma_bts_0; + unsigned num_rep_1 = shapes[1] / hmma_bts_1; + unsigned num_rep_2 = is_batched ? shapes[2] / hmma_bts_2 : 1; + // size of each pack (interleaving) + pack_size_0_ = std::min(num_rep_0, 1); + pack_size_1_ = std::min(num_rep_1, 1); + // number of packs (interleaving) + num_packs_0_ = num_rep_0 / pack_size_0_; + num_packs_1_ = num_rep_1 / pack_size_1_; + + /* intra warp offset */ + // offset of quad in pair + Value *in_pair_off_a = builder.CreateMul(builder.CreateUDiv(builder.CreateAnd(u_thread_id, _16), builder.getInt32(4)), + builder.getInt32(fpw_0 * pack_size_0_)); + Value *in_pair_off_b = builder.CreateMul(builder.CreateUDiv(builder.CreateAnd(u_thread_id, _16), builder.getInt32(4)), + builder.getInt32(fpw_1 * pack_size_1_)); + + // Quad pair id + Value *pair_a_id = builder.CreateUDiv(builder.CreateURem(u_thread_id, _16), _4); + Value *pair_b_id = builder.CreateUDiv(builder.CreateURem(u_thread_id, _16), _4); + pair_a_id = builder.CreateURem(pair_a_id, builder.getInt32(fpw_0)); + pair_b_id = builder.CreateUDiv(pair_b_id, builder.getInt32(fpw_0)); + pair_b_id = builder.CreateURem(pair_b_id, builder.getInt32(fpw_1)); + // Quad pair offset + Value *pair_a_off = builder.CreateMul(pair_a_id, builder.getInt32(4 * pack_size_0_)); + Value *pair_b_off = builder.CreateMul(pair_b_id, builder.getInt32(4 * pack_size_1_)); + + /* inter warp offset */ + Value *warp_id_0 = builder.CreateURem(u_warp_id, builder.getInt32(wpt_0)); + Value *warp_id_12 = builder.CreateUDiv(u_warp_id, builder.getInt32(wpt_0)); + Value *warp_id_1 = builder.CreateURem(warp_id_12, builder.getInt32(wpt_1)); + Value *warp_id_2 = builder.CreateUDiv(warp_id_12, builder.getInt32(wpt_1)); + Value *warp_offset_i = builder.CreateMul(warp_id_0, builder.getInt32(hmma_wts_0 * pack_size_0_)); + Value *warp_offset_j = builder.CreateMul(warp_id_1, builder.getInt32(hmma_wts_1 * pack_size_1_)); + + /* offsets */ + // a offset + offset_a_i_ = builder.CreateAdd(warp_offset_i, builder.CreateAdd(pair_a_off, in_pair_off_a)); + offset_a_k_ = builder.CreateAnd(u_thread_id, _3); + // b offsets + offset_b_j_ = builder.CreateAdd(warp_offset_j, builder.CreateAdd(pair_b_off, in_pair_off_b)); + offset_b_k_ = builder.CreateAnd(u_thread_id, _3); + + // c offsets + Value *offset_c_i = builder.CreateAdd(builder.CreateAnd(u_thread_id, _1), offset_a_i_); + Value *offset_c_j = builder.CreateAdd(builder.CreateAnd(u_thread_id, _2), + builder.CreateAdd(warp_offset_j, pair_b_off)); + + /* indices */ + // i indices + std::vector idx_i; + for(unsigned pack = 0; pack < num_packs_0_; pack++) + for(unsigned ii = 0; ii < pack_size_0_; ii++) + for(unsigned i = 0; i < 2; i++){ + idx_i.push_back(builder.CreateAdd(offset_c_i, builder.getInt32(pack*hmma_bts_0*pack_size_0_ + ii*4 + i*2))); + } + // j indices + std::vector idx_j; + for(unsigned pack = 0; pack < num_packs_1_; pack++) + for(unsigned jj = 0; jj < pack_size_1_; jj++) + for(unsigned j = 0; j < 2; j++){ + idx_j.push_back(builder.CreateAdd(offset_c_j, builder.getInt32(pack*hmma_bts_1*pack_size_1_ + jj*4 + j*4*fpw_1*pack_size_1_))); + idx_j.push_back(builder.CreateAdd(offset_c_j, builder.getInt32(pack*hmma_bts_1*pack_size_1_ + jj*4 + j*4*fpw_1*pack_size_1_ + 1))); + } + // z indices + std::vector idx_z; + for(unsigned pack = 0; pack < num_rep_2; pack++) + idx_z.push_back(builder.CreateAdd(warp_id_2, builder.getInt32(pack*hmma_bts_2))); + + + /* axes */ + axes_[params_->get_param_group(v, 0)] = distributed_axis{1, idx_i, warp_id_0}; + axes_[params_->get_param_group(v, 1)] = distributed_axis{1, idx_j, warp_id_1}; + if(is_batched) + axes_[params_->get_param_group(v, 2)] = distributed_axis{1, idx_z, warp_id_2}; +} + + +void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) { + if(params_->get_fragment(v, 0) == analysis::grids::STRIDED_SCAN) + init_strided_scan_axes(v, builder, u_thread_id, u_warp_id); + else + init_hmma_axes(v, builder, u_thread_id, u_warp_id); +} + bool static inline has_phi_user(ir::value *v) { for(ir::user *usr: v->get_users()){ if(dynamic_cast(usr)) @@ -717,94 +726,97 @@ bool static inline has_phi_user(ir::value *v) { } return false; } + +void selection::create_shared_tile(ir::value *v, IRBuilder<> &builder, Value *sh_mem_ptr) { + auto shapes = v->get_type()->get_tile_shapes(); + unsigned pad = alloc_->is_ld_padded(v); + if(pad > 0) + shapes[0] += pad; + Type* ty = llvm_type(v->get_type()->get_scalar_ty(), builder.getContext()); + // shared copy + PointerType *ptr_ty = ty->getPointerTo(sh_mem_ptr->getType()->getPointerAddressSpace()); + // phi-node (double-buffering) + if(auto *phi = dynamic_cast(v)) { + BasicBlock *parent = (BasicBlock*)vmap_[phi->get_parent()]; + unsigned id_pre = 0, id_loop = 1; + if(phi->get_incoming_block(0) == phi->get_parent()) + std::swap(id_pre, id_loop); + if(parent->empty()) + builder.SetInsertPoint(parent); + else + builder.SetInsertPoint(&*parent->getFirstInsertionPt()); + PHINode *ptr = builder.CreatePHI(ptr_ty, 2); + PHINode *offset = builder.CreatePHI(builder.getInt32Ty(), 2); + // next pointer + Value *pre_ptr = builder.CreateGEP(sh_mem_ptr, builder.getInt32(alloc_->get_offset(phi))); + pre_ptr = builder.CreateBitCast(pre_ptr, ptr->getType()); + Value *next_ptr = builder.CreateGEP(ptr, offset, "next_ptr"); + tmap_.insert({phi, new shared_tile(ty, shapes, ptr, builder, offset)}); + for(unsigned i = 0; i < phi->get_num_incoming(); i++) { + ir::basic_block* inc_block = phi->get_incoming_block(i); + ir::value* inc_value = phi->get_incoming_value(i); + ir::instruction* terminator = inc_block->get_inst_list().back(); + bool is_loop_latch = buffer_info_->is_loop_latch(phi, terminator); + tmap_.insert({inc_value, new shared_tile(ty, shapes, is_loop_latch?next_ptr:pre_ptr, builder)}); + } + } + else { + if(!has_phi_user(v)){ + size_t offset = alloc_->get_offset(v); + Value *ptr = builder.CreateGEP(sh_mem_ptr, builder.getInt32(offset)); + ptr = builder.CreateBitCast(ptr, ptr_ty); + tmap_.insert({v, new shared_tile(ty, shapes, ptr, builder)}); + } + } +} + +void selection::create_distributed_tile(ir::value *v, IRBuilder<> &builder) { + Type* ty = llvm_type(v->get_type()->get_scalar_ty(), builder.getContext()); + const auto &shapes = v->get_type()->get_tile_shapes(); + std::vector axes(shapes.size()); + for(size_t d = 0; d < shapes.size(); d++){ + if(shapes[d] > 1){ + unsigned x = params_->get_param_group(v, d); + axes[d] = axes_.at(x); + } + else{ + axes[d].contiguous = 1; + axes[d].values = {builder.getInt32(0)}; + } + } + bool vectorize = dynamic_cast(v); + distributed_tile *T = new distributed_tile(ty, shapes, axes, builder, vectorize); + bool is_inserted = tmap_.insert({v, T}).second; + // constant range + if(is_inserted && dynamic_cast(v)){ + T->for_each([&](indices_t idx){ + assert(idx.size() == 1); + T->set_value(idx, idx[0]); + }); + } + if(is_inserted && dynamic_cast(v)){ + T->for_each([&](indices_t idx){ + assert(idx.size() == 1); + BinaryOperator *bin_add = dyn_cast(idx[0]); + assert(bin_add); + Value *res = bin_add->getOperand(1); + assert(isa(res)); + T->set_value(idx, res); + }); + } +} + void selection::create_tile(ir::value *v, IRBuilder<> &builder, std::set &seen, Value *sh_mem_ptr) { if(!v->get_type()->is_tile_ty() || !seen.insert(v).second) return; if(auto *user = dynamic_cast(v)) - for(ir::value *op: user->ops()){ + for(ir::value *op: user->ops()) create_tile(op, builder, seen, sh_mem_ptr); - } - LLVMContext &ctx = builder.getContext(); - auto shapes = v->get_type()->get_tile_shapes(); - unsigned pad = alloc_->is_ld_padded(v); - if(pad > 0) - shapes[0] += pad; - Type* ty = llvm_type(v->get_type()->get_scalar_ty(), ctx); - // create shared tile - if(buffer_info_->is_shared(v) && !dynamic_cast(v)){ - // shared copy - PointerType *ptr_ty = ty->getPointerTo(sh_mem_ptr->getType()->getPointerAddressSpace()); - // phi-node (double-buffering) - if(auto *phi = dynamic_cast(v)) { - BasicBlock *parent = (BasicBlock*)vmap_[phi->get_parent()]; - unsigned id_pre = 0, id_loop = 1; - if(phi->get_incoming_block(0) == phi->get_parent()) - std::swap(id_pre, id_loop); - if(parent->empty()) - builder.SetInsertPoint(parent); - else - builder.SetInsertPoint(&*parent->getFirstInsertionPt()); - PHINode *ptr = builder.CreatePHI(ptr_ty, 2); - PHINode *offset = builder.CreatePHI(builder.getInt32Ty(), 2); - // next pointer - Value *pre_ptr = builder.CreateGEP(sh_mem_ptr, builder.getInt32(alloc_->get_offset(phi))); - pre_ptr = builder.CreateBitCast(pre_ptr, ptr->getType()); - Value *next_ptr = builder.CreateGEP(ptr, offset, "next_ptr"); - tmap_.insert({phi, new shared_tile(ty, shapes, ptr, builder, offset)}); - for(unsigned i = 0; i < phi->get_num_incoming(); i++) { - ir::basic_block* inc_block = phi->get_incoming_block(i); - ir::value* inc_value = phi->get_incoming_value(i); - ir::instruction* terminator = inc_block->get_inst_list().back(); - bool is_loop_latch = buffer_info_->is_loop_latch(phi, terminator); - tmap_.insert({inc_value, new shared_tile(ty, shapes, is_loop_latch?next_ptr:pre_ptr, builder)}); - } - } - else { - if(!has_phi_user(v)){ - size_t offset = alloc_->get_offset(v); - Value *ptr = builder.CreateGEP(sh_mem_ptr, builder.getInt32(offset)); - ptr = builder.CreateBitCast(ptr, ptr_ty); - tmap_.insert({v, new shared_tile(ty, shapes, ptr, builder)}); - } - } - } - // create distributed tile - else { - const auto &shapes = v->get_type()->get_tile_shapes(); - std::vector axes(shapes.size()); - for(size_t d = 0; d < shapes.size(); d++){ - if(shapes[d] > 1){ - unsigned x = params_->get_param_group(v, d); - axes[d] = axes_.at(x); - } - else{ - axes[d].contiguous = 1; - axes[d].values = {builder.getInt32(0)}; - } - } - bool vectorize = dynamic_cast(v); - distributed_tile *T = new distributed_tile(ty, shapes, axes, builder, vectorize); - bool is_inserted = tmap_.insert({v, T}).second; - // constant range - if(is_inserted && dynamic_cast(v)){ - T->for_each([&](indices_t idx){ - assert(idx.size() == 1); - T->set_value(idx, idx[0]); - }); - } - if(is_inserted && dynamic_cast(v)){ - T->for_each([&](indices_t idx){ - assert(idx.size() == 1); - BinaryOperator *bin_add = dyn_cast(idx[0]); - assert(bin_add); - Value *res = bin_add->getOperand(1); - assert(isa(res)); - T->set_value(idx, res); - }); - } - - } + if(buffer_info_->is_shared(v) && !dynamic_cast(v)) + create_shared_tile(v, builder, sh_mem_ptr); + else + create_distributed_tile(v, builder); } void selection::init_grids(ir::function *fn, IRBuilder<> &builder, Value *sh_mem_ptr){ @@ -908,7 +920,7 @@ void selection::lower_reduce(ir::reduce_inst *x, LLVMContext &ctx, Function *fn, tgt_->add_barrier(module, builder); builder.CreateStore(result, write_ptr); // build result - unsigned depth = params_->get_wpt(op, axis); + unsigned depth = params_->wpt(op, axis); for(unsigned i = depth/2; i > 0; i >>= 1){ // current indices indices_t current(write_idx.size(), builder.getInt32(0)); @@ -1075,12 +1087,12 @@ void selection::lower_hmma_dot(ir::dot_inst *dot, LLVMContext &ctx, Function *fn "{$10, $11}, " "{$0, $1, $2, $3, $4, $5, $6, $7};", "=f,=f,=f,=f,=f,=f,=f,=f,r,r,r,r,0,1,2,3,4,5,6,7", false); - unsigned fpw_0 = params_->get_fpw(dot, 0); - unsigned fpw_1 = params_->get_fpw(dot, 1); + unsigned fpw_0 = params_->fpw(dot, 0); + unsigned fpw_1 = params_->fpw(dot, 1); unsigned wts_0 = fpw_0 * 8; unsigned wts_1 = fpw_1 * 8; - unsigned wpt_0 = params_->get_wpt(dot, 0); - unsigned wpt_1 = params_->get_wpt(dot, 1); + unsigned wpt_0 = params_->wpt(dot, 0); + unsigned wpt_1 = params_->wpt(dot, 1); unsigned stride_rep_i = wpt_0 * wts_0; unsigned stride_rep_j = wpt_1 * wts_1; unsigned num_rep_i = shapes[0] / stride_rep_i; @@ -1457,7 +1469,7 @@ void selection::run(ir::module &src, Module &dst) { Metadata *md_args[] = { ValueAsMetadata::get(dst_fn), MDString::get(dst_ctx, "maxntidx"), - ValueAsMetadata::get(dst_builder.getInt32(params_->get_num_threads())) + ValueAsMetadata::get(dst_builder.getInt32(num_warps_*32)) }; dst.getOrInsertNamedMetadata("nvvm.annotations")->addOperand(MDNode::get(dst_ctx, md_args)); diff --git a/lib/codegen/transform/vectorize.cc b/lib/codegen/transform/vectorize.cc index e7e329c02..4d1b88541 100644 --- a/lib/codegen/transform/vectorize.cc +++ b/lib/codegen/transform/vectorize.cc @@ -27,7 +27,7 @@ void vectorize::run(ir::module &mod) { } if(dynamic_cast(i)){ ir::value *x = i->get_operand(0); - if(params_->get_nts(x, 0) == 1) + if(params_->nts(x, 0) == 1) continue; builder.set_insert_point(i); ir::instruction *rx = (ir::instruction*)builder.create_vectorize(x); diff --git a/lib/runtime/function.cc b/lib/runtime/function.cc index 3dd7c1507..016d5c879 100644 --- a/lib/runtime/function.cc +++ b/lib/runtime/function.cc @@ -205,7 +205,7 @@ std::unique_ptr function::make_bin(ir::module &module, driver::c codegen::transform::dce dce; codegen::transform::peephole peephole; codegen::transform::reassociate reassociate(&alignment_info, &grids); - codegen::selection selection(&shmem_allocation, &grids, &shmem_info, &alignment_info, &reorder, target.get()); + codegen::selection selection(&shmem_allocation, &grids, &shmem_info, &alignment_info, &reorder, target.get(), opt.num_warps); // run passes peephole.run(module); dce.run(module);