[CODEGEN] Various bugfixes and stability improvements in compiler backend (#240)
This commit is contained in:
@@ -586,7 +586,7 @@ void generator::visit_load_inst(ir::load_inst* x){
|
||||
Type* ty = cvt(op->get_type()->get_scalar_ty()->get_pointer_element_ty());
|
||||
// compute vector width
|
||||
size_t vec = 1;
|
||||
if(op->get_type()->is_block_ty()){
|
||||
if(op->get_type()->is_block_ty() && op->get_type()->get_tile_rank() > 1){
|
||||
auto ord = ords_.at(op);
|
||||
size_t aln = alignment_->get(op, ord[0]);
|
||||
size_t nts = layouts_->get(x)->to_scanline()->nts(ord[0]);
|
||||
@@ -626,10 +626,10 @@ void generator::visit_load_inst(ir::load_inst* x){
|
||||
// -----
|
||||
std::ostringstream asm_oss;
|
||||
asm_oss << "@$" << n_words; // predicate
|
||||
if(force_nc_cache_)
|
||||
asm_oss << " ld.global.nc";
|
||||
else
|
||||
asm_oss << " ld.global.cg";
|
||||
// if(force_nc_cache_)
|
||||
asm_oss << " ld.global";
|
||||
// else
|
||||
// asm_oss << " ld.global.cg";
|
||||
if(n_words > 1)
|
||||
asm_oss << ".v" << n_words; // vector width
|
||||
asm_oss << ".b" << width; // word size
|
||||
@@ -1058,7 +1058,8 @@ void generator::visit_mma884(ir::dot_inst* C, ir::value *A, ir::value *B, ir::va
|
||||
/* --------------------------------- */
|
||||
BasicBlock* curr_bb = builder_->GetInsertBlock();
|
||||
BasicBlock* entry = &curr_bb->getParent()->getEntryBlock();
|
||||
builder_->SetInsertPoint(entry->getTerminator());
|
||||
if(entry != curr_bb)
|
||||
builder_->SetInsertPoint(entry->getTerminator());
|
||||
Value* off_a0 = is_a_row ? offset_a_k_[layout_c] : offset_a_m_[layout_c];
|
||||
Value* off_a1 = is_a_row ? offset_a_m_[layout_c] : offset_a_k_[layout_c];
|
||||
Value* phase_a = urem(udiv(off_a1, i32(per_phase_a)), i32(max_phase_a));
|
||||
@@ -1116,8 +1117,8 @@ void generator::visit_mma884(ir::dot_inst* C, ir::value *A, ir::value *B, ir::va
|
||||
for(indices_t idx: idxs_.at(C))
|
||||
acc.push_back(vals_[D][idx]);
|
||||
|
||||
unsigned num_m = layout_c->rep(0) * shape_c[0] / layout_c->spt(0);
|
||||
unsigned num_n = layout_c->rep(1) * shape_c[1] / layout_c->spt(1);
|
||||
unsigned num_m = layout_c->rep(0) * shape_c[0] / layout_c->shape_per_cta(0);
|
||||
unsigned num_n = layout_c->rep(1) * shape_c[1] / layout_c->shape_per_cta(1);
|
||||
|
||||
// create mma & unpack result
|
||||
auto call_mma = [&](unsigned m, unsigned n, unsigned K) {
|
||||
@@ -1333,7 +1334,8 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir::
|
||||
|
||||
BasicBlock* CurrBB = builder_->GetInsertBlock();
|
||||
BasicBlock* FirstBB = &CurrBB->getParent()->getEntryBlock();
|
||||
builder_->SetInsertPoint(FirstBB->getTerminator());
|
||||
if(FirstBB != CurrBB)
|
||||
builder_->SetInsertPoint(FirstBB->getTerminator());
|
||||
|
||||
Value* thread = tgt_->get_local_id(mod_, *builder_, 0);
|
||||
Value *lane = urem(thread, i32(32));
|
||||
@@ -1396,8 +1398,8 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir::
|
||||
"{$10, $11, $12, $13};",
|
||||
"=f,=f,=f,=f,r,r,r,r,r,r,0,1,2,3", true);
|
||||
|
||||
unsigned num_rep_0 = shapes[0] / layout->spt(0);
|
||||
unsigned num_rep_1 = shapes[1] / layout->spt(1);
|
||||
unsigned num_rep_0 = shapes[0] / layout->shape_per_cta(0);
|
||||
unsigned num_rep_1 = shapes[1] / layout->shape_per_cta(1);
|
||||
|
||||
// create mma & unpack result
|
||||
auto call_mma = [&](unsigned m, unsigned n, unsigned K) {
|
||||
@@ -1626,8 +1628,8 @@ void generator::visit_fmadot(ir::dot_inst* C, ir::value* A, ir::value* B, ir::va
|
||||
std::map<std::pair<int, int>, Value*> has, hbs;
|
||||
for(unsigned k = 0; k < NK; k++){
|
||||
int z = 0;
|
||||
for(unsigned m = 0; m < shape_c[0]; m+=layout_c->mts(0)*layout_c->nts(0))
|
||||
for(unsigned n = 0; n < shape_c[1]; n+=layout_c->mts(1)*layout_c->nts(1))
|
||||
for(unsigned m = 0; m < shape_c[0]; m += layout_c->shape_per_cta(0))
|
||||
for(unsigned n = 0; n < shape_c[1]; n += layout_c->shape_per_cta(1))
|
||||
for(unsigned mm = 0; mm < layout_c->nts(0); mm++)
|
||||
for(unsigned nn = 0; nn < layout_c->nts(1); nn++)
|
||||
{
|
||||
@@ -1818,6 +1820,7 @@ void generator::visit_reducend_inst(ir::reduce_inst* x, std::function<Value*(Val
|
||||
add_barrier();
|
||||
// update accumulator
|
||||
acc = do_acc(acc, load(read_ptr));
|
||||
add_barrier();
|
||||
store(acc, write_ptr);
|
||||
}
|
||||
}
|
||||
@@ -1884,54 +1887,74 @@ void generator::visit_select_inst(ir::select_inst* x) {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Code Generation for `recoalesce`
|
||||
*/
|
||||
void generator::visit_recoalesce_inst(ir::recoalesce_inst* rc) {
|
||||
ir::value *op = rc->get_operand(0);
|
||||
ir::block_type::block_shapes_t shape = rc->get_type()->get_block_shapes();
|
||||
|
||||
|
||||
void generator::visit_layout_convert(ir::value *out, ir::value *in){
|
||||
ir::block_type::block_shapes_t shape = out->get_type()->get_block_shapes();
|
||||
// pointer to temporary shared memory
|
||||
Type *ty = cvt(rc->get_type()->get_scalar_ty());
|
||||
// layout
|
||||
analysis::mma_layout* in_layout = layouts_->get(op)->to_mma();
|
||||
analysis::scanline_layout* out_layout = layouts_->get(rc)->to_scanline();
|
||||
Type *ty = cvt(out->get_type()->get_scalar_ty());
|
||||
// Orders
|
||||
auto ord = layouts_->get(rc)->to_scanline()->get_order();
|
||||
analysis::distributed_layout* in_layout = dynamic_cast<analysis::distributed_layout*>(layouts_->get(in));
|
||||
analysis::distributed_layout* out_layout = dynamic_cast<analysis::distributed_layout*>(layouts_->get(out));
|
||||
auto in_ord = in_layout->get_order();
|
||||
auto out_ord = out_layout->get_order();
|
||||
Value *base;
|
||||
base = gep(shmem_, i32(alloc_->offset(layouts_->get(layouts_->tmp(rc)))));
|
||||
base = gep(shmem_, i32(alloc_->offset(layouts_->get(layouts_->tmp(out)))));
|
||||
base = bit_cast(base, ptr_ty(ty, 3));
|
||||
Value *ld = i32(shape[ord[0]]);
|
||||
auto in_ord0 = axes_.at(a_axes_->get(op, ord[0])).values;
|
||||
auto in_ord1 = axes_.at(a_axes_->get(op, ord[1])).values;
|
||||
auto out_ord0 = axes_.at(a_axes_->get(rc, ord[0])).values;
|
||||
auto out_ord1 = axes_.at(a_axes_->get(rc, ord[1])).values;
|
||||
int in_spt0 = in_layout->spt(ord[0]);
|
||||
int in_spt1 = in_layout->spt(ord[1]);
|
||||
int out_spt0 = out_layout->mts(ord[0])*out_layout->nts(ord[0]);
|
||||
int out_spt1 = out_layout->mts(ord[1])*out_layout->nts(ord[1]);
|
||||
int max_spt1 = std::max(in_spt1, out_spt1);
|
||||
indices_t idx(2);
|
||||
int num_packs = shape[ord[1]]/max_spt1;
|
||||
for(size_t j = 0; j < num_packs; j++){
|
||||
add_barrier();
|
||||
for(size_t k = 0; k < in_ord1.size()/num_packs; k++)
|
||||
for(size_t i = 0; i < in_ord0.size(); i++){
|
||||
idx[ord[0]] = in_ord0[i];
|
||||
idx[ord[1]] = in_ord1[j*in_ord1.size()/num_packs + k];
|
||||
Value *off = add(idx[ord[0]], mul(in_ord1[k], ld));
|
||||
Value *ptr = gep(base, off);
|
||||
store(vals_[op][idx], ptr);
|
||||
}
|
||||
add_barrier();
|
||||
for(size_t k = 0; k < out_ord1.size()/num_packs; k++)
|
||||
for(size_t i = 0; i < out_ord0.size(); i++){
|
||||
idx[ord[0]] = out_ord0[i];
|
||||
idx[ord[1]] = out_ord1[j*out_ord1.size()/num_packs + k];
|
||||
Value *off = add(idx[ord[0]], mul(out_ord1[k], ld));
|
||||
Value *ptr = gep(base, off);
|
||||
vals_[rc][idx] = load(ptr);
|
||||
}
|
||||
std::vector<int> n_reps;
|
||||
for(int i = 0; i < shape.size(); i++){
|
||||
int in_per_cta = in_layout->shape_per_cta(i);
|
||||
int out_per_cta = out_layout->shape_per_cta(i);
|
||||
int max_per_cta = std::max(in_per_cta, out_per_cta);
|
||||
n_reps.push_back(shape[i]/max_per_cta);
|
||||
}
|
||||
std::vector<std::vector<Value*>> in_ax;
|
||||
std::vector<std::vector<Value*>> out_ax;
|
||||
for(int d = 0; d < shape.size(); d++){
|
||||
in_ax.push_back(axes_.at(a_axes_->get(in, d)).values);
|
||||
out_ax.push_back(axes_.at(a_axes_->get(out, d)).values);
|
||||
}
|
||||
in_ord = in_layout->to_mma() ? out_ord : in_ord;
|
||||
out_ord = out_layout->to_mma() ? in_ord : out_ord;
|
||||
Value *in_ld = i32(shape[in_ord[0]]);
|
||||
Value *out_ld = i32(shape[out_ord[0]]);
|
||||
for(int i = 0; i < n_reps[0]; i++)
|
||||
for(int j = 0; j < n_reps[1]; j++){
|
||||
int max_ii, max_jj;
|
||||
add_barrier();
|
||||
max_ii = in_ax[0].size()/n_reps[0];
|
||||
max_jj = in_ax[1].size()/n_reps[1];
|
||||
for(int ii = 0; ii < max_ii; ii++)
|
||||
for(int jj = 0; jj < max_jj; jj++){
|
||||
// shared mem pointer
|
||||
indices_t offs = {in_ax[0][ii], in_ax[1][jj]};
|
||||
Value *off = add(offs[out_ord[0]], mul(out_ld, offs[out_ord[1]]));
|
||||
Value *ptr = gep(base, off);
|
||||
// stash value to shared mem
|
||||
indices_t idxs = {in_ax[0][i*max_ii + ii],
|
||||
in_ax[1][j*max_jj + jj]};
|
||||
store(vals_[in][idxs], ptr);
|
||||
}
|
||||
add_barrier();
|
||||
max_ii = out_ax[0].size()/n_reps[0];
|
||||
max_jj = out_ax[1].size()/n_reps[1];
|
||||
for(int ii = 0; ii < max_ii; ii++)
|
||||
for(int jj = 0; jj < max_jj; jj++){
|
||||
// shared mem pointer
|
||||
indices_t offs = {out_ax[0][ii], out_ax[1][jj]};
|
||||
Value *off = add(offs[out_ord[0]], mul(out_ld, offs[out_ord[1]]));
|
||||
Value *ptr = gep(base, off);
|
||||
// load value from shared rem
|
||||
indices_t idxs = {out_ax[0][i*max_ii + ii],
|
||||
out_ax[1][j*max_jj + jj]};
|
||||
vals_[out][idxs] = load(ptr);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
void generator::visit_cvt_layout_inst(ir::cvt_layout_inst *rc) {
|
||||
visit_layout_convert(rc, rc->get_operand(0));
|
||||
}
|
||||
|
||||
void generator::visit_masked_load_async_inst(ir::masked_load_async_inst* x){
|
||||
@@ -2325,12 +2348,12 @@ void generator::visit_layout_mma(analysis::mma_layout* layout) {
|
||||
offset_b_k_[layout] = and_(lane, _3);
|
||||
// i indices
|
||||
Value *offset_c_m = add(and_(lane, _1), offset_a_m_[layout]);
|
||||
for(unsigned m = 0; m < shape[0]; m+=layout->spt(0))
|
||||
for(unsigned m = 0; m < shape[0]; m+=layout->shape_per_cta(0))
|
||||
for(unsigned mm = 0; mm < layout->rep(0); mm++)
|
||||
idx_m.push_back(add(offset_c_m, i32(m + mm*2)));
|
||||
// j indices
|
||||
Value *offset_c_n = add(and_(lane, _2), add(off_warp_n, off_pair_n));
|
||||
for(unsigned n = 0; n < shape[1]; n+=layout->spt(1))
|
||||
for(unsigned n = 0; n < shape[1]; n+=layout->shape_per_cta(1))
|
||||
for(unsigned nn = 0; nn < layout->rep(1); nn++){
|
||||
idx_n.push_back(add(offset_c_n, i32(n + nn/2*4 + (nn%2)*2*layout->fpw(1)*layout->rep(1))));
|
||||
idx_n.push_back(add(offset_c_n, i32(n + nn/2*4 + (nn%2)*2*layout->fpw(1)*layout->rep(1) + 1)));
|
||||
@@ -2366,11 +2389,11 @@ void generator::visit_layout_mma(analysis::mma_layout* layout) {
|
||||
// c offset
|
||||
Value *off_c_m = add(udiv(lane, _4), off_warp_m);
|
||||
Value *off_c_n = add(mul(_2, urem(lane, _4)), off_warp_n);
|
||||
for(unsigned m = 0; m < shape[0]; m+=layout->spt(0)){
|
||||
for(unsigned m = 0; m < shape[0]; m+=layout->shape_per_cta(0)){
|
||||
idx_m.push_back(add(off_c_m, i32(m)));
|
||||
idx_m.push_back(add(off_c_m, i32(m + 8)));
|
||||
}
|
||||
for(unsigned n = 0; n < shape[1]; n+=layout->spt(1)){
|
||||
for(unsigned n = 0; n < shape[1]; n+=layout->shape_per_cta(1)){
|
||||
idx_n.push_back(add(off_c_n, i32(n)));
|
||||
idx_n.push_back(add(off_c_n, i32(n + 1)));
|
||||
}
|
||||
@@ -2406,11 +2429,11 @@ void generator::visit_layout_scanline(analysis::scanline_layout* layout) {
|
||||
std::string str_k = std::to_string(k);
|
||||
Value *contiguous_k = i32(nts);
|
||||
Value *scaled_thread_id = mul(thread_id[k], contiguous_k);
|
||||
unsigned per_block = nts * mts;
|
||||
unsigned per_thread = nts * shape[k] / per_block;
|
||||
unsigned per_cta = layout->shape_per_cta(k);
|
||||
unsigned per_thread = nts * shape[k] / per_cta;
|
||||
std::vector<Value*> idx_list(per_thread);
|
||||
for(unsigned n = 0 ; n < per_thread; n++){
|
||||
unsigned offset = n / nts * per_block + n % nts;
|
||||
unsigned offset = n / nts * per_cta + n % nts;
|
||||
idx_list[n] = add(scaled_thread_id, i32(offset), "idx_" + str_k + "_" + std::to_string(n));
|
||||
}
|
||||
axes_[layout->get_axis(k)] = distributed_axis{nts, idx_list, thread_id[k]};
|
||||
|
Reference in New Issue
Block a user