[BACKEND][CODEGEN] Faster reduction for scanline layout (#516)
This commit is contained in:
@@ -2311,60 +2311,69 @@ inline Value* generator::shfl_sync(Value* acc, int32_t i){
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Code Generation for `reduce` (1D case)
|
||||
* \brief Code Generation for `reduce` (ND case)
|
||||
*/
|
||||
void generator::visit_reduce1d_inst(ir::reduce_inst* x, std::function<Value*(Value*,Value*)> do_acc, Value *neutral) {
|
||||
std::map<indices_t, Value*> partial;
|
||||
void generator::visit_reducend_inst_fast(ir::reduce_inst* x, std::function<Value*(Value*,Value*)> do_acc, Value *neutral){
|
||||
//
|
||||
ir::value *arg = x->get_operand(0);
|
||||
analysis::scanline_layout* layout = layouts_->get(arg)->to_scanline();
|
||||
std::vector<unsigned> shapes = layout->get_shape();
|
||||
std::vector<int> order = layout->get_order();
|
||||
unsigned mts = layout->mts(order[0]);
|
||||
unsigned nts = layout->nts(order[0]);
|
||||
unsigned col_per_thread = shapes[order[0]] / mts;
|
||||
auto idxs = idxs_.at(arg);
|
||||
size_t n_elts = idxs.size();
|
||||
//
|
||||
Type *ret_ty = cvt(x->get_type()->get_scalar_ty());
|
||||
Value *acc = nullptr;
|
||||
|
||||
// reduce within thread
|
||||
for(indices_t idx: idxs_.at(arg)){
|
||||
Value *val = vals_[arg][idx];
|
||||
acc = !acc ? val : do_acc(acc, val);
|
||||
}
|
||||
// reduce within wrap
|
||||
for(int i = 16; i > 0; i >>= 1)
|
||||
acc = do_acc(acc, shfl_sync(acc, i));
|
||||
// pointers
|
||||
unsigned addr_space = shmem_->getType()->getPointerAddressSpace();
|
||||
Value *base = bit_cast(shmem_, ptr_ty(ret_ty, addr_space));
|
||||
Value* thread = tgt_->get_local_id(mod_, *builder_, 0);
|
||||
Value* warp = udiv(thread, i32(32));
|
||||
Value* lane = urem(thread, i32(32));
|
||||
// store warp result in shared memory
|
||||
add_barrier();
|
||||
store(neutral, gep(base, lane));
|
||||
add_barrier();
|
||||
store(acc, gep(base, warp));
|
||||
add_barrier();
|
||||
size_t warps_per_inner = std::max<int>(mts/32, 1);
|
||||
Value* warp_i = udiv(warp, i32(warps_per_inner));
|
||||
unsigned row_per_thread = std::max<int>(32/mts, 1);
|
||||
|
||||
// reduce across warps
|
||||
Value *cond = icmp_eq(warp, i32(0));
|
||||
Instruction *barrier = add_barrier();
|
||||
builder_->SetInsertPoint(barrier->getParent());
|
||||
Instruction* dummy = builder_->CreateRet(nullptr);
|
||||
Instruction *term = llvm::SplitBlockAndInsertIfThen(cond, barrier, false);
|
||||
dummy->removeFromParent();
|
||||
builder_->SetInsertPoint(term);
|
||||
Value* ret = load(gep(base, thread));
|
||||
for(int i = (num_warps_+1)/2; i > 0; i >>= 1){
|
||||
Value *current = shfl_sync(ret, i);
|
||||
ret = do_acc(ret, current);
|
||||
for(size_t i = 0; i < n_elts/col_per_thread; i++){
|
||||
Value* acc;
|
||||
// reduce within thread
|
||||
for(size_t j = 0; j < col_per_thread; j++){
|
||||
Value* val = vals_[arg][idxs[i*col_per_thread + j]];
|
||||
acc = (j == 0) ? val : do_acc(acc, val);
|
||||
}
|
||||
// reduce within warp
|
||||
for(int k = std::min<int>(mts, 32)/2 ; k > 0; k >>= 1)
|
||||
acc = do_acc(acc, shfl_sync(acc, k));
|
||||
// store warp result in shared memory
|
||||
Value* ret = acc;
|
||||
if(mts >= 32){
|
||||
add_barrier();
|
||||
store(neutral, gep(base, lane));
|
||||
add_barrier();
|
||||
store(acc, gep(base, warp));
|
||||
add_barrier();
|
||||
// reduce across warps
|
||||
Value *cond = icmp_eq(warp, i32(0));
|
||||
Instruction *barrier = add_barrier();
|
||||
builder_->SetInsertPoint(barrier->getParent());
|
||||
Instruction* dummy = builder_->CreateRet(nullptr);
|
||||
Instruction *term = llvm::SplitBlockAndInsertIfThen(cond, barrier, false);
|
||||
dummy->removeFromParent();
|
||||
builder_->SetInsertPoint(term);
|
||||
ret = load(gep(base, thread));
|
||||
for(int k = (mts/32)/2; k > 0; k >>= 1){
|
||||
Value *current = shfl_sync(ret, k);
|
||||
ret = do_acc(ret, current);
|
||||
}
|
||||
store(ret, gep(base, thread));
|
||||
builder_->SetInsertPoint(barrier->getParent());
|
||||
ret = load(gep(base, warp));
|
||||
}
|
||||
vals_[x][idxs_[x][i]] = ret;
|
||||
}
|
||||
store(ret, gep(base, thread));
|
||||
|
||||
// store first warp done
|
||||
builder_->SetInsertPoint(barrier->getParent());
|
||||
ret = load(base);
|
||||
for(indices_t idx: idxs_.at(x))
|
||||
vals_[x][idx] = ret;
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Code Generation for `reduce` (ND case)
|
||||
*/
|
||||
void generator::visit_reducend_inst(ir::reduce_inst* x, std::function<Value*(Value*,Value*)> do_acc, Value *neutral) {
|
||||
ir::value *arg = x->get_operand(0);
|
||||
Type *ty = cvt(x->get_type()->get_scalar_ty());
|
||||
@@ -2462,8 +2471,9 @@ void generator::visit_reduce_inst(ir::reduce_inst* x) {
|
||||
default: throw std::runtime_error("unreachable");
|
||||
}
|
||||
ir::value *arg = x->get_operand(0);
|
||||
if(arg->get_type()->get_tile_rank() == 1)
|
||||
visit_reduce1d_inst(x, do_acc, neutral);
|
||||
analysis::scanline_layout* scanline = layouts_->get(x->get_operand(0))->to_scanline();
|
||||
if(scanline && scanline->get_order()[0] == x->get_axis())
|
||||
visit_reducend_inst_fast(x, do_acc, neutral);
|
||||
else
|
||||
visit_reducend_inst(x, do_acc, neutral);
|
||||
}
|
||||
|
Reference in New Issue
Block a user