diff --git a/include/triton/codegen/analysis/allocation.h b/include/triton/codegen/analysis/allocation.h index 3dfede223..b23f11964 100644 --- a/include/triton/codegen/analysis/allocation.h +++ b/include/triton/codegen/analysis/allocation.h @@ -4,6 +4,7 @@ #include #include #include +#include "triton/codegen/analysis/liveness.h" namespace triton{ @@ -25,10 +26,8 @@ class allocation { public: allocation(liveness *live, tiles *params) : liveness_(live), tiles_(params){ } - // utilities - unsigned num_bytes(ir::value *x); - unsigned is_ld_padded(ir::value* x); // accessors + bool has_offset(ir::value *x) const { return offsets_.find(x) != offsets_.end(); } unsigned offset(ir::value *x) const { return offsets_.at(x); } unsigned allocated_size() const { return allocated_size_; } // run @@ -36,7 +35,6 @@ public: private: std::map offsets_; - std::map num_bytes_; size_t allocated_size_; // dependences liveness *liveness_; diff --git a/include/triton/codegen/analysis/liveness.h b/include/triton/codegen/analysis/liveness.h index 52ea33cca..0f8aea7b4 100644 --- a/include/triton/codegen/analysis/liveness.h +++ b/include/triton/codegen/analysis/liveness.h @@ -2,6 +2,8 @@ #define TDL_INCLUDE_IR_CODEGEN_LIVENESS_H #include +#include +#include namespace triton{ @@ -10,6 +12,7 @@ namespace ir{ class phi_node; class function; class module; + class instruction; } namespace codegen{ @@ -17,7 +20,7 @@ namespace analysis{ typedef unsigned slot_index; -class cts; +class tiles; struct segment { slot_index start; @@ -37,21 +40,47 @@ struct double_buffer_info_t { ir::phi_node* phi; }; +struct buffer_t { + unsigned id; + size_t size; + bool operator<(buffer_t other) const { return id < other.id; } +}; + class liveness { private: typedef std::map indices_map_t; - typedef std::map intervals_map_t; + typedef std::map intervals_map_t; typedef std::map has_storage_map_t; + typedef ir::value* node_t; + typedef std::map > graph_t; public: // Intervals iterators using iterator = intervals_map_t::iterator; using const_iterator = intervals_map_t::const_iterator; + + + +private: + void connected_components(node_t x, std::set &nodes, graph_t &graph, unsigned group_id); + void extract_double_bufferable(ir::instruction *i); + void extract_buffers(ir::instruction *i); + void get_parents(ir::instruction *i, std::vector& res); + void make_graph(ir::instruction *i); + + public: + liveness(tiles *t): tiles_(t){ } + // buffer size + unsigned is_ld_padded(ir::value *x); + unsigned num_bytes(ir::value *x); // accessors - const intervals_map_t& intervals() const { return intervals_; } - segment get_interval(ir::value* v) const { return intervals_.at(v); } + const intervals_map_t& intervals() const { return intervals_; } + segment get_interval(buffer_t v) const { return intervals_.at(v); } + // buffers + buffer_t get_buffer(ir::value *v) const { return groups_.at(v); } + std::vector get_values(buffer_t x) const { return values_.at(x); } // double-buffering bool has_double(ir::value *x) const { return double_.find(x) != double_.end(); } double_buffer_info_t get_double(ir::value *x) const { return double_.at(x); } @@ -59,10 +88,19 @@ public: void run(ir::module &mod); private: + // analysis + tiles *tiles_; + // stuff has_storage_map_t has_dedicated_storage_; - indices_map_t indices_; + indices_map_t indices; intervals_map_t intervals_; std::map double_; + std::map> parents_; + // graph + std::set nodes_; + graph_t graph_; + std::map groups_; + std::map> values_; }; } diff --git a/include/triton/codegen/instructions.h b/include/triton/codegen/instructions.h index cecd716e0..e3ad9344d 100644 --- a/include/triton/codegen/instructions.h +++ b/include/triton/codegen/instructions.h @@ -56,7 +56,7 @@ static const std::map storage_info = { { ir::INST_BROADCAST, {DISTRIBUTED, {REPLICATED}}}, { ir::INST_DOWNCAST, {DISTRIBUTED, {REPLICATED}}}, // array arithmetic - { ir::INST_TRANS, {SHARED, {DISTRIBUTED}}}, // TODO: not necessarily + { ir::INST_TRANS, {SHARED, {SHARED}}}, { ir::INST_REDUCE, {SHARED, {DISTRIBUTED}}}, { ir::INST_DOT, {DISTRIBUTED, {SHARED, SHARED, DISTRIBUTED}}}, // terminator diff --git a/lib/codegen/analysis/allocation.cc b/lib/codegen/analysis/allocation.cc index fc2a5ce22..21087e680 100644 --- a/lib/codegen/analysis/allocation.cc +++ b/lib/codegen/analysis/allocation.cc @@ -15,79 +15,28 @@ namespace triton{ namespace codegen{ namespace analysis{ -unsigned allocation::is_ld_padded(ir::value *x) { - if(auto *trans = dynamic_cast(x)){ - if(trans->get_perm()[0]->get_value() != 0) - return 4; - } - auto order = tiles_->order(x); - bool is_col_major = order[0] == 0; - if(tiles_->hmma(x) == HMMA_A_ROW) - return is_col_major ? 16 : 8; - if(tiles_->hmma(x) == HMMA_A_COL) - return is_col_major ? 8 : 16; - if(tiles_->hmma(x) == HMMA_B_COL) - return is_col_major ? 16 : 8; - if(tiles_->hmma(x) == HMMA_B_ROW) - return is_col_major ? 8 : 16; - if(auto* phi = dynamic_cast(x)) { - unsigned result = 0; - for(unsigned i = 0; i < phi->get_num_incoming(); i++) - result = std::max(result, is_ld_padded(phi->get_incoming_value(i))); - return result; - } - return 0; -} - -unsigned allocation::num_bytes(ir::value *x) { - if(auto *red = dynamic_cast(x)){ - unsigned num_bytes = x->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8; - size_t axis = red->get_axis(); - ir::value *op = red->get_operand(0); - auto shapes = op->get_type()->get_tile_shapes(); - shapes.erase(shapes.begin() + axis); - size_t num_elements = 1; - for(auto x: shapes) - num_elements *= x; - size_t depth; - if(tiles_->hmma(x)) - depth = tiles_->wpt(op, axis); - else - depth = tiles_->mts(op, axis); - return num_elements * num_bytes * depth; - } - unsigned num_bytes = x->get_type()->get_primitive_size_in_bits() / 8; - unsigned pad = is_ld_padded(x); - if(pad > 0){ - unsigned ld = x->get_type()->get_tile_shapes()[tiles_->order(x)[0]]; - num_bytes += pad * num_bytes / ld; - } - if(liveness_->has_double(x)) - num_bytes *= 2; - return num_bytes; -} void allocation::run(ir::module &mod) { using std::max; using std::min; typedef std::multimap triples_map_type; - std::vector I; + std::vector I; for(auto x: liveness_->intervals()) I.push_back(x.first); - std::vector J = I; + std::vector J = I; triples_map_type H; H.insert({0, segment{0, INT_MAX}}); - std::vector V; - std::map starts; + std::vector V; + std::map starts; while(!J.empty()){ auto h_it = H.begin(); unsigned w = h_it->first; segment xh = h_it->second; H.erase(h_it); - auto j_it = std::find_if(J.begin(), J.end(), [&](ir::value *JJ){ + auto j_it = std::find_if(J.begin(), J.end(), [&](buffer_t JJ){ segment xj = liveness_->get_interval(JJ); bool res = xj.intersect(xh); for(auto val: H) @@ -95,7 +44,7 @@ void allocation::run(ir::module &mod) { return res; }); if(j_it != J.end()){ - unsigned size = num_bytes(*j_it); + unsigned size = j_it->size; segment xj = liveness_->get_interval(*j_it); starts[*j_it] = w; H.insert({w + size, segment{max(xh.start, xj.start), min(xh.end, xj.end)}}); @@ -109,14 +58,14 @@ void allocation::run(ir::module &mod) { } // Build interference graph - std::map> interferences; - for(ir::value *x: V) - for(ir::value *y: V){ - if(x == y) + std::map> interferences; + for(buffer_t x: V) + for(buffer_t y: V){ + if(x.id == y.id) continue; unsigned X0 = starts[x], Y0 = starts[y]; - unsigned NX = num_bytes(x); - unsigned NY = num_bytes(y); + unsigned NX = x.size; + unsigned NY = y.size; segment XS = {X0, X0 + NX}; segment YS = {Y0, Y0 + NY}; if(liveness_->get_interval(x).intersect(liveness_->get_interval(y)) @@ -125,17 +74,17 @@ void allocation::run(ir::module &mod) { } // Initialize colors - std::map colors; - for(ir::value *X: V) - colors[X] = (X==V[0])?0:-1; + std::map colors; + for(buffer_t X: V) + colors[X] = (X.id==V[0].id)?0:-1; // First-fit graph coloring std::vector available(V.size()); - for(ir::value *x: V){ + for(buffer_t x: V){ // Non-neighboring colors are available std::fill(available.begin(), available.end(), true); - for(ir::value *Y: interferences[x]){ + for(buffer_t Y: interferences[x]){ int color = colors[Y]; if(color >= 0) available[color] = false; @@ -146,21 +95,24 @@ void allocation::run(ir::module &mod) { } // Finalize allocation - for(ir::value *x: V){ + for(buffer_t x: V){ unsigned Adj = 0; - for(ir::value *y: interferences[x]) - Adj = std::max(Adj, starts[y] + num_bytes(y)); - offsets_[x] = starts[x] + colors[x] * Adj; - if(liveness_->has_double(x)){ - auto info = liveness_->get_double(x); - offsets_[info.latch] = offsets_[x] + num_bytes(x) / 2; + for(buffer_t y: interferences[x]) + Adj = std::max(Adj, starts[y] + y.size); + // create offsets + for(ir::value *v: liveness_->get_values(x)){ + offsets_[v] = starts[x] + colors[x] * Adj; + if(liveness_->has_double(v)){ + auto info = liveness_->get_double(v); + offsets_[info.latch] = offsets_[v] + x.size / 2; + } } } // Save maximum size of induced memory space allocated_size_ = 0; for(auto &x: offsets_){ - allocated_size_ = std::max(allocated_size_, x.second + num_bytes(x.first)); + allocated_size_ = std::max(allocated_size_, x.second + liveness_->get_buffer(x.first).size); } } diff --git a/lib/codegen/analysis/liveness.cc b/lib/codegen/analysis/liveness.cc index f6df78b72..13b456cae 100644 --- a/lib/codegen/analysis/liveness.cc +++ b/lib/codegen/analysis/liveness.cc @@ -1,6 +1,9 @@ #include +#include +#include #include "triton/codegen/instructions.h" #include "triton/codegen/analysis/liveness.h" +#include "triton/codegen/analysis/tiles.h" #include "triton/codegen/transform/cts.h" #include "triton/ir/basic_block.h" #include "triton/ir/function.h" @@ -25,7 +28,7 @@ inline bool is_loop_latch(ir::phi_node *phi, ir::instruction *terminator){ throw std::runtime_error("unreachable"); } -inline void extract_double_bufferable(ir::instruction *i, std::map& result) { +void liveness::extract_double_bufferable(ir::instruction *i) { auto* phi = dynamic_cast(i); if(!phi || phi->get_num_incoming() != 2) return; @@ -42,65 +45,142 @@ inline void extract_double_bufferable(ir::instruction *i, std::mapget_id()).first != SHARED || storage_info.at(i_1->get_id()).first != SHARED) return; if(is_latch_1) - result[value_0] = double_buffer_info_t{value_1, phi}; + double_[value_0] = double_buffer_info_t{value_1, phi}; if(is_latch_0) - result[value_1] = double_buffer_info_t{value_0, phi}; + double_[value_1] = double_buffer_info_t{value_0, phi}; } +void liveness::make_graph(ir::instruction *i) { + if(has_double(i)){ + ir::value *latch = double_[i].latch; + nodes_.insert(i); + nodes_.insert(latch); + graph_[i].insert(latch); + graph_[latch].insert(i); + } + if(i->get_id() == ir::INST_TRANS){ + nodes_.insert(i); + nodes_.insert(i->get_operand(0)); + graph_[i].insert(i->get_operand(0)); + graph_[i->get_operand(0)].insert(i); + } +} + +// connected components +void liveness::connected_components(node_t x, std::set &nodes, graph_t &graph, unsigned group_id) { + buffer_t buffer{group_id, num_bytes(x)}; + groups_[x] = buffer; + values_[buffer].push_back(x); + if(nodes.find(x) != nodes.end()){ + nodes.erase(x); + for(const node_t &y: graph[x]) + connected_components(y, nodes, graph, group_id); + } +} + +unsigned liveness::is_ld_padded(ir::value *x) { + if(auto *trans = dynamic_cast(x)){ + if(trans->get_perm()[0]->get_value() != 0) + return 4; + } + auto order = tiles_->order(x); + bool is_col_major = order[0] == 0; + if(tiles_->hmma(x) == HMMA_A_ROW) + return is_col_major ? 16 : 16; + if(tiles_->hmma(x) == HMMA_A_COL) + return is_col_major ? 8 : 8; + if(tiles_->hmma(x) == HMMA_B_COL) + return is_col_major ? 16 : 16; + if(tiles_->hmma(x) == HMMA_B_ROW) + return is_col_major ? 8 : 8; + if(auto* phi = dynamic_cast(x)) { + unsigned result = 0; + for(unsigned i = 0; i < phi->get_num_incoming(); i++) + result = std::max(result, is_ld_padded(phi->get_incoming_value(i))); + return result; + } + return 0; +} + +unsigned liveness::num_bytes(ir::value *x) { + if(auto *red = dynamic_cast(x)){ + unsigned num_bytes = x->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8; + size_t axis = red->get_axis(); + ir::value *op = red->get_operand(0); + auto shapes = op->get_type()->get_tile_shapes(); + shapes.erase(shapes.begin() + axis); + size_t num_elements = 1; + for(auto x: shapes) + num_elements *= x; + size_t depth; + if(tiles_->hmma(x)) + depth = tiles_->wpt(op, axis); + else + depth = tiles_->mts(op, axis); + return num_elements * num_bytes * depth; + } + unsigned num_bytes = x->get_type()->get_primitive_size_in_bits() / 8; + unsigned pad = is_ld_padded(x); + if(pad > 0){ + unsigned ld = x->get_type()->get_tile_shapes()[tiles_->order(x)[0]]; + num_bytes += pad * num_bytes / ld; + } + if(has_double(x)) + num_bytes *= 2; + return num_bytes; +} // Entry point void liveness::run(ir::module &mod) { double_.clear(); - indices_.clear(); + indices.clear(); intervals_.clear(); + parents_.clear(); - // set of pair of values that can be double-buffered + // Create set of pair of values that can be double-buffered ir::for_each_instruction(mod, [this](ir::instruction* i) { - extract_double_bufferable(i, this->double_); + this->extract_double_bufferable(i); }); + // Create buffer dependency graph + ir::for_each_instruction(mod, [this](ir::instruction* i) { + this->make_graph(i); + }); + // connected components + unsigned group_id = 0; + while(!nodes_.empty()){ + connected_components(*nodes_.begin(), nodes_, graph_, group_id++); + } + + // Assigns index to each instruction for(ir::function *fn: mod.get_function_list()){ - // Assigns index to each instruction slot_index index = 0; for(ir::basic_block *block: fn->blocks()) for(ir::instruction *instr: block->get_inst_list()){ index += 1; - indices_.insert({instr, index}); - } - // Liveness analysis - // Creates live intervals - for(auto i: indices_){ - ir::value *v = i.first; - ir::instruction* instr = dynamic_cast(v); - if(!instr) - continue; - if(storage_info.at(instr->get_id()).first != SHARED) - continue; - unsigned start = i.second; - unsigned end = start; - for(ir::value *u: v->get_users()){ - start = std::min(start, indices_.at(u)); - end = std::max(end, indices_.at(u)); - } - intervals_[v] = segment{start, end}; - } - // Double-Buffering - // Arrays are live throughout the end of the loop - auto it = intervals_.begin(); - while(it != intervals_.end()) { - ir::value *x = it->first; - auto dit = double_.find(x); - if(dit != double_.end()) { - ir::value *y = dit->second.latch; - unsigned start = intervals_[x].start; - unsigned end = intervals_[y].end; - intervals_[x] = segment{start, end}; - intervals_.erase(y); - } - it++; + indices.insert({instr, index}); } } + + for(auto x: values_) { + // users + std::set values; + for(ir::value *v: x.second){ + values.insert(v); + for(ir::user *u: v->get_users()) + values.insert(u); + } + // compute intervals + unsigned start = INT32_MAX; + unsigned end = 0; + for(ir::value *u: values){ + start = std::min(start, indices.at(u)); + end = std::max(end, indices.at(u)); + } + intervals_[x.first] = segment{start, end}; + } + } } diff --git a/lib/codegen/selection.cc b/lib/codegen/selection.cc index a20fbaa60..60facdae4 100644 --- a/lib/codegen/selection.cc +++ b/lib/codegen/selection.cc @@ -725,7 +725,7 @@ void selection::create_shared_tile(ir::value *v, IRBuilder<> &builder, Value *sh return; auto order = tiles_->order(v); auto shapes = v->get_type()->get_tile_shapes(); - unsigned pad = alloc_->is_ld_padded(v); + unsigned pad = liveness_->is_ld_padded(v); if(pad > 0) shapes[order[0]] += pad; Type* ty = llvm_type(v->get_type()->get_scalar_ty(), builder.getContext()); @@ -1040,15 +1040,13 @@ void selection::lower_copy_to_shared(ir::copy_to_shared_inst *x, LLVMContext &ct } void selection::lower_trans(ir::trans_inst *x, LLVMContext &ctx, Function *fn, IRBuilder<> &builder) { - shared_tile* result = (shared_tile*)tmap_.at(x); - distributed_tile* in = (distributed_tile*)tmap_.at(x->get_operand(0)); - auto perm = x->get_perm(); - in->for_each([&](indices_t idx){ - indices_t out_idx(idx.size()); - for(size_t i = 0; i < idx.size(); i++) - out_idx[i] = idx[perm[i]->get_value()]; - result->set_value(out_idx, in->get_value(idx)); - }); + shared_tile* in = (shared_tile*)tmap_.at(x->get_operand(0)); + auto in_order = in->get_order(); + std::vector order; + for(auto p: x->get_perm()) + order.push_back(in_order[p->get_value()]); + shared_tile* out = new shared_tile(in->get_ty(), in->get_shapes(), order, in->get_pointer(), builder, in->get_offset()); + tmap_[x] = out; } void selection::lower_hmma_dot(ir::dot_inst *dot, LLVMContext &ctx, Function *fn, IRBuilder<> &builder, @@ -1555,7 +1553,7 @@ void selection::run(ir::module &src, Module &dst) { } else { unsigned num_bytes = phi->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8; - offset->addIncoming(dst_builder.getInt32(alloc_->num_bytes(phi)/(num_bytes)), llvm_inc_block); + offset->addIncoming(dst_builder.getInt32(liveness_->num_bytes(phi)/(num_bytes)), llvm_inc_block); } ptr->addIncoming(inc_shared->get_pointer(), llvm_inc_block); } diff --git a/lib/codegen/transform/membar.cc b/lib/codegen/transform/membar.cc index 6ec14bc09..aee19110f 100644 --- a/lib/codegen/transform/membar.cc +++ b/lib/codegen/transform/membar.cc @@ -36,9 +36,9 @@ void membar::add_reference(ir::value *v, interval_vec_t &res){ auto *i = dynamic_cast(v); if(!i) return; - if(storage_info.at(i->get_id()).first == SHARED){ + if(alloc_->has_offset(v)){ unsigned offset = alloc_->offset(v); - unsigned num_bytes = alloc_->num_bytes(v); + unsigned num_bytes = liveness_->num_bytes(v); res.push_back(interval_t(offset, offset + num_bytes)); } } @@ -97,8 +97,10 @@ std::pairget_double(i); safe_war.insert(i); safe_war.insert(info.latch); + auto *trans = dynamic_cast(info.latch); + if(trans) + safe_war.insert(trans->get_operand(0)); } + if(i->get_id() == ir::INST_TRANS) + safe_war.insert(i); }); for(ir::function *fn: mod.get_function_list()){ @@ -152,9 +159,8 @@ void membar::run(ir::module &mod) { done = (n_inserted_im1 == n_inserted_i); n_inserted_im1 = n_inserted_i; }while(!done); - for(ir::instruction* i: insert_locs){ + for(ir::instruction* i: insert_locs) insert_barrier(i, builder); - } } } diff --git a/lib/codegen/transform/peephole.cc b/lib/codegen/transform/peephole.cc index d490e7bcc..ca67ecf5a 100644 --- a/lib/codegen/transform/peephole.cc +++ b/lib/codegen/transform/peephole.cc @@ -84,70 +84,6 @@ bool peephole::rewrite_trans_phi(ir::instruction* value, ir::builder& builder) { return true; } -bool peephole::rewrite_dot_hmma(ir::dot_inst *dot, ir::builder& builder, bool trans_a, bool trans_b, - ir::value *A, ir::value *B, ir::value *D){ - ir::value *AA = A; - ir::value *BB = B; - if(trans_a){ - AA = ((ir::trans_inst*)A)->get_operand(0); - } - else{ - if(auto *T = dynamic_cast(A)){ - std::vector perm(T->get_perm()); - std::swap(perm[0], perm[1]); - AA = builder.create_trans(T->get_operand(0), perm); - T->replace_all_uses_with(AA); - trans_a = true; - } - } - if(trans_b){ - BB = ((ir::trans_inst*)B)->get_operand(0); - } - else{ - if(auto *T = dynamic_cast(B)){ - std::vector perm(T->get_perm()); - std::swap(perm[0], perm[1]); - BB = builder.create_trans(T->get_operand(0), perm); - T->replace_all_uses_with(BB); - trans_b = true; - } - } - if(!trans_a && !trans_b) - return false; - - ir::instruction *dot_atbt = builder.insert(ir::dot_inst::create(AA, BB, D, trans_a, trans_b)); - dot->replace_all_uses_with(dot_atbt); - - return true; -} - -bool peephole::rewrite_dot_fp32(ir::dot_inst *dot, ir::builder& builder, bool trans_a, bool trans_b, - ir::value *A, ir::value *B, ir::value *D){ - // dot(op(a), trans(b)) - if(trans_b){ - ir::value* BB = ((ir::trans_inst*)B)->get_operand(0); - ir::instruction *NT = builder.insert(ir::dot_inst::create_nt(A, BB, D)); - dot->replace_all_uses_with(NT); - return true; - } - // dot(op(a), b) - if(!trans_b){ - // create permutations - size_t size = B->get_type()->get_tile_shapes().size(); - std::vector perm(size); - ir::type *int32_ty = ir::type::get_int32_ty(B->get_type()->get_context()); - for(size_t i = 0; i < size; i++) - perm[i] = ir::constant_int::get(int32_ty, i); - std::swap(perm[0], perm[1]); - // replace NN -> NT (trans) - ir::value* BB = builder.create_trans(B, perm); - ir::instruction *NT = builder.insert(ir::dot_inst::create_nt(A, BB, D)); - dot->replace_all_uses_with(NT); - return true; - } - return false; -} - bool peephole::rewrite_dot(ir::instruction *value, ir::builder& builder){ // dot(a, b, 0) + c -> dot(a, b, c) auto add = dynamic_cast(value); @@ -176,26 +112,6 @@ bool peephole::rewrite_dot(ir::instruction *value, ir::builder& builder){ add->replace_all_uses_with(new_dot); return true; } - - // dot(a, b, c) - auto dot = dynamic_cast(value); - if(!dot) - return false; - builder.set_insert_point(value); - ir::value *A = dot->get_operand(0); - ir::value *B = dot->get_operand(1); - ir::value *D = dot->get_operand(2); - bool trans_a = is_trans(A); - bool trans_b = is_trans(B); - // only consider dot-nn - if(dot->is_a_trans() || dot->is_b_trans()) - return false; - // hmma - if(is_hmma(dot)){ - return rewrite_dot_hmma(dot, builder, trans_a, trans_b, A, B, D); - } - else - return rewrite_dot_fp32(dot, builder, trans_a, trans_b, A, B, D); } bool peephole::rewrite_unit_red(ir::instruction *value, ir::builder& builder){ diff --git a/lib/runtime/function.cc b/lib/runtime/function.cc index d1b342e45..f628f9171 100644 --- a/lib/runtime/function.cc +++ b/lib/runtime/function.cc @@ -202,11 +202,11 @@ std::unique_ptr function::make_bin(ir::module &module, driver::c // create passes codegen::transform::cts cts; codegen::analysis::align align; - codegen::analysis::liveness liveness; codegen::analysis::axes axes; codegen::analysis::layout layouts(&axes); codegen::transform::coalesce coalesce(&align, &layouts); codegen::analysis::tiles tiles(opt.num_warps, &align, &axes, &layouts); + codegen::analysis::liveness liveness(&tiles); codegen::analysis::allocation allocation(&liveness, &tiles); codegen::transform::membar barriers(&liveness, &allocation); codegen::transform::dce dce; @@ -235,12 +235,10 @@ std::unique_ptr function::make_bin(ir::module &module, driver::c return std::unique_ptr(); barriers.run(module); dce.run(module); - dce.run(module); axes.run(module); layouts.run(module); align.run(module); tiles.run(module); -// ir::print(module, std::cout); selection.run(module, *llvm); // return binary std::unique_ptr res(driver::module::create(context, std::move(llvm))); diff --git a/tests/bench/dot.cc b/tests/bench/dot.cc index 74043d8e5..14384bbe8 100644 --- a/tests/bench/dot.cc +++ b/tests/bench/dot.cc @@ -79,10 +79,8 @@ int main() { // shapes to benchmark typedef std::tuple config_t; std::vector configs; - for(auto x: std::vector>{{false, false}, - {false, true}, - {true, false}, - {true, true}}){ + for(auto x: std::vector>{{false, true}, + {true, false}, {true, true}}){ std::vector tmp = { config_t{x[0], x[1], 4096, 4096, 4096} // config_t{x[0], x[1], 16, 2048, 2048}, diff --git a/tests/common/src/dot.h b/tests/common/src/dot.h index 2168b23b6..ff80ad6ae 100644 --- a/tests/common/src/dot.h +++ b/tests/common/src/dot.h @@ -4,15 +4,15 @@ namespace src { R"( #if AT == 1 #define USEA ^a -#define STRIDE_AK 1 -#define STRIDE_AM lda +#define STRIDE_AK lda +#define STRIDE_AM 1 #define BROADCAST_AK :, newaxis #define BROADCAST_AM newaxis, : #define SHAPE_A TK, TM #else #define USEA a -#define STRIDE_AK lda -#define STRIDE_AM 1 +#define STRIDE_AK 1 +#define STRIDE_AM lda #define BROADCAST_AK newaxis, : #define BROADCAST_AM :, newaxis #define SHAPE_A TM, TK @@ -20,15 +20,15 @@ R"( #if BT == 1 #define USEB ^b -#define STRIDE_BK ldb -#define STRIDE_BN 1 +#define STRIDE_BK 1 +#define STRIDE_BN ldb #define BROADCAST_BK newaxis, : #define BROADCAST_BN :, newaxis #define SHAPE_B TN, TK #else #define USEB b -#define STRIDE_BK 1 -#define STRIDE_BN ldb +#define STRIDE_BK ldb +#define STRIDE_BN 1 #define BROADCAST_BK :, newaxis #define BROADCAST_BN newaxis, : #define SHAPE_B TK, TN diff --git a/tests/unit/dot.cc b/tests/unit/dot.cc index bb75df10e..b440a1c07 100644 --- a/tests/unit/dot.cc +++ b/tests/unit/dot.cc @@ -139,8 +139,8 @@ int main() { // shapes to benchmark typedef std::tuple config_t; std::vector configs; - for(bool AT: std::array{false, true}) - for(bool BT: std::array{false, true}) + for(bool AT: std::array{false}) + for(bool BT: std::array{false}) for(int TM: std::vector{32, 64}) for(int TN: std::vector{32, 64}) for(int TK: std::vector{16, 32})