[python] basic tensorflow wrapper working
This commit is contained in:
@@ -153,8 +153,8 @@ perf_t do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int
|
||||
opt.defines.push_back({"AT", {""}});
|
||||
if(BT)
|
||||
opt.defines.push_back({"BT", {""}});
|
||||
opt.defines.push_back({"TM", {"128"}});
|
||||
opt.defines.push_back({"TN", {"128"}});
|
||||
opt.defines.push_back({"TM", {"32"}});
|
||||
opt.defines.push_back({"TN", {"32"}});
|
||||
opt.defines.push_back({"TK", {"32"}});
|
||||
opt.num_warps = {1, 2, 4, 8};
|
||||
rt::function function(src, opt);
|
||||
@@ -208,7 +208,7 @@ int main() {
|
||||
// shapes to benchmark
|
||||
std::vector<config_t> configs = {
|
||||
// {false, false, 8192, 512, 512},
|
||||
{false, false, 128, 128, 128}
|
||||
{false, true, 128, 128, 128}
|
||||
// {false, true, 128, 128, 128},
|
||||
// {false, false, 128, 128, 128},
|
||||
// {true, false, 128, 128, 128},
|
||||
|
Reference in New Issue
Block a user