diff --git a/include/triton/codegen/selection/generator.h b/include/triton/codegen/selection/generator.h index cd091d821..2153e3e1a 100644 --- a/include/triton/codegen/selection/generator.h +++ b/include/triton/codegen/selection/generator.h @@ -167,6 +167,7 @@ public: void visit_dot_inst(ir::dot_inst*); void visit_trans_inst(ir::trans_inst*); void visit_sqrt_inst(ir::sqrt_inst*); + Value* shfl_sync(Value* acc, int32_t i); void visit_reduce1d_inst(ir::reduce_inst*, std::function, Value*); void visit_reducend_inst(ir::reduce_inst*, std::function, Value*); void visit_reduce_inst(ir::reduce_inst*); diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index 47b96aa08..d00984167 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -1723,6 +1723,21 @@ Value* generator::shared_off(const std::vector& shapes, const std::vec return result; } +inline Value* generator::shfl_sync(Value* acc, int32_t i){ + Type* ty = acc->getType(); + std::string asm_str = "shfl.sync.bfly.b32 $0, $1, $2, 0x1f, 0xffffffff;"; + InlineAsm *shfl = InlineAsm::get(FunctionType::get(ty, {ty, i32_ty}, false), asm_str, "=f,f,r", false); + if(ty->getPrimitiveSizeInBits() <= 32) + return call(shfl, {acc, i32(i)}); + acc = builder_->CreateBitCast(acc, vec_ty(f32_ty, 2)); + Value* acc0 = builder_->CreateExtractElement(acc, i32(0)); + Value* acc1 = builder_->CreateExtractElement(acc, i32(1)); + Value* ret = UndefValue::get(vec_ty(f32_ty, 2)); + ret = insert_elt(ret, shfl_sync(acc0, i), i32(0)); + ret = insert_elt(ret, shfl_sync(acc1, i), i32(1)); + return builder_->CreateBitCast(ret, ty); +} + /** * \brief Code Generation for `reduce` (1D case) */ @@ -1738,10 +1753,8 @@ void generator::visit_reduce1d_inst(ir::reduce_inst* x, std::function 0; i >>= 1) - acc = do_acc(acc, call(shfl, {acc, i32(i)})); + acc = do_acc(acc, shfl_sync(acc, i)); // pointers unsigned addr_space = shmem_->getType()->getPointerAddressSpace(); Value *base = bit_cast(shmem_, ptr_ty(ret_ty, addr_space)); @@ -1765,7 +1778,7 @@ void generator::visit_reduce1d_inst(ir::reduce_inst* x, std::functionSetInsertPoint(term); Value* ret = load(gep(base, thread)); for(int i = (num_warps_+1)/2; i > 0; i >>= 1){ - Value *current = call(shfl, {ret, i32(i)}); + Value *current = shfl_sync(ret, i); ret = do_acc(ret, current); } store(ret, gep(base, thread)); diff --git a/python/test/test_language.py b/python/test/test_language.py index 71a151f47..e15bbc6bb 100644 --- a/python/test/test_language.py +++ b/python/test/test_language.py @@ -337,6 +337,33 @@ def test_cast(dtype_x, dtype_z, bitcast, device='cuda'): z_ref = x.to(z_tri.dtype) assert z_tri == z_ref +# --------------- +# test reduce +# --------------- +@pytest.mark.parametrize("dtype, shape", + [(dtype, shape) \ + for dtype in dtypes\ + for shape in [128, 512]]) +def test_reduce1d(dtype, shape, device='cuda'): + dtype = cvt[dtype] + + # triton kernel + @triton.jit + def kernel(X, Z, **meta): + x = tl.load(X + tl.arange(0, meta['BLOCK'])) + tl.store(Z, tl.sum(x, axis=0)) + + x = triton.testing.random((shape,), dtype=dtype, device=device) + # triton result + z_tri = triton.testing.random((1,), dtype=dtype, device=device) + kernel[(1,)](x, z_tri, BLOCK=shape) + # torch result + z_ref = torch.sum(x).to(dtype) + # compare + triton.testing.assert_almost_equal(z_tri, z_ref) + + + # --------------- # test load