diff --git a/lib/codegen/analysis/layout.cc b/lib/codegen/analysis/layout.cc index 989963206..550061975 100644 --- a/lib/codegen/analysis/layout.cc +++ b/lib/codegen/analysis/layout.cc @@ -98,21 +98,35 @@ data_layout::data_layout(id_t id, extract_io_use(v, ptr); order_.resize(axes_.size()); std::iota(order_.begin(), order_.end(), 0); - auto largest = std::max_element(ptr.begin(), ptr.end(), [&](ir::value *x, ir::value *y){ - std::pair xx = {x->get_type()->get_tile_rank(), x->get_type()->get_tile_num_elements()}; - std::pair yy = {y->get_type()->get_tile_rank(), y->get_type()->get_tile_num_elements()}; - return xx < yy; - }); - if(*largest){ - auto max_contiguous = align->contiguous(*largest); + std::vector max_contiguous; + for(ir::value* p: ptr){ + std::vector curr = align->contiguous(p); + if(curr.size() > max_contiguous.size()) + max_contiguous = curr; + else if(curr.size() == max_contiguous.size()){ + if(*std::max_element(curr.begin(), curr.end()) > *std::max_element(max_contiguous.begin(), max_contiguous.end())) + max_contiguous = curr; + } + } + bool is_recoalesce = false; + for(ir::value* v: values) + is_recoalesce = is_recoalesce || dynamic_cast(v); + if(max_contiguous.size() > 0){ std::sort(order_.begin(), order_.end(), [&](unsigned a, unsigned b) { return max_contiguous[a] > max_contiguous[b]; }); -// std::cout << "===" << std::endl; -// std::cout << (*largest)->get_name() << std::endl; -// for(ir::value* x: ptr) -// std::cout << x->get_name() << std::endl; +// std::cout << max_contiguous[0] << " " << max_contiguous[1] << std::endl; +// std::cout << order_[0] << " " << order_[1] << std::endl; } + if(is_recoalesce){ + if(ptr.size() > 0){ +// std::cout << "recoalesce: " << order_[0] << " " << order_[1] << " " << ptr.size() << std::endl; +// std::cout << max_contiguous[0] << " " << max_contiguous[1] << std::endl; +// if(order_[0] == 0) +// exit(1); + } + } +// std::cout << "---" << std::endl; } int data_layout::find_axis(int to_find) const { @@ -136,16 +150,16 @@ mma_layout::mma_layout(size_t num_warps, /* fragments per warp */ // try to make things as square as possible to maximize data re-use if(tgt->as_nvidia()->sm() < 80){ - fpw_ = {1, 1, 1}; - std::vector fpw_nm1; - unsigned num_fragments = std::min((shape_[0]/8)*(shape_[1]/8), 4); - do { - fpw_nm1 = fpw_; - if(fpw_[0]*fpw_[1] < num_fragments) - fpw_[0] = clamp(fpw_[0]*2, 1, shape_[0] / 8); - if(fpw_[0]*fpw_[1] < num_fragments) - fpw_[1] = clamp(fpw_[1]*2, 1, shape_[1] / 8); - }while(fpw_nm1 != fpw_); + fpw_ = {2, 2, 1}; +// std::vector fpw_nm1; +// unsigned num_fragments = std::min((shape_[0]/8)*(shape_[1]/8), 4); +// do { +// fpw_nm1 = fpw_; +// if(fpw_[0]*fpw_[1] < num_fragments) +// fpw_[0] = clamp(fpw_[0]*2, 1, shape_[0] / 8); +// if(fpw_[0]*fpw_[1] < num_fragments) +// fpw_[1] = clamp(fpw_[1]*2, 1, shape_[1] / 8); +// }while(fpw_nm1 != fpw_); auto ord_a = layout_a->get_order(); auto ord_b = layout_b->get_order(); bool is_a_row = ord_a[0] != 0; @@ -154,8 +168,9 @@ mma_layout::mma_layout(size_t num_warps, bool is_b_vec4 = is_b_row && (layout_b->get_shape()[ord_b[0]] <= 16); int pack_size_0 = (is_a_row || is_a_vec4) ? 1 : 2; int pack_size_1 = (is_b_row && !is_b_vec4) ? 2 : 1; + pack_size_0 = pack_size_1 = 1; rep_ = {2*pack_size_0, 2*pack_size_1, 1}; - spw_ = {fpw_[0]*8*pack_size_0, fpw_[1]*8*pack_size_1, 1}; + spw_ = {fpw_[0]*4*rep_[0], fpw_[1]*4*rep_[1], 1}; } else{ fpw_ = {1, 1, 1}; diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index 939ba25d1..8c900df70 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -1386,29 +1386,29 @@ void generator::visit_recoalesce_inst(ir::recoalesce_inst* rc) { auto in_ord1 = axes_.at(a_axes_->get(op, ord[1])).values; auto out_ord0 = axes_.at(a_axes_->get(rc, ord[0])).values; auto out_ord1 = axes_.at(a_axes_->get(rc, ord[1])).values; - int in_outer = in_layout->spt(ord[1]); - int in_rep = in_layout->rep(ord[1]); - int out_outer = out_layout->mts(ord[1]) * out_layout->nts(ord[1]); - int max_outer = std::max(in_outer, out_outer); - int out_ratio = std::max(out_outer/in_outer, 1); - int in_ratio = std::max(in_outer/out_outer, 1); + int in_spt0 = in_layout->spt(ord[0]); + int in_spt1 = in_layout->spt(ord[1]); + int out_spt0 = out_layout->mts(ord[0])*out_layout->nts(ord[0]); + int out_spt1 = out_layout->mts(ord[1])*out_layout->nts(ord[1]); + int max_spt1 = std::max(in_spt1, out_spt1); indices_t idx(2); - for(size_t j = 0; j < shape[ord[1]]/max_outer; j++){ + int num_packs = shape[ord[1]]/max_spt1; + for(size_t j = 0; j < num_packs; j++){ add_barrier(); - for(size_t k = 0; k < in_rep*out_ratio; k++) + for(size_t k = 0; k < in_ord1.size()/num_packs; k++) for(size_t i = 0; i < in_ord0.size(); i++){ idx[ord[0]] = in_ord0[i]; - idx[ord[1]] = in_ord1[j*in_rep*out_ratio + k]; + idx[ord[1]] = in_ord1[j*in_ord1.size()/num_packs + k]; Value *off = add(idx[ord[0]], mul(in_ord1[k], ld)); Value *ptr = gep(base, off); store(vals_[op][idx], ptr); } add_barrier(); - for(size_t k = 0; k < in_ratio; k++) + for(size_t k = 0; k < out_ord1.size()/num_packs; k++) for(size_t i = 0; i < out_ord0.size(); i++){ idx[ord[0]] = out_ord0[i]; - idx[ord[1]] = out_ord1[j*in_ratio + k]; - Value *off = add(out_ord0[i], mul(out_ord1[k], ld)); + idx[ord[1]] = out_ord1[j*out_ord1.size()/num_packs + k]; + Value *off = add(idx[ord[0]], mul(out_ord1[k], ld)); Value *ptr = gep(base, off); vals_[rc][idx] = load(ptr); } diff --git a/lib/runtime/function.cc b/lib/runtime/function.cc index 3fcf74e66..c6e4a5032 100644 --- a/lib/runtime/function.cc +++ b/lib/runtime/function.cc @@ -241,6 +241,7 @@ std::unique_ptr function::make_bin(ir::module &module, driver::d dce.run(module); align.run(module); axes.run(module); +// ir::print(module, std::cout); layouts.run(module); swizzle.run(module); liveness.run(module); @@ -248,7 +249,6 @@ std::unique_ptr function::make_bin(ir::module &module, driver::d if(allocation.allocated_size() > device->max_shared_memory()) throw exception::out_of_shared_memory(); barriers.run(module); -// ir::print(module, std::cout); isel.visit(module, *llvm); std::unique_ptr res(driver::module::create(device, std::move(llvm))); // if(res->spilled() > 256)