fixed simple FP16 test
This commit is contained in:
@@ -9,11 +9,11 @@ namespace dnn{
|
||||
|
||||
dot::dot(int M, int N, int K,
|
||||
bool AT, bool BT,
|
||||
std::string a_ty, std::string b_ty,
|
||||
std::string a_ty, std::string b_ty, std::string c_ty,
|
||||
unsigned align_lda, unsigned align_ldb, unsigned align_ldc)
|
||||
: base("matmul"),
|
||||
M_(M), N_(N), K_(K), AT_(AT), BT_(BT),
|
||||
a_ty_(a_ty), b_ty_(b_ty),
|
||||
a_ty_(a_ty), b_ty_(b_ty), c_ty_(c_ty),
|
||||
align_lda_(align_lda), align_ldb_(align_ldb), align_ldc_(align_ldc),
|
||||
locks_(nullptr) {
|
||||
|
||||
@@ -74,24 +74,33 @@ void dot::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
||||
void dot::triton_c_src(std::ostream &os) const {
|
||||
std::string AS0 = "TM", AS1 = "TK";
|
||||
std::string BS0 = "TK", BS1 = "TN";
|
||||
std::string XAS0 = "TM", XAS1 = "TK", XAS2 = "1";
|
||||
std::string XBS0 = "TK", XBS1 = "1", XBS2 = "TN";
|
||||
std::string bca0 = "[newaxis, :]", bca1 = "[:, newaxis]";
|
||||
std::string bcb0 = "[:, newaxis]", bcb1 = "[newaxis, :]";
|
||||
std::string lda0 = "*lda", lda1 = "";
|
||||
std::string ldb0 = "", ldb1 = "*ldb";
|
||||
std::string usea = AT_ ? "trans(a)" : "a";
|
||||
std::string useb = BT_ ? "trans(b)" : "b";
|
||||
std::string usea = AT_ ? "trans(xa, 0, 2, 1)" : "xa";
|
||||
std::string useb = BT_ ? "trans(xb, 1, 0, 2)" : "trans(xb, 0, 2, 1)";
|
||||
if(AT_){
|
||||
std::swap(AS0, AS1);
|
||||
std::swap(XAS0, XAS1);
|
||||
std::swap(XAS1, XAS2);
|
||||
std::swap(bca0, bca1);
|
||||
std::swap(lda0, lda1);
|
||||
}
|
||||
if(BT_){
|
||||
std::swap(BS0, BS1);
|
||||
std::swap(XBS1, XBS2);
|
||||
std::swap(XBS0, XBS1);
|
||||
std::swap(bcb0, bcb1);
|
||||
std::swap(ldb0, ldb1);
|
||||
}
|
||||
std::string AS = AS0 + ", " + AS1;
|
||||
std::string BS = BS0 + ", " + BS1;
|
||||
std::string XAS = XAS0 + ", " + XAS1 + ", " + XAS2;
|
||||
std::string XBS = XBS0 + ", " + XBS1 + ", " + XBS2;
|
||||
std::string XCS = "TM, TN, 1";
|
||||
std::string align_lda_str = "multiple_of(" + std::to_string(align_lda_) + ")";
|
||||
std::string align_ldb_str = "multiple_of(" + std::to_string(align_ldb_) + ")";
|
||||
std::string res =
|
||||
@@ -101,9 +110,10 @@ const tunable int TN = {16, 32, 64, 128};
|
||||
const tunable int TK = {32};
|
||||
const tunable int GZ = {1};
|
||||
|
||||
|
||||
void matmul(restrict read_only align(16) )" + a_ty_ + R"( *A,
|
||||
restrict read_only align(16) )" + b_ty_ + R"( *B,
|
||||
restrict read_only align(16) float *C,
|
||||
restrict read_only align(16) )" + c_ty_ + R"( *C,
|
||||
int M, int N, int K,
|
||||
)" + align_lda_str + R"( int lda, )" + align_ldb_str + R"(" int ldb, int ldc,
|
||||
int bound, int *locks, int grid0, int grid1) {
|
||||
@@ -113,7 +123,7 @@ void matmul(restrict read_only align(16) )" + a_ty_ + R"( *A,
|
||||
int ryb[TN] = ridy * TN + (0 ... TN);
|
||||
int rka[TK] = 0 ... TK;
|
||||
int rkb[TK] = 0 ... TK;
|
||||
float c[TM, TN] = 0;
|
||||
float xc[)" + XCS + R"(] = 0;
|
||||
)" + a_ty_ + R"(* pa[)" + AS + "] = A + rka" + bca0 + lda0 + " + rxa" + bca1 + lda1 + R"(;
|
||||
)" + b_ty_ + R"(* pb[)" + BS + "] = B + rkb" + bcb0 + ldb0 + " + ryb" + bcb1 + ldb1 + R"(;
|
||||
bool checka[)" + AS + R"(] = (rka < K))" + bca0 + " && (rxa < M)" + bca1 + R"(;
|
||||
@@ -121,7 +131,9 @@ void matmul(restrict read_only align(16) )" + a_ty_ + R"( *A,
|
||||
)" + a_ty_ + R"( a[)" + AS + R"(] = checka ? *pa : 0;
|
||||
)" + b_ty_ + R"( b[)" + BS + R"(] = checkb ? *pb : 0;
|
||||
for(int k = K; k > 0; k = k - TK){
|
||||
c = dot()" + usea + ", " + useb + R"(, c);
|
||||
)" + a_ty_ + R"( xa[)" + XAS + "] = __reshape(a, " + XAS + R"();
|
||||
)" + b_ty_ + R"( xb[)" + XBS + "] = __reshape(b, " + XBS + R"();
|
||||
xc = dot()" + usea + ", " + useb + R"(, xc);
|
||||
pa = pa + TK)" + lda0 + R"(;
|
||||
pb = pb + TK)" + ldb0 + R"(;
|
||||
bool checka[)" + AS + R"(] = k > TK;
|
||||
@@ -131,14 +143,13 @@ void matmul(restrict read_only align(16) )" + a_ty_ + R"( *A,
|
||||
}
|
||||
int rxc[TM] = ridx * TM + (0 ... TM);
|
||||
int ryc[TN] = ridy * TN + (0 ... TN);
|
||||
bool checkc0[TM] = rxc < M;
|
||||
bool checkc1[TN] = ryc < N;
|
||||
bool checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
|
||||
float* pc[TM, TN] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis];
|
||||
@checkc *pc = c;
|
||||
)" + c_ty_ + R"(* pc[TM, TN] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis];
|
||||
)" + c_ty_ + R"( c[TM, TN] = __sum(xc, 2);
|
||||
*pc = c;
|
||||
}
|
||||
)";
|
||||
|
||||
// std::cout << res << std::endl;
|
||||
os << res;
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user