diff --git a/tests/bench/conv.cc b/tests/bench/conv.cc index 078029473..67aa090c4 100644 --- a/tests/bench/conv.cc +++ b/tests/bench/conv.cc @@ -19,9 +19,9 @@ int main() { // {1, 8, 8, 256, 256, 3, 3, 1, 1, 1, 1}, // {1, 16, 16, 256, 256, 3, 3, 1, 1, 1, 1}, -// {1, 32, 32, 256, 256, 3, 3, 1, 1, 1, 1}, + {1, 32, 32, 256, 256, 3, 3, 1, 1, 1, 1}, // {1, 64, 64, 256, 256, 3, 3, 1, 1, 1, 1}, - {1, 64, 64, 4096, 4096, 1, 1, 0, 0, 1, 1}, +// {1, 64, 64, 4096, 4096, 1, 1, 0, 0, 1, 1}, // {1, 256, 256, 256, 256, 3, 3, 1, 1, 1, 1} diff --git a/tests/common/conv.h b/tests/common/conv.h index 3101f396f..d2abdd718 100644 --- a/tests/common/conv.h +++ b/tests/common/conv.h @@ -11,34 +11,6 @@ #include "cuda/cublas.h" #include "util.h" - -struct conv_arg_t{ - CUdeviceptr a; - CUdeviceptr b; - CUdeviceptr c; - float alpha; - int M; - int N; - int K; - int pad_h; - int pad_w; - int stride_h; - int stride_w; - CUdeviceptr adelta; - int lda_z; - int lda_ci; - int lda_h; - int lda_w; - int ldb_ci; - int ldb_r; - int ldb_s; - int ldb_co; - int ldc_z; - int ldc_co; - int ldc_p; - int ldc_q; -}; - enum run_mode_t { BENCH, TEST @@ -104,8 +76,8 @@ void triton_conv(drv::context* context, drv::stream* stream, // macros rt::options_space_t opt; opt.defines.push_back({"TYPE", {ty}}); - opt.defines.push_back({"TM", {"128"}}); - opt.defines.push_back({"TN", {"128"}}); + opt.defines.push_back({"TM", {"64", "128"}}); + opt.defines.push_back({"TN", {"64", "128"}}); opt.defines.push_back({"TK", {std::to_string(TK)}}); opt.defines.push_back({"TZ", {"1"}}); opt.defines.push_back({"RR", {std::to_string(R)}}); @@ -114,24 +86,42 @@ void triton_conv(drv::context* context, drv::stream* stream, opt.defines.push_back({"QQ", {std::to_string(Q)}}); opt.defines.push_back({"HH", {std::to_string(H)}}); opt.defines.push_back({"WW", {std::to_string(W)}}); - - opt.num_warps = {2, 4}; - + opt.num_warps = {4}; + // arguments + std::stringstream oss; + rt::add_arg(oss, *da->cu()); + rt::add_arg(oss, *db->cu()); + rt::add_arg(oss, *dc->cu()); + rt::add_arg(oss, (float)1); + rt::add_arg(oss, Z*P*Q); + rt::add_arg(oss, CO); + rt::add_arg(oss, CI*R*S); + rt::add_arg(oss, pad_h); + rt::add_arg(oss, pad_w); + rt::add_arg(oss, stride_h); + rt::add_arg(oss, stride_w); + rt::add_arg(oss, *ddelta->cu()); + rt::add_arg(oss, W*H*CI); + rt::add_arg(oss, W*H); + rt::add_arg(oss, W); + rt::add_arg(oss, 1); + rt::add_arg(oss, CO*S*R); + rt::add_arg(oss, CO*S); + rt::add_arg(oss, CO); + rt::add_arg(oss, 1); + rt::add_arg(oss, Q*P*CO); + rt::add_arg(oss, Q*P); + rt::add_arg(oss, Q); + rt::add_arg(oss, 1); // kernels rt::function function(src::conv, opt); - conv_arg_t args{*da->cu(), *db->cu(), *dc->cu(), 1, Z*P*Q, CO, CI*R*S, - pad_h, pad_w, stride_h, stride_w, - *ddelta->cu(), - W*H*CI, W*H, W, 1, - CO*S*R , CO*S, CO, 1, - Q*P*CO, Q*P, Q, 1}; auto grid = [Z,P,Q,CO](const rt::options_t& x) { return rt::grid_t{ceil(Z*P*Q, x.D("TM")), ceil(CO , x.D("TN")), (size_t)x.D("TZ")}; }; auto tflops = [&](double nanosec) { return 2.*Z*P*Q*CI*CO*R*S / nanosec * 1e-3; }; - double triton_ns = triton::tools::bench([&]() { function((void**)&args, sizeof(args), grid, stream, device);}, stream); + double triton_ns = triton::tools::bench([&]() { function((void**)oss.str().data(), oss.str().size(), grid, stream, device);}, stream); bench.push_back(tflops(triton_ns)); } diff --git a/tests/common/copy.h b/tests/common/copy.h index 60dcd8233..4925c4c84 100644 --- a/tests/common/copy.h +++ b/tests/common/copy.h @@ -12,15 +12,6 @@ int32_t off(const std::vector& idx, const std::vector& strides return res; } -struct copy_arg_t{ - CUdeviceptr X; - CUdeviceptr Y; - int S0; - int S1; - int S2; -}; - - enum run_mode_t { BENCH, TEST @@ -124,16 +115,21 @@ void triton_copy_nd(drv::context* context, drv::stream* stream, const std::vecto // kernel rt::function function(src::copy_nd[rank - 1], opt); - copy_arg_t args = {*dx->cu(), *dy->cu(), shape[0]}; - if(shape.size() > 1) args.S1 = shape[1]; - if(shape.size() > 2) args.S2 = shape[2]; + + std::stringstream oss; + rt::add_arg(oss, *dx->cu()); + rt::add_arg(oss, *dy->cu()); + rt::add_arg(oss, (uint32_t)shape[0]); + if(shape.size() > 1) rt::add_arg(oss, (uint32_t)shape[1]); + if(shape.size() > 2) rt::add_arg(oss, (uint32_t)shape[2]); + std::vector ts = {"TS0", "TS1", "TS2"}; auto grid = grid_nd(shape, ts); // metrics if(mode == BENCH){ auto gbps = [&](double ns) { return 2 * size * dtsize / (ns * 1e-9) * 1e-9; }; - double triton_ns = triton::tools::bench([&]() { function((void**)&args, sizeof(args), grid, stream, device);}, stream); + double triton_ns = triton::tools::bench([&]() { function((void**)oss.str().data(), oss.str().size(), grid, stream, device);}, stream); bench.push_back(gbps(triton_ns)); } @@ -145,7 +141,7 @@ void triton_copy_nd(drv::context* context, drv::stream* stream, const std::vecto for(size_t i = 0; i < hx.size(); i++) hx[i] = static_cast((float)rand()/RAND_MAX); stream->write(&*dx, true, 0, hx); - function((void**)&args, sizeof(args), grid, stream, device); + function((void**)oss.str().data(), oss.str().size(), grid, stream, device); stream->synchronize(); stream->read(&*dy, true, 0, hy); cc_copy_nd(hx, ry, shape, x_order, y_order); diff --git a/tests/common/dot.h b/tests/common/dot.h index 2fe46d9cc..5b8d4c15d 100644 --- a/tests/common/dot.h +++ b/tests/common/dot.h @@ -12,21 +12,6 @@ #include "cuda/cublas.h" #include "util.h" - -struct dot_arg_t{ - uintptr_t a; - uintptr_t b; - uintptr_t c; - float alpha; - int M; - int N; - int K; - int lda; - int ldb; - int ldc; - uintptr_t locks; -}; - template static void cc_dot(std::vector &c, const std::vector &a, const std::vector &b, size_t M, size_t N, size_t K){ @@ -126,11 +111,22 @@ void triton_dot(drv::context* context, drv::stream* stream, bool AT, bool BT, opts.num_warps = {4}; } - // kernels + // arguments + std::stringstream oss; + rt::add_arg(oss, *da->cu()); + rt::add_arg(oss, *db->cu()); + rt::add_arg(oss, *dc->cu()); + rt::add_arg(oss, (float)1); + rt::add_arg(oss, M); + rt::add_arg(oss, N); + rt::add_arg(oss, K); + rt::add_arg(oss, lda); + rt::add_arg(oss, ldb); + rt::add_arg(oss, ldc); + rt::add_arg(oss, *dlocks->cu()); + // kernel rt::function function(src::dot, opts); - dot_arg_t args = {da->addr_as_uintptr_t(), db->addr_as_uintptr_t(), dc->addr_as_uintptr_t(), - 1, M, N, K, lda, ldb, ldc, dlocks->addr_as_uintptr_t()}; - + // grid auto grid = [M, N](const rt::options_t& x) { return rt::grid_t{ceil(M, x.D("TM"))* ceil(N, x.D("TN")), @@ -140,7 +136,7 @@ void triton_dot(drv::context* context, drv::stream* stream, bool AT, bool BT, // metrics if(mode == BENCH){ auto tflops = [&](double nanosec) { return 2.*M*N*K / nanosec * 1e-3; }; - double triton_ns = triton::tools::bench([&]() { function((void**)&args, sizeof(args), grid, stream, device);}, stream); + double triton_ns = triton::tools::bench([&]() { function((void**)oss.str().data(), oss.str().size(), grid, stream, device);}, stream); bench.push_back(tflops(triton_ns)); // cublas @@ -177,7 +173,7 @@ void triton_dot(drv::context* context, drv::stream* stream, bool AT, bool BT, stream->write(&*da, true, 0, ha); stream->write(&*db, true, 0, hb); // run kernel - function((void**)&args, sizeof(args), grid, stream, device); + function((void**)oss.str().data(), oss.str().size(), grid, stream, device); // write back stream->synchronize(); // compare with CPU diff --git a/tests/common/reduce.h b/tests/common/reduce.h index 34369e6e7..3c0b79fc9 100644 --- a/tests/common/reduce.h +++ b/tests/common/reduce.h @@ -13,15 +13,6 @@ namespace drv = triton::driver; namespace rt = triton::runtime; -struct reduce_arg_t{ - CUdeviceptr X; - CUdeviceptr Y; - int S0; - int S1; - int S2; -}; - - template void cc_reduce_nd(std::vector &y, const std::vector &x, reduce_op_t op, size_t axis, const std::vector& shapes) { assert(axis <= shapes.size() - 1);