From ab1afbf0825c94640a870f458513aeaf218c732e Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Fri, 28 Jun 2019 17:04:07 -0700 Subject: [PATCH] more performance optimizations --- examples/cpp/shift.cpp | 8 +++--- include/triton/dnn/shift.h | 1 + lib/dnn/shift.cpp | 58 +++++++++++++++++++++----------------- 3 files changed, 37 insertions(+), 30 deletions(-) diff --git a/examples/cpp/shift.cpp b/examples/cpp/shift.cpp index 90aeaa595..f8d0b3ed9 100644 --- a/examples/cpp/shift.cpp +++ b/examples/cpp/shift.cpp @@ -14,9 +14,9 @@ int main() { triton::jit jit(context); // initialization int32_t R = 3, S = 3; - int32_t BS = 4, F = 128; + int32_t BS = 4, F = 512; int32_t H = 32, W = 32; - int32_t C = 128; + int32_t C = 512; // random shifts std::vector shift_h(C); std::vector shift_w(C); @@ -68,12 +68,12 @@ int main() { // shift std::vector params = { - 4, 2, 16, 8, 2, 64, 4, 8, 2, 2, 4, 8, 8 + 32, 2, 128, 16, 2, 128, 16, 8, 2, 2, 4, 2, 8, 8 }; std::ostringstream oss; shift.src(oss); std::string src = oss.str(); -// jit.autotune("shift", src.c_str(), benchmark); + jit.autotune("shift", src.c_str(), benchmark); jit.add_module("shift", src.c_str(), params); triton::driver::kernel* kernel = jit.get_function("shift"); triton::jit::launch_information info = jit.get_launch_info("shift"); diff --git a/include/triton/dnn/shift.h b/include/triton/dnn/shift.h index cec282d34..99a173112 100644 --- a/include/triton/dnn/shift.h +++ b/include/triton/dnn/shift.h @@ -106,6 +106,7 @@ public: } private: + int32_t MAX_C_; // image size int32_t NB_; int32_t NC_; diff --git a/lib/dnn/shift.cpp b/lib/dnn/shift.cpp index 0eae63ddc..330fd9ec8 100644 --- a/lib/dnn/shift.cpp +++ b/lib/dnn/shift.cpp @@ -28,6 +28,8 @@ shift::shift(int B, int NC, shift_h_(shift_h), shift_w_(shift_w), a_ty_(a_ty), b_ty_(b_ty), ty_(ty), bias_(bias) { + // max number of channels + MAX_C_ = 1024; // equivalent matmul M_ = NB_*AH_*AW_; N_ = NF_; @@ -52,16 +54,12 @@ void shift::build_deltas() { }; // allocate look-up table size_t TK = 8; - h_deltas_ = std::vector(512, 0); + h_deltas_.resize(MAX_C_); // populate look-up table - for(unsigned c = 0; c < TK; c++){ - h_deltas_[c] = offset(c); // init (shift) - h_deltas_[c + 256] = c*ld_a_[0]; // init (no shift) - } - for(unsigned c = 0; c < NC_; c++){ - h_deltas_[TK + c] = offset(c + TK) - offset(c); // deltas (shift) - h_deltas_[TK + c + 256] = TK*ld_a_[0]; // deltas (shift) - } + for(unsigned c = 0; c < TK; c++) + h_deltas_[c] = offset(c); + for(unsigned c = 0; c < NC_; c++) + h_deltas_[TK + c] = offset(c + TK) - offset(c); } size_t shift::a_size(){ @@ -102,11 +100,12 @@ void shift::enqueue(driver::stream *stream, driver::kernel *kernel, kernel->setArg(3, M_); kernel->setArg(4, N_); kernel->setArg(5, K_); - kernel->setArg(6, NB_); - kernel->setArg(7, AH_); - kernel->setArg(8, AW_); - kernel->setArg(9, BH_); - kernel->setArg(10, BW_); + kernel->setArg(6, NB_*AH_*AW_); + kernel->setArg(7, NB_); + kernel->setArg(8, AH_); + kernel->setArg(9, AW_); + kernel->setArg(10, BH_); + kernel->setArg(11, BW_); // dry run std::array grid = {(M_ + TM - 1)/TM, (N_ + TN - 1)/TN, 1}; stream->enqueue(kernel, grid, {nthreads, 1, 1}); @@ -119,19 +118,19 @@ const tunable int32 TM = {16, 32, 64, 128}; const tunable int32 TN = {16, 32, 64, 128}; const tunable int32 TK = {8}; -__constant__ int32* delta = alloc_const int32[512]; +__constant__ int32* delta = alloc_const int32[)" << MAX_C_ << R"(]; -void shift(restrict read_only align(16) fp32 *a, - restrict read_only align(16) fp32 *b, +void shift(restrict read_only align(16) )" << a_ty_ << R"( *a, + restrict read_only align(16) )" << b_ty_ << R"( *b, fp32 *c, int32 M, int32 N, int32 K, + int32 lda, int32 ABS, int32 AH, int32 AW, int32 AR, int32 AS) { int32 rxa[TM] = get_global_range[TM](0); int32 ryb[TN] = get_global_range[TN](1); int32 rka[TK] = 0 ... TK; int32 rkb[TK] = 0 ... TK; fp32 C[TM, TN] = 0; - fp32* pb[TN, TK] = b + rkb[newaxis, :]*N + ryb[:, newaxis]; int32 pad_h = AR/2; int32 pad_w = AS/2; int32 rawhc[TM] = rxa / ABS; @@ -140,17 +139,24 @@ void shift(restrict read_only align(16) fp32 *a, int32 rah[TM] = rahc % AH; int1 maskh[TM] = (rah >= pad_h) && (rah < (AH - pad_h)); int1 maskw[TM] = (raw >= pad_w) && (raw < (AW - pad_w)); - int1 mask[TM] = maskh && maskw; - int32 offd[TM] = mask ? 0 : 256; - __constant__ int32* pd[TM, TK] = delta + rka[newaxis, :] + offd[:, newaxis]; - fp32* pa[TM, TK] = a + rxa[:, newaxis] + (*pd); - for(int32 k = K; k > 0; k = k - TK){ - fp32 a[TM, TK] = *pa; - fp32 b[TN, TK] = *pb; + int1 mask[TM, TK] = maskh[:, newaxis] && maskw[:, newaxis]; + __constant__ int32* pd[TK] = delta + rka; + int32 d[TK] = *pd; + int32 offa1[TK] = rka*lda; + int32 inc[TM, TK] = mask ? d[newaxis, :] : offa1[newaxis, :]; + )" << a_ty_ << R"(* pa[TM, TK] = a + rxa[:, newaxis] + inc; + )" << b_ty_ << R"(* pb[TN, TK] = b + rkb[newaxis, :]*N + ryb[:, newaxis]; + )" << a_ty_ << R"( a[TM, TK] = *pa; + )" << b_ty_ << R"( b[TN, TK] = *pb; + for(int32 k = K; k > TK; k = k - TK){ C = dot(a, trans(b), C); pb = pb + TK*N; pd = pd + TK; - pa = pa + (*pd); + d = *pd; + inc = mask ? d[newaxis, :] : TK*lda; + pa = pa + inc; + a = *pa; + b = *pb; } int32 rxc[TM] = get_global_range[TM](0); int32 ryc[TN] = get_global_range[TN](1);