[PYTHON] Added support for tuples (#116)
This commit is contained in:
committed by
Philippe Tillet
parent
b5dcac484d
commit
3ab121dbdb
@@ -189,6 +189,52 @@ def test_index1d(expr, device='cuda'):
|
||||
triton.testing.assert_allclose(z_ref, z_tri)
|
||||
|
||||
|
||||
# ---------------
|
||||
# test tuples
|
||||
# ---------------
|
||||
|
||||
|
||||
@triton.jit
|
||||
def fn(a, b):
|
||||
return a + b, \
|
||||
a - b, \
|
||||
a * b
|
||||
|
||||
|
||||
def test_tuples():
|
||||
device = 'cuda'
|
||||
|
||||
@triton.jit
|
||||
def with_fn(X, Y, A, B, C):
|
||||
x = tl.load(X)
|
||||
y = tl.load(Y)
|
||||
a, b, c = fn(x, y)
|
||||
tl.store(A, a)
|
||||
tl.store(B, b)
|
||||
tl.store(C, c)
|
||||
|
||||
@triton.jit
|
||||
def without_fn(X, Y, A, B, C):
|
||||
x = tl.load(X)
|
||||
y = tl.load(Y)
|
||||
a, b, c = x + y, x - y, x * y
|
||||
tl.store(A, a)
|
||||
tl.store(B, b)
|
||||
tl.store(C, c)
|
||||
|
||||
x = torch.tensor([1.3], device=device, dtype=torch.float32)
|
||||
y = torch.tensor([1.9], device=device, dtype=torch.float32)
|
||||
a_tri = torch.tensor([0], device=device, dtype=torch.float32)
|
||||
b_tri = torch.tensor([0], device=device, dtype=torch.float32)
|
||||
c_tri = torch.tensor([0], device=device, dtype=torch.float32)
|
||||
for kernel in [with_fn, without_fn]:
|
||||
kernel[(1, )](x, y, a_tri, b_tri, c_tri, num_warps=1)
|
||||
a_ref, b_ref, c_ref = x + y, x - y, x * y
|
||||
assert a_tri == a_ref
|
||||
assert b_tri == b_ref
|
||||
assert c_tri == c_ref
|
||||
|
||||
|
||||
# ---------------
|
||||
# test atomics
|
||||
# ---------------
|
||||
|
@@ -11,7 +11,7 @@ import triton._C.libtriton.triton as _triton
|
||||
import triton
|
||||
import sys
|
||||
import textwrap
|
||||
from abc import ABC, abstractmethod
|
||||
import collections
|
||||
|
||||
|
||||
class CodeGenerator(ast.NodeVisitor):
|
||||
@@ -134,15 +134,20 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
return node.arg
|
||||
|
||||
def visit_Assign(self, node):
|
||||
names = []
|
||||
_names = []
|
||||
for target in node.targets:
|
||||
names += [self.visit(target)]
|
||||
assert len(names) == 1
|
||||
name = names[0]
|
||||
value = self.visit(node.value)
|
||||
if not isinstance(value, triton.language.block):
|
||||
value = triton.language._to_ir(value, self.builder)
|
||||
self.set_value(names[0], value)
|
||||
_names += [self.visit(target)]
|
||||
assert len(_names) == 1
|
||||
names = _names[0]
|
||||
values = self.visit(node.value)
|
||||
if not isinstance(names, tuple):
|
||||
names = [names]
|
||||
if not isinstance(values, tuple):
|
||||
values = [values]
|
||||
for name, value in zip(names, values):
|
||||
if not isinstance(value, triton.language.block):
|
||||
value = triton.language._to_ir(value, self.builder)
|
||||
self.set_value(name, value)
|
||||
|
||||
def visit_AugAssign(self, node):
|
||||
name = node.target.id
|
||||
|
Reference in New Issue
Block a user