This commit is contained in:
Philippe Tillet
2019-06-06 19:36:41 -07:00
parent cdf5a0d011
commit 81eba3e1ec
2 changed files with 52 additions and 14 deletions

View File

@@ -33,6 +33,7 @@ private:
fragment_t get_fragmentation_type(node_t x, graph_t &graph); fragment_t get_fragmentation_type(node_t x, graph_t &graph);
void connected_components(node_t x, const std::vector<ir::metaparameter *> mps, const std::vector<std::string> prefixes, std::set<node_t> &nodes, graph_t &graph, unsigned group_id); void connected_components(node_t x, const std::vector<ir::metaparameter *> mps, const std::vector<std::string> prefixes, std::set<node_t> &nodes, graph_t &graph, unsigned group_id);
void create_grids(std::vector<ir::instruction*> &grids, std::map<ir::metaparameter *, ir::instruction *> &references, ir::function *fn); void create_grids(std::vector<ir::instruction*> &grids, std::map<ir::metaparameter *, ir::instruction *> &references, ir::function *fn);
unsigned get_req_num_threads(ir::instruction *i);
public: public:

View File

@@ -100,8 +100,9 @@ void tune::init_c_graph(ir::instruction *v) {
else if(dynamic_cast<ir::user*>(v)) { else if(dynamic_cast<ir::user*>(v)) {
for(unsigned k = 0; k < v->get_num_results(); k++) for(unsigned k = 0; k < v->get_num_results(); k++)
for(unsigned i = 0; i < shapes.size(); i ++){ for(unsigned i = 0; i < shapes.size(); i ++){
ir::value *result = v->get_result(k);
for(ir::value* op: v->ops()){ for(ir::value* op: v->ops()){
add_constraint({v->get_result(k), i}, {op, i}); add_constraint({result, i}, {op, i});
} }
} }
} }
@@ -199,20 +200,23 @@ void tune::run(ir::module &mod) {
init_c_phi(i); init_c_phi(i);
// Layout parameters // Layout parameters
unsigned group_id = 0; unsigned group_id = 0;
// for(auto x: nodes_){
// fragments_[x] = STRIDED_SCAN;
// }
while(!nodes_.empty()) { while(!nodes_.empty()) {
ir::type *ty = mod.get_builder().get_int32_ty(); ir::type *ty = mod.get_builder().get_int32_ty();
node_t node = *nodes_.begin(); node_t node = *nodes_.begin();
fragment_t fragment = get_fragmentation_type(node, dependencies_); // if(fragments_[node] == STRIDED_SCAN) {
if(fragment == STRIDED_SCAN) {
ir::metaparameter *nts = ir::metaparameter::create(ctx, ty, 1, 1); ir::metaparameter *nts = ir::metaparameter::create(ctx, ty, 1, 1);
ir::metaparameter *mts = ir::metaparameter::create(ctx, ty, 4, 32); ir::metaparameter *mts = ir::metaparameter::create(ctx, ty, 4, 32);
connected_components(node, {nts, mts}, {"nts", "mts"}, nodes_, dependencies_, group_id++); connected_components(node, {nts, mts}, {"nts", "mts"}, nodes_, dependencies_, group_id++);
nts->set_value(1); nts->set_value(1);
} // }
else { // else {
ir::metaparameter *fpw = ir::metaparameter::create(ctx, ty, 1, 4); // ir::metaparameter *fpw = ir::metaparameter::create(ctx, ty, 1, 4);
connected_components(node, {fpw}, {"fpw"}, nodes_, dependencies_, group_id++); // ir::metaparameter *wpb = ir::metaparameter::create(ctx, ty, 1, 4);
} // connected_components(node, {fpw, wpb}, {"fpw", "wpb"}, nodes_, dependencies_, group_id++);
// }
} }
} }
@@ -220,6 +224,8 @@ void tune::run(ir::module &mod) {
for(ir::function *fn: mod.get_function_list()) for(ir::function *fn: mod.get_function_list())
for(ir::basic_block *block: fn->blocks()) for(ir::basic_block *block: fn->blocks())
for(ir::instruction *i : block->get_inst_list()){ for(ir::instruction *i : block->get_inst_list()){
// if(fragments_.find({i, 0}) != fragments_.end() && fragments_.at({i, 0}) != STRIDED_SCAN)
// continue;
if(dynamic_cast<ir::load_inst*>(i) && i->get_type()->is_tile_ty()){ if(dynamic_cast<ir::load_inst*>(i) && i->get_type()->is_tile_ty()){
ir::type *ty = mod.get_builder().get_int32_ty(); ir::type *ty = mod.get_builder().get_int32_ty();
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 2, 2)); std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 2, 2));
@@ -250,6 +256,23 @@ void tune::init(ir::module &mod) {
} }
} }
unsigned tune::get_req_num_threads(ir::instruction *i){
// if(fragments_.at({i, 0}) == STRIDED_SCAN) {
// unsigned result = 1;
// for(unsigned k = 0; k < i->get_type()->get_tile_shapes().size(); k++){
// std::string suffix = ".d" + std::to_string(k);
// result *= params_.at(i).at("mts" + suffix)->get_value();
// }
// }
// else {
unsigned result = 32;
for(unsigned k = 0; k < i->get_type()->get_tile_shapes().size(); k++){
std::string suffix = ".d" + std::to_string(k);
result *= params_.at(i).at("wpt" + suffix)->get_value();
}
// }
}
void tune::create_grids(std::vector<ir::instruction*> &grids, void tune::create_grids(std::vector<ir::instruction*> &grids,
std::map<ir::metaparameter*, ir::instruction*> &references, std::map<ir::metaparameter*, ir::instruction*> &references,
ir::function *fn) { ir::function *fn) {
@@ -307,16 +330,30 @@ bool tune::check_constraints(std::map<ir::value *, std::vector<std::string>> &er
// must device the shape // must device the shape
for(size_t k = 0; k < shapes.size(); k++) { for(size_t k = 0; k < shapes.size(); k++) {
std::string strk = to_string(k); std::string strk = to_string(k);
ir::metaparameter *mts = params_[i]["mts.d" + strk]; unsigned multiple;
ir::metaparameter *nts = params_[i]["nts.d" + strk]; // if(fragments_.at({i, 0}) == STRIDED_SCAN) {
unsigned multiple = mts->get_value()*nts->get_value(); ir::metaparameter *mts = params_[i]["mts.d" + strk];
ir::metaparameter *nts = params_[i]["nts.d" + strk];
multiple = mts->get_value()*nts->get_value();
// }
// else {
// ir::metaparameter *fpw = params_[i]["fpw.d" + strk];
// ir::metaparameter *wpt = params_[i]["wpt.d" + strk];
// multiple = fpw->get_value()*wpt->get_value();
// }
if(shapes[k]->get_value() % multiple != 0) if(shapes[k]->get_value() % multiple != 0)
errors[i].push_back("for dim " + strk + ": shape (" + to_string(shapes[k]->get_value()) + ")" errors[i].push_back("for dim " + strk + ": shape (" + to_string(shapes[k]->get_value()) + ")"
" is not a multiple of layout (" + to_string(multiple) + ")"); " is not a multiple of layout (" + to_string(multiple) + ")");
} }
int num_threads = 1; // the product of mma fragments per warp must be 4
for(size_t k = 0; k < shapes.size(); k++) // if(fragments_.at({i, 0}) == HMMA_FRAGMENT_C){
num_threads *= params_[i]["mts.d" + to_string(k)]->get_value(); // unsigned prod = 1;
// for(size_t k = 0; k < shapes.size(); k++)
// prod *= params_[i]["fpw.d" + std::to_string(k)]->get_value();
// if(prod != 4)
// errors[i].push_back("HMMA must have only 4 fragments per warp");
// }
int num_threads = get_req_num_threads(i);
if(num_threads % 32 != 0) if(num_threads % 32 != 0)
errors[i].push_back("number of threads per block (" + to_string(num_threads) + ") must be multiple of warp size"); errors[i].push_back("number of threads per block (" + to_string(num_threads) + ") must be multiple of warp size");
if(num_threads != num_threads_) if(num_threads != num_threads_)