From a3bf3a1804a5bc28b5d62e6df517ca7606b7571e Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Tue, 24 Sep 2019 19:35:46 -0400 Subject: [PATCH] [codegen] more hmma row-major handling --- include/triton/codegen/selection.h | 7 ++-- lib/codegen/analysis/allocation.cc | 2 +- lib/codegen/analysis/tiles.cc | 15 +++++++++ lib/codegen/selection.cc | 54 ++++++++++++++++++------------ lib/driver/module.cc | 2 +- lib/runtime/function.cc | 3 +- tests/bench/dot.cc | 5 ++- tests/common/src/copy.h | 2 +- 8 files changed, 60 insertions(+), 30 deletions(-) diff --git a/include/triton/codegen/selection.h b/include/triton/codegen/selection.h index df34f2987..bc236ff22 100644 --- a/include/triton/codegen/selection.h +++ b/include/triton/codegen/selection.h @@ -89,7 +89,7 @@ private: public: - shared_tile(Type* ty, const shapes_t &shapes, Value* ptr, Builder &builder, Value* offset = nullptr); + shared_tile(Type* ty, const shapes_t &shapes, const std::vector &order, Value* ptr, Builder &builder, Value* offset = nullptr); void set_vector_size(unsigned vector_size); void set_return_mode(bool return_vector); void set_value(indices_t, Value *); @@ -97,7 +97,8 @@ public: Value* get_value(indices_t idx); Value* get_pointer() { return ptr_; } Value* get_offset() { return offset_; } - static Value* shared_offset(Builder& builder, const shapes_t& shapes, indices_t idx); + const std::vector& get_order() { return order_; } + static Value* shared_offset(Builder& builder, const shapes_t& shapes, const std::vector& order, indices_t idx); private: Value *ptr_; @@ -106,6 +107,7 @@ private: Value *offset_; std::map ptr_cache_; unsigned vector_size_; + std::vector order_; }; // Distribtued tile @@ -123,6 +125,7 @@ public: distributed_tile(Type *ty, const shapes_t& shapes, const std::vector& order, const axes_t &axes, Builder &builder, bool vectorize); void set_value(indices_t idx, Value *v); Value* get_value(indices_t idx); + const std::vector& get_order() { return order_; } unsigned get_linear_index(indices_t idx); indices_t get_ordered_indices(unsigned id); void for_each(std::function fn); diff --git a/lib/codegen/analysis/allocation.cc b/lib/codegen/analysis/allocation.cc index 98813b4c0..8ff77eb25 100644 --- a/lib/codegen/analysis/allocation.cc +++ b/lib/codegen/analysis/allocation.cc @@ -57,7 +57,7 @@ unsigned allocation::num_bytes(ir::value *x) { unsigned num_bytes = x->get_type()->get_primitive_size_in_bits() / 8; unsigned pad = is_ld_padded(x); if(pad > 0){ - unsigned ld = x->get_type()->get_tile_shapes()[0]; + unsigned ld = x->get_type()->get_tile_shapes()[tiles_->order(x)[0]]; num_bytes += pad * num_bytes / ld; } if(liveness_->has_double(x)) diff --git a/lib/codegen/analysis/tiles.cc b/lib/codegen/analysis/tiles.cc index e48795cec..7f19df276 100644 --- a/lib/codegen/analysis/tiles.cc +++ b/lib/codegen/analysis/tiles.cc @@ -218,6 +218,7 @@ void tiles::run(ir::module &) { auto cmp = [&rank](ir::value* x, ir::value *y) { return rank(x) < rank(y); }; largest_[i] = *std::max_element(values.begin(), values.end(), cmp); } + // find out the order of a group for(size_t i = 0; i < num_groups; i++){ std::set io; @@ -237,6 +238,20 @@ void tiles::run(ir::module &) { } order_[i] = order; } + for(size_t i = 0; i < num_groups; i++){ + bool is_hmma_op = hmma_[i] == HMMA_A_COL || hmma_[i] == HMMA_A_ROW || + hmma_[i] == HMMA_B_COL || hmma_[i] == HMMA_B_ROW; + if(!is_hmma_op) + continue; + // extract copies to shared memory + std::vector cts; + for(ir::value* v: layout_->values(i)) + if(auto *x = dynamic_cast(v)) + cts.push_back(x); + if(cts.empty()) + continue; + order_[i] = order(cts[0]->get_operand(0)); + } // tiling parameters for(auto x: largest_){ ir::value *i = x.second; diff --git a/lib/codegen/selection.cc b/lib/codegen/selection.cc index 1bc356723..a20fbaa60 100644 --- a/lib/codegen/selection.cc +++ b/lib/codegen/selection.cc @@ -146,26 +146,26 @@ void shared_tile::extract_constant(const indices_t &arg_idx, indices_t &non_cst_ } -Value* shared_tile::shared_offset(llvm::IRBuilder<> &builder, const shapes_t& shapes, indices_t idx) { +Value* shared_tile::shared_offset(llvm::IRBuilder<> &builder, const shapes_t& shapes, const std::vector& order, indices_t idx) { Value *result = builder.getInt32(0); - result = builder.CreateAdd(result, idx[0]); - Value *ld = builder.getInt32(shapes[0]); + result = builder.CreateAdd(result, idx[order[0]]); + Value *ld = builder.getInt32(shapes[order[0]]); for(size_t i = 1; i < idx.size(); i++) { - result = builder.CreateAdd(result, builder.CreateMul(idx[i], ld)); + result = builder.CreateAdd(result, builder.CreateMul(idx[order[i]], ld)); if(i < idx.size() - 1){ - ld = builder.CreateMul(ld, builder.getInt32(shapes[i])); + ld = builder.CreateMul(ld, builder.getInt32(shapes[order[i]])); } } return result; } -shared_tile::shared_tile(Type *ty, const shapes_t &shapes, Value *ptr, llvm::IRBuilder<> &builder, Value *offset): - tile(ty, shapes), ptr_(ptr), builder_(builder), offset_(offset), vector_size_(1){ +shared_tile::shared_tile(Type *ty, const shapes_t &shapes, const std::vector& order, Value *ptr, llvm::IRBuilder<> &builder, Value *offset): + tile(ty, shapes), order_(order), ptr_(ptr), builder_(builder), offset_(offset), vector_size_(1){ return_vector_ = false; } void shared_tile::set_value(indices_t idx, Value *value) { - Value *ptr = builder_.CreateGEP(ptr_, shared_offset(builder_, shapes_, idx)); + Value *ptr = builder_.CreateGEP(ptr_, shared_offset(builder_, shapes_, order_, idx)); unsigned addr_space = ptr->getType()->getPointerAddressSpace(); ptr = builder_.CreateBitCast(ptr, value->getType()->getPointerTo(addr_space)); builder_.CreateStore(value, ptr); @@ -196,7 +196,7 @@ Value* shared_tile::get_value(indices_t idx) { // if(isa(non_cst_idx.front())){ // builder_.SetInsertPoint((Instruction*)non_cst_idx.front()); // } - base_ptr = builder_.CreateGEP(ptr_, shared_offset(builder_, shapes_, non_cst_idx)); + base_ptr = builder_.CreateGEP(ptr_, shared_offset(builder_, shapes_, order_, non_cst_idx)); if(vector_size_ > 1){ Type *vec_ty = VectorType::get(ty, vector_size); Type *vec_ptr_ty = PointerType::get(vec_ty, base_ptr->getType()->getPointerAddressSpace()); @@ -204,7 +204,7 @@ Value* shared_tile::get_value(indices_t idx) { } // builder_.SetInsertPoint(store); } - Value *offset = shared_offset(builder_, shapes_, cst_idx); + Value *offset = shared_offset(builder_, shapes_, order_, cst_idx); Value *div = offset; if(vector_size_ > 1) div = builder_.CreateUDiv(offset, builder_.getInt32(vector_size_)); @@ -721,10 +721,13 @@ void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id * ------------------- */ void selection::create_shared_tile(ir::value *v, IRBuilder<> &builder, Value *sh_mem_ptr) { + if(tmap_.find(v) != tmap_.end()) + return; + auto order = tiles_->order(v); auto shapes = v->get_type()->get_tile_shapes(); unsigned pad = alloc_->is_ld_padded(v); if(pad > 0) - shapes[0] += pad; + shapes[order[0]] += pad; Type* ty = llvm_type(v->get_type()->get_scalar_ty(), builder.getContext()); // shared copy PointerType *ptr_ty = ty->getPointerTo(sh_mem_ptr->getType()->getPointerAddressSpace()); @@ -744,15 +747,15 @@ void selection::create_shared_tile(ir::value *v, IRBuilder<> &builder, Value *sh Value *pre_ptr = builder.CreateGEP(sh_mem_ptr, builder.getInt32(alloc_->offset(v))); pre_ptr = builder.CreateBitCast(pre_ptr, ptr->getType()); Value *next_ptr = builder.CreateGEP(ptr, offset, "next_ptr"); - tmap_.insert({phi, new shared_tile(ty, shapes, ptr, builder, offset)}); - tmap_.insert({v, new shared_tile(ty, shapes, pre_ptr, builder)}); - tmap_.insert({info.latch, new shared_tile(ty, shapes, next_ptr, builder)}); + tmap_.insert({phi, new shared_tile(ty, shapes, order, ptr, builder, offset)}); + tmap_.insert({v, new shared_tile(ty, shapes, order, pre_ptr, builder)}); + tmap_.insert({info.latch, new shared_tile(ty, shapes, order, next_ptr, builder)}); } else { size_t offset = alloc_->offset(v); Value *ptr = builder.CreateGEP(sh_mem_ptr, builder.getInt32(offset)); ptr = builder.CreateBitCast(ptr, ptr_ty); - tmap_.insert({v, new shared_tile(ty, shapes, ptr, builder)}); + tmap_.insert({v, new shared_tile(ty, shapes, order, ptr, builder)}); } } @@ -920,7 +923,7 @@ void selection::lower_reduce(ir::reduce_inst *x, LLVMContext &ctx, Function *fn, write_idx.insert(write_idx.begin() + axis, lane); // shared memory write pointer - Value *write_offset = shared_tile::shared_offset(builder, op_tile->get_shapes(), write_idx); + Value *write_offset = shared_tile::shared_offset(builder, op_tile->get_shapes(), op_tile->get_order(), write_idx); Value *write_ptr = builder.CreateGEP(base_ptr, write_offset); // initialize shared memory @@ -933,7 +936,7 @@ void selection::lower_reduce(ir::reduce_inst *x, LLVMContext &ctx, Function *fn, indices_t current(write_idx.size(), builder.getInt32(0)); current[axis] = builder.getInt32(i); // shared memory offset - Value *read_offset = shared_tile::shared_offset(builder, op_tile->get_shapes(), current); + Value *read_offset = shared_tile::shared_offset(builder, op_tile->get_shapes(), op_tile->get_order(), current); Value *is_active = builder.CreateICmpULT(lane, builder.getInt32(i)); read_offset = builder.CreateSelect(is_active, read_offset, builder.getInt32(0)); // shared memory read pointer @@ -949,7 +952,7 @@ void selection::lower_reduce(ir::reduce_inst *x, LLVMContext &ctx, Function *fn, // result is on the first lane of shared memory indices_t final = write_idx; final[axis] = builder.getInt32(0); - Value *read_offset = shared_tile::shared_offset(builder, op_tile->get_shapes(), final); + Value *read_offset = shared_tile::shared_offset(builder, op_tile->get_shapes(), op_tile->get_order(), final); Value *read_ptr = builder.CreateGEP(base_ptr, read_offset); tgt_->add_barrier(module, builder); result = builder.CreateLoad(read_ptr); @@ -1077,17 +1080,24 @@ void selection::lower_hmma_dot(ir::dot_inst *dot, LLVMContext &ctx, Function *fn Value *offset_b_k = offset_b_k_; Value* u_thread_id = tgt_->get_local_id(builder.GetInsertBlock()->getModule(), builder, 0); - if(dot->is_a_trans()){ + + auto ord_a = tiles_->order(dot->get_operand(0)); + auto ord_b = tiles_->order(dot->get_operand(1)); + + bool is_a_row = dot->is_a_trans() ^ ord_a[ord_a.size() - 2] == 1; + bool is_b_row = dot->is_b_trans() ^ ord_b[ord_b.size() - 2] == 1; + + if(is_a_row){ offset_a_i = builder.CreateAdd(offset_a_i, builder.CreateURem(u_thread_id, builder.getInt32(4))); offset_a_k = builder.getInt32(0); } - if(!dot->is_b_trans()){ + if(!is_b_row){ offset_b_j = builder.CreateAdd(offset_b_j, builder.CreateURem(u_thread_id, builder.getInt32(4))); offset_b_k = builder.getInt32(0); } - std::string op_a = dot->is_a_trans() ? "row" : "col"; - std::string op_b = dot->is_b_trans() ? "row" : "col"; + std::string op_a = is_a_row ? "row" : "col"; + std::string op_b = is_b_row ? "row" : "col"; InlineAsm *mma_fn = InlineAsm::get(mma_ty, " mma.sync.aligned.m8n8k4." + op_a + "." + op_b + ".f32.f16.f16.f32 " "{$0, $1, $2, $3, $4, $5, $6, $7}, " diff --git a/lib/driver/module.cc b/lib/driver/module.cc index 66c775ac6..f29c830f4 100755 --- a/lib/driver/module.cc +++ b/lib/driver/module.cc @@ -242,7 +242,7 @@ cu_module::cu_module(driver::context * context, std::unique_ptr ll cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){ // std::cout << source << std::endl; - cu_context::context_switcher ctx_switch(*context); + cu_context::context_switcher ctx(*context); // JIT compile source-code CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER}; unsigned int errbufsize = 8096; diff --git a/lib/runtime/function.cc b/lib/runtime/function.cc index 86c0a3f8f..d1b342e45 100644 --- a/lib/runtime/function.cc +++ b/lib/runtime/function.cc @@ -222,13 +222,11 @@ std::unique_ptr function::make_bin(ir::module &module, driver::c axes.run(module); layouts.run(module); coalesce.run(module); -// ir::print(module, std::cout); dce.run(module); align.run(module); dce.run(module); tiles.run(module); reassociate.run(module); - peephole.run(module); dce.run(module); cts.run(module); liveness.run(module); @@ -242,6 +240,7 @@ std::unique_ptr function::make_bin(ir::module &module, driver::c layouts.run(module); align.run(module); tiles.run(module); +// ir::print(module, std::cout); selection.run(module, *llvm); // return binary std::unique_ptr res(driver::module::create(context, std::move(llvm))); diff --git a/tests/bench/dot.cc b/tests/bench/dot.cc index 4f6c989e9..74043d8e5 100644 --- a/tests/bench/dot.cc +++ b/tests/bench/dot.cc @@ -79,7 +79,10 @@ int main() { // shapes to benchmark typedef std::tuple config_t; std::vector configs; - for(auto x: std::vector>{{false, false}}){ + for(auto x: std::vector>{{false, false}, + {false, true}, + {true, false}, + {true, true}}){ std::vector tmp = { config_t{x[0], x[1], 4096, 4096, 4096} // config_t{x[0], x[1], 16, 2048, 2048}, diff --git a/tests/common/src/copy.h b/tests/common/src/copy.h index c6263d4bb..f45f7a5cd 100644 --- a/tests/common/src/copy.h +++ b/tests/common/src/copy.h @@ -59,7 +59,7 @@ void copy3d(TYPE * X __noalias __readonly __aligned(16), } )"; - const char* copy_nd[] = {copy1d, copy2d, copy3d}; + const char* copy_nd[] = {copy1d, copy2d, copy3d}; }