From d35617bea13755327dff9e96a4390fd500e2f8ad Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Sat, 14 May 2022 15:26:13 -0700 Subject: [PATCH] [BACKEND][CODEGEN] Faster reduction for scanline layout (#516) --- include/triton/codegen/selection/generator.h | 1 + lib/codegen/selection/generator.cc | 98 +++++++++++--------- python/test/unit/language/test_core.py | 13 ++- 3 files changed, 65 insertions(+), 47 deletions(-) diff --git a/include/triton/codegen/selection/generator.h b/include/triton/codegen/selection/generator.h index d855d3eca..945b9b074 100644 --- a/include/triton/codegen/selection/generator.h +++ b/include/triton/codegen/selection/generator.h @@ -177,6 +177,7 @@ public: void visit_sqrt_inst(ir::sqrt_inst*); Value* shfl_sync(Value* acc, int32_t i); void visit_reduce1d_inst(ir::reduce_inst*, std::function, Value*); + void visit_reducend_inst_fast(ir::reduce_inst* x, std::function do_acc, Value *neutral); void visit_reducend_inst(ir::reduce_inst*, std::function, Value*); void visit_reduce_inst(ir::reduce_inst*); void visit_select_inst(ir::select_inst*); diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index 9abf86df8..cf51a3b4c 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -2311,60 +2311,69 @@ inline Value* generator::shfl_sync(Value* acc, int32_t i){ } /** - * \brief Code Generation for `reduce` (1D case) + * \brief Code Generation for `reduce` (ND case) */ -void generator::visit_reduce1d_inst(ir::reduce_inst* x, std::function do_acc, Value *neutral) { - std::map partial; +void generator::visit_reducend_inst_fast(ir::reduce_inst* x, std::function do_acc, Value *neutral){ + // ir::value *arg = x->get_operand(0); + analysis::scanline_layout* layout = layouts_->get(arg)->to_scanline(); + std::vector shapes = layout->get_shape(); + std::vector order = layout->get_order(); + unsigned mts = layout->mts(order[0]); + unsigned nts = layout->nts(order[0]); + unsigned col_per_thread = shapes[order[0]] / mts; + auto idxs = idxs_.at(arg); + size_t n_elts = idxs.size(); + // Type *ret_ty = cvt(x->get_type()->get_scalar_ty()); - Value *acc = nullptr; - - // reduce within thread - for(indices_t idx: idxs_.at(arg)){ - Value *val = vals_[arg][idx]; - acc = !acc ? val : do_acc(acc, val); - } - // reduce within wrap - for(int i = 16; i > 0; i >>= 1) - acc = do_acc(acc, shfl_sync(acc, i)); - // pointers unsigned addr_space = shmem_->getType()->getPointerAddressSpace(); Value *base = bit_cast(shmem_, ptr_ty(ret_ty, addr_space)); Value* thread = tgt_->get_local_id(mod_, *builder_, 0); Value* warp = udiv(thread, i32(32)); Value* lane = urem(thread, i32(32)); - // store warp result in shared memory - add_barrier(); - store(neutral, gep(base, lane)); - add_barrier(); - store(acc, gep(base, warp)); - add_barrier(); + size_t warps_per_inner = std::max(mts/32, 1); + Value* warp_i = udiv(warp, i32(warps_per_inner)); + unsigned row_per_thread = std::max(32/mts, 1); - // reduce across warps - Value *cond = icmp_eq(warp, i32(0)); - Instruction *barrier = add_barrier(); - builder_->SetInsertPoint(barrier->getParent()); - Instruction* dummy = builder_->CreateRet(nullptr); - Instruction *term = llvm::SplitBlockAndInsertIfThen(cond, barrier, false); - dummy->removeFromParent(); - builder_->SetInsertPoint(term); - Value* ret = load(gep(base, thread)); - for(int i = (num_warps_+1)/2; i > 0; i >>= 1){ - Value *current = shfl_sync(ret, i); - ret = do_acc(ret, current); + for(size_t i = 0; i < n_elts/col_per_thread; i++){ + Value* acc; + // reduce within thread + for(size_t j = 0; j < col_per_thread; j++){ + Value* val = vals_[arg][idxs[i*col_per_thread + j]]; + acc = (j == 0) ? val : do_acc(acc, val); + } + // reduce within warp + for(int k = std::min(mts, 32)/2 ; k > 0; k >>= 1) + acc = do_acc(acc, shfl_sync(acc, k)); + // store warp result in shared memory + Value* ret = acc; + if(mts >= 32){ + add_barrier(); + store(neutral, gep(base, lane)); + add_barrier(); + store(acc, gep(base, warp)); + add_barrier(); + // reduce across warps + Value *cond = icmp_eq(warp, i32(0)); + Instruction *barrier = add_barrier(); + builder_->SetInsertPoint(barrier->getParent()); + Instruction* dummy = builder_->CreateRet(nullptr); + Instruction *term = llvm::SplitBlockAndInsertIfThen(cond, barrier, false); + dummy->removeFromParent(); + builder_->SetInsertPoint(term); + ret = load(gep(base, thread)); + for(int k = (mts/32)/2; k > 0; k >>= 1){ + Value *current = shfl_sync(ret, k); + ret = do_acc(ret, current); + } + store(ret, gep(base, thread)); + builder_->SetInsertPoint(barrier->getParent()); + ret = load(gep(base, warp)); + } + vals_[x][idxs_[x][i]] = ret; } - store(ret, gep(base, thread)); - - // store first warp done - builder_->SetInsertPoint(barrier->getParent()); - ret = load(base); - for(indices_t idx: idxs_.at(x)) - vals_[x][idx] = ret; } -/** - * \brief Code Generation for `reduce` (ND case) - */ void generator::visit_reducend_inst(ir::reduce_inst* x, std::function do_acc, Value *neutral) { ir::value *arg = x->get_operand(0); Type *ty = cvt(x->get_type()->get_scalar_ty()); @@ -2462,8 +2471,9 @@ void generator::visit_reduce_inst(ir::reduce_inst* x) { default: throw std::runtime_error("unreachable"); } ir::value *arg = x->get_operand(0); - if(arg->get_type()->get_tile_rank() == 1) - visit_reduce1d_inst(x, do_acc, neutral); + analysis::scanline_layout* scanline = layouts_->get(x->get_operand(0))->to_scanline(); + if(scanline && scanline->get_order()[0] == x->get_axis()) + visit_reducend_inst_fast(x, do_acc, neutral); else visit_reducend_inst(x, do_acc, neutral); } diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 9a997d661..77a870eea 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -676,9 +676,16 @@ def test_reduce1d(dtype_str, shape, device='cuda'): np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01) -@pytest.mark.parametrize("dtype_str, shape, axis", [ - (dtype, (1, 1024), 1) for dtype in ['float32', 'uint32'] -]) +reduce_configs1 = [ + (dtype, (1, 1024), axis) for dtype in ['float32', 'uint32'] + for axis in [1] +] +reduce_configs2 = [ + ('float32', shape, 1) for shape in [(2, 32), (4, 128), (32, 64), (64, 128), (128, 256), (32, 1024)] +] + + +@pytest.mark.parametrize("dtype_str, shape, axis", reduce_configs1 + reduce_configs2) def test_reduce2d(dtype_str, shape, axis, device='cuda'): # triton kernel @triton.jit