diff --git a/include/triton/codegen/analysis/liveness.h b/include/triton/codegen/analysis/liveness.h index 0f8aea7b4..3aef03a8d 100644 --- a/include/triton/codegen/analysis/liveness.h +++ b/include/triton/codegen/analysis/liveness.h @@ -49,7 +49,7 @@ struct buffer_t { class liveness { private: typedef std::map indices_map_t; - typedef std::map intervals_map_t; + typedef std::map intervals_map_t; typedef std::map has_storage_map_t; typedef ir::value* node_t; typedef std::map > graph_t; @@ -63,24 +63,26 @@ public: private: - void connected_components(node_t x, std::set &nodes, graph_t &graph, unsigned group_id); + void connected_components(node_t x, std::set &nodes, graph_t &graph, buffer_t *buffer); void extract_double_bufferable(ir::instruction *i); void extract_buffers(ir::instruction *i); void get_parents(ir::instruction *i, std::vector& res); void make_graph(ir::instruction *i); + bool do_pad(ir::value *x); public: liveness(tiles *t): tiles_(t){ } + // padding + unsigned get_pad(ir::value *v) const { return pad_.at(v); } // buffer size - unsigned is_ld_padded(ir::value *x); unsigned num_bytes(ir::value *x); // accessors const intervals_map_t& intervals() const { return intervals_; } - segment get_interval(buffer_t v) const { return intervals_.at(v); } + segment get_interval(buffer_t* v) const { return intervals_.at(v); } // buffers - buffer_t get_buffer(ir::value *v) const { return groups_.at(v); } - std::vector get_values(buffer_t x) const { return values_.at(x); } + buffer_t* get_buffer(ir::value *v) const { return groups_.at(v); } + std::vector get_values(buffer_t* x) const { return values_.at(x); } // double-buffering bool has_double(ir::value *x) const { return double_.find(x) != double_.end(); } double_buffer_info_t get_double(ir::value *x) const { return double_.at(x); } @@ -95,12 +97,14 @@ private: indices_map_t indices; intervals_map_t intervals_; std::map double_; + std::map pad_; std::map> parents_; // graph std::set nodes_; graph_t graph_; - std::map groups_; - std::map> values_; + std::vector buffers_; + std::map groups_; + std::map> values_; }; } diff --git a/include/triton/codegen/selection.h b/include/triton/codegen/selection.h index bc236ff22..29241f1c3 100644 --- a/include/triton/codegen/selection.h +++ b/include/triton/codegen/selection.h @@ -89,7 +89,7 @@ private: public: - shared_tile(Type* ty, const shapes_t &shapes, const std::vector &order, Value* ptr, Builder &builder, Value* offset = nullptr); + shared_tile(Type* ty, const shapes_t &shapes, const std::vector &order, Value* ptr, Builder &builder, Value* offset = nullptr, const std::vector& perm = {}); void set_vector_size(unsigned vector_size); void set_return_mode(bool return_vector); void set_value(indices_t, Value *); @@ -97,8 +97,9 @@ public: Value* get_value(indices_t idx); Value* get_pointer() { return ptr_; } Value* get_offset() { return offset_; } + const std::vector& get_perm() { return perm_; } const std::vector& get_order() { return order_; } - static Value* shared_offset(Builder& builder, const shapes_t& shapes, const std::vector& order, indices_t idx); + static Value* shared_offset(Builder& builder, const shapes_t& shapes, const std::vector& perm, const std::vector& order, indices_t idx); private: Value *ptr_; @@ -108,6 +109,7 @@ private: std::map ptr_cache_; unsigned vector_size_; std::vector order_; + std::vector perm_; }; // Distribtued tile diff --git a/include/triton/ir/builder.h b/include/triton/ir/builder.h index 5cf107be3..e254f6d38 100644 --- a/include/triton/ir/builder.h +++ b/include/triton/ir/builder.h @@ -135,7 +135,7 @@ public: value *create_atomic_exch(value *ptr, value *val, const std::string &name = ""); value *create_atomic_add(value *ptr, value *val, const std::string &name = ""); value *create_dot(value *A, value *B, value *C, const std::string &name = ""); - value *create_trans(value *A, const std::vector &perm = {}, const std::string &name = ""); + value *create_trans(value *A, const std::vector &perm = {}, const std::string &name = ""); value *create_sqrt(value *A, const std::string &name = ""); value *create_reduce(value *A, unsigned axis, const std::string &name = ""); value *create_select(value *pred, value *if_value, value *else_value, const std::string &name = ""); diff --git a/include/triton/ir/instructions.h b/include/triton/ir/instructions.h index bafc1c2c3..9298ccbe0 100644 --- a/include/triton/ir/instructions.h +++ b/include/triton/ir/instructions.h @@ -591,7 +591,7 @@ public: private: dot_inst(value *A, value *B, value *C, TransT AT, TransT BT, const std::string &name, instruction *next); - std::string repr_impl() const { return std::string("dot.") + ((AT_==NoTrans)?"n":"t") + ((BT_==NoTrans)?"n":"t"); } + std::string repr_impl() const { return "dot"; } public: static instruction *create(value *A, value *B, value *C, bool AT, bool BT, const std::string &name = "", instruction *next = nullptr); @@ -599,13 +599,7 @@ public: static instruction* create_nt(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr); static instruction* create_tn(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr); static instruction* create_tt(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr); - bool is_a_trans() { return AT_ == Trans; } - bool is_b_trans() { return BT_ == Trans; } _TRITON_DEFINE_CLONE(dot_inst) - -private: - TransT AT_; - TransT BT_; }; //class outer_inst: public builtin_inst { @@ -617,20 +611,20 @@ private: class trans_inst: public builtin_inst { public: - ir::type* get_res_ty(ir::type* in, std::vector perm); - std::vector init_perm(ir::type* ty, const std::vector& perm); + ir::type* get_res_ty(ir::type* in, std::vector perm); + std::vector init_perm(ir::type* ty, const std::vector& perm); private: - trans_inst(value *arg, const std::vector& perm, const std::string& name, instruction* next); + trans_inst(value *arg, const std::vector& perm, const std::string& name, instruction* next); std::string repr_impl() const { return "trans"; } public: - static instruction* create(value *arg, const std::vector& perm = {}, const std::string &name = "", instruction *next = nullptr); - const std::vector get_perm() const; + static instruction* create(value *arg, const std::vector &perm = {}, const std::string &name = "", instruction *next = nullptr); + const std::vector get_perm() const; _TRITON_DEFINE_CLONE(trans_inst) private: - std::vector perm_; + std::vector perm_; }; class sqrt_inst: public builtin_inst { diff --git a/lib/codegen/analysis/align.cc b/lib/codegen/analysis/align.cc index e31799b59..ef57e7a4f 100644 --- a/lib/codegen/analysis/align.cc +++ b/lib/codegen/analysis/align.cc @@ -487,9 +487,6 @@ void align::populate(ir::value *v) { populate_is_constant(v); populate_starting_multiple(v); populate_max_contiguous(v); -// std::cout << v->get_name() << std::endl; -// if(max_contiguous_[v].size() == 2) -// std::cout << max_contiguous_[v][0] << " " << max_contiguous_[v][1] << std::endl; } void align::run(ir::module &mod) { diff --git a/lib/codegen/analysis/allocation.cc b/lib/codegen/analysis/allocation.cc index 21087e680..91ca0868f 100644 --- a/lib/codegen/analysis/allocation.cc +++ b/lib/codegen/analysis/allocation.cc @@ -21,22 +21,22 @@ void allocation::run(ir::module &mod) { using std::min; typedef std::multimap triples_map_type; - std::vector I; + std::vector I; for(auto x: liveness_->intervals()) I.push_back(x.first); - std::vector J = I; + std::vector J = I; triples_map_type H; H.insert({0, segment{0, INT_MAX}}); - std::vector V; - std::map starts; + std::vector V; + std::map starts; while(!J.empty()){ auto h_it = H.begin(); unsigned w = h_it->first; segment xh = h_it->second; H.erase(h_it); - auto j_it = std::find_if(J.begin(), J.end(), [&](buffer_t JJ){ + auto j_it = std::find_if(J.begin(), J.end(), [&](buffer_t* JJ){ segment xj = liveness_->get_interval(JJ); bool res = xj.intersect(xh); for(auto val: H) @@ -44,7 +44,7 @@ void allocation::run(ir::module &mod) { return res; }); if(j_it != J.end()){ - unsigned size = j_it->size; + unsigned size = (*j_it)->size; segment xj = liveness_->get_interval(*j_it); starts[*j_it] = w; H.insert({w + size, segment{max(xh.start, xj.start), min(xh.end, xj.end)}}); @@ -58,14 +58,14 @@ void allocation::run(ir::module &mod) { } // Build interference graph - std::map> interferences; - for(buffer_t x: V) - for(buffer_t y: V){ - if(x.id == y.id) + std::map> interferences; + for(buffer_t* x: V) + for(buffer_t* y: V){ + if(x->id == y->id) continue; unsigned X0 = starts[x], Y0 = starts[y]; - unsigned NX = x.size; - unsigned NY = y.size; + unsigned NX = x->size; + unsigned NY = y->size; segment XS = {X0, X0 + NX}; segment YS = {Y0, Y0 + NY}; if(liveness_->get_interval(x).intersect(liveness_->get_interval(y)) @@ -74,17 +74,17 @@ void allocation::run(ir::module &mod) { } // Initialize colors - std::map colors; - for(buffer_t X: V) - colors[X] = (X.id==V[0].id)?0:-1; + std::map colors; + for(buffer_t* X: V) + colors[X] = (X->id==V[0]->id)?0:-1; // First-fit graph coloring std::vector available(V.size()); - for(buffer_t x: V){ + for(buffer_t* x: V){ // Non-neighboring colors are available std::fill(available.begin(), available.end(), true); - for(buffer_t Y: interferences[x]){ + for(buffer_t* Y: interferences[x]){ int color = colors[Y]; if(color >= 0) available[color] = false; @@ -95,25 +95,24 @@ void allocation::run(ir::module &mod) { } // Finalize allocation - for(buffer_t x: V){ + for(buffer_t* x: V){ unsigned Adj = 0; - for(buffer_t y: interferences[x]) - Adj = std::max(Adj, starts[y] + y.size); + for(buffer_t* y: interferences[x]) + Adj = std::max(Adj, starts[y] + y->size); // create offsets for(ir::value *v: liveness_->get_values(x)){ offsets_[v] = starts[x] + colors[x] * Adj; if(liveness_->has_double(v)){ auto info = liveness_->get_double(v); - offsets_[info.latch] = offsets_[v] + x.size / 2; + offsets_[info.latch] = offsets_[v] + x->size / 2; } } } // Save maximum size of induced memory space allocated_size_ = 0; - for(auto &x: offsets_){ - allocated_size_ = std::max(allocated_size_, x.second + liveness_->get_buffer(x.first).size); - } + for(buffer_t* x: V) + allocated_size_ = std::max(allocated_size_, starts[x] + x->size); } } diff --git a/lib/codegen/analysis/axes.cc b/lib/codegen/analysis/axes.cc index 2c152f439..16614b8a7 100644 --- a/lib/codegen/analysis/axes.cc +++ b/lib/codegen/analysis/axes.cc @@ -74,7 +74,7 @@ void axes::update_graph_trans(ir::instruction *i) { auto perm = trans->get_perm(); // add edge between axis perm[d] and axis d for(unsigned d = 0; d < perm.size(); d++) - add_constraint({i, perm[d]->get_value()}, {op, d}); + add_constraint({i, perm[d]}, {op, d}); } void axes::update_graph_broadcast(ir::instruction *i) { diff --git a/lib/codegen/analysis/liveness.cc b/lib/codegen/analysis/liveness.cc index 13b456cae..ace03a07a 100644 --- a/lib/codegen/analysis/liveness.cc +++ b/lib/codegen/analysis/liveness.cc @@ -58,6 +58,18 @@ void liveness::make_graph(ir::instruction *i) { graph_[i].insert(latch); graph_[latch].insert(i); } + if(i->get_id() == ir::INST_PHI){ + ir::phi_node* phi = (ir::phi_node*)i; + for(ir::value* op: phi->ops()){ + auto* iop = dynamic_cast(op); + if(!iop || storage_info.at(iop->get_id()).first != SHARED) + continue; + nodes_.insert(phi); + nodes_.insert(op); + graph_[phi].insert(op); + graph_[op].insert(phi); + } + } if(i->get_id() == ir::INST_TRANS){ nodes_.insert(i); nodes_.insert(i->get_operand(0)); @@ -67,39 +79,63 @@ void liveness::make_graph(ir::instruction *i) { } // connected components -void liveness::connected_components(node_t x, std::set &nodes, graph_t &graph, unsigned group_id) { - buffer_t buffer{group_id, num_bytes(x)}; +void liveness::connected_components(node_t x, std::set &nodes, graph_t &graph, buffer_t* buffer) { groups_[x] = buffer; values_[buffer].push_back(x); if(nodes.find(x) != nodes.end()){ nodes.erase(x); for(const node_t &y: graph[x]) - connected_components(y, nodes, graph, group_id); + connected_components(y, nodes, graph, buffer); } } -unsigned liveness::is_ld_padded(ir::value *x) { - if(auto *trans = dynamic_cast(x)){ - if(trans->get_perm()[0]->get_value() != 0) - return 4; +bool liveness::do_pad(ir::value *x) { + // alignment for matrix product + if(auto* dot = dynamic_cast(x)) { + auto order = tiles_->order(x); + // a + ir::value *a = dot->get_operand(0);\ + size_t previous_a = pad_[a]; + bool a_trans = dynamic_cast(a); + bool a_row = order[0] == 1; + if(tiles_->hmma(x) == HMMA_A_ROW) + pad_[a] = 16; + else if(tiles_->hmma(x) == HMMA_A_COL) + pad_[a] = 8; + else if(a_trans ^ a_row) + pad_[a] = 4; + else + pad_[a] = 0; + // b + ir::value *b = dot->get_operand(1); + size_t previous_b = pad_[b]; + bool b_trans = dynamic_cast(a); + bool b_col = order[0] == 0; + if(tiles_->hmma(x) == HMMA_B_COL) + pad_[b] = 16; + if(tiles_->hmma(x) == HMMA_B_ROW) + pad_[b] = 8; + if(b_trans ^ b_col) + pad_[b] = 4; + else + pad_[b] = 0; + return previous_a != pad_[a] || previous_b != pad_[b]; } - auto order = tiles_->order(x); - bool is_col_major = order[0] == 0; - if(tiles_->hmma(x) == HMMA_A_ROW) - return is_col_major ? 16 : 16; - if(tiles_->hmma(x) == HMMA_A_COL) - return is_col_major ? 8 : 8; - if(tiles_->hmma(x) == HMMA_B_COL) - return is_col_major ? 16 : 16; - if(tiles_->hmma(x) == HMMA_B_ROW) - return is_col_major ? 8 : 8; + // padding for phi-nodes if(auto* phi = dynamic_cast(x)) { - unsigned result = 0; - for(unsigned i = 0; i < phi->get_num_incoming(); i++) - result = std::max(result, is_ld_padded(phi->get_incoming_value(i))); - return result; + bool has_changed = false; + for(unsigned i = 0; i < phi->get_num_incoming(); i++){ + ir::value* op = phi->get_operand(i); + size_t previous = pad_[op]; + pad_[op] = std::max(pad_[op], pad_[phi]); + has_changed |= previous != pad_[op]; + } + return has_changed; } - return 0; + // default -- no pading + size_t previous = pad_[x]; + pad_[x] = std::max(previous, 0); + return pad_[x] != previous; } unsigned liveness::num_bytes(ir::value *x) { @@ -120,7 +156,8 @@ unsigned liveness::num_bytes(ir::value *x) { return num_elements * num_bytes * depth; } unsigned num_bytes = x->get_type()->get_primitive_size_in_bits() / 8; - unsigned pad = is_ld_padded(x); + unsigned pad = pad_.at(x); + std::cout << x->get_name() << " " << pad << std::endl; if(pad > 0){ unsigned ld = x->get_type()->get_tile_shapes()[tiles_->order(x)[0]]; num_bytes += pad * num_bytes / ld; @@ -134,6 +171,7 @@ unsigned liveness::num_bytes(ir::value *x) { void liveness::run(ir::module &mod) { double_.clear(); indices.clear(); + pad_.clear(); intervals_.clear(); parents_.clear(); @@ -142,6 +180,15 @@ void liveness::run(ir::module &mod) { this->extract_double_bufferable(i); }); + // Padding information + bool has_changed; + do{ + has_changed = false; + ir::for_each_value(mod, [this, &has_changed](ir::value* v){ + has_changed |= this->do_pad(v); + }); + }while(has_changed); + // Create buffer dependency graph ir::for_each_instruction(mod, [this](ir::instruction* i) { this->make_graph(i); @@ -150,7 +197,10 @@ void liveness::run(ir::module &mod) { // connected components unsigned group_id = 0; while(!nodes_.empty()){ - connected_components(*nodes_.begin(), nodes_, graph_, group_id++); + buffer_t* buffer = new buffer_t{group_id++}; + connected_components(*nodes_.begin(), nodes_, graph_, buffer); + for(ir::value *v: values_.at(buffer)) + buffer->size = std::max(buffer->size, num_bytes(v)); } // Assigns index to each instruction diff --git a/lib/codegen/analysis/tiles.cc b/lib/codegen/analysis/tiles.cc index 7f19df276..0bd317f8f 100644 --- a/lib/codegen/analysis/tiles.cc +++ b/lib/codegen/analysis/tiles.cc @@ -40,7 +40,7 @@ bool is_hmma_a_col(ir::value* v) { for(ir::user *u: v->get_users()) if(is_hmma_c(u)){ ir::dot_inst* dot = (ir::dot_inst*)u; - if((v == dot->get_operand(0)) && !dot->is_a_trans()) + if((v == dot->get_operand(0))) return true; } } @@ -49,7 +49,7 @@ bool is_hmma_a_row(ir::value* v) { for(ir::user *u: v->get_users()) if(is_hmma_c(u)){ ir::dot_inst* dot = (ir::dot_inst*)u; - if((v == dot->get_operand(0)) && dot->is_a_trans()) + if((v == dot->get_operand(0))) return true; } } @@ -58,7 +58,7 @@ bool is_hmma_b_col(ir::value* v) { for(ir::user *u: v->get_users()) if(is_hmma_c(u)){ ir::dot_inst* dot = (ir::dot_inst*)u; - if((v == dot->get_operand(1)) && !dot->is_b_trans()) + if((v == dot->get_operand(1))) return true; } } @@ -67,7 +67,7 @@ bool is_hmma_b_row(ir::value* v) { for(ir::user *u: v->get_users()) if(is_hmma_c(u)){ ir::dot_inst* dot = (ir::dot_inst*)u; - if((v == dot->get_operand(1)) && dot->is_b_trans()) + if((v == dot->get_operand(1))) return true; } } @@ -170,6 +170,7 @@ void tiles::init_scanline_tile(ir::value *i) { unsigned effective_num_threads = 1; for(size_t d = 0; d < shapes.size(); d++) effective_num_threads *= mts_[axes_->get_id(i, d)]; +// std::cout << num_threads << " " << effective_num_threads << std::endl; if(num_threads != effective_num_threads) throw std::runtime_error("cannot create a kernel with this amount of warps"); } @@ -219,7 +220,7 @@ void tiles::run(ir::module &) { largest_[i] = *std::max_element(values.begin(), values.end(), cmp); } - // find out the order of a group + // find out the layout ordering of a group for(size_t i = 0; i < num_groups; i++){ std::set io; for(ir::value* v: layout_->values(i)) @@ -239,11 +240,6 @@ void tiles::run(ir::module &) { order_[i] = order; } for(size_t i = 0; i < num_groups; i++){ - bool is_hmma_op = hmma_[i] == HMMA_A_COL || hmma_[i] == HMMA_A_ROW || - hmma_[i] == HMMA_B_COL || hmma_[i] == HMMA_B_ROW; - if(!is_hmma_op) - continue; - // extract copies to shared memory std::vector cts; for(ir::value* v: layout_->values(i)) if(auto *x = dynamic_cast(v)) diff --git a/lib/codegen/selection.cc b/lib/codegen/selection.cc index 60facdae4..c355f9d2f 100644 --- a/lib/codegen/selection.cc +++ b/lib/codegen/selection.cc @@ -146,26 +146,30 @@ void shared_tile::extract_constant(const indices_t &arg_idx, indices_t &non_cst_ } -Value* shared_tile::shared_offset(llvm::IRBuilder<> &builder, const shapes_t& shapes, const std::vector& order, indices_t idx) { +Value* shared_tile::shared_offset(llvm::IRBuilder<> &builder, const shapes_t& shapes, const std::vector& perm, const std::vector& order, indices_t idx) { + // strides + std::vector strides(order.size()); + strides[order[0]] = builder.getInt32(1); + for(size_t i = 1; i < idx.size(); i++) + strides[order[i]] = builder.CreateMul(strides[order[i-1]], builder.getInt32(shapes[order[i-1]])); + // result Value *result = builder.getInt32(0); - result = builder.CreateAdd(result, idx[order[0]]); - Value *ld = builder.getInt32(shapes[order[0]]); - for(size_t i = 1; i < idx.size(); i++) { - result = builder.CreateAdd(result, builder.CreateMul(idx[order[i]], ld)); - if(i < idx.size() - 1){ - ld = builder.CreateMul(ld, builder.getInt32(shapes[order[i]])); - } - } + for(size_t i = 0; i < strides.size(); i++) + result = builder.CreateAdd(result, builder.CreateMul(idx[perm[i]], strides[i])); return result; } -shared_tile::shared_tile(Type *ty, const shapes_t &shapes, const std::vector& order, Value *ptr, llvm::IRBuilder<> &builder, Value *offset): - tile(ty, shapes), order_(order), ptr_(ptr), builder_(builder), offset_(offset), vector_size_(1){ +shared_tile::shared_tile(Type *ty, const shapes_t &shapes, const std::vector& order, Value *ptr, llvm::IRBuilder<> &builder, Value *offset, const std::vector& perm): + tile(ty, shapes), order_(order), ptr_(ptr), builder_(builder), offset_(offset), vector_size_(1), perm_(perm){ return_vector_ = false; + if(perm_.empty()){ + perm_.resize(shapes.size()); + std::iota(perm_.begin(), perm_.end(), 0); + } } void shared_tile::set_value(indices_t idx, Value *value) { - Value *ptr = builder_.CreateGEP(ptr_, shared_offset(builder_, shapes_, order_, idx)); + Value *ptr = builder_.CreateGEP(ptr_, shared_offset(builder_, shapes_, perm_, order_, idx)); unsigned addr_space = ptr->getType()->getPointerAddressSpace(); ptr = builder_.CreateBitCast(ptr, value->getType()->getPointerTo(addr_space)); builder_.CreateStore(value, ptr); @@ -196,7 +200,7 @@ Value* shared_tile::get_value(indices_t idx) { // if(isa(non_cst_idx.front())){ // builder_.SetInsertPoint((Instruction*)non_cst_idx.front()); // } - base_ptr = builder_.CreateGEP(ptr_, shared_offset(builder_, shapes_, order_, non_cst_idx)); + base_ptr = builder_.CreateGEP(ptr_, shared_offset(builder_, shapes_, perm_, order_, non_cst_idx)); if(vector_size_ > 1){ Type *vec_ty = VectorType::get(ty, vector_size); Type *vec_ptr_ty = PointerType::get(vec_ty, base_ptr->getType()->getPointerAddressSpace()); @@ -204,7 +208,7 @@ Value* shared_tile::get_value(indices_t idx) { } // builder_.SetInsertPoint(store); } - Value *offset = shared_offset(builder_, shapes_, order_, cst_idx); + Value *offset = shared_offset(builder_, shapes_, perm_, order_, cst_idx); Value *div = offset; if(vector_size_ > 1) div = builder_.CreateUDiv(offset, builder_.getInt32(vector_size_)); @@ -725,7 +729,7 @@ void selection::create_shared_tile(ir::value *v, IRBuilder<> &builder, Value *sh return; auto order = tiles_->order(v); auto shapes = v->get_type()->get_tile_shapes(); - unsigned pad = liveness_->is_ld_padded(v); + unsigned pad = liveness_->get_pad(v); if(pad > 0) shapes[order[0]] += pad; Type* ty = llvm_type(v->get_type()->get_scalar_ty(), builder.getContext()); @@ -923,7 +927,7 @@ void selection::lower_reduce(ir::reduce_inst *x, LLVMContext &ctx, Function *fn, write_idx.insert(write_idx.begin() + axis, lane); // shared memory write pointer - Value *write_offset = shared_tile::shared_offset(builder, op_tile->get_shapes(), op_tile->get_order(), write_idx); + Value *write_offset = shared_tile::shared_offset(builder, op_tile->get_shapes(), {0, 1}, op_tile->get_order(), write_idx); Value *write_ptr = builder.CreateGEP(base_ptr, write_offset); // initialize shared memory @@ -936,7 +940,7 @@ void selection::lower_reduce(ir::reduce_inst *x, LLVMContext &ctx, Function *fn, indices_t current(write_idx.size(), builder.getInt32(0)); current[axis] = builder.getInt32(i); // shared memory offset - Value *read_offset = shared_tile::shared_offset(builder, op_tile->get_shapes(), op_tile->get_order(), current); + Value *read_offset = shared_tile::shared_offset(builder, op_tile->get_shapes(), {0, 1}, op_tile->get_order(), current); Value *is_active = builder.CreateICmpULT(lane, builder.getInt32(i)); read_offset = builder.CreateSelect(is_active, read_offset, builder.getInt32(0)); // shared memory read pointer @@ -952,7 +956,7 @@ void selection::lower_reduce(ir::reduce_inst *x, LLVMContext &ctx, Function *fn, // result is on the first lane of shared memory indices_t final = write_idx; final[axis] = builder.getInt32(0); - Value *read_offset = shared_tile::shared_offset(builder, op_tile->get_shapes(), op_tile->get_order(), final); + Value *read_offset = shared_tile::shared_offset(builder, op_tile->get_shapes(), {0, 1}, op_tile->get_order(), final); Value *read_ptr = builder.CreateGEP(base_ptr, read_offset); tgt_->add_barrier(module, builder); result = builder.CreateLoad(read_ptr); @@ -1041,11 +1045,7 @@ void selection::lower_copy_to_shared(ir::copy_to_shared_inst *x, LLVMContext &ct void selection::lower_trans(ir::trans_inst *x, LLVMContext &ctx, Function *fn, IRBuilder<> &builder) { shared_tile* in = (shared_tile*)tmap_.at(x->get_operand(0)); - auto in_order = in->get_order(); - std::vector order; - for(auto p: x->get_perm()) - order.push_back(in_order[p->get_value()]); - shared_tile* out = new shared_tile(in->get_ty(), in->get_shapes(), order, in->get_pointer(), builder, in->get_offset()); + shared_tile* out = new shared_tile(in->get_ty(), in->get_shapes(), in->get_order(), in->get_pointer(), builder, in->get_offset(), x->get_perm()); tmap_[x] = out; } @@ -1082,8 +1082,8 @@ void selection::lower_hmma_dot(ir::dot_inst *dot, LLVMContext &ctx, Function *fn auto ord_a = tiles_->order(dot->get_operand(0)); auto ord_b = tiles_->order(dot->get_operand(1)); - bool is_a_row = dot->is_a_trans() ^ ord_a[ord_a.size() - 2] == 1; - bool is_b_row = dot->is_b_trans() ^ ord_b[ord_b.size() - 2] == 1; + bool is_a_row = ord_a[ord_a.size() - 2] == 1; + bool is_b_row = ord_b[ord_b.size() - 2] == 1; if(is_a_row){ offset_a_i = builder.CreateAdd(offset_a_i, builder.CreateURem(u_thread_id, builder.getInt32(4))); @@ -1125,10 +1125,6 @@ void selection::lower_hmma_dot(ir::dot_inst *dot, LLVMContext &ctx, Function *fn Value *current_offset_b_i = builder.CreateAdd(offset_b_j, builder.getInt32(pack_j*stride_rep_j*pack_size_1_)); indices_t idx_a = {current_offset_a_i, builder.CreateAdd(offset_a_k, _K)}; indices_t idx_b = {current_offset_b_i, builder.CreateAdd(offset_b_k, _K)}; - if(dot->is_a_trans()) - std::swap(idx_a[0], idx_a[1]); - if(!dot->is_b_trans()) - std::swap(idx_b[0], idx_b[1]); idx_a.insert(idx_a.end(), x.first.begin(), x.first.end()); idx_b.insert(idx_b.end(), x.first.begin(), x.first.end()); Value *ha = TA->get_value(idx_a); @@ -1188,10 +1184,6 @@ void selection::lower_scanline_dot(ir::dot_inst *dot, LLVMContext &ctx, Function // input indices indices_t a_idx = {idx[0], builder.getInt32(K)}; indices_t b_idx = {builder.getInt32(K), idx[1]}; - if(dot->is_a_trans()) - std::swap(a_idx[0], a_idx[1]); - if(dot->is_b_trans()) - std::swap(b_idx[0], b_idx[1]); // add batching dimension for(size_t i = 2; i < idx.size(); i++){ a_idx.insert(a_idx.end(), idx[i]); @@ -1217,10 +1209,8 @@ void selection::lower_outer_dot(ir::dot_inst *dot, LLVMContext &ctx, Function *f Value *res = TD->get_value(idx); indices_t a_idx = {idx[0], builder.getInt32(0)}; indices_t b_idx = {builder.getInt32(0), idx[1]}; - if(dot->is_a_trans()) - std::swap(a_idx[0], a_idx[1]); - if(dot->is_b_trans()) - std::swap(b_idx[0], b_idx[1]); + std::swap(a_idx[0], a_idx[1]); + std::swap(b_idx[0], b_idx[1]); Value *a = TA->get_value(a_idx); Value *b = TB->get_value(b_idx); if(a->getType() != c_ty) @@ -1243,7 +1233,7 @@ void selection::lower_dot(ir::dot_inst *dot, LLVMContext &ctx, Function *fn, IRB Type *c_ty = llvm_type(D->get_type()->get_scalar_ty(), ctx); Function *f_mul_add = Intrinsic::getDeclaration(module, Intrinsic::fmuladd, {c_ty}); auto A_shapes = A->get_type()->get_tile_shapes(); - size_t red_axis = dot->is_a_trans() ? 0 : 1; + size_t red_axis = 1; unsigned NK = A_shapes[red_axis]; if(NK != 1) { @@ -1552,8 +1542,8 @@ void selection::run(ir::module &src, Module &dst) { offset->addIncoming(next_offset, llvm_inc_block); } else { - unsigned num_bytes = phi->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8; - offset->addIncoming(dst_builder.getInt32(liveness_->num_bytes(phi)/(num_bytes)), llvm_inc_block); + unsigned num_bytes = inst->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8; + offset->addIncoming(dst_builder.getInt32(liveness_->get_buffer(inst)->size / (2*num_bytes)), llvm_inc_block); } ptr->addIncoming(inc_shared->get_pointer(), llvm_inc_block); } diff --git a/lib/codegen/transform/membar.cc b/lib/codegen/transform/membar.cc index aee19110f..ee5821da4 100644 --- a/lib/codegen/transform/membar.cc +++ b/lib/codegen/transform/membar.cc @@ -38,8 +38,8 @@ void membar::add_reference(ir::value *v, interval_vec_t &res){ return; if(alloc_->has_offset(v)){ unsigned offset = alloc_->offset(v); - unsigned num_bytes = liveness_->num_bytes(v); - res.push_back(interval_t(offset, offset + num_bytes)); + unsigned size = liveness_->get_buffer(v)->size; + res.push_back(interval_t(offset, offset + size)); } } diff --git a/lib/codegen/transform/peephole.cc b/lib/codegen/transform/peephole.cc index ca67ecf5a..73b8ff27f 100644 --- a/lib/codegen/transform/peephole.cc +++ b/lib/codegen/transform/peephole.cc @@ -8,37 +8,8 @@ namespace codegen{ namespace transform{ -inline bool is_trans(ir::value *v){ - auto *x = dynamic_cast(v); - if(!x) - return false; - std::vector perm = x->get_perm(); - std::vector ref; - ir::type *int32_ty = ir::type::get_int32_ty(v->get_type()->get_context()); - for(size_t i = 0; i < perm.size(); i++) - ref.push_back(ir::constant_int::get(int32_ty, i)); - std::swap(ref[0], ref[1]); - // true is perm == ref - return std::equal(perm.begin(), perm.end(), ref.begin()); -} - -inline bool is_hmma(ir::value *v){ - bool result = false; - if(auto *x = dynamic_cast(v)){ - ir::value *a = x->get_operand(0); - ir::type *a_ty = a->get_type(); - ir::value *b = x->get_operand(1); - ir::type *b_ty = b->get_type(); - // inputs have to be FP16 - result = a_ty->get_scalar_ty()->is_half_ty() && b_ty->get_scalar_ty()->is_half_ty(); -// reduction has to be multiple of 4 -// result = result && ((a_ty->get_tile_shapes()[1]->get_value() % 4) == 0); - } - return result; -} - ir::value* rewrite_trans_phi_impl(ir::value *value, ir::builder &builder, - const std::vector& perm) { + const std::vector& perm) { if(auto phi = dynamic_cast(value)) { // transpose operands std::vector incs; @@ -106,9 +77,7 @@ bool peephole::rewrite_dot(ir::instruction *value, ir::builder& builder){ ir::value *a = dot->get_operand(0); ir::value *b = dot->get_operand(1); builder.set_insert_point(add); - ir::value * new_dot = builder.insert(ir::dot_inst::create(a, b, other, - dot->is_a_trans(), dot->is_b_trans(), - dot->get_name())); + ir::value * new_dot = builder.insert(ir::dot_inst::create_nn(a, b, other, dot->get_name())); add->replace_all_uses_with(new_dot); return true; } diff --git a/lib/driver/module.cc b/lib/driver/module.cc index f29c830f4..30881d087 100755 --- a/lib/driver/module.cc +++ b/lib/driver/module.cc @@ -241,7 +241,7 @@ std::string cu_module::compile_llvm_module(std::unique_ptr module, cu_module::cu_module(driver::context * context, std::unique_ptr ll_module): cu_module(context, compile_llvm_module(std::move(ll_module), context->device())) { } cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){ -// std::cout << source << std::endl; + std::cout << source << std::endl; cu_context::context_switcher ctx(*context); // JIT compile source-code CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER}; diff --git a/lib/ir/builder.cc b/lib/ir/builder.cc index db2080a4d..caf22348f 100644 --- a/lib/ir/builder.cc +++ b/lib/ir/builder.cc @@ -322,7 +322,7 @@ value *builder::create_dot(value *A, value *B, value *C, const std::string &name return insert(dot_inst::create_nn(A, B, C, name)); } -value *builder::create_trans(value *A, const std::vector& perm, const std::string &name) { +value *builder::create_trans(value *A, const std::vector& perm, const std::string &name) { return insert(trans_inst::create(A, perm, name)); } diff --git a/lib/ir/instructions.cc b/lib/ir/instructions.cc index e89367536..4fdfa797d 100644 --- a/lib/ir/instructions.cc +++ b/lib/ir/instructions.cc @@ -536,7 +536,7 @@ instruction* downcast_inst::create(value *arg, const std::string &name, instruct dot_inst::dot_inst(value *A, value *B, value *C, TransT AT, TransT BT, const std::string &name, instruction *next) - : builtin_inst(C->get_type(), INST_DOT, 3, name, next), AT_(AT), BT_(BT) { + : builtin_inst(C->get_type(), INST_DOT, 3, name, next) { set_operand(0, A); set_operand(1, B); set_operand(2, C); @@ -574,31 +574,30 @@ instruction *dot_inst::create_tt(value *A, value *B, value *C, // trans instructions //===----------------------------------------------------------------------===// -ir::type* trans_inst::get_res_ty(ir::type* ty, std::vector perm) { +ir::type* trans_inst::get_res_ty(ir::type* ty, std::vector perm) { // get argument shapes ir::tile_type::tile_shapes_t arg_shapes = ty->get_tile_shapes(); // permutate argument shapes perm = init_perm(ty, perm); ir::tile_type::tile_shapes_t res_shapes = arg_shapes; for(size_t i = 0; i < perm.size(); i++) - res_shapes[i] = arg_shapes[perm[i]->get_value()]; + res_shapes[i] = arg_shapes[perm[i]]; // construct type return tile_type::get(ty->get_scalar_ty(), res_shapes); } -std::vector trans_inst::init_perm(ir::type* ty, const std::vector& perm) { +std::vector trans_inst::init_perm(ir::type* ty, const std::vector& perm) { if(!perm.empty()) return perm; auto size = ty->get_tile_shapes().size(); - ir::type* int32_ty = type::get_int32_ty(ty->get_context()); - std::vector result; - result.push_back(ir::constant_int::get(int32_ty, size - 1)); + std::vector result; + result.push_back(size - 1); for(size_t i = 0; i < size - 1; i++) - result.push_back(ir::constant_int::get(int32_ty, i)); + result.push_back(i); return result; } -trans_inst::trans_inst(value *arg, const std::vector& perm, const std::string &name, instruction *next) +trans_inst::trans_inst(value *arg, const std::vector &perm, const std::string &name, instruction *next) : builtin_inst(get_res_ty(arg->get_type(), perm), INST_TRANS, 1, name, next) { // sanity check perm_ = init_perm(arg->get_type(), perm); @@ -607,11 +606,11 @@ trans_inst::trans_inst(value *arg, const std::vector& perm, const set_operand(0, arg); } -instruction* trans_inst::create(value *arg, const std::vector &perm, const std::string &name, instruction *next) { +instruction* trans_inst::create(value *arg, const std::vector &perm, const std::string &name, instruction *next) { return new trans_inst(arg, perm, name, next); } -const std::vector trans_inst::get_perm() const { +const std::vector trans_inst::get_perm() const { return perm_; } diff --git a/lib/runtime/function.cc b/lib/runtime/function.cc index f628f9171..e9f5f8921 100644 --- a/lib/runtime/function.cc +++ b/lib/runtime/function.cc @@ -229,6 +229,7 @@ std::unique_ptr function::make_bin(ir::module &module, driver::c reassociate.run(module); dce.run(module); cts.run(module); +// ir::print(module, std::cout); liveness.run(module); allocation.run(module); if(allocation.allocated_size() > context->device()->max_shared_memory()) diff --git a/tests/bench/dot.cc b/tests/bench/dot.cc index 14384bbe8..45541e247 100644 --- a/tests/bench/dot.cc +++ b/tests/bench/dot.cc @@ -27,9 +27,9 @@ inline rt::function::grid_fn_ty grid2d(size_t M, size_t N) { std::vector do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int32_t K){ - typedef half_float::half NumericT; - std::string ty = "half"; - cublasDataType_t cuty = CUDA_R_16F; + typedef float NumericT; + std::string ty = "float"; + cublasDataType_t cuty = CUDA_R_32F; size_t dt_nbytes = sizeof(NumericT); drv::context* context = stream->context(); // leading dimensions @@ -45,9 +45,9 @@ std::vector do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, i opt.defines.push_back({"TYPE", {ty}}); opt.defines.push_back({"AT", {AT?"1":"0"}}); opt.defines.push_back({"BT", {BT?"1":"0"}}); - opt.defines.push_back({"TM", {"128"}}); - opt.defines.push_back({"TN", {"128"}}); - opt.defines.push_back({"TK", {"16"}}); + opt.defines.push_back({"TM", {"64"}}); + opt.defines.push_back({"TN", {"64"}}); + opt.defines.push_back({"TK", {"8"}}); opt.num_warps = {4}; // create function rt::function function(src::dot, opt); @@ -79,10 +79,9 @@ int main() { // shapes to benchmark typedef std::tuple config_t; std::vector configs; - for(auto x: std::vector>{{false, true}, - {true, false}, {true, true}}){ + for(auto x: std::vector>{{false, false}}){ std::vector tmp = { - config_t{x[0], x[1], 4096, 4096, 4096} + config_t{x[0], x[1], 2048, 2048, 2048} // config_t{x[0], x[1], 16, 2048, 2048}, // config_t{x[0], x[1], 32, 2048, 2048}, // config_t{x[0], x[1], 64, 2048, 2048}, diff --git a/tests/common/src/dot.h b/tests/common/src/dot.h index ff80ad6ae..2cc3fa290 100644 --- a/tests/common/src/dot.h +++ b/tests/common/src/dot.h @@ -54,12 +54,12 @@ void dot(TYPE * A, TYPE * B, TYPE * C, TYPE a[SHAPE_A] = *pa; TYPE b[SHAPE_B] = *pb; // reduction loop - for(int k = K; k > 0; k-= TK){ + for(int k = K; k > TK; k-= TK){ c += USEA @ USEB; pa = pa + TK * STRIDE_AK; pb = pb + TK * STRIDE_BK; - a = ((bool[SHAPE_A])(k > TK)) ? *pa : 0; - b = ((bool[SHAPE_B])(k > TK)) ? *pb : 0; + a = *pa; + b = *pb; } // epilogue int rxc[TM] = ridx * TM + 0 ... TM; diff --git a/tests/unit/dot.cc b/tests/unit/dot.cc index b440a1c07..9b3ee06f0 100644 --- a/tests/unit/dot.cc +++ b/tests/unit/dot.cc @@ -139,13 +139,13 @@ int main() { // shapes to benchmark typedef std::tuple config_t; std::vector configs; - for(bool AT: std::array{false}) - for(bool BT: std::array{false}) for(int TM: std::vector{32, 64}) for(int TN: std::vector{32, 64}) - for(int TK: std::vector{16, 32}) - for(int nwarps: std::vector{1, 2, 4, 8}){ - configs.push_back(config_t{HALF, AT, BT, 128, 128, 128, TM, TN, TK, nwarps}); + for(int TK: std::vector{8}) + for(int nwarps: std::vector{1, 2, 4, 8}) + for(bool AT: std::array{false, true}) + for(bool BT: std::array{false, true}){ + configs.push_back(config_t{FLOAT, AT, BT, 128, 128, 128, TM, TN, TK, nwarps}); } // does the work dtype_t dtype;