[examples] some cleaning

This commit is contained in:
Philippe Tillet
2019-03-09 17:17:55 -05:00
parent 9a3537662d
commit 9e2cfddf4c

View File

@@ -64,6 +64,11 @@ void simple_gemm(std::vector<T> &c, const std::vector<T> &a, const std::vector<T
}
int main() {
// initialize JIT on default device
auto context = triton::driver::backend::contexts::get_default();
triton::jit jit(context);
// add module
std::vector<unsigned> params = {
// a0
2, 8, 1, 16,
@@ -76,13 +81,7 @@ int main() {
// b1
1, 8, 1
};
auto context = triton::driver::backend::contexts::get_default();
triton::jit jit(context);
jit.add_module(src, params);
triton::driver::kernel kernel = jit.get_function("matmul");
triton::jit::launch_information info = jit.get_launch_info("matmul");
size_t M = 128, N = 128, K = 128;
size_t bound = 8;
@@ -104,6 +103,7 @@ int main() {
stream.write(da, true, 0, ha);
stream.write(db, true, 0, hb);
stream.write(dc, true, 0, hc);
triton::driver::kernel kernel = jit.get_function("matmul");
kernel.setArg(0, da);
kernel.setArg(1, db);
kernel.setArg(2, dc);
@@ -111,6 +111,7 @@ int main() {
kernel.setArg(4, N);
kernel.setArg(5, K);
kernel.setArg(6, bound);
triton::jit::launch_information info = jit.get_launch_info("matmul");
unsigned TM = info.global_range_size[0];
unsigned TN = info.global_range_size[1];
unsigned nthreads = info.num_threads;