changed auto-tuner parameter ranges

This commit is contained in:
Philippe Tillet
2019-06-25 19:27:49 -07:00
parent d945ce5e1b
commit 25e9a10917
2 changed files with 7 additions and 7 deletions

View File

@@ -69,14 +69,14 @@ class BlockSparseGemmOp : public OpKernel {
[&](){ stream->synchronize(); }, ctx->device());
return 2.*M*N*K / ts * 1e-3;
};
std::string src = triton::dnn::gemm::src(false, false, "fp16", "fp16", 1, 1);
std::string src = triton::dnn::gemm::src(false, true, "fp16", "fp16", 1, 1);
// just-in-time compile source-code
// jit.autotune("matmul", src.c_str(), benchmark);
jit.autotune("matmul", src.c_str(), benchmark);
// jit.add_module("matmul", src.c_str(), {4, 2, 8, 4, 2, 32, 1, 4, 1, 1, 8, 8, 8, 1});
// jit.add_module("matmul", src.c_str(), {16, 4, 128, 16, 4, 128, 2, 2, 2, 2, 8, 32, 8, 1});
// jit.add_module("matmul", src.c_str(), {8, 8, 128, 16, 8, 128, 2, 2, 2, 2, 16, 32, 8, 1 });
// jit.add_module("matmul", src.c_str(), {16, 4, 128, 16, 4, 128, 2, 2, 2, 2, 8, 16, 8, 1});
// jit.add_module("matmul", src.c_str(), {16, 2, 128, 32, 32, 2, 2, 2, 2, 8, 8, 4, 2, 1}); //NN
jit.add_module("matmul", src.c_str(), {16, 2, 128, 32, 32, 2, 2, 2, 2, 8, 8, 4, 2, 1}); //NN
triton::driver::kernel* kernel = jit.get_function("matmul");
triton::jit::launch_information info = jit.get_launch_info("matmul");
std::cout << benchmark(kernel, info) << std::endl;;