[PYTHON] Added support for tuples (#116)
This commit is contained in:
committed by
Philippe Tillet
parent
b5dcac484d
commit
3ab121dbdb
@@ -189,6 +189,52 @@ def test_index1d(expr, device='cuda'):
|
||||
triton.testing.assert_allclose(z_ref, z_tri)
|
||||
|
||||
|
||||
# ---------------
|
||||
# test tuples
|
||||
# ---------------
|
||||
|
||||
|
||||
@triton.jit
|
||||
def fn(a, b):
|
||||
return a + b, \
|
||||
a - b, \
|
||||
a * b
|
||||
|
||||
|
||||
def test_tuples():
|
||||
device = 'cuda'
|
||||
|
||||
@triton.jit
|
||||
def with_fn(X, Y, A, B, C):
|
||||
x = tl.load(X)
|
||||
y = tl.load(Y)
|
||||
a, b, c = fn(x, y)
|
||||
tl.store(A, a)
|
||||
tl.store(B, b)
|
||||
tl.store(C, c)
|
||||
|
||||
@triton.jit
|
||||
def without_fn(X, Y, A, B, C):
|
||||
x = tl.load(X)
|
||||
y = tl.load(Y)
|
||||
a, b, c = x + y, x - y, x * y
|
||||
tl.store(A, a)
|
||||
tl.store(B, b)
|
||||
tl.store(C, c)
|
||||
|
||||
x = torch.tensor([1.3], device=device, dtype=torch.float32)
|
||||
y = torch.tensor([1.9], device=device, dtype=torch.float32)
|
||||
a_tri = torch.tensor([0], device=device, dtype=torch.float32)
|
||||
b_tri = torch.tensor([0], device=device, dtype=torch.float32)
|
||||
c_tri = torch.tensor([0], device=device, dtype=torch.float32)
|
||||
for kernel in [with_fn, without_fn]:
|
||||
kernel[(1, )](x, y, a_tri, b_tri, c_tri, num_warps=1)
|
||||
a_ref, b_ref, c_ref = x + y, x - y, x * y
|
||||
assert a_tri == a_ref
|
||||
assert b_tri == b_ref
|
||||
assert c_tri == c_ref
|
||||
|
||||
|
||||
# ---------------
|
||||
# test atomics
|
||||
# ---------------
|
||||
|
Reference in New Issue
Block a user