ugh
This commit is contained in:
@@ -33,6 +33,7 @@ private:
|
||||
fragment_t get_fragmentation_type(node_t x, graph_t &graph);
|
||||
void connected_components(node_t x, const std::vector<ir::metaparameter *> mps, const std::vector<std::string> prefixes, std::set<node_t> &nodes, graph_t &graph, unsigned group_id);
|
||||
void create_grids(std::vector<ir::instruction*> &grids, std::map<ir::metaparameter *, ir::instruction *> &references, ir::function *fn);
|
||||
unsigned get_req_num_threads(ir::instruction *i);
|
||||
|
||||
|
||||
public:
|
||||
|
@@ -100,8 +100,9 @@ void tune::init_c_graph(ir::instruction *v) {
|
||||
else if(dynamic_cast<ir::user*>(v)) {
|
||||
for(unsigned k = 0; k < v->get_num_results(); k++)
|
||||
for(unsigned i = 0; i < shapes.size(); i ++){
|
||||
ir::value *result = v->get_result(k);
|
||||
for(ir::value* op: v->ops()){
|
||||
add_constraint({v->get_result(k), i}, {op, i});
|
||||
add_constraint({result, i}, {op, i});
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -199,20 +200,23 @@ void tune::run(ir::module &mod) {
|
||||
init_c_phi(i);
|
||||
// Layout parameters
|
||||
unsigned group_id = 0;
|
||||
// for(auto x: nodes_){
|
||||
// fragments_[x] = STRIDED_SCAN;
|
||||
// }
|
||||
while(!nodes_.empty()) {
|
||||
ir::type *ty = mod.get_builder().get_int32_ty();
|
||||
node_t node = *nodes_.begin();
|
||||
fragment_t fragment = get_fragmentation_type(node, dependencies_);
|
||||
if(fragment == STRIDED_SCAN) {
|
||||
// if(fragments_[node] == STRIDED_SCAN) {
|
||||
ir::metaparameter *nts = ir::metaparameter::create(ctx, ty, 1, 1);
|
||||
ir::metaparameter *mts = ir::metaparameter::create(ctx, ty, 4, 32);
|
||||
connected_components(node, {nts, mts}, {"nts", "mts"}, nodes_, dependencies_, group_id++);
|
||||
nts->set_value(1);
|
||||
}
|
||||
else {
|
||||
ir::metaparameter *fpw = ir::metaparameter::create(ctx, ty, 1, 4);
|
||||
connected_components(node, {fpw}, {"fpw"}, nodes_, dependencies_, group_id++);
|
||||
}
|
||||
// }
|
||||
// else {
|
||||
// ir::metaparameter *fpw = ir::metaparameter::create(ctx, ty, 1, 4);
|
||||
// ir::metaparameter *wpb = ir::metaparameter::create(ctx, ty, 1, 4);
|
||||
// connected_components(node, {fpw, wpb}, {"fpw", "wpb"}, nodes_, dependencies_, group_id++);
|
||||
// }
|
||||
}
|
||||
}
|
||||
|
||||
@@ -220,6 +224,8 @@ void tune::run(ir::module &mod) {
|
||||
for(ir::function *fn: mod.get_function_list())
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction *i : block->get_inst_list()){
|
||||
// if(fragments_.find({i, 0}) != fragments_.end() && fragments_.at({i, 0}) != STRIDED_SCAN)
|
||||
// continue;
|
||||
if(dynamic_cast<ir::load_inst*>(i) && i->get_type()->is_tile_ty()){
|
||||
ir::type *ty = mod.get_builder().get_int32_ty();
|
||||
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 2, 2));
|
||||
@@ -250,6 +256,23 @@ void tune::init(ir::module &mod) {
|
||||
}
|
||||
}
|
||||
|
||||
unsigned tune::get_req_num_threads(ir::instruction *i){
|
||||
// if(fragments_.at({i, 0}) == STRIDED_SCAN) {
|
||||
// unsigned result = 1;
|
||||
// for(unsigned k = 0; k < i->get_type()->get_tile_shapes().size(); k++){
|
||||
// std::string suffix = ".d" + std::to_string(k);
|
||||
// result *= params_.at(i).at("mts" + suffix)->get_value();
|
||||
// }
|
||||
// }
|
||||
// else {
|
||||
unsigned result = 32;
|
||||
for(unsigned k = 0; k < i->get_type()->get_tile_shapes().size(); k++){
|
||||
std::string suffix = ".d" + std::to_string(k);
|
||||
result *= params_.at(i).at("wpt" + suffix)->get_value();
|
||||
}
|
||||
// }
|
||||
}
|
||||
|
||||
void tune::create_grids(std::vector<ir::instruction*> &grids,
|
||||
std::map<ir::metaparameter*, ir::instruction*> &references,
|
||||
ir::function *fn) {
|
||||
@@ -307,16 +330,30 @@ bool tune::check_constraints(std::map<ir::value *, std::vector<std::string>> &er
|
||||
// must device the shape
|
||||
for(size_t k = 0; k < shapes.size(); k++) {
|
||||
std::string strk = to_string(k);
|
||||
ir::metaparameter *mts = params_[i]["mts.d" + strk];
|
||||
ir::metaparameter *nts = params_[i]["nts.d" + strk];
|
||||
unsigned multiple = mts->get_value()*nts->get_value();
|
||||
unsigned multiple;
|
||||
// if(fragments_.at({i, 0}) == STRIDED_SCAN) {
|
||||
ir::metaparameter *mts = params_[i]["mts.d" + strk];
|
||||
ir::metaparameter *nts = params_[i]["nts.d" + strk];
|
||||
multiple = mts->get_value()*nts->get_value();
|
||||
// }
|
||||
// else {
|
||||
// ir::metaparameter *fpw = params_[i]["fpw.d" + strk];
|
||||
// ir::metaparameter *wpt = params_[i]["wpt.d" + strk];
|
||||
// multiple = fpw->get_value()*wpt->get_value();
|
||||
// }
|
||||
if(shapes[k]->get_value() % multiple != 0)
|
||||
errors[i].push_back("for dim " + strk + ": shape (" + to_string(shapes[k]->get_value()) + ")"
|
||||
" is not a multiple of layout (" + to_string(multiple) + ")");
|
||||
}
|
||||
int num_threads = 1;
|
||||
for(size_t k = 0; k < shapes.size(); k++)
|
||||
num_threads *= params_[i]["mts.d" + to_string(k)]->get_value();
|
||||
// the product of mma fragments per warp must be 4
|
||||
// if(fragments_.at({i, 0}) == HMMA_FRAGMENT_C){
|
||||
// unsigned prod = 1;
|
||||
// for(size_t k = 0; k < shapes.size(); k++)
|
||||
// prod *= params_[i]["fpw.d" + std::to_string(k)]->get_value();
|
||||
// if(prod != 4)
|
||||
// errors[i].push_back("HMMA must have only 4 fragments per warp");
|
||||
// }
|
||||
int num_threads = get_req_num_threads(i);
|
||||
if(num_threads % 32 != 0)
|
||||
errors[i].push_back("number of threads per block (" + to_string(num_threads) + ") must be multiple of warp size");
|
||||
if(num_threads != num_threads_)
|
||||
|
Reference in New Issue
Block a user