From 6c39cdbace56909199e521a4eabccb5b409ab239 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Tue, 6 Aug 2019 16:48:53 -0700 Subject: [PATCH] making sure changes didn't break HMMA --- lib/dnn/dot.cpp | 27 ++++++++------------------- 1 file changed, 8 insertions(+), 19 deletions(-) diff --git a/lib/dnn/dot.cpp b/lib/dnn/dot.cpp index 65395695c..3b9a2e300 100644 --- a/lib/dnn/dot.cpp +++ b/lib/dnn/dot.cpp @@ -74,33 +74,24 @@ void dot::enqueue_impl(driver::stream *stream, driver::kernel *kernel, void dot::triton_c_src(std::ostream &os) const { std::string AS0 = "TM", AS1 = "TK"; std::string BS0 = "TK", BS1 = "TN"; - std::string XAS0 = "TM", XAS1 = "TK", XAS2 = "1"; - std::string XBS0 = "TK", XBS1 = "1", XBS2 = "TN"; std::string bca0 = "[newaxis, :]", bca1 = "[:, newaxis]"; std::string bcb0 = "[:, newaxis]", bcb1 = "[newaxis, :]"; std::string lda0 = "*lda", lda1 = ""; std::string ldb0 = "", ldb1 = "*ldb"; - std::string usea = AT_ ? "trans(xa, 0, 2, 1)" : "xa"; - std::string useb = BT_ ? "trans(xb, 1, 0, 2)" : "trans(xb, 0, 2, 1)"; + std::string usea = AT_ ? "trans(a)" : "a"; + std::string useb = BT_ ? "trans(b)" : "b"; if(AT_){ std::swap(AS0, AS1); - std::swap(XAS0, XAS1); - std::swap(XAS1, XAS2); std::swap(bca0, bca1); std::swap(lda0, lda1); } if(BT_){ std::swap(BS0, BS1); - std::swap(XBS1, XBS2); - std::swap(XBS0, XBS1); std::swap(bcb0, bcb1); std::swap(ldb0, ldb1); } std::string AS = AS0 + ", " + AS1; std::string BS = BS0 + ", " + BS1; - std::string XAS = XAS0 + ", " + XAS1 + ", " + XAS2; - std::string XBS = XBS0 + ", " + XBS1 + ", " + XBS2; - std::string XCS = "TM, TN, 1"; std::string align_lda_str = "multiple_of(" + std::to_string(align_lda_) + ")"; std::string align_ldb_str = "multiple_of(" + std::to_string(align_ldb_) + ")"; std::string res = @@ -110,7 +101,6 @@ const tunable int TN = {16, 32, 64, 128}; const tunable int TK = {32}; const tunable int GZ = {1}; - void matmul(restrict read_only align(16) )" + a_ty_ + R"( *A, restrict read_only align(16) )" + b_ty_ + R"( *B, restrict read_only align(16) float *C, @@ -123,7 +113,7 @@ void matmul(restrict read_only align(16) )" + a_ty_ + R"( *A, int ryb[TN] = ridy * TN + (0 ... TN); int rka[TK] = 0 ... TK; int rkb[TK] = 0 ... TK; - float xc[)" + XCS + R"(] = 0; + float c[TM, TN] = 0; )" + a_ty_ + R"(* pa[)" + AS + "] = A + rka" + bca0 + lda0 + " + rxa" + bca1 + lda1 + R"(; )" + b_ty_ + R"(* pb[)" + BS + "] = B + rkb" + bcb0 + ldb0 + " + ryb" + bcb1 + ldb1 + R"(; bool checka[)" + AS + R"(] = (rka < K))" + bca0 + " && (rxa < M)" + bca1 + R"(; @@ -131,9 +121,7 @@ void matmul(restrict read_only align(16) )" + a_ty_ + R"( *A, )" + a_ty_ + R"( a[)" + AS + R"(] = checka ? *pa : 0; )" + b_ty_ + R"( b[)" + BS + R"(] = checkb ? *pb : 0; for(int k = K; k > 0; k = k - TK){ - )" + a_ty_ + R"( xa[)" + XAS + "] = __reshape(a, " + XAS + R"(); - )" + b_ty_ + R"( xb[)" + XBS + "] = __reshape(b, " + XBS + R"(); - xc = dot()" + usea + ", " + useb + R"(, xc); + c = dot()" + usea + ", " + useb + R"(, c); pa = pa + TK)" + lda0 + R"(; pb = pb + TK)" + ldb0 + R"(; bool checka[)" + AS + R"(] = k > TK; @@ -143,13 +131,14 @@ void matmul(restrict read_only align(16) )" + a_ty_ + R"( *A, } int rxc[TM] = ridx * TM + (0 ... TM); int ryc[TN] = ridy * TN + (0 ... TN); + bool checkc0[TM] = rxc < M; + bool checkc1[TN] = ryc < N; + bool checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :]; float* pc[TM, TN] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis]; - float c[TM, TN] = __sum(xc, 2); - *pc = c; + @checkc *pc = c; } )"; -// std::cout << res << std::endl; os << res; }