[runtime] overall of the run-time API
This commit is contained in:
@@ -1,12 +1,12 @@
|
||||
#include <cstring>
|
||||
#include <cstring>
|
||||
#include <sstream>
|
||||
#include <cstdio>
|
||||
#include "triton/runtime/jit.h"
|
||||
#include "triton/driver/backend.h"
|
||||
#include "triton/driver/stream.h"
|
||||
#include "triton/dnn/dot.h"
|
||||
#include "triton/tools/bench.hpp"
|
||||
#include "triton/external/half.hpp"
|
||||
#include "triton/runtime/function.h"
|
||||
#include "cuda.h"
|
||||
|
||||
template<class T>
|
||||
@@ -19,20 +19,125 @@ void diff(const std::vector<T>& x, const std::vector<T>& y){
|
||||
std::cout << "Pass!" << std::endl;
|
||||
}
|
||||
|
||||
template<class T, bool AT, bool BT>
|
||||
static void cpu_ref(std::vector<T> &c, const std::vector<T> &a, const std::vector<T> &b,
|
||||
size_t M, size_t N, size_t K){
|
||||
for(size_t m = 0; m < M; m++)
|
||||
for(size_t n = 0; n < N; n++){
|
||||
float acc = 0;
|
||||
for(size_t k = 0; k < K; k++)
|
||||
acc = acc + (AT ? a[k + m*K] : a[m + k*M]) * (BT ? b[n + k*N] : b[k + n*K]);
|
||||
c[m + n*M] = static_cast<T>(acc);
|
||||
}
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void cpu_ref(bool AT_, bool BT_, size_t M, size_t N, size_t K,
|
||||
std::vector<T> &c, const std::vector<T> &a, const std::vector<T> &b) {
|
||||
if(AT_ && BT_)
|
||||
cpu_ref<T, true, true>(c, a, b, M, N, K);
|
||||
else if(AT_ && !BT_)
|
||||
cpu_ref<T, true, false>(c, a, b, M, N, K);
|
||||
else if(!AT_ && BT_)
|
||||
cpu_ref<T, false, true>(c, a, b, M, N, K);
|
||||
else
|
||||
cpu_ref<T, false, false>(c, a, b, M, N, K);
|
||||
}
|
||||
|
||||
|
||||
|
||||
std::string src(bool AT, bool BT, std::string a_ty, std::string b_ty, std::string c_ty, int align_lda, int align_ldb) {
|
||||
std::string ZS = "1";
|
||||
std::string AS0 = "TM", AS1 = "TK";
|
||||
std::string BS0 = "TK", BS1 = "TN";
|
||||
std::string XAS0 = "TM", XAS1 = "TK / " + ZS, XAS2 = ZS;
|
||||
std::string XBS0 = "TK / " + ZS, XBS1 = ZS, XBS2 = "TN";
|
||||
std::string bca0 = "[newaxis, :]", bca1 = "[:, newaxis]";
|
||||
std::string bcb0 = "[:, newaxis]", bcb1 = "[newaxis, :]";
|
||||
std::string lda0 = "*lda", lda1 = "";
|
||||
std::string ldb0 = "", ldb1 = "*ldb";
|
||||
std::string usea = AT ? "trans(a)" : "a";
|
||||
std::string useb = BT ? "trans(b)" : "b";
|
||||
if(AT){
|
||||
std::swap(AS0, AS1);
|
||||
std::swap(XAS0, XAS1);
|
||||
std::swap(XAS1, XAS2);
|
||||
std::swap(bca0, bca1);
|
||||
std::swap(lda0, lda1);
|
||||
}
|
||||
if(BT){
|
||||
std::swap(BS0, BS1);
|
||||
std::swap(XBS1, XBS2);
|
||||
std::swap(XBS0, XBS1);
|
||||
std::swap(bcb0, bcb1);
|
||||
std::swap(ldb0, ldb1);
|
||||
}
|
||||
std::string AS = AS0 + ", " + AS1;
|
||||
std::string BS = BS0 + ", " + BS1;
|
||||
std::string XCS = "TM, TN";
|
||||
std::string align_lda_str = "multiple_of(" + std::to_string(align_lda) + ")";
|
||||
std::string align_ldb_str = "multiple_of(" + std::to_string(align_ldb) + ")";
|
||||
std::string res =
|
||||
R"(
|
||||
const tunable int TM = {128};
|
||||
const tunable int TN = {128};
|
||||
const tunable int TK = {32};
|
||||
|
||||
void matmul(restrict read_only align(16) )" + a_ty + R"( *A,
|
||||
restrict read_only align(16) )" + b_ty + R"( *B,
|
||||
restrict read_only align(16) )" + c_ty + R"( *C,
|
||||
int M, int N, int K,
|
||||
)" + align_lda_str + R"( int lda, )" + align_ldb_str + R"(" int ldb, int ldc) {
|
||||
int ridx = get_range_id(0);
|
||||
int ridy = get_range_id(1);
|
||||
int rxa[TM] = ridx * TM + (0 ... TM);
|
||||
int ryb[TN] = ridy * TN + (0 ... TN);
|
||||
int rka[TK] = 0 ... TK;
|
||||
int rkb[TK] = 0 ... TK;
|
||||
float xc[)" + XCS + R"(] = 0;
|
||||
)" + a_ty + R"(* pa[)" + AS + "] = A + rka" + bca0 + lda0 + " + rxa" + bca1 + lda1 + R"(;
|
||||
)" + b_ty + R"(* pb[)" + BS + "] = B + rkb" + bcb0 + ldb0 + " + ryb" + bcb1 + ldb1 + R"(;
|
||||
)" + a_ty + R"( a[)" + AS + R"(] = *pa;
|
||||
)" + b_ty + R"( b[)" + BS + R"(] = *pb;
|
||||
for(int k = K; k > 0; k = k - TK){
|
||||
xc = dot()" + usea + ", " + useb + R"(, xc);
|
||||
pa = pa + TK)" + lda0 + R"(;
|
||||
pb = pb + TK)" + ldb0 + R"(;
|
||||
a = *pa;
|
||||
b = *pb;
|
||||
}
|
||||
int rxc[TM] = ridx * TM + (0 ... TM);
|
||||
int ryc[TN] = ridy * TN + (0 ... TN);
|
||||
)" + c_ty + R"(* pc[TM, TN] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis];
|
||||
)" + c_ty + R"( c[TM, TN] = xc;
|
||||
bool checkc0[TM] = rxc < M;
|
||||
bool checkc1[TN] = ryc < N;
|
||||
bool checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
|
||||
@checkc *pc = c;
|
||||
}
|
||||
)";
|
||||
return res;
|
||||
}
|
||||
|
||||
struct perf_t {
|
||||
double triton;
|
||||
double cublas;
|
||||
};
|
||||
|
||||
namespace drv = triton::driver;
|
||||
namespace rt = triton::runtime;
|
||||
|
||||
perf_t do_bench(triton::driver::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int32_t K){
|
||||
perf_t do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int32_t K){
|
||||
typedef half NumericT;
|
||||
std::string ty = "half";
|
||||
size_t dt_nbytes = sizeof(NumericT);
|
||||
triton::driver::context* context = stream->context();
|
||||
drv::context* context = stream->context();
|
||||
std::vector<NumericT> hc(M*N);
|
||||
std::vector<NumericT> ha(M*K);
|
||||
std::vector<NumericT> hb(K*N);
|
||||
int32_t lda = AT ? K : M;
|
||||
int32_t ldb = BT ? N : K;
|
||||
int32_t ldc = M;
|
||||
srand(0);
|
||||
for(size_t i = 0; i < ha.size(); i++)
|
||||
ha[i] = static_cast<NumericT>((double)rand()/RAND_MAX);
|
||||
@@ -40,54 +145,40 @@ perf_t do_bench(triton::driver::stream* stream, bool AT, bool BT, int32_t M, int
|
||||
hb[i] = static_cast<NumericT>((double)rand()/RAND_MAX);
|
||||
for(size_t i = 0; i < hc.size(); i++)
|
||||
hc[i] = static_cast<NumericT>((double)0);
|
||||
triton::driver::buffer* dc = triton::driver::buffer::create(context, hc.size()*dt_nbytes);
|
||||
triton::driver::buffer* da = triton::driver::buffer::create(context, ha.size()*dt_nbytes);
|
||||
triton::driver::buffer* db = triton::driver::buffer::create(context, hb.size()*dt_nbytes);
|
||||
drv::buffer* dc = drv::buffer::create(context, hc.size()*dt_nbytes);
|
||||
drv::buffer* da = drv::buffer::create(context, ha.size()*dt_nbytes);
|
||||
drv::buffer* db = drv::buffer::create(context, hb.size()*dt_nbytes);
|
||||
stream->write(da, true, 0, ha);
|
||||
stream->write(db, true, 0, hb);
|
||||
stream->write(dc, true, 0, hc);
|
||||
stream->synchronize();
|
||||
triton::dnn::dot dot(M, N, K, AT, BT, ty, ty, ty, 8, 8, 8);
|
||||
// benchmark triton
|
||||
double triton_ns = triton::tools::bench([&]() { dot.enqueue(stream, {da, db, dc}, triton::dnn::FULL_TUNING);}, stream);
|
||||
// benchmark cublas
|
||||
// NumericT alpha = 1;
|
||||
// NumericT beta = 0;
|
||||
// int32_t lda = AT ? K : M;
|
||||
// int32_t ldb = BT ? N : K;
|
||||
// int32_t ldc = M;
|
||||
// cublasGemmAlgo_t fastest;
|
||||
// cublasGemm(HALF_TYPE, stream, AT, BT, M, N, K,
|
||||
// &alpha, da, lda,
|
||||
// db, ldb, &beta,
|
||||
// dc, ldc, &fastest);
|
||||
// double cublas_ns = triton::tools::bench([&]() { cublasGemm(HALF_TYPE, stream, AT, BT, M, N, K,
|
||||
// &alpha, da, lda,
|
||||
// db, ldb, &beta,
|
||||
// dc, ldc, nullptr, CUBLAS_GEMM_DEFAULT_TENSOR_OP); }, stream);
|
||||
// result
|
||||
auto tflops = [&](double nanosec) { return dot.num_flops() / nanosec * 1e-3; };
|
||||
// run
|
||||
rt::function function(src(AT, BT, ty, ty, ty, 8, 8));
|
||||
auto ceil = [](size_t x, size_t y) { return (x + y - 1) / y; };
|
||||
auto grid = [&](const rt::params_t& x) { return rt::grid_t{ceil(M, x.at("TM")), ceil(N, x.at("TN")), 1}; };
|
||||
|
||||
perf_t result;
|
||||
// result.cublas = tflops(cublas_ns);
|
||||
result.triton = tflops(triton_ns);
|
||||
auto tflops = [&](double nanosec) { return 2.*M*N*K / nanosec * 1e-3; };
|
||||
perf_t res;
|
||||
res.triton = tflops(triton::tools::bench([&]() { function({da, db, dc, M, N, K, lda, ldb, ldc}, grid, stream);}, stream));
|
||||
res.cublas = 0;
|
||||
|
||||
// test
|
||||
stream->read(dc, true, 0, hc);
|
||||
std::vector<NumericT> rc(hc.size());
|
||||
dot.cpu_ref(rc, ha, hb);
|
||||
for(size_t i = 0; i < M*N; i++)
|
||||
if(std::isinf(hc[i]) || std::isnan(hc[i]) || std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-2){
|
||||
std::cout << i << " " << hc[i] << " " << rc[i] << std::endl;
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
std::cout << "Pass!" << std::endl;
|
||||
// stream->synchronize();
|
||||
// stream->read(dc, true, 0, hc);
|
||||
// std::vector<NumericT> rc(hc.size());
|
||||
// cpu_ref(AT, BT, M, N, K, rc, ha, hb);
|
||||
// for(size_t i = 0; i < M*N; i++)
|
||||
// if(std::isinf(hc[i]) || std::isnan(hc[i]) || std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-2){
|
||||
// std::cout << i << " " << hc[i] << " " << rc[i] << std::endl;
|
||||
// exit(EXIT_FAILURE);
|
||||
// }
|
||||
// std::cout << "Pass!" << std::endl;
|
||||
|
||||
// clean-up
|
||||
delete dc;
|
||||
delete da;
|
||||
delete db;
|
||||
return result;
|
||||
return res;
|
||||
}
|
||||
|
||||
int main() {
|
||||
@@ -111,12 +202,11 @@ int main() {
|
||||
// shapes to benchmark
|
||||
std::vector<config_t> configs = {
|
||||
// {false, false, 8192, 512, 512},
|
||||
{false, true, 128, 128, 128}
|
||||
{false, true, 8192, 8192, 8192}
|
||||
// {false, true, 128, 128, 128},
|
||||
// {false, false, 128, 128, 128},
|
||||
// {true, false, 128, 128, 128},
|
||||
// {true, true, 128, 128, 128}
|
||||
|
||||
// {false, true, 32768, 256, 512}
|
||||
// {true, false, 8192, 512, 512},
|
||||
// {true, true, 8192, 512, 512}
|
||||
|
Reference in New Issue
Block a user