trying interior shift

This commit is contained in:
Philippe Tillet
2019-06-27 14:13:48 -07:00
parent d8526669f5
commit 12e6036e5f
3 changed files with 20 additions and 51 deletions

View File

@@ -7,37 +7,6 @@
#include "triton/tools/bench.hpp"
#include "triton/dnn/shift.h"
// input layout: C, H, W, BS
// filter layout: C, K
// output layout: K, H, W, BS
template<class IN_DTYPE, class OUT_DTYPE>
void shift_conv(int32_t C, int32_t H, int32_t W, int32_t BS,
int32_t K,
std::vector<OUT_DTYPE>& O,
const std::vector<IN_DTYPE>& I,
const std::vector<IN_DTYPE>& F,
const std::vector<int32_t> shift_h,
const std::vector<int32_t> shift_w)
{
OUT_DTYPE acc;
for(int32_t p = 0; p < H; ++p)
for(int32_t q = 0; q < W; ++q)
for(int32_t bs = 0; bs < BS; ++bs)
for(int32_t k = 0; k < K; ++k)
{
acc = 0;
for(int32_t c = 0; c < C; ++c){
int32_t h = p + shift_h[c];
int32_t w = q + shift_w[c];
bool in_bounds = (h >= 0 && w >= 0 && h < H && w < W);
IN_DTYPE a = in_bounds?I[bs + w*BS + h*BS*W + c*BS*H*W]:0;
IN_DTYPE b = F[k + c*K];
acc = std::fma(a, b, acc);
}
O[bs + q*BS + p*BS*W + k*BS*H*W] = acc;
}
}
int main() {
// initialize default compute device
auto context = triton::driver::backend::contexts::get_default();
@@ -104,7 +73,7 @@ int main() {
std::ostringstream oss;
shift.src(oss);
std::string src = oss.str();
// jit.autotune("shift", src.c_str(), benchmark);
jit.autotune("shift", src.c_str(), benchmark);
jit.add_module("shift", src.c_str(), params);
triton::driver::kernel* kernel = jit.get_function("shift");
triton::jit::launch_information info = jit.get_launch_info("shift");

View File

@@ -90,10 +90,13 @@ public:
{
acc = 0;
for(int32_t c = 0; c < NC_; ++c){
int32_t h = p + shift_h_[c];
int32_t w = q + shift_w_[c];
bool in_bounds = (h >= 0 && w >= 0 && h < AH_ && w < AW_);
IN_DTYPE a = in_bounds?I[bs + w*NB_ + h*NB_*AW_ + c*NB_*AH_*AW_]:0;
int32_t h = p;
int32_t w = q;
if(h >= BH_/2 && h < AH_ - BH_/2)
h += shift_h_[c];
if(w > BW_/2 && w < AW_ - BW_/2)
w += shift_w_[c];
IN_DTYPE a = I[bs + w*NB_ + h*NB_*AW_ + c*NB_*AH_*AW_];
IN_DTYPE b = F[k + c*NF_];
acc = std::fma(a, b, acc);
}

View File

@@ -47,7 +47,7 @@ shift::shift(int B, int NC,
}
void shift::build_deltas() {
h_deltas_.resize(NC_);
h_deltas_ = std::vector<int32_t>(512, 0);
for(unsigned c = 0; c < NC_; c++){
h_deltas_[c] = c*ld_a_[0];
h_deltas_[c] += shift_h_[c]*ld_a_[1];
@@ -128,12 +128,12 @@ const tunable int32 TM = {16, 32, 64, 128};
const tunable int32 TN = {16, 32, 64, 128};
const tunable int32 TK = {8};
__constant__ int32* delta = alloc_const int32[256];
__constant__ int32* delta = alloc_const int32[512];
__constant__ int32* masks = alloc_const int32[8192];
void shift(restrict read_only fp32 *a, restrict read_only fp32 *b, fp32 *c,
int32 M, int32 N, int32 K,
int32 ABS, int32 AH, int32 AW, int32 AR, int32 AS){
int32 ABS, int32 AH, int32 AW, int32 AR, int32 AS) {
int32 rxa[TM] = get_global_range[TM](0);
int32 ryb[TN] = get_global_range[TN](1);
int32 rka[TK] = 0 ... TK;
@@ -141,27 +141,24 @@ void shift(restrict read_only fp32 *a, restrict read_only fp32 *b, fp32 *c,
fp32 C[TM, TN] = 0;
fp32* pxa[TM, TK] = a + rxa[:, newaxis];
fp32* pb[TN, TK] = b + rkb[newaxis, :]*N + ryb[:, newaxis];
__constant__ int32* pd[TK] = delta + rka;
int32 pad_h = AR/2;
int32 pad_w = AS/2;
int32 rawhc[TM] = rxa / ABS;
int32 raw[TM] = rawhc % AW - pad_w;
int32 raw[TM] = rawhc % AW;
int32 rahc[TM] = rawhc / AW;
int32 rah[TM] = rahc % AH - pad_h;
int32 maskh[TM] = pad_h + min(rah, 0) + max(rah + AR - AH, 0);
int32 maskw[TM] = pad_w + min(raw, 0) + max(raw + AS - AW, 0);
__constant__ int32* pxm[TM] = masks + maskh*K + maskw*K*(2*pad_h + 1);
__constant__ int32* pm[TM, TK] = pxm[:, newaxis] + rka[newaxis, :];
int32 rah[TM] = rahc % AH;
int1 maskh[TM] = (rah >= pad_h) && (rah < (AH - pad_h));
int1 maskw[TM] = (raw >= pad_w) && (raw < (AW - pad_w));
int32 offd[TM] = (maskh && maskw) ? 0 : 256;
__constant__ int32* pd[TM, TK] = delta + rka[newaxis, :] + offd[:, newaxis];
for(int32 k = K; k > 0; k = k - TK){
int32 delta[TK] = *pd;
fp32 *pa[TM, TK] = pxa + delta[newaxis, :];
int1 m[TM, TK] = *pm > 0;
fp32 a[TM, TK] = m ? *pa : 0;
int32 delta[TM, TK] = *pd;
fp32 *pa[TM, TK] = pxa + delta;
fp32 a[TM, TK] = *pa;
fp32 b[TN, TK] = *pb;
C = dot(a, trans(b), C);
pb = pb + TK*N;
pd = pd + TK;
pm = pm + TK;
}
int32 rxc[TM] = get_global_range[TM](0);
int32 ryc[TN] = get_global_range[TN](1);