From 150ba0c70bdb96326726fdd8be962d307523ba03 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Wed, 12 Aug 2020 19:37:08 -0400 Subject: [PATCH] [TESTS] Updated the test to be compatible with the new runtime API --- tests/common/conv.h | 40 +++++++++++++++++++++++++------ tests/common/copy.h | 19 +++++++++++---- tests/common/dot.h | 53 +++++++++++++---------------------------- tests/common/reduce.h | 19 +++++++++++---- tests/common/src/conv.h | 6 +++-- tests/common/src/copy.h | 12 +++++----- 6 files changed, 88 insertions(+), 61 deletions(-) diff --git a/tests/common/conv.h b/tests/common/conv.h index b6166bb94..006b5dc0a 100644 --- a/tests/common/conv.h +++ b/tests/common/conv.h @@ -12,6 +12,32 @@ #include "util.h" +struct conv_arg_t{ + CUdeviceptr a; + CUdeviceptr b; + CUdeviceptr c; + float alpha; + int M; + int N; + int K; + int pad_h; + int pad_w; + int stride_h; + int stride_w; + CUdeviceptr adelta; + int lda_z; + int lda_ci; + int lda_h; + int lda_w; + int ldb_ci; + int ldb_r; + int ldb_s; + int ldb_co; + int ldc_z; + int ldc_co; + int ldc_p; + int ldc_q; +}; enum run_mode_t { BENCH, @@ -93,19 +119,19 @@ void triton_conv(drv::stream* stream, // kernels rt::function function(src::conv, opt); - std::vector args = {&*da, &*db, &*dc, (float)1, Z*P*Q, CO, CI*R*S, - pad_h, pad_w, stride_h, stride_w, - &*ddelta, - W*H*CI, W*H, W, 1, - CO*S*R , CO*S, CO, 1, - Q*P*CO, Q*P, Q, 1}; + conv_arg_t args{*da->cu(), *db->cu(), *dc->cu(), 1, Z*P*Q, CO, CI*R*S, + pad_h, pad_w, stride_h, stride_w, + *ddelta->cu(), + W*H*CI, W*H, W, 1, + CO*S*R , CO*S, CO, 1, + Q*P*CO, Q*P, Q, 1}; auto grid = [Z,P,Q,CO](const rt::function::options_t& x) { return rt::grid_t{ceil(Z*P*Q, x.D("TM")), ceil(CO , x.D("TN")), (size_t)x.D("TZ")}; }; auto tflops = [&](double nanosec) { return 2.*Z*P*Q*CI*CO*R*S / nanosec * 1e-3; }; - double triton_ns = triton::tools::bench([&]() { function(args, grid, stream);}, stream); + double triton_ns = triton::tools::bench([&]() { function((void**)&args, sizeof(args), grid, stream);}, stream); bench.push_back(tflops(triton_ns)); } diff --git a/tests/common/copy.h b/tests/common/copy.h index 0398be7c9..aac462789 100644 --- a/tests/common/copy.h +++ b/tests/common/copy.h @@ -12,6 +12,15 @@ int32_t off(const std::vector& idx, const std::vector& strides return res; } +struct copy_arg_t{ + CUdeviceptr X; + CUdeviceptr Y; + int S0; + int S1; + int S2; +}; + + enum run_mode_t { BENCH, TEST @@ -115,16 +124,16 @@ void triton_copy_nd(drv::stream* stream, const std::vector& shape, // kernel rt::function function(src::copy_nd[rank - 1], opt); - std::vector args = {&*dx, &*dy}; - for(int32_t d: shape) - args.push_back(d); + copy_arg_t args = {*dx->cu(), *dy->cu(), shape[0]}; + if(shape.size() > 1) args.S1 = shape[1]; + if(shape.size() > 2) args.S2 = shape[2]; std::vector ts = {"TS0", "TS1", "TS2"}; auto grid = grid_nd(shape, ts); // metrics if(mode == BENCH){ auto gbps = [&](double ns) { return 2 * size * dtsize / (ns * 1e-9) * 1e-9; }; - double triton_ns = triton::tools::bench([&]() { function(args, grid, stream);}, stream); + double triton_ns = triton::tools::bench([&]() { function((void**)&args, sizeof(args), grid, stream);}, stream); bench.push_back(gbps(triton_ns)); } @@ -136,7 +145,7 @@ void triton_copy_nd(drv::stream* stream, const std::vector& shape, for(size_t i = 0; i < hx.size(); i++) hx[i] = static_cast((float)rand()/RAND_MAX); stream->write(&*dx, true, 0, hx); - function(args, grid, stream); + function((void**)&args, sizeof(args), grid, stream); stream->synchronize(); stream->read(&*dy, true, 0, hy); cc_copy_nd(hx, ry, shape, x_order, y_order); diff --git a/tests/common/dot.h b/tests/common/dot.h index 74c7d78ee..9c3f21091 100644 --- a/tests/common/dot.h +++ b/tests/common/dot.h @@ -13,23 +13,19 @@ #include "util.h" -//struct dot_arg_t{ -// CUdeviceptr a; -// CUdeviceptr b; -// CUdeviceptr c; -// float alpha; -// int M; -// int N; -// int K; -// int lda; -// int ldb; -// int ldc; -// CUdeviceptr locks; -//}; - -//typedef std::tuple dot_arg_t; +struct dot_arg_t{ + CUdeviceptr a; + CUdeviceptr b; + CUdeviceptr c; + float alpha; + int M; + int N; + int K; + int lda; + int ldb; + int ldc; + CUdeviceptr locks; +}; template static void cc_dot(std::vector &c, const std::vector &a, const std::vector &b, @@ -140,24 +136,9 @@ void triton_dot(drv::stream* stream, bool AT, bool BT, // kernels rt::function function(src::dot, opt); - float alpha = 1; - char args[60]; - memcpy(args + 0, &*da->cu(), 8); - memcpy(args + 8, &*db->cu(), 8); - memcpy(args + 16, &*dc->cu(), 8); - memcpy(args + 24, &alpha, 4); - memcpy(args + 28, &M, 4); - memcpy(args + 32, &N, 4); - memcpy(args + 36, &K, 4); - memcpy(args + 40, &lda, 4); - memcpy(args + 44, &ldb, 4); - memcpy(args + 48, &ldc, 4); - memcpy(args + 52, &*dlocks->cu(), 8); + dot_arg_t args = {*da->cu(), *db->cu(), *dc->cu(), + 1, M, N, K, lda, ldb, ldc, *dlocks->cu()}; - -// dot_arg_t args = {*da->cu(), *db->cu(), *dc->cu(), -// 1, M, N, K, lda, ldb, ldc, *dlocks->cu()}; -// std::cout << sizeof(dot_arg_t) << std::endl; auto grid = [M, N](const rt::function::options_t& x) { return rt::grid_t{ceil(M, x.D("TM")), ceil(N, x.D("TN")), @@ -167,7 +148,7 @@ void triton_dot(drv::stream* stream, bool AT, bool BT, // metrics if(mode == BENCH){ auto tflops = [&](double nanosec) { return 2.*M*N*K / nanosec * 1e-3; }; - double triton_ns = triton::tools::bench([&]() { function((void**)&args, grid, stream);}, stream); + double triton_ns = triton::tools::bench([&]() { function((void**)&args, sizeof(args), grid, stream);}, stream); bench.push_back(tflops(triton_ns)); // cublas @@ -198,7 +179,7 @@ void triton_dot(drv::stream* stream, bool AT, bool BT, stream->write(&*da, true, 0, ha); stream->write(&*db, true, 0, hb); // run kernel - function((void**)&args, grid, stream); + function((void**)&args, sizeof(args), grid, stream); // write back stream->synchronize(); // compare with CPU diff --git a/tests/common/reduce.h b/tests/common/reduce.h index 0f5d63612..60923274e 100644 --- a/tests/common/reduce.h +++ b/tests/common/reduce.h @@ -13,6 +13,15 @@ namespace drv = triton::driver; namespace rt = triton::runtime; +struct reduce_arg_t{ + CUdeviceptr X; + CUdeviceptr Y; + int S0; + int S1; + int S2; +}; + + template void cc_reduce_nd(std::vector &y, const std::vector &x, reduce_op_t op, size_t axis, const std::vector& shapes) { assert(axis <= shapes.size() - 1); @@ -123,16 +132,16 @@ void triton_reduce_nd(drv::stream* stream, const std::vector& shape_x, auto dy = std::unique_ptr(drv::buffer::create(context, size_y*dtsize)); // grid - std::vector args = {&*dx, &*dy}; - for(int32_t d: shape_x) - args.push_back(d); + reduce_arg_t args = {*dx->cu(), *dy->cu(), shape_x[0]}; + if(shape_x.size() > 1) args.S1 = shape_x[1]; + if(shape_x.size() > 2) args.S2 = shape_x[2]; std::vector ts = {"TS0", "TS1", "TS2"}; auto grid = grid_nd(shape_x, ts); // metrics if(mode == BENCH){ auto gbps = [&](double ns) { return 2 * size_x * dtsize / (ns * 1e-9) * 1e-9; }; - double triton_ns = triton::tools::bench([&]() { function(args, grid, stream);}, stream); + double triton_ns = triton::tools::bench([&]() { function((void**)&args, sizeof(args), grid, stream);}, stream); bench.push_back(gbps(triton_ns)); } @@ -144,7 +153,7 @@ void triton_reduce_nd(drv::stream* stream, const std::vector& shape_x, init_zeros(hy); init_rand(hx); stream->write(&*dx, true, 0, hx); - function(args, grid, stream); + function((void**)&args, sizeof(args), grid, stream); stream->synchronize(); stream->read(&*dy, true, 0, hy); cc_reduce_nd(ry, hx, op, axis, shape_x); diff --git a/tests/common/src/conv.h b/tests/common/src/conv.h index c0786cb6f..ace395575 100644 --- a/tests/common/src/conv.h +++ b/tests/common/src/conv.h @@ -7,7 +7,9 @@ R"( TYPE *C __noalias __aligned(16), float alpha, // equivalent matmul - int M, int N, int K, + int M __retune, + int N __retune, + int K __retune, // convolution properties int pad_h, int pad_w, int stride_h, int stride_w, // pointer increment @@ -16,7 +18,7 @@ R"( int lda_z __multipleof(8), int lda_ci __multipleof(8), int lda_h __multipleof(8), int lda_w __multipleof(8), int ldb_ci __multipleof(8), int ldb_r __multipleof(8), int ldb_s __multipleof(8), int ldb_co __multipleof(8), int ldc_z __multipleof(8), int ldc_co __multipleof(8), int ldc_p __multipleof(8), int ldc_q __multipleof(8)) { - // prologue + // prologue int ridx = get_program_id(0); int ridy = get_program_id(1); int ridz = get_program_id(2); diff --git a/tests/common/src/copy.h b/tests/common/src/copy.h index cd35cbf4e..2b9eb6cdc 100644 --- a/tests/common/src/copy.h +++ b/tests/common/src/copy.h @@ -7,7 +7,7 @@ namespace src { R"( void copy1d(TYPE * X __noalias __readonly __aligned(16), TYPE * Y __noalias __readonly __aligned(16), - int S0) { + int S0 __retune) { int pid0 = get_program_id(0); int rs0[TS0] = pid0 * TS0 + 0 ... TS0; TYPE* px[TS0] = X + rs0; @@ -20,8 +20,8 @@ void copy1d(TYPE * X __noalias __readonly __aligned(16), R"( void copy2d(TYPE * X __noalias __readonly __aligned(16), TYPE * Y __noalias __writeonly __aligned(16), - int S0 __multipleof(8), - int S1 __multipleof(8)) { + int S0 __multipleof(8) __retune, + int S1 __multipleof(8) __retune) { int pid0 = get_program_id(0); int pid1 = get_program_id(1); int rs0[TS0] = pid0 * TS0 + 0 ... TS0; @@ -37,9 +37,9 @@ void copy2d(TYPE * X __noalias __readonly __aligned(16), R"( void copy3d(TYPE * X __noalias __readonly __aligned(16), TYPE * Y __noalias __writeonly __aligned(16), - int S0 __multipleof(8), - int S1 __multipleof(8), - int S2 __multipleof(8)) { + int S0 __multipleof(8) __retune, + int S1 __multipleof(8) __retune, + int S2 __multipleof(8) __retune) { // program id int pid0 = get_program_id(0); int pid1 = get_program_id(1);