[PYTHON] Added support for tuples (#116)

This commit is contained in:
Philippe Tillet
2021-05-20 14:12:04 -04:00
committed by Philippe Tillet
parent b5dcac484d
commit 3ab121dbdb
2 changed files with 60 additions and 9 deletions

View File

@@ -189,6 +189,52 @@ def test_index1d(expr, device='cuda'):
triton.testing.assert_allclose(z_ref, z_tri)
# ---------------
# test tuples
# ---------------
@triton.jit
def fn(a, b):
return a + b, \
a - b, \
a * b
def test_tuples():
device = 'cuda'
@triton.jit
def with_fn(X, Y, A, B, C):
x = tl.load(X)
y = tl.load(Y)
a, b, c = fn(x, y)
tl.store(A, a)
tl.store(B, b)
tl.store(C, c)
@triton.jit
def without_fn(X, Y, A, B, C):
x = tl.load(X)
y = tl.load(Y)
a, b, c = x + y, x - y, x * y
tl.store(A, a)
tl.store(B, b)
tl.store(C, c)
x = torch.tensor([1.3], device=device, dtype=torch.float32)
y = torch.tensor([1.9], device=device, dtype=torch.float32)
a_tri = torch.tensor([0], device=device, dtype=torch.float32)
b_tri = torch.tensor([0], device=device, dtype=torch.float32)
c_tri = torch.tensor([0], device=device, dtype=torch.float32)
for kernel in [with_fn, without_fn]:
kernel[(1, )](x, y, a_tri, b_tri, c_tri, num_warps=1)
a_ref, b_ref, c_ref = x + y, x - y, x * y
assert a_tri == a_ref
assert b_tri == b_ref
assert c_tri == c_ref
# ---------------
# test atomics
# ---------------

View File

@@ -11,7 +11,7 @@ import triton._C.libtriton.triton as _triton
import triton
import sys
import textwrap
from abc import ABC, abstractmethod
import collections
class CodeGenerator(ast.NodeVisitor):
@@ -134,15 +134,20 @@ class CodeGenerator(ast.NodeVisitor):
return node.arg
def visit_Assign(self, node):
names = []
_names = []
for target in node.targets:
names += [self.visit(target)]
assert len(names) == 1
name = names[0]
value = self.visit(node.value)
if not isinstance(value, triton.language.block):
value = triton.language._to_ir(value, self.builder)
self.set_value(names[0], value)
_names += [self.visit(target)]
assert len(_names) == 1
names = _names[0]
values = self.visit(node.value)
if not isinstance(names, tuple):
names = [names]
if not isinstance(values, tuple):
values = [values]
for name, value in zip(names, values):
if not isinstance(value, triton.language.block):
value = triton.language._to_ir(value, self.builder)
self.set_value(name, value)
def visit_AugAssign(self, node):
name = node.target.id