[BACKEND][CODEGEN] Faster reduction for scanline layout (#516)

This commit is contained in:
Philippe Tillet
2022-05-14 15:26:13 -07:00
committed by GitHub
parent d1a22a94e6
commit d35617bea1
3 changed files with 65 additions and 47 deletions

View File

@@ -177,6 +177,7 @@ public:
void visit_sqrt_inst(ir::sqrt_inst*);
Value* shfl_sync(Value* acc, int32_t i);
void visit_reduce1d_inst(ir::reduce_inst*, std::function<Value*(Value*,Value*)>, Value*);
void visit_reducend_inst_fast(ir::reduce_inst* x, std::function<Value*(Value*,Value*)> do_acc, Value *neutral);
void visit_reducend_inst(ir::reduce_inst*, std::function<Value*(Value*,Value*)>, Value*);
void visit_reduce_inst(ir::reduce_inst*);
void visit_select_inst(ir::select_inst*);

View File

@@ -2311,35 +2311,48 @@ inline Value* generator::shfl_sync(Value* acc, int32_t i){
}
/**
* \brief Code Generation for `reduce` (1D case)
* \brief Code Generation for `reduce` (ND case)
*/
void generator::visit_reduce1d_inst(ir::reduce_inst* x, std::function<Value*(Value*,Value*)> do_acc, Value *neutral) {
std::map<indices_t, Value*> partial;
void generator::visit_reducend_inst_fast(ir::reduce_inst* x, std::function<Value*(Value*,Value*)> do_acc, Value *neutral){
//
ir::value *arg = x->get_operand(0);
analysis::scanline_layout* layout = layouts_->get(arg)->to_scanline();
std::vector<unsigned> shapes = layout->get_shape();
std::vector<int> order = layout->get_order();
unsigned mts = layout->mts(order[0]);
unsigned nts = layout->nts(order[0]);
unsigned col_per_thread = shapes[order[0]] / mts;
auto idxs = idxs_.at(arg);
size_t n_elts = idxs.size();
//
Type *ret_ty = cvt(x->get_type()->get_scalar_ty());
Value *acc = nullptr;
// reduce within thread
for(indices_t idx: idxs_.at(arg)){
Value *val = vals_[arg][idx];
acc = !acc ? val : do_acc(acc, val);
}
// reduce within wrap
for(int i = 16; i > 0; i >>= 1)
acc = do_acc(acc, shfl_sync(acc, i));
// pointers
unsigned addr_space = shmem_->getType()->getPointerAddressSpace();
Value *base = bit_cast(shmem_, ptr_ty(ret_ty, addr_space));
Value* thread = tgt_->get_local_id(mod_, *builder_, 0);
Value* warp = udiv(thread, i32(32));
Value* lane = urem(thread, i32(32));
size_t warps_per_inner = std::max<int>(mts/32, 1);
Value* warp_i = udiv(warp, i32(warps_per_inner));
unsigned row_per_thread = std::max<int>(32/mts, 1);
for(size_t i = 0; i < n_elts/col_per_thread; i++){
Value* acc;
// reduce within thread
for(size_t j = 0; j < col_per_thread; j++){
Value* val = vals_[arg][idxs[i*col_per_thread + j]];
acc = (j == 0) ? val : do_acc(acc, val);
}
// reduce within warp
for(int k = std::min<int>(mts, 32)/2 ; k > 0; k >>= 1)
acc = do_acc(acc, shfl_sync(acc, k));
// store warp result in shared memory
Value* ret = acc;
if(mts >= 32){
add_barrier();
store(neutral, gep(base, lane));
add_barrier();
store(acc, gep(base, warp));
add_barrier();
// reduce across warps
Value *cond = icmp_eq(warp, i32(0));
Instruction *barrier = add_barrier();
@@ -2348,23 +2361,19 @@ void generator::visit_reduce1d_inst(ir::reduce_inst* x, std::function<Value*(Val
Instruction *term = llvm::SplitBlockAndInsertIfThen(cond, barrier, false);
dummy->removeFromParent();
builder_->SetInsertPoint(term);
Value* ret = load(gep(base, thread));
for(int i = (num_warps_+1)/2; i > 0; i >>= 1){
Value *current = shfl_sync(ret, i);
ret = load(gep(base, thread));
for(int k = (mts/32)/2; k > 0; k >>= 1){
Value *current = shfl_sync(ret, k);
ret = do_acc(ret, current);
}
store(ret, gep(base, thread));
// store first warp done
builder_->SetInsertPoint(barrier->getParent());
ret = load(base);
for(indices_t idx: idxs_.at(x))
vals_[x][idx] = ret;
ret = load(gep(base, warp));
}
vals_[x][idxs_[x][i]] = ret;
}
}
/**
* \brief Code Generation for `reduce` (ND case)
*/
void generator::visit_reducend_inst(ir::reduce_inst* x, std::function<Value*(Value*,Value*)> do_acc, Value *neutral) {
ir::value *arg = x->get_operand(0);
Type *ty = cvt(x->get_type()->get_scalar_ty());
@@ -2462,8 +2471,9 @@ void generator::visit_reduce_inst(ir::reduce_inst* x) {
default: throw std::runtime_error("unreachable");
}
ir::value *arg = x->get_operand(0);
if(arg->get_type()->get_tile_rank() == 1)
visit_reduce1d_inst(x, do_acc, neutral);
analysis::scanline_layout* scanline = layouts_->get(x->get_operand(0))->to_scanline();
if(scanline && scanline->get_order()[0] == x->get_axis())
visit_reducend_inst_fast(x, do_acc, neutral);
else
visit_reducend_inst(x, do_acc, neutral);
}

View File

@@ -676,9 +676,16 @@ def test_reduce1d(dtype_str, shape, device='cuda'):
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
@pytest.mark.parametrize("dtype_str, shape, axis", [
(dtype, (1, 1024), 1) for dtype in ['float32', 'uint32']
])
reduce_configs1 = [
(dtype, (1, 1024), axis) for dtype in ['float32', 'uint32']
for axis in [1]
]
reduce_configs2 = [
('float32', shape, 1) for shape in [(2, 32), (4, 128), (32, 64), (64, 128), (128, 256), (32, 1024)]
]
@pytest.mark.parametrize("dtype_str, shape, axis", reduce_configs1 + reduce_configs2)
def test_reduce2d(dtype_str, shape, axis, device='cuda'):
# triton kernel
@triton.jit