trying interior shift
This commit is contained in:
@@ -7,37 +7,6 @@
|
||||
#include "triton/tools/bench.hpp"
|
||||
#include "triton/dnn/shift.h"
|
||||
|
||||
// input layout: C, H, W, BS
|
||||
// filter layout: C, K
|
||||
// output layout: K, H, W, BS
|
||||
template<class IN_DTYPE, class OUT_DTYPE>
|
||||
void shift_conv(int32_t C, int32_t H, int32_t W, int32_t BS,
|
||||
int32_t K,
|
||||
std::vector<OUT_DTYPE>& O,
|
||||
const std::vector<IN_DTYPE>& I,
|
||||
const std::vector<IN_DTYPE>& F,
|
||||
const std::vector<int32_t> shift_h,
|
||||
const std::vector<int32_t> shift_w)
|
||||
{
|
||||
OUT_DTYPE acc;
|
||||
for(int32_t p = 0; p < H; ++p)
|
||||
for(int32_t q = 0; q < W; ++q)
|
||||
for(int32_t bs = 0; bs < BS; ++bs)
|
||||
for(int32_t k = 0; k < K; ++k)
|
||||
{
|
||||
acc = 0;
|
||||
for(int32_t c = 0; c < C; ++c){
|
||||
int32_t h = p + shift_h[c];
|
||||
int32_t w = q + shift_w[c];
|
||||
bool in_bounds = (h >= 0 && w >= 0 && h < H && w < W);
|
||||
IN_DTYPE a = in_bounds?I[bs + w*BS + h*BS*W + c*BS*H*W]:0;
|
||||
IN_DTYPE b = F[k + c*K];
|
||||
acc = std::fma(a, b, acc);
|
||||
}
|
||||
O[bs + q*BS + p*BS*W + k*BS*H*W] = acc;
|
||||
}
|
||||
}
|
||||
|
||||
int main() {
|
||||
// initialize default compute device
|
||||
auto context = triton::driver::backend::contexts::get_default();
|
||||
@@ -104,7 +73,7 @@ int main() {
|
||||
std::ostringstream oss;
|
||||
shift.src(oss);
|
||||
std::string src = oss.str();
|
||||
// jit.autotune("shift", src.c_str(), benchmark);
|
||||
jit.autotune("shift", src.c_str(), benchmark);
|
||||
jit.add_module("shift", src.c_str(), params);
|
||||
triton::driver::kernel* kernel = jit.get_function("shift");
|
||||
triton::jit::launch_information info = jit.get_launch_info("shift");
|
||||
|
@@ -90,10 +90,13 @@ public:
|
||||
{
|
||||
acc = 0;
|
||||
for(int32_t c = 0; c < NC_; ++c){
|
||||
int32_t h = p + shift_h_[c];
|
||||
int32_t w = q + shift_w_[c];
|
||||
bool in_bounds = (h >= 0 && w >= 0 && h < AH_ && w < AW_);
|
||||
IN_DTYPE a = in_bounds?I[bs + w*NB_ + h*NB_*AW_ + c*NB_*AH_*AW_]:0;
|
||||
int32_t h = p;
|
||||
int32_t w = q;
|
||||
if(h >= BH_/2 && h < AH_ - BH_/2)
|
||||
h += shift_h_[c];
|
||||
if(w > BW_/2 && w < AW_ - BW_/2)
|
||||
w += shift_w_[c];
|
||||
IN_DTYPE a = I[bs + w*NB_ + h*NB_*AW_ + c*NB_*AH_*AW_];
|
||||
IN_DTYPE b = F[k + c*NF_];
|
||||
acc = std::fma(a, b, acc);
|
||||
}
|
||||
|
@@ -47,7 +47,7 @@ shift::shift(int B, int NC,
|
||||
}
|
||||
|
||||
void shift::build_deltas() {
|
||||
h_deltas_.resize(NC_);
|
||||
h_deltas_ = std::vector<int32_t>(512, 0);
|
||||
for(unsigned c = 0; c < NC_; c++){
|
||||
h_deltas_[c] = c*ld_a_[0];
|
||||
h_deltas_[c] += shift_h_[c]*ld_a_[1];
|
||||
@@ -128,12 +128,12 @@ const tunable int32 TM = {16, 32, 64, 128};
|
||||
const tunable int32 TN = {16, 32, 64, 128};
|
||||
const tunable int32 TK = {8};
|
||||
|
||||
__constant__ int32* delta = alloc_const int32[256];
|
||||
__constant__ int32* delta = alloc_const int32[512];
|
||||
__constant__ int32* masks = alloc_const int32[8192];
|
||||
|
||||
void shift(restrict read_only fp32 *a, restrict read_only fp32 *b, fp32 *c,
|
||||
int32 M, int32 N, int32 K,
|
||||
int32 ABS, int32 AH, int32 AW, int32 AR, int32 AS){
|
||||
int32 ABS, int32 AH, int32 AW, int32 AR, int32 AS) {
|
||||
int32 rxa[TM] = get_global_range[TM](0);
|
||||
int32 ryb[TN] = get_global_range[TN](1);
|
||||
int32 rka[TK] = 0 ... TK;
|
||||
@@ -141,27 +141,24 @@ void shift(restrict read_only fp32 *a, restrict read_only fp32 *b, fp32 *c,
|
||||
fp32 C[TM, TN] = 0;
|
||||
fp32* pxa[TM, TK] = a + rxa[:, newaxis];
|
||||
fp32* pb[TN, TK] = b + rkb[newaxis, :]*N + ryb[:, newaxis];
|
||||
__constant__ int32* pd[TK] = delta + rka;
|
||||
int32 pad_h = AR/2;
|
||||
int32 pad_w = AS/2;
|
||||
int32 rawhc[TM] = rxa / ABS;
|
||||
int32 raw[TM] = rawhc % AW - pad_w;
|
||||
int32 raw[TM] = rawhc % AW;
|
||||
int32 rahc[TM] = rawhc / AW;
|
||||
int32 rah[TM] = rahc % AH - pad_h;
|
||||
int32 maskh[TM] = pad_h + min(rah, 0) + max(rah + AR - AH, 0);
|
||||
int32 maskw[TM] = pad_w + min(raw, 0) + max(raw + AS - AW, 0);
|
||||
__constant__ int32* pxm[TM] = masks + maskh*K + maskw*K*(2*pad_h + 1);
|
||||
__constant__ int32* pm[TM, TK] = pxm[:, newaxis] + rka[newaxis, :];
|
||||
int32 rah[TM] = rahc % AH;
|
||||
int1 maskh[TM] = (rah >= pad_h) && (rah < (AH - pad_h));
|
||||
int1 maskw[TM] = (raw >= pad_w) && (raw < (AW - pad_w));
|
||||
int32 offd[TM] = (maskh && maskw) ? 0 : 256;
|
||||
__constant__ int32* pd[TM, TK] = delta + rka[newaxis, :] + offd[:, newaxis];
|
||||
for(int32 k = K; k > 0; k = k - TK){
|
||||
int32 delta[TK] = *pd;
|
||||
fp32 *pa[TM, TK] = pxa + delta[newaxis, :];
|
||||
int1 m[TM, TK] = *pm > 0;
|
||||
fp32 a[TM, TK] = m ? *pa : 0;
|
||||
int32 delta[TM, TK] = *pd;
|
||||
fp32 *pa[TM, TK] = pxa + delta;
|
||||
fp32 a[TM, TK] = *pa;
|
||||
fp32 b[TN, TK] = *pb;
|
||||
C = dot(a, trans(b), C);
|
||||
pb = pb + TK*N;
|
||||
pd = pd + TK;
|
||||
pm = pm + TK;
|
||||
}
|
||||
int32 rxc[TM] = get_global_range[TM](0);
|
||||
int32 ryc[TN] = get_global_range[TN](1);
|
||||
|
Reference in New Issue
Block a user