[CODEGEN] Performance improvement on A100 (#125)
Improved codegen for the Ampere GPUs. * Make the layout pass recognize the multistage pipelined pattern. * Now the pipeline pass can automate the multistage pipelining transformation. * Remove extra barriers (from the prefetch pass & WAR) on Ampere. * Update the code generator (generator.cc) to make Triton generate n-buffered shared memory loads/stores.
This commit is contained in:
committed by
Philippe Tillet
parent
5a51f3e529
commit
d8d6b715c8
@@ -212,18 +212,41 @@ void generator::visit_value(ir::value* v) {
|
||||
return;
|
||||
if(v->get_type()->is_block_ty()){
|
||||
if(analysis::shared_layout* layout = layouts_->get(v)->to_shared()){
|
||||
auto double_buffer = layout->get_double_buffer();
|
||||
analysis::N_buffer_info_t *n_buffer = layout->get_N_buffer();
|
||||
analysis::double_buffer_info_t *double_buffer = layout->get_double_buffer();
|
||||
|
||||
// offset
|
||||
Value *offset = nullptr;
|
||||
if(double_buffer && v == double_buffer->phi)
|
||||
offset = shared_off_[layout];
|
||||
// base pointer
|
||||
Value *ptr = shared_ptr_[layout];
|
||||
if(double_buffer && v == double_buffer->latch)
|
||||
ptr = shared_next_ptr_[layout];
|
||||
else if(double_buffer && v == double_buffer->first)
|
||||
ptr = shared_pre_ptr_[layout];
|
||||
|
||||
if (n_buffer) {
|
||||
// ptr = base (shared_ptr_[layout]) + smem_idx * size
|
||||
// read_smem_idx
|
||||
if (v == n_buffer->phi) {
|
||||
ptr = shared_ptr_[layout];
|
||||
}
|
||||
// write_smem_idx
|
||||
if (std::find(n_buffer->firsts.begin(), n_buffer->firsts.end(), v) != n_buffer->firsts.end()) {
|
||||
int write_smem_idx = /*stage_idx*/n_buffer->firsts_idx.at(v);
|
||||
int elements = write_smem_idx * layout->get_per_stage_elements();
|
||||
ptr = gep(shared_pre_ptr_[layout], i32(elements));
|
||||
} else if (v == n_buffer->latch) {
|
||||
Value* write_smem_idx = write_smem_idx_[layout];
|
||||
Value* elements = mul(write_smem_idx, i32(layout->get_per_stage_elements()));
|
||||
ptr = gep(shared_pre_ptr_[layout], elements);
|
||||
}
|
||||
} else if (double_buffer) {
|
||||
if(v == double_buffer->phi)
|
||||
offset = shared_off_[layout];
|
||||
if(v == double_buffer->latch)
|
||||
ptr = shared_next_ptr_[layout];
|
||||
else if(v == double_buffer->first)
|
||||
ptr = shared_pre_ptr_[layout];
|
||||
} // else do nothing
|
||||
// what visit_dot & vist_cts & ... see
|
||||
shmems_[v] = ptr;
|
||||
// now only latches have offset (PHINode), only used by finalize_share_layout()
|
||||
shoffs_[v] = offset;
|
||||
}
|
||||
}
|
||||
@@ -1223,24 +1246,21 @@ void generator::visit_mma884(ir::dot_inst* C, ir::value *A, ir::value *B, ir::va
|
||||
* \brief Code Generation for `mma.16816` (A100)
|
||||
*/
|
||||
//TODO: clean-up
|
||||
void generator::visit_mma16816(ir::dot_inst* dot, ir::value *A, ir::value *B, ir::value *D, unsigned NK) {
|
||||
const auto& shapes = dot->get_type()->get_block_shapes();
|
||||
|
||||
void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir::value *D, unsigned NK) {
|
||||
const std::vector<unsigned>& shapes = C->get_type()->get_block_shapes();
|
||||
std::map<std::vector<Value*>, std::vector<Value*>> fcs;
|
||||
|
||||
for(indices_t idx: idxs_.at(dot)){
|
||||
for(indices_t idx: idxs_.at(C)){
|
||||
std::vector<Value*> key(idx.size() - 2);
|
||||
std::copy(idx.begin() + 2, idx.end(), key.begin());
|
||||
fcs[key].push_back(vals_[D][idx]);
|
||||
};
|
||||
|
||||
auto shape_a = A->get_type()->get_block_shapes();
|
||||
auto shape_b = B->get_type()->get_block_shapes();
|
||||
auto ord_a = layouts_->get(A)->get_order();
|
||||
auto ord_b = layouts_->get(B)->get_order();
|
||||
analysis::mma_layout* layout = layouts_->get(dot)->to_mma();
|
||||
analysis::shared_layout* layout_a = (analysis::shared_layout*)layouts_->get(dot->get_operand(0));
|
||||
analysis::shared_layout* layout_b = (analysis::shared_layout*)layouts_->get(dot->get_operand(1));
|
||||
analysis::mma_layout* layout = layouts_->get(C)->to_mma();
|
||||
analysis::shared_layout* layout_a = (analysis::shared_layout*)layouts_->get(C->get_operand(0));
|
||||
analysis::shared_layout* layout_b = (analysis::shared_layout*)layouts_->get(C->get_operand(1));
|
||||
bool is_a_row = ord_a[0] == 1;
|
||||
bool is_b_row = ord_b[0] == 1;
|
||||
std::string a_trans = is_a_row ? "" : ".trans";
|
||||
@@ -1264,8 +1284,6 @@ void generator::visit_mma16816(ir::dot_inst* dot, ir::value *A, ir::value *B, ir
|
||||
int vec_a = 8;
|
||||
int vec_b = 8;
|
||||
|
||||
|
||||
|
||||
Type *fp32_ty = f32_ty;
|
||||
Type *fp16x2_ty = vec_ty(f16_ty, 2);
|
||||
Type *fp16x2_pack4_ty = StructType::get(*ctx_, std::vector<llvm::Type*>{fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty});
|
||||
@@ -1276,7 +1294,6 @@ void generator::visit_mma16816(ir::dot_inst* dot, ir::value *A, ir::value *B, ir
|
||||
std::map<std::pair<unsigned, unsigned>, std::pair<Value*, Value*>> ha;
|
||||
std::map<std::pair<unsigned, unsigned>, Value*> hb;
|
||||
|
||||
|
||||
BasicBlock* CurrBB = builder_->GetInsertBlock();
|
||||
BasicBlock* FirstBB = &CurrBB->getParent()->getEntryBlock();
|
||||
builder_->SetInsertPoint(FirstBB->getTerminator());
|
||||
@@ -1339,66 +1356,167 @@ void generator::visit_mma16816(ir::dot_inst* dot, ir::value *A, ir::value *B, ir
|
||||
"{$0, $1, $2, $3}, "
|
||||
"{$4, $5, $6, $7}, "
|
||||
"{$8, $9}, "
|
||||
"{$10, $11, $12, $13};", "=f,=f,=f,=f,r,r,r,r,r,r,0,1,2,3", false);
|
||||
"{$10, $11, $12, $13};",
|
||||
"=f,=f,=f,=f,r,r,r,r,r,r,0,1,2,3", true);
|
||||
|
||||
unsigned num_rep_0 = shapes[0] / layout->spt(0);
|
||||
unsigned num_rep_1 = shapes[1] / layout->spt(1);
|
||||
for(unsigned K = 0; K < NK; K += 16)
|
||||
for(unsigned m = 0; m < num_rep_0; m++)
|
||||
for(unsigned n = 0; n < num_rep_1; n++){
|
||||
if(ha.find({m, K}) == ha.end()){
|
||||
Value* ptra = ptrs_a[(is_a_row ? K/16 : m) % num_ptr_a];
|
||||
|
||||
// create mma & unpack result
|
||||
auto call_mma = [&](unsigned m, unsigned n, unsigned K) {
|
||||
unsigned cols_per_thread = num_rep_0 * 2;
|
||||
std::vector<size_t> idx = {
|
||||
(m*2 + 0) + (n*2 + 0)*cols_per_thread,
|
||||
(m*2 + 0) + (n*2 + 1)*cols_per_thread,
|
||||
(m*2 + 1) + (n*2 + 0)*cols_per_thread,
|
||||
(m*2 + 1) + (n*2 + 1)*cols_per_thread
|
||||
};
|
||||
Value *nc = call(mma_ty, mma_fn, {ha[{m, K}].first, ha[{m, K}].second,ha[{m, K+8}].first, ha[{m, K+8}].second,
|
||||
hb[{n, K}], hb[{n, K+8}],
|
||||
fc[idx[0]], fc[idx[1]], fc[idx[2]], fc[idx[3]]});
|
||||
fc[idx[0]] = extract_val(nc, std::vector<unsigned>{0});
|
||||
fc[idx[1]] = extract_val(nc, std::vector<unsigned>{1});
|
||||
fc[idx[2]] = extract_val(nc, std::vector<unsigned>{2});
|
||||
fc[idx[3]] = extract_val(nc, std::vector<unsigned>{3});
|
||||
};
|
||||
|
||||
ir::phi_node* phiA = dynamic_cast<ir::phi_node*>(A);
|
||||
ir::phi_node* phiB = dynamic_cast<ir::phi_node*>(B);
|
||||
|
||||
auto register_lds =
|
||||
[&](decltype(ha)& vals, int m, int K, int inc, Value* val0, Value *val1, bool is_prefetch) {
|
||||
if (K <= 8 && is_prefetch) {
|
||||
ir::basic_block* inc_block = phiA->get_incoming_block(inc);
|
||||
lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{m, K}].first, val0, inc_block));
|
||||
lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{m, K}].second, val1, inc_block));
|
||||
} else
|
||||
vals[{m, K}] = {val0, val1};
|
||||
};
|
||||
|
||||
auto register_lds2 =
|
||||
[&](decltype(hb)& vals, int m, int K, int inc, Value* val, bool is_prefetch) {
|
||||
if (K <= 8 && is_prefetch) {
|
||||
ir::basic_block* inc_block = phiA->get_incoming_block(inc);
|
||||
lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{m, K}], val, inc_block));
|
||||
} else
|
||||
vals[{m, K}] = val;
|
||||
};
|
||||
|
||||
auto load_a = [&](int m, int K, int inc, bool is_prefetch) {
|
||||
int offidx = (is_a_row ? K/16 : m) % num_ptr_a;
|
||||
Value* ptra;
|
||||
if(K == 0 && is_prefetch){
|
||||
if(inc == 0)
|
||||
ptra = gep(shared_pre_ptr_[layout_a], off_a[offidx]);
|
||||
else
|
||||
ptra = gep(shared_next_ptr_[layout_a], off_a[offidx]);
|
||||
}
|
||||
else
|
||||
ptra = ptrs_a[offidx];
|
||||
int step_am = is_a_row ? m : m / (num_ptr_a)*(num_ptr_a);
|
||||
int step_ak = is_a_row ? K / (num_ptr_a*16)*(num_ptr_a*16) : K;
|
||||
InlineAsm *ld_a0_fn = InlineAsm::get(ld_x4_ty, "ldmatrix.sync.aligned.m8n8.x4" + a_trans + ".shared.b16 "
|
||||
"{$0, $1, $2, $3}, [$4 + " + std::to_string(2*step_am*16*layout->wpt(0)*stride_a_m + 2*step_ak*stride_a_k) + "];", "=r,=r,=r,=r,r", false);
|
||||
"{$0, $1, $2, $3}, [$4 + " +
|
||||
std::to_string(2*step_am*16*layout->wpt(0)*stride_a_m + 2*step_ak*stride_a_k) + "];",
|
||||
"=r,=r,=r,=r,r", true);
|
||||
Value *haa = call(ld_x4_ty, ld_a0_fn, {ptra});
|
||||
if(K == 0 && inc == 1 && is_prefetch)
|
||||
prefetch_latch_to_bb_[phiA->get_incoming_value(1)].push_back(haa);
|
||||
Value *ha0 = extract_val(haa, std::vector<unsigned>{0});
|
||||
Value *ha1 = extract_val(haa, std::vector<unsigned>{1});
|
||||
Value *ha2 = extract_val(haa, std::vector<unsigned>{2});
|
||||
Value *ha3 = extract_val(haa, std::vector<unsigned>{3});
|
||||
ha[{m, K}] = std::make_pair(ha0, ha1);
|
||||
ha[{m, K+8}] = std::make_pair(ha2, ha3);
|
||||
}
|
||||
if(hb.find({n, K})==hb.end()){
|
||||
Value* ptrb = ptrs_b[(is_b_row ? n : K/16) % num_ptr_b];
|
||||
register_lds(ha, m, K, inc, ha0, ha1, is_prefetch);
|
||||
register_lds(ha, m, K + 8, inc, ha2, ha3, is_prefetch);
|
||||
};
|
||||
|
||||
auto load_b = [&](int n, int K, int inc, bool is_prefetch) {
|
||||
int offidx = (is_b_row ? n : K/16) % num_ptr_b;
|
||||
Value* ptrb;
|
||||
if(K == 0 && is_prefetch){
|
||||
if(inc == 0)
|
||||
ptrb = gep(shared_pre_ptr_[layout_b], off_b[offidx]);
|
||||
else
|
||||
ptrb = gep(shared_next_ptr_[layout_b], off_b[offidx]);
|
||||
}
|
||||
else
|
||||
ptrb = ptrs_b[offidx];
|
||||
int step_bn = is_b_row ? n / (num_ptr_b)*(num_ptr_b) : n;
|
||||
int step_bk = is_b_row ? K : K / (num_ptr_b*8)*(num_ptr_b*8);
|
||||
InlineAsm *ld_b_fn = InlineAsm::get(ld_x4_ty, "ldmatrix.sync.aligned.m8n8.x4" + b_trans + ".shared.b16 "
|
||||
"{$0, $1, $2, $3}, [$4 + " + std::to_string(2*step_bn*8*layout->wpt(1)*stride_b_n + 2*step_bk*stride_b_k) + "];", "=r,=r,=r,=r,r", false);
|
||||
"{$0, $1, $2, $3}, [$4 + " +
|
||||
std::to_string(2*step_bn*8*layout->wpt(1)*stride_b_n + 2*step_bk*stride_b_k) + "];",
|
||||
"=r,=r,=r,=r,r", true);
|
||||
Value *hbb = call(ld_x4_ty, ld_b_fn, {ptrb});
|
||||
if(K == 0 && inc == 1 && is_prefetch)
|
||||
prefetch_latch_to_bb_[phiB->get_incoming_value(1)].push_back(hbb);
|
||||
Value *hb0 = extract_val(hbb, std::vector<unsigned>{0});
|
||||
Value *hb1 = extract_val(hbb, std::vector<unsigned>{1});
|
||||
Value *hb2 = extract_val(hbb, std::vector<unsigned>{2});
|
||||
Value *hb3 = extract_val(hbb, std::vector<unsigned>{3});
|
||||
hb[{n, K}] = hb0;
|
||||
hb[{n+1, K}] = hb2;
|
||||
hb[{n, K+8}] = hb1;
|
||||
hb[{n+1, K+8}] = hb3;
|
||||
}
|
||||
unsigned cols_per_thread = num_rep_0 * 2;
|
||||
std::vector<size_t> idx = {
|
||||
(m*2 + 0) + (n*2 + 0)*cols_per_thread,
|
||||
(m*2 + 0) + (n*2 + 1)*cols_per_thread,
|
||||
(m*2 + 1) + (n*2 + 0)*cols_per_thread,
|
||||
(m*2 + 1) + (n*2 + 1)*cols_per_thread
|
||||
};
|
||||
Value *nc = call(mma_ty, mma_fn, {ha[{m, K}].first, ha[{m, K}].second,ha[{m, K+8}].first, ha[{m, K+8}].second,
|
||||
hb[{n, K}], hb[{n, K+8}],
|
||||
fc[idx[0]], fc[idx[1]], fc[idx[2]], fc[idx[3]]});
|
||||
fc[idx[0]] = extract_val(nc, std::vector<unsigned>{0});
|
||||
fc[idx[1]] = extract_val(nc, std::vector<unsigned>{1});
|
||||
fc[idx[2]] = extract_val(nc, std::vector<unsigned>{2});
|
||||
fc[idx[3]] = extract_val(nc, std::vector<unsigned>{3});
|
||||
}
|
||||
register_lds2(hb, n, K, inc, hb0, is_prefetch);
|
||||
register_lds2(hb, n+1, K, inc, hb2, is_prefetch);
|
||||
register_lds2(hb, n, K+8, inc, hb1, is_prefetch);
|
||||
register_lds2(hb, n+1, K+8, inc, hb3, is_prefetch);
|
||||
};
|
||||
|
||||
if (C->is_prefetched()) {
|
||||
// create phis
|
||||
builder_->SetInsertPoint(CurrBB->getFirstNonPHI());
|
||||
for(unsigned m = 0; m < num_rep_0; m++){
|
||||
ha[{m, 0}].first = phi(fp16x2_ty, 2);
|
||||
ha[{m, 0}].second = phi(fp16x2_ty, 2);
|
||||
ha[{m, 8}].first = phi(fp16x2_ty, 2);
|
||||
ha[{m, 8}].second = phi(fp16x2_ty, 2);
|
||||
}
|
||||
for(unsigned n = 0; n < num_rep_1; n+=2){
|
||||
hb[{n, 0}] = phi(fp16x2_ty, 2);
|
||||
hb[{n+1, 0}] = phi(fp16x2_ty, 2);
|
||||
hb[{n, 8}] = phi(fp16x2_ty, 2);
|
||||
hb[{n+1, 8}] = phi(fp16x2_ty, 2);
|
||||
}
|
||||
// insert prefetched lds at the end of loop header
|
||||
builder_->SetInsertPoint(bbs_[phiA->get_incoming_block(0)]->getTerminator());
|
||||
for(unsigned m = 0; m < num_rep_0; m++)
|
||||
load_a(m, 0, 0, true);
|
||||
for(unsigned n = 0; n < num_rep_1; n+=2)
|
||||
load_b(n, 0, 0, true);
|
||||
// update accumulators
|
||||
builder_->SetInsertPoint(CurrBB);
|
||||
for(unsigned K = 0; K < NK; K += 16){
|
||||
int NEXTK = (K + 16) % NK;
|
||||
// prefetch A
|
||||
for(unsigned m = 0; m < num_rep_0; m++)
|
||||
load_a(m, NEXTK, 1, true);
|
||||
// prefetch B
|
||||
for(unsigned n = 0; n < num_rep_1; n+=2)
|
||||
load_b(n, NEXTK, 1, true);
|
||||
// tensor core ops
|
||||
for(unsigned m = 0; m < num_rep_0; m++)
|
||||
for(unsigned n = 0; n < num_rep_1; n++){
|
||||
call_mma(m, n, K);
|
||||
}
|
||||
}
|
||||
}
|
||||
else{
|
||||
for(unsigned K = 0; K < NK; K += 16)
|
||||
for(unsigned m = 0; m < num_rep_0; m++)
|
||||
for(unsigned n = 0; n < num_rep_1; n++){
|
||||
if(ha.find({m, K}) == ha.end())
|
||||
load_a(m, K, 0, false);
|
||||
if(hb.find({n, K})==hb.end())
|
||||
load_b(n, K, 0, false);
|
||||
call_mma(m, n, K);
|
||||
}
|
||||
}
|
||||
// write back
|
||||
unsigned i = 0;
|
||||
for(indices_t idx: idxs_.at(dot)){
|
||||
for(indices_t idx: idxs_.at(C)){
|
||||
std::vector<Value*> key(idx.size() - 2);
|
||||
std::copy(idx.begin() + 2, idx.end(), key.begin());
|
||||
if(i >= fcs.at(key).size())
|
||||
i = 0;
|
||||
vals_[dot][idx] = fcs.at(key)[i++];
|
||||
vals_[C][idx] = fcs.at(key)[i++];
|
||||
};
|
||||
}
|
||||
|
||||
@@ -2252,8 +2370,35 @@ void generator::visit_layout_scanline(analysis::scanline_layout* layout) {
|
||||
void generator::visit_layout_shared(analysis::shared_layout* layout) {
|
||||
Type* ty = cvt(layout->get_type());
|
||||
PointerType *ptr_ty = ty->getPointerTo(shmem_->getType()->getPointerAddressSpace());
|
||||
// double-buffered
|
||||
if(layout->get_double_buffer()) {
|
||||
if (layout->get_N_buffer()) {
|
||||
// create pointers
|
||||
shared_pre_ptr_[layout] = gep(shmem_, i32(alloc_->offset(layout)));
|
||||
shared_pre_ptr_[layout] = bit_cast(shared_pre_ptr_[layout], ptr_ty);
|
||||
|
||||
BasicBlock *current = builder_->GetInsertBlock();
|
||||
|
||||
auto info = *layout->get_N_buffer();
|
||||
ir::phi_node *phi = info.phi;
|
||||
BasicBlock *parent = bbs_.at(phi->get_parent());
|
||||
if(parent->empty())
|
||||
builder_->SetInsertPoint(parent);
|
||||
else if (const Instruction *first_non_phi = &*parent->getFirstNonPHI()) {
|
||||
builder_->SetInsertPoint(&*parent->getFirstNonPHI());
|
||||
} else
|
||||
builder_->SetInsertPoint(parent);
|
||||
|
||||
// create smem_idx
|
||||
read_smem_idx_[layout] = phi(i32_ty, 2);
|
||||
write_smem_idx_[layout] = phi(i32_ty, 2);
|
||||
|
||||
// create pointers
|
||||
// ptr of the current iteration
|
||||
shared_ptr_[layout] = phi(ptr_ty, 2);
|
||||
// ptr of the next iteration
|
||||
shared_next_ptr_[layout] = phi(ptr_ty, 2);
|
||||
|
||||
builder_->SetInsertPoint(current);
|
||||
} else if(layout->get_double_buffer()) {
|
||||
BasicBlock *current = builder_->GetInsertBlock();
|
||||
auto info = *layout->get_double_buffer();
|
||||
ir::phi_node *phi = info.phi;
|
||||
@@ -2269,8 +2414,7 @@ void generator::visit_layout_shared(analysis::shared_layout* layout) {
|
||||
shared_off_[layout] = phi(i32_ty, 2);
|
||||
shared_next_ptr_[layout] = gep(shared_ptr_[layout], shared_off_[layout], "next_ptr");
|
||||
builder_->SetInsertPoint(current);
|
||||
}
|
||||
else{
|
||||
} else{
|
||||
size_t offset = alloc_->offset(layout);
|
||||
shared_ptr_[layout] = gep(shmem_, i32(offset));
|
||||
shared_ptr_[layout] = bit_cast(shared_ptr_[layout], ptr_ty);
|
||||
@@ -2354,7 +2498,67 @@ void generator::init_idx(ir::value *v) {
|
||||
}
|
||||
|
||||
void generator::finalize_shared_layout(analysis::shared_layout *shared) {
|
||||
if(shared->get_double_buffer()) {
|
||||
if (auto n_buffer = shared->get_N_buffer()) {
|
||||
// if (*_smem_idx == #stages-1) {
|
||||
// *_smem_idx = 0;
|
||||
// } else *_smem_idx++;
|
||||
auto finalize_smem_idx = [&](auto &smem_idx, int init_stage) {
|
||||
// insert point
|
||||
Value *idx = smem_idx[shared];
|
||||
builder_->SetInsertPoint(bbs_.at(n_buffer->phi->get_parent())->getTerminator());
|
||||
Value *cond = icmp_eq(idx, i32(shared->get_num_stages()-1));
|
||||
PHINode *_ret = phi(i32_ty, 2);
|
||||
Instruction *then_term = nullptr;
|
||||
Instruction *else_term = nullptr;
|
||||
Instruction *dummy = builder_->CreateRet(nullptr);
|
||||
llvm::SplitBlockAndInsertIfThenElse(cond, _ret, &then_term, &else_term, nullptr);
|
||||
dummy->removeFromParent();
|
||||
builder_->SetInsertPoint(then_term);
|
||||
Value *zero_smem_idx = i32(0);
|
||||
builder_->SetInsertPoint(else_term);
|
||||
Value *inc_smem_idx = add(idx, i32(1));
|
||||
builder_->SetInsertPoint(_ret->getParent());
|
||||
_ret->addIncoming(zero_smem_idx, then_term->getParent());
|
||||
_ret->addIncoming(inc_smem_idx, else_term->getParent());
|
||||
// update ir::bb -> llvm::bb mapping
|
||||
bbs_.at(n_buffer->phi->get_parent()) = builder_->GetInsertBlock();
|
||||
// idx = init_stage;
|
||||
// loop: ...
|
||||
if (auto idx_phi = llvm::dyn_cast<PHINode>(smem_idx[shared])) {
|
||||
idx_phi->addIncoming(i32(init_stage), bbs_.at(n_buffer->phi->get_incoming_block(0)));
|
||||
idx_phi->addIncoming(_ret, bbs_.at(n_buffer->phi->get_incoming_block(1)));
|
||||
} else
|
||||
throw std::runtime_error("Should be PHINode");
|
||||
};
|
||||
|
||||
// read_smem_idx is used by next_ptr to compute the next iteration value, so init value is 2
|
||||
finalize_smem_idx(read_smem_idx_, 2);
|
||||
finalize_smem_idx(write_smem_idx_, shared->get_num_stages()-1);
|
||||
|
||||
// finalize pointers
|
||||
ir::phi_node *pn = n_buffer->phi;
|
||||
BasicBlock *header = bbs_.at(pn->get_incoming_block(0));
|
||||
BasicBlock *loop = bbs_.at(pn->get_incoming_block(1));
|
||||
// %curr_ptr = phi %shared_pre_ptr, %next_ptr
|
||||
// %next_ptr = phi %shared_pre_ptr[+1], (gep(%pre_ptr, read_smem_idx*per_stage_size))
|
||||
if (auto curr_ptr = dyn_cast<PHINode>(shared_ptr_[shared])) {
|
||||
curr_ptr->addIncoming(shared_pre_ptr_[shared], header);
|
||||
curr_ptr->addIncoming(shared_next_ptr_[shared], loop);
|
||||
} else
|
||||
throw std::runtime_error("Should be PHINode");
|
||||
|
||||
BasicBlock *current = builder_->GetInsertBlock();
|
||||
builder_->SetInsertPoint(header->getTerminator());
|
||||
Value *next_ptr_header = gep(shared_pre_ptr_[shared], i32(shared->get_per_stage_elements()));
|
||||
builder_->SetInsertPoint(current->getTerminator());
|
||||
|
||||
assert(isa<PHINode>(shared_next_ptr_[shared]));
|
||||
static_cast<PHINode*>(shared_next_ptr_[shared])->addIncoming(next_ptr_header, header);
|
||||
|
||||
Value *lds_offset = mul(read_smem_idx_[shared], i32(shared->get_per_stage_elements()));
|
||||
Value *next_ptr = gep(shared_pre_ptr_[shared], lds_offset);
|
||||
static_cast<PHINode*>(shared_next_ptr_[shared])->addIncoming(next_ptr, loop);
|
||||
} else if(shared->get_double_buffer()) {
|
||||
auto info = *shared->get_double_buffer();
|
||||
ir::phi_node *phi = info.phi;
|
||||
PHINode *ptr = (PHINode*)shmems_[phi];
|
||||
|
Reference in New Issue
Block a user