[examples] removed dependency on isaac for auto-tuning

This commit is contained in:
Philippe Tillet
2019-03-11 22:22:43 -04:00
parent 87c85ed50d
commit b73c3bdd25
6 changed files with 86 additions and 32 deletions

View File

@@ -81,13 +81,37 @@ private:
high_resolution_clock::time_point _start;
};
template<class T>
T min(std::vector<T> x)
{ return *std::min_element(x.begin(), x.end()); }
template<class OP, class SYNC>
double bench(OP const & op, SYNC const & sync, triton::driver::device const & device)
{
timer tmr;
std::vector<size_t> times;
double total_time = 0;
op();
sync();
while(total_time*1e-9 < 1e-3){
float norm = (float)device.current_sm_clock()/device.max_sm_clock();
tmr.start();
op();
sync();
times.push_back(norm*tmr.get().count());
total_time+=times.back();
}
return min(times);
}
int main() {
// initialize default compute device
auto context = triton::driver::backend::contexts::get_default();
triton::jit jit(context);
// matrix multiplication parameters
size_t M = 512, N = 512, K = 512;
size_t bound = 8;
std::vector<float> hc(M*N);
std::vector<float> rc(M*N);
std::vector<float> ha(M*K);
@@ -112,6 +136,22 @@ int main() {
// benchmark a given matrix multiplication kernel
auto benchmark = [&](triton::driver::kernel kernel,
triton::jit::launch_information info) {
// launch info
unsigned TM = info.global_range_size[0];
unsigned TN = info.global_range_size[1];
unsigned nthreads = info.num_threads;
std::array<size_t, 3> grid = {(M + TM - 1)/TM, (N + TN - 1)/TN, 1};
// fast bounds-checking
unsigned TK = jit.get_int("TK");
unsigned lasti = (grid[0]*TM - 1)*TM + TM - 1;
unsigned lastj = (grid[1]*TN - 1)*TN + TN - 1;
unsigned lastk = TK - 1;
bool AT = false;
bool BT = true;
unsigned last_safe_a = (AT==false)?(M*K - 1 - lasti)/M - lastk : M*K - 1 - lasti*K - lastk;
unsigned last_safe_b = (BT==true)?(N*K - 1 - lastj)/N - lastk : N*K - 1 - lastj*K - lastk;
int32_t bound = std::max<unsigned>(1, std::max(K - last_safe_a, K - last_safe_b));
// set argument
kernel.setArg(0, da);
kernel.setArg(1, db);
kernel.setArg(2, dc);
@@ -119,39 +159,33 @@ int main() {
kernel.setArg(4, N);
kernel.setArg(5, K);
kernel.setArg(6, bound);
unsigned TM = info.global_range_size[0];
unsigned TN = info.global_range_size[1];
unsigned nthreads = info.num_threads;
timer t;
t.start();
stream.enqueue(kernel, {(M + TM - 1)/TM, (N + TN - 1)/TN, 1}, {nthreads, 1, 1});
// dry run
stream.enqueue(kernel, grid, {nthreads, 1, 1});
stream.synchronize();
double ts = t.get().count()*1e-9;
// benchmark
double ts = bench([&](){stream.enqueue(kernel, grid, {nthreads, 1, 1});},
[&](){ stream.synchronize(); },
context.device());
ts = ts * 1e-9;
double tflops = 2*M*N*K / ts * 1e-12;
std::cout << tflops << std::endl;
return ts;
return tflops;
};
// just-in-time compile source-code
std::vector<unsigned> params = {
// a0
8, 2, 16,
// b0
4, 4, 16,
// c
8, 4, 2, 4,
// a1
4, 2, 8,
// b1
8, 1
16, 2, 64,
32, 2, 64,
16, 8, 2, 2,
8, 1, 8,
4, 1
};
triton::jit jit(context);
jit.autotune(src, benchmark);
// jit.autotune(src, benchmark);
jit.add_module(src, params);
triton::driver::kernel kernel = jit.get_function("matmul");
triton::jit::launch_information info = jit.get_launch_info("matmul");
benchmark(kernel, info);
std::cout << benchmark(kernel, info) << std::endl;
stream.read(dc, true, 0, hc);
simple_gemm(rc, ha, hb, M, N, K);
for(size_t i = 0; i < M*N; i++)