|
|
|
@@ -51,6 +51,7 @@ void gemm::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
|
|
|
|
driver::buffer *a = args[0], *b = args[1], *c = args[2];
|
|
|
|
|
unsigned TM = info.globals.at("TM");
|
|
|
|
|
unsigned TN = info.globals.at("TN");
|
|
|
|
|
unsigned TK = info.globals.at("TK");
|
|
|
|
|
unsigned grid_0 = (M_ + TM - 1)/TM;
|
|
|
|
|
unsigned grid_1 = (N_ + TN - 1)/TN;
|
|
|
|
|
unsigned grid_2 = 1;
|
|
|
|
@@ -67,23 +68,13 @@ void gemm::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
|
|
|
|
kernel->setArg(6, lda);
|
|
|
|
|
kernel->setArg(7, ldb);
|
|
|
|
|
kernel->setArg(8, ldc);
|
|
|
|
|
kernel->setArg(9, locks_);
|
|
|
|
|
kernel->setArg(10, grid_0);
|
|
|
|
|
kernel->setArg(11, grid_1);
|
|
|
|
|
kernel->setArg(9, TK);
|
|
|
|
|
kernel->setArg(10, locks_);
|
|
|
|
|
kernel->setArg(11, grid_0);
|
|
|
|
|
kernel->setArg(12, grid_1);
|
|
|
|
|
stream->enqueue(kernel, grid, {info.num_threads, 1, 1});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<unsigned> gemm::default_params() {
|
|
|
|
|
if(AT_ && BT_)
|
|
|
|
|
return {32, 64, 32, 64, 16, 8, 2, 2, 4, 2, 8, 4, 2, 1};
|
|
|
|
|
else if(AT_ && !BT_)
|
|
|
|
|
return {32, 64, 32, 64, 16, 8, 2, 2, 4, 2, 8, 4, 2, 1};
|
|
|
|
|
else if(!AT_ && BT_)
|
|
|
|
|
return {16, 2, 64, 16, 2, 64, 16, 8, 2, 2, 8, 8, 8, 1};
|
|
|
|
|
else
|
|
|
|
|
return {16, 2, 128, 32, 32, 32, 4, 2, 2, 8, 8, 4, 2, 1};
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void gemm::triton_c_src(std::ostream &os) const {
|
|
|
|
|
std::string AS0 = "TM", AS1 = "TK";
|
|
|
|
|
std::string BS0 = "TK", BS1 = "TN";
|
|
|
|
@@ -103,12 +94,14 @@ void gemm::triton_c_src(std::ostream &os) const {
|
|
|
|
|
std::swap(bcb0, bcb1);
|
|
|
|
|
std::swap(ldb0, ldb1);
|
|
|
|
|
}
|
|
|
|
|
std::string AS = AS0 + ", " + AS1;
|
|
|
|
|
std::string BS = BS0 + ", " + BS1;
|
|
|
|
|
std::string align_lda_str = "multiple_of(" + std::to_string(align_lda_) + ")";
|
|
|
|
|
std::string align_ldb_str = "multiple_of(" + std::to_string(align_ldb_) + ")";
|
|
|
|
|
std::string res =
|
|
|
|
|
R"(
|
|
|
|
|
const tunable int32 TM = {16, 32, 64, 128};
|
|
|
|
|
const tunable int32 TN = {16, 32, 64, 128};
|
|
|
|
|
const tunable int32 TM = {32, 64, 128, 256};
|
|
|
|
|
const tunable int32 TN = {32, 64, 128, 256};
|
|
|
|
|
const tunable int32 TK = {32};
|
|
|
|
|
const tunable int32 GZ = {1};
|
|
|
|
|
|
|
|
|
@@ -117,27 +110,36 @@ void matmul(restrict read_only align(16) )" + a_ty_ + R"( *A,
|
|
|
|
|
fp32 *C,
|
|
|
|
|
int32 M, int32 N, int32 K,
|
|
|
|
|
)" + align_lda_str + R"( int32 lda, )" + align_ldb_str + R"(" int32 ldb, int32 ldc,
|
|
|
|
|
int32 *locks, int32 grid0, int32 grid1) {
|
|
|
|
|
int32 rxa[TM] = get_global_range[TM](0);
|
|
|
|
|
int32 ryb[TN] = get_global_range[TN](1);
|
|
|
|
|
int32 bound, int32 *locks, int32 grid0, int32 grid1) {
|
|
|
|
|
int32 ridx = get_range_id(0);
|
|
|
|
|
int32 ridy = get_range_id(1);
|
|
|
|
|
int32 rxa[TM] = ridx*TM + (0 ... TM);
|
|
|
|
|
int32 ryb[TN] = ridy*TN + (0 ... TN);
|
|
|
|
|
int32 rka[TK] = 0 ... TK;
|
|
|
|
|
int32 rkb[TK] = 0 ... TK;
|
|
|
|
|
fp32 c[TM, TN] = 0;
|
|
|
|
|
)" + a_ty_ + R"(* pa[)" + AS0 + ", " + AS1 + "] = A + rka" + bca0 + lda0 + " + rxa" + bca1 + lda1 + R"(;
|
|
|
|
|
)" + b_ty_ + R"(* pb[)" + BS0 + ", " + BS1 + "] = B + rkb" + bcb0 + ldb0 + " + ryb" + bcb1 + ldb1 + R"(;
|
|
|
|
|
)" + a_ty_ + R"( a[)" + AS0 + ", " + AS1 + R"(] = *pa;
|
|
|
|
|
)" + b_ty_ + R"( b[)" + BS0 + ", " + BS1 + R"(] = *pb;
|
|
|
|
|
for(int32 k = K; k > TK; k = k - TK){
|
|
|
|
|
)" + a_ty_ + R"(* pa[)" + AS + "] = A + rka" + bca0 + lda0 + " + rxa" + bca1 + lda1 + R"(;
|
|
|
|
|
)" + b_ty_ + R"(* pb[)" + BS + "] = B + rkb" + bcb0 + ldb0 + " + ryb" + bcb1 + ldb1 + R"(;
|
|
|
|
|
int1 checka[)" + AS + R"(] = (rka < K))" + bca0 + " && (rxa < M)" + bca1 + R"(;
|
|
|
|
|
int1 checkb[)" + BS + R"(] = (rkb < K))" + bcb0 + " && (ryb < N)" + bcb1 + R"(;
|
|
|
|
|
)" + a_ty_ + R"( a[)" + AS + R"(] = checka ? *pa : 0;
|
|
|
|
|
)" + b_ty_ + R"( b[)" + BS + R"(] = checkb ? *pb : 0;
|
|
|
|
|
for(int32 k = K; k > 0; k = k - TK){
|
|
|
|
|
c = dot()" + usea + ", " + useb + R"(, c);
|
|
|
|
|
pa = pa + TK)" + lda0 + R"(;
|
|
|
|
|
pb = pb + TK)" + ldb0 + R"(;
|
|
|
|
|
a = *pa;
|
|
|
|
|
b = *pb;
|
|
|
|
|
int1 checka[)" + AS + R"(] = k > bound;
|
|
|
|
|
int1 checkb[)" + BS + R"(] = k > bound;
|
|
|
|
|
@checka a = *pa;
|
|
|
|
|
@checkb b = *pb;
|
|
|
|
|
}
|
|
|
|
|
int32 rxc[TM] = get_global_range[TM](0);
|
|
|
|
|
int32 ryc[TN] = get_global_range[TN](1);
|
|
|
|
|
int32 rxc[TM] = ridx*TM + (0 ... TM);
|
|
|
|
|
int32 ryc[TN] = ridy*TN + (0 ... TN);
|
|
|
|
|
int1 checkc0[TM] = rxc < M;
|
|
|
|
|
int1 checkc1[TN] = ryc < N;
|
|
|
|
|
int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
|
|
|
|
|
fp32* pc[TM, TN] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis];
|
|
|
|
|
*pc = c;
|
|
|
|
|
@checkc *pc = c;
|
|
|
|
|
}
|
|
|
|
|
)";
|
|
|
|
|
os << res;
|
|
|
|
|