[CODEGEN] Fixed bug for visit_reduce1d with 64-bit data-types (#207)
This commit is contained in:
		| @@ -167,6 +167,7 @@ public: | |||||||
|   void visit_dot_inst(ir::dot_inst*); |   void visit_dot_inst(ir::dot_inst*); | ||||||
|   void visit_trans_inst(ir::trans_inst*); |   void visit_trans_inst(ir::trans_inst*); | ||||||
|   void visit_sqrt_inst(ir::sqrt_inst*); |   void visit_sqrt_inst(ir::sqrt_inst*); | ||||||
|  |   Value* shfl_sync(Value* acc, int32_t i); | ||||||
|   void visit_reduce1d_inst(ir::reduce_inst*, std::function<Value*(Value*,Value*)>, Value*); |   void visit_reduce1d_inst(ir::reduce_inst*, std::function<Value*(Value*,Value*)>, Value*); | ||||||
|   void visit_reducend_inst(ir::reduce_inst*, std::function<Value*(Value*,Value*)>, Value*); |   void visit_reducend_inst(ir::reduce_inst*, std::function<Value*(Value*,Value*)>, Value*); | ||||||
|   void visit_reduce_inst(ir::reduce_inst*); |   void visit_reduce_inst(ir::reduce_inst*); | ||||||
|   | |||||||
| @@ -1723,6 +1723,21 @@ Value* generator::shared_off(const std::vector<unsigned>& shapes, const std::vec | |||||||
|   return result; |   return result; | ||||||
| } | } | ||||||
|  |  | ||||||
|  | inline Value* generator::shfl_sync(Value* acc, int32_t i){ | ||||||
|  |   Type* ty = acc->getType(); | ||||||
|  |   std::string asm_str = "shfl.sync.bfly.b32 $0, $1, $2, 0x1f, 0xffffffff;"; | ||||||
|  |   InlineAsm *shfl = InlineAsm::get(FunctionType::get(ty, {ty, i32_ty}, false), asm_str, "=f,f,r", false); | ||||||
|  |   if(ty->getPrimitiveSizeInBits() <= 32) | ||||||
|  |     return call(shfl, {acc, i32(i)}); | ||||||
|  |   acc = builder_->CreateBitCast(acc, vec_ty(f32_ty, 2)); | ||||||
|  |   Value* acc0 = builder_->CreateExtractElement(acc, i32(0)); | ||||||
|  |   Value* acc1 = builder_->CreateExtractElement(acc, i32(1)); | ||||||
|  |   Value* ret = UndefValue::get(vec_ty(f32_ty, 2)); | ||||||
|  |   ret = insert_elt(ret, shfl_sync(acc0, i), i32(0)); | ||||||
|  |   ret = insert_elt(ret, shfl_sync(acc1, i), i32(1)); | ||||||
|  |   return builder_->CreateBitCast(ret, ty); | ||||||
|  | } | ||||||
|  |  | ||||||
| /** | /** | ||||||
|  * \brief Code Generation for `reduce` (1D case) |  * \brief Code Generation for `reduce` (1D case) | ||||||
|  */ |  */ | ||||||
| @@ -1738,10 +1753,8 @@ void generator::visit_reduce1d_inst(ir::reduce_inst* x, std::function<Value*(Val | |||||||
|     acc = !acc ? val : do_acc(acc, val); |     acc = !acc ? val : do_acc(acc, val); | ||||||
|   } |   } | ||||||
|   // reduce within wrap |   // reduce within wrap | ||||||
|   InlineAsm *shfl = InlineAsm::get(FunctionType::get(ret_ty, {ret_ty, i32_ty}, false), |  | ||||||
|                                    "shfl.sync.bfly.b32 $0, $1, $2, 0x1f, 0xffffffff;", "=f,f,r", false); |  | ||||||
|   for(int i = 16; i > 0; i >>= 1) |   for(int i = 16; i > 0; i >>= 1) | ||||||
|     acc = do_acc(acc, call(shfl, {acc, i32(i)})); |     acc = do_acc(acc, shfl_sync(acc, i)); | ||||||
|   // pointers |   // pointers | ||||||
|   unsigned addr_space = shmem_->getType()->getPointerAddressSpace(); |   unsigned addr_space = shmem_->getType()->getPointerAddressSpace(); | ||||||
|   Value *base = bit_cast(shmem_, ptr_ty(ret_ty, addr_space)); |   Value *base = bit_cast(shmem_, ptr_ty(ret_ty, addr_space)); | ||||||
| @@ -1765,7 +1778,7 @@ void generator::visit_reduce1d_inst(ir::reduce_inst* x, std::function<Value*(Val | |||||||
|   builder_->SetInsertPoint(term); |   builder_->SetInsertPoint(term); | ||||||
|   Value* ret = load(gep(base, thread)); |   Value* ret = load(gep(base, thread)); | ||||||
|   for(int i = (num_warps_+1)/2; i > 0; i >>= 1){ |   for(int i = (num_warps_+1)/2; i > 0; i >>= 1){ | ||||||
|     Value *current = call(shfl, {ret, i32(i)}); |     Value *current = shfl_sync(ret, i); | ||||||
|     ret = do_acc(ret, current); |     ret = do_acc(ret, current); | ||||||
|   } |   } | ||||||
|   store(ret, gep(base, thread)); |   store(ret, gep(base, thread)); | ||||||
|   | |||||||
| @@ -337,6 +337,33 @@ def test_cast(dtype_x, dtype_z, bitcast, device='cuda'): | |||||||
|         z_ref = x.to(z_tri.dtype) |         z_ref = x.to(z_tri.dtype) | ||||||
|     assert z_tri == z_ref |     assert z_tri == z_ref | ||||||
|  |  | ||||||
|  | # --------------- | ||||||
|  | # test reduce | ||||||
|  | # --------------- | ||||||
|  | @pytest.mark.parametrize("dtype, shape",  | ||||||
|  |   [(dtype, shape) \ | ||||||
|  |         for dtype in dtypes\ | ||||||
|  |         for shape in [128, 512]]) | ||||||
|  | def test_reduce1d(dtype, shape, device='cuda'): | ||||||
|  |     dtype = cvt[dtype] | ||||||
|  |  | ||||||
|  |     # triton kernel | ||||||
|  |     @triton.jit | ||||||
|  |     def kernel(X, Z, **meta): | ||||||
|  |         x = tl.load(X + tl.arange(0, meta['BLOCK'])) | ||||||
|  |         tl.store(Z, tl.sum(x, axis=0)) | ||||||
|  |  | ||||||
|  |     x = triton.testing.random((shape,), dtype=dtype, device=device) | ||||||
|  |     # triton result | ||||||
|  |     z_tri = triton.testing.random((1,), dtype=dtype, device=device) | ||||||
|  |     kernel[(1,)](x, z_tri, BLOCK=shape) | ||||||
|  |     # torch result | ||||||
|  |     z_ref = torch.sum(x).to(dtype) | ||||||
|  |     # compare | ||||||
|  |     triton.testing.assert_almost_equal(z_tri, z_ref) | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| # --------------- | # --------------- | ||||||
| # test load | # test load | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user