[examples] removed dependency on isaac for auto-tuning
This commit is contained in:
@@ -81,13 +81,37 @@ private:
|
||||
high_resolution_clock::time_point _start;
|
||||
};
|
||||
|
||||
template<class T>
|
||||
T min(std::vector<T> x)
|
||||
{ return *std::min_element(x.begin(), x.end()); }
|
||||
|
||||
|
||||
template<class OP, class SYNC>
|
||||
double bench(OP const & op, SYNC const & sync, triton::driver::device const & device)
|
||||
{
|
||||
timer tmr;
|
||||
std::vector<size_t> times;
|
||||
double total_time = 0;
|
||||
op();
|
||||
sync();
|
||||
while(total_time*1e-9 < 1e-3){
|
||||
float norm = (float)device.current_sm_clock()/device.max_sm_clock();
|
||||
tmr.start();
|
||||
op();
|
||||
sync();
|
||||
times.push_back(norm*tmr.get().count());
|
||||
total_time+=times.back();
|
||||
}
|
||||
return min(times);
|
||||
}
|
||||
|
||||
int main() {
|
||||
// initialize default compute device
|
||||
auto context = triton::driver::backend::contexts::get_default();
|
||||
triton::jit jit(context);
|
||||
|
||||
// matrix multiplication parameters
|
||||
size_t M = 512, N = 512, K = 512;
|
||||
size_t bound = 8;
|
||||
std::vector<float> hc(M*N);
|
||||
std::vector<float> rc(M*N);
|
||||
std::vector<float> ha(M*K);
|
||||
@@ -112,6 +136,22 @@ int main() {
|
||||
// benchmark a given matrix multiplication kernel
|
||||
auto benchmark = [&](triton::driver::kernel kernel,
|
||||
triton::jit::launch_information info) {
|
||||
// launch info
|
||||
unsigned TM = info.global_range_size[0];
|
||||
unsigned TN = info.global_range_size[1];
|
||||
unsigned nthreads = info.num_threads;
|
||||
std::array<size_t, 3> grid = {(M + TM - 1)/TM, (N + TN - 1)/TN, 1};
|
||||
// fast bounds-checking
|
||||
unsigned TK = jit.get_int("TK");
|
||||
unsigned lasti = (grid[0]*TM - 1)*TM + TM - 1;
|
||||
unsigned lastj = (grid[1]*TN - 1)*TN + TN - 1;
|
||||
unsigned lastk = TK - 1;
|
||||
bool AT = false;
|
||||
bool BT = true;
|
||||
unsigned last_safe_a = (AT==false)?(M*K - 1 - lasti)/M - lastk : M*K - 1 - lasti*K - lastk;
|
||||
unsigned last_safe_b = (BT==true)?(N*K - 1 - lastj)/N - lastk : N*K - 1 - lastj*K - lastk;
|
||||
int32_t bound = std::max<unsigned>(1, std::max(K - last_safe_a, K - last_safe_b));
|
||||
// set argument
|
||||
kernel.setArg(0, da);
|
||||
kernel.setArg(1, db);
|
||||
kernel.setArg(2, dc);
|
||||
@@ -119,39 +159,33 @@ int main() {
|
||||
kernel.setArg(4, N);
|
||||
kernel.setArg(5, K);
|
||||
kernel.setArg(6, bound);
|
||||
unsigned TM = info.global_range_size[0];
|
||||
unsigned TN = info.global_range_size[1];
|
||||
unsigned nthreads = info.num_threads;
|
||||
timer t;
|
||||
t.start();
|
||||
stream.enqueue(kernel, {(M + TM - 1)/TM, (N + TN - 1)/TN, 1}, {nthreads, 1, 1});
|
||||
// dry run
|
||||
stream.enqueue(kernel, grid, {nthreads, 1, 1});
|
||||
stream.synchronize();
|
||||
double ts = t.get().count()*1e-9;
|
||||
// benchmark
|
||||
double ts = bench([&](){stream.enqueue(kernel, grid, {nthreads, 1, 1});},
|
||||
[&](){ stream.synchronize(); },
|
||||
context.device());
|
||||
ts = ts * 1e-9;
|
||||
double tflops = 2*M*N*K / ts * 1e-12;
|
||||
std::cout << tflops << std::endl;
|
||||
return ts;
|
||||
return tflops;
|
||||
};
|
||||
|
||||
|
||||
// just-in-time compile source-code
|
||||
std::vector<unsigned> params = {
|
||||
// a0
|
||||
8, 2, 16,
|
||||
// b0
|
||||
4, 4, 16,
|
||||
// c
|
||||
8, 4, 2, 4,
|
||||
// a1
|
||||
4, 2, 8,
|
||||
// b1
|
||||
8, 1
|
||||
16, 2, 64,
|
||||
32, 2, 64,
|
||||
16, 8, 2, 2,
|
||||
8, 1, 8,
|
||||
4, 1
|
||||
};
|
||||
triton::jit jit(context);
|
||||
jit.autotune(src, benchmark);
|
||||
|
||||
// jit.autotune(src, benchmark);
|
||||
jit.add_module(src, params);
|
||||
triton::driver::kernel kernel = jit.get_function("matmul");
|
||||
triton::jit::launch_information info = jit.get_launch_info("matmul");
|
||||
benchmark(kernel, info);
|
||||
std::cout << benchmark(kernel, info) << std::endl;
|
||||
stream.read(dc, true, 0, hc);
|
||||
simple_gemm(rc, ha, hb, M, N, K);
|
||||
for(size_t i = 0; i < M*N; i++)
|
||||
|
Reference in New Issue
Block a user