[dnn][blocksparse] added dw code
This commit is contained in:
@@ -52,40 +52,56 @@ void dot::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
||||
driver::buffer *b = args[1];
|
||||
driver::buffer *c = args[2];
|
||||
driver::buffer *lut = args[3];
|
||||
int32_t lda = N_;
|
||||
int32_t ldc = N_;
|
||||
kernel->setArg(0, a);
|
||||
kernel->setArg(1, b);
|
||||
kernel->setArg(2, c);
|
||||
kernel->setArg(3, lda);
|
||||
kernel->setArg(4, ldc);
|
||||
kernel->setArg(5, N_);
|
||||
kernel->setArg(6, lut);
|
||||
kernel->setArg(7, locks_.get());
|
||||
kernel->setArg(8, nlocks_);
|
||||
int32_t TM = info.globals["TM"];
|
||||
size_t grid_0 = (N_ + TM - 1) / TM;
|
||||
size_t grid_1 = S_;
|
||||
if(nlocks_)
|
||||
((driver::cu_buffer*)locks_.get())->set_zero(stream, grid_0 * nlocks_ * 2 * 4);
|
||||
stream->enqueue(kernel, {grid_0, grid_1, 1}, {info.num_threads, 1, 1});
|
||||
if(op_ == FPROP || op_ == BPROP){
|
||||
kernel->setArg(3, N_);
|
||||
kernel->setArg(4, BS_);
|
||||
kernel->setArg(5, N_);
|
||||
}
|
||||
else{
|
||||
kernel->setArg(3, N_);
|
||||
kernel->setArg(4, N_);
|
||||
kernel->setArg(5, BS_);
|
||||
}
|
||||
kernel->setArg(6, N_);
|
||||
kernel->setArg(7, lut);
|
||||
kernel->setArg(8, locks_.get());
|
||||
kernel->setArg(9, nlocks_);
|
||||
if(op_ == FPROP || op_ == BPROP){
|
||||
int32_t TM = info.globals["TM"];
|
||||
size_t grid_0 = (N_ + TM - 1) / TM;
|
||||
size_t grid_1 = S_;
|
||||
if(nlocks_)
|
||||
((driver::cu_buffer*)locks_.get())->set_zero(stream, grid_0 * nlocks_ * 2 * 4);
|
||||
stream->enqueue(kernel, {grid_0, grid_1, 1}, {info.num_threads, 1, 1});
|
||||
}
|
||||
else{
|
||||
size_t grid_0 = nblocks_;
|
||||
stream->enqueue(kernel, {grid_0, 1, 1}, {info.num_threads, 1, 1});
|
||||
}
|
||||
}
|
||||
|
||||
driver::buffer* dot::get_locks() const {
|
||||
return locks_.get();
|
||||
}
|
||||
|
||||
void dot::triton_c_src(std::ostream &os) const {
|
||||
std::string usea = (op_ == WGRAD) ? "trans(a)" : "a";
|
||||
std::string useb = (op_ == FPROP) ? "trans(b)" : "b";
|
||||
std::string dot::triton_c_src_ydx() const {
|
||||
bool AT = (op_ == WGRAD);
|
||||
bool BT = (op_ == FPROP);
|
||||
std::string usea = AT ? "trans(a)" : "a";
|
||||
std::string useb = BT ? "trans(b)" : "b";
|
||||
std::string sizea = "TM, TK";
|
||||
std::string sizeb = (op_ == FPROP) ? "TN, TK" : "TK, TN";
|
||||
std::string sizeb = BT ? "TN, TK" : "TK, TN";
|
||||
std::string bca0 = ":, newaxis";
|
||||
std::string bca1 = "newaxis, :";
|
||||
std::string bcb0 = (op_ == FPROP) ? ":, newaxis" : "newaxis, :";
|
||||
std::string bcb1 = (op_ == FPROP) ? "newaxis, :" : ":, newaxis";
|
||||
std::string ldb0 = (op_ == FPROP) ? "" : "*TK";
|
||||
std::string ldb1 = (op_ == FPROP) ? "*TK" : "" ;
|
||||
std::string bcb0 = BT ? ":, newaxis" : "newaxis, :";
|
||||
std::string bcb1 = BT ? "newaxis, :" : ":, newaxis";
|
||||
std::string lda0 = AT ? "*lda" : "";
|
||||
std::string lda1 = AT ? "" : "*lda";
|
||||
std::string ldb0 = BT ? "" : "*ldb";
|
||||
std::string ldb1 = BT ? "*ldb" : "" ;
|
||||
std::string result =
|
||||
R"(
|
||||
const tunable int TM = {16, 32, 64, 128};
|
||||
@@ -95,26 +111,25 @@ void dot::triton_c_src(std::ostream &os) const {
|
||||
void bsdot(restrict read_only align(16) )" + ab_ty_ + R"( *A,
|
||||
restrict read_only align(16) )" + ab_ty_ + R"( *B,
|
||||
)" + c_ty_ + R"(* C,
|
||||
int lda, int ldc, int N,
|
||||
int* lut, int* locks, int nlocks) {
|
||||
int lda, int ldb, int ldc,
|
||||
int N, int* lut,
|
||||
int* locks, int nlocks) {
|
||||
int ridx = get_range_id(0);
|
||||
int ridy = get_range_id(1);
|
||||
float acc[TM, TN] = 0;
|
||||
int rxa[TM] = ridx * TM + (0 ... TM);
|
||||
int ryb[TN] = 0 ... TN;
|
||||
int rka[TK] = 0 ... TK;
|
||||
int rkb[TK] = 0 ... TK;
|
||||
bool checka[TM, TK] = (rxa < N)[:, newaxis];
|
||||
int offa[)" + sizea + "] = rxa[" + bca0 + "] + rka[" + bca1 + R"(]*lda;
|
||||
int offb[)" + sizeb + "] = ryb[" + bcb0 + "]" + ldb0 + " + rkb[" + bcb1 + "]" + ldb1 + R"(;
|
||||
int *header = lut + ridy * 4;
|
||||
int *header = lut + get_range_id(1) * 4;
|
||||
int offset = *(header + 0);
|
||||
int K = *(header + 1);
|
||||
int column = *(header + 2);
|
||||
int lockid = *(header + 3);
|
||||
int rxa[TM] = ridx * TM + (0 ... TM);
|
||||
int ryb[TN] = 0 ... TN;
|
||||
int *plut = lut + offset * 2;
|
||||
for(int k = K; k > 0; k = k - 1)
|
||||
{
|
||||
int offa[)" + sizea + "] = rxa[" + bca0 + "]" + lda0 + " + rka[" + bca1 + "]" + lda1 + R"(;
|
||||
int offb[)" + sizeb + "] = ryb[" + bcb0 + "]" + ldb0 + " + rkb[" + bcb1 + "]" + ldb1 + R"(;
|
||||
bool checka[TM, TK] = (rxa < N)[:, newaxis];
|
||||
for(int k = K; k > 0; k = k - 1) {
|
||||
int ak = *(plut + 0);
|
||||
int bk = *(plut + 1);
|
||||
)" + ab_ty_ + "* pa[" + sizea + R"(] = A + offa + ak * TK * lda;
|
||||
@@ -137,17 +152,83 @@ void dot::triton_c_src(std::ostream &os) const {
|
||||
int *pcount = plock + get_num_program(0)*nlocks;
|
||||
while(__atomic_cas(plock, 0, 1));
|
||||
int count = *pcount;
|
||||
if(count == 0){
|
||||
if(count == 0)
|
||||
@checkc *pc = c;
|
||||
}
|
||||
else{
|
||||
else
|
||||
@checkc *pc = c + *pc;
|
||||
}
|
||||
__atomic_exch(pcount, 1);
|
||||
__atomic_exch(plock, 0);
|
||||
}
|
||||
})";
|
||||
os << result;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
std::string dot::triton_c_src_dw() const {
|
||||
bool AT = (op_ == WGRAD);
|
||||
bool BT = (op_ == FPROP);
|
||||
std::string usea = AT ? "trans(a)" : "a";
|
||||
std::string useb = BT ? "trans(b)" : "b";
|
||||
std::string sizea = AT ? "TK, TM" : "TM, TK";
|
||||
std::string sizeb = BT ? "TN, TK" : "TK, TN";
|
||||
std::string bca0 = AT ? "newaxis, :" : ":, newaxis";
|
||||
std::string bca1 = AT ? ":, newaxis" : "newaxis, :";
|
||||
std::string bcb0 = BT ? ":, newaxis" : "newaxis, :";
|
||||
std::string bcb1 = BT ? "newaxis, :" : ":, newaxis";
|
||||
std::string lda0 = AT ? "*lda" : "";
|
||||
std::string lda1 = AT ? "" : "*lda";
|
||||
std::string ldb0 = BT ? "" : "*ldb";
|
||||
std::string ldb1 = BT ? "*ldb" : "" ;
|
||||
std::string result =
|
||||
R"(
|
||||
const tunable int TM = {)" + std::to_string(BS_) + R"(};
|
||||
const tunable int TN = {)" + std::to_string(BS_) + R"(};
|
||||
const tunable int TK = {32};
|
||||
|
||||
void bsdot(restrict read_only align(16) )" + ab_ty_ + R"( *A,
|
||||
restrict read_only align(16) )" + ab_ty_ + R"( *B,
|
||||
)" + c_ty_ + R"(* C,
|
||||
int lda, int ldb, int ldc,
|
||||
int N, int* lut,
|
||||
int* locks, int nlocks) {
|
||||
int ridx = get_range_id(0);
|
||||
float acc[TM, TN] = 0;
|
||||
int rka[TK] = 0 ... TK;
|
||||
int rkb[TK] = 0 ... TK;
|
||||
int *header = lut + ridx * 2;
|
||||
int offx = *(header + 0);
|
||||
int offy = *(header + 1);
|
||||
int rxa[TM] = offx*TM + (0 ... TM);
|
||||
int ryb[TN] = offy*TN + (0 ... TN);
|
||||
bool checka[TK, TM] = (rka < N)[:, newaxis];
|
||||
bool checkb[TK, TN] = (rkb < N)[:, newaxis];
|
||||
int offa[)" + sizea + "] = rxa[" + bca0 + "]" + lda0 + " + rka[" + bca1 + "]" + lda1 + R"(;
|
||||
int offb[)" + sizeb + "] = ryb[" + bcb0 + "]" + ldb0 + " + rkb[" + bcb1 + "]" + ldb1 + R"(;
|
||||
)" + ab_ty_ + " * pa[" + sizea + R"(] = A + offa;
|
||||
)" + ab_ty_ + " * pb[" + sizeb + R"(] = B + offb;
|
||||
)" + ab_ty_ + " a[" + sizea + R"(] = checka ? *pa : 0;
|
||||
)" + ab_ty_ + " b[" + sizeb + R"(] = checkb ? *pb : 0;
|
||||
for(int k = N; k > 0; k = k - TK) {
|
||||
acc = dot()" + usea + ", " + useb + R"(, acc);
|
||||
pa = pa + TK)" + lda1 + R"(;
|
||||
pb = pb + TK)" + ldb1 + R"(;
|
||||
a = checka ? *pa : 0;
|
||||
b = checkb ? *pb : 0;
|
||||
}
|
||||
int rxc[TM] = (0 ... TM);
|
||||
int ryc[TN] = (0 ... TN);
|
||||
)" + c_ty_ + R"( c[TM, TN] = acc;
|
||||
)" + c_ty_ + R"(* pc[TM, TN] = C + rxc[:, newaxis]*TM + ryc[newaxis, :] + ridx*TM*TN;
|
||||
*pc = c;
|
||||
})";
|
||||
|
||||
return result;
|
||||
}
|
||||
void dot::triton_c_src(std::ostream &os) const {
|
||||
if(op_ == FPROP || op_ == BPROP)
|
||||
os << triton_c_src_ydx();
|
||||
else
|
||||
os << triton_c_src_dw();
|
||||
}
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user