[codegen] shift: added sketch for shift-convolution backpropagation
This commit is contained in:
@@ -114,6 +114,8 @@ void shift::init(driver::stream *stream, driver::cu_module *module) {
|
||||
void shift::enqueue(driver::stream *stream, driver::kernel *kernel,
|
||||
driver::buffer *a, driver::buffer *b, driver::buffer *c,
|
||||
size_t TM, size_t TN, size_t nthreads) {
|
||||
if(ty_ == WGRAD)
|
||||
std::swap(a, b);
|
||||
kernel->setArg(0, a);
|
||||
kernel->setArg(1, b);
|
||||
kernel->setArg(2, c);
|
||||
@@ -121,24 +123,35 @@ void shift::enqueue(driver::stream *stream, driver::kernel *kernel,
|
||||
kernel->setArg(4, N_);
|
||||
kernel->setArg(5, K_);
|
||||
kernel->setArg(6, B_*AH_*AW_);
|
||||
kernel->setArg(7, B_);
|
||||
kernel->setArg(8, AH_);
|
||||
kernel->setArg(9, AW_);
|
||||
kernel->setArg(10, BH_);
|
||||
kernel->setArg(11, BW_);
|
||||
kernel->setArg(7, N_);
|
||||
kernel->setArg(8, B_);
|
||||
kernel->setArg(9, AH_);
|
||||
kernel->setArg(10, AW_);
|
||||
kernel->setArg(11, BH_);
|
||||
kernel->setArg(12, BW_);
|
||||
std::array<size_t, 3> grid = {(M_ + TM - 1)/TM, (N_ + TN - 1)/TN, 1};
|
||||
if(ty_ == BPROP)
|
||||
((driver::cu_buffer*)c)->set_zero(stream, M_*N_*4);
|
||||
stream->enqueue(kernel, grid, {nthreads, 1, 1});
|
||||
}
|
||||
|
||||
void shift::src(std::ostream &os) {
|
||||
std::string AS0 = "TM", AS1 = "TK";
|
||||
std::string BS0 = "TK", BS1 = "TN";
|
||||
std::string bca0 = "[newaxis, :]", bca1 = "[:, newaxis]";
|
||||
std::string bcb0 = "[:, newaxis]", bcb1 = "[newaxis, :]";
|
||||
std::string lda0 = "*lda", lda1 = "";
|
||||
std::string ldb0 = "", ldb1 = "*ldb";
|
||||
std::string usea = AT_ ? "trans(a)" : "a";
|
||||
std::string useb = BT_ ? "trans(b)" : "b";
|
||||
std::string rkb = "rkb";
|
||||
std::string rka = "rka";
|
||||
std::string bca0 = "[newaxis, :]", bca1 = "[:, newaxis]";
|
||||
std::string lda0 = "*lda", lda1 = "";
|
||||
if(ty_ == FPROP){
|
||||
rka = "inc";
|
||||
bca0 = "";
|
||||
lda0 = "";
|
||||
}
|
||||
|
||||
if(AT_){
|
||||
std::swap(AS0, AS1);
|
||||
std::swap(bca0, bca1);
|
||||
@@ -149,6 +162,8 @@ void shift::src(std::ostream &os) {
|
||||
std::swap(bcb0, bcb1);
|
||||
std::swap(ldb0, ldb1);
|
||||
}
|
||||
std::string AS = AS0 + ", " + AS1;
|
||||
std::string BS = BS0 + ", " + BS1;
|
||||
|
||||
os <<
|
||||
R"(
|
||||
@@ -161,8 +176,8 @@ __constant__ int32* delta = alloc_const int32[)" << MAX_C_ << R"(];
|
||||
void shift(restrict read_only align(16) )" << a_ty_ << R"( *a,
|
||||
restrict read_only align(16) )" << b_ty_ << R"( *b,
|
||||
fp32 *c,
|
||||
multiple_of(4) int32 M, multiple_of(4) int32 N, multiple_of(4) int32 K,
|
||||
multiple_of(4) int32 lda,
|
||||
int32 M, int32 N, int32 K,
|
||||
multiple_of(4) int32 lda, multiple_of(4) int32 ldb,
|
||||
int32 ABS, int32 AH, int32 AW, int32 AR, int32 AS) {
|
||||
int32 rxa[TM] = get_global_range[TM](0);
|
||||
int32 ryb[TN] = get_global_range[TN](1);
|
||||
@@ -170,7 +185,9 @@ void shift(restrict read_only align(16) )" << a_ty_ << R"( *a,
|
||||
int32 rkb[TK] = 0 ... TK;
|
||||
fp32 C[TM, TN] = 0;
|
||||
int32 pad_h = AR / 2;
|
||||
int32 pad_w = AS / 2;
|
||||
int32 pad_w = AS / 2;)";
|
||||
if(ty_ == FPROP){
|
||||
os << R"(
|
||||
int32 rawhc[TM] = rxa / ABS;
|
||||
int32 raw[TM] = rawhc % AW;
|
||||
int32 rahc[TM] = rawhc / AW;
|
||||
@@ -179,35 +196,86 @@ void shift(restrict read_only align(16) )" << a_ty_ << R"( *a,
|
||||
multiple_of(4) int32 d[TK] = *pd;
|
||||
int1 maskh[TM] = (rah >= pad_h) && (rah < (AH - pad_h));
|
||||
int1 maskw[TM] = (raw >= pad_w) && (raw < (AW - pad_w));
|
||||
int1 mask[)" << AS0 << ", " << AS1 << "] = maskh" << bca1 << " && maskw" << bca1 << R"(;
|
||||
int32 inc_true[)" << AS0 << ", " << AS1 << "] = d" << bca0 << R"(;
|
||||
int32 inc_false[)" << AS0 << ", " << AS1 << "] = rka" << bca0 << R"( * lda;
|
||||
)" << a_ty_ << "* pa[" << AS0 << ", " << AS1 << R"(] = a + rxa)" << bca1 << R"( + (mask ? inc_true : inc_false);
|
||||
)" << b_ty_ << "* pb[" << BS0 << ", " << BS1 << "] = b + ryb" << bcb1 << " + rkb" << bcb0 << R"(*N;
|
||||
)" << a_ty_ << " a[" << AS0 << ", " << AS1 << R"(] = *pa;
|
||||
)" << b_ty_ << " b[" << BS0 << ", " << BS1 << R"(] = *pb;
|
||||
int1 mask[TM, TK] = maskh[:, newaxis] && maskw[:, newaxis];
|
||||
int32 inc_true[TM, TK] = d[newaxis, :];
|
||||
int32 inc_false[TM, TK] = rka[newaxis, :] * lda;
|
||||
int32 inc[TM, TK] = mask ? inc_true : inc_false;)";
|
||||
}
|
||||
if(ty_ == WGRAD){
|
||||
os << R"(
|
||||
int32 shift[TK, TN] = 0;)";
|
||||
}
|
||||
os << R"(
|
||||
)" << a_ty_ << "* pa[" << AS << "] = a + rxa" << bca1 << " + " << rka << bca0 << lda0 << R"(;
|
||||
)" << b_ty_ << "* pb[" << BS << "] = b + ryb" << bcb1 << " + " << rkb << bcb0 << ldb0 << R"(;
|
||||
)" << a_ty_ << " a[" << AS << R"(] = *pa;
|
||||
)" << b_ty_ << " b[" << BS << R"(] = *pb;
|
||||
for(int32 k = K; k > 0; k = k - TK){
|
||||
C = dot()" << usea << "," << useb << R"(, C);
|
||||
pb = pb + TK*N;
|
||||
int1 checka[)" << AS << R"(] = k > TK;
|
||||
int1 checkb[)" << BS << R"(] = k > TK;)";
|
||||
if(ty_ == FPROP){
|
||||
os << R"(
|
||||
pd = pd + TK;
|
||||
d = *pd;
|
||||
inc_true = d)" << bca0 << R"(;
|
||||
inc_true = d[newaxis, :];
|
||||
inc_false = TK * lda;
|
||||
pa = pa + (mask ? inc_true : inc_false);
|
||||
int1 checka[)" << AS0 << ", " << AS1 << R"(] = k > TK;
|
||||
int1 checkb[)" << BS0 << ", " << BS1 << R"(] = k > TK;
|
||||
@checka a = *pa;
|
||||
@checkb b = *pb;
|
||||
inc = mask ? inc_true : inc_false;
|
||||
pa = pa + inc;
|
||||
@checka a = *pa;)";
|
||||
}
|
||||
else{
|
||||
os << R"(
|
||||
pa = pa + TK)" << lda0 << R"(;
|
||||
@checka a = *pa;)";
|
||||
}
|
||||
if(ty_ == WGRAD){
|
||||
os << R"(
|
||||
int32 rbwhc[TK] = rkb / ABS;
|
||||
int32 rbw[TK] = rbwhc % AW;
|
||||
int32 rbhc[TK] = rbwhc / AW;
|
||||
int32 rbh[TK] = rbhc % AH;
|
||||
int1 maskh[TK] = (rbh >= pad_h) && (rbh < (AH - pad_h));
|
||||
int1 maskw[TK] = (rbw >= pad_w) && (rbw < (AW - pad_w));
|
||||
int1 mask[TK, TN] = maskh[:, newaxis] && maskw[:, newaxis];
|
||||
int32 inc[TK, TN] = mask ? 0 : shift;
|
||||
pb = pb + TK;
|
||||
)" << b_ty_ << R"(* pbb[TK, TN] = pb + inc;
|
||||
@checkb b = *pbb;)";
|
||||
}
|
||||
else{
|
||||
os << R"(
|
||||
pb = pb + TK)" << ldb0 << R"(;
|
||||
@checkb b = *pb;)";
|
||||
}
|
||||
os << R"(
|
||||
}
|
||||
int32 rxc[TM] = get_global_range[TM](0);
|
||||
int32 ryc[TN] = get_global_range[TN](1);
|
||||
fp32* pc[TM, TN] = c + ryc[newaxis, :]*M + rxc[:, newaxis];
|
||||
int1 checkc0[TM] = rxc < M;
|
||||
int1 checkc1[TN] = ryc < N;
|
||||
int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
|
||||
@checkc *pc = C;
|
||||
int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];)";
|
||||
if(ty_ == BPROP){
|
||||
os << R"(
|
||||
int32 rcwhc[TM] = rxc / ABS;
|
||||
int32 rcw[TM] = rcwhc % AW;
|
||||
int32 rchc[TM] = rcwhc / AW;
|
||||
int32 rch[TM] = rchc % AH;
|
||||
int1 maskh[TM] = (rch >= pad_h) && (rch < (AH - pad_h));
|
||||
int1 maskw[TM] = (rcw >= pad_w) && (rcw < (AW - pad_w));
|
||||
int1 interior[TM, TN] = maskh[:, newaxis] && maskw[:, newaxis];
|
||||
fp32* shiftpc[TM, TN] = pc + 0;
|
||||
pc = interior ? shiftpc : pc;
|
||||
@checkc __atomic_add(pc, C);
|
||||
)";
|
||||
}
|
||||
)";
|
||||
else{
|
||||
os << R"(
|
||||
@checkc *pc = C;)";
|
||||
}
|
||||
os << R"(
|
||||
})";
|
||||
}
|
||||
|
||||
}
|
||||
|
Reference in New Issue
Block a user