From 81eba3e1ec5d1532b88143b548b8708b1a942fe8 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Thu, 6 Jun 2019 19:36:41 -0700 Subject: [PATCH] ugh --- include/triton/codegen/tune.h | 1 + lib/codegen/tune.cpp | 65 +++++++++++++++++++++++++++-------- 2 files changed, 52 insertions(+), 14 deletions(-) diff --git a/include/triton/codegen/tune.h b/include/triton/codegen/tune.h index 43a731a32..6c08f2ea0 100644 --- a/include/triton/codegen/tune.h +++ b/include/triton/codegen/tune.h @@ -33,6 +33,7 @@ private: fragment_t get_fragmentation_type(node_t x, graph_t &graph); void connected_components(node_t x, const std::vector mps, const std::vector prefixes, std::set &nodes, graph_t &graph, unsigned group_id); void create_grids(std::vector &grids, std::map &references, ir::function *fn); + unsigned get_req_num_threads(ir::instruction *i); public: diff --git a/lib/codegen/tune.cpp b/lib/codegen/tune.cpp index f567128f0..4b8e405bc 100644 --- a/lib/codegen/tune.cpp +++ b/lib/codegen/tune.cpp @@ -100,8 +100,9 @@ void tune::init_c_graph(ir::instruction *v) { else if(dynamic_cast(v)) { for(unsigned k = 0; k < v->get_num_results(); k++) for(unsigned i = 0; i < shapes.size(); i ++){ + ir::value *result = v->get_result(k); for(ir::value* op: v->ops()){ - add_constraint({v->get_result(k), i}, {op, i}); + add_constraint({result, i}, {op, i}); } } } @@ -199,20 +200,23 @@ void tune::run(ir::module &mod) { init_c_phi(i); // Layout parameters unsigned group_id = 0; +// for(auto x: nodes_){ +// fragments_[x] = STRIDED_SCAN; +// } while(!nodes_.empty()) { ir::type *ty = mod.get_builder().get_int32_ty(); node_t node = *nodes_.begin(); - fragment_t fragment = get_fragmentation_type(node, dependencies_); - if(fragment == STRIDED_SCAN) { +// if(fragments_[node] == STRIDED_SCAN) { ir::metaparameter *nts = ir::metaparameter::create(ctx, ty, 1, 1); ir::metaparameter *mts = ir::metaparameter::create(ctx, ty, 4, 32); connected_components(node, {nts, mts}, {"nts", "mts"}, nodes_, dependencies_, group_id++); nts->set_value(1); - } - else { - ir::metaparameter *fpw = ir::metaparameter::create(ctx, ty, 1, 4); - connected_components(node, {fpw}, {"fpw"}, nodes_, dependencies_, group_id++); - } +// } +// else { +// ir::metaparameter *fpw = ir::metaparameter::create(ctx, ty, 1, 4); +// ir::metaparameter *wpb = ir::metaparameter::create(ctx, ty, 1, 4); +// connected_components(node, {fpw, wpb}, {"fpw", "wpb"}, nodes_, dependencies_, group_id++); +// } } } @@ -220,6 +224,8 @@ void tune::run(ir::module &mod) { for(ir::function *fn: mod.get_function_list()) for(ir::basic_block *block: fn->blocks()) for(ir::instruction *i : block->get_inst_list()){ +// if(fragments_.find({i, 0}) != fragments_.end() && fragments_.at({i, 0}) != STRIDED_SCAN) +// continue; if(dynamic_cast(i) && i->get_type()->is_tile_ty()){ ir::type *ty = mod.get_builder().get_int32_ty(); std::unique_ptr tmp(ir::metaparameter::create(ctx, ty, 2, 2)); @@ -250,6 +256,23 @@ void tune::init(ir::module &mod) { } } +unsigned tune::get_req_num_threads(ir::instruction *i){ +// if(fragments_.at({i, 0}) == STRIDED_SCAN) { +// unsigned result = 1; +// for(unsigned k = 0; k < i->get_type()->get_tile_shapes().size(); k++){ +// std::string suffix = ".d" + std::to_string(k); +// result *= params_.at(i).at("mts" + suffix)->get_value(); +// } +// } +// else { + unsigned result = 32; + for(unsigned k = 0; k < i->get_type()->get_tile_shapes().size(); k++){ + std::string suffix = ".d" + std::to_string(k); + result *= params_.at(i).at("wpt" + suffix)->get_value(); + } +// } +} + void tune::create_grids(std::vector &grids, std::map &references, ir::function *fn) { @@ -307,16 +330,30 @@ bool tune::check_constraints(std::map> &er // must device the shape for(size_t k = 0; k < shapes.size(); k++) { std::string strk = to_string(k); - ir::metaparameter *mts = params_[i]["mts.d" + strk]; - ir::metaparameter *nts = params_[i]["nts.d" + strk]; - unsigned multiple = mts->get_value()*nts->get_value(); + unsigned multiple; +// if(fragments_.at({i, 0}) == STRIDED_SCAN) { + ir::metaparameter *mts = params_[i]["mts.d" + strk]; + ir::metaparameter *nts = params_[i]["nts.d" + strk]; + multiple = mts->get_value()*nts->get_value(); +// } +// else { +// ir::metaparameter *fpw = params_[i]["fpw.d" + strk]; +// ir::metaparameter *wpt = params_[i]["wpt.d" + strk]; +// multiple = fpw->get_value()*wpt->get_value(); +// } if(shapes[k]->get_value() % multiple != 0) errors[i].push_back("for dim " + strk + ": shape (" + to_string(shapes[k]->get_value()) + ")" " is not a multiple of layout (" + to_string(multiple) + ")"); } - int num_threads = 1; - for(size_t k = 0; k < shapes.size(); k++) - num_threads *= params_[i]["mts.d" + to_string(k)]->get_value(); + // the product of mma fragments per warp must be 4 +// if(fragments_.at({i, 0}) == HMMA_FRAGMENT_C){ +// unsigned prod = 1; +// for(size_t k = 0; k < shapes.size(); k++) +// prod *= params_[i]["fpw.d" + std::to_string(k)]->get_value(); +// if(prod != 4) +// errors[i].push_back("HMMA must have only 4 fragments per warp"); +// } + int num_threads = get_req_num_threads(i); if(num_threads % 32 != 0) errors[i].push_back("number of threads per block (" + to_string(num_threads) + ") must be multiple of warp size"); if(num_threads != num_threads_)