[BACKEND] Restored reduction bugfixes
This commit is contained in:
@@ -698,6 +698,7 @@ def test_reduce1d(dtype_str, shape, device='cuda'):
|
||||
|
||||
rs = RandomState(17)
|
||||
x = numpy_random((shape,), dtype_str=dtype_str, rs=rs)
|
||||
x[:] = 1
|
||||
# numpy result
|
||||
z_ref = np.sum(x).astype(getattr(np, dtype_str))
|
||||
# triton result
|
||||
@@ -1132,3 +1133,25 @@ def test_constexpr_shape():
|
||||
x_tri = to_triton(np.empty((256, ), dtype=np.int32))
|
||||
kernel[(1,)](x_tri)
|
||||
np.testing.assert_equal(to_numpy(x_tri), np.arange(0, 256))
|
||||
|
||||
# -------------
|
||||
# test if
|
||||
# -------------
|
||||
|
||||
|
||||
def test_if():
|
||||
|
||||
@triton.jit
|
||||
def kernel(Cond, XTrue, XFalse, Ret):
|
||||
pid = tl.program_id(0)
|
||||
cond = tl.load(Cond)
|
||||
if pid % 2:
|
||||
tl.store(Ret, tl.load(XTrue))
|
||||
else:
|
||||
tl.store(Ret, tl.load(XFalse))
|
||||
|
||||
cond = torch.ones(1, dtype=torch.int32, device='cuda')
|
||||
x_true = torch.tensor([3.14], dtype=torch.float32, device='cuda')
|
||||
x_false = torch.tensor([1.51], dtype=torch.float32, device='cuda')
|
||||
ret = torch.empty(1, dtype=torch.float32, device='cuda')
|
||||
kernel[(1,)](cond, x_true, x_false, ret)
|
||||
|
Reference in New Issue
Block a user