more stuff
This commit is contained in:
@@ -8,42 +8,63 @@ void shift::set_ld(const std::vector<int32_t>& shapes,
|
||||
std::vector<int32_t>& ld) {
|
||||
size_t size = shapes.size();
|
||||
ld.resize(size);
|
||||
ld[3] = 1;
|
||||
ld[2] = shapes[3]*ld[3];
|
||||
ld[1] = shapes[2]*ld[2];
|
||||
ld[0] = shapes[1]*ld[1];
|
||||
ld[size - 1] = 1;
|
||||
for(int i = size - 1; i >= 1; i--)
|
||||
ld[i - 1] = shapes[i] * ld[i];
|
||||
}
|
||||
|
||||
shift::shift(int B, int NC,
|
||||
shift::shift(int B, int C,
|
||||
int D, int H, int W,
|
||||
int T, int R, int S,
|
||||
int NF,
|
||||
int F,
|
||||
const std::vector<int32_t>& shift_h, const std::vector<int32_t>& shift_w,
|
||||
std::string a_ty, std::string b_ty,
|
||||
type ty, bool bias)
|
||||
: NB_(B), NC_(NC),
|
||||
: B_(B), C_(C),
|
||||
AD_(D), AH_(H), AW_(W),
|
||||
BD_(T), BH_(R), BW_(S),
|
||||
NF_(NF),
|
||||
F_(F),
|
||||
shift_h_(shift_h), shift_w_(shift_w),
|
||||
a_ty_(a_ty), b_ty_(b_ty),
|
||||
ty_(ty), bias_(bias) {
|
||||
// max number of channels
|
||||
TK_ = 16;
|
||||
MAX_C_ = 8192 + TK_;
|
||||
// transpose
|
||||
AT_ = false;
|
||||
BT_ = true;
|
||||
// equivalent matmul
|
||||
M_ = NB_*AH_*AW_;
|
||||
N_ = NF_;
|
||||
K_ = NC_;
|
||||
M_ = B_*AH_*AW_;
|
||||
N_ = F_;
|
||||
K_ = C_;
|
||||
// shapes
|
||||
// input layout: C, H, W, BS
|
||||
// filter layout: C, K
|
||||
// output layout: K, H, W, BS
|
||||
shapes_a_ = {NC, H, W, B};
|
||||
shapes_b_ = {NC, NF};
|
||||
shapes_c_ = {NF, H, W, B};
|
||||
// input layout: C, H, W, B
|
||||
// filter layout: C, F
|
||||
// output layout: F, H, W, B
|
||||
shapes_a_ = {C, H, W, B};
|
||||
shapes_b_ = {C, F};
|
||||
shapes_c_ = {F, H, W, B};
|
||||
if(ty_ == WGRAD){
|
||||
shapes_b_.swap(shapes_c_);
|
||||
shapes_a_.swap(shapes_b_);
|
||||
AT_ = true;
|
||||
BT_ = false;
|
||||
M_ = K_;
|
||||
N_ = C_;
|
||||
K_ = B_*AH_*AW_;
|
||||
}
|
||||
if(ty_ == BPROP){
|
||||
shapes_a_.swap(shapes_c_);
|
||||
AT_ = false;
|
||||
BT_ = false;
|
||||
K_ = F_;
|
||||
M_ = B_*AH_*AW_;
|
||||
N_ = C_;
|
||||
}
|
||||
// memory strides
|
||||
set_ld(shapes_a_, ld_a_);
|
||||
set_ld(shapes_b_, ld_b_);
|
||||
set_ld(shapes_c_, ld_c_);
|
||||
// build LUTs
|
||||
build_deltas();
|
||||
}
|
||||
@@ -57,7 +78,7 @@ void shift::build_deltas() {
|
||||
// populate look-up table
|
||||
for(unsigned c = 0; c < TK_; c++)
|
||||
h_deltas_[c] = offset(c);
|
||||
for(unsigned c = 0; c < NC_; c++)
|
||||
for(unsigned c = 0; c < C_; c++)
|
||||
h_deltas_[TK_ + c] = offset(c + TK_) - offset(c);
|
||||
}
|
||||
|
||||
@@ -99,18 +120,36 @@ void shift::enqueue(driver::stream *stream, driver::kernel *kernel,
|
||||
kernel->setArg(3, M_);
|
||||
kernel->setArg(4, N_);
|
||||
kernel->setArg(5, K_);
|
||||
kernel->setArg(6, NB_*AH_*AW_);
|
||||
kernel->setArg(7, NB_);
|
||||
kernel->setArg(6, B_*AH_*AW_);
|
||||
kernel->setArg(7, B_);
|
||||
kernel->setArg(8, AH_);
|
||||
kernel->setArg(9, AW_);
|
||||
kernel->setArg(10, BH_);
|
||||
kernel->setArg(11, BW_);
|
||||
// dry run
|
||||
std::array<size_t, 3> grid = {(M_ + TM - 1)/TM, (N_ + TN - 1)/TN, 1};
|
||||
stream->enqueue(kernel, grid, {nthreads, 1, 1});
|
||||
}
|
||||
|
||||
void shift::src(std::ostream &os) {
|
||||
std::string AS0 = "TM", AS1 = "TK";
|
||||
std::string BS0 = "TK", BS1 = "TN";
|
||||
std::string bca0 = "[newaxis, :]", bca1 = "[:, newaxis]";
|
||||
std::string bcb0 = "[:, newaxis]", bcb1 = "[newaxis, :]";
|
||||
std::string lda0 = "*lda", lda1 = "";
|
||||
std::string ldb0 = "", ldb1 = "*ldb";
|
||||
std::string usea = AT_ ? "trans(a)" : "a";
|
||||
std::string useb = BT_ ? "trans(b)" : "b";
|
||||
if(AT_){
|
||||
std::swap(AS0, AS1);
|
||||
std::swap(bca0, bca1);
|
||||
std::swap(lda0, lda1);
|
||||
}
|
||||
if(BT_){
|
||||
std::swap(BS0, BS1);
|
||||
std::swap(bcb0, bcb1);
|
||||
std::swap(ldb0, ldb1);
|
||||
}
|
||||
|
||||
os <<
|
||||
R"(
|
||||
const tunable int32 TM = {16, 32, 64, 128};
|
||||
@@ -136,26 +175,27 @@ void shift(restrict read_only align(16) )" << a_ty_ << R"( *a,
|
||||
int32 raw[TM] = rawhc % AW;
|
||||
int32 rahc[TM] = rawhc / AW;
|
||||
int32 rah[TM] = rahc % AH;
|
||||
__constant__ int32* pd[TK] = delta + rka;
|
||||
multiple_of(4) int32 d[TK] = *pd;
|
||||
int1 maskh[TM] = (rah >= pad_h) && (rah < (AH - pad_h));
|
||||
int1 maskw[TM] = (raw >= pad_w) && (raw < (AW - pad_w));
|
||||
int1 mask[TM, TK] = maskh[:, newaxis] && maskw[:, newaxis];
|
||||
__constant__ int32* pd[TK] = delta + rka;
|
||||
multiple_of(4) int32 d[TK];
|
||||
d = *pd;
|
||||
int32 offa1[TK] = rka*lda;
|
||||
int32 inc[TM, TK] = mask ? d[newaxis, :] : offa1[newaxis, :];
|
||||
)" << a_ty_ << R"(* pa[TM, TK] = a + rxa[:, newaxis] + inc;
|
||||
)" << b_ty_ << R"(* pb[TN, TK] = b + rkb[newaxis, :]*N + ryb[:, newaxis];
|
||||
)" << a_ty_ << R"( a[TM, TK] = *pa;
|
||||
)" << b_ty_ << R"( b[TN, TK] = *pb;
|
||||
int1 mask[)" << AS0 << ", " << AS1 << "] = maskh" << bca1 << " && maskw" << bca1 << R"(;
|
||||
int32 inc_true[)" << AS0 << ", " << AS1 << "] = d" << bca0 << R"(;
|
||||
int32 inc_false[)" << AS0 << ", " << AS1 << "] = rka" << bca0 << R"( * lda;
|
||||
)" << a_ty_ << "* pa[" << AS0 << ", " << AS1 << R"(] = a + rxa)" << bca1 << R"( + (mask ? inc_true : inc_false);
|
||||
)" << b_ty_ << "* pb[" << BS0 << ", " << BS1 << "] = b + ryb" << bcb1 << " + rkb" << bcb0 << R"(*N;
|
||||
)" << a_ty_ << " a[" << AS0 << ", " << AS1 << R"(] = *pa;
|
||||
)" << b_ty_ << " b[" << BS0 << ", " << BS1 << R"(] = *pb;
|
||||
for(int32 k = K; k > 0; k = k - TK){
|
||||
C = dot(a, trans(b), C);
|
||||
C = dot()" << usea << "," << useb << R"(, C);
|
||||
pb = pb + TK*N;
|
||||
pd = pd + TK;
|
||||
d = *pd;
|
||||
pa = pa + (mask ? d[newaxis, :] : TK*lda);
|
||||
int1 checka[TM, TK] = k > TK;
|
||||
int1 checkb[TN, TK] = k > TK;
|
||||
inc_true = d)" << bca0 << R"(;
|
||||
inc_false = TK * lda;
|
||||
pa = pa + (mask ? inc_true : inc_false);
|
||||
int1 checka[)" << AS0 << ", " << AS1 << R"(] = k > TK;
|
||||
int1 checkb[)" << BS0 << ", " << BS1 << R"(] = k > TK;
|
||||
@checka a = *pa;
|
||||
@checkb b = *pb;
|
||||
}
|
||||
|
Reference in New Issue
Block a user