[BACKEND][CODEGEN] Faster reduction for scanline layout (#516)
This commit is contained in:
@@ -177,6 +177,7 @@ public:
|
||||
void visit_sqrt_inst(ir::sqrt_inst*);
|
||||
Value* shfl_sync(Value* acc, int32_t i);
|
||||
void visit_reduce1d_inst(ir::reduce_inst*, std::function<Value*(Value*,Value*)>, Value*);
|
||||
void visit_reducend_inst_fast(ir::reduce_inst* x, std::function<Value*(Value*,Value*)> do_acc, Value *neutral);
|
||||
void visit_reducend_inst(ir::reduce_inst*, std::function<Value*(Value*,Value*)>, Value*);
|
||||
void visit_reduce_inst(ir::reduce_inst*);
|
||||
void visit_select_inst(ir::select_inst*);
|
||||
|
@@ -2311,60 +2311,69 @@ inline Value* generator::shfl_sync(Value* acc, int32_t i){
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Code Generation for `reduce` (1D case)
|
||||
* \brief Code Generation for `reduce` (ND case)
|
||||
*/
|
||||
void generator::visit_reduce1d_inst(ir::reduce_inst* x, std::function<Value*(Value*,Value*)> do_acc, Value *neutral) {
|
||||
std::map<indices_t, Value*> partial;
|
||||
void generator::visit_reducend_inst_fast(ir::reduce_inst* x, std::function<Value*(Value*,Value*)> do_acc, Value *neutral){
|
||||
//
|
||||
ir::value *arg = x->get_operand(0);
|
||||
analysis::scanline_layout* layout = layouts_->get(arg)->to_scanline();
|
||||
std::vector<unsigned> shapes = layout->get_shape();
|
||||
std::vector<int> order = layout->get_order();
|
||||
unsigned mts = layout->mts(order[0]);
|
||||
unsigned nts = layout->nts(order[0]);
|
||||
unsigned col_per_thread = shapes[order[0]] / mts;
|
||||
auto idxs = idxs_.at(arg);
|
||||
size_t n_elts = idxs.size();
|
||||
//
|
||||
Type *ret_ty = cvt(x->get_type()->get_scalar_ty());
|
||||
Value *acc = nullptr;
|
||||
|
||||
// reduce within thread
|
||||
for(indices_t idx: idxs_.at(arg)){
|
||||
Value *val = vals_[arg][idx];
|
||||
acc = !acc ? val : do_acc(acc, val);
|
||||
}
|
||||
// reduce within wrap
|
||||
for(int i = 16; i > 0; i >>= 1)
|
||||
acc = do_acc(acc, shfl_sync(acc, i));
|
||||
// pointers
|
||||
unsigned addr_space = shmem_->getType()->getPointerAddressSpace();
|
||||
Value *base = bit_cast(shmem_, ptr_ty(ret_ty, addr_space));
|
||||
Value* thread = tgt_->get_local_id(mod_, *builder_, 0);
|
||||
Value* warp = udiv(thread, i32(32));
|
||||
Value* lane = urem(thread, i32(32));
|
||||
// store warp result in shared memory
|
||||
add_barrier();
|
||||
store(neutral, gep(base, lane));
|
||||
add_barrier();
|
||||
store(acc, gep(base, warp));
|
||||
add_barrier();
|
||||
size_t warps_per_inner = std::max<int>(mts/32, 1);
|
||||
Value* warp_i = udiv(warp, i32(warps_per_inner));
|
||||
unsigned row_per_thread = std::max<int>(32/mts, 1);
|
||||
|
||||
// reduce across warps
|
||||
Value *cond = icmp_eq(warp, i32(0));
|
||||
Instruction *barrier = add_barrier();
|
||||
builder_->SetInsertPoint(barrier->getParent());
|
||||
Instruction* dummy = builder_->CreateRet(nullptr);
|
||||
Instruction *term = llvm::SplitBlockAndInsertIfThen(cond, barrier, false);
|
||||
dummy->removeFromParent();
|
||||
builder_->SetInsertPoint(term);
|
||||
Value* ret = load(gep(base, thread));
|
||||
for(int i = (num_warps_+1)/2; i > 0; i >>= 1){
|
||||
Value *current = shfl_sync(ret, i);
|
||||
ret = do_acc(ret, current);
|
||||
for(size_t i = 0; i < n_elts/col_per_thread; i++){
|
||||
Value* acc;
|
||||
// reduce within thread
|
||||
for(size_t j = 0; j < col_per_thread; j++){
|
||||
Value* val = vals_[arg][idxs[i*col_per_thread + j]];
|
||||
acc = (j == 0) ? val : do_acc(acc, val);
|
||||
}
|
||||
// reduce within warp
|
||||
for(int k = std::min<int>(mts, 32)/2 ; k > 0; k >>= 1)
|
||||
acc = do_acc(acc, shfl_sync(acc, k));
|
||||
// store warp result in shared memory
|
||||
Value* ret = acc;
|
||||
if(mts >= 32){
|
||||
add_barrier();
|
||||
store(neutral, gep(base, lane));
|
||||
add_barrier();
|
||||
store(acc, gep(base, warp));
|
||||
add_barrier();
|
||||
// reduce across warps
|
||||
Value *cond = icmp_eq(warp, i32(0));
|
||||
Instruction *barrier = add_barrier();
|
||||
builder_->SetInsertPoint(barrier->getParent());
|
||||
Instruction* dummy = builder_->CreateRet(nullptr);
|
||||
Instruction *term = llvm::SplitBlockAndInsertIfThen(cond, barrier, false);
|
||||
dummy->removeFromParent();
|
||||
builder_->SetInsertPoint(term);
|
||||
ret = load(gep(base, thread));
|
||||
for(int k = (mts/32)/2; k > 0; k >>= 1){
|
||||
Value *current = shfl_sync(ret, k);
|
||||
ret = do_acc(ret, current);
|
||||
}
|
||||
store(ret, gep(base, thread));
|
||||
builder_->SetInsertPoint(barrier->getParent());
|
||||
ret = load(gep(base, warp));
|
||||
}
|
||||
vals_[x][idxs_[x][i]] = ret;
|
||||
}
|
||||
store(ret, gep(base, thread));
|
||||
|
||||
// store first warp done
|
||||
builder_->SetInsertPoint(barrier->getParent());
|
||||
ret = load(base);
|
||||
for(indices_t idx: idxs_.at(x))
|
||||
vals_[x][idx] = ret;
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Code Generation for `reduce` (ND case)
|
||||
*/
|
||||
void generator::visit_reducend_inst(ir::reduce_inst* x, std::function<Value*(Value*,Value*)> do_acc, Value *neutral) {
|
||||
ir::value *arg = x->get_operand(0);
|
||||
Type *ty = cvt(x->get_type()->get_scalar_ty());
|
||||
@@ -2462,8 +2471,9 @@ void generator::visit_reduce_inst(ir::reduce_inst* x) {
|
||||
default: throw std::runtime_error("unreachable");
|
||||
}
|
||||
ir::value *arg = x->get_operand(0);
|
||||
if(arg->get_type()->get_tile_rank() == 1)
|
||||
visit_reduce1d_inst(x, do_acc, neutral);
|
||||
analysis::scanline_layout* scanline = layouts_->get(x->get_operand(0))->to_scanline();
|
||||
if(scanline && scanline->get_order()[0] == x->get_axis())
|
||||
visit_reducend_inst_fast(x, do_acc, neutral);
|
||||
else
|
||||
visit_reducend_inst(x, do_acc, neutral);
|
||||
}
|
||||
|
@@ -676,9 +676,16 @@ def test_reduce1d(dtype_str, shape, device='cuda'):
|
||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype_str, shape, axis", [
|
||||
(dtype, (1, 1024), 1) for dtype in ['float32', 'uint32']
|
||||
])
|
||||
reduce_configs1 = [
|
||||
(dtype, (1, 1024), axis) for dtype in ['float32', 'uint32']
|
||||
for axis in [1]
|
||||
]
|
||||
reduce_configs2 = [
|
||||
('float32', shape, 1) for shape in [(2, 32), (4, 128), (32, 64), (64, 128), (128, 256), (32, 1024)]
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype_str, shape, axis", reduce_configs1 + reduce_configs2)
|
||||
def test_reduce2d(dtype_str, shape, axis, device='cuda'):
|
||||
# triton kernel
|
||||
@triton.jit
|
||||
|
Reference in New Issue
Block a user