This commit is contained in:
Philippe Tillet
2019-08-06 16:44:16 -07:00
parent 5efdb7978e
commit cf256a636c
7 changed files with 42 additions and 39 deletions

View File

@@ -194,7 +194,6 @@ std::vector<ir::metaparameter *> tune::get_params(ir::module &mod) {
for(ir::instruction *i : block->get_inst_list())
for(auto &x: params_[i])
if(seen.insert(x.second).second && !x.second->has_value()){
// std::cout << i->get_name() << " " << x.first << std::endl;
result.push_back(x.second);
}
@@ -291,28 +290,29 @@ void tune::run(ir::module &mod) {
}
// initialize grids
// for(ir::instruction *i: grids_){
// auto shapes = i->get_type()->get_tile_shapes();
// for(size_t k = 0; k < shapes.size(); k++)
// if(shapes[k]->get_value() == 1) {
// if(fragments_.at({i, k}) == STRIDED_SCAN){
// params_.at(i).at("nts.d" + std::to_string(k))->set_value(1);
// params_.at(i).at("mts.d" + std::to_string(k))->set_value(1);
// }
// if(fragments_.at({i, k}) == HMMA_FRAGMENT_C){
// params_.at(i).at("fpw.d" + std::to_string(k))->set_value(1);
// params_.at(i).at("wpt.d" + std::to_string(k))->set_value(1);
// }
// }
// }
}
void tune::init(ir::module &mod) {
for(ir::function *fn: mod.get_function_list()){
std::map<ir::metaparameter*, ir::instruction*> references;
create_grids(grids_, references, fn);
}
for(ir::instruction *i: grids_){
auto shapes = i->get_type()->get_tile_shapes();
for(size_t k = 0; k < shapes.size(); k++)
if(shapes[k]->get_value() == 1) {
if(fragments_.at({i, k}) == STRIDED_SCAN){
params_.at(i).at("nts.d" + std::to_string(k))->set_value(1);
params_.at(i).at("mts.d" + std::to_string(k))->set_value(1);
}
if(fragments_.at({i, k}) == HMMA_FRAGMENT_C){
params_.at(i).at("fpw.d" + std::to_string(k))->set_value(1);
params_.at(i).at("wpt.d" + std::to_string(k))->set_value(1);
}
}
}
}
void tune::init(ir::module &mod) {
num_threads_ = get_req_num_threads(grids_.front());
}
@@ -407,7 +407,9 @@ bool tune::check_constraints(std::map<ir::value *, std::vector<std::string>> &er
else {
ir::metaparameter *fpw = params_[i]["fpw.d" + strk];
ir::metaparameter *wpt = params_[i]["wpt.d" + strk];
multiple = fpw->get_value()*wpt->get_value()*8;
multiple = fpw->get_value()*wpt->get_value();
if(k < 2)
multiple *= 8;
}
if(shapes[k]->get_value() % multiple != 0)
errors[i].push_back("for dim " + strk + ": shape (" + to_string(shapes[k]->get_value()) + ")"