Files
triton/examples/cpp/dot.cpp
2019-08-16 15:56:58 -07:00

224 lines
7.2 KiB
C++

#include <cstring>
#include <sstream>
#include <cstdio>
#include "triton/driver/backend.h"
#include "triton/driver/stream.h"
#include "triton/dnn/dot.h"
#include "triton/tools/bench.hpp"
#include "triton/external/half.hpp"
#include "triton/runtime/function.h"
#include "cuda.h"
template<class T>
void diff(const std::vector<T>& x, const std::vector<T>& y){
for(size_t i = 0; i < x.size(); i++)
if(std::isnan(x[i]) || std::abs(x[i] - y[i])/std::max(x[i], y[i]) > 1e-4){
std::cout << i << " " << x[i] << " " << y[i] << std::endl;
exit(EXIT_FAILURE);
}
std::cout << "Pass!" << std::endl;
}
template<class T, bool AT, bool BT>
static void cpu_ref(std::vector<T> &c, const std::vector<T> &a, const std::vector<T> &b,
size_t M, size_t N, size_t K){
for(size_t m = 0; m < M; m++)
for(size_t n = 0; n < N; n++){
float acc = 0;
for(size_t k = 0; k < K; k++)
acc = acc + (AT ? a[k + m*K] : a[m + k*M]) * (BT ? b[n + k*N] : b[k + n*K]);
c[m + n*M] = static_cast<T>(acc);
}
}
template<class T>
void cpu_ref(bool AT_, bool BT_, size_t M, size_t N, size_t K,
std::vector<T> &c, const std::vector<T> &a, const std::vector<T> &b) {
if(AT_ && BT_)
cpu_ref<T, true, true>(c, a, b, M, N, K);
else if(AT_ && !BT_)
cpu_ref<T, true, false>(c, a, b, M, N, K);
else if(!AT_ && BT_)
cpu_ref<T, false, true>(c, a, b, M, N, K);
else
cpu_ref<T, false, false>(c, a, b, M, N, K);
}
std::string src(bool AT, bool BT, std::string a_ty, std::string b_ty, std::string c_ty, int align_lda, int align_ldb) {
std::string ZS = "1";
std::string AS0 = "TM", AS1 = "TK";
std::string BS0 = "TK", BS1 = "TN";
std::string XAS0 = "TM", XAS1 = "TK / " + ZS, XAS2 = ZS;
std::string XBS0 = "TK / " + ZS, XBS1 = ZS, XBS2 = "TN";
std::string bca0 = "[newaxis, :]", bca1 = "[:, newaxis]";
std::string bcb0 = "[:, newaxis]", bcb1 = "[newaxis, :]";
std::string lda0 = "*lda", lda1 = "";
std::string ldb0 = "", ldb1 = "*ldb";
std::string usea = AT ? "trans(a)" : "a";
std::string useb = BT ? "trans(b)" : "b";
if(AT){
std::swap(AS0, AS1);
std::swap(XAS0, XAS1);
std::swap(XAS1, XAS2);
std::swap(bca0, bca1);
std::swap(lda0, lda1);
}
if(BT){
std::swap(BS0, BS1);
std::swap(XBS1, XBS2);
std::swap(XBS0, XBS1);
std::swap(bcb0, bcb1);
std::swap(ldb0, ldb1);
}
std::string AS = AS0 + ", " + AS1;
std::string BS = BS0 + ", " + BS1;
std::string XCS = "TM, TN";
std::string align_lda_str = "multiple_of(" + std::to_string(align_lda) + ")";
std::string align_ldb_str = "multiple_of(" + std::to_string(align_ldb) + ")";
std::string res =
R"(
const tunable int TM = {128};
const tunable int TN = {128};
const tunable int TK = {32};
void matmul(restrict read_only align(16) )" + a_ty + R"( *A,
restrict read_only align(16) )" + b_ty + R"( *B,
restrict read_only align(16) )" + c_ty + R"( *C,
int M, int N, int K,
)" + align_lda_str + R"( int lda, )" + align_ldb_str + R"(" int ldb, int ldc) {
int ridx = get_range_id(0);
int ridy = get_range_id(1);
int rxa[TM] = ridx * TM + (0 ... TM);
int ryb[TN] = ridy * TN + (0 ... TN);
int rka[TK] = 0 ... TK;
int rkb[TK] = 0 ... TK;
float xc[)" + XCS + R"(] = 0;
)" + a_ty + R"(* pa[)" + AS + "] = A + rka" + bca0 + lda0 + " + rxa" + bca1 + lda1 + R"(;
)" + b_ty + R"(* pb[)" + BS + "] = B + rkb" + bcb0 + ldb0 + " + ryb" + bcb1 + ldb1 + R"(;
)" + a_ty + R"( a[)" + AS + R"(] = *pa;
)" + b_ty + R"( b[)" + BS + R"(] = *pb;
for(int k = K; k > 0; k = k - TK){
xc = dot()" + usea + ", " + useb + R"(, xc);
pa = pa + TK)" + lda0 + R"(;
pb = pb + TK)" + ldb0 + R"(;
a = *pa;
b = *pb;
}
int rxc[TM] = ridx * TM + (0 ... TM);
int ryc[TN] = ridy * TN + (0 ... TN);
)" + c_ty + R"(* pc[TM, TN] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis];
)" + c_ty + R"( c[TM, TN] = xc;
bool checkc0[TM] = rxc < M;
bool checkc1[TN] = ryc < N;
bool checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
@checkc *pc = c;
}
)";
return res;
}
struct perf_t {
double triton;
double cublas;
};
namespace drv = triton::driver;
namespace rt = triton::runtime;
perf_t do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int32_t K){
typedef half NumericT;
std::string ty = "half";
size_t dt_nbytes = sizeof(NumericT);
drv::context* context = stream->context();
std::vector<NumericT> hc(M*N);
std::vector<NumericT> ha(M*K);
std::vector<NumericT> hb(K*N);
int32_t lda = AT ? K : M;
int32_t ldb = BT ? N : K;
int32_t ldc = M;
srand(0);
for(size_t i = 0; i < ha.size(); i++)
ha[i] = static_cast<NumericT>((double)rand()/RAND_MAX);
for(size_t i = 0; i < hb.size(); i++)
hb[i] = static_cast<NumericT>((double)rand()/RAND_MAX);
for(size_t i = 0; i < hc.size(); i++)
hc[i] = static_cast<NumericT>((double)0);
drv::buffer* dc = drv::buffer::create(context, hc.size()*dt_nbytes);
drv::buffer* da = drv::buffer::create(context, ha.size()*dt_nbytes);
drv::buffer* db = drv::buffer::create(context, hb.size()*dt_nbytes);
stream->write(da, true, 0, ha);
stream->write(db, true, 0, hb);
stream->write(dc, true, 0, hc);
stream->synchronize();
// run
rt::function function(src(AT, BT, ty, ty, ty, 8, 8));
auto ceil = [](size_t x, size_t y) { return (x + y - 1) / y; };
auto grid = [&](const rt::params_t& x) { return rt::grid_t{ceil(M, x.at("TM")), ceil(N, x.at("TN")), 1}; };
auto tflops = [&](double nanosec) { return 2.*M*N*K / nanosec * 1e-3; };
perf_t res;
res.triton = tflops(triton::tools::bench([&]() { function({da, db, dc, M, N, K, lda, ldb, ldc}, grid, stream);}, stream));
res.cublas = 0;
// test
// stream->synchronize();
// stream->read(dc, true, 0, hc);
// std::vector<NumericT> rc(hc.size());
// cpu_ref(AT, BT, M, N, K, rc, ha, hb);
// for(size_t i = 0; i < M*N; i++)
// if(std::isinf(hc[i]) || std::isnan(hc[i]) || std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-2){
// std::cout << i << " " << hc[i] << " " << rc[i] << std::endl;
// exit(EXIT_FAILURE);
// }
// std::cout << "Pass!" << std::endl;
// clean-up
delete dc;
delete da;
delete db;
return res;
}
int main() {
struct config_t{
bool AT;
bool BT;
int32_t M;
int32_t N;
int32_t K;
std::string repr() {
std::ostringstream oss;
oss << AT << " " << BT << " " << M << " " << N << " " << K;
return oss.str();
}
perf_t perf(triton::driver::stream *stream){
return do_bench(stream, AT, BT, M, N, K);
}
};
// shapes to benchmark
std::vector<config_t> configs = {
// {false, false, 8192, 512, 512},
{false, true, 128, 128, 128}
// {false, true, 128, 128, 128},
// {false, false, 128, 128, 128},
// {true, false, 128, 128, 128},
// {true, true, 128, 128, 128}
// {false, true, 32768, 256, 512}
// {true, false, 8192, 512, 512},
// {true, true, 8192, 512, 512}
};
// initialize default compute device
auto context = triton::driver::backend::contexts::get_default();
triton::driver::stream* stream = triton::driver::stream::create(context);
// does the work
for(config_t c: configs){
perf_t perf = c.perf(stream);
std::cout << "// " << c.repr() << ", " << perf.triton << ", " << perf.cublas << std::endl;
}
}