#include "triton/codegen/alignment_info.h" #include "triton/ir/module.h" #include "triton/ir/function.h" #include "triton/ir/basic_block.h" #include "triton/ir/instructions.h" #include "triton/ir/type.h" namespace triton { namespace codegen{ template inline T add_to_cache(ir::value *i, T value, std::map &map) { return map[i] = value; } bool alignment_info::is_first_axis_unit(ir::value *x){ if(x->get_type()->is_tile_ty()) return x->get_type()->get_tile_shapes()[0]->get_value() == 1; else return true; } bool alignment_info::populate_is_constant(ir::value *v) { if(is_constant_.find(v) != is_constant_.end()) return is_constant_.at(v); // helper for the cache auto cache = [this,v](bool value){ return add_to_cache(v, value, is_constant_); }; // populate if(auto *x = dynamic_cast(v)){ ir::value *op = x->get_operand(0); populate_is_constant(op); if(is_first_axis_unit(op)) return cache(true); } if(auto *x = dynamic_cast(v)) return cache(true); if(auto *x = dynamic_cast(v)){ bool lhs = populate_is_constant(x->get_operand(0)); bool rhs = populate_is_constant(x->get_operand(1)); return cache(lhs && rhs); } if(auto *x = dynamic_cast(v)){ bool value_true = populate_is_constant(x->get_value_true()); bool value_false = populate_is_constant(x->get_value_false()); return cache(value_true && value_false); } if(v->get_type()->is_tile_ty()) return cache(false); if(auto *x = dynamic_cast(v)){ // put a conservative initial value in phi node to avoid infinite recursion bool result = true; for(unsigned n = 0; n < x->get_num_incoming(); n++){ ir::value* inc = x->get_incoming_value(n); if(is_constant_.find(inc) != is_constant_.end()) result = is_constant_.at(inc); } cache(result); // recurse for(unsigned n = 0; n < x->get_num_incoming(); n++){ ir::value* inc = x->get_incoming_value(n); result = result && populate_is_constant(inc); } return cache(result); } // scalars are always constant in the contiguous dimension return cache(true); } unsigned alignment_info::populate_max_contiguous(ir::value *v){ if(max_contiguous_.find(v) != max_contiguous_.end()) return max_contiguous_.at(v); // helper for the cache auto cache = [this,v](unsigned value){ return add_to_cache(v, value, max_contiguous_); }; // populate if(!v->get_type()->is_tile_ty()) return cache(1); auto shapes = v->get_type()->get_tile_shapes(); if(dynamic_cast(v)) return cache(shapes[0]->get_value()); if(dynamic_cast(v)) return cache(shapes[0]->get_value()); if(auto *x = dynamic_cast(v)){ ir::value *op = x->get_operand(0); if(op->get_type()->is_tile_ty()){ auto op_shapes = op->get_type()->get_tile_shapes(); if(op_shapes[0] == shapes[0]) return cache(populate_max_contiguous(op)); } return cache(1); } if(auto *x = dynamic_cast(v)){ ir::value* lhs = x->get_operand(0); ir::value* rhs = x->get_operand(1); unsigned lhs_max_contiguous = populate_max_contiguous(lhs); unsigned rhs_max_contiguous = populate_max_contiguous(rhs); bool lhs_has_cst = populate_is_constant(lhs); bool rhs_has_cst = populate_is_constant(rhs); if(x->is_int_add_sub()){ if(lhs_has_cst) return cache(rhs_max_contiguous); if(rhs_has_cst) return cache(lhs_max_contiguous); } } if(auto *x = dynamic_cast(v)){ int value_true = populate_max_contiguous(x->get_value_true()); int value_false = populate_max_contiguous(x->get_value_false()); return cache(std::min(value_true, value_false)); } if(auto *x = dynamic_cast(v)){ ir::value* lhs = x->get_operand(0); ir::value* rhs = x->get_operand(1); unsigned lhs_max_contiguous = populate_max_contiguous(lhs); unsigned rhs_max_contiguous = populate_max_contiguous(rhs); bool lhs_has_cst = populate_is_constant(lhs); bool rhs_has_cst = populate_is_constant(rhs); if(lhs_has_cst) return cache(rhs_max_contiguous); if(rhs_has_cst) return cache(lhs_max_contiguous); } if(auto *x = dynamic_cast(v)){ // put a conservative initial value in phi node to avoid infinite recursion unsigned result = 1; for(unsigned n = 0; n < x->get_num_incoming(); n++){ ir::value* inc = x->get_incoming_value(n); if(max_contiguous_.find(inc) != max_contiguous_.end()) result = max_contiguous_.at(inc); } cache(result); // recurse for(unsigned n = 0; n < x->get_num_incoming(); n++){ ir::value* inc = x->get_incoming_value(n); result = std::min(result, populate_max_contiguous(inc)); } return cache(result); } return cache(1); } inline int gcd(int a, int b) { if (a == 0) return b; if (b == 0) return a; if (a == b) return a; if (a > b) return gcd(a-b, b); return gcd(a, b-a); } unsigned alignment_info::populate_starting_multiple(ir::value *v){ if(starting_multiple_.find(v) != starting_multiple_.end()) return starting_multiple_.at(v); auto cache = [this,v](unsigned value){ return add_to_cache(v, value, starting_multiple_); }; // has metadata if(auto *x = dynamic_cast(v)){ unsigned multiple_of = x->get_metadata(ir::metadata::multiple_of); if(multiple_of > 0) return cache(multiple_of); } // arguments if(auto *x = dynamic_cast(v)){ std::set attributes = x->get_parent()->get_attributes(x); for(auto attr: attributes){ if(attr.get_kind() == ir::multiple_of) return cache(attr.get_value()); if(attr.get_kind() == ir::aligned){ ir::type* ty = x->get_type()->get_pointer_element_ty(); int nbits = ty->get_primitive_size_in_bits(); int nbytes = nbits / 8; return cache(attr.get_value() / nbytes); } } } if(auto *x = dynamic_cast(v)){ int lhs = populate_starting_multiple(x->get_operand(0)); int rhs = populate_starting_multiple(x->get_operand(1)); if(x->is_int_mult()) return cache(lhs * rhs); if(x->is_int_add_sub()) return cache(gcd(lhs, rhs)); if(x->is_int_div()) return cache(std::max(lhs / rhs, 1)); if(x->is_int_rem()) return cache(std::max(lhs % rhs, 1)); if(x->is_shl()) return cache(lhs << rhs); if(x->is_shr()) return cache(std::max(lhs >> rhs, 1)); } if(auto *x = dynamic_cast(v)) return cache(x->get_value()); if(auto *x = dynamic_cast(v)){ return cache(x->get_first()->get_value()); } if(auto *x = dynamic_cast(v)){ int lhs = populate_starting_multiple(x->get_operand(0)); int rhs = populate_starting_multiple(x->get_operand(1)); return cache(gcd(lhs, rhs)); } if(auto *x = dynamic_cast(v)){ int op = populate_starting_multiple(x->get_operand(0)); return cache(op); } if(auto *x = dynamic_cast(v)){ return cache(v->get_type()->get_tile_shapes()[0]->get_value()); } if(auto *x = dynamic_cast(v)){ int value_true = populate_starting_multiple(x->get_value_true()); int value_false = populate_starting_multiple(x->get_value_false()); return cache(gcd(value_true, value_false)); } if(auto *x = dynamic_cast(v)){ // put a conservative initial value in phi node to avoid infinite recursion unsigned result = 1; for(unsigned n = 0; n < x->get_num_incoming(); n++){ ir::value* inc = x->get_incoming_value(n); if(starting_multiple_.find(inc) != starting_multiple_.end()) result = starting_multiple_.at(inc); } cache(result); // recurse for(unsigned n = 0; n < x->get_num_incoming(); n++){ ir::value* inc = x->get_incoming_value(n); result = gcd(result, populate_starting_multiple(inc)); } return cache(result); } // scalars if(!v->get_type()->is_tile_ty()) return cache(1); // tiles auto shapes = v->get_type()->get_tile_shapes(); unsigned result = 1; for(unsigned i = 0; i < shapes.size() - 1; i++) result *= shapes[i]->get_value(); return cache(result); } unsigned alignment_info::get_starting_multiple(ir::value* v) const { return starting_multiple_.at(v); } unsigned alignment_info::get_max_contiguous(ir::value* v) const { return max_contiguous_.at(v); } ///TODO: This doesn't seem to work in DOT-NN, DOT-TT, DOT-TN void alignment_info::run(ir::module &mod) { // populate constant for(ir::function *fn: mod.get_function_list()) for(ir::basic_block *block: fn->blocks()) for(ir::instruction *i: block->get_inst_list()){ populate_is_constant(i); } // populate starting multiple for(ir::function *fn: mod.get_function_list()) for(ir::basic_block *block: fn->blocks()) for(ir::instruction *i: block->get_inst_list()){ populate_starting_multiple(i); } // populate maximum contiguous for(ir::function *fn: mod.get_function_list()) for(ir::basic_block *block: fn->blocks()) for(ir::instruction *i: block->get_inst_list()){ populate_max_contiguous(i); // std::cout << i->get_name() << " " << max_contiguous_.at(i) << " " << starting_multiple_.at(i) << std::endl; } } } }