Now find correct tuning configuration

This commit is contained in:
Philippe Tillet
2019-06-06 20:13:26 -07:00
parent 0a0b48e9a2
commit 6045209d5b
3 changed files with 40 additions and 27 deletions

View File

@@ -4,6 +4,7 @@
#include "triton/driver/backend.h"
#include "triton/driver/stream.h"
#include "triton/runtime/jit.h"
#include "triton/tools/bench.hpp"
#define EIGEN_USE_GPU
#include "tensorflow/core/framework/op.h"
@@ -125,30 +126,36 @@ class BlockSparseGemmOp : public OpKernel {
triton::driver::cu_buffer dc(ctx, (CUdeviceptr)c->flat<float>().data(), false);
triton::driver::cu_buffer dlocks(ctx, (CUdeviceptr)locks.flat<int32_t>().data(), false);
stream->synchronize();
// benchmark a given matrix multiplication kernel
auto benchmark = [&](triton::driver::kernel* kernel,
triton::jit::launch_information info) {
// launch info
unsigned TM = info.global_range_size[0];
unsigned TN = info.global_range_size[1];
unsigned nthreads = info.num_threads;
unsigned GZ = jit.get_int("GZ");
std::array<size_t, 3> grid = {(M + TM - 1)/TM, (N + TN - 1)/TN, GZ};
// set argument
kernel->setArg(0, *da.cu());
kernel->setArg(1, *db.cu());
kernel->setArg(2, *dc.cu());
kernel->setArg(3, M);
kernel->setArg(4, N);
kernel->setArg(5, K);
kernel->setArg(6, M);
kernel->setArg(7, N);
kernel->setArg(8, M);
kernel->setArg(9, *dlocks.cu());
kernel->setArg(10, grid[0]);
kernel->setArg(11, grid[1]);
stream->enqueue(kernel, grid, {nthreads, 1, 1});
stream->synchronize();
double ts = triton::tools::bench([&](){stream->enqueue(kernel, grid, {nthreads, 1, 1});},
[&](){ stream->synchronize(); }, nullptr);
return 2.*M*N*K / ts * 1e-3;
};
// just-in-time compile source-code
jit.add_module("matmul", src, {8, 2, 16, 8, 2, 16, 8, 8, 2, 2, 8, 8, 8, 1});
triton::driver::kernel* kernel = jit.get_function("matmul");
triton::jit::launch_information info = jit.get_launch_info("matmul");
// launch info
unsigned TM = info.global_range_size[0];
unsigned TN = info.global_range_size[1];
unsigned nthreads = info.num_threads;
unsigned GZ = jit.get_int("GZ");
std::array<size_t, 3> grid = {(M + TM - 1)/TM, (N + TN - 1)/TN, GZ};
// set argument
kernel->setArg(0, *da.cu());
kernel->setArg(1, *db.cu());
kernel->setArg(2, *dc.cu());
kernel->setArg(3, M);
kernel->setArg(4, N);
kernel->setArg(5, K);
kernel->setArg(6, M);
kernel->setArg(7, N);
kernel->setArg(8, M);
kernel->setArg(9, *dlocks.cu());
kernel->setArg(10, grid[0]);
kernel->setArg(11, grid[1]);
stream->enqueue(kernel, grid, {nthreads, 1, 1});
jit.autotune("matmul", src, benchmark);
}
private: