more cleaning
This commit is contained in:
@@ -68,7 +68,7 @@ int main() {
|
||||
|
||||
// shift
|
||||
std::vector<unsigned> params = {
|
||||
4, 2, 32, 8, 2, 32, 8, 4, 2, 2, 8, 8, 4
|
||||
4, 2, 16, 8, 2, 64, 4, 8, 2, 2, 4, 8, 8
|
||||
};
|
||||
std::ostringstream oss;
|
||||
shift.src(oss);
|
||||
|
@@ -43,34 +43,24 @@ shift::shift(int B, int NC,
|
||||
set_ld(shapes_a_, ld_a_);
|
||||
// build LUTs
|
||||
build_deltas();
|
||||
build_masks();
|
||||
}
|
||||
|
||||
void shift::build_deltas() {
|
||||
// compute offset
|
||||
auto offset = [&](unsigned c) {
|
||||
return c*ld_a_[0] + shift_h_[c]*ld_a_[1] + shift_w_[c]*ld_a_[2];
|
||||
};
|
||||
// allocate look-up table
|
||||
size_t TK = 8;
|
||||
h_deltas_ = std::vector<int32_t>(512, 0);
|
||||
for(unsigned c = 0; c < NC_; c++){
|
||||
h_deltas_[c] = c*ld_a_[0];
|
||||
h_deltas_[c] += shift_h_[c]*ld_a_[1];
|
||||
h_deltas_[c] += shift_w_[c]*ld_a_[2];
|
||||
// populate look-up table
|
||||
for(unsigned c = 0; c < TK; c++){
|
||||
h_deltas_[c] = offset(c); // init (shift)
|
||||
h_deltas_[c + 256] = c*ld_a_[0]; // init (no shift)
|
||||
}
|
||||
for(unsigned c = 0; c < NC_; c++){
|
||||
h_deltas_[c + 256] = c*ld_a_[0];
|
||||
}
|
||||
}
|
||||
|
||||
void shift::build_masks() {
|
||||
size_t S0 = NC_;
|
||||
size_t S1 = BH_;
|
||||
size_t S2 = BW_;
|
||||
h_masks_.resize(S0*S1*S2);
|
||||
for(size_t ph = 0; ph < S1; ++ph)
|
||||
for(size_t pw = 0; pw < S2; ++pw){
|
||||
int32_t* ptr = &h_masks_[ph*S0 + pw*S0*S1];
|
||||
for(size_t i = 0; i < S0; ++i){
|
||||
bool in_bounds_h = shift_h_[i] + ph >= 0 && shift_h_[i] + ph < BH_;
|
||||
bool in_bounds_w = shift_w_[i] + pw >= 0 && shift_w_[i] + pw < BW_;
|
||||
ptr[i] = in_bounds_h && in_bounds_w;
|
||||
}
|
||||
h_deltas_[TK + c] = offset(c + TK) - offset(c); // deltas (shift)
|
||||
h_deltas_[TK + c + 256] = TK*ld_a_[0]; // deltas (shift)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -100,9 +90,7 @@ size_t shift::get_nflops() {
|
||||
|
||||
void shift::init(driver::stream *stream, driver::cu_module *module) {
|
||||
triton::driver::buffer* delta = ((triton::driver::cu_module*)module)->symbol("delta");
|
||||
triton::driver::buffer* masks = ((triton::driver::cu_module*)module)->symbol("masks");
|
||||
stream->write(delta, false, 0, h_deltas_.size()*4, h_deltas_.data());
|
||||
stream->write(masks, false, 0, h_masks_.size()*4, h_masks_.data());
|
||||
}
|
||||
|
||||
void shift::enqueue(driver::stream *stream, driver::kernel *kernel,
|
||||
@@ -132,9 +120,10 @@ const tunable int32 TN = {16, 32, 64, 128};
|
||||
const tunable int32 TK = {8};
|
||||
|
||||
__constant__ int32* delta = alloc_const int32[512];
|
||||
__constant__ int32* masks = alloc_const int32[8192];
|
||||
|
||||
void shift(restrict read_only fp32 *a, restrict read_only fp32 *b, fp32 *c,
|
||||
void shift(restrict read_only align(16) fp32 *a,
|
||||
restrict read_only align(16) fp32 *b,
|
||||
fp32 *c,
|
||||
int32 M, int32 N, int32 K,
|
||||
int32 ABS, int32 AH, int32 AW, int32 AR, int32 AS) {
|
||||
int32 rxa[TM] = get_global_range[TM](0);
|
||||
@@ -142,7 +131,6 @@ void shift(restrict read_only fp32 *a, restrict read_only fp32 *b, fp32 *c,
|
||||
int32 rka[TK] = 0 ... TK;
|
||||
int32 rkb[TK] = 0 ... TK;
|
||||
fp32 C[TM, TN] = 0;
|
||||
fp32* pxa[TM, TK] = a + rxa[:, newaxis];
|
||||
fp32* pb[TN, TK] = b + rkb[newaxis, :]*N + ryb[:, newaxis];
|
||||
int32 pad_h = AR/2;
|
||||
int32 pad_w = AS/2;
|
||||
@@ -152,16 +140,17 @@ void shift(restrict read_only fp32 *a, restrict read_only fp32 *b, fp32 *c,
|
||||
int32 rah[TM] = rahc % AH;
|
||||
int1 maskh[TM] = (rah >= pad_h) && (rah < (AH - pad_h));
|
||||
int1 maskw[TM] = (raw >= pad_w) && (raw < (AW - pad_w));
|
||||
int32 offd[TM] = (maskh && maskw) ? 0 : 256;
|
||||
int1 mask[TM] = maskh && maskw;
|
||||
int32 offd[TM] = mask ? 0 : 256;
|
||||
__constant__ int32* pd[TM, TK] = delta + rka[newaxis, :] + offd[:, newaxis];
|
||||
fp32* pa[TM, TK] = a + rxa[:, newaxis] + (*pd);
|
||||
for(int32 k = K; k > 0; k = k - TK){
|
||||
int32 delta[TM, TK] = *pd;
|
||||
fp32 *pa[TM, TK] = pxa + delta;
|
||||
fp32 a[TM, TK] = *pa;
|
||||
fp32 b[TN, TK] = *pb;
|
||||
C = dot(a, trans(b), C);
|
||||
pb = pb + TK*N;
|
||||
pd = pd + TK;
|
||||
pa = pa + (*pd);
|
||||
}
|
||||
int32 rxc[TM] = get_global_range[TM](0);
|
||||
int32 ryc[TN] = get_global_range[TN](1);
|
||||
|
Reference in New Issue
Block a user