[CODEGEN] Fixed bug in recoalesce_inst LLVM codegen
This commit is contained in:
@@ -98,21 +98,35 @@ data_layout::data_layout(id_t id,
|
||||
extract_io_use(v, ptr);
|
||||
order_.resize(axes_.size());
|
||||
std::iota(order_.begin(), order_.end(), 0);
|
||||
auto largest = std::max_element(ptr.begin(), ptr.end(), [&](ir::value *x, ir::value *y){
|
||||
std::pair<int, int> xx = {x->get_type()->get_tile_rank(), x->get_type()->get_tile_num_elements()};
|
||||
std::pair<int, int> yy = {y->get_type()->get_tile_rank(), y->get_type()->get_tile_num_elements()};
|
||||
return xx < yy;
|
||||
});
|
||||
if(*largest){
|
||||
auto max_contiguous = align->contiguous(*largest);
|
||||
std::vector<unsigned> max_contiguous;
|
||||
for(ir::value* p: ptr){
|
||||
std::vector<unsigned> curr = align->contiguous(p);
|
||||
if(curr.size() > max_contiguous.size())
|
||||
max_contiguous = curr;
|
||||
else if(curr.size() == max_contiguous.size()){
|
||||
if(*std::max_element(curr.begin(), curr.end()) > *std::max_element(max_contiguous.begin(), max_contiguous.end()))
|
||||
max_contiguous = curr;
|
||||
}
|
||||
}
|
||||
bool is_recoalesce = false;
|
||||
for(ir::value* v: values)
|
||||
is_recoalesce = is_recoalesce || dynamic_cast<ir::recoalesce_inst*>(v);
|
||||
if(max_contiguous.size() > 0){
|
||||
std::sort(order_.begin(), order_.end(), [&](unsigned a, unsigned b) {
|
||||
return max_contiguous[a] > max_contiguous[b];
|
||||
});
|
||||
// std::cout << "===" << std::endl;
|
||||
// std::cout << (*largest)->get_name() << std::endl;
|
||||
// for(ir::value* x: ptr)
|
||||
// std::cout << x->get_name() << std::endl;
|
||||
// std::cout << max_contiguous[0] << " " << max_contiguous[1] << std::endl;
|
||||
// std::cout << order_[0] << " " << order_[1] << std::endl;
|
||||
}
|
||||
if(is_recoalesce){
|
||||
if(ptr.size() > 0){
|
||||
// std::cout << "recoalesce: " << order_[0] << " " << order_[1] << " " << ptr.size() << std::endl;
|
||||
// std::cout << max_contiguous[0] << " " << max_contiguous[1] << std::endl;
|
||||
// if(order_[0] == 0)
|
||||
// exit(1);
|
||||
}
|
||||
}
|
||||
// std::cout << "---" << std::endl;
|
||||
}
|
||||
|
||||
int data_layout::find_axis(int to_find) const {
|
||||
@@ -136,16 +150,16 @@ mma_layout::mma_layout(size_t num_warps,
|
||||
/* fragments per warp */
|
||||
// try to make things as square as possible to maximize data re-use
|
||||
if(tgt->as_nvidia()->sm() < 80){
|
||||
fpw_ = {1, 1, 1};
|
||||
std::vector<int> fpw_nm1;
|
||||
unsigned num_fragments = std::min<unsigned>((shape_[0]/8)*(shape_[1]/8), 4);
|
||||
do {
|
||||
fpw_nm1 = fpw_;
|
||||
if(fpw_[0]*fpw_[1] < num_fragments)
|
||||
fpw_[0] = clamp(fpw_[0]*2, 1, shape_[0] / 8);
|
||||
if(fpw_[0]*fpw_[1] < num_fragments)
|
||||
fpw_[1] = clamp(fpw_[1]*2, 1, shape_[1] / 8);
|
||||
}while(fpw_nm1 != fpw_);
|
||||
fpw_ = {2, 2, 1};
|
||||
// std::vector<int> fpw_nm1;
|
||||
// unsigned num_fragments = std::min<unsigned>((shape_[0]/8)*(shape_[1]/8), 4);
|
||||
// do {
|
||||
// fpw_nm1 = fpw_;
|
||||
// if(fpw_[0]*fpw_[1] < num_fragments)
|
||||
// fpw_[0] = clamp(fpw_[0]*2, 1, shape_[0] / 8);
|
||||
// if(fpw_[0]*fpw_[1] < num_fragments)
|
||||
// fpw_[1] = clamp(fpw_[1]*2, 1, shape_[1] / 8);
|
||||
// }while(fpw_nm1 != fpw_);
|
||||
auto ord_a = layout_a->get_order();
|
||||
auto ord_b = layout_b->get_order();
|
||||
bool is_a_row = ord_a[0] != 0;
|
||||
@@ -154,8 +168,9 @@ mma_layout::mma_layout(size_t num_warps,
|
||||
bool is_b_vec4 = is_b_row && (layout_b->get_shape()[ord_b[0]] <= 16);
|
||||
int pack_size_0 = (is_a_row || is_a_vec4) ? 1 : 2;
|
||||
int pack_size_1 = (is_b_row && !is_b_vec4) ? 2 : 1;
|
||||
pack_size_0 = pack_size_1 = 1;
|
||||
rep_ = {2*pack_size_0, 2*pack_size_1, 1};
|
||||
spw_ = {fpw_[0]*8*pack_size_0, fpw_[1]*8*pack_size_1, 1};
|
||||
spw_ = {fpw_[0]*4*rep_[0], fpw_[1]*4*rep_[1], 1};
|
||||
}
|
||||
else{
|
||||
fpw_ = {1, 1, 1};
|
||||
|
@@ -1386,29 +1386,29 @@ void generator::visit_recoalesce_inst(ir::recoalesce_inst* rc) {
|
||||
auto in_ord1 = axes_.at(a_axes_->get(op, ord[1])).values;
|
||||
auto out_ord0 = axes_.at(a_axes_->get(rc, ord[0])).values;
|
||||
auto out_ord1 = axes_.at(a_axes_->get(rc, ord[1])).values;
|
||||
int in_outer = in_layout->spt(ord[1]);
|
||||
int in_rep = in_layout->rep(ord[1]);
|
||||
int out_outer = out_layout->mts(ord[1]) * out_layout->nts(ord[1]);
|
||||
int max_outer = std::max(in_outer, out_outer);
|
||||
int out_ratio = std::max(out_outer/in_outer, 1);
|
||||
int in_ratio = std::max(in_outer/out_outer, 1);
|
||||
int in_spt0 = in_layout->spt(ord[0]);
|
||||
int in_spt1 = in_layout->spt(ord[1]);
|
||||
int out_spt0 = out_layout->mts(ord[0])*out_layout->nts(ord[0]);
|
||||
int out_spt1 = out_layout->mts(ord[1])*out_layout->nts(ord[1]);
|
||||
int max_spt1 = std::max(in_spt1, out_spt1);
|
||||
indices_t idx(2);
|
||||
for(size_t j = 0; j < shape[ord[1]]/max_outer; j++){
|
||||
int num_packs = shape[ord[1]]/max_spt1;
|
||||
for(size_t j = 0; j < num_packs; j++){
|
||||
add_barrier();
|
||||
for(size_t k = 0; k < in_rep*out_ratio; k++)
|
||||
for(size_t k = 0; k < in_ord1.size()/num_packs; k++)
|
||||
for(size_t i = 0; i < in_ord0.size(); i++){
|
||||
idx[ord[0]] = in_ord0[i];
|
||||
idx[ord[1]] = in_ord1[j*in_rep*out_ratio + k];
|
||||
idx[ord[1]] = in_ord1[j*in_ord1.size()/num_packs + k];
|
||||
Value *off = add(idx[ord[0]], mul(in_ord1[k], ld));
|
||||
Value *ptr = gep(base, off);
|
||||
store(vals_[op][idx], ptr);
|
||||
}
|
||||
add_barrier();
|
||||
for(size_t k = 0; k < in_ratio; k++)
|
||||
for(size_t k = 0; k < out_ord1.size()/num_packs; k++)
|
||||
for(size_t i = 0; i < out_ord0.size(); i++){
|
||||
idx[ord[0]] = out_ord0[i];
|
||||
idx[ord[1]] = out_ord1[j*in_ratio + k];
|
||||
Value *off = add(out_ord0[i], mul(out_ord1[k], ld));
|
||||
idx[ord[1]] = out_ord1[j*out_ord1.size()/num_packs + k];
|
||||
Value *off = add(idx[ord[0]], mul(out_ord1[k], ld));
|
||||
Value *ptr = gep(base, off);
|
||||
vals_[rc][idx] = load(ptr);
|
||||
}
|
||||
|
@@ -241,6 +241,7 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::d
|
||||
dce.run(module);
|
||||
align.run(module);
|
||||
axes.run(module);
|
||||
// ir::print(module, std::cout);
|
||||
layouts.run(module);
|
||||
swizzle.run(module);
|
||||
liveness.run(module);
|
||||
@@ -248,7 +249,6 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::d
|
||||
if(allocation.allocated_size() > device->max_shared_memory())
|
||||
throw exception::out_of_shared_memory();
|
||||
barriers.run(module);
|
||||
// ir::print(module, std::cout);
|
||||
isel.visit(module, *llvm);
|
||||
std::unique_ptr<driver::module> res(driver::module::create(device, std::move(llvm)));
|
||||
// if(res->spilled() > 256)
|
||||
|
Reference in New Issue
Block a user