Improvements w/ Auto-Tuning and standard benchmarks (#57)

[PYTHON] Bug-fixes in the auto-tuning module and improvement of the existing API for it
This commit is contained in:
Philippe Tillet
2021-02-03 13:37:21 -08:00
committed by Philippe Tillet
parent ad005d49ac
commit 6fb4800f57
12 changed files with 215 additions and 149 deletions

View File

@@ -158,21 +158,17 @@ float triton_dot(drv::context* context, drv::stream* stream,
stream->write(&*da, true, 0, ha);
stream->write(&*db, true, 0, hb);
// macros
rt::options_space_t opts;
// A access patterns
opts.defines.push_back({"STRIDE_AK", {AT? "1" : "lda" }});
opts.defines.push_back({"STRIDE_AM", {AT? "lda" : "1" }});
// B access patterns
opts.defines.push_back({"STRIDE_BK", {BT? "ldb" : "1" }});
opts.defines.push_back({"STRIDE_BN", {BT? "1" : "ldb" }});
// data-type
opts.defines.push_back({"TYPE", {ty}});
// tile sizes
opts.defines.push_back({"TM", {"128"}});
opts.defines.push_back({"TN", {"128"}});
opts.defines.push_back({"TK", {"32"}});
opts.defines.push_back({"TZ", {"1"}});
opts.num_warps = {4};
rt::options_t opt;
opt.defines["STRIDE_AK"] = AT? "1" : "lda";
opt.defines["STRIDE_AM"] = AT? "lda" : "1";
opt.defines["STRIDE_BK"] = BT? "ldb" : "1";
opt.defines["STRIDE_BN"] = BT? "1" : "ldb";
opt.defines["TYPE"] = ty;
opt.defines["TM"] = "128";
opt.defines["TN"] = "128";
opt.defines["TK"] = "32" ;
opt.defines["TZ"] = "1";
opt.num_warps = 4;
// arguments
std::stringstream oss;
rt::add_arg(oss, *da->cu());
@@ -187,7 +183,7 @@ float triton_dot(drv::context* context, drv::stream* stream,
rt::add_arg(oss, ldc);
rt::add_arg(oss, *dlocks->cu());
// function
rt::function function(src::dot, opts, device);
rt::function function(src::dot, opt, device);
// std::cout << function.get_kernels()[0].second->get_asm(rt::ASM_LLIR) << std::endl;
// grid
auto ceil = [](size_t x, size_t y) { return (x + y - 1) / y; };