fixup
This commit is contained in:
@@ -194,7 +194,6 @@ std::vector<ir::metaparameter *> tune::get_params(ir::module &mod) {
|
||||
for(ir::instruction *i : block->get_inst_list())
|
||||
for(auto &x: params_[i])
|
||||
if(seen.insert(x.second).second && !x.second->has_value()){
|
||||
// std::cout << i->get_name() << " " << x.first << std::endl;
|
||||
result.push_back(x.second);
|
||||
}
|
||||
|
||||
@@ -291,28 +290,29 @@ void tune::run(ir::module &mod) {
|
||||
}
|
||||
|
||||
// initialize grids
|
||||
|
||||
// for(ir::instruction *i: grids_){
|
||||
// auto shapes = i->get_type()->get_tile_shapes();
|
||||
// for(size_t k = 0; k < shapes.size(); k++)
|
||||
// if(shapes[k]->get_value() == 1) {
|
||||
// if(fragments_.at({i, k}) == STRIDED_SCAN){
|
||||
// params_.at(i).at("nts.d" + std::to_string(k))->set_value(1);
|
||||
// params_.at(i).at("mts.d" + std::to_string(k))->set_value(1);
|
||||
// }
|
||||
// if(fragments_.at({i, k}) == HMMA_FRAGMENT_C){
|
||||
// params_.at(i).at("fpw.d" + std::to_string(k))->set_value(1);
|
||||
// params_.at(i).at("wpt.d" + std::to_string(k))->set_value(1);
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
}
|
||||
|
||||
void tune::init(ir::module &mod) {
|
||||
for(ir::function *fn: mod.get_function_list()){
|
||||
std::map<ir::metaparameter*, ir::instruction*> references;
|
||||
create_grids(grids_, references, fn);
|
||||
}
|
||||
|
||||
for(ir::instruction *i: grids_){
|
||||
auto shapes = i->get_type()->get_tile_shapes();
|
||||
for(size_t k = 0; k < shapes.size(); k++)
|
||||
if(shapes[k]->get_value() == 1) {
|
||||
if(fragments_.at({i, k}) == STRIDED_SCAN){
|
||||
params_.at(i).at("nts.d" + std::to_string(k))->set_value(1);
|
||||
params_.at(i).at("mts.d" + std::to_string(k))->set_value(1);
|
||||
}
|
||||
if(fragments_.at({i, k}) == HMMA_FRAGMENT_C){
|
||||
params_.at(i).at("fpw.d" + std::to_string(k))->set_value(1);
|
||||
params_.at(i).at("wpt.d" + std::to_string(k))->set_value(1);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void tune::init(ir::module &mod) {
|
||||
num_threads_ = get_req_num_threads(grids_.front());
|
||||
}
|
||||
|
||||
@@ -407,7 +407,9 @@ bool tune::check_constraints(std::map<ir::value *, std::vector<std::string>> &er
|
||||
else {
|
||||
ir::metaparameter *fpw = params_[i]["fpw.d" + strk];
|
||||
ir::metaparameter *wpt = params_[i]["wpt.d" + strk];
|
||||
multiple = fpw->get_value()*wpt->get_value()*8;
|
||||
multiple = fpw->get_value()*wpt->get_value();
|
||||
if(k < 2)
|
||||
multiple *= 8;
|
||||
}
|
||||
if(shapes[k]->get_value() % multiple != 0)
|
||||
errors[i].push_back("for dim " + strk + ": shape (" + to_string(shapes[k]->get_value()) + ")"
|
||||
|
Reference in New Issue
Block a user