#include #include #include #include "triton/codegen/analysis/axes.h" #include "triton/codegen/analysis/align.h" #include "triton/codegen/analysis/layout.h" #include "triton/ir/function.h" #include "triton/ir/module.h" #include "triton/ir/utils.h" namespace triton{ namespace codegen{ namespace analysis{ // constructor layout::layout(analysis::axes *axes, analysis::align *align, size_t num_warps) : axes_(axes), align_(align), num_warps_(num_warps) { } // get group id unsigned layout::layout_of(ir::value *value) const { return groups_.at(value); } // get values const std::vector& layout::values_of(unsigned id) const { return values_.at(id); } // get number of groups size_t layout::num_layouts() const { return values_.size(); } // connect two values void layout::connect(ir::value *x, ir::value *y) { if(x == y) return; if(!x->get_type()->is_tile_ty()) return; if(!y->get_type()->is_tile_ty()) return; std::vector x_axes = axes_->get(x); std::vector y_axes = axes_->get(y); std::set sx_axes(x_axes.begin(), x_axes.end()); std::set sy_axes(y_axes.begin(), y_axes.end()); std::set common; std::set_intersection(sx_axes.begin(), sx_axes.end(), sy_axes.begin(), sy_axes.end(), std::inserter(common, common.begin())); if(!common.empty()) graph_.add_edge(x, y); } // make graph void layout::make_graph(ir::instruction *i) { for(ir::value* opx: i->ops()) for(ir::value* opy: i->ops()){ connect(i, opx); connect(opx, opy); } } // hmma bool is_hmma_c(ir::value *v){ bool result = false; if(auto *x = dynamic_cast(v)){ ir::value *a = x->get_operand(0); ir::type *a_ty = a->get_type(); ir::value *b = x->get_operand(1); ir::type *b_ty = b->get_type(); result = a_ty->get_scalar_ty()->is_half_ty() && b_ty->get_scalar_ty()->is_half_ty(); } return result; } const layout_t &layout::get(ir::value *v) const { return layouts_.at(groups_.at(v)); } std::map& layout::get_all() { return layouts_; } void extract_io_use(ir::value *v, std::set& result) { for(ir::user* u: v->get_users()){ auto i = dynamic_cast(u); if(i && i->get_pointer_operand() == v) result.insert(v); } } inline bool is_trans(ir::value *v) { if(dynamic_cast(v)) { return true; } if(auto *phi = dynamic_cast(v)) { bool result = true; for(ir::value *op: phi->ops()) result = result && is_trans(op); return result; } return false; } inline unsigned clamp(unsigned x, unsigned lo, unsigned hi) { return std::min(std::max(x, lo), hi); } void layout::init_hmma_tile(layout_t& layout) { auto ord = layout.order; auto shapes = layout.shapes; unsigned shape_0 = shapes[ord[0]]; unsigned shape_1 = shapes[ord[1]]; /* fragments per warp */ // try to make things as square as possible to maximize data re-use std::vector fpw = {1, 1, 1}; std::vector fpw_nm1; unsigned num_fragments = std::min((shape_0/8)*(shape_1/8), 4); do { fpw_nm1 = fpw; if(fpw[0]*fpw[1] < num_fragments) fpw[0] = clamp(fpw[0]*2, 1, shape_0 / 8); if(fpw[0]*fpw[1] < num_fragments) fpw[1] = clamp(fpw[1]*2, 1, shape_1 / 8); }while(fpw_nm1 != fpw); // store parameters for(unsigned d = 0; d < shapes.size(); d++) layout.fpw[d] = fpw[d]; /* warps per tile */ // try to make things as square as possible to maximize data re-use std::vector wpt = {1, 1, 1}; std::vector wpt_nm1; do{ wpt_nm1 = wpt; if(wpt[0] * wpt[1] * wpt[2] < num_warps_) wpt[0] = clamp(wpt[0]*2, 1, shape_0 / (fpw[0]*8)); if(wpt[0] * wpt[1] * wpt[2] < num_warps_) wpt[1] = clamp(wpt[1]*2, 1, shape_1 / (fpw[1]*8)); }while(wpt_nm1 != wpt); // store parameters for(unsigned d = 0; d < shapes.size(); d++) layout.wpt[d] = wpt[d]; /* sanity check */ unsigned effective_num_warps = 1; for(size_t d = 0; d < shapes.size(); d++) effective_num_warps *= layout.wpt[d]; if(num_warps_ != effective_num_warps) throw std::runtime_error("cannot create a kernel with this amount of warps"); } void layout::init_scanline_tile(layout_t& layout) { auto ord = layout.order; auto shapes = layout.shapes; unsigned size = std::accumulate(shapes.begin(), shapes.end(), 1, std::multiplies()); unsigned ld = ord[0]; unsigned num_threads = num_warps_*32; unsigned current = num_threads; layout.nts[ld] = clamp(size / num_threads, 1, 4); layout.mts[ld] = clamp(current, 1, shapes[ld] / layout.nts[ld]); current = current / layout.mts[ld]; for(size_t d = 1; d < shapes.size(); d++){ ld = ord[d]; layout.nts[ld] = 1; layout.mts[ld] = clamp(current, 1, shapes[ld]); current = current / layout.mts[ld]; } /* sanity check */ unsigned effective_num_threads = 1; for(size_t d = 0; d < shapes.size(); d++) effective_num_threads *= layout.mts[d]; if(num_threads != effective_num_threads) throw std::runtime_error("cannot create a kernel with this amount of warps"); } void layout::run(ir::module &mod) { // make graph graph_.clear(); ir::for_each_instruction(mod, [this](ir::instruction* i) { make_graph(i); }); // connected components graph_.connected_components(&values_, &groups_); // create layouts for(const auto& x: values_) { bool hmma_c = std::any_of(x.second.begin(), x.second.end(), &is_hmma_c); // type layouts_[x.first].type = hmma_c ? HMMA_884 : SCANLINE; } /* ---- TO CLEAN ---- */ size_t num_groups = num_layouts(); // helpers auto rank = [this](ir::value* v) { int ret = 0; for(int s: v->get_type()->get_tile_shapes()) ret += s > 1; return ret; }; // find out axes for each layout for(const auto& x: values_) { auto cmp = [&rank](ir::value* x, ir::value *y) { return rank(x) < rank(y); }; ir::value *largest = *std::max_element(x.second.begin(), x.second.end(), cmp); layouts_[x.first].axes = axes_->get(largest); layouts_[x.first].shapes = largest->get_type()->get_tile_shapes(); } // find out the layout ordering of a group for(const auto& x: values_) { std::set ptr; for(ir::value* v: x.second) extract_io_use(v, ptr); size_t rank = layouts_[x.first].axes.size(); std::vector order(rank); std::iota(order.begin(), order.end(), 0); for(ir::value *v: ptr){ auto max_contiguous = align_->contiguous(v); std::sort(order.begin(), order.end(), [&](unsigned a, unsigned b) { return max_contiguous[a] > max_contiguous[b]; } ); } layouts_[x.first].order = order; } // matrix multiplication optimizations for(size_t i = 0; i < num_groups; i++){ std::vector dots; for(ir::value* v: values_of(i)) if(auto *x = dynamic_cast(v)) dots.push_back(x); for(ir::dot_inst* dot: dots){ ir::value* a = dot->get_operand(0); ir::value* b = dot->get_operand(1); if(get(dot).type == HMMA_884){ auto a_val = values_of(layout_of(a)); auto b_val = values_of(layout_of(b)); for(ir::value *v: a_val) if(auto *cts = dynamic_cast(v)) layouts_[layout_of(a)].order = layouts_[layout_of(cts->get_operand(0))].order; for(ir::value *v: b_val) if(auto *cts = dynamic_cast(v)) layouts_[layout_of(b)].order = layouts_[layout_of(cts->get_operand(0))].order; } else{ std::vector col = {0, 1}; std::vector row = {1, 0}; layouts_[layout_of(a)].order = is_trans(a) ? row : col; layouts_[layout_of(b)].order = is_trans(b) ? col : row; } } } // tiling parameters for(auto& x: layouts_){ /* HMMA parameters*/ if(x.second.type == HMMA_884) init_hmma_tile(x.second); else init_scanline_tile(x.second); } } } } }