stuff
This commit is contained in:
@@ -70,16 +70,26 @@ shift::shift(int B, int C,
|
||||
}
|
||||
|
||||
void shift::build_deltas() {
|
||||
// compute offset
|
||||
auto offset = [&](unsigned c) {
|
||||
return c*ld_a_[0] + shift_h_[c]*ld_a_[1] + shift_w_[c]*ld_a_[2];
|
||||
};
|
||||
h_deltas_.resize(MAX_C_);
|
||||
// populate look-up table
|
||||
for(unsigned c = 0; c < TK_; c++)
|
||||
h_deltas_[c] = offset(c);
|
||||
for(unsigned c = 0; c < C_; c++)
|
||||
h_deltas_[TK_ + c] = offset(c + TK_) - offset(c);
|
||||
if(ty_ == FPROP){
|
||||
// compute offset
|
||||
auto offset = [&](unsigned c) {
|
||||
return c*ld_a_[0] + shift_h_[c]*ld_a_[1] + shift_w_[c]*ld_a_[2];
|
||||
};
|
||||
// populate look-up table
|
||||
for(unsigned c = 0; c < TK_; c++)
|
||||
h_deltas_[c] = offset(c);
|
||||
for(unsigned c = 0; c < C_; c++)
|
||||
h_deltas_[TK_ + c] = offset(c + TK_) - offset(c);
|
||||
}
|
||||
if(ty_ == BPROP){
|
||||
for(unsigned c = 0; c < C_; c++)
|
||||
h_deltas_[c] = shift_h_[c]*ld_c_[1] + shift_w_[c]*ld_c_[2];
|
||||
}
|
||||
if(ty_ == WGRAD){
|
||||
for(unsigned c = 0; c < C_; c++)
|
||||
h_deltas_[c] = shift_h_[c]*ld_b_[1] + shift_w_[c]*ld_b_[2];
|
||||
}
|
||||
}
|
||||
|
||||
size_t shift::a_size(){
|
||||
@@ -102,7 +112,7 @@ std::vector<int32_t> shift::c_shapes(){
|
||||
}
|
||||
|
||||
size_t shift::get_nflops() {
|
||||
return 2. * M_ * N_ * K_;
|
||||
return 2.*M_*N_*K_;
|
||||
}
|
||||
|
||||
|
||||
@@ -114,15 +124,13 @@ void shift::init(driver::stream *stream, driver::cu_module *module) {
|
||||
void shift::enqueue(driver::stream *stream, driver::kernel *kernel,
|
||||
driver::buffer *a, driver::buffer *b, driver::buffer *c,
|
||||
size_t TM, size_t TN, size_t nthreads) {
|
||||
if(ty_ == WGRAD)
|
||||
std::swap(a, b);
|
||||
kernel->setArg(0, a);
|
||||
kernel->setArg(1, b);
|
||||
kernel->setArg(2, c);
|
||||
kernel->setArg(3, M_);
|
||||
kernel->setArg(4, N_);
|
||||
kernel->setArg(5, K_);
|
||||
kernel->setArg(6, B_*AH_*AW_);
|
||||
kernel->setArg(6, M_);
|
||||
kernel->setArg(7, N_);
|
||||
kernel->setArg(8, B_);
|
||||
kernel->setArg(9, AH_);
|
||||
@@ -177,7 +185,7 @@ void shift(restrict read_only align(16) )" << a_ty_ << R"( *a,
|
||||
restrict read_only align(16) )" << b_ty_ << R"( *b,
|
||||
fp32 *c,
|
||||
int32 M, int32 N, int32 K,
|
||||
multiple_of(4) int32 lda, multiple_of(4) int32 ldb,
|
||||
int32 lda, int32 ldb,
|
||||
int32 ABS, int32 AH, int32 AW, int32 AR, int32 AS) {
|
||||
int32 rxa[TM] = get_global_range[TM](0);
|
||||
int32 ryb[TN] = get_global_range[TN](1);
|
||||
@@ -203,11 +211,13 @@ if(ty_ == FPROP){
|
||||
}
|
||||
if(ty_ == WGRAD){
|
||||
os << R"(
|
||||
int32 shift[TK, TN] = 0;)";
|
||||
__constant__ int32* pd[TN] = delta + ryb;
|
||||
int32 d[TN] = *pd;
|
||||
int32 shift[TK, TN] = d[newaxis, :];)";
|
||||
}
|
||||
os << R"(
|
||||
)" << a_ty_ << "* pa[" << AS << "] = a + rxa" << bca1 << " + " << rka << bca0 << lda0 << R"(;
|
||||
)" << b_ty_ << "* pb[" << BS << "] = b + ryb" << bcb1 << " + " << rkb << bcb0 << ldb0 << R"(;
|
||||
)" << a_ty_ << "* pa[" << AS << "] = a + rxa" << bca1 << lda1 << " + " << rka << bca0 << lda0 << R"(;
|
||||
)" << b_ty_ << "* pb[" << BS << "] = b + ryb" << bcb1 << ldb1 << " + " << rkb << bcb0 << ldb0 << R"(;
|
||||
)" << a_ty_ << " a[" << AS << R"(] = *pa;
|
||||
)" << b_ty_ << " b[" << BS << R"(] = *pb;
|
||||
for(int32 k = K; k > 0; k = k - TK){
|
||||
@@ -239,7 +249,7 @@ if(ty_ == WGRAD){
|
||||
int1 maskw[TK] = (rbw >= pad_w) && (rbw < (AW - pad_w));
|
||||
int1 mask[TK, TN] = maskh[:, newaxis] && maskw[:, newaxis];
|
||||
int32 inc[TK, TN] = mask ? 0 : shift;
|
||||
pb = pb + TK;
|
||||
pb = pb + TK)" << ldb0 << R"(;
|
||||
)" << b_ty_ << R"(* pbb[TK, TN] = pb + inc;
|
||||
@checkb b = *pbb;)";
|
||||
}
|
||||
@@ -259,14 +269,15 @@ else{
|
||||
if(ty_ == BPROP){
|
||||
os << R"(
|
||||
int32 rcwhc[TM] = rxc / ABS;
|
||||
int32 rcw[TM] = rcwhc % AW;
|
||||
int32 rcw[TM] = (rcwhc % AW);
|
||||
int32 rchc[TM] = rcwhc / AW;
|
||||
int32 rch[TM] = rchc % AH;
|
||||
int32 rch[TM] = (rchc % AH);
|
||||
int1 maskh[TM] = (rch >= pad_h) && (rch < (AH - pad_h));
|
||||
int1 maskw[TM] = (rcw >= pad_w) && (rcw < (AW - pad_w));
|
||||
int1 interior[TM, TN] = maskh[:, newaxis] && maskw[:, newaxis];
|
||||
fp32* shiftpc[TM, TN] = pc + 0;
|
||||
pc = interior ? shiftpc : pc;
|
||||
__constant__ int32* pd[TN] = delta + ryc;
|
||||
fp32* shift_pc[TM, TN] = pc + (*pd)[newaxis, :];
|
||||
pc = interior ? shift_pc : pc;
|
||||
@checkc __atomic_add(pc, C);
|
||||
)";
|
||||
}
|
||||
|
Reference in New Issue
Block a user