From 6045209d5bfa0226a8f31ee9383c0be1dc665521 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Thu, 6 Jun 2019 20:13:26 -0700 Subject: [PATCH] Now find correct tuning configuration --- examples/python/tensorflow/dot.cpp | 53 +++++++++++++++++------------- lib/codegen/tune.cpp | 8 +++-- lib/runtime/jit.cpp | 6 +++- 3 files changed, 40 insertions(+), 27 deletions(-) diff --git a/examples/python/tensorflow/dot.cpp b/examples/python/tensorflow/dot.cpp index d02c8a56e..01d18c435 100644 --- a/examples/python/tensorflow/dot.cpp +++ b/examples/python/tensorflow/dot.cpp @@ -4,6 +4,7 @@ #include "triton/driver/backend.h" #include "triton/driver/stream.h" #include "triton/runtime/jit.h" +#include "triton/tools/bench.hpp" #define EIGEN_USE_GPU #include "tensorflow/core/framework/op.h" @@ -125,30 +126,36 @@ class BlockSparseGemmOp : public OpKernel { triton::driver::cu_buffer dc(ctx, (CUdeviceptr)c->flat().data(), false); triton::driver::cu_buffer dlocks(ctx, (CUdeviceptr)locks.flat().data(), false); stream->synchronize(); + // benchmark a given matrix multiplication kernel + auto benchmark = [&](triton::driver::kernel* kernel, + triton::jit::launch_information info) { + // launch info + unsigned TM = info.global_range_size[0]; + unsigned TN = info.global_range_size[1]; + unsigned nthreads = info.num_threads; + unsigned GZ = jit.get_int("GZ"); + std::array grid = {(M + TM - 1)/TM, (N + TN - 1)/TN, GZ}; + // set argument + kernel->setArg(0, *da.cu()); + kernel->setArg(1, *db.cu()); + kernel->setArg(2, *dc.cu()); + kernel->setArg(3, M); + kernel->setArg(4, N); + kernel->setArg(5, K); + kernel->setArg(6, M); + kernel->setArg(7, N); + kernel->setArg(8, M); + kernel->setArg(9, *dlocks.cu()); + kernel->setArg(10, grid[0]); + kernel->setArg(11, grid[1]); + stream->enqueue(kernel, grid, {nthreads, 1, 1}); + stream->synchronize(); + double ts = triton::tools::bench([&](){stream->enqueue(kernel, grid, {nthreads, 1, 1});}, + [&](){ stream->synchronize(); }, nullptr); + return 2.*M*N*K / ts * 1e-3; + }; // just-in-time compile source-code - jit.add_module("matmul", src, {8, 2, 16, 8, 2, 16, 8, 8, 2, 2, 8, 8, 8, 1}); - triton::driver::kernel* kernel = jit.get_function("matmul"); - triton::jit::launch_information info = jit.get_launch_info("matmul"); - // launch info - unsigned TM = info.global_range_size[0]; - unsigned TN = info.global_range_size[1]; - unsigned nthreads = info.num_threads; - unsigned GZ = jit.get_int("GZ"); - std::array grid = {(M + TM - 1)/TM, (N + TN - 1)/TN, GZ}; - // set argument - kernel->setArg(0, *da.cu()); - kernel->setArg(1, *db.cu()); - kernel->setArg(2, *dc.cu()); - kernel->setArg(3, M); - kernel->setArg(4, N); - kernel->setArg(5, K); - kernel->setArg(6, M); - kernel->setArg(7, N); - kernel->setArg(8, M); - kernel->setArg(9, *dlocks.cu()); - kernel->setArg(10, grid[0]); - kernel->setArg(11, grid[1]); - stream->enqueue(kernel, grid, {nthreads, 1, 1}); + jit.autotune("matmul", src, benchmark); } private: diff --git a/lib/codegen/tune.cpp b/lib/codegen/tune.cpp index 4a9940400..a995a8a7c 100644 --- a/lib/codegen/tune.cpp +++ b/lib/codegen/tune.cpp @@ -213,8 +213,8 @@ void tune::run(ir::module &mod) { nts->set_value(1); } else { - ir::metaparameter *fpw = ir::metaparameter::create(ctx, ty, 1, 4); - ir::metaparameter *wpt = ir::metaparameter::create(ctx, ty, 1, 4); + ir::metaparameter *fpw = ir::metaparameter::create(ctx, ty, 2, 2); + ir::metaparameter *wpt = ir::metaparameter::create(ctx, ty, 1, 1); connected_components(node, {fpw, wpt}, {"fpw", "wpt"}, nodes_, dependencies_, group_id++); } } @@ -266,6 +266,7 @@ unsigned tune::get_req_num_threads(ir::instruction *i){ std::string suffix = ".d" + std::to_string(k); result *= params_.at(i).at("wpt" + suffix)->get_value(); } + return result; } } @@ -349,8 +350,9 @@ bool tune::check_constraints(std::map> &er // the product of mma fragments per warp must be 4 if(fragments_.at({i, 0}) == HMMA_FRAGMENT_C){ unsigned prod = 1; - for(size_t k = 0; k < shapes.size(); k++) + for(size_t k = 0; k < shapes.size(); k++){ prod *= params_[i]["fpw.d" + std::to_string(k)]->get_value(); + } if(prod != 4) errors[i].push_back("HMMA must have only 4 fragments per warp"); } diff --git a/lib/runtime/jit.cpp b/lib/runtime/jit.cpp index e03d51c63..5f5b161ea 100644 --- a/lib/runtime/jit.cpp +++ b/lib/runtime/jit.cpp @@ -120,7 +120,11 @@ jit::tune_res_t jit::autotune(const char *name, const char *src, benchmark_t ben mp->set_value(params[i++]); passes.target_independent(tt_module); passes.tune.init(tt_module); - if(!passes.tune.check_constraints(errors)) + passes.tune.check_constraints(errors); + for(auto x: errors) + for(auto err: x.second) + std::cout << err << std::endl; + if(!errors.empty()) return; // Deep copy of the module and tuner auto ptt_module = make_triton_module(name, src);