[TESTS] Updated the test to be compatible with the new runtime API
This commit is contained in:
committed by
Philippe Tillet
parent
acff1b5e05
commit
150ba0c70b
@@ -12,6 +12,32 @@
|
|||||||
#include "util.h"
|
#include "util.h"
|
||||||
|
|
||||||
|
|
||||||
|
struct conv_arg_t{
|
||||||
|
CUdeviceptr a;
|
||||||
|
CUdeviceptr b;
|
||||||
|
CUdeviceptr c;
|
||||||
|
float alpha;
|
||||||
|
int M;
|
||||||
|
int N;
|
||||||
|
int K;
|
||||||
|
int pad_h;
|
||||||
|
int pad_w;
|
||||||
|
int stride_h;
|
||||||
|
int stride_w;
|
||||||
|
CUdeviceptr adelta;
|
||||||
|
int lda_z;
|
||||||
|
int lda_ci;
|
||||||
|
int lda_h;
|
||||||
|
int lda_w;
|
||||||
|
int ldb_ci;
|
||||||
|
int ldb_r;
|
||||||
|
int ldb_s;
|
||||||
|
int ldb_co;
|
||||||
|
int ldc_z;
|
||||||
|
int ldc_co;
|
||||||
|
int ldc_p;
|
||||||
|
int ldc_q;
|
||||||
|
};
|
||||||
|
|
||||||
enum run_mode_t {
|
enum run_mode_t {
|
||||||
BENCH,
|
BENCH,
|
||||||
@@ -93,19 +119,19 @@ void triton_conv(drv::stream* stream,
|
|||||||
|
|
||||||
// kernels
|
// kernels
|
||||||
rt::function function(src::conv, opt);
|
rt::function function(src::conv, opt);
|
||||||
std::vector<rt::arg> args = {&*da, &*db, &*dc, (float)1, Z*P*Q, CO, CI*R*S,
|
conv_arg_t args{*da->cu(), *db->cu(), *dc->cu(), 1, Z*P*Q, CO, CI*R*S,
|
||||||
pad_h, pad_w, stride_h, stride_w,
|
pad_h, pad_w, stride_h, stride_w,
|
||||||
&*ddelta,
|
*ddelta->cu(),
|
||||||
W*H*CI, W*H, W, 1,
|
W*H*CI, W*H, W, 1,
|
||||||
CO*S*R , CO*S, CO, 1,
|
CO*S*R , CO*S, CO, 1,
|
||||||
Q*P*CO, Q*P, Q, 1};
|
Q*P*CO, Q*P, Q, 1};
|
||||||
auto grid = [Z,P,Q,CO](const rt::function::options_t& x) {
|
auto grid = [Z,P,Q,CO](const rt::function::options_t& x) {
|
||||||
return rt::grid_t{ceil(Z*P*Q, x.D<int>("TM")),
|
return rt::grid_t{ceil(Z*P*Q, x.D<int>("TM")),
|
||||||
ceil(CO , x.D<int>("TN")),
|
ceil(CO , x.D<int>("TN")),
|
||||||
(size_t)x.D<int>("TZ")};
|
(size_t)x.D<int>("TZ")};
|
||||||
};
|
};
|
||||||
auto tflops = [&](double nanosec) { return 2.*Z*P*Q*CI*CO*R*S / nanosec * 1e-3; };
|
auto tflops = [&](double nanosec) { return 2.*Z*P*Q*CI*CO*R*S / nanosec * 1e-3; };
|
||||||
double triton_ns = triton::tools::bench([&]() { function(args, grid, stream);}, stream);
|
double triton_ns = triton::tools::bench([&]() { function((void**)&args, sizeof(args), grid, stream);}, stream);
|
||||||
bench.push_back(tflops(triton_ns));
|
bench.push_back(tflops(triton_ns));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -12,6 +12,15 @@ int32_t off(const std::vector<int32_t>& idx, const std::vector<int32_t>& strides
|
|||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct copy_arg_t{
|
||||||
|
CUdeviceptr X;
|
||||||
|
CUdeviceptr Y;
|
||||||
|
int S0;
|
||||||
|
int S1;
|
||||||
|
int S2;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
enum run_mode_t {
|
enum run_mode_t {
|
||||||
BENCH,
|
BENCH,
|
||||||
TEST
|
TEST
|
||||||
@@ -115,16 +124,16 @@ void triton_copy_nd(drv::stream* stream, const std::vector<int32_t>& shape,
|
|||||||
|
|
||||||
// kernel
|
// kernel
|
||||||
rt::function function(src::copy_nd[rank - 1], opt);
|
rt::function function(src::copy_nd[rank - 1], opt);
|
||||||
std::vector<rt::arg> args = {&*dx, &*dy};
|
copy_arg_t args = {*dx->cu(), *dy->cu(), shape[0]};
|
||||||
for(int32_t d: shape)
|
if(shape.size() > 1) args.S1 = shape[1];
|
||||||
args.push_back(d);
|
if(shape.size() > 2) args.S2 = shape[2];
|
||||||
std::vector<std::string> ts = {"TS0", "TS1", "TS2"};
|
std::vector<std::string> ts = {"TS0", "TS1", "TS2"};
|
||||||
auto grid = grid_nd(shape, ts);
|
auto grid = grid_nd(shape, ts);
|
||||||
|
|
||||||
// metrics
|
// metrics
|
||||||
if(mode == BENCH){
|
if(mode == BENCH){
|
||||||
auto gbps = [&](double ns) { return 2 * size * dtsize / (ns * 1e-9) * 1e-9; };
|
auto gbps = [&](double ns) { return 2 * size * dtsize / (ns * 1e-9) * 1e-9; };
|
||||||
double triton_ns = triton::tools::bench([&]() { function(args, grid, stream);}, stream);
|
double triton_ns = triton::tools::bench([&]() { function((void**)&args, sizeof(args), grid, stream);}, stream);
|
||||||
bench.push_back(gbps(triton_ns));
|
bench.push_back(gbps(triton_ns));
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -136,7 +145,7 @@ void triton_copy_nd(drv::stream* stream, const std::vector<int32_t>& shape,
|
|||||||
for(size_t i = 0; i < hx.size(); i++)
|
for(size_t i = 0; i < hx.size(); i++)
|
||||||
hx[i] = static_cast<T>((float)rand()/RAND_MAX);
|
hx[i] = static_cast<T>((float)rand()/RAND_MAX);
|
||||||
stream->write(&*dx, true, 0, hx);
|
stream->write(&*dx, true, 0, hx);
|
||||||
function(args, grid, stream);
|
function((void**)&args, sizeof(args), grid, stream);
|
||||||
stream->synchronize();
|
stream->synchronize();
|
||||||
stream->read(&*dy, true, 0, hy);
|
stream->read(&*dy, true, 0, hy);
|
||||||
cc_copy_nd(hx, ry, shape, x_order, y_order);
|
cc_copy_nd(hx, ry, shape, x_order, y_order);
|
||||||
|
@@ -13,23 +13,19 @@
|
|||||||
#include "util.h"
|
#include "util.h"
|
||||||
|
|
||||||
|
|
||||||
//struct dot_arg_t{
|
struct dot_arg_t{
|
||||||
// CUdeviceptr a;
|
CUdeviceptr a;
|
||||||
// CUdeviceptr b;
|
CUdeviceptr b;
|
||||||
// CUdeviceptr c;
|
CUdeviceptr c;
|
||||||
// float alpha;
|
float alpha;
|
||||||
// int M;
|
int M;
|
||||||
// int N;
|
int N;
|
||||||
// int K;
|
int K;
|
||||||
// int lda;
|
int lda;
|
||||||
// int ldb;
|
int ldb;
|
||||||
// int ldc;
|
int ldc;
|
||||||
// CUdeviceptr locks;
|
CUdeviceptr locks;
|
||||||
//};
|
};
|
||||||
|
|
||||||
//typedef std::tuple<CUdeviceptr, CUdeviceptr, CUdeviceptr,
|
|
||||||
// float, int, int, int, int, int, int,
|
|
||||||
// CUdeviceptr> dot_arg_t;
|
|
||||||
|
|
||||||
template<class T, bool AT, bool BT>
|
template<class T, bool AT, bool BT>
|
||||||
static void cc_dot(std::vector<T> &c, const std::vector<T> &a, const std::vector<T> &b,
|
static void cc_dot(std::vector<T> &c, const std::vector<T> &a, const std::vector<T> &b,
|
||||||
@@ -140,24 +136,9 @@ void triton_dot(drv::stream* stream, bool AT, bool BT,
|
|||||||
|
|
||||||
// kernels
|
// kernels
|
||||||
rt::function function(src::dot, opt);
|
rt::function function(src::dot, opt);
|
||||||
float alpha = 1;
|
dot_arg_t args = {*da->cu(), *db->cu(), *dc->cu(),
|
||||||
char args[60];
|
1, M, N, K, lda, ldb, ldc, *dlocks->cu()};
|
||||||
memcpy(args + 0, &*da->cu(), 8);
|
|
||||||
memcpy(args + 8, &*db->cu(), 8);
|
|
||||||
memcpy(args + 16, &*dc->cu(), 8);
|
|
||||||
memcpy(args + 24, &alpha, 4);
|
|
||||||
memcpy(args + 28, &M, 4);
|
|
||||||
memcpy(args + 32, &N, 4);
|
|
||||||
memcpy(args + 36, &K, 4);
|
|
||||||
memcpy(args + 40, &lda, 4);
|
|
||||||
memcpy(args + 44, &ldb, 4);
|
|
||||||
memcpy(args + 48, &ldc, 4);
|
|
||||||
memcpy(args + 52, &*dlocks->cu(), 8);
|
|
||||||
|
|
||||||
|
|
||||||
// dot_arg_t args = {*da->cu(), *db->cu(), *dc->cu(),
|
|
||||||
// 1, M, N, K, lda, ldb, ldc, *dlocks->cu()};
|
|
||||||
// std::cout << sizeof(dot_arg_t) << std::endl;
|
|
||||||
auto grid = [M, N](const rt::function::options_t& x) {
|
auto grid = [M, N](const rt::function::options_t& x) {
|
||||||
return rt::grid_t{ceil(M, x.D<int>("TM")),
|
return rt::grid_t{ceil(M, x.D<int>("TM")),
|
||||||
ceil(N, x.D<int>("TN")),
|
ceil(N, x.D<int>("TN")),
|
||||||
@@ -167,7 +148,7 @@ void triton_dot(drv::stream* stream, bool AT, bool BT,
|
|||||||
// metrics
|
// metrics
|
||||||
if(mode == BENCH){
|
if(mode == BENCH){
|
||||||
auto tflops = [&](double nanosec) { return 2.*M*N*K / nanosec * 1e-3; };
|
auto tflops = [&](double nanosec) { return 2.*M*N*K / nanosec * 1e-3; };
|
||||||
double triton_ns = triton::tools::bench([&]() { function((void**)&args, grid, stream);}, stream);
|
double triton_ns = triton::tools::bench([&]() { function((void**)&args, sizeof(args), grid, stream);}, stream);
|
||||||
bench.push_back(tflops(triton_ns));
|
bench.push_back(tflops(triton_ns));
|
||||||
|
|
||||||
// cublas
|
// cublas
|
||||||
@@ -198,7 +179,7 @@ void triton_dot(drv::stream* stream, bool AT, bool BT,
|
|||||||
stream->write(&*da, true, 0, ha);
|
stream->write(&*da, true, 0, ha);
|
||||||
stream->write(&*db, true, 0, hb);
|
stream->write(&*db, true, 0, hb);
|
||||||
// run kernel
|
// run kernel
|
||||||
function((void**)&args, grid, stream);
|
function((void**)&args, sizeof(args), grid, stream);
|
||||||
// write back
|
// write back
|
||||||
stream->synchronize();
|
stream->synchronize();
|
||||||
// compare with CPU
|
// compare with CPU
|
||||||
|
@@ -13,6 +13,15 @@
|
|||||||
namespace drv = triton::driver;
|
namespace drv = triton::driver;
|
||||||
namespace rt = triton::runtime;
|
namespace rt = triton::runtime;
|
||||||
|
|
||||||
|
struct reduce_arg_t{
|
||||||
|
CUdeviceptr X;
|
||||||
|
CUdeviceptr Y;
|
||||||
|
int S0;
|
||||||
|
int S1;
|
||||||
|
int S2;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
template<class T>
|
template<class T>
|
||||||
void cc_reduce_nd(std::vector<T> &y, const std::vector<T> &x, reduce_op_t op, size_t axis, const std::vector<int>& shapes) {
|
void cc_reduce_nd(std::vector<T> &y, const std::vector<T> &x, reduce_op_t op, size_t axis, const std::vector<int>& shapes) {
|
||||||
assert(axis <= shapes.size() - 1);
|
assert(axis <= shapes.size() - 1);
|
||||||
@@ -123,16 +132,16 @@ void triton_reduce_nd(drv::stream* stream, const std::vector<int32_t>& shape_x,
|
|||||||
auto dy = std::unique_ptr<drv::buffer>(drv::buffer::create(context, size_y*dtsize));
|
auto dy = std::unique_ptr<drv::buffer>(drv::buffer::create(context, size_y*dtsize));
|
||||||
|
|
||||||
// grid
|
// grid
|
||||||
std::vector<rt::arg> args = {&*dx, &*dy};
|
reduce_arg_t args = {*dx->cu(), *dy->cu(), shape_x[0]};
|
||||||
for(int32_t d: shape_x)
|
if(shape_x.size() > 1) args.S1 = shape_x[1];
|
||||||
args.push_back(d);
|
if(shape_x.size() > 2) args.S2 = shape_x[2];
|
||||||
std::vector<std::string> ts = {"TS0", "TS1", "TS2"};
|
std::vector<std::string> ts = {"TS0", "TS1", "TS2"};
|
||||||
auto grid = grid_nd(shape_x, ts);
|
auto grid = grid_nd(shape_x, ts);
|
||||||
|
|
||||||
// metrics
|
// metrics
|
||||||
if(mode == BENCH){
|
if(mode == BENCH){
|
||||||
auto gbps = [&](double ns) { return 2 * size_x * dtsize / (ns * 1e-9) * 1e-9; };
|
auto gbps = [&](double ns) { return 2 * size_x * dtsize / (ns * 1e-9) * 1e-9; };
|
||||||
double triton_ns = triton::tools::bench([&]() { function(args, grid, stream);}, stream);
|
double triton_ns = triton::tools::bench([&]() { function((void**)&args, sizeof(args), grid, stream);}, stream);
|
||||||
bench.push_back(gbps(triton_ns));
|
bench.push_back(gbps(triton_ns));
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -144,7 +153,7 @@ void triton_reduce_nd(drv::stream* stream, const std::vector<int32_t>& shape_x,
|
|||||||
init_zeros(hy);
|
init_zeros(hy);
|
||||||
init_rand(hx);
|
init_rand(hx);
|
||||||
stream->write(&*dx, true, 0, hx);
|
stream->write(&*dx, true, 0, hx);
|
||||||
function(args, grid, stream);
|
function((void**)&args, sizeof(args), grid, stream);
|
||||||
stream->synchronize();
|
stream->synchronize();
|
||||||
stream->read(&*dy, true, 0, hy);
|
stream->read(&*dy, true, 0, hy);
|
||||||
cc_reduce_nd(ry, hx, op, axis, shape_x);
|
cc_reduce_nd(ry, hx, op, axis, shape_x);
|
||||||
|
@@ -7,7 +7,9 @@ R"(
|
|||||||
TYPE *C __noalias __aligned(16),
|
TYPE *C __noalias __aligned(16),
|
||||||
float alpha,
|
float alpha,
|
||||||
// equivalent matmul
|
// equivalent matmul
|
||||||
int M, int N, int K,
|
int M __retune,
|
||||||
|
int N __retune,
|
||||||
|
int K __retune,
|
||||||
// convolution properties
|
// convolution properties
|
||||||
int pad_h, int pad_w, int stride_h, int stride_w,
|
int pad_h, int pad_w, int stride_h, int stride_w,
|
||||||
// pointer increment
|
// pointer increment
|
||||||
@@ -16,7 +18,7 @@ R"(
|
|||||||
int lda_z __multipleof(8), int lda_ci __multipleof(8), int lda_h __multipleof(8), int lda_w __multipleof(8),
|
int lda_z __multipleof(8), int lda_ci __multipleof(8), int lda_h __multipleof(8), int lda_w __multipleof(8),
|
||||||
int ldb_ci __multipleof(8), int ldb_r __multipleof(8), int ldb_s __multipleof(8), int ldb_co __multipleof(8),
|
int ldb_ci __multipleof(8), int ldb_r __multipleof(8), int ldb_s __multipleof(8), int ldb_co __multipleof(8),
|
||||||
int ldc_z __multipleof(8), int ldc_co __multipleof(8), int ldc_p __multipleof(8), int ldc_q __multipleof(8)) {
|
int ldc_z __multipleof(8), int ldc_co __multipleof(8), int ldc_p __multipleof(8), int ldc_q __multipleof(8)) {
|
||||||
// prologue
|
// prologue
|
||||||
int ridx = get_program_id(0);
|
int ridx = get_program_id(0);
|
||||||
int ridy = get_program_id(1);
|
int ridy = get_program_id(1);
|
||||||
int ridz = get_program_id(2);
|
int ridz = get_program_id(2);
|
||||||
|
@@ -7,7 +7,7 @@ namespace src {
|
|||||||
R"(
|
R"(
|
||||||
void copy1d(TYPE * X __noalias __readonly __aligned(16),
|
void copy1d(TYPE * X __noalias __readonly __aligned(16),
|
||||||
TYPE * Y __noalias __readonly __aligned(16),
|
TYPE * Y __noalias __readonly __aligned(16),
|
||||||
int S0) {
|
int S0 __retune) {
|
||||||
int pid0 = get_program_id(0);
|
int pid0 = get_program_id(0);
|
||||||
int rs0[TS0] = pid0 * TS0 + 0 ... TS0;
|
int rs0[TS0] = pid0 * TS0 + 0 ... TS0;
|
||||||
TYPE* px[TS0] = X + rs0;
|
TYPE* px[TS0] = X + rs0;
|
||||||
@@ -20,8 +20,8 @@ void copy1d(TYPE * X __noalias __readonly __aligned(16),
|
|||||||
R"(
|
R"(
|
||||||
void copy2d(TYPE * X __noalias __readonly __aligned(16),
|
void copy2d(TYPE * X __noalias __readonly __aligned(16),
|
||||||
TYPE * Y __noalias __writeonly __aligned(16),
|
TYPE * Y __noalias __writeonly __aligned(16),
|
||||||
int S0 __multipleof(8),
|
int S0 __multipleof(8) __retune,
|
||||||
int S1 __multipleof(8)) {
|
int S1 __multipleof(8) __retune) {
|
||||||
int pid0 = get_program_id(0);
|
int pid0 = get_program_id(0);
|
||||||
int pid1 = get_program_id(1);
|
int pid1 = get_program_id(1);
|
||||||
int rs0[TS0] = pid0 * TS0 + 0 ... TS0;
|
int rs0[TS0] = pid0 * TS0 + 0 ... TS0;
|
||||||
@@ -37,9 +37,9 @@ void copy2d(TYPE * X __noalias __readonly __aligned(16),
|
|||||||
R"(
|
R"(
|
||||||
void copy3d(TYPE * X __noalias __readonly __aligned(16),
|
void copy3d(TYPE * X __noalias __readonly __aligned(16),
|
||||||
TYPE * Y __noalias __writeonly __aligned(16),
|
TYPE * Y __noalias __writeonly __aligned(16),
|
||||||
int S0 __multipleof(8),
|
int S0 __multipleof(8) __retune,
|
||||||
int S1 __multipleof(8),
|
int S1 __multipleof(8) __retune,
|
||||||
int S2 __multipleof(8)) {
|
int S2 __multipleof(8) __retune) {
|
||||||
// program id
|
// program id
|
||||||
int pid0 = get_program_id(0);
|
int pid0 = get_program_id(0);
|
||||||
int pid1 = get_program_id(1);
|
int pid1 = get_program_id(1);
|
||||||
|
Reference in New Issue
Block a user