From 3ab121dbdb98ba80ea1925e255173dca4b6ce58f Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Thu, 20 May 2021 14:12:04 -0400 Subject: [PATCH] [PYTHON] Added support for tuples (#116) --- python/test/test_language.py | 46 ++++++++++++++++++++++++++++++++++++ python/triton/code_gen.py | 23 +++++++++++------- 2 files changed, 60 insertions(+), 9 deletions(-) diff --git a/python/test/test_language.py b/python/test/test_language.py index 4b447af78..4ef7302b4 100644 --- a/python/test/test_language.py +++ b/python/test/test_language.py @@ -189,6 +189,52 @@ def test_index1d(expr, device='cuda'): triton.testing.assert_allclose(z_ref, z_tri) +# --------------- +# test tuples +# --------------- + + +@triton.jit +def fn(a, b): + return a + b, \ + a - b, \ + a * b + + +def test_tuples(): + device = 'cuda' + + @triton.jit + def with_fn(X, Y, A, B, C): + x = tl.load(X) + y = tl.load(Y) + a, b, c = fn(x, y) + tl.store(A, a) + tl.store(B, b) + tl.store(C, c) + + @triton.jit + def without_fn(X, Y, A, B, C): + x = tl.load(X) + y = tl.load(Y) + a, b, c = x + y, x - y, x * y + tl.store(A, a) + tl.store(B, b) + tl.store(C, c) + + x = torch.tensor([1.3], device=device, dtype=torch.float32) + y = torch.tensor([1.9], device=device, dtype=torch.float32) + a_tri = torch.tensor([0], device=device, dtype=torch.float32) + b_tri = torch.tensor([0], device=device, dtype=torch.float32) + c_tri = torch.tensor([0], device=device, dtype=torch.float32) + for kernel in [with_fn, without_fn]: + kernel[(1, )](x, y, a_tri, b_tri, c_tri, num_warps=1) + a_ref, b_ref, c_ref = x + y, x - y, x * y + assert a_tri == a_ref + assert b_tri == b_ref + assert c_tri == c_ref + + # --------------- # test atomics # --------------- diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 014104b99..ce4e5e12b 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -11,7 +11,7 @@ import triton._C.libtriton.triton as _triton import triton import sys import textwrap -from abc import ABC, abstractmethod +import collections class CodeGenerator(ast.NodeVisitor): @@ -134,15 +134,20 @@ class CodeGenerator(ast.NodeVisitor): return node.arg def visit_Assign(self, node): - names = [] + _names = [] for target in node.targets: - names += [self.visit(target)] - assert len(names) == 1 - name = names[0] - value = self.visit(node.value) - if not isinstance(value, triton.language.block): - value = triton.language._to_ir(value, self.builder) - self.set_value(names[0], value) + _names += [self.visit(target)] + assert len(_names) == 1 + names = _names[0] + values = self.visit(node.value) + if not isinstance(names, tuple): + names = [names] + if not isinstance(values, tuple): + values = [values] + for name, value in zip(names, values): + if not isinstance(value, triton.language.block): + value = triton.language._to_ir(value, self.builder) + self.set_value(name, value) def visit_AugAssign(self, node): name = node.target.id