better benchmarking
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
#include <cstring>
|
||||
#include <cstdio>
|
||||
#include <sstream>
|
||||
#include "cuda.h"
|
||||
#include "triton/runtime/jit.h"
|
||||
#include "triton/driver/backend.h"
|
||||
#include "triton/driver/stream.h"
|
||||
@@ -8,12 +9,20 @@
|
||||
#include "triton/dnn/shift.h"
|
||||
#include "triton/external/half.hpp"
|
||||
|
||||
double do_bench(triton::driver::context* context,
|
||||
struct perf_t {
|
||||
double triton;
|
||||
double cublas;
|
||||
};
|
||||
|
||||
perf_t do_bench(triton::driver::stream *stream,
|
||||
int32_t R, int32_t S, int32_t B, int32_t F, int32_t H, int32_t W, int32_t C,
|
||||
triton::dnn::op_t op, triton::dnn::layout_t layout,
|
||||
std::string numeric_t) {
|
||||
typedef float NumericT;
|
||||
|
||||
// driver variables
|
||||
triton::driver::context* context = stream->context();
|
||||
|
||||
// random shifts
|
||||
std::vector<int32_t> shift_h(C);
|
||||
std::vector<int32_t> shift_w(C);
|
||||
@@ -44,7 +53,6 @@ double do_bench(triton::driver::context* context,
|
||||
triton::driver::buffer* dc = triton::driver::buffer::create(context, hc.size()*4);
|
||||
triton::driver::buffer* da = triton::driver::buffer::create(context, ha.size()*sizeof(NumericT));
|
||||
triton::driver::buffer* db = triton::driver::buffer::create(context, hb.size()*sizeof(NumericT));
|
||||
triton::driver::stream* stream = triton::driver::stream::create(context);
|
||||
// initialize host
|
||||
srand(0);
|
||||
for(size_t i = 0; i < ha.size(); i++)
|
||||
@@ -58,8 +66,29 @@ double do_bench(triton::driver::context* context,
|
||||
stream->write(db, true, 0, hb);
|
||||
stream->write(dc, true, 0, hc);
|
||||
stream->synchronize();
|
||||
double nanosec = triton::tools::bench([&]() { shift.enqueue(stream, {da, db, dc});}, stream);
|
||||
return shift.num_flops() / nanosec * 1e-3;
|
||||
// benchmark triton
|
||||
double triton_ns = triton::tools::bench([&]() { shift.enqueue(stream, {da, db, dc}, triton::dnn::FULL_TUNING);}, stream);
|
||||
// benchmark cublas
|
||||
NumericT alpha = 1;
|
||||
NumericT beta = 0;
|
||||
cublasGemmAlgo_t fastest;
|
||||
cublasGemm(HALF_TYPE, stream, shift.AT(), shift.BT(), shift.M(), shift.N(), shift.K(),
|
||||
&alpha, da, shift.lda(),
|
||||
db, shift.ldb(), &beta,
|
||||
dc, shift.ldc(), &fastest);
|
||||
double cublas_ns = triton::tools::bench([&]() { cublasGemm(HALF_TYPE, stream, shift.AT(), shift.BT(), shift.M(), shift.N(), shift.K(),
|
||||
&alpha, da, shift.lda(),
|
||||
db, shift.ldb(),
|
||||
&beta, dc, shift.ldc(), nullptr, fastest); }, stream);
|
||||
// result
|
||||
auto tflops = [&](double nanosec) { return shift.num_flops() / nanosec * 1e-3; };
|
||||
perf_t result;
|
||||
result.cublas = tflops(cublas_ns);
|
||||
result.triton = tflops(triton_ns);
|
||||
delete da;
|
||||
delete db;
|
||||
delete dc;
|
||||
return result;
|
||||
}
|
||||
|
||||
int main() {
|
||||
@@ -86,13 +115,15 @@ int main() {
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
double perf(triton::driver::context *context){
|
||||
return do_bench(context, R, S, B, F, H, W, C, op, layout, ty);
|
||||
perf_t perf(triton::driver::stream *stream){
|
||||
return do_bench(stream, R, S, B, F, H, W, C, op, layout, ty);
|
||||
}
|
||||
};
|
||||
// shapes to benchmark
|
||||
std::vector<config_t> configs;
|
||||
std::vector<config_t> resnet18 = {
|
||||
std::vector<config_t> resnet18 =
|
||||
{
|
||||
{128, 128, 32, 32, 3, 3, 128, 1, 1},
|
||||
{128, 128, 32, 32, 3, 3, 128, 1, 1},
|
||||
{128, 128, 32, 32, 3, 3, 256, 2, 2},
|
||||
{128, 256, 16, 16, 3, 3, 256, 1, 1},
|
||||
@@ -108,7 +139,11 @@ int main() {
|
||||
|
||||
// initialize default compute device
|
||||
auto context = triton::driver::backend::contexts::get_default();
|
||||
for(config_t c: configs)
|
||||
std::cout << c.repr() << ", " << c.perf(context) << std::endl;
|
||||
triton::driver::stream *stream = triton::driver::stream::create(context);
|
||||
|
||||
for(config_t c: configs){
|
||||
std::string repr = c.repr();
|
||||
perf_t perf = c.perf(stream);
|
||||
std::cout << repr << ", " << perf.triton << ", " << perf.cublas << std::endl;
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user