[CODEGEN] Fixed bug for visit_reduce1d with 64-bit data-types (#207)
This commit is contained in:
		| @@ -167,6 +167,7 @@ public: | ||||
|   void visit_dot_inst(ir::dot_inst*); | ||||
|   void visit_trans_inst(ir::trans_inst*); | ||||
|   void visit_sqrt_inst(ir::sqrt_inst*); | ||||
|   Value* shfl_sync(Value* acc, int32_t i); | ||||
|   void visit_reduce1d_inst(ir::reduce_inst*, std::function<Value*(Value*,Value*)>, Value*); | ||||
|   void visit_reducend_inst(ir::reduce_inst*, std::function<Value*(Value*,Value*)>, Value*); | ||||
|   void visit_reduce_inst(ir::reduce_inst*); | ||||
|   | ||||
| @@ -1723,6 +1723,21 @@ Value* generator::shared_off(const std::vector<unsigned>& shapes, const std::vec | ||||
|   return result; | ||||
| } | ||||
|  | ||||
| inline Value* generator::shfl_sync(Value* acc, int32_t i){ | ||||
|   Type* ty = acc->getType(); | ||||
|   std::string asm_str = "shfl.sync.bfly.b32 $0, $1, $2, 0x1f, 0xffffffff;"; | ||||
|   InlineAsm *shfl = InlineAsm::get(FunctionType::get(ty, {ty, i32_ty}, false), asm_str, "=f,f,r", false); | ||||
|   if(ty->getPrimitiveSizeInBits() <= 32) | ||||
|     return call(shfl, {acc, i32(i)}); | ||||
|   acc = builder_->CreateBitCast(acc, vec_ty(f32_ty, 2)); | ||||
|   Value* acc0 = builder_->CreateExtractElement(acc, i32(0)); | ||||
|   Value* acc1 = builder_->CreateExtractElement(acc, i32(1)); | ||||
|   Value* ret = UndefValue::get(vec_ty(f32_ty, 2)); | ||||
|   ret = insert_elt(ret, shfl_sync(acc0, i), i32(0)); | ||||
|   ret = insert_elt(ret, shfl_sync(acc1, i), i32(1)); | ||||
|   return builder_->CreateBitCast(ret, ty); | ||||
| } | ||||
|  | ||||
| /** | ||||
|  * \brief Code Generation for `reduce` (1D case) | ||||
|  */ | ||||
| @@ -1738,10 +1753,8 @@ void generator::visit_reduce1d_inst(ir::reduce_inst* x, std::function<Value*(Val | ||||
|     acc = !acc ? val : do_acc(acc, val); | ||||
|   } | ||||
|   // reduce within wrap | ||||
|   InlineAsm *shfl = InlineAsm::get(FunctionType::get(ret_ty, {ret_ty, i32_ty}, false), | ||||
|                                    "shfl.sync.bfly.b32 $0, $1, $2, 0x1f, 0xffffffff;", "=f,f,r", false); | ||||
|   for(int i = 16; i > 0; i >>= 1) | ||||
|     acc = do_acc(acc, call(shfl, {acc, i32(i)})); | ||||
|     acc = do_acc(acc, shfl_sync(acc, i)); | ||||
|   // pointers | ||||
|   unsigned addr_space = shmem_->getType()->getPointerAddressSpace(); | ||||
|   Value *base = bit_cast(shmem_, ptr_ty(ret_ty, addr_space)); | ||||
| @@ -1765,7 +1778,7 @@ void generator::visit_reduce1d_inst(ir::reduce_inst* x, std::function<Value*(Val | ||||
|   builder_->SetInsertPoint(term); | ||||
|   Value* ret = load(gep(base, thread)); | ||||
|   for(int i = (num_warps_+1)/2; i > 0; i >>= 1){ | ||||
|     Value *current = call(shfl, {ret, i32(i)}); | ||||
|     Value *current = shfl_sync(ret, i); | ||||
|     ret = do_acc(ret, current); | ||||
|   } | ||||
|   store(ret, gep(base, thread)); | ||||
|   | ||||
| @@ -337,6 +337,33 @@ def test_cast(dtype_x, dtype_z, bitcast, device='cuda'): | ||||
|         z_ref = x.to(z_tri.dtype) | ||||
|     assert z_tri == z_ref | ||||
|  | ||||
| # --------------- | ||||
| # test reduce | ||||
| # --------------- | ||||
| @pytest.mark.parametrize("dtype, shape",  | ||||
|   [(dtype, shape) \ | ||||
|         for dtype in dtypes\ | ||||
|         for shape in [128, 512]]) | ||||
| def test_reduce1d(dtype, shape, device='cuda'): | ||||
|     dtype = cvt[dtype] | ||||
|  | ||||
|     # triton kernel | ||||
|     @triton.jit | ||||
|     def kernel(X, Z, **meta): | ||||
|         x = tl.load(X + tl.arange(0, meta['BLOCK'])) | ||||
|         tl.store(Z, tl.sum(x, axis=0)) | ||||
|  | ||||
|     x = triton.testing.random((shape,), dtype=dtype, device=device) | ||||
|     # triton result | ||||
|     z_tri = triton.testing.random((1,), dtype=dtype, device=device) | ||||
|     kernel[(1,)](x, z_tri, BLOCK=shape) | ||||
|     # torch result | ||||
|     z_ref = torch.sum(x).to(dtype) | ||||
|     # compare | ||||
|     triton.testing.assert_almost_equal(z_tri, z_ref) | ||||
|  | ||||
|  | ||||
|  | ||||
|  | ||||
| # --------------- | ||||
| # test load | ||||
|   | ||||
		Reference in New Issue
	
	Block a user