[PYTHON] Added support for tuples (#116)
This commit is contained in:
committed by
Philippe Tillet
parent
b5dcac484d
commit
3ab121dbdb
@@ -189,6 +189,52 @@ def test_index1d(expr, device='cuda'):
|
|||||||
triton.testing.assert_allclose(z_ref, z_tri)
|
triton.testing.assert_allclose(z_ref, z_tri)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------
|
||||||
|
# test tuples
|
||||||
|
# ---------------
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def fn(a, b):
|
||||||
|
return a + b, \
|
||||||
|
a - b, \
|
||||||
|
a * b
|
||||||
|
|
||||||
|
|
||||||
|
def test_tuples():
|
||||||
|
device = 'cuda'
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def with_fn(X, Y, A, B, C):
|
||||||
|
x = tl.load(X)
|
||||||
|
y = tl.load(Y)
|
||||||
|
a, b, c = fn(x, y)
|
||||||
|
tl.store(A, a)
|
||||||
|
tl.store(B, b)
|
||||||
|
tl.store(C, c)
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def without_fn(X, Y, A, B, C):
|
||||||
|
x = tl.load(X)
|
||||||
|
y = tl.load(Y)
|
||||||
|
a, b, c = x + y, x - y, x * y
|
||||||
|
tl.store(A, a)
|
||||||
|
tl.store(B, b)
|
||||||
|
tl.store(C, c)
|
||||||
|
|
||||||
|
x = torch.tensor([1.3], device=device, dtype=torch.float32)
|
||||||
|
y = torch.tensor([1.9], device=device, dtype=torch.float32)
|
||||||
|
a_tri = torch.tensor([0], device=device, dtype=torch.float32)
|
||||||
|
b_tri = torch.tensor([0], device=device, dtype=torch.float32)
|
||||||
|
c_tri = torch.tensor([0], device=device, dtype=torch.float32)
|
||||||
|
for kernel in [with_fn, without_fn]:
|
||||||
|
kernel[(1, )](x, y, a_tri, b_tri, c_tri, num_warps=1)
|
||||||
|
a_ref, b_ref, c_ref = x + y, x - y, x * y
|
||||||
|
assert a_tri == a_ref
|
||||||
|
assert b_tri == b_ref
|
||||||
|
assert c_tri == c_ref
|
||||||
|
|
||||||
|
|
||||||
# ---------------
|
# ---------------
|
||||||
# test atomics
|
# test atomics
|
||||||
# ---------------
|
# ---------------
|
||||||
|
@@ -11,7 +11,7 @@ import triton._C.libtriton.triton as _triton
|
|||||||
import triton
|
import triton
|
||||||
import sys
|
import sys
|
||||||
import textwrap
|
import textwrap
|
||||||
from abc import ABC, abstractmethod
|
import collections
|
||||||
|
|
||||||
|
|
||||||
class CodeGenerator(ast.NodeVisitor):
|
class CodeGenerator(ast.NodeVisitor):
|
||||||
@@ -134,15 +134,20 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
return node.arg
|
return node.arg
|
||||||
|
|
||||||
def visit_Assign(self, node):
|
def visit_Assign(self, node):
|
||||||
names = []
|
_names = []
|
||||||
for target in node.targets:
|
for target in node.targets:
|
||||||
names += [self.visit(target)]
|
_names += [self.visit(target)]
|
||||||
assert len(names) == 1
|
assert len(_names) == 1
|
||||||
name = names[0]
|
names = _names[0]
|
||||||
value = self.visit(node.value)
|
values = self.visit(node.value)
|
||||||
if not isinstance(value, triton.language.block):
|
if not isinstance(names, tuple):
|
||||||
value = triton.language._to_ir(value, self.builder)
|
names = [names]
|
||||||
self.set_value(names[0], value)
|
if not isinstance(values, tuple):
|
||||||
|
values = [values]
|
||||||
|
for name, value in zip(names, values):
|
||||||
|
if not isinstance(value, triton.language.block):
|
||||||
|
value = triton.language._to_ir(value, self.builder)
|
||||||
|
self.set_value(name, value)
|
||||||
|
|
||||||
def visit_AugAssign(self, node):
|
def visit_AugAssign(self, node):
|
||||||
name = node.target.id
|
name = node.target.id
|
||||||
|
Reference in New Issue
Block a user