This commit is contained in:
Philippe Tillet
2019-07-16 12:59:27 -07:00
parent f50d7a420a
commit 7d1797cd32
6 changed files with 27 additions and 28 deletions

View File

@@ -223,6 +223,7 @@ void shift::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
unsigned grid_2 = num_locks < max_locks_ ? info.globals.at("GZ") : 1;
std::array<size_t, 3> grid = {grid_0, grid_1, grid_2};
driver::buffer *a = args[0], *b = args[1], *c = args[2];
// std::cout << op_ << " " << M_ << " " << N_ << " " << K_ << std::endl;
kernel->setArg(0, a);
kernel->setArg(1, b);
kernel->setArg(2, c);
@@ -253,6 +254,9 @@ void shift::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
kernel->setArg(27, (num_locks > max_locks_) ? nullptr : locks_);
kernel->setArg(28, (int32_t)grid[0]);
kernel->setArg(29, (int32_t)grid[1]);
kernel->setArg(30, (int32_t)grid[2]);
if(locks_)
((driver::cu_buffer*)locks_)->set_zero(stream, 2*max_locks_*4);
if(op_ == BPROP){
size_t c_nbytes = (c_ty_ == "fp16") ? 2 : 4;
((driver::cu_buffer*)c)->set_zero(stream, AH_*AW_*B_*C_*c_nbytes);
@@ -317,7 +321,7 @@ const tunable int32 TM = {16, 32, 64, 128};
const tunable int32 TN = {16, 32, 64, 128};
const tunable int32 TK = {)" + std::to_string(TK_) + "};";
if(op_ == WGRAD)
result += "const tunable int32 GZ = {1};";
result += "const tunable int32 GZ = {1, 4, 16};";
else
result += "const tunable int32 GZ = {1};";
@@ -329,14 +333,14 @@ void shift(restrict read_only align(16) )" + a_ty_ + R"( *A,
)" + c_ty_ + R"( *C,
int32 M, int32 N, int32 K,
int32 stride_h, int32 stride_w,
int32 lda_b, int32 lda_w, int32 lda_h, int32 lda_c,
int32 ldb_b, int32 ldb_w, int32 ldb_h, int32 ldb_c,
int32 ldc_b, int32 ldc_w, int32 ldc_h, int32 ldc_c,
multiple_of(4) int32 lda_b, multiple_of(4) int32 lda_w, multiple_of(4) int32 lda_h, multiple_of(4) int32 lda_c,
multiple_of(4) int32 ldb_b, multiple_of(4) int32 ldb_w, multiple_of(4) int32 ldb_h, multiple_of(4) int32 ldb_c,
multiple_of(4) int32 ldc_b, multiple_of(4) int32 ldc_w, multiple_of(4) int32 ldc_h, multiple_of(4) int32 ldc_c,
int32 NB,
int32 AH, int32 AW,
int32 BH, int32 BW,
int32 CH, int32 CW,
int32* locks, int32 grid0, int32 grid1) {
int32* locks, int32 grid0, int32 grid1, int32 grid2) {
int32 rxa[TM] = get_global_range[TM](0);
int32 ryb[TN] = get_global_range[TN](1);
int32 rz = get_global_range[1](2);
@@ -345,9 +349,8 @@ void shift(restrict read_only align(16) )" + a_ty_ + R"( *A,
fp32 acc[TM, TN] = 0;
int32 pad_h = BH / 2;
int32 pad_w = BW / 2;
int32 split = select(locks == 0, 1, GZ);
int32 div = K / split;
int32 rem = K % split;
int32 div = K / grid2;
int32 rem = K % grid2;
K = select(rz < rem, div - 1, div);
int32 offk = select(rz < rem, rz*(div + 1), rz*div + rem);)";
if(op_ == WGRAD){
@@ -366,7 +369,7 @@ if(op_ == FPROP){
int32 offxa[TM] = rab*lda_b + raw*lda_w + rah*lda_h;
int32 offa0[TM, TK] = offxa[:, newaxis];
__constant__ int32* pd[TK] = delta_a + rka;
int32 d[TK] = *pd;
multiple_of(4) int32 d[TK] = *pd;
int32 offa_interior[TM, TK] = d[newaxis, :];
int32 offa_exterior[TM, TK] = rka[newaxis, :] * lda_c;
)" + compute_interior("ra", "TM", "TK") + R"(
@@ -524,18 +527,17 @@ else{
int32 ridx = get_range_id(0);
int32 ridy = get_range_id(1);
int32 *plock = locks + ridx + ridy*grid0;
while(__atomic_cas(plock, 0, 1));
int32 *pcount = plock + grid0*grid1;
while(__atomic_cas(plock, 0, 1) == 1);
int32 count = *pcount;
int32 countp1 = select(count == split - 1, 0, count + 1);
int32 countp1 = select(count == grid2 - 1, 0, count + 1);
if(count == 0) {
@checkc *pc = c;
*pcount = countp1;
}
else {
@checkc *pc = c + *pc;
*pcount = countp1;
}
*pcount = countp1;
*plock = 0;
}
else{