#include #include #include #include #include "triton/driver/backend.h" #include "triton/driver/stream.h" #include "triton/tools/bench.hpp" #include "triton/external/half.hpp" #include "triton/runtime/function.h" #include "src/dot.h" #include "cuda/cublas.h" #include "util.h" template static void cc_dot(std::vector &c, const std::vector &a, const std::vector &b, size_t M, size_t N, size_t K){ for(size_t m = 0; m < M; m++) for(size_t n = 0; n < N; n++){ float acc = 0; for(size_t k = 0; k < K; k++) acc = acc + (!AT ? a[k*M + m] : a[m*K + k]) * (!BT ? b[n*K + k] : b[k*N + n]); c[m*N + n] = static_cast(acc); } } template void cc_dot(bool AT_, bool BT_, size_t M, size_t N, size_t K, std::vector &c, const std::vector &a, const std::vector &b) { if(AT_ && BT_) cc_dot(c, a, b, M, N, K); else if(AT_ && !BT_) cc_dot(c, a, b, M, N, K); else if(!AT_ && BT_) cc_dot(c, a, b, M, N, K); else cc_dot(c, a, b, M, N, K); } enum run_mode_t { BENCH, TEST }; enum dtype_t { FLOAT, HALF, DOUBLE }; template struct to_string; template<> struct to_string{ static constexpr const char* value = "half"; }; template<> struct to_string{ static constexpr const char* value = "float"; }; template<> struct to_string{ static constexpr const char* value = "double"; }; template bool triton_dot(drv::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int32_t K, int32_t TM, int32_t TN, int32_t TK, int32_t nwarp, const std::vector& a_order, const std::vector& b_order, run_mode_t mode, std::vector& bench, bool &test){ std::string ty = to_string::value; size_t dt_nbytes = sizeof(T); drv::context* context = stream->context(); int32_t lda = (AT ^ a_order[0]==1) ? K : M; int32_t ldb = (BT ^ b_order[0]==1) ? N : K; int32_t ldc = N; std::vector sa = { "1", "lda" }; std::vector sb = { "1", "ldb" }; // inputs auto dc = std::shared_ptr(drv::buffer::create(context, M*N*dt_nbytes)); auto da = std::shared_ptr(drv::buffer::create(context, M*K*dt_nbytes)); auto db = std::shared_ptr(drv::buffer::create(context, K*N*dt_nbytes)); // macros rt::function::options_space_t opt; // A access patterns opt.defines.push_back({"USEA", {AT? "a" : "a" }}); opt.defines.push_back({"BROADCAST_AK", {AT? "newaxis, :" : "newaxis, :" }}); opt.defines.push_back({"BROADCAST_AM", {AT? ":, newaxis" : ":, newaxis" }}); opt.defines.push_back({"SHAPE_A", {AT? "TM, TK" : "TM, TK" }}); opt.defines.push_back({"STRIDE_AK", {AT? sa[a_order[0]] : sa[a_order[1]] }}); opt.defines.push_back({"STRIDE_AM", {AT? sa[a_order[1]] : sa[a_order[0]] }}); // B access patterns opt.defines.push_back({"USEB", {BT? "b" : "b" }}); opt.defines.push_back({"BROADCAST_BK", {BT? ":, newaxis" : ":, newaxis" }}); opt.defines.push_back({"BROADCAST_BN", {BT? "newaxis, :" : "newaxis, :" }}); opt.defines.push_back({"SHAPE_B", {BT? "TK, TN" : "TK, TN" }}); opt.defines.push_back({"STRIDE_BK", {BT? sb[b_order[1]] : sb[b_order[0]] }}); opt.defines.push_back({"STRIDE_BN", {BT? sb[b_order[0]] : sb[b_order[1]] }}); // data-type opt.defines.push_back({"TYPE", {ty}}); // tile sizes if(mode == TEST) { opt.defines.push_back({"TM", {std::to_string(TM)}}); opt.defines.push_back({"TN", {std::to_string(TN)}}); opt.defines.push_back({"TK", {std::to_string(TK)}}); opt.num_warps = {nwarp}; } if(mode == BENCH) { opt.defines.push_back({"TM", {"128"}}); opt.defines.push_back({"TN", {"32"}}); opt.defines.push_back({"TK", {to_string::value == "half" ? "16" : "8"}}); opt.num_warps = {4}; } // kernels rt::function function(src::dot, opt); std::vector args = {&*da, &*db, &*dc, (float)1, M, N, K, lda, ldb, ldc}; auto grid = grid2d(M, N); // metrics if(mode == BENCH){ auto tflops = [&](double nanosec) { return 2.*M*N*K / nanosec * 1e-3; }; double triton_ns = triton::tools::bench([&]() { function(args, grid, stream);}, stream); bench.push_back(tflops(triton_ns)); // // cublas // if(cublas::cublasinit()){ // T alpha(static_cast(1)); // T beta(static_cast(0)); // cublasGemmAlgo_t fastest; // cublasGemm(CUDA_R_32F, stream, AT, BT, M, N, K, &alpha, &*da, lda, &*db, ldb, &beta, &*dc, ldc, &fastest); // double cublas_ms = triton::tools::bench([&]() { cublasGemm(CUDA_R_16F, stream, AT, BT, M, N, K, // &alpha, &*da, lda, &*db, ldb, &beta, &*dc, // ldc, nullptr, fastest); }, stream); // bench.push_back(tflops(cublas_ms)); // } } // test triton if(mode == TEST){ srand(0); // initialize buffers std::vector hc(M*N); std::vector ha(M*K); std::vector hb(K*N); for(size_t i = 0; i < ha.size(); i++) ha[i] = (float)rand()/RAND_MAX; for(size_t i = 0; i < hb.size(); i++) hb[i] = (float)rand()/RAND_MAX; // copy buffer stream->write(&*da, true, 0, ha); stream->write(&*db, true, 0, hb); // run kernel function(args, grid, stream); // write back stream->synchronize(); // compare with CPU stream->read(&*dc, true, 0, hc); std::vector rc(hc.size()); cc_dot(AT, BT, M, N, K, rc, ha, hb); test = testing::diff(hc, rc); } } std::vector bench_dot(drv::stream* stream, dtype_t dtype, bool AT, bool BT, int32_t M, int32_t N, int32_t K, const std::vector& a_order, const std::vector& b_order) { std::vector bench; bool test; switch(dtype){ case HALF: triton_dot(stream, AT, BT, M, N, K, 0, 0, 0, 0, a_order, b_order, BENCH, bench, test); break; case FLOAT: triton_dot(stream, AT, BT, M, N, K, 0, 0, 0, 0, a_order, b_order, BENCH, bench, test); break; case DOUBLE: triton_dot(stream, AT, BT, M, N, K, 0, 0, 0, 0, a_order, b_order, BENCH, bench, test); break; default: break; } return bench; } bool test_dot(drv::stream* stream, dtype_t dtype, bool AT, bool BT, int32_t M, int32_t N, int32_t K, const std::vector& a_order, const std::vector& b_order, int32_t TM, int32_t TN, int32_t TK, size_t nwarp) { std::vector bench; bool test = false; switch(dtype){ case HALF: triton_dot(stream, AT, BT, M, N, K, TM, TN, TK, nwarp, a_order, b_order, TEST, bench, test); break; case FLOAT: triton_dot(stream, AT, BT, M, N, K, TM, TN, TK, nwarp, a_order, b_order, TEST, bench, test); break; case DOUBLE: triton_dot(stream, AT, BT, M, N, K, TM, TN, TK, nwarp, a_order, b_order, TEST, bench, test); break; default: break; } return test; }