[CODEGEN] Fixed bug for visit_reduce1d with 64-bit data-types (#207)
This commit is contained in:
@@ -167,6 +167,7 @@ public:
|
||||
void visit_dot_inst(ir::dot_inst*);
|
||||
void visit_trans_inst(ir::trans_inst*);
|
||||
void visit_sqrt_inst(ir::sqrt_inst*);
|
||||
Value* shfl_sync(Value* acc, int32_t i);
|
||||
void visit_reduce1d_inst(ir::reduce_inst*, std::function<Value*(Value*,Value*)>, Value*);
|
||||
void visit_reducend_inst(ir::reduce_inst*, std::function<Value*(Value*,Value*)>, Value*);
|
||||
void visit_reduce_inst(ir::reduce_inst*);
|
||||
|
@@ -1723,6 +1723,21 @@ Value* generator::shared_off(const std::vector<unsigned>& shapes, const std::vec
|
||||
return result;
|
||||
}
|
||||
|
||||
inline Value* generator::shfl_sync(Value* acc, int32_t i){
|
||||
Type* ty = acc->getType();
|
||||
std::string asm_str = "shfl.sync.bfly.b32 $0, $1, $2, 0x1f, 0xffffffff;";
|
||||
InlineAsm *shfl = InlineAsm::get(FunctionType::get(ty, {ty, i32_ty}, false), asm_str, "=f,f,r", false);
|
||||
if(ty->getPrimitiveSizeInBits() <= 32)
|
||||
return call(shfl, {acc, i32(i)});
|
||||
acc = builder_->CreateBitCast(acc, vec_ty(f32_ty, 2));
|
||||
Value* acc0 = builder_->CreateExtractElement(acc, i32(0));
|
||||
Value* acc1 = builder_->CreateExtractElement(acc, i32(1));
|
||||
Value* ret = UndefValue::get(vec_ty(f32_ty, 2));
|
||||
ret = insert_elt(ret, shfl_sync(acc0, i), i32(0));
|
||||
ret = insert_elt(ret, shfl_sync(acc1, i), i32(1));
|
||||
return builder_->CreateBitCast(ret, ty);
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Code Generation for `reduce` (1D case)
|
||||
*/
|
||||
@@ -1738,10 +1753,8 @@ void generator::visit_reduce1d_inst(ir::reduce_inst* x, std::function<Value*(Val
|
||||
acc = !acc ? val : do_acc(acc, val);
|
||||
}
|
||||
// reduce within wrap
|
||||
InlineAsm *shfl = InlineAsm::get(FunctionType::get(ret_ty, {ret_ty, i32_ty}, false),
|
||||
"shfl.sync.bfly.b32 $0, $1, $2, 0x1f, 0xffffffff;", "=f,f,r", false);
|
||||
for(int i = 16; i > 0; i >>= 1)
|
||||
acc = do_acc(acc, call(shfl, {acc, i32(i)}));
|
||||
acc = do_acc(acc, shfl_sync(acc, i));
|
||||
// pointers
|
||||
unsigned addr_space = shmem_->getType()->getPointerAddressSpace();
|
||||
Value *base = bit_cast(shmem_, ptr_ty(ret_ty, addr_space));
|
||||
@@ -1765,7 +1778,7 @@ void generator::visit_reduce1d_inst(ir::reduce_inst* x, std::function<Value*(Val
|
||||
builder_->SetInsertPoint(term);
|
||||
Value* ret = load(gep(base, thread));
|
||||
for(int i = (num_warps_+1)/2; i > 0; i >>= 1){
|
||||
Value *current = call(shfl, {ret, i32(i)});
|
||||
Value *current = shfl_sync(ret, i);
|
||||
ret = do_acc(ret, current);
|
||||
}
|
||||
store(ret, gep(base, thread));
|
||||
|
@@ -337,6 +337,33 @@ def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
|
||||
z_ref = x.to(z_tri.dtype)
|
||||
assert z_tri == z_ref
|
||||
|
||||
# ---------------
|
||||
# test reduce
|
||||
# ---------------
|
||||
@pytest.mark.parametrize("dtype, shape",
|
||||
[(dtype, shape) \
|
||||
for dtype in dtypes\
|
||||
for shape in [128, 512]])
|
||||
def test_reduce1d(dtype, shape, device='cuda'):
|
||||
dtype = cvt[dtype]
|
||||
|
||||
# triton kernel
|
||||
@triton.jit
|
||||
def kernel(X, Z, **meta):
|
||||
x = tl.load(X + tl.arange(0, meta['BLOCK']))
|
||||
tl.store(Z, tl.sum(x, axis=0))
|
||||
|
||||
x = triton.testing.random((shape,), dtype=dtype, device=device)
|
||||
# triton result
|
||||
z_tri = triton.testing.random((1,), dtype=dtype, device=device)
|
||||
kernel[(1,)](x, z_tri, BLOCK=shape)
|
||||
# torch result
|
||||
z_ref = torch.sum(x).to(dtype)
|
||||
# compare
|
||||
triton.testing.assert_almost_equal(z_tri, z_ref)
|
||||
|
||||
|
||||
|
||||
|
||||
# ---------------
|
||||
# test load
|
||||
|
Reference in New Issue
Block a user