some cleaning
This commit is contained in:
@@ -27,7 +27,7 @@ shift::shift(int B, int C,
|
||||
layout_(layout){
|
||||
// std::cout << B_ << " " << C_ << " " << F_ << " " << stride_h_ << " " << stride_w_ << " " << a_ty_ << " " << b_ty_ << " " << ty_ << " " << layout_ << std::endl;
|
||||
// max number of channels
|
||||
TK_ = 16;
|
||||
TK_ = (ty == FPROP && a_ty_ == "fp32") ? 8 : 16;
|
||||
MAX_C_ = 8192 + TK_;
|
||||
// activation sizes
|
||||
CD_ = AD_ / stride_d_;
|
||||
@@ -53,8 +53,8 @@ shift::shift(int B, int C,
|
||||
throw std::runtime_error("unsupported input layout");
|
||||
}
|
||||
// Shift edge
|
||||
shift_edge_h_ = (AH_ == stride_h_);
|
||||
shift_edge_w_ = (AW_ == stride_w_);
|
||||
shift_edge_h_ = (AH_ == stride_h_ && stride_h_ > 1);
|
||||
shift_edge_w_ = (AW_ == stride_w_ && stride_w_ > 1);
|
||||
// B memory strides: [C, F]
|
||||
ldb_n_ = 1;
|
||||
ldb_h_ = 1;
|
||||
@@ -132,8 +132,8 @@ base* shift::clone() const {
|
||||
|
||||
void shift::build_delta_a() {
|
||||
h_delta_a.resize(MAX_C_);
|
||||
auto shift_h = [&](int c) { return shift_edge_h_ ? std::max(0, shift_h_[c]) : shift_h_[c]; };
|
||||
auto shift_w = [&](int c) { return shift_edge_w_ ? std::max(0, shift_w_[c]) : shift_w_[c]; };
|
||||
auto shift_h = [&](int c) { return shift_edge_h_ ? (c / AH_) % AH_ : shift_h_[c]; };
|
||||
auto shift_w = [&](int c) { return shift_edge_w_ ? c % AW_ : shift_w_[c]; };
|
||||
if(op_ == FPROP){
|
||||
// compute offset
|
||||
auto offset = [&](unsigned c) {
|
||||
@@ -253,23 +253,24 @@ void shift::triton_c_src(std::ostream &os) const {
|
||||
std::string AS = AS0 + ", " + AS1;
|
||||
std::string BS = BS0 + ", " + BS1;
|
||||
|
||||
os <<
|
||||
std::string result =
|
||||
R"(
|
||||
const tunable int32 TM = {16, 32, 64, 128};
|
||||
const tunable int32 TN = {16, 32, 64, 128};
|
||||
const tunable int32 TK = {)" << TK_ << R"(};
|
||||
const tunable int32 TK = {)" + std::to_string(TK_) + R"(};
|
||||
|
||||
__constant__ int32* delta_a = alloc_const int32[)" << MAX_C_ << R"(];
|
||||
__constant__ int32* delta_a = alloc_const int32[)" + std::to_string(MAX_C_) + R"(];
|
||||
|
||||
void shift(restrict read_only align(16) )" << a_ty_ << R"( *A,
|
||||
restrict read_only align(16) )" << b_ty_ << R"( *B,
|
||||
void shift(restrict read_only align(16) )" + a_ty_ + R"( *A,
|
||||
restrict read_only align(16) )" + b_ty_ + R"( *B,
|
||||
fp32 *C,
|
||||
int32 M, int32 N, int32 K,
|
||||
int32 stride_h, int32 stride_w,
|
||||
int32 lda_b, int32 lda_w, int32 lda_h, int32 lda_c,
|
||||
int32 ldb_b, int32 ldb_w, int32 ldb_h, int32 ldb_c,
|
||||
int32 ldc_b, int32 ldc_w, int32 ldc_h, int32 ldc_c,
|
||||
int32 NB, int32 AH, int32 AW,
|
||||
int32 NB,
|
||||
int32 AH, int32 AW,
|
||||
int32 BH, int32 BW,
|
||||
int32 CH, int32 CW) {
|
||||
int32 rxa[TM] = get_global_range[TM](0);
|
||||
@@ -282,31 +283,34 @@ void shift(restrict read_only align(16) )" << a_ty_ << R"( *A,
|
||||
|
||||
/* A offsets */
|
||||
if(op_ == FPROP){
|
||||
os << R"(
|
||||
int32 rawh[TM] = rxa / NB;
|
||||
int32 rab[TM] = rxa % NB;
|
||||
int32 raw[TM] = (rawh % CW) * stride_w;
|
||||
int32 rah[TM] = (rawh / CW) * stride_h;
|
||||
result += R"(
|
||||
int32 rawh[TM] = rxa / NB;
|
||||
int32 rab[TM] = rxa % NB;
|
||||
int32 raw[TM] = rawh % CW;
|
||||
int32 rah[TM] = rawh / CW;
|
||||
raw = raw * stride_w;
|
||||
rah = rah * stride_h;
|
||||
int32 offxa[TM] = rab*lda_b + raw*lda_w + rah*lda_h;
|
||||
int32 offa0[TM, TK] = offxa[:, newaxis];
|
||||
__constant__ int32* pd[TK] = delta_a + rka;
|
||||
multiple_of(4) int32 d[TK] = *pd;
|
||||
int32 offa_interior[TM, TK] = d[newaxis, :];
|
||||
int32 offa_exterior[TM, TK] = rka[newaxis, :] * lda_c;\n)";
|
||||
int32 offa_exterior[TM, TK] = rka[newaxis, :] * lda_c;
|
||||
)";
|
||||
if(shift_edge_h_)
|
||||
os << " int1 interiorh[TM] = 1;";
|
||||
result += " int1 interiorh[TM] = 1;\n";
|
||||
else
|
||||
os << " int1 interiorh[TM] = (rah >= pad_h) && (rah < (AH - pad_h));";
|
||||
result += " int1 interiorh[TM] = (rah >= pad_h) && (rah < (AH - pad_h));\n";
|
||||
if(shift_edge_w_)
|
||||
os << " int1 interiorw[TM] = 1;";
|
||||
result += " int1 interiorw[TM] = 1;";
|
||||
else
|
||||
os << " int1 interiorw[TM] = (raw >= pad_w) && (raw < (AW - pad_w));";
|
||||
os << R"(
|
||||
result += " int1 interiorw[TM] = (raw >= pad_w) && (raw < (AW - pad_w));";
|
||||
result += R"(
|
||||
int1 interior[TM, TK] = interiorh[:, newaxis] && interiorw[:, newaxis];
|
||||
int32 offa1[TM, TK] = interior ? offa_interior : offa_exterior;)";
|
||||
}
|
||||
if(op_ == BPROP){
|
||||
os << R"(
|
||||
result += R"(
|
||||
int32 rawh[TM] = rxa / NB;
|
||||
int32 rab[TM] = rxa % NB;
|
||||
int32 raw[TM] = (rawh % CW);
|
||||
@@ -316,12 +320,12 @@ if(op_ == BPROP){
|
||||
int32 offa1[TM, TK] = rka[newaxis, :] * lda_c;)";
|
||||
}
|
||||
if(op_ == WGRAD && layout_ == CHWN){
|
||||
os << R"(
|
||||
result += R"(
|
||||
int32 offa0[TK, TM] = rxa[newaxis, :] * lda_c;
|
||||
int32 offa1[TK, TM] = rka[:, newaxis];)";
|
||||
}
|
||||
if(op_ == WGRAD && layout_ == NCHW){
|
||||
os << R"(
|
||||
result += R"(
|
||||
int32 offa0[TK, TM] = rxa[newaxis, :] * lda_c;
|
||||
int32 rawh[TK] = rka / NB;
|
||||
int32 rab[TK] = rka % NB;
|
||||
@@ -333,34 +337,37 @@ if(op_ == WGRAD && layout_ == NCHW){
|
||||
|
||||
/* B offsets */
|
||||
if(op_ == FPROP){
|
||||
os << R"(
|
||||
result += R"(
|
||||
int32 offb0[TN, TK] = ryb[:, newaxis];
|
||||
int32 offb1[TN, TK] = rkb[newaxis, :] * ldb_c;)";
|
||||
}
|
||||
if(op_ == BPROP){
|
||||
os << R"(
|
||||
result += R"(
|
||||
int32 offb0[TK, TN] = ryb[newaxis, :] * ldb_c;
|
||||
int32 offb1[TK, TN] = rkb[:, newaxis];)";
|
||||
}
|
||||
if(op_ == WGRAD){
|
||||
os << R"(
|
||||
result += R"(
|
||||
__constant__ int32* pd[TN] = delta_a + ryb;
|
||||
int32 d[TN] = *pd;
|
||||
int32 shift[TK, TN] = d[newaxis, :];
|
||||
int32 rbwh[TK] = rkb / NB;
|
||||
int32 rbb[TK] = rkb % NB;
|
||||
int32 rbw[TK] = (rbwh % CW)*stride_w;
|
||||
int32 rbh[TK] = (rbwh / CW)*stride_h;
|
||||
int32 offkb[TK] = rbb*ldb_b + rbw*ldb_w + rbh*ldb_h;\n)";
|
||||
int32 rbw[TK] = rbwh % CW;
|
||||
int32 rbh[TK] = rbwh / CW;
|
||||
rbw = rbw * stride_w;
|
||||
rbh = rbh * stride_h;
|
||||
int32 offkb[TK] = rbb*ldb_b + rbw*ldb_w + rbh*ldb_h;
|
||||
)";
|
||||
if(shift_edge_h_)
|
||||
os << " int1 interiorh[TK] = 1;\n";
|
||||
result += " int1 interiorh[TK] = 1;\n";
|
||||
else
|
||||
os << " int1 interiorh[TK] = (rbh >= pad_h) && (rbh < (AH - pad_h));\n";
|
||||
result += " int1 interiorh[TK] = (rbh >= pad_h) && (rbh < (AH - pad_h));\n";
|
||||
if(shift_edge_w_)
|
||||
os << " int1 interiorw[TK] = 1;\n";
|
||||
result += " int1 interiorw[TK] = 1;";
|
||||
else
|
||||
os << " int1 interiorw[TK] = (rbw >= pad_w) && (rbw < (AW - pad_w));\n";
|
||||
os << R"(
|
||||
result += " int1 interiorw[TK] = (rbw >= pad_w) && (rbw < (AW - pad_w));";
|
||||
result += R"(
|
||||
int1 interior[TK, TN] = interiorh[:, newaxis] && interiorw[:, newaxis];
|
||||
int32 incb[TK, TN] = interior ? shift : 0;
|
||||
int32 offb0[TK, TN] = ryb[newaxis, :] * ldb_c;
|
||||
@@ -368,21 +375,21 @@ if(op_ == WGRAD){
|
||||
}
|
||||
|
||||
/* Main loop */
|
||||
os << R"(
|
||||
)" << a_ty_ << "* pa[" << AS << R"(] = A + offa0 + offa1;
|
||||
)" << b_ty_ << "* pb[" << BS << R"(] = B + offb0 + offb1;
|
||||
int1 checka[)" << AS << "] = (rka < K)" << bca0 << R"(;
|
||||
int1 checkb[)" << BS << "] = (rkb < K)" << bcb0 << R"(;
|
||||
)" << a_ty_ << " a[" << AS << R"(] = checka ? *pa : 0;
|
||||
)" << b_ty_ << " b[" << BS << R"(] = checkb ? *pb : 0;
|
||||
result += R"(
|
||||
)" + a_ty_ + "* pa[" + AS + R"(] = A + offa0 + offa1;
|
||||
)" + b_ty_ + "* pb[" + BS + R"(] = B + offb0 + offb1;
|
||||
int1 checka[)" + AS + "] = (rka < K)" + bca0 + R"(;
|
||||
int1 checkb[)" + BS + "] = (rkb < K)" + bcb0 + R"(;
|
||||
)" + a_ty_ + " a[" + AS + R"(] = checka ? *pa : 0;
|
||||
)" + b_ty_ + " b[" + BS + R"(] = checkb ? *pb : 0;
|
||||
for(int32 k = K; k > 0; k = k - TK){
|
||||
c = dot()" << usea << "," << useb << R"(, c);
|
||||
int1 checka[)" << AS << R"(] = k > TK;
|
||||
int1 checkb[)" << BS << R"(] = k > TK;)";
|
||||
c = dot()" + usea + "," + useb + R"(, c);
|
||||
int1 checka[)" + AS + R"(] = k > TK;
|
||||
int1 checkb[)" + BS + R"(] = k > TK;)";
|
||||
|
||||
/* Increment A pointers */
|
||||
if(op_ == FPROP){
|
||||
os << R"(
|
||||
result += R"(
|
||||
pd = pd + TK;
|
||||
d = *pd;
|
||||
offa_interior = d[newaxis, :];
|
||||
@@ -391,15 +398,15 @@ if(op_ == FPROP){
|
||||
pa = pa + offa;)";
|
||||
}
|
||||
if(op_ == BPROP){
|
||||
os << R"(
|
||||
result += R"(
|
||||
pa = pa + TK * lda_c;)";
|
||||
}
|
||||
if(op_ == WGRAD && layout_ == CHWN){
|
||||
os << R"(
|
||||
result += R"(
|
||||
pa = pa + TK;)";
|
||||
}
|
||||
if(op_ == WGRAD && layout_ == NCHW){
|
||||
os << R"(
|
||||
result += R"(
|
||||
rka = rka + TK;
|
||||
rawh = rka / NB;
|
||||
rab = rka % NB;
|
||||
@@ -408,40 +415,43 @@ if(op_ == WGRAD && layout_ == NCHW){
|
||||
offxa = rab*lda_b + raw*lda_w + rah*lda_h;
|
||||
pa = A + offa0 + offxa[:, newaxis];)";
|
||||
}
|
||||
os << R"(
|
||||
result += R"(
|
||||
@checka a = *pa;)";
|
||||
|
||||
/* Increment B pointers */
|
||||
if(op_ == WGRAD){
|
||||
os << R"(
|
||||
result += R"(
|
||||
rkb = rkb + TK;
|
||||
rbwh = rkb / NB;
|
||||
rbb = rkb % NB;
|
||||
rbw = (rbwh % CW)*stride_w;
|
||||
rbh = (rbwh / CW)*stride_h;
|
||||
offkb = rbb*ldb_b + rbw*ldb_w + rbh*ldb_h;\n)";
|
||||
rbw = rbwh % CW;
|
||||
rbh = rbwh / CW;
|
||||
rbw = rbw * stride_w;
|
||||
rbh = rbh * stride_h;
|
||||
offkb = rbb*ldb_b + rbw*ldb_w + rbh*ldb_h;
|
||||
)";
|
||||
if(shift_edge_h_)
|
||||
os << " interiorh = 1;\n";
|
||||
result += " interiorh = 1;\n";
|
||||
else
|
||||
os << " interiorh = (rbh >= pad_h) && (rbh < (AH - pad_h));\n";
|
||||
result += " interiorh = (rbh >= pad_h) && (rbh < (AH - pad_h));\n";
|
||||
if(shift_edge_w_)
|
||||
os << " interiorw = 1;\n";
|
||||
result += " interiorw = 1;";
|
||||
else
|
||||
os << " interiorw = (rbw >= pad_w) && (rbw < (AW - pad_w));\n";
|
||||
os << R"(
|
||||
result += " interiorw = (rbw >= pad_w) && (rbw < (AW - pad_w));";
|
||||
result += R"(
|
||||
interior = interiorh[:, newaxis] && interiorw[:, newaxis];
|
||||
incb = interior ? shift : 0;
|
||||
pb = B + offb0 + offkb[:, newaxis] + incb;)";
|
||||
}
|
||||
if(op_ == FPROP){
|
||||
os << R"(
|
||||
result += R"(
|
||||
pb = pb + TK * ldb_c;)";
|
||||
}
|
||||
if(op_ == BPROP){
|
||||
os << R"(
|
||||
result += R"(
|
||||
pb = pb + TK;)";
|
||||
}
|
||||
os << R"(
|
||||
result += R"(
|
||||
@checkb b = *pb;
|
||||
}
|
||||
int32 rxc[TM] = get_global_range[TM](0);
|
||||
@@ -449,44 +459,41 @@ if(op_ == BPROP){
|
||||
|
||||
/* C offsets */
|
||||
if(op_ == BPROP){
|
||||
os << R"(
|
||||
result += R"(
|
||||
int32 rcwh[TM] = rxc / NB;
|
||||
int32 rcb[TM] = rxc % NB;
|
||||
int32 rcw[TM] = (rcwh % CW) * stride_w;
|
||||
int32 rch[TM] = (rcwh / CW) * stride_h;
|
||||
int32 offxc[TM] = rcb*ldc_b + rcw*ldc_w + rch*ldc_h;
|
||||
)";
|
||||
int32 offxc[TM] = rcb*ldc_b + rcw*ldc_w + rch*ldc_h;)";
|
||||
}
|
||||
if(op_ == FPROP){
|
||||
os << R"(
|
||||
result += R"(
|
||||
int32 rcwh[TM] = rxc / NB;
|
||||
int32 rcb[TM] = rxc % NB;
|
||||
int32 rcw[TM] = (rcwh % CW);
|
||||
int32 rch[TM] = (rcwh / CW);
|
||||
int32 offxc[TM] = rcb*ldc_b + rcw*ldc_w + rch*ldc_h;
|
||||
)";
|
||||
int32 offxc[TM] = rcb*ldc_b + rcw*ldc_w + rch*ldc_h;)";
|
||||
}
|
||||
if(op_ == WGRAD){
|
||||
os << R"(
|
||||
int32 offxc[TM] = rxc;
|
||||
)";
|
||||
result += R"(
|
||||
int32 offxc[TM] = rxc;)";
|
||||
}
|
||||
os << R"("
|
||||
result += R"("
|
||||
fp32* pc[TM, TN] = C + offxc[:, newaxis] + ryc[newaxis, :]*ldc_c;
|
||||
int1 checkc0[TM] = rxc < M;
|
||||
int1 checkc1[TN] = ryc < N;
|
||||
int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];)";
|
||||
if(op_ == BPROP){
|
||||
os << "\n";
|
||||
result += "\n";
|
||||
if(shift_edge_h_)
|
||||
os << " int1 interiorh[TM] = 1;\n";
|
||||
result += " int1 interiorh[TM] = 1;\n";
|
||||
else
|
||||
os << " int1 interiorh[TM] = (rch >= pad_h) && (rch < (AH - pad_h));\n";
|
||||
result += " int1 interiorh[TM] = (rch >= pad_h) && (rch < (AH - pad_h));\n";
|
||||
if(shift_edge_w_)
|
||||
os << " int1 interiorw[TM] = 1;\n";
|
||||
result += " int1 interiorw[TM] = 1;";
|
||||
else
|
||||
os << " int1 interiorw[TM] = (rcw >= pad_w) && (rcw < (AW - pad_w));\n";
|
||||
os << R"(
|
||||
result += " int1 interiorw[TM] = (rcw >= pad_w) && (rcw < (AW - pad_w));";
|
||||
result += R"(
|
||||
int1 interior[TM, TN] = interiorh[:, newaxis] && interiorw[:, newaxis];
|
||||
__constant__ int32* pd[TN] = delta_a + ryc;
|
||||
fp32* shift_pc[TM, TN] = pc + (*pd)[newaxis, :];
|
||||
@@ -495,11 +502,13 @@ if(op_ == BPROP){
|
||||
)";
|
||||
}
|
||||
else{
|
||||
os << R"(
|
||||
result += R"(
|
||||
@checkc *pc = c;)";
|
||||
}
|
||||
os << R"(
|
||||
result += R"(
|
||||
})";
|
||||
|
||||
os << result;
|
||||
}
|
||||
|
||||
}
|
||||
|
Reference in New Issue
Block a user