[PYTHON] Added support for tuples (#116)

This commit is contained in:
Philippe Tillet
2021-05-20 14:12:04 -04:00
committed by Philippe Tillet
parent b5dcac484d
commit 3ab121dbdb
2 changed files with 60 additions and 9 deletions

View File

@@ -189,6 +189,52 @@ def test_index1d(expr, device='cuda'):
triton.testing.assert_allclose(z_ref, z_tri)
# ---------------
# test tuples
# ---------------
@triton.jit
def fn(a, b):
return a + b, \
a - b, \
a * b
def test_tuples():
device = 'cuda'
@triton.jit
def with_fn(X, Y, A, B, C):
x = tl.load(X)
y = tl.load(Y)
a, b, c = fn(x, y)
tl.store(A, a)
tl.store(B, b)
tl.store(C, c)
@triton.jit
def without_fn(X, Y, A, B, C):
x = tl.load(X)
y = tl.load(Y)
a, b, c = x + y, x - y, x * y
tl.store(A, a)
tl.store(B, b)
tl.store(C, c)
x = torch.tensor([1.3], device=device, dtype=torch.float32)
y = torch.tensor([1.9], device=device, dtype=torch.float32)
a_tri = torch.tensor([0], device=device, dtype=torch.float32)
b_tri = torch.tensor([0], device=device, dtype=torch.float32)
c_tri = torch.tensor([0], device=device, dtype=torch.float32)
for kernel in [with_fn, without_fn]:
kernel[(1, )](x, y, a_tri, b_tri, c_tri, num_warps=1)
a_ref, b_ref, c_ref = x + y, x - y, x * y
assert a_tri == a_ref
assert b_tri == b_ref
assert c_tri == c_ref
# ---------------
# test atomics
# ---------------