ugh
This commit is contained in:
@@ -223,6 +223,7 @@ void shift::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
||||
unsigned grid_2 = num_locks < max_locks_ ? info.globals.at("GZ") : 1;
|
||||
std::array<size_t, 3> grid = {grid_0, grid_1, grid_2};
|
||||
driver::buffer *a = args[0], *b = args[1], *c = args[2];
|
||||
// std::cout << op_ << " " << M_ << " " << N_ << " " << K_ << std::endl;
|
||||
kernel->setArg(0, a);
|
||||
kernel->setArg(1, b);
|
||||
kernel->setArg(2, c);
|
||||
@@ -253,6 +254,9 @@ void shift::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
||||
kernel->setArg(27, (num_locks > max_locks_) ? nullptr : locks_);
|
||||
kernel->setArg(28, (int32_t)grid[0]);
|
||||
kernel->setArg(29, (int32_t)grid[1]);
|
||||
kernel->setArg(30, (int32_t)grid[2]);
|
||||
if(locks_)
|
||||
((driver::cu_buffer*)locks_)->set_zero(stream, 2*max_locks_*4);
|
||||
if(op_ == BPROP){
|
||||
size_t c_nbytes = (c_ty_ == "fp16") ? 2 : 4;
|
||||
((driver::cu_buffer*)c)->set_zero(stream, AH_*AW_*B_*C_*c_nbytes);
|
||||
@@ -317,7 +321,7 @@ const tunable int32 TM = {16, 32, 64, 128};
|
||||
const tunable int32 TN = {16, 32, 64, 128};
|
||||
const tunable int32 TK = {)" + std::to_string(TK_) + "};";
|
||||
if(op_ == WGRAD)
|
||||
result += "const tunable int32 GZ = {1};";
|
||||
result += "const tunable int32 GZ = {1, 4, 16};";
|
||||
else
|
||||
result += "const tunable int32 GZ = {1};";
|
||||
|
||||
@@ -329,14 +333,14 @@ void shift(restrict read_only align(16) )" + a_ty_ + R"( *A,
|
||||
)" + c_ty_ + R"( *C,
|
||||
int32 M, int32 N, int32 K,
|
||||
int32 stride_h, int32 stride_w,
|
||||
int32 lda_b, int32 lda_w, int32 lda_h, int32 lda_c,
|
||||
int32 ldb_b, int32 ldb_w, int32 ldb_h, int32 ldb_c,
|
||||
int32 ldc_b, int32 ldc_w, int32 ldc_h, int32 ldc_c,
|
||||
multiple_of(4) int32 lda_b, multiple_of(4) int32 lda_w, multiple_of(4) int32 lda_h, multiple_of(4) int32 lda_c,
|
||||
multiple_of(4) int32 ldb_b, multiple_of(4) int32 ldb_w, multiple_of(4) int32 ldb_h, multiple_of(4) int32 ldb_c,
|
||||
multiple_of(4) int32 ldc_b, multiple_of(4) int32 ldc_w, multiple_of(4) int32 ldc_h, multiple_of(4) int32 ldc_c,
|
||||
int32 NB,
|
||||
int32 AH, int32 AW,
|
||||
int32 BH, int32 BW,
|
||||
int32 CH, int32 CW,
|
||||
int32* locks, int32 grid0, int32 grid1) {
|
||||
int32* locks, int32 grid0, int32 grid1, int32 grid2) {
|
||||
int32 rxa[TM] = get_global_range[TM](0);
|
||||
int32 ryb[TN] = get_global_range[TN](1);
|
||||
int32 rz = get_global_range[1](2);
|
||||
@@ -345,9 +349,8 @@ void shift(restrict read_only align(16) )" + a_ty_ + R"( *A,
|
||||
fp32 acc[TM, TN] = 0;
|
||||
int32 pad_h = BH / 2;
|
||||
int32 pad_w = BW / 2;
|
||||
int32 split = select(locks == 0, 1, GZ);
|
||||
int32 div = K / split;
|
||||
int32 rem = K % split;
|
||||
int32 div = K / grid2;
|
||||
int32 rem = K % grid2;
|
||||
K = select(rz < rem, div - 1, div);
|
||||
int32 offk = select(rz < rem, rz*(div + 1), rz*div + rem);)";
|
||||
if(op_ == WGRAD){
|
||||
@@ -366,7 +369,7 @@ if(op_ == FPROP){
|
||||
int32 offxa[TM] = rab*lda_b + raw*lda_w + rah*lda_h;
|
||||
int32 offa0[TM, TK] = offxa[:, newaxis];
|
||||
__constant__ int32* pd[TK] = delta_a + rka;
|
||||
int32 d[TK] = *pd;
|
||||
multiple_of(4) int32 d[TK] = *pd;
|
||||
int32 offa_interior[TM, TK] = d[newaxis, :];
|
||||
int32 offa_exterior[TM, TK] = rka[newaxis, :] * lda_c;
|
||||
)" + compute_interior("ra", "TM", "TK") + R"(
|
||||
@@ -524,18 +527,17 @@ else{
|
||||
int32 ridx = get_range_id(0);
|
||||
int32 ridy = get_range_id(1);
|
||||
int32 *plock = locks + ridx + ridy*grid0;
|
||||
while(__atomic_cas(plock, 0, 1));
|
||||
int32 *pcount = plock + grid0*grid1;
|
||||
while(__atomic_cas(plock, 0, 1) == 1);
|
||||
int32 count = *pcount;
|
||||
int32 countp1 = select(count == split - 1, 0, count + 1);
|
||||
int32 countp1 = select(count == grid2 - 1, 0, count + 1);
|
||||
if(count == 0) {
|
||||
@checkc *pc = c;
|
||||
*pcount = countp1;
|
||||
}
|
||||
else {
|
||||
@checkc *pc = c + *pc;
|
||||
*pcount = countp1;
|
||||
}
|
||||
*pcount = countp1;
|
||||
*plock = 0;
|
||||
}
|
||||
else{
|
||||
|
Reference in New Issue
Block a user