[BACKEND] Restored reduction bugfixes

This commit is contained in:
Philippe Tillet
2022-06-03 11:38:52 -07:00
parent a60374a597
commit 8876e53206
11 changed files with 173 additions and 65 deletions

View File

@@ -698,6 +698,7 @@ def test_reduce1d(dtype_str, shape, device='cuda'):
rs = RandomState(17)
x = numpy_random((shape,), dtype_str=dtype_str, rs=rs)
x[:] = 1
# numpy result
z_ref = np.sum(x).astype(getattr(np, dtype_str))
# triton result
@@ -1132,3 +1133,25 @@ def test_constexpr_shape():
x_tri = to_triton(np.empty((256, ), dtype=np.int32))
kernel[(1,)](x_tri)
np.testing.assert_equal(to_numpy(x_tri), np.arange(0, 256))
# -------------
# test if
# -------------
def test_if():
@triton.jit
def kernel(Cond, XTrue, XFalse, Ret):
pid = tl.program_id(0)
cond = tl.load(Cond)
if pid % 2:
tl.store(Ret, tl.load(XTrue))
else:
tl.store(Ret, tl.load(XFalse))
cond = torch.ones(1, dtype=torch.int32, device='cuda')
x_true = torch.tensor([3.14], dtype=torch.float32, device='cuda')
x_false = torch.tensor([1.51], dtype=torch.float32, device='cuda')
ret = torch.empty(1, dtype=torch.float32, device='cuda')
kernel[(1,)](cond, x_true, x_false, ret)