[python] basic tensorflow wrapper working

This commit is contained in:
Philippe Tillet
2019-08-26 16:53:49 -07:00
parent 0e0399f866
commit 4075949f80
26 changed files with 702 additions and 968 deletions

View File

@@ -153,8 +153,8 @@ perf_t do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int
opt.defines.push_back({"AT", {""}});
if(BT)
opt.defines.push_back({"BT", {""}});
opt.defines.push_back({"TM", {"128"}});
opt.defines.push_back({"TN", {"128"}});
opt.defines.push_back({"TM", {"32"}});
opt.defines.push_back({"TN", {"32"}});
opt.defines.push_back({"TK", {"32"}});
opt.num_warps = {1, 2, 4, 8};
rt::function function(src, opt);
@@ -208,7 +208,7 @@ int main() {
// shapes to benchmark
std::vector<config_t> configs = {
// {false, false, 8192, 512, 512},
{false, false, 128, 128, 128}
{false, true, 128, 128, 128}
// {false, true, 128, 128, 128},
// {false, false, 128, 128, 128},
// {true, false, 128, 128, 128},