diff --git a/examples/cpp/shift.cpp b/examples/cpp/shift.cpp index fa7714782..90aeaa595 100644 --- a/examples/cpp/shift.cpp +++ b/examples/cpp/shift.cpp @@ -68,7 +68,7 @@ int main() { // shift std::vector params = { - 4, 2, 32, 8, 2, 32, 8, 4, 2, 2, 8, 8, 4 + 4, 2, 16, 8, 2, 64, 4, 8, 2, 2, 4, 8, 8 }; std::ostringstream oss; shift.src(oss); diff --git a/lib/dnn/shift.cpp b/lib/dnn/shift.cpp index 87b158648..0eae63ddc 100644 --- a/lib/dnn/shift.cpp +++ b/lib/dnn/shift.cpp @@ -43,34 +43,24 @@ shift::shift(int B, int NC, set_ld(shapes_a_, ld_a_); // build LUTs build_deltas(); - build_masks(); } void shift::build_deltas() { + // compute offset + auto offset = [&](unsigned c) { + return c*ld_a_[0] + shift_h_[c]*ld_a_[1] + shift_w_[c]*ld_a_[2]; + }; + // allocate look-up table + size_t TK = 8; h_deltas_ = std::vector(512, 0); - for(unsigned c = 0; c < NC_; c++){ - h_deltas_[c] = c*ld_a_[0]; - h_deltas_[c] += shift_h_[c]*ld_a_[1]; - h_deltas_[c] += shift_w_[c]*ld_a_[2]; + // populate look-up table + for(unsigned c = 0; c < TK; c++){ + h_deltas_[c] = offset(c); // init (shift) + h_deltas_[c + 256] = c*ld_a_[0]; // init (no shift) } for(unsigned c = 0; c < NC_; c++){ - h_deltas_[c + 256] = c*ld_a_[0]; - } -} - -void shift::build_masks() { - size_t S0 = NC_; - size_t S1 = BH_; - size_t S2 = BW_; - h_masks_.resize(S0*S1*S2); - for(size_t ph = 0; ph < S1; ++ph) - for(size_t pw = 0; pw < S2; ++pw){ - int32_t* ptr = &h_masks_[ph*S0 + pw*S0*S1]; - for(size_t i = 0; i < S0; ++i){ - bool in_bounds_h = shift_h_[i] + ph >= 0 && shift_h_[i] + ph < BH_; - bool in_bounds_w = shift_w_[i] + pw >= 0 && shift_w_[i] + pw < BW_; - ptr[i] = in_bounds_h && in_bounds_w; - } + h_deltas_[TK + c] = offset(c + TK) - offset(c); // deltas (shift) + h_deltas_[TK + c + 256] = TK*ld_a_[0]; // deltas (shift) } } @@ -100,9 +90,7 @@ size_t shift::get_nflops() { void shift::init(driver::stream *stream, driver::cu_module *module) { triton::driver::buffer* delta = ((triton::driver::cu_module*)module)->symbol("delta"); - triton::driver::buffer* masks = ((triton::driver::cu_module*)module)->symbol("masks"); stream->write(delta, false, 0, h_deltas_.size()*4, h_deltas_.data()); - stream->write(masks, false, 0, h_masks_.size()*4, h_masks_.data()); } void shift::enqueue(driver::stream *stream, driver::kernel *kernel, @@ -132,9 +120,10 @@ const tunable int32 TN = {16, 32, 64, 128}; const tunable int32 TK = {8}; __constant__ int32* delta = alloc_const int32[512]; -__constant__ int32* masks = alloc_const int32[8192]; -void shift(restrict read_only fp32 *a, restrict read_only fp32 *b, fp32 *c, +void shift(restrict read_only align(16) fp32 *a, + restrict read_only align(16) fp32 *b, + fp32 *c, int32 M, int32 N, int32 K, int32 ABS, int32 AH, int32 AW, int32 AR, int32 AS) { int32 rxa[TM] = get_global_range[TM](0); @@ -142,7 +131,6 @@ void shift(restrict read_only fp32 *a, restrict read_only fp32 *b, fp32 *c, int32 rka[TK] = 0 ... TK; int32 rkb[TK] = 0 ... TK; fp32 C[TM, TN] = 0; - fp32* pxa[TM, TK] = a + rxa[:, newaxis]; fp32* pb[TN, TK] = b + rkb[newaxis, :]*N + ryb[:, newaxis]; int32 pad_h = AR/2; int32 pad_w = AS/2; @@ -152,16 +140,17 @@ void shift(restrict read_only fp32 *a, restrict read_only fp32 *b, fp32 *c, int32 rah[TM] = rahc % AH; int1 maskh[TM] = (rah >= pad_h) && (rah < (AH - pad_h)); int1 maskw[TM] = (raw >= pad_w) && (raw < (AW - pad_w)); - int32 offd[TM] = (maskh && maskw) ? 0 : 256; + int1 mask[TM] = maskh && maskw; + int32 offd[TM] = mask ? 0 : 256; __constant__ int32* pd[TM, TK] = delta + rka[newaxis, :] + offd[:, newaxis]; + fp32* pa[TM, TK] = a + rxa[:, newaxis] + (*pd); for(int32 k = K; k > 0; k = k - TK){ - int32 delta[TM, TK] = *pd; - fp32 *pa[TM, TK] = pxa + delta; fp32 a[TM, TK] = *pa; fp32 b[TN, TK] = *pb; C = dot(a, trans(b), C); pb = pb + TK*N; pd = pd + TK; + pa = pa + (*pd); } int32 rxc[TM] = get_global_range[TM](0); int32 ryc[TN] = get_global_range[TN](1);