diff --git a/examples/cpp/dot.cpp b/examples/cpp/dot.cpp index 3f04d01ad..6d4c728a8 100644 --- a/examples/cpp/dot.cpp +++ b/examples/cpp/dot.cpp @@ -47,7 +47,7 @@ perf_t do_bench(triton::driver::stream* stream, bool AT, bool BT, int32_t M, int stream->write(db, true, 0, hb); stream->write(dc, true, 0, hc); stream->synchronize(); - triton::dnn::dot dot(M, N, K, AT, BT, ty, ty, 8, 8, 8); + triton::dnn::dot dot(M, N, K, AT, BT, ty, ty, ty, 8, 8, 8); // benchmark triton double triton_ns = triton::tools::bench([&]() { dot.enqueue(stream, {da, db, dc}, triton::dnn::NO_TUNING);}, stream); // benchmark cublas @@ -77,7 +77,7 @@ perf_t do_bench(triton::driver::stream* stream, bool AT, bool BT, int32_t M, int std::vector rc(hc.size()); dot.cpu_ref(rc, ha, hb); for(size_t i = 0; i < M*N; i++) - if(std::isnan(hc[i]) || std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-4){ + if(std::isnan(hc[i]) || std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-2){ std::cout << i << " " << hc[i] << " " << rc[i] << std::endl; exit(EXIT_FAILURE); } diff --git a/examples/python/tensorflow/dot.cpp b/examples/python/tensorflow/dot.cpp index bdcb5c62c..453bb87cb 100644 --- a/examples/python/tensorflow/dot.cpp +++ b/examples/python/tensorflow/dot.cpp @@ -49,8 +49,8 @@ class DotOp : public OpKernel { triton::driver::cu_buffer db(ctx, b.tensor_data().size(), (CUdeviceptr)b.tensor_data().data(), false); triton::driver::cu_buffer dc(ctx, c->tensor_data().size(), (CUdeviceptr)c->tensor_data().data(), false); // template - triton::dnn::dot dot(M, N, K, false, true, "half", "half", 8, 8, 8); - dot.enqueue(stream, {&da, &db, &dc}); + triton::dnn::dot dot(M, N, K, false, false, "half", "half", "float", 8, 8, 8); + dot.enqueue(stream, {&da, &db, &dc}, triton::dnn::autotuning_t::NO_TUNING); } private: diff --git a/examples/python/tensorflow/run.py b/examples/python/tensorflow/run.py index 4b1f7ac53..ffdde3f76 100644 --- a/examples/python/tensorflow/run.py +++ b/examples/python/tensorflow/run.py @@ -23,7 +23,7 @@ def run_dot(): result = sess.run([c], feed_dict = {a: ha, b: hb})[0] # Test - hresult = np.dot(ha.T, hb).T + hresult = np.dot(ha.T, hb.T).T dif = np.abs(result - hresult) np.savetxt('dif.dat', dif, '%2.4f') print(hresult) diff --git a/include/triton/dnn/dot.h b/include/triton/dnn/dot.h index 2beeede7b..f36d05db5 100644 --- a/include/triton/dnn/dot.h +++ b/include/triton/dnn/dot.h @@ -24,7 +24,7 @@ private: public: dot(int M, int N, int K, bool AT, bool BT, - std::string a_ty, std::string b_ty, + std::string a_ty, std::string b_ty, std::string c_ty, unsigned align_lda, unsigned align_ldb, unsigned align_ldc); // number of flops @@ -42,10 +42,10 @@ public: size_t M, size_t N, size_t K){ for(size_t m = 0; m < M; m++) for(size_t n = 0; n < N; n++){ - T acc = static_cast((double)0); + float acc = 0; for(size_t k = 0; k < K; k++) - acc = acc + (AT?a[k + m*K]:a[m + k*M]) * (BT?b[n + k*N]:b[k + n*K]); - c[m + n*M] = acc; + acc = acc + (AT ? a[k + m*K] : a[m + k*M]) * (BT ? b[n + k*N] : b[k + n*K]); + c[m + n*M] = static_cast(acc); } } template @@ -68,6 +68,7 @@ private: bool BT_; std::string a_ty_; std::string b_ty_; + std::string c_ty_; unsigned align_lda_; unsigned align_ldb_; unsigned align_ldc_; diff --git a/lib/dnn/dot.cpp b/lib/dnn/dot.cpp index 3b9a2e300..ddad107f0 100644 --- a/lib/dnn/dot.cpp +++ b/lib/dnn/dot.cpp @@ -9,11 +9,11 @@ namespace dnn{ dot::dot(int M, int N, int K, bool AT, bool BT, - std::string a_ty, std::string b_ty, + std::string a_ty, std::string b_ty, std::string c_ty, unsigned align_lda, unsigned align_ldb, unsigned align_ldc) : base("matmul"), M_(M), N_(N), K_(K), AT_(AT), BT_(BT), - a_ty_(a_ty), b_ty_(b_ty), + a_ty_(a_ty), b_ty_(b_ty), c_ty_(c_ty), align_lda_(align_lda), align_ldb_(align_ldb), align_ldc_(align_ldc), locks_(nullptr) { @@ -74,24 +74,33 @@ void dot::enqueue_impl(driver::stream *stream, driver::kernel *kernel, void dot::triton_c_src(std::ostream &os) const { std::string AS0 = "TM", AS1 = "TK"; std::string BS0 = "TK", BS1 = "TN"; + std::string XAS0 = "TM", XAS1 = "TK", XAS2 = "1"; + std::string XBS0 = "TK", XBS1 = "1", XBS2 = "TN"; std::string bca0 = "[newaxis, :]", bca1 = "[:, newaxis]"; std::string bcb0 = "[:, newaxis]", bcb1 = "[newaxis, :]"; std::string lda0 = "*lda", lda1 = ""; std::string ldb0 = "", ldb1 = "*ldb"; - std::string usea = AT_ ? "trans(a)" : "a"; - std::string useb = BT_ ? "trans(b)" : "b"; + std::string usea = AT_ ? "trans(xa, 0, 2, 1)" : "xa"; + std::string useb = BT_ ? "trans(xb, 1, 0, 2)" : "trans(xb, 0, 2, 1)"; if(AT_){ std::swap(AS0, AS1); + std::swap(XAS0, XAS1); + std::swap(XAS1, XAS2); std::swap(bca0, bca1); std::swap(lda0, lda1); } if(BT_){ std::swap(BS0, BS1); + std::swap(XBS1, XBS2); + std::swap(XBS0, XBS1); std::swap(bcb0, bcb1); std::swap(ldb0, ldb1); } std::string AS = AS0 + ", " + AS1; std::string BS = BS0 + ", " + BS1; + std::string XAS = XAS0 + ", " + XAS1 + ", " + XAS2; + std::string XBS = XBS0 + ", " + XBS1 + ", " + XBS2; + std::string XCS = "TM, TN, 1"; std::string align_lda_str = "multiple_of(" + std::to_string(align_lda_) + ")"; std::string align_ldb_str = "multiple_of(" + std::to_string(align_ldb_) + ")"; std::string res = @@ -101,9 +110,10 @@ const tunable int TN = {16, 32, 64, 128}; const tunable int TK = {32}; const tunable int GZ = {1}; + void matmul(restrict read_only align(16) )" + a_ty_ + R"( *A, restrict read_only align(16) )" + b_ty_ + R"( *B, - restrict read_only align(16) float *C, + restrict read_only align(16) )" + c_ty_ + R"( *C, int M, int N, int K, )" + align_lda_str + R"( int lda, )" + align_ldb_str + R"(" int ldb, int ldc, int bound, int *locks, int grid0, int grid1) { @@ -113,7 +123,7 @@ void matmul(restrict read_only align(16) )" + a_ty_ + R"( *A, int ryb[TN] = ridy * TN + (0 ... TN); int rka[TK] = 0 ... TK; int rkb[TK] = 0 ... TK; - float c[TM, TN] = 0; + float xc[)" + XCS + R"(] = 0; )" + a_ty_ + R"(* pa[)" + AS + "] = A + rka" + bca0 + lda0 + " + rxa" + bca1 + lda1 + R"(; )" + b_ty_ + R"(* pb[)" + BS + "] = B + rkb" + bcb0 + ldb0 + " + ryb" + bcb1 + ldb1 + R"(; bool checka[)" + AS + R"(] = (rka < K))" + bca0 + " && (rxa < M)" + bca1 + R"(; @@ -121,7 +131,9 @@ void matmul(restrict read_only align(16) )" + a_ty_ + R"( *A, )" + a_ty_ + R"( a[)" + AS + R"(] = checka ? *pa : 0; )" + b_ty_ + R"( b[)" + BS + R"(] = checkb ? *pb : 0; for(int k = K; k > 0; k = k - TK){ - c = dot()" + usea + ", " + useb + R"(, c); + )" + a_ty_ + R"( xa[)" + XAS + "] = __reshape(a, " + XAS + R"(); + )" + b_ty_ + R"( xb[)" + XBS + "] = __reshape(b, " + XBS + R"(); + xc = dot()" + usea + ", " + useb + R"(, xc); pa = pa + TK)" + lda0 + R"(; pb = pb + TK)" + ldb0 + R"(; bool checka[)" + AS + R"(] = k > TK; @@ -131,14 +143,13 @@ void matmul(restrict read_only align(16) )" + a_ty_ + R"( *A, } int rxc[TM] = ridx * TM + (0 ... TM); int ryc[TN] = ridy * TN + (0 ... TN); - bool checkc0[TM] = rxc < M; - bool checkc1[TN] = ryc < N; - bool checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :]; - float* pc[TM, TN] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis]; - @checkc *pc = c; + )" + c_ty_ + R"(* pc[TM, TN] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis]; + )" + c_ty_ + R"( c[TM, TN] = __sum(xc, 2); + *pc = c; } )"; +// std::cout << res << std::endl; os << res; }