From 5f292630442cc2ffdda120eccd31dd3d831ee511 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Sat, 9 Mar 2019 12:05:12 -0500 Subject: [PATCH] [code generation] now using ir::metaparameter* for all tunable metaparameters --- include/triton/codegen/selection.h | 6 ++-- include/triton/codegen/tune.h | 17 +++++---- include/triton/ir/builder.h | 5 +++ include/triton/ir/constant.h | 4 ++- lib/codegen/selection.cpp | 16 ++++----- lib/codegen/tune.cpp | 57 ++++++++++++++++-------------- lib/codegen/vectorize.cpp | 2 +- lib/ir/builder.cpp | 15 ++++++++ lib/ir/constant.cpp | 2 +- lib/jit.cpp | 5 +-- 10 files changed, 80 insertions(+), 49 deletions(-) diff --git a/include/triton/codegen/selection.h b/include/triton/codegen/selection.h index 5c81ca8a0..291fbf827 100644 --- a/include/triton/codegen/selection.h +++ b/include/triton/codegen/selection.h @@ -117,9 +117,9 @@ private: // grid construction void create_grids(std::vector &grids, - std::map &references, + std::map &references, ir::function *fn); - void create_tile(ir::value *v, llvm::IRBuilder<> &builder, const std::map &references, std::set &seen, llvm::Value *sh_mem_ptr); + void create_tile(ir::value *v, llvm::IRBuilder<> &builder, const std::map &references, std::set &seen, llvm::Value *sh_mem_ptr); void init_axes(ir::value *i, llvm::IRBuilder<> &builder, llvm::Value *u_thread_id, llvm::Value *u_warp_id); void init_grids(ir::function *fn, llvm::IRBuilder<> &builder, llvm::Value *sh_mem_ptr); @@ -139,7 +139,7 @@ private: allocation *alloc_; tune *params_; buffer_info_pass *buffer_info_; - std::map axes_; + std::map axes_; }; } diff --git a/include/triton/codegen/tune.h b/include/triton/codegen/tune.h index cb1d5b509..5979290fa 100644 --- a/include/triton/codegen/tune.h +++ b/include/triton/codegen/tune.h @@ -12,6 +12,7 @@ namespace ir{ class module; class instruction; class function; + class metaparameter; } namespace codegen{ @@ -24,24 +25,28 @@ private: void add_constraint(node_t x, node_t y); void init_c_phi(ir::instruction *i); void init_c_graph(ir::instruction *v); - void connected_components(node_t x, const std::vector vals, std::set &nodes, graph_t &graph); - void create_grids(std::vector &grids, std::map &references, ir::function *fn); + void connected_components(node_t x, const std::vector mps, std::set &nodes, graph_t &graph); + void create_grids(std::vector &grids, std::map &references, ir::function *fn); public: - std::vector get_params(ir::module& mod); - std::map get_params(ir::instruction* i); - unsigned *get_param(ir::value *value, const std::string &key) { return params_[value][key]; } + std::vector get_params(ir::module& mod); + std::map get_params(ir::instruction* i); + ir::metaparameter* get_param(ir::value *value, const std::string &key) { return params_[value][key]; } void copy(ir::value *dst, ir::value *src) { params_[dst] = params_[src]; } bool check_constraints(ir::module &fn, std::map> &errors); void run(ir::module &mod); + ir::metaparameter* get_num_threads(); + ir::metaparameter* get_global_range_size(unsigned axis); private: - std::map> params_; std::vector pool_; graph_t dependencies_; std::set nodes_; std::map static_params_; + std::map> params_; + ir::metaparameter *num_threads_; + std::vector global_range_sizes_; }; diff --git a/include/triton/ir/builder.h b/include/triton/ir/builder.h index 509ae8e47..852f55aa9 100644 --- a/include/triton/ir/builder.h +++ b/include/triton/ir/builder.h @@ -35,6 +35,11 @@ public: // Constants value *get_int32(unsigned val); // Types + type *get_int1_ty(); + type *get_int8_ty(); + type *get_int16_ty(); + type *get_int32_ty(); + type *get_int64_ty(); type *get_float_ty(); type *get_double_ty(); // Insert diff --git a/include/triton/ir/constant.h b/include/triton/ir/constant.h index 9f2baf618..e3bd2ab24 100644 --- a/include/triton/ir/constant.h +++ b/include/triton/ir/constant.h @@ -49,11 +49,13 @@ class metaparameter: public constant_int{ public: static metaparameter *create(context &ctx, type *ty, unsigned lo, unsigned hi); - void set_value(uint64_t value) { value_ = value; } + void set_value(uint64_t value) { has_value_ = true; value_ = value; } + bool has_value() { return has_value_; } private: unsigned lo_; unsigned hi_; + bool has_value_; }; /* constant range */ diff --git a/lib/codegen/selection.cpp b/lib/codegen/selection.cpp index 3f79c9375..a9f7e8524 100644 --- a/lib/codegen/selection.cpp +++ b/lib/codegen/selection.cpp @@ -379,9 +379,9 @@ void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id std::vector n_warps(dim); for(unsigned i = 0; i < shapes.size(); i++){ std::string str_i = std::to_string(i); - contiguous[i] = *params_->get_param(v, "p0.d" + str_i); - warp_size[i] = *params_->get_param(v, "p1.d" + str_i); - n_warps[i] = *params_->get_param(v, "p2.d" + str_i); + contiguous[i] = params_->get_param(v, "p0.d" + str_i)->get_value(); + warp_size[i] = params_->get_param(v, "p1.d" + str_i)->get_value(); + n_warps[i] = params_->get_param(v, "p2.d" + str_i)->get_value(); } std::vector thread_id_in_warp = delinearize(u_thread_id, warp_size, builder); std::vector warp_id = delinearize(u_warp_id, n_warps, builder); @@ -404,7 +404,7 @@ void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id } void selection::create_grids(std::vector &grids, - std::map &references, + std::map &references, ir::function *fn) { // get number of dimensions greater than 1 auto get_tile_gt1_dim = [&](ir::value *v){ @@ -432,7 +432,7 @@ void selection::create_grids(std::vector &grids, for(size_t d = 0; d < shapes.size(); d++){ if(shapes[d]->get_value() == 1) continue; - unsigned *x = params_->get_param(v, "p0.d" + std::to_string(d)); + ir::metaparameter *x = params_->get_param(v, "p0.d" + std::to_string(d)); ir::value *&r = references[x]; if(!r || get_tile_gt1_dim(v) > get_tile_gt1_dim(r)) r = v; @@ -457,7 +457,7 @@ bool static inline has_phi_user(ir::value *v) { return false; } void selection::create_tile(ir::value *v, IRBuilder<> &builder, - const std::map& references, + const std::map& references, std::set &seen, Value *sh_mem_ptr) { if(!v->get_type()->is_tile_ty() || !seen.insert(v).second) return; @@ -517,7 +517,7 @@ void selection::create_tile(ir::value *v, IRBuilder<> &builder, std::vector axes(shapes.size()); for(size_t d = 0; d < shapes.size(); d++){ if(shapes[d]->get_value() > 1){ - unsigned *x = params_->get_param(v, "p0.d" + std::to_string(d)); + ir::metaparameter *x = params_->get_param(v, "p0.d" + std::to_string(d)); axes[d] = axes_.at(x); } else{ @@ -549,7 +549,7 @@ void selection::init_grids(ir::function *fn, IRBuilder<> &builder, Value *sh_mem Value *u_warp_id = builder.CreateUDiv(u_thread_id, warp_size); // create grid std::vector grids; - std::map references; + std::map references; create_grids(grids, references, fn); for(ir::value* i: grids){ if(auto *instr = dynamic_cast(i)) diff --git a/lib/codegen/tune.cpp b/lib/codegen/tune.cpp index 0c64401de..4972abd71 100644 --- a/lib/codegen/tune.cpp +++ b/lib/codegen/tune.cpp @@ -4,6 +4,7 @@ #include "triton/ir/module.h" #include "triton/ir/function.h" #include "triton/ir/context_impl.h" +#include "triton/ir/constant.h" #include @@ -77,43 +78,44 @@ void tune::init_c_graph(ir::instruction *v) { } } -void tune::connected_components(node_t x, const std::vector vals, std::set &nodes, graph_t &graph) { +void tune::connected_components(node_t x, const std::vector mps, std::set &nodes, graph_t &graph) { if(nodes.find(x) != nodes.end()){ nodes.erase(x); std::string suffix = ".d" + std::to_string(x.second); - params_[x.first].insert({"p0" + suffix, vals[0]}); - params_[x.first].insert({"p1" + suffix, vals[1]}); - params_[x.first].insert({"p2" + suffix, vals[2]}); + params_[x.first].insert({"p0" + suffix, mps[0]}); + params_[x.first].insert({"p1" + suffix, mps[1]}); + params_[x.first].insert({"p2" + suffix, mps[2]}); if(static_params_.find(x) != static_params_.end()){ - *vals[0] = static_params_.at(x); - *vals[1] = static_params_.at(x); - *vals[2] = static_params_.at(x); + mps[0]->set_value(static_params_.at(x)); + mps[1]->set_value(static_params_.at(x)); + mps[2]->set_value(static_params_.at(x)); } for(const node_t &y: graph[x]) - connected_components(y, vals, nodes, graph); + connected_components(y, mps, nodes, graph); } } -std::vector tune::get_params(ir::module &mod) { - std::vector result; - std::set seen; +std::vector tune::get_params(ir::module &mod) { + std::vector result; + std::set seen; for(ir::function *fn: mod.get_function_list()) for(ir::basic_block *block: fn->blocks()) for(ir::instruction *i : block->get_inst_list()) for(auto &x: params_[i]) - if(seen.insert(x.second).second && *x.second == 0){ + if(seen.insert(x.second).second && !x.second->has_value()){ result.push_back(x.second); } return result; } -std::map tune::get_params(ir::instruction* i) { +std::map tune::get_params(ir::instruction* i) { return params_.at(i); } void tune::run(ir::module &mod) { + ir::context &ctx = mod.get_context(); for(ir::function *fn: mod.get_function_list()){ // Build constraints graph for(ir::basic_block *block: fn->blocks()) @@ -128,16 +130,17 @@ void tune::run(ir::module &mod) { init_c_phi(i); // Layout parameters while(!nodes_.empty()){ - unsigned *v0 = new unsigned(0); - unsigned *v1 = new unsigned(0); - unsigned *v2 = new unsigned(0); - connected_components(*nodes_.begin(), {v0, v1, v2}, nodes_, dependencies_); + ir::type *ty = mod.get_builder().get_int32_ty(); + ir::metaparameter *mp0 = ir::metaparameter::create(ctx, ty, 1, 4); + ir::metaparameter *mp1 = ir::metaparameter::create(ctx, ty, 4, 32); + ir::metaparameter *mp2 = ir::metaparameter::create(ctx, ty, 4, 32); + connected_components(*nodes_.begin(), {mp0, mp1, mp2}, nodes_, dependencies_); } } } void tune::create_grids(std::vector &grids, - std::map &references, + std::map &references, ir::function *fn) { // get number of dimensions greater than 1 auto get_tile_gt1_dim = [&](ir::value *v){ @@ -154,7 +157,7 @@ void tune::create_grids(std::vector &grids, if(!i->get_type()->is_tile_ty()) continue; for(auto ¶m: params_.at(i)){ - if(*param.second == 1) + if(param.second->get_value() == 1) continue; ir::instruction *&r = references[param.second]; if(!r || get_tile_gt1_dim(i) > get_tile_gt1_dim(r)) @@ -173,14 +176,14 @@ for(ir::function *fn: mod.get_function_list()){ using std::to_string; // initialize grids - std::map references; + std::map references; std::vector grids; create_grids(grids, references, fn); // number of warps int num_warps = 1; for(size_t k = 0; k < grids.front()->get_type()->get_tile_shapes().size(); k++) - num_warps *= *params_[grids.front()]["p2.d" + to_string(k)]; + num_warps *= params_[grids.front()]["p2.d" + to_string(k)]->get_value(); // check constraints for(ir::instruction *i: grids){ @@ -190,10 +193,10 @@ for(ir::function *fn: mod.get_function_list()){ // must device the shape for(size_t k = 0; k < shapes.size(); k++) { std::string strk = to_string(k); - unsigned *s0 = params_[i]["p0.d" + strk]; - unsigned *s1 = params_[i]["p1.d" + strk]; - unsigned *s2 = params_[i]["p2.d" + strk]; - unsigned multiple = (*s0)*(*s1)*(*s2); + ir::metaparameter *mp0 = params_[i]["p0.d" + strk]; + ir::metaparameter *mp1 = params_[i]["p1.d" + strk]; + ir::metaparameter *mp2 = params_[i]["p2.d" + strk]; + unsigned multiple = mp0->get_value()*mp1->get_value()*mp2->get_value(); if(shapes[k]->get_value() % multiple != 0) errors[i].push_back("for dim " + strk + ": shape (" + to_string(shapes[k]->get_value()) + ")" " is not a multiple of layout (" + to_string(multiple) + ")"); @@ -201,14 +204,14 @@ for(ir::function *fn: mod.get_function_list()){ // the number of thread per warp must be 32 int num_threads = 1; for(size_t k = 0; k < shapes.size(); k++) - num_threads *= *params_[i]["p1.d" + to_string(k)]; + num_threads *= params_[i]["p1.d" + to_string(k)]->get_value(); if(num_threads != 32) errors[i].push_back("number of threads per warp (" + to_string(num_threads) + ") must be 32"); // The number of warps required by the layout is the same // for all tiles in the function int required_num_warps = 1; for(size_t k = 0; k < shapes.size(); k++) - required_num_warps *= *params_[i]["p2.d" + to_string(k)]; + required_num_warps *= params_[i]["p2.d" + to_string(k)]->get_value(); if(required_num_warps != num_warps) errors[i].push_back("number of warps (" + to_string(required_num_warps) + ") must be " + to_string(num_warps)); } diff --git a/lib/codegen/vectorize.cpp b/lib/codegen/vectorize.cpp index 57c2142c9..672e97dc1 100644 --- a/lib/codegen/vectorize.cpp +++ b/lib/codegen/vectorize.cpp @@ -16,7 +16,7 @@ void vectorize::run(ir::module &mod) { for(ir::instruction *i: block->get_inst_list()) if(dynamic_cast(i)){ ir::value *x = i->get_operand(0); - if(*params_->get_param(x, "p0.d0") == 1) + if(params_->get_param(x, "p0.d0")->get_value() == 1) continue; builder.set_insert_point(i); ir::instruction *rx = (ir::instruction*)builder.create_vectorize(x); diff --git a/lib/ir/builder.cpp b/lib/ir/builder.cpp index b3c1174ce..db0ae9e94 100644 --- a/lib/ir/builder.cpp +++ b/lib/ir/builder.cpp @@ -41,6 +41,21 @@ value *builder::get_int32(unsigned val) { return constant_int::get(type::get_int32_ty(ctx_), val); } +type *builder::get_int1_ty() +{ return type::get_int1_ty(ctx_); } + +type *builder::get_int8_ty() +{ return type::get_int8_ty(ctx_); } + +type *builder::get_int16_ty() +{ return type::get_int16_ty(ctx_); } + +type *builder::get_int32_ty() +{ return type::get_int32_ty(ctx_); } + +type *builder::get_int64_ty() +{ return type::get_int64_ty(ctx_); } + type *builder::get_float_ty() { return type::get_float_ty(ctx_); } diff --git a/lib/ir/constant.cpp b/lib/ir/constant.cpp index 314714c04..bfb6fdb9b 100644 --- a/lib/ir/constant.cpp +++ b/lib/ir/constant.cpp @@ -99,7 +99,7 @@ constant *constant_fp::get(context &ctx, double v){ // metaparameter metaparameter::metaparameter(type *ty, unsigned lo, unsigned hi) - : constant_int(ty, 0), lo_(lo), hi_(hi){ } + : constant_int(ty, 0), lo_(lo), hi_(hi), has_value_(false){ } metaparameter* metaparameter::create(context &ctx, type *ty, unsigned lo, unsigned hi) { context_impl *impl = ctx.p_impl.get(); diff --git a/lib/jit.cpp b/lib/jit.cpp index 17787c352..9db98ca32 100644 --- a/lib/jit.cpp +++ b/lib/jit.cpp @@ -65,8 +65,9 @@ std::unique_ptr jit::make_llvm_module(ir::module &module, const st triton_context_.p_impl->mp_constants_[0]->set_value(params[0]); triton_context_.p_impl->mp_constants_[1]->set_value(params[1]); triton_context_.p_impl->mp_constants_[2]->set_value(params[2]); - for(unsigned *x: tune.get_params(module)) - *x = params[3 + i++]; + for(ir::metaparameter *x: tune.get_params(module)){ + x->set_value(params[3 + i++]); + } // constraints std::map> errors; tune.check_constraints(module, errors);