[CODEGEN] Fixed bug in recoalesce_inst LLVM codegen
This commit is contained in:
@@ -98,21 +98,35 @@ data_layout::data_layout(id_t id,
|
||||
extract_io_use(v, ptr);
|
||||
order_.resize(axes_.size());
|
||||
std::iota(order_.begin(), order_.end(), 0);
|
||||
auto largest = std::max_element(ptr.begin(), ptr.end(), [&](ir::value *x, ir::value *y){
|
||||
std::pair<int, int> xx = {x->get_type()->get_tile_rank(), x->get_type()->get_tile_num_elements()};
|
||||
std::pair<int, int> yy = {y->get_type()->get_tile_rank(), y->get_type()->get_tile_num_elements()};
|
||||
return xx < yy;
|
||||
});
|
||||
if(*largest){
|
||||
auto max_contiguous = align->contiguous(*largest);
|
||||
std::vector<unsigned> max_contiguous;
|
||||
for(ir::value* p: ptr){
|
||||
std::vector<unsigned> curr = align->contiguous(p);
|
||||
if(curr.size() > max_contiguous.size())
|
||||
max_contiguous = curr;
|
||||
else if(curr.size() == max_contiguous.size()){
|
||||
if(*std::max_element(curr.begin(), curr.end()) > *std::max_element(max_contiguous.begin(), max_contiguous.end()))
|
||||
max_contiguous = curr;
|
||||
}
|
||||
}
|
||||
bool is_recoalesce = false;
|
||||
for(ir::value* v: values)
|
||||
is_recoalesce = is_recoalesce || dynamic_cast<ir::recoalesce_inst*>(v);
|
||||
if(max_contiguous.size() > 0){
|
||||
std::sort(order_.begin(), order_.end(), [&](unsigned a, unsigned b) {
|
||||
return max_contiguous[a] > max_contiguous[b];
|
||||
});
|
||||
// std::cout << "===" << std::endl;
|
||||
// std::cout << (*largest)->get_name() << std::endl;
|
||||
// for(ir::value* x: ptr)
|
||||
// std::cout << x->get_name() << std::endl;
|
||||
// std::cout << max_contiguous[0] << " " << max_contiguous[1] << std::endl;
|
||||
// std::cout << order_[0] << " " << order_[1] << std::endl;
|
||||
}
|
||||
if(is_recoalesce){
|
||||
if(ptr.size() > 0){
|
||||
// std::cout << "recoalesce: " << order_[0] << " " << order_[1] << " " << ptr.size() << std::endl;
|
||||
// std::cout << max_contiguous[0] << " " << max_contiguous[1] << std::endl;
|
||||
// if(order_[0] == 0)
|
||||
// exit(1);
|
||||
}
|
||||
}
|
||||
// std::cout << "---" << std::endl;
|
||||
}
|
||||
|
||||
int data_layout::find_axis(int to_find) const {
|
||||
@@ -136,16 +150,16 @@ mma_layout::mma_layout(size_t num_warps,
|
||||
/* fragments per warp */
|
||||
// try to make things as square as possible to maximize data re-use
|
||||
if(tgt->as_nvidia()->sm() < 80){
|
||||
fpw_ = {1, 1, 1};
|
||||
std::vector<int> fpw_nm1;
|
||||
unsigned num_fragments = std::min<unsigned>((shape_[0]/8)*(shape_[1]/8), 4);
|
||||
do {
|
||||
fpw_nm1 = fpw_;
|
||||
if(fpw_[0]*fpw_[1] < num_fragments)
|
||||
fpw_[0] = clamp(fpw_[0]*2, 1, shape_[0] / 8);
|
||||
if(fpw_[0]*fpw_[1] < num_fragments)
|
||||
fpw_[1] = clamp(fpw_[1]*2, 1, shape_[1] / 8);
|
||||
}while(fpw_nm1 != fpw_);
|
||||
fpw_ = {2, 2, 1};
|
||||
// std::vector<int> fpw_nm1;
|
||||
// unsigned num_fragments = std::min<unsigned>((shape_[0]/8)*(shape_[1]/8), 4);
|
||||
// do {
|
||||
// fpw_nm1 = fpw_;
|
||||
// if(fpw_[0]*fpw_[1] < num_fragments)
|
||||
// fpw_[0] = clamp(fpw_[0]*2, 1, shape_[0] / 8);
|
||||
// if(fpw_[0]*fpw_[1] < num_fragments)
|
||||
// fpw_[1] = clamp(fpw_[1]*2, 1, shape_[1] / 8);
|
||||
// }while(fpw_nm1 != fpw_);
|
||||
auto ord_a = layout_a->get_order();
|
||||
auto ord_b = layout_b->get_order();
|
||||
bool is_a_row = ord_a[0] != 0;
|
||||
@@ -154,8 +168,9 @@ mma_layout::mma_layout(size_t num_warps,
|
||||
bool is_b_vec4 = is_b_row && (layout_b->get_shape()[ord_b[0]] <= 16);
|
||||
int pack_size_0 = (is_a_row || is_a_vec4) ? 1 : 2;
|
||||
int pack_size_1 = (is_b_row && !is_b_vec4) ? 2 : 1;
|
||||
pack_size_0 = pack_size_1 = 1;
|
||||
rep_ = {2*pack_size_0, 2*pack_size_1, 1};
|
||||
spw_ = {fpw_[0]*8*pack_size_0, fpw_[1]*8*pack_size_1, 1};
|
||||
spw_ = {fpw_[0]*4*rep_[0], fpw_[1]*4*rep_[1], 1};
|
||||
}
|
||||
else{
|
||||
fpw_ = {1, 1, 1};
|
||||
|
Reference in New Issue
Block a user