diff --git a/examples/matrix.cpp b/examples/matrix.cpp index dfa64e5c3..ef5dbf36d 100644 --- a/examples/matrix.cpp +++ b/examples/matrix.cpp @@ -81,13 +81,37 @@ private: high_resolution_clock::time_point _start; }; +template +T min(std::vector x) +{ return *std::min_element(x.begin(), x.end()); } + + +template +double bench(OP const & op, SYNC const & sync, triton::driver::device const & device) +{ + timer tmr; + std::vector times; + double total_time = 0; + op(); + sync(); + while(total_time*1e-9 < 1e-3){ + float norm = (float)device.current_sm_clock()/device.max_sm_clock(); + tmr.start(); + op(); + sync(); + times.push_back(norm*tmr.get().count()); + total_time+=times.back(); + } + return min(times); +} + int main() { // initialize default compute device auto context = triton::driver::backend::contexts::get_default(); + triton::jit jit(context); // matrix multiplication parameters size_t M = 512, N = 512, K = 512; - size_t bound = 8; std::vector hc(M*N); std::vector rc(M*N); std::vector ha(M*K); @@ -112,6 +136,22 @@ int main() { // benchmark a given matrix multiplication kernel auto benchmark = [&](triton::driver::kernel kernel, triton::jit::launch_information info) { + // launch info + unsigned TM = info.global_range_size[0]; + unsigned TN = info.global_range_size[1]; + unsigned nthreads = info.num_threads; + std::array grid = {(M + TM - 1)/TM, (N + TN - 1)/TN, 1}; + // fast bounds-checking + unsigned TK = jit.get_int("TK"); + unsigned lasti = (grid[0]*TM - 1)*TM + TM - 1; + unsigned lastj = (grid[1]*TN - 1)*TN + TN - 1; + unsigned lastk = TK - 1; + bool AT = false; + bool BT = true; + unsigned last_safe_a = (AT==false)?(M*K - 1 - lasti)/M - lastk : M*K - 1 - lasti*K - lastk; + unsigned last_safe_b = (BT==true)?(N*K - 1 - lastj)/N - lastk : N*K - 1 - lastj*K - lastk; + int32_t bound = std::max(1, std::max(K - last_safe_a, K - last_safe_b)); + // set argument kernel.setArg(0, da); kernel.setArg(1, db); kernel.setArg(2, dc); @@ -119,39 +159,33 @@ int main() { kernel.setArg(4, N); kernel.setArg(5, K); kernel.setArg(6, bound); - unsigned TM = info.global_range_size[0]; - unsigned TN = info.global_range_size[1]; - unsigned nthreads = info.num_threads; - timer t; - t.start(); - stream.enqueue(kernel, {(M + TM - 1)/TM, (N + TN - 1)/TN, 1}, {nthreads, 1, 1}); + // dry run + stream.enqueue(kernel, grid, {nthreads, 1, 1}); stream.synchronize(); - double ts = t.get().count()*1e-9; + // benchmark + double ts = bench([&](){stream.enqueue(kernel, grid, {nthreads, 1, 1});}, + [&](){ stream.synchronize(); }, + context.device()); + ts = ts * 1e-9; double tflops = 2*M*N*K / ts * 1e-12; - std::cout << tflops << std::endl; - return ts; + return tflops; }; // just-in-time compile source-code std::vector params = { - // a0 - 8, 2, 16, - // b0 - 4, 4, 16, - // c - 8, 4, 2, 4, - // a1 - 4, 2, 8, - // b1 - 8, 1 + 16, 2, 64, + 32, 2, 64, + 16, 8, 2, 2, + 8, 1, 8, + 4, 1 }; - triton::jit jit(context); - jit.autotune(src, benchmark); + +// jit.autotune(src, benchmark); jit.add_module(src, params); triton::driver::kernel kernel = jit.get_function("matmul"); triton::jit::launch_information info = jit.get_launch_info("matmul"); - benchmark(kernel, info); + std::cout << benchmark(kernel, info) << std::endl; stream.read(dc, true, 0, hc); simple_gemm(rc, ha, hb, M, N, K); for(size_t i = 0; i < M*N; i++) diff --git a/include/triton/ir/module.h b/include/triton/ir/module.h index 4ec681f67..3d2d5afb9 100644 --- a/include/triton/ir/module.h +++ b/include/triton/ir/module.h @@ -74,12 +74,15 @@ public: functions_list_t &get_function_list() { return functions_; } function *get_or_insert_function(const std::string &name, function_type *ty); // Scope - void add_new_scope() { if(scopes_.empty()) scopes_.push(scope()); else scopes_.push(scope(get_scope())); } - void pop_scope() { scopes_.pop(); } - scope& get_scope() { return scopes_.top(); } + void add_new_scope() { if(scopes_.empty()) scopes_.push(scope()); else scopes_.push(scope(get_scope())); } + void pop_scope() { scopes_.pop(); } + scope& get_scope() { return scopes_.top(); } // Const allocation - void add_alloc(ir::alloc_const* x) { allocs_.push_back(x); } - const std::vector& allocs() { return allocs_; } + void add_alloc(ir::alloc_const* x) { allocs_.push_back(x); } + const std::vector& allocs() { return allocs_; } + // Register global + void register_global(const std::string& name, ir::value *x) { globals_[name] = x; } + const std::map& globals() const { return globals_; } private: std::string name_; @@ -96,6 +99,7 @@ private: std::map current_phi_; std::stack scopes_; std::vector allocs_; + std::map globals_; }; } diff --git a/include/triton/jit.h b/include/triton/jit.h index a01c43685..0d90d63b0 100644 --- a/include/triton/jit.h +++ b/include/triton/jit.h @@ -39,7 +39,7 @@ public: std::vector global_range_size; unsigned num_threads; }; - typedef std::function benchmark_t; + typedef std::function benchmark_t; struct passes_wrapper { passes_wrapper(): shared(&buffer_info), liveness(&buffer_info), @@ -80,6 +80,7 @@ public: void add_module(const std::string &src, const std::vector& params = {}); driver::kernel get_function(const std::string &name); launch_information get_launch_info(const std::string &name); + unsigned get_int(const std::string &name); private: std::vector modules_; @@ -87,6 +88,7 @@ private: llvm::LLVMContext llvm_context_; ir::context triton_context_; std::map launch_info_map_; + std::map global_ints_; }; diff --git a/lib/ast/lowering.cpp b/lib/ast/lowering.cpp index 5dda59ce9..04d03aa99 100644 --- a/lib/ast/lowering.cpp +++ b/lib/ast/lowering.cpp @@ -412,7 +412,8 @@ ir::value* initializer::codegen(ir::module * mod) const{ if(std::find(storage.begin(), storage.end(), TUNABLE_T) != storage.end()){ assert(expr_ == nullptr); //TODO: implement ranges - value = ir::metaparameter::create(mod->get_context(), ty, 8, 64); + value = ir::metaparameter::create(mod->get_context(), ty, 8, (name=="TK")?8:64); + mod->register_global(name, value); } if(expr_){ value = expr_->codegen(mod); diff --git a/lib/codegen/tune.cpp b/lib/codegen/tune.cpp index dcf817ec8..f3a9cedfb 100644 --- a/lib/codegen/tune.cpp +++ b/lib/codegen/tune.cpp @@ -144,7 +144,7 @@ void tune::run(ir::module &mod) { // Layout parameters while(!nodes_.empty()){ ir::type *ty = mod.get_builder().get_int32_ty(); - ir::metaparameter *nts = ir::metaparameter::create(ctx, ty, 2, 2); + ir::metaparameter *nts = ir::metaparameter::create(ctx, ty, 1, 2); ir::metaparameter *mts = ir::metaparameter::create(ctx, ty, 4, 32); connected_components(*nodes_.begin(), {nts, mts}, nodes_, dependencies_); } diff --git a/lib/jit.cpp b/lib/jit.cpp index 150ff40a6..64e0865fa 100644 --- a/lib/jit.cpp +++ b/lib/jit.cpp @@ -111,6 +111,7 @@ void jit::autotune(const std::string &src, benchmark_t benchmark) { } // iterate over parameters unsigned i; + double best = 0; loop_nest(ranges, [&](const std::vector params){ std::map> errors; i = 0; @@ -142,7 +143,12 @@ void jit::autotune(const std::string &src, benchmark_t benchmark) { launch_information info = launch_info_map_.at("matmul"); for(unsigned p: params) std::cout << p << " " << std::flush; - benchmark(kernel, info); + // add globals + for(auto x: tt_module.globals()) + global_ints_[x.first] = ((ir::metaparameter*)x.second)->get_value(); + double perf = benchmark(kernel, info); + best = std::max(perf, best); + std::cout << perf << " [ " << best << " ] " << std::endl; }); } @@ -166,6 +172,9 @@ void jit::add_module(ir::module &tt_module, const std::vector ¶ms) auto ll_module = make_llvm_module(tt_module, passes); // llvm module -> machine code modules_.push_back(driver::module(driver_context_, &*ll_module)); + // add globals + for(auto x: tt_module.globals()) + global_ints_[x.first] = ((ir::metaparameter*)x.second)->get_value(); } void jit::add_module(const std::string &src, const std::vector ¶ms) { @@ -181,4 +190,8 @@ jit::launch_information jit::get_launch_info(const std::string &name) { return launch_info_map_.at(name); } +unsigned jit::get_int(const std::string &name){ + return global_ints_.at(name); +} + }