diff --git a/examples/matrix.cpp b/examples/matrix.cpp index 0548075e7..19cf9d036 100644 --- a/examples/matrix.cpp +++ b/examples/matrix.cpp @@ -64,6 +64,11 @@ void simple_gemm(std::vector &c, const std::vector &a, const std::vector params = { // a0 2, 8, 1, 16, @@ -76,13 +81,7 @@ int main() { // b1 1, 8, 1 }; - - - auto context = triton::driver::backend::contexts::get_default(); - triton::jit jit(context); jit.add_module(src, params); - triton::driver::kernel kernel = jit.get_function("matmul"); - triton::jit::launch_information info = jit.get_launch_info("matmul"); size_t M = 128, N = 128, K = 128; size_t bound = 8; @@ -104,6 +103,7 @@ int main() { stream.write(da, true, 0, ha); stream.write(db, true, 0, hb); stream.write(dc, true, 0, hc); + triton::driver::kernel kernel = jit.get_function("matmul"); kernel.setArg(0, da); kernel.setArg(1, db); kernel.setArg(2, dc); @@ -111,6 +111,7 @@ int main() { kernel.setArg(4, N); kernel.setArg(5, K); kernel.setArg(6, bound); + triton::jit::launch_information info = jit.get_launch_info("matmul"); unsigned TM = info.global_range_size[0]; unsigned TN = info.global_range_size[1]; unsigned nthreads = info.num_threads;