diff --git a/include/triton/codegen/analysis/layout.h b/include/triton/codegen/analysis/layout.h index e0eee3a38..074bfb27c 100644 --- a/include/triton/codegen/analysis/layout.h +++ b/include/triton/codegen/analysis/layout.h @@ -50,31 +50,31 @@ public: virtual void visit_layout_shared(layout_shared_t*) = 0; }; +class layout_hmma_884_t; +class layout_scanline_t; +class layout_shared_t; + struct layout_t { layout_t(layout_type_t _type, const std::vector& _axes, const std::vector &_shapes, const std::vector &_values, ir::type *_ty, - size_t _id, analysis::align* align); - + // visitor virtual void accept(layout_visitor* vst) = 0; + // downcast + layout_hmma_884_t* to_hmma884(); + layout_scanline_t* to_scanline(); + layout_shared_t* to_shared(); + layout_type_t type; std::vector axes; std::vector shapes; std::vector values; std::vector order; - size_t id; - size_t size; - std::shared_ptr double_buffer; ir::type *ty; - size_t pad; - std::vector mts; - std::vector nts; - std::vector fpw; - std::vector wpt; }; struct layout_hmma_884_t: public layout_t { @@ -83,9 +83,11 @@ struct layout_hmma_884_t: public layout_t { const std::vector& _shapes, const std::vector &_values, ir::type *_ty, - size_t _id, analysis::align* align); void accept(layout_visitor* vst) { vst->visit_layout_hmma_884(this); } + + std::vector fpw; + std::vector wpt; }; struct layout_scanline_t: public layout_t { @@ -94,9 +96,11 @@ struct layout_scanline_t: public layout_t { const std::vector& _shapes, const std::vector &values, ir::type *_ty, - size_t _id, analysis::align* align); void accept(layout_visitor* vst) { vst->visit_layout_scanline(this); } + + std::vector mts; + std::vector nts; }; struct layout_shared_t: public layout_t { @@ -105,9 +109,11 @@ struct layout_shared_t: public layout_t { const std::vector& _shapes, const std::vector &values, ir::type *ty, - size_t _id, analysis::align* align); void accept(layout_visitor* vst) { vst->visit_layout_shared(this); } + + std::shared_ptr double_buffer; + size_t size; }; @@ -126,18 +132,6 @@ private: void create(size_t id, const std::vector& values); -// size_t shared_tmp_req(ir::instruction* i) { -// switch(i->get_id()) { -// case ir::INST_REDUCE: { -// ir::reduce_inst *red = (ir::reduce_inst*)i; -// ir::type *ty = red->get_type(); - - -// } -// default: return 0; -// } -// } - public: // constructor layout(analysis::axes *axes, analysis::align *align, size_t num_warps); @@ -146,8 +140,8 @@ public: unsigned layout_of(ir::value *value) const; const std::vector& values_of(unsigned id) const; size_t num_layouts() const; - const layout_t* get(size_t id) const; - const layout_t* get(ir::value *v) const; + layout_t* get(size_t id); + layout_t* get(ir::value *v); std::map &get_all(); size_t tmp(ir::instruction* i); diff --git a/include/triton/codegen/analysis/liveness.h b/include/triton/codegen/analysis/liveness.h index 4d5fa3e91..e0158dc8a 100644 --- a/include/triton/codegen/analysis/liveness.h +++ b/include/triton/codegen/analysis/liveness.h @@ -42,14 +42,14 @@ struct segment { class liveness { private: - typedef std::map intervals_map_t; + typedef std::map intervals_map_t; public: // constructor liveness(layout *l): layouts_(l){ } // accessors const intervals_map_t& get() const { return intervals_; } - segment get(layout_t* v) const { return intervals_.at(v); } + segment get(layout_shared_t* v) const { return intervals_.at(v); } // run void run(ir::module &mod); diff --git a/include/triton/codegen/selection/machine_layout.h b/include/triton/codegen/selection/machine_layout.h index a3b453995..5ea34f3f3 100644 --- a/include/triton/codegen/selection/machine_layout.h +++ b/include/triton/codegen/selection/machine_layout.h @@ -71,7 +71,8 @@ public: class machine_layout_shared_t: public machine_layout_t { public: - machine_layout_shared_t(Module *mod, Builder *builder, target *tgt, analysis::allocation* alloc, Value *&sh_mem_ptr, analysis::layout_t* layout, + machine_layout_shared_t(Module *mod, Builder *builder, target *tgt, analysis::allocation* alloc, Value *&sh_mem_ptr, + analysis::layout_shared_t* layout, std::map& vmap, std::map& tmap); @@ -82,7 +83,7 @@ public: target *tgt_; analysis::allocation* alloc_; Value *&sh_mem_ptr_; - analysis::layout_t* layout_; + analysis::layout_shared_t* layout_; std::map& vmap_; std::map& tmap_; diff --git a/lib/codegen/analysis/allocation.cc b/lib/codegen/analysis/allocation.cc index b92b5bd44..0cff27640 100644 --- a/lib/codegen/analysis/allocation.cc +++ b/lib/codegen/analysis/allocation.cc @@ -15,22 +15,22 @@ void allocation::run(ir::module &mod) { using std::min; typedef std::multimap triples_map_type; - std::vector I; + std::vector I; for(auto x: liveness_->get()) I.push_back(x.first); - std::vector J = I; + std::vector J = I; triples_map_type H; H.insert({0, segment{0, INT_MAX}}); - std::vector V; - std::map starts; + std::vector V; + std::map starts; while(!J.empty()){ auto h_it = H.begin(); unsigned w = h_it->first; segment xh = h_it->second; H.erase(h_it); - auto j_it = std::find_if(J.begin(), J.end(), [&](layout_t* JJ){ + auto j_it = std::find_if(J.begin(), J.end(), [&](layout_shared_t* JJ){ segment xj = liveness_->get(JJ); bool res = xj.intersect(xh); for(auto val: H) @@ -52,10 +52,10 @@ void allocation::run(ir::module &mod) { } // Build interference graph - std::map> interferences; - for(layout_t* x: V) - for(layout_t* y: V){ - if(x->id == y->id) + std::map> interferences; + for(layout_shared_t* x: V) + for(layout_shared_t* y: V){ + if(x == y) continue; unsigned X0 = starts[x], Y0 = starts[y]; unsigned NX = x->size; @@ -68,17 +68,17 @@ void allocation::run(ir::module &mod) { } // Initialize colors - std::map colors; - for(layout_t* X: V) - colors[X] = (X->id==V[0]->id)?0:-1; + std::map colors; + for(layout_shared_t* X: V) + colors[X] = (X==V[0])?0:-1; // First-fit graph coloring std::vector available(V.size()); - for(layout_t* x: V){ + for(layout_shared_t* x: V){ // Non-neighboring colors are available std::fill(available.begin(), available.end(), true); - for(layout_t* Y: interferences[x]){ + for(layout_shared_t* Y: interferences[x]){ int color = colors[Y]; if(color >= 0) available[color] = false; @@ -89,16 +89,16 @@ void allocation::run(ir::module &mod) { } // Finalize allocation - for(layout_t* x: V){ + for(layout_shared_t* x: V){ unsigned Adj = 0; - for(layout_t* y: interferences[x]) + for(layout_shared_t* y: interferences[x]) Adj = std::max(Adj, starts[y] + y->size); offsets_[x] = starts[x] + colors[x] * Adj; } // Save maximum size of induced memory space allocated_size_ = 0; - for(layout_t* x: V) + for(layout_shared_t* x: V) allocated_size_ = std::max(allocated_size_, starts[x] + x->size); } diff --git a/lib/codegen/analysis/layout.cc b/lib/codegen/analysis/layout.cc index 6d7c2dc9c..2136d4162 100644 --- a/lib/codegen/analysis/layout.cc +++ b/lib/codegen/analysis/layout.cc @@ -75,11 +75,11 @@ bool is_hmma_c(ir::value *v){ return result; } -const layout_t* layout::get(size_t id) const { +layout_t* layout::get(size_t id) { return layouts_.at(id); } -const layout_t* layout::get(ir::value *v) const { +layout_t* layout::get(ir::value *v) { return layouts_.at(groups_.at(v)); } @@ -140,8 +140,7 @@ layout_t::layout_t(layout_type_t _type, const std::vector &_axes, const std::vector &_shapes, const std::vector &_values, ir::type *_ty, - size_t _id, - analysis::align* align): type(_type), axes(_axes), shapes(_shapes), values(_values), id(_id), ty(_ty) { + analysis::align* align): type(_type), axes(_axes), shapes(_shapes), values(_values), ty(_ty) { // io pointer std::set ptr; for(ir::value* v: values) @@ -159,6 +158,21 @@ layout_t::layout_t(layout_type_t _type, } } +// downcast +layout_hmma_884_t* layout_t::to_hmma884() { + assert(type == HMMA_884); + return static_cast(this); +} + +layout_scanline_t* layout_t::to_scanline() { + assert(type == SCANLINE); + return static_cast(this); +} + +layout_shared_t* layout_t::to_shared() { + assert(type == SHARED); + return static_cast(this); +} inline unsigned clamp(unsigned x, unsigned lo, unsigned hi) { return std::min(std::max(x, lo), hi); @@ -167,8 +181,8 @@ inline unsigned clamp(unsigned x, unsigned lo, unsigned hi) { layout_hmma_884_t::layout_hmma_884_t(size_t num_warps, const std::vector& _axes, const std::vector& _shapes, - const std::vector &values, ir::type *_ty, size_t _id, - analysis::align* align): layout_t(HMMA_884, _axes, _shapes, values, _ty, _id, align) { + const std::vector &values, ir::type *_ty, + analysis::align* align): layout_t(HMMA_884, _axes, _shapes, values, _ty, align) { unsigned shape_0 = shapes[0]; unsigned shape_1 = shapes[1]; /* fragments per warp */ @@ -210,8 +224,7 @@ layout_scanline_t::layout_scanline_t(size_t num_warps, const std::vector& _axes, const std::vector& _shapes, const std::vector &values, ir::type *_ty, - size_t _id, - analysis::align* align): layout_t(SCANLINE, _axes, _shapes, values, _ty, _id, align){ + analysis::align* align): layout_t(SCANLINE, _axes, _shapes, values, _ty, align){ unsigned size = std::accumulate(shapes.begin(), shapes.end(), 1, std::multiplies()); unsigned num_threads = num_warps * 32; nts.resize(shapes.size()); @@ -295,8 +308,7 @@ layout_shared_t::layout_shared_t(const layout_t *arg, const std::vector& _shapes, const std::vector &values, ir::type *ty, - size_t _id, - analysis::align* align): layout_t(SHARED, _axes, _shapes, values, ty, _id, align) { + analysis::align* align): layout_t(SHARED, _axes, _shapes, values, ty, align) { size = 0; @@ -335,7 +347,7 @@ layout_shared_t::layout_shared_t(const layout_t *arg, // else // order = row; // padding - pad = 0; + size_t pad = 0; if(hmma_dot_a){ bool row = is_trans(hmma_dot_a) ^ order[0] != 0; pad = 24 - shapes[row ? 0 : 1] % 32; @@ -375,15 +387,15 @@ void layout::create(size_t id, const std::vector& values) { }); // type if(it_hmma_c != values.end()) - layouts_[id] = new layout_hmma_884_t(num_warps_, axes, shapes, values, largest->get_type()->get_scalar_ty(), id, align_); + layouts_[id] = new layout_hmma_884_t(num_warps_, axes, shapes, values, largest->get_type()->get_scalar_ty(), align_); else if(it_cts != values.end()){ ir::copy_to_shared_inst *cts = (ir::copy_to_shared_inst*)*it_cts; ir::value *arg = cts->get_operand(0); create(groups_.at(arg), values_.at(groups_.at(arg))); - layouts_[id] = new layout_shared_t(get(arg), axes, shapes, values, largest->get_type()->get_scalar_ty(), id, align_); + layouts_[id] = new layout_shared_t(get(arg), axes, shapes, values, largest->get_type()->get_scalar_ty(), align_); } else - layouts_[id] = new layout_scanline_t(num_warps_, axes, shapes, values, largest->get_type()->get_scalar_ty(), id, align_); + layouts_[id] = new layout_scanline_t(num_warps_, axes, shapes, values, largest->get_type()->get_scalar_ty(), align_); } void layout::run(ir::module &mod) { @@ -410,18 +422,18 @@ void layout::run(ir::module &mod) { // shape auto shapes = arg->get_type()->get_tile_shapes(); unsigned shape_ax = shapes[axis]; - const layout_t *layout = get(arg); + layout_scanline_t *layout = get(arg)->to_scanline(); unsigned per_thread = layout->nts[axis]; unsigned depth = shape_ax / per_thread; shapes[axis] = depth; // create layout - layouts_[id] = new layout_shared_t(layout, axes_->get(arg), shapes, {red}, red->get_type()->get_scalar_ty(), id, align_); + layouts_[id] = new layout_shared_t(layout, axes_->get(arg), shapes, {red}, red->get_type()->get_scalar_ty(), align_); tmp_[red] = id; } if(auto *recoalasce = dynamic_cast(i)){ ir::value *val = recoalasce->get_operand(0); - const layout_t* in_layout = get(val); - const layout_t* out_layout = get(i); + layout_t* in_layout = get(val); + layout_t* out_layout = get(i); if(in_layout->type != HMMA_884) return; id++; @@ -431,14 +443,14 @@ void layout::run(ir::module &mod) { shape[ld] = in_shape[ld]; for(size_t k = 0; k < in_shape.size(); k++) if(k != ld) - shape[k] = 4*in_layout->fpw[k]*in_layout->wpt[k]; + shape[k] = 4*in_layout->to_hmma884()->fpw[k]*in_layout->to_hmma884()->wpt[k]; // create layout - layouts_[id] = new layout_shared_t(out_layout, axes_->get(val), shape, {recoalasce}, val->get_type()->get_scalar_ty(), id, align_); + layouts_[id] = new layout_shared_t(out_layout, axes_->get(val), shape, {recoalasce}, val->get_type()->get_scalar_ty(), align_); tmp_[recoalasce] = id; } if(auto *atom = dynamic_cast(i)){ id++; - layouts_[id] = new layout_shared_t(nullptr, {}, {1}, {atom}, atom->get_type()->get_scalar_ty(), id, align_); + layouts_[id] = new layout_shared_t(nullptr, {}, {1}, {atom}, atom->get_type()->get_scalar_ty(), align_); tmp_[atom] = id; } }); diff --git a/lib/codegen/analysis/liveness.cc b/lib/codegen/analysis/liveness.cc index 382f8ef6c..a4bb41f5e 100644 --- a/lib/codegen/analysis/liveness.cc +++ b/lib/codegen/analysis/liveness.cc @@ -27,9 +27,9 @@ void liveness::run(ir::module &mod) { // create live intervals for(auto &x: layouts_->get_all()) { - layout_t* layout = x.second; - if(layout->type != SHARED) + if(x.second->type != SHARED) continue; + layout_shared_t* layout = x.second->to_shared(); // users std::set users; for(ir::value *v: layout->values){ diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index 4d4fe0b11..5cf964915 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -655,13 +655,14 @@ void generator::visit_hmma_dot(ir::dot_inst* dot, shared_tile *TA, shared_tile * "{$8, $9}, " "{$10, $11}, " "{$0, $1, $2, $3, $4, $5, $6, $7};", "=f,=f,=f,=f,=f,=f,=f,=f,r,r,r,r,0,1,2,3,4,5,6,7", false); + analysis::layout_hmma_884_t* layout = layouts_->get(dot)->to_hmma884(); - unsigned fpw_0 = layouts_->get(dot)->fpw.at(0); - unsigned fpw_1 = layouts_->get(dot)->fpw.at(1); + unsigned fpw_0 = layout->fpw.at(0); + unsigned fpw_1 = layout->fpw.at(1); unsigned wts_0 = fpw_0 * 8; unsigned wts_1 = fpw_1 * 8; - unsigned wpt_0 = layouts_->get(dot)->wpt.at(0); - unsigned wpt_1 = layouts_->get(dot)->wpt.at(1); + unsigned wpt_0 = layout->wpt.at(0); + unsigned wpt_1 = layout->wpt.at(1); unsigned stride_rep_i = wpt_0 * wts_0; unsigned stride_rep_j = wpt_1 * wts_1; unsigned num_rep_i = shapes[0] / stride_rep_i; @@ -925,8 +926,8 @@ void generator::visit_recoalesce_inst(ir::recoalesce_inst* rc) { // pointer to temporary shared memory Type *ty = llvm_type(rc->get_type()->get_scalar_ty(), *ctx_); // layouts - const analysis::layout_t* in_layout = layouts_->get(op); - const analysis::layout_t* out_layout = layouts_->get(rc); + analysis::layout_hmma_884_t* in_layout = layouts_->get(op)->to_hmma884(); + analysis::layout_scanline_t* out_layout = layouts_->get(rc)->to_scanline(); // machine tiles distributed_tile *in_dt = (distributed_tile*)(tmap_.at(op)); distributed_tile *out_dt = (distributed_tile*)(tmap_.at(rc)); @@ -1026,14 +1027,14 @@ void generator::visit_recoalesce_inst(ir::recoalesce_inst* rc) { void generator::visit_copy_to_shared_inst(ir::copy_to_shared_inst* cts) { unsigned vector_size = 1; - auto x_order = layouts_->get(cts)->order; ir::value *arg = cts->get_operand(0); - auto arg_order = layouts_->get(arg)->order; + analysis::layout_shared_t* out_layout = layouts_->get(cts)->to_shared(); + analysis::layout_scanline_t* in_layout = layouts_->get(arg)->to_scanline(); + auto out_order = out_layout->order; + auto in_order = in_layout->order; // tiles - if(x_order == arg_order){ - size_t ld = arg_order[0]; - vector_size = layouts_->get(arg)->nts.at(ld); - } + if(out_order == in_order) + vector_size = in_layout->nts.at(in_order[0]); std::map packets; for_each(arg, [&](indices_t idx){ diff --git a/lib/codegen/selection/machine_layout.cc b/lib/codegen/selection/machine_layout.cc index 2d02e7b1f..d1ea9fa0f 100644 --- a/lib/codegen/selection/machine_layout.cc +++ b/lib/codegen/selection/machine_layout.cc @@ -72,7 +72,7 @@ inline int32_t ceil(int32_t num, int32_t div){ machine_layout_shared_t::machine_layout_shared_t(Module *mod, Builder *builder, target *tgt, analysis::allocation* alloc, - Value *&sh_mem_ptr, analysis::layout_t *layout, + Value *&sh_mem_ptr, analysis::layout_shared_t *layout, std::map& vmap, std::map& tmap) : mod_(mod), builder_(builder), tgt_(tgt), alloc_(alloc), sh_mem_ptr_(sh_mem_ptr), layout_(layout), vmap_(vmap), tmap_(tmap) { @@ -132,7 +132,10 @@ machine_layout_distributed_t::machine_layout_distributed_t(Module *mod, Builder tile *machine_layout_distributed_t::create(ir::value *v) { Type *ty = llvm_type(v->get_type()->get_scalar_ty(), builder_->getContext()); const auto &shapes = v->get_type()->get_tile_shapes(); - std::vector axes(shapes.size()); + size_t rank = shapes.size(); + std::vector axes(rank); + std::vector order(rank); + // compute axes for(size_t d = 0; d < shapes.size(); d++){ if(shapes[d] > 1){ unsigned x = a_axes_->get(v, d); @@ -143,7 +146,22 @@ tile *machine_layout_distributed_t::create(ir::value *v) { axes[d].values = {builder_->getInt32(0)}; } } - return new distributed_tile(ty, shapes, layout_->order, axes, *builder_); + // compute order + std::iota(order.begin(), order.end(), 0); + auto cmp = [&](int x, int y) { + unsigned axx = a_axes_->get(v, x); + unsigned axy = a_axes_->get(v, y); + auto itx = std::find(layout_->axes.begin(), layout_->axes.end(), axx); + auto ity = std::find(layout_->axes.begin(), layout_->axes.end(), axy); + size_t posx = std::distance(layout_->axes.begin(), itx); + size_t posy = std::distance(layout_->axes.begin(), ity); + if(posx < rank && posy < rank) + return layout_->order[posx] < layout_->order[posy]; + return false; + }; + std::sort(order.begin(), order.end(), cmp); + + return new distributed_tile(ty, shapes, order, axes, *builder_); } machine_layout_hmma_884_t::machine_layout_hmma_884_t(Module *mod, Builder *builder, diff --git a/lib/codegen/selection/machine_value.cc b/lib/codegen/selection/machine_value.cc index a94661b90..c70ba85b0 100644 --- a/lib/codegen/selection/machine_value.cc +++ b/lib/codegen/selection/machine_value.cc @@ -11,13 +11,6 @@ using namespace llvm; /* Distributed Tile */ void distributed_tile::init_indices() { std::vector id(axes_.size(), 0); - // create iteration order - std::vector order(id.size()); - std::iota(order.begin(), order.end(), 0); - auto cmp = [&](int x, int y) { - return order_[x] < order_[y]; - }; - std::sort(order.begin(), order.end(), cmp); // build size_t k = 0; while(true) { @@ -28,12 +21,12 @@ void distributed_tile::init_indices() { indices_[current] = sz; values_[current] = nullptr; ordered_indices_.push_back(current); - id[order[0]]++; - while(id[order[k]] == axes_[order[k]].values.size()){ + id[order_[0]]++; + while(id[order_[k]] == axes_[order_[k]].values.size()){ if(k == id.size() - 1) return; - id[order[k++]] = 0; - id[order[k]]++; + id[order_[k++]] = 0; + id[order_[k]]++; } k = 0; } diff --git a/lib/codegen/transform/membar.cc b/lib/codegen/transform/membar.cc index 8cb48f7df..1d9aef055 100644 --- a/lib/codegen/transform/membar.cc +++ b/lib/codegen/transform/membar.cc @@ -37,7 +37,7 @@ void membar::add_reference(ir::value *v, interval_vec_t &res){ return; if(alloc_->has_offset(layouts_->get(v))){ unsigned offset = alloc_->offset(layouts_->get(v)); - unsigned size = layouts_->get(v)->size; + unsigned size = layouts_->get(v)->to_shared()->size; res.push_back(interval_t(offset, offset + size)); } } @@ -119,12 +119,14 @@ void membar::run(ir::module &mod) { // without needing synchronization std::set safe_war; for(const auto& x: layouts_->get_all()){ - if(x.second->double_buffer){ - auto info = *x.second->double_buffer; - for(ir::value *v: x.second->values) - if(v != info.phi) - safe_war.insert(v); - } + if(x.second->type != analysis::SHARED) + continue; + analysis::layout_shared_t* layout = x.second->to_shared(); + if(!layout->double_buffer) + continue; + for(ir::value *v: layout->values) + if(v != layout->double_buffer->phi) + safe_war.insert(v); } diff --git a/tests/bench/dot.cc b/tests/bench/dot.cc index 79718a232..d118b95be 100644 --- a/tests/bench/dot.cc +++ b/tests/bench/dot.cc @@ -10,11 +10,11 @@ int main() { typedef std::tuple, bool, bool, int, int, int> config_t; std::vector configs; for(auto ord: std::vector>{{1, 0}}) - for(auto x: std::vector>{{false, false}}){ + for(auto x: std::vector>{{false, false}, {true, false}}){ std::vector tmp = { -// config_t{ord, x[0], x[1], 512, 512, 512}, -// config_t{ord, x[0], x[1], 1024, 1024, 1024}, - config_t{ord, x[0], x[1], 127008, 768, 576}, + config_t{ord, x[0], x[1], 512, 512, 512}, + config_t{ord, x[0], x[1], 2048, 2048, 2048}, +// config_t{ord, x[0], x[1], 127008, 768, 576}, // config_t{ord, x[0], x[1], 8192, 8192, 8192} // config_t{ord, x[0], x[1], 16, 2048, 2048}, // config_t{ord, x[0], x[1], 32, 2048, 2048}, @@ -36,7 +36,7 @@ int main() { for(const auto& c: configs){ std::tie(ord, AT, BT, M, N, K) = c; std::cout << "// " << c ; - for(auto perf: bench_dot(stream, HALF, AT, BT, M, N, K, ord, ord)) + for(auto perf: bench_dot(stream, FLOAT, AT, BT, M, N, K, ord, ord)) std::cout << ", " << perf << std::flush; std::cout << std::endl; }