[CODEGEN] Performance improvement on A100 (#125)
Improved codegen for the Ampere GPUs.
    * Make the layout pass recognize the multistage pipelined pattern.
    * Now the pipeline pass can automate the multistage pipelining transformation.
    * Remove extra barriers (from the prefetch pass & WAR) on Ampere.
    * Update the code generator (generator.cc) to make Triton generate n-buffered shared memory loads/stores.
			
			
This commit is contained in:
		
				
					committed by
					
						
						Philippe Tillet
					
				
			
			
				
	
			
			
			
						parent
						
							5a51f3e529
						
					
				
				
					commit
					d8d6b715c8
				
			@@ -212,18 +212,41 @@ void generator::visit_value(ir::value* v) {
 | 
			
		||||
    return;
 | 
			
		||||
  if(v->get_type()->is_block_ty()){
 | 
			
		||||
    if(analysis::shared_layout* layout = layouts_->get(v)->to_shared()){
 | 
			
		||||
      auto double_buffer = layout->get_double_buffer();
 | 
			
		||||
      analysis::N_buffer_info_t *n_buffer = layout->get_N_buffer();
 | 
			
		||||
      analysis::double_buffer_info_t *double_buffer = layout->get_double_buffer();
 | 
			
		||||
 | 
			
		||||
      // offset
 | 
			
		||||
      Value *offset = nullptr;
 | 
			
		||||
      if(double_buffer && v == double_buffer->phi)
 | 
			
		||||
        offset = shared_off_[layout];
 | 
			
		||||
      // base pointer
 | 
			
		||||
      Value *ptr = shared_ptr_[layout];
 | 
			
		||||
      if(double_buffer && v == double_buffer->latch)
 | 
			
		||||
        ptr = shared_next_ptr_[layout];
 | 
			
		||||
      else if(double_buffer && v == double_buffer->first)
 | 
			
		||||
        ptr = shared_pre_ptr_[layout];
 | 
			
		||||
 | 
			
		||||
      if (n_buffer) {
 | 
			
		||||
        // ptr = base (shared_ptr_[layout]) + smem_idx * size
 | 
			
		||||
        // read_smem_idx
 | 
			
		||||
        if (v == n_buffer->phi) {
 | 
			
		||||
          ptr = shared_ptr_[layout];
 | 
			
		||||
        }
 | 
			
		||||
        // write_smem_idx
 | 
			
		||||
        if (std::find(n_buffer->firsts.begin(), n_buffer->firsts.end(), v) != n_buffer->firsts.end()) {
 | 
			
		||||
          int write_smem_idx = /*stage_idx*/n_buffer->firsts_idx.at(v);
 | 
			
		||||
          int elements = write_smem_idx * layout->get_per_stage_elements();
 | 
			
		||||
          ptr = gep(shared_pre_ptr_[layout], i32(elements));
 | 
			
		||||
        } else if (v == n_buffer->latch) {
 | 
			
		||||
          Value* write_smem_idx = write_smem_idx_[layout];
 | 
			
		||||
          Value* elements = mul(write_smem_idx, i32(layout->get_per_stage_elements()));
 | 
			
		||||
          ptr = gep(shared_pre_ptr_[layout], elements);
 | 
			
		||||
        }
 | 
			
		||||
      } else if (double_buffer) {
 | 
			
		||||
        if(v == double_buffer->phi)
 | 
			
		||||
          offset = shared_off_[layout];
 | 
			
		||||
        if(v == double_buffer->latch)
 | 
			
		||||
          ptr = shared_next_ptr_[layout];
 | 
			
		||||
        else if(v == double_buffer->first)
 | 
			
		||||
          ptr = shared_pre_ptr_[layout];
 | 
			
		||||
      } // else do nothing
 | 
			
		||||
      // what visit_dot & vist_cts & ... see
 | 
			
		||||
      shmems_[v] = ptr;
 | 
			
		||||
      // now only latches have offset (PHINode), only used by finalize_share_layout()
 | 
			
		||||
      shoffs_[v] = offset;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
@@ -1223,24 +1246,21 @@ void generator::visit_mma884(ir::dot_inst* C, ir::value *A, ir::value *B, ir::va
 | 
			
		||||
 * \brief Code Generation for `mma.16816` (A100)
 | 
			
		||||
 */
 | 
			
		||||
//TODO: clean-up
 | 
			
		||||
void generator::visit_mma16816(ir::dot_inst* dot, ir::value *A, ir::value *B, ir::value *D, unsigned NK) {
 | 
			
		||||
  const auto& shapes = dot->get_type()->get_block_shapes();
 | 
			
		||||
 | 
			
		||||
void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir::value *D, unsigned NK) {
 | 
			
		||||
  const std::vector<unsigned>& shapes = C->get_type()->get_block_shapes();
 | 
			
		||||
  std::map<std::vector<Value*>, std::vector<Value*>> fcs;
 | 
			
		||||
 | 
			
		||||
  for(indices_t idx: idxs_.at(dot)){
 | 
			
		||||
  for(indices_t idx: idxs_.at(C)){
 | 
			
		||||
    std::vector<Value*> key(idx.size() - 2);
 | 
			
		||||
    std::copy(idx.begin() + 2, idx.end(), key.begin());
 | 
			
		||||
    fcs[key].push_back(vals_[D][idx]);
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  auto shape_a = A->get_type()->get_block_shapes();
 | 
			
		||||
  auto shape_b = B->get_type()->get_block_shapes();
 | 
			
		||||
  auto ord_a = layouts_->get(A)->get_order();
 | 
			
		||||
  auto ord_b = layouts_->get(B)->get_order();
 | 
			
		||||
  analysis::mma_layout* layout = layouts_->get(dot)->to_mma();
 | 
			
		||||
  analysis::shared_layout* layout_a = (analysis::shared_layout*)layouts_->get(dot->get_operand(0));
 | 
			
		||||
  analysis::shared_layout* layout_b = (analysis::shared_layout*)layouts_->get(dot->get_operand(1));
 | 
			
		||||
  analysis::mma_layout* layout = layouts_->get(C)->to_mma();
 | 
			
		||||
  analysis::shared_layout* layout_a = (analysis::shared_layout*)layouts_->get(C->get_operand(0));
 | 
			
		||||
  analysis::shared_layout* layout_b = (analysis::shared_layout*)layouts_->get(C->get_operand(1));
 | 
			
		||||
  bool is_a_row = ord_a[0] == 1;
 | 
			
		||||
  bool is_b_row = ord_b[0] == 1;
 | 
			
		||||
  std::string a_trans = is_a_row ? "" : ".trans";
 | 
			
		||||
@@ -1264,8 +1284,6 @@ void generator::visit_mma16816(ir::dot_inst* dot, ir::value *A, ir::value *B, ir
 | 
			
		||||
  int vec_a = 8;
 | 
			
		||||
  int vec_b = 8;
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
  Type *fp32_ty = f32_ty;
 | 
			
		||||
  Type *fp16x2_ty = vec_ty(f16_ty, 2);
 | 
			
		||||
  Type *fp16x2_pack4_ty = StructType::get(*ctx_, std::vector<llvm::Type*>{fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty});
 | 
			
		||||
@@ -1276,7 +1294,6 @@ void generator::visit_mma16816(ir::dot_inst* dot, ir::value *A, ir::value *B, ir
 | 
			
		||||
  std::map<std::pair<unsigned, unsigned>, std::pair<Value*, Value*>> ha;
 | 
			
		||||
  std::map<std::pair<unsigned, unsigned>, Value*> hb;
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
  BasicBlock* CurrBB = builder_->GetInsertBlock();
 | 
			
		||||
  BasicBlock* FirstBB = &CurrBB->getParent()->getEntryBlock();
 | 
			
		||||
  builder_->SetInsertPoint(FirstBB->getTerminator());
 | 
			
		||||
@@ -1339,66 +1356,167 @@ void generator::visit_mma16816(ir::dot_inst* dot, ir::value *A, ir::value *B, ir
 | 
			
		||||
                                             "{$0, $1, $2, $3}, "
 | 
			
		||||
                                             "{$4, $5, $6, $7}, "
 | 
			
		||||
                                             "{$8, $9}, "
 | 
			
		||||
                                             "{$10, $11, $12, $13};", "=f,=f,=f,=f,r,r,r,r,r,r,0,1,2,3", false);
 | 
			
		||||
                                             "{$10, $11, $12, $13};", 
 | 
			
		||||
                                             "=f,=f,=f,=f,r,r,r,r,r,r,0,1,2,3", true);
 | 
			
		||||
 | 
			
		||||
  unsigned num_rep_0 = shapes[0] / layout->spt(0);
 | 
			
		||||
  unsigned num_rep_1 = shapes[1] / layout->spt(1);
 | 
			
		||||
  for(unsigned K = 0; K < NK; K += 16)
 | 
			
		||||
  for(unsigned m = 0; m < num_rep_0; m++)
 | 
			
		||||
  for(unsigned n = 0; n < num_rep_1; n++){
 | 
			
		||||
    if(ha.find({m, K}) == ha.end()){
 | 
			
		||||
      Value* ptra = ptrs_a[(is_a_row ? K/16 : m) % num_ptr_a];
 | 
			
		||||
 | 
			
		||||
  // create mma & unpack result
 | 
			
		||||
  auto call_mma = [&](unsigned m, unsigned n, unsigned K) {
 | 
			
		||||
      unsigned cols_per_thread = num_rep_0 * 2;
 | 
			
		||||
      std::vector<size_t> idx = {
 | 
			
		||||
        (m*2 + 0) + (n*2 + 0)*cols_per_thread,
 | 
			
		||||
        (m*2 + 0) + (n*2 + 1)*cols_per_thread,
 | 
			
		||||
        (m*2 + 1) + (n*2 + 0)*cols_per_thread,
 | 
			
		||||
        (m*2 + 1) + (n*2 + 1)*cols_per_thread
 | 
			
		||||
      };
 | 
			
		||||
      Value *nc = call(mma_ty, mma_fn, {ha[{m, K}].first, ha[{m, K}].second,ha[{m, K+8}].first, ha[{m, K+8}].second,
 | 
			
		||||
                                                        hb[{n, K}], hb[{n, K+8}],
 | 
			
		||||
                                                        fc[idx[0]], fc[idx[1]], fc[idx[2]], fc[idx[3]]});
 | 
			
		||||
      fc[idx[0]] = extract_val(nc, std::vector<unsigned>{0});
 | 
			
		||||
      fc[idx[1]] = extract_val(nc, std::vector<unsigned>{1});
 | 
			
		||||
      fc[idx[2]] = extract_val(nc, std::vector<unsigned>{2});
 | 
			
		||||
      fc[idx[3]] = extract_val(nc, std::vector<unsigned>{3});
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  ir::phi_node* phiA = dynamic_cast<ir::phi_node*>(A);
 | 
			
		||||
  ir::phi_node* phiB = dynamic_cast<ir::phi_node*>(B);
 | 
			
		||||
 | 
			
		||||
  auto register_lds =
 | 
			
		||||
    [&](decltype(ha)& vals, int m, int K, int inc, Value* val0, Value *val1, bool is_prefetch) {
 | 
			
		||||
      if (K <= 8 && is_prefetch) {
 | 
			
		||||
        ir::basic_block* inc_block = phiA->get_incoming_block(inc);
 | 
			
		||||
        lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{m, K}].first, val0, inc_block));
 | 
			
		||||
        lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{m, K}].second, val1, inc_block));
 | 
			
		||||
      } else
 | 
			
		||||
        vals[{m, K}] = {val0, val1};
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  auto register_lds2 =
 | 
			
		||||
    [&](decltype(hb)& vals, int m, int K, int inc, Value* val, bool is_prefetch) {
 | 
			
		||||
      if (K <= 8 && is_prefetch) {
 | 
			
		||||
        ir::basic_block* inc_block = phiA->get_incoming_block(inc);
 | 
			
		||||
        lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{m, K}], val, inc_block));
 | 
			
		||||
      } else
 | 
			
		||||
        vals[{m, K}] = val;
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  auto load_a = [&](int m, int K, int inc, bool is_prefetch) {
 | 
			
		||||
      int offidx = (is_a_row ? K/16 : m) % num_ptr_a;
 | 
			
		||||
      Value* ptra;
 | 
			
		||||
      if(K == 0 && is_prefetch){
 | 
			
		||||
        if(inc == 0)
 | 
			
		||||
          ptra = gep(shared_pre_ptr_[layout_a], off_a[offidx]);
 | 
			
		||||
        else
 | 
			
		||||
          ptra = gep(shared_next_ptr_[layout_a], off_a[offidx]);
 | 
			
		||||
      }
 | 
			
		||||
      else
 | 
			
		||||
        ptra = ptrs_a[offidx];
 | 
			
		||||
      int step_am = is_a_row ? m : m / (num_ptr_a)*(num_ptr_a);
 | 
			
		||||
      int step_ak = is_a_row ? K / (num_ptr_a*16)*(num_ptr_a*16) : K;
 | 
			
		||||
      InlineAsm *ld_a0_fn = InlineAsm::get(ld_x4_ty, "ldmatrix.sync.aligned.m8n8.x4" + a_trans + ".shared.b16 "
 | 
			
		||||
                                                "{$0, $1, $2, $3}, [$4 + " + std::to_string(2*step_am*16*layout->wpt(0)*stride_a_m + 2*step_ak*stride_a_k) + "];", "=r,=r,=r,=r,r", false);
 | 
			
		||||
                                                "{$0, $1, $2, $3}, [$4 + " + 
 | 
			
		||||
                                                std::to_string(2*step_am*16*layout->wpt(0)*stride_a_m + 2*step_ak*stride_a_k) + "];", 
 | 
			
		||||
                                                "=r,=r,=r,=r,r", true);
 | 
			
		||||
      Value *haa = call(ld_x4_ty, ld_a0_fn, {ptra});
 | 
			
		||||
      if(K == 0 && inc == 1 && is_prefetch)
 | 
			
		||||
          prefetch_latch_to_bb_[phiA->get_incoming_value(1)].push_back(haa);
 | 
			
		||||
      Value *ha0 = extract_val(haa, std::vector<unsigned>{0});
 | 
			
		||||
      Value *ha1 = extract_val(haa, std::vector<unsigned>{1});
 | 
			
		||||
      Value *ha2 = extract_val(haa, std::vector<unsigned>{2});
 | 
			
		||||
      Value *ha3 = extract_val(haa, std::vector<unsigned>{3});
 | 
			
		||||
      ha[{m, K}] = std::make_pair(ha0, ha1);
 | 
			
		||||
      ha[{m, K+8}] = std::make_pair(ha2, ha3);
 | 
			
		||||
    }
 | 
			
		||||
    if(hb.find({n, K})==hb.end()){
 | 
			
		||||
      Value* ptrb = ptrs_b[(is_b_row ? n : K/16) % num_ptr_b];
 | 
			
		||||
      register_lds(ha, m, K, inc, ha0, ha1, is_prefetch);
 | 
			
		||||
      register_lds(ha, m, K + 8, inc, ha2, ha3, is_prefetch);
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  auto load_b = [&](int n, int K, int inc, bool is_prefetch) {
 | 
			
		||||
      int offidx = (is_b_row ? n : K/16) % num_ptr_b;
 | 
			
		||||
      Value* ptrb;
 | 
			
		||||
      if(K == 0 && is_prefetch){
 | 
			
		||||
        if(inc == 0)
 | 
			
		||||
          ptrb = gep(shared_pre_ptr_[layout_b], off_b[offidx]);
 | 
			
		||||
        else
 | 
			
		||||
          ptrb = gep(shared_next_ptr_[layout_b], off_b[offidx]);
 | 
			
		||||
      }
 | 
			
		||||
      else
 | 
			
		||||
        ptrb = ptrs_b[offidx];
 | 
			
		||||
      int step_bn = is_b_row ? n / (num_ptr_b)*(num_ptr_b) : n;
 | 
			
		||||
      int step_bk = is_b_row ? K : K / (num_ptr_b*8)*(num_ptr_b*8);
 | 
			
		||||
      InlineAsm *ld_b_fn = InlineAsm::get(ld_x4_ty, "ldmatrix.sync.aligned.m8n8.x4" + b_trans + ".shared.b16 "
 | 
			
		||||
                                                "{$0, $1, $2, $3}, [$4 + " + std::to_string(2*step_bn*8*layout->wpt(1)*stride_b_n + 2*step_bk*stride_b_k) + "];", "=r,=r,=r,=r,r", false);
 | 
			
		||||
                                                    "{$0, $1, $2, $3}, [$4 + " + 
 | 
			
		||||
                                                    std::to_string(2*step_bn*8*layout->wpt(1)*stride_b_n + 2*step_bk*stride_b_k) + "];", 
 | 
			
		||||
                                                    "=r,=r,=r,=r,r", true);
 | 
			
		||||
      Value *hbb = call(ld_x4_ty, ld_b_fn, {ptrb});
 | 
			
		||||
      if(K == 0 && inc == 1 && is_prefetch)
 | 
			
		||||
          prefetch_latch_to_bb_[phiB->get_incoming_value(1)].push_back(hbb);
 | 
			
		||||
      Value *hb0 = extract_val(hbb, std::vector<unsigned>{0});
 | 
			
		||||
      Value *hb1 = extract_val(hbb, std::vector<unsigned>{1});
 | 
			
		||||
      Value *hb2 = extract_val(hbb, std::vector<unsigned>{2});
 | 
			
		||||
      Value *hb3 = extract_val(hbb, std::vector<unsigned>{3});
 | 
			
		||||
      hb[{n, K}] = hb0;
 | 
			
		||||
      hb[{n+1, K}] = hb2;
 | 
			
		||||
      hb[{n, K+8}] = hb1;
 | 
			
		||||
      hb[{n+1, K+8}] = hb3;
 | 
			
		||||
    }
 | 
			
		||||
    unsigned cols_per_thread = num_rep_0 * 2;
 | 
			
		||||
    std::vector<size_t> idx = {
 | 
			
		||||
      (m*2 + 0) + (n*2 + 0)*cols_per_thread,
 | 
			
		||||
      (m*2 + 0) + (n*2 + 1)*cols_per_thread,
 | 
			
		||||
      (m*2 + 1) + (n*2 + 0)*cols_per_thread,
 | 
			
		||||
      (m*2 + 1) + (n*2 + 1)*cols_per_thread
 | 
			
		||||
    };
 | 
			
		||||
    Value *nc = call(mma_ty, mma_fn, {ha[{m, K}].first, ha[{m, K}].second,ha[{m, K+8}].first, ha[{m, K+8}].second,
 | 
			
		||||
                                                      hb[{n, K}], hb[{n, K+8}],
 | 
			
		||||
                                                      fc[idx[0]], fc[idx[1]], fc[idx[2]], fc[idx[3]]});
 | 
			
		||||
    fc[idx[0]] = extract_val(nc, std::vector<unsigned>{0});
 | 
			
		||||
    fc[idx[1]] = extract_val(nc, std::vector<unsigned>{1});
 | 
			
		||||
    fc[idx[2]] = extract_val(nc, std::vector<unsigned>{2});
 | 
			
		||||
    fc[idx[3]] = extract_val(nc, std::vector<unsigned>{3});
 | 
			
		||||
  }
 | 
			
		||||
      register_lds2(hb, n, K, inc, hb0, is_prefetch);
 | 
			
		||||
      register_lds2(hb, n+1, K, inc, hb2, is_prefetch);
 | 
			
		||||
      register_lds2(hb, n, K+8, inc, hb1, is_prefetch);
 | 
			
		||||
      register_lds2(hb, n+1, K+8, inc, hb3, is_prefetch);
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  if (C->is_prefetched()) {
 | 
			
		||||
      // create phis
 | 
			
		||||
      builder_->SetInsertPoint(CurrBB->getFirstNonPHI());
 | 
			
		||||
      for(unsigned m = 0; m < num_rep_0; m++){
 | 
			
		||||
        ha[{m, 0}].first = phi(fp16x2_ty, 2);
 | 
			
		||||
        ha[{m, 0}].second = phi(fp16x2_ty, 2);
 | 
			
		||||
        ha[{m, 8}].first = phi(fp16x2_ty, 2);
 | 
			
		||||
        ha[{m, 8}].second = phi(fp16x2_ty, 2);
 | 
			
		||||
      }
 | 
			
		||||
      for(unsigned n = 0; n < num_rep_1; n+=2){
 | 
			
		||||
        hb[{n, 0}] = phi(fp16x2_ty, 2);
 | 
			
		||||
        hb[{n+1, 0}] = phi(fp16x2_ty, 2);
 | 
			
		||||
        hb[{n, 8}] = phi(fp16x2_ty, 2);
 | 
			
		||||
        hb[{n+1, 8}] = phi(fp16x2_ty, 2);
 | 
			
		||||
      }
 | 
			
		||||
      // insert prefetched lds at the end of loop header
 | 
			
		||||
      builder_->SetInsertPoint(bbs_[phiA->get_incoming_block(0)]->getTerminator());
 | 
			
		||||
      for(unsigned m = 0; m < num_rep_0; m++)
 | 
			
		||||
        load_a(m, 0, 0, true);
 | 
			
		||||
      for(unsigned n = 0; n < num_rep_1; n+=2)
 | 
			
		||||
        load_b(n, 0, 0, true);
 | 
			
		||||
      // update accumulators
 | 
			
		||||
      builder_->SetInsertPoint(CurrBB);
 | 
			
		||||
      for(unsigned K = 0; K < NK; K += 16){
 | 
			
		||||
        int NEXTK = (K + 16) % NK;
 | 
			
		||||
        // prefetch A
 | 
			
		||||
        for(unsigned m = 0; m < num_rep_0; m++)
 | 
			
		||||
          load_a(m, NEXTK, 1, true);
 | 
			
		||||
        // prefetch B
 | 
			
		||||
        for(unsigned n = 0; n < num_rep_1; n+=2)
 | 
			
		||||
          load_b(n, NEXTK, 1, true);
 | 
			
		||||
        // tensor core ops
 | 
			
		||||
        for(unsigned m = 0; m < num_rep_0; m++)
 | 
			
		||||
        for(unsigned n = 0; n < num_rep_1; n++){
 | 
			
		||||
          call_mma(m, n, K);
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
  }
 | 
			
		||||
  else{
 | 
			
		||||
      for(unsigned K = 0; K < NK; K += 16)
 | 
			
		||||
      for(unsigned m = 0; m < num_rep_0; m++)
 | 
			
		||||
      for(unsigned n = 0; n < num_rep_1; n++){
 | 
			
		||||
        if(ha.find({m, K}) == ha.end())
 | 
			
		||||
          load_a(m, K, 0, false);
 | 
			
		||||
        if(hb.find({n, K})==hb.end())
 | 
			
		||||
          load_b(n, K, 0, false);
 | 
			
		||||
        call_mma(m, n, K);
 | 
			
		||||
      }
 | 
			
		||||
  }
 | 
			
		||||
  // write back
 | 
			
		||||
  unsigned i = 0;
 | 
			
		||||
  for(indices_t idx: idxs_.at(dot)){
 | 
			
		||||
  for(indices_t idx: idxs_.at(C)){
 | 
			
		||||
    std::vector<Value*> key(idx.size() - 2);
 | 
			
		||||
    std::copy(idx.begin() + 2, idx.end(), key.begin());
 | 
			
		||||
    if(i >= fcs.at(key).size())
 | 
			
		||||
      i = 0;
 | 
			
		||||
    vals_[dot][idx] = fcs.at(key)[i++];
 | 
			
		||||
    vals_[C][idx] = fcs.at(key)[i++];
 | 
			
		||||
  };
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -2252,8 +2370,35 @@ void generator::visit_layout_scanline(analysis::scanline_layout* layout) {
 | 
			
		||||
void generator::visit_layout_shared(analysis::shared_layout* layout) {
 | 
			
		||||
  Type* ty = cvt(layout->get_type());
 | 
			
		||||
  PointerType *ptr_ty = ty->getPointerTo(shmem_->getType()->getPointerAddressSpace());
 | 
			
		||||
  // double-buffered
 | 
			
		||||
  if(layout->get_double_buffer()) {
 | 
			
		||||
  if (layout->get_N_buffer()) {
 | 
			
		||||
    // create pointers
 | 
			
		||||
    shared_pre_ptr_[layout] = gep(shmem_, i32(alloc_->offset(layout)));
 | 
			
		||||
    shared_pre_ptr_[layout] = bit_cast(shared_pre_ptr_[layout], ptr_ty);
 | 
			
		||||
 | 
			
		||||
    BasicBlock *current = builder_->GetInsertBlock();
 | 
			
		||||
 | 
			
		||||
    auto info = *layout->get_N_buffer();
 | 
			
		||||
    ir::phi_node *phi = info.phi;
 | 
			
		||||
    BasicBlock *parent = bbs_.at(phi->get_parent());
 | 
			
		||||
    if(parent->empty())
 | 
			
		||||
      builder_->SetInsertPoint(parent);
 | 
			
		||||
    else if (const Instruction *first_non_phi = &*parent->getFirstNonPHI()) {
 | 
			
		||||
      builder_->SetInsertPoint(&*parent->getFirstNonPHI());
 | 
			
		||||
    } else 
 | 
			
		||||
      builder_->SetInsertPoint(parent);
 | 
			
		||||
 | 
			
		||||
    // create smem_idx
 | 
			
		||||
    read_smem_idx_[layout] = phi(i32_ty, 2);
 | 
			
		||||
    write_smem_idx_[layout] = phi(i32_ty, 2);
 | 
			
		||||
 | 
			
		||||
    // create pointers
 | 
			
		||||
    // ptr of the current iteration
 | 
			
		||||
    shared_ptr_[layout] = phi(ptr_ty, 2);
 | 
			
		||||
    // ptr of the next iteration
 | 
			
		||||
    shared_next_ptr_[layout] = phi(ptr_ty, 2);
 | 
			
		||||
 | 
			
		||||
    builder_->SetInsertPoint(current);
 | 
			
		||||
  } else if(layout->get_double_buffer()) {
 | 
			
		||||
    BasicBlock *current = builder_->GetInsertBlock();
 | 
			
		||||
    auto info = *layout->get_double_buffer();
 | 
			
		||||
    ir::phi_node *phi = info.phi;
 | 
			
		||||
@@ -2269,8 +2414,7 @@ void generator::visit_layout_shared(analysis::shared_layout* layout) {
 | 
			
		||||
    shared_off_[layout] = phi(i32_ty, 2);
 | 
			
		||||
    shared_next_ptr_[layout] = gep(shared_ptr_[layout], shared_off_[layout], "next_ptr");
 | 
			
		||||
    builder_->SetInsertPoint(current);
 | 
			
		||||
  }
 | 
			
		||||
  else{
 | 
			
		||||
  } else{
 | 
			
		||||
    size_t offset = alloc_->offset(layout);
 | 
			
		||||
    shared_ptr_[layout] = gep(shmem_, i32(offset));
 | 
			
		||||
    shared_ptr_[layout] = bit_cast(shared_ptr_[layout], ptr_ty);
 | 
			
		||||
@@ -2354,7 +2498,67 @@ void generator::init_idx(ir::value *v) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void generator::finalize_shared_layout(analysis::shared_layout *shared) {
 | 
			
		||||
  if(shared->get_double_buffer()) {
 | 
			
		||||
  if (auto n_buffer = shared->get_N_buffer()) {
 | 
			
		||||
    // if (*_smem_idx == #stages-1) {
 | 
			
		||||
    //   *_smem_idx = 0;
 | 
			
		||||
    // } else *_smem_idx++;
 | 
			
		||||
    auto finalize_smem_idx = [&](auto &smem_idx, int init_stage) {
 | 
			
		||||
      // insert point
 | 
			
		||||
      Value *idx = smem_idx[shared];
 | 
			
		||||
      builder_->SetInsertPoint(bbs_.at(n_buffer->phi->get_parent())->getTerminator());
 | 
			
		||||
      Value *cond = icmp_eq(idx, i32(shared->get_num_stages()-1));
 | 
			
		||||
      PHINode *_ret = phi(i32_ty, 2);      
 | 
			
		||||
      Instruction *then_term = nullptr;
 | 
			
		||||
      Instruction *else_term = nullptr;
 | 
			
		||||
      Instruction *dummy = builder_->CreateRet(nullptr);
 | 
			
		||||
      llvm::SplitBlockAndInsertIfThenElse(cond, _ret, &then_term, &else_term, nullptr);
 | 
			
		||||
      dummy->removeFromParent();
 | 
			
		||||
      builder_->SetInsertPoint(then_term);
 | 
			
		||||
      Value *zero_smem_idx = i32(0);
 | 
			
		||||
      builder_->SetInsertPoint(else_term);
 | 
			
		||||
      Value *inc_smem_idx = add(idx, i32(1));
 | 
			
		||||
      builder_->SetInsertPoint(_ret->getParent());
 | 
			
		||||
      _ret->addIncoming(zero_smem_idx, then_term->getParent());
 | 
			
		||||
      _ret->addIncoming(inc_smem_idx, else_term->getParent());
 | 
			
		||||
      // update ir::bb -> llvm::bb mapping
 | 
			
		||||
      bbs_.at(n_buffer->phi->get_parent()) = builder_->GetInsertBlock();
 | 
			
		||||
      // idx = init_stage;
 | 
			
		||||
      // loop: ...
 | 
			
		||||
      if (auto idx_phi = llvm::dyn_cast<PHINode>(smem_idx[shared])) {
 | 
			
		||||
        idx_phi->addIncoming(i32(init_stage), bbs_.at(n_buffer->phi->get_incoming_block(0)));
 | 
			
		||||
        idx_phi->addIncoming(_ret, bbs_.at(n_buffer->phi->get_incoming_block(1)));
 | 
			
		||||
      } else
 | 
			
		||||
        throw std::runtime_error("Should be PHINode");
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
    // read_smem_idx is used by next_ptr to compute the next iteration value, so init value is 2
 | 
			
		||||
    finalize_smem_idx(read_smem_idx_, 2);
 | 
			
		||||
    finalize_smem_idx(write_smem_idx_, shared->get_num_stages()-1);
 | 
			
		||||
 | 
			
		||||
    // finalize pointers
 | 
			
		||||
    ir::phi_node *pn = n_buffer->phi;
 | 
			
		||||
    BasicBlock *header = bbs_.at(pn->get_incoming_block(0));
 | 
			
		||||
    BasicBlock *loop = bbs_.at(pn->get_incoming_block(1));
 | 
			
		||||
    // %curr_ptr = phi %shared_pre_ptr, %next_ptr
 | 
			
		||||
    // %next_ptr = phi %shared_pre_ptr[+1], (gep(%pre_ptr, read_smem_idx*per_stage_size))
 | 
			
		||||
    if (auto curr_ptr = dyn_cast<PHINode>(shared_ptr_[shared])) {
 | 
			
		||||
      curr_ptr->addIncoming(shared_pre_ptr_[shared], header);
 | 
			
		||||
      curr_ptr->addIncoming(shared_next_ptr_[shared], loop);
 | 
			
		||||
    } else 
 | 
			
		||||
      throw std::runtime_error("Should be PHINode");
 | 
			
		||||
 | 
			
		||||
    BasicBlock *current = builder_->GetInsertBlock();
 | 
			
		||||
    builder_->SetInsertPoint(header->getTerminator());
 | 
			
		||||
    Value *next_ptr_header = gep(shared_pre_ptr_[shared], i32(shared->get_per_stage_elements()));
 | 
			
		||||
    builder_->SetInsertPoint(current->getTerminator());
 | 
			
		||||
 | 
			
		||||
    assert(isa<PHINode>(shared_next_ptr_[shared]));
 | 
			
		||||
    static_cast<PHINode*>(shared_next_ptr_[shared])->addIncoming(next_ptr_header, header);
 | 
			
		||||
 | 
			
		||||
    Value *lds_offset = mul(read_smem_idx_[shared], i32(shared->get_per_stage_elements()));
 | 
			
		||||
    Value *next_ptr = gep(shared_pre_ptr_[shared], lds_offset);
 | 
			
		||||
    static_cast<PHINode*>(shared_next_ptr_[shared])->addIncoming(next_ptr, loop);
 | 
			
		||||
  } else if(shared->get_double_buffer()) {
 | 
			
		||||
    auto info = *shared->get_double_buffer();
 | 
			
		||||
    ir::phi_node *phi = info.phi;
 | 
			
		||||
    PHINode *ptr = (PHINode*)shmems_[phi];
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user