From 484e3871cfa96747eaa004f694dbcd3af92b95ce Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Sat, 20 Jul 2019 22:05:16 -0700 Subject: [PATCH] [dnn/shift] added base pointer for a, b --- examples/cpp/dot.cpp | 8 ++++---- examples/cpp/shift.cpp | 8 ++++---- lib/codegen/alignment_info.cpp | 6 ++++-- lib/codegen/selection.cpp | 2 +- lib/dnn/shift.cpp | 23 +++++++++++++---------- 5 files changed, 26 insertions(+), 21 deletions(-) diff --git a/examples/cpp/dot.cpp b/examples/cpp/dot.cpp index 09483116e..0eac9c046 100644 --- a/examples/cpp/dot.cpp +++ b/examples/cpp/dot.cpp @@ -8,15 +8,15 @@ int main() { - bool AT = true; - bool BT = false; + bool AT = false; + bool BT = true; typedef float T; std::string ty = "fp16"; size_t dt_nbytes = sizeof(T); // initialize default compute device auto context = triton::driver::backend::contexts::get_default(); // matrix multiplication parameters - int32_t M = 4096, N = 4096, K = 4096; + int32_t M = 65536, N = 2048, K = 2048; std::vector hc(M*N); std::vector rc(M*N); std::vector ha(M*K); @@ -37,7 +37,7 @@ int main() { stream->write(dc, true, 0, hc); stream->synchronize(); triton::dnn::gemm gemm(M, N, K, AT, BT, ty, ty, 4, 4); - gemm.enqueue(stream, {da, db, dc}, false); + gemm.enqueue(stream, {da, db, dc}, true); // stream->read(dc, true, 0, hc); // gemm.cpu_ref(rc, ha, hb); // for(size_t i = 0; i < M*N; i++) diff --git a/examples/cpp/shift.cpp b/examples/cpp/shift.cpp index 739b35117..3dabddfe2 100644 --- a/examples/cpp/shift.cpp +++ b/examples/cpp/shift.cpp @@ -14,13 +14,13 @@ int main() { // initialize default compute device auto context = triton::driver::backend::contexts::get_default(); - auto op = triton::dnn::shift::WGRAD; + auto op = triton::dnn::shift::FPROP; // initialization int32_t R = 3, S = 3; - int32_t B = 16, F = 4096; - int32_t H = 16, W = 16; - int32_t C = 4096; + int32_t B = 64, F = 2048; + int32_t H = 32, W = 32; + int32_t C = 2048; // random shifts std::vector shift_h(C); diff --git a/lib/codegen/alignment_info.cpp b/lib/codegen/alignment_info.cpp index b7e0b3641..ccd9778d1 100644 --- a/lib/codegen/alignment_info.cpp +++ b/lib/codegen/alignment_info.cpp @@ -59,8 +59,9 @@ alignment_info::cst_info alignment_info::populate_is_constant(ir::value *v) { cst_info rhs = populate_is_constant(rhs_op); if(lhs.num_cst==0 && rhs.value && x->is_int_div()){ unsigned max_contiguous = populate_max_contiguous(lhs_op); - unsigned starting_multiple = populate_starting_multiple(lhs_op); - return cache({gcd(max_contiguous, rhs.value) - (starting_multiple % rhs.value), 0}); + // todo might not be entirely true + unsigned num_constants = gcd(max_contiguous, rhs.value); + return cache({num_constants, 0}); } return cache({std::min(lhs.num_cst, rhs.num_cst), 0}); } @@ -300,6 +301,7 @@ void alignment_info::run(ir::module &mod) { for(ir::basic_block *block: fn->blocks()) for(ir::instruction *i: block->get_inst_list()){ populate_max_contiguous(i); +// std::cout << i->get_name() << " " << is_constant_.at(i).num_cst << " " << max_contiguous_.at(i) << " " << starting_multiple_.at(i) << std::endl; } } diff --git a/lib/codegen/selection.cpp b/lib/codegen/selection.cpp index b5cd54a8b..4e4741658 100644 --- a/lib/codegen/selection.cpp +++ b/lib/codegen/selection.cpp @@ -1148,7 +1148,7 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> & unsigned max_contiguous = axis_info_->get_max_contiguous(ptr); unsigned alignment = std::min(starting_multiple, max_contiguous); unsigned vector_size = std::min(result->axis(0).contiguous, alignment); - vector_size = result->axis(0).contiguous; +// vector_size = result->axis(0).contiguous; // vector_size = 1; std::map packets; distributed_tile *TP = (distributed_tile*)tmap_.at(ld->get_pointer_operand()); diff --git a/lib/dnn/shift.cpp b/lib/dnn/shift.cpp index 844c982e7..adc36740c 100644 --- a/lib/dnn/shift.cpp +++ b/lib/dnn/shift.cpp @@ -354,11 +354,6 @@ void shift(restrict read_only align(16) )" + a_ty_ + R"( *A, fp32 acc[TM, TN] = 0; int32 pad_h = BH / 2; int32 pad_w = BW / 2;)"; -if(op_ == WGRAD){ - result += R"( - - )"; -} /* A offsets */ if(op_ == FPROP){ @@ -408,13 +403,21 @@ if(op_ == WGRAD){ rbh = rbh * )" + stride_h + R"(; int32 offkb[TK] = rbb*)" + ldb_b + R"( + rbw*ldb_w + rbh*ldb_h; int32 offb0[TK, TN] = ryb[newaxis, :] * ldb_c; - int32 offb1[TK, TN] = offkb[:, newaxis] + shift;)"; + int32 offb1[TK, TN] = offkb[:, newaxis]; + )" + a_ty_ + "* pa_base[" + AS + R"(] = A + offa0; + )" + b_ty_ + "* pb_base[" + BS + R"(] = B + offb0 + shift; + )" + a_ty_ + "* pa[" + AS + R"(] = pa_base + offa1; + )" + b_ty_ + "* pb[" + BS + R"(] = pb_base + offb1;)"; +} +else{ + result += R"( + )" + a_ty_ + "* pa[" + AS + R"(] = A + offa0 + offa1; + )" + b_ty_ + "* pb[" + BS + R"(] = B + offb0 + offb1;)"; } /* Main loop */ +/* Increment A pointers */ result += R"( - )" + a_ty_ + "* pa[" + AS + R"(] = A + offa0 + offa1; - )" + b_ty_ + "* pb[" + BS + R"(] = B + offb0 + offb1; int1 checka[)" + AS + "] = (rka < K)" + bca0 + R"(; int1 checkb[)" + BS + "] = (rkb < K)" + bcb0 + R"(; )" + a_ty_ + " a[" + AS + R"(] = checka ? *pa : 0; @@ -440,7 +443,7 @@ if(op_ == WGRAD){ rka = rka + TK;)" + compute_bhw("ra", "TK", "rka") + R"( offxa = rab*)" + lda_b + R"( + raw*lda_w + rah*lda_h; - pa = A + offa0 + offxa[:, newaxis];)"; + pa = pa_base + offxa[:, newaxis];)"; } result += R"( @checka a = *pa;)"; @@ -453,7 +456,7 @@ if(op_ == WGRAD){ rbw = rbw * )" + stride_w + R"(; rbh = rbh * )" + stride_h + R"(; offkb = rbb*)" + ldb_b + R"( + rbw*ldb_w + rbh*ldb_h; - pb = B + offb0 + offkb[:, newaxis] + shift;)"; + pb = pb_base + offkb[:, newaxis];)"; } if(op_ == FPROP){ result += R"(