#include #include #include #include "triton/codegen/analysis/align.h" #include "triton/codegen/analysis/axes.h" #include "triton/codegen/analysis/tiles.h" #include "triton/codegen/analysis/layout.h" #include "triton/ir/instructions.h" #include "triton/ir/type.h" #include "triton/ir/module.h" #include "triton/ir/function.h" #include "triton/ir/context_impl.h" #include "triton/ir/constant.h" #include "triton/driver/device.h" namespace triton{ namespace codegen{ namespace analysis{ tiles::tiles(size_t num_warps, analysis::align *align, analysis::axes *axes, analysis::layout *layout): num_warps_(num_warps), align_(align), axes_(axes), layout_(layout) { } bool is_hmma(ir::value *v){ bool result = false; if(auto *x = dynamic_cast(v)){ ir::value *a = x->get_operand(0); ir::type *a_ty = a->get_type(); ir::value *b = x->get_operand(1); ir::type *b_ty = b->get_type(); result = a_ty->get_scalar_ty()->is_half_ty() && b_ty->get_scalar_ty()->is_half_ty(); } return result; } bool tiles::hmma(ir::value *value) { return hmma_.at(layout_->id(value)); } int tiles::mts(ir::value *value, unsigned ax) { return mts_.at(axes_->get(value, ax)); } int tiles::nts(ir::value *value, unsigned ax) { return nts_.at(axes_->get(value, ax)); } int tiles::fpw(ir::value *value, unsigned ax) { return fpw_.at(axes_->get(value, ax)); } int tiles::wpt(ir::value *value, unsigned ax) { return wpt_.at(axes_->get(value, ax)); } std::vector tiles::order(ir::value *v) { auto ret = order_[layout_->id(v)]; return ret; } const std::map& tiles::largest() { return largest_; } unsigned clamp(unsigned x, unsigned lo, unsigned hi) { return std::min(std::max(x, lo), hi); } void tiles::init_hmma_tile(ir::value *i) { auto ord = order(i); auto shapes = i->get_type()->get_tile_shapes(); unsigned shape_0 = shapes[ord[0]]; unsigned shape_1 = shapes[ord[1]]; /* fragments per warp */ // try to make things as square as possible to maximize data re-use std::vector fpw = {1, 1, 1}; std::vector fpw_nm1; unsigned num_fragments = std::min((shape_0/8)*(shape_1/8), 4); do { fpw_nm1 = fpw; if(fpw[0]*fpw[1] < num_fragments) fpw[0] = clamp(fpw[0]*2, 1, shape_0 / 8); if(fpw[0]*fpw[1] < num_fragments) fpw[1] = clamp(fpw[1]*2, 1, shape_1 / 8); }while(fpw_nm1 != fpw); // store parameters for(unsigned d = 0; d < shapes.size(); d++) fpw_[axes_->get(i, d)] = fpw[d]; /* warps per tile */ // try to make things as square as possible to maximize data re-use std::vector wpt = {1, 1, 1}; std::vector wpt_nm1; do{ wpt_nm1 = wpt; if(wpt[0] * wpt[1] * wpt[2] < num_warps_) wpt[0] = clamp(wpt[0]*2, 1, shape_0 / (fpw[0]*8)); if(wpt[0] * wpt[1] * wpt[2] < num_warps_) wpt[1] = clamp(wpt[1]*2, 1, shape_1 / (fpw[1]*8)); }while(wpt_nm1 != wpt); // store parameters for(unsigned d = 0; d < shapes.size(); d++) wpt_[axes_->get(i, d)] = wpt[d]; /* sanity check */ unsigned effective_num_warps = 1; for(size_t d = 0; d < shapes.size(); d++) effective_num_warps *= wpt_[axes_->get(i, d)]; if(num_warps_ != effective_num_warps) throw std::runtime_error("cannot create a kernel with this amount of warps"); } void tiles::init_scanline_tile(ir::value *i) { auto ord = order(i); auto shapes = i->get_type()->get_tile_shapes(); unsigned size = i->get_type()->get_tile_num_elements(); unsigned ld = ord[0]; unsigned num_threads = num_warps_*32; unsigned current = num_threads; nts_[axes_->get(i, ld)] = clamp(size / num_threads, 1, 4); mts_[axes_->get(i, ld)] = clamp(current, 1, shapes[ld] / nts_[axes_->get(i, ld)]); current = current / mts_[axes_->get(i, ld)]; for(size_t d = 1; d < shapes.size(); d++){ ld = ord[d]; nts_[axes_->get(i, ld)] = 1; mts_[axes_->get(i, ld)] = clamp(current, 1, shapes[ld]); current = current / mts_[axes_->get(i, ld)]; } /* sanity check */ unsigned effective_num_threads = 1; for(size_t d = 0; d < shapes.size(); d++) effective_num_threads *= mts_[axes_->get(i, d)]; if(num_threads != effective_num_threads) throw std::runtime_error("cannot create a kernel with this amount of warps"); } void extract_io_use(ir::value *v, std::set& result) { for(ir::user* u: v->get_users()){ auto i = dynamic_cast(u); if(i && i->get_pointer_operand() == v) result.insert(i); } } void tiles::run(ir::module &) { hmma_.clear(); largest_.clear(); size_t num_groups = layout_->get_num_groups(); // helpers auto rank = [](ir::value* v) { ir::type *ty = v->get_type(); size_t ret = 0; if(ty->is_tile_ty()) for(int s: ty->get_tile_shapes()) ret += s > 1; return ret; }; // find out which groups require hmma layout for(size_t i = 0; i < num_groups; i++) { const auto& values = layout_->values(i); hmma_[i] = std::any_of(values.begin(), values.end(), &is_hmma); } // find out which value is the largest in each group for(size_t i = 0; i < num_groups; i++) { const auto& values = layout_->values(i); auto cmp = [&rank](ir::value* x, ir::value *y) { return rank(x) < rank(y); }; largest_[i] = *std::max_element(values.begin(), values.end(), cmp); } // find out the order of a group for(size_t i = 0; i < num_groups; i++){ std::set io; for(ir::value* v: layout_->values(i)) extract_io_use(v, io); auto cmp = [&rank](ir::io_inst* x, ir::io_inst *y) { return rank(x->get_pointer_operand()) < rank(y->get_pointer_operand()); }; auto it = std::max_element(io.begin(), io.end(), cmp); std::vector order(rank(largest_[i])); std::iota(order.begin(), order.end(), 0); if(it != io.end()) { auto max_contiguous = align_->contiguous((*it)->get_pointer_operand()); std::sort(order.begin(), order.end(), [&](unsigned a, unsigned b) { return max_contiguous[a] > max_contiguous[b]; } ); } order_[i] = order; std::cout << "order: " << order[0] << " " << order[1] << std::endl; } // tiling parameters for(auto x: largest_){ ir::value *i = x.second; if(!i->get_type()->is_tile_ty()) continue; /* HMMA parameters*/ if(hmma_[x.first]) init_hmma_tile(i); else init_scanline_tile(i); } } } } }