more stuff

This commit is contained in:
Philippe Tillet
2019-06-30 16:55:02 -07:00
parent 9a86bc51e1
commit c172bd518b
9 changed files with 124 additions and 90 deletions

View File

@@ -8,42 +8,63 @@ void shift::set_ld(const std::vector<int32_t>& shapes,
std::vector<int32_t>& ld) {
size_t size = shapes.size();
ld.resize(size);
ld[3] = 1;
ld[2] = shapes[3]*ld[3];
ld[1] = shapes[2]*ld[2];
ld[0] = shapes[1]*ld[1];
ld[size - 1] = 1;
for(int i = size - 1; i >= 1; i--)
ld[i - 1] = shapes[i] * ld[i];
}
shift::shift(int B, int NC,
shift::shift(int B, int C,
int D, int H, int W,
int T, int R, int S,
int NF,
int F,
const std::vector<int32_t>& shift_h, const std::vector<int32_t>& shift_w,
std::string a_ty, std::string b_ty,
type ty, bool bias)
: NB_(B), NC_(NC),
: B_(B), C_(C),
AD_(D), AH_(H), AW_(W),
BD_(T), BH_(R), BW_(S),
NF_(NF),
F_(F),
shift_h_(shift_h), shift_w_(shift_w),
a_ty_(a_ty), b_ty_(b_ty),
ty_(ty), bias_(bias) {
// max number of channels
TK_ = 16;
MAX_C_ = 8192 + TK_;
// transpose
AT_ = false;
BT_ = true;
// equivalent matmul
M_ = NB_*AH_*AW_;
N_ = NF_;
K_ = NC_;
M_ = B_*AH_*AW_;
N_ = F_;
K_ = C_;
// shapes
// input layout: C, H, W, BS
// filter layout: C, K
// output layout: K, H, W, BS
shapes_a_ = {NC, H, W, B};
shapes_b_ = {NC, NF};
shapes_c_ = {NF, H, W, B};
// input layout: C, H, W, B
// filter layout: C, F
// output layout: F, H, W, B
shapes_a_ = {C, H, W, B};
shapes_b_ = {C, F};
shapes_c_ = {F, H, W, B};
if(ty_ == WGRAD){
shapes_b_.swap(shapes_c_);
shapes_a_.swap(shapes_b_);
AT_ = true;
BT_ = false;
M_ = K_;
N_ = C_;
K_ = B_*AH_*AW_;
}
if(ty_ == BPROP){
shapes_a_.swap(shapes_c_);
AT_ = false;
BT_ = false;
K_ = F_;
M_ = B_*AH_*AW_;
N_ = C_;
}
// memory strides
set_ld(shapes_a_, ld_a_);
set_ld(shapes_b_, ld_b_);
set_ld(shapes_c_, ld_c_);
// build LUTs
build_deltas();
}
@@ -57,7 +78,7 @@ void shift::build_deltas() {
// populate look-up table
for(unsigned c = 0; c < TK_; c++)
h_deltas_[c] = offset(c);
for(unsigned c = 0; c < NC_; c++)
for(unsigned c = 0; c < C_; c++)
h_deltas_[TK_ + c] = offset(c + TK_) - offset(c);
}
@@ -99,18 +120,36 @@ void shift::enqueue(driver::stream *stream, driver::kernel *kernel,
kernel->setArg(3, M_);
kernel->setArg(4, N_);
kernel->setArg(5, K_);
kernel->setArg(6, NB_*AH_*AW_);
kernel->setArg(7, NB_);
kernel->setArg(6, B_*AH_*AW_);
kernel->setArg(7, B_);
kernel->setArg(8, AH_);
kernel->setArg(9, AW_);
kernel->setArg(10, BH_);
kernel->setArg(11, BW_);
// dry run
std::array<size_t, 3> grid = {(M_ + TM - 1)/TM, (N_ + TN - 1)/TN, 1};
stream->enqueue(kernel, grid, {nthreads, 1, 1});
}
void shift::src(std::ostream &os) {
std::string AS0 = "TM", AS1 = "TK";
std::string BS0 = "TK", BS1 = "TN";
std::string bca0 = "[newaxis, :]", bca1 = "[:, newaxis]";
std::string bcb0 = "[:, newaxis]", bcb1 = "[newaxis, :]";
std::string lda0 = "*lda", lda1 = "";
std::string ldb0 = "", ldb1 = "*ldb";
std::string usea = AT_ ? "trans(a)" : "a";
std::string useb = BT_ ? "trans(b)" : "b";
if(AT_){
std::swap(AS0, AS1);
std::swap(bca0, bca1);
std::swap(lda0, lda1);
}
if(BT_){
std::swap(BS0, BS1);
std::swap(bcb0, bcb1);
std::swap(ldb0, ldb1);
}
os <<
R"(
const tunable int32 TM = {16, 32, 64, 128};
@@ -136,26 +175,27 @@ void shift(restrict read_only align(16) )" << a_ty_ << R"( *a,
int32 raw[TM] = rawhc % AW;
int32 rahc[TM] = rawhc / AW;
int32 rah[TM] = rahc % AH;
__constant__ int32* pd[TK] = delta + rka;
multiple_of(4) int32 d[TK] = *pd;
int1 maskh[TM] = (rah >= pad_h) && (rah < (AH - pad_h));
int1 maskw[TM] = (raw >= pad_w) && (raw < (AW - pad_w));
int1 mask[TM, TK] = maskh[:, newaxis] && maskw[:, newaxis];
__constant__ int32* pd[TK] = delta + rka;
multiple_of(4) int32 d[TK];
d = *pd;
int32 offa1[TK] = rka*lda;
int32 inc[TM, TK] = mask ? d[newaxis, :] : offa1[newaxis, :];
)" << a_ty_ << R"(* pa[TM, TK] = a + rxa[:, newaxis] + inc;
)" << b_ty_ << R"(* pb[TN, TK] = b + rkb[newaxis, :]*N + ryb[:, newaxis];
)" << a_ty_ << R"( a[TM, TK] = *pa;
)" << b_ty_ << R"( b[TN, TK] = *pb;
int1 mask[)" << AS0 << ", " << AS1 << "] = maskh" << bca1 << " && maskw" << bca1 << R"(;
int32 inc_true[)" << AS0 << ", " << AS1 << "] = d" << bca0 << R"(;
int32 inc_false[)" << AS0 << ", " << AS1 << "] = rka" << bca0 << R"( * lda;
)" << a_ty_ << "* pa[" << AS0 << ", " << AS1 << R"(] = a + rxa)" << bca1 << R"( + (mask ? inc_true : inc_false);
)" << b_ty_ << "* pb[" << BS0 << ", " << BS1 << "] = b + ryb" << bcb1 << " + rkb" << bcb0 << R"(*N;
)" << a_ty_ << " a[" << AS0 << ", " << AS1 << R"(] = *pa;
)" << b_ty_ << " b[" << BS0 << ", " << BS1 << R"(] = *pb;
for(int32 k = K; k > 0; k = k - TK){
C = dot(a, trans(b), C);
C = dot()" << usea << "," << useb << R"(, C);
pb = pb + TK*N;
pd = pd + TK;
d = *pd;
pa = pa + (mask ? d[newaxis, :] : TK*lda);
int1 checka[TM, TK] = k > TK;
int1 checkb[TN, TK] = k > TK;
inc_true = d)" << bca0 << R"(;
inc_false = TK * lda;
pa = pa + (mask ? inc_true : inc_false);
int1 checka[)" << AS0 << ", " << AS1 << R"(] = k > TK;
int1 checkb[)" << BS0 << ", " << BS1 << R"(] = k > TK;
@checka a = *pa;
@checkb b = *pb;
}