diff --git a/examples/cpp/dot.cpp b/examples/cpp/dot.cpp index c0b0ae52d..c19d43e2e 100644 --- a/examples/cpp/dot.cpp +++ b/examples/cpp/dot.cpp @@ -9,7 +9,7 @@ int main() { bool AT = false; - bool BT = false; + bool BT = true; // initialize default compute device auto context = triton::driver::backend::contexts::get_default(); // matrix multiplication parameters diff --git a/examples/python/tensorflow/run.py b/examples/python/tensorflow/run.py index 9824bcea4..88fe7ef3d 100644 --- a/examples/python/tensorflow/run.py +++ b/examples/python/tensorflow/run.py @@ -25,6 +25,7 @@ def run_dot(): # Test hresult = np.dot(ha.T, hb.T).T dif = np.abs(result - hresult) + np.savetxt('dif.dat', dif, '%2.4f') print(hresult) print(result) print("dif: %f" % np.max(dif)) diff --git a/lib/codegen/selection.cpp b/lib/codegen/selection.cpp index f83eea5c7..72ca66ad1 100644 --- a/lib/codegen/selection.cpp +++ b/lib/codegen/selection.cpp @@ -538,8 +538,11 @@ void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id /* intra warp offset */ // offset of quad in pair - Value *in_pair_off_a = builder.CreateMul(builder.CreateUDiv(builder.CreateAnd(u_thread_id, _16), builder.getInt32(4)), builder.getInt32(fpw_0 * pack_size_0_)); - Value *in_pair_off_b = builder.CreateMul(builder.CreateUDiv(builder.CreateAnd(u_thread_id, _16), builder.getInt32(4)), builder.getInt32(fpw_1 * pack_size_1_)); + Value *in_pair_off_a = builder.CreateMul(builder.CreateUDiv(builder.CreateAnd(u_thread_id, _16), builder.getInt32(4)), + builder.getInt32(fpw_0 * pack_size_0_)); + Value *in_pair_off_b = builder.CreateMul(builder.CreateUDiv(builder.CreateAnd(u_thread_id, _16), builder.getInt32(4)), + builder.getInt32(fpw_1 * pack_size_1_)); + // Quad pair id Value *pair_a_id = builder.CreateUDiv(builder.CreateURem(u_thread_id, _16), _4); Value *pair_b_id = builder.CreateUDiv(builder.CreateURem(u_thread_id, _16), _4); @@ -559,15 +562,17 @@ void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id // a offset offset_a_i_ = builder.CreateAdd(warp_offset_i, builder.CreateAdd(pair_a_off, in_pair_off_a)); offset_a_k_ = builder.CreateAnd(u_thread_id, _3); -// // b offsets + // b offsets offset_b_j_ = builder.CreateAdd(warp_offset_j, builder.CreateAdd(pair_b_off, in_pair_off_b)); offset_b_k_ = builder.CreateAnd(u_thread_id, _3); + // c offsets Value *offset_c_i = builder.CreateAdd(builder.CreateAnd(u_thread_id, _1), offset_a_i_); Value *offset_c_j = builder.CreateAdd(builder.CreateAnd(u_thread_id, _2), builder.CreateAdd(warp_offset_j, pair_b_off)); + /* indices */ // i indices std::vector idx_i; @@ -1026,7 +1031,25 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> & Type *fp32_pack8_ty = StructType::get(ctx, {fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}); FunctionType *mma_ty = FunctionType::get(fp32_pack8_ty, {fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false); - InlineAsm *mma_fn = InlineAsm::get(mma_ty, " mma.sync.aligned.m8n8k4.col.row.f32.f16.f16.f32 " + Value *offset_a_i = offset_a_i_; + Value *offset_a_k = offset_a_k_; + Value *offset_b_j = offset_b_j_; + Value *offset_b_k = offset_b_k_; + + Value* u_thread_id = tgt_->get_local_id(builder.GetInsertBlock()->getModule(), builder, 0); + if(dot->is_a_trans()){ + offset_a_i = builder.CreateAdd(offset_a_i, builder.CreateURem(u_thread_id, builder.getInt32(4))); + offset_a_k = builder.getInt32(0); + } + if(!dot->is_b_trans()){ + offset_b_j = builder.CreateAdd(offset_b_j, builder.CreateURem(u_thread_id, builder.getInt32(4))); + offset_b_k = builder.getInt32(0); + } + + std::string op_a = dot->is_a_trans() ? "row" : "col"; + std::string op_b = dot->is_b_trans() ? "row" : "col"; + + InlineAsm *mma_fn = InlineAsm::get(mma_ty, " mma.sync.aligned.m8n8k4." + op_a + "." + op_b + ".f32.f16.f16.f32 " "{$0, $1, $2, $3, $4, $5, $6, $7}, " "{$8, $9}, " "{$10, $11}, " @@ -1046,10 +1069,16 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> & for(unsigned pack_j = 0; pack_j < num_packs_1_; pack_j++){ for(unsigned K = 0; K < NK; K += 4){ Value *_K = builder.getInt32(K); - Value *current_offset_a_i = builder.CreateAdd(offset_a_i_, builder.getInt32(pack_i*stride_rep_i*pack_size_0_)); - Value *current_offset_b_i = builder.CreateAdd(offset_b_j_, builder.getInt32(pack_j*stride_rep_j*pack_size_1_)); - Value *ha = TA->get_value({current_offset_a_i, builder.CreateAdd(offset_a_k_, _K)}); - Value *hb = TB->get_value({current_offset_b_i, builder.CreateAdd(offset_b_k_, _K)}); + Value *current_offset_a_i = builder.CreateAdd(offset_a_i, builder.getInt32(pack_i*stride_rep_i*pack_size_0_)); + Value *current_offset_b_i = builder.CreateAdd(offset_b_j, builder.getInt32(pack_j*stride_rep_j*pack_size_1_)); + indices_t idx_a = {current_offset_a_i, builder.CreateAdd(offset_a_k, _K)}; + indices_t idx_b = {current_offset_b_i, builder.CreateAdd(offset_b_k, _K)}; + if(dot->is_a_trans()) + std::swap(idx_a[0], idx_a[1]); + if(!dot->is_b_trans()) + std::swap(idx_b[0], idx_b[1]); + Value *ha = TA->get_value(idx_a); + Value *hb = TB->get_value(idx_b); for(unsigned ii = 0; ii < pack_size_0_; ii++) for(unsigned jj = 0; jj < pack_size_1_; jj++){ Value *ha0 = builder.CreateExtractElement(ha, builder.getInt32(ii*pack_size_0_ + 0)); diff --git a/lib/codegen/shmem_allocation.cpp b/lib/codegen/shmem_allocation.cpp index 469524b07..eb65b224f 100644 --- a/lib/codegen/shmem_allocation.cpp +++ b/lib/codegen/shmem_allocation.cpp @@ -15,9 +15,22 @@ unsigned shmem_allocation::is_ld_padded(ir::value *x) { if(dynamic_cast(x)) return 4; for(ir::user* user: x->get_users()) - if(dynamic_cast(user)) - if(params_->get_fragment(user, 0) == tune::HMMA_FRAGMENT_C){ - return 16; + if(auto dot = dynamic_cast(user)){ + bool is_hmma = params_->get_fragment(user, 0) == tune::HMMA_FRAGMENT_C; + bool is_op_0 = x == dot->get_operand(0); + bool is_op_1 = x == dot->get_operand(1); + if(is_hmma && is_op_0){ + if(dot->is_a_trans()) + return 20; + else + return 16; + } + if(is_hmma && is_op_1){ + if(!dot->is_b_trans()) + return 20; + else + return 16; + } } if(auto* phi = dynamic_cast(x)) { unsigned result = 0; diff --git a/lib/codegen/tune.cpp b/lib/codegen/tune.cpp index 7baf54fc8..293ebf053 100644 --- a/lib/codegen/tune.cpp +++ b/lib/codegen/tune.cpp @@ -221,7 +221,7 @@ void tune::run(ir::module &mod) { } else { ir::metaparameter *fpw = ir::metaparameter::create(ctx, ty, 2, 2); - ir::metaparameter *wpt = ir::metaparameter::create(ctx, ty, 2, 4); + ir::metaparameter *wpt = ir::metaparameter::create(ctx, ty, 1, 4); connected_components(node, {fpw, wpt}, {"fpw", "wpt"}, nodes_, dependencies_, group_id++); } } @@ -235,7 +235,7 @@ void tune::run(ir::module &mod) { continue; if(dynamic_cast(i) && i->get_type()->is_tile_ty()){ ir::type *ty = mod.get_builder().get_int32_ty(); - std::unique_ptr tmp(ir::metaparameter::create(ctx, ty, 2, 2)); + std::unique_ptr tmp(ir::metaparameter::create(ctx, ty, 4, 4)); *params_.at(i).at("nts.d0") = *tmp; } if(dynamic_cast(i) && i->get_type()->is_tile_ty()){