diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 59e26a3cc..c231666bf 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -159,14 +159,17 @@ class CodeGenerator(ast.NodeVisitor): def visit_Return(self, node): ret_value = self.visit(node.value) if ret_value is None: - return self.builder.ret([]) - if isinstance(ret_value, list): - assert False, "returing a tuple is not supported" + self.builder.ret([]) + return None + if isinstance(ret_value, tuple): + ret_values = [triton.language.core._to_tensor(v, self.builder) for v in ret_value] + ret_types = [v.type for v in ret_values] + self.builder.ret([v.handle for v in ret_values]) + return tuple(ret_types) else: ret = triton.language.core._to_tensor(ret_value, self.builder) - ret_type = ret.type self.builder.ret([ret_value.handle]) - return ret_type + return ret.type def visit_FunctionDef(self, node): arg_names, kwarg_names = self.visit(node.args) @@ -219,7 +222,7 @@ class CodeGenerator(ast.NodeVisitor): else: # update return type if isinstance(self.last_ret_type, tuple): - self.prototype.ret_types = [ret_tensor.type for ret_tensor in self.last_ret_type] + self.prototype.ret_types = list(self.last_ret_type) fn.reset_type(self.prototype.to_ir(self.builder)) else: self.prototype.ret_types = [self.last_ret_type] @@ -661,7 +664,10 @@ class CodeGenerator(ast.NodeVisitor): return triton.language.tensor(call_op.get_result(0), callee_ret_type) else: # should return a tuple of tl.tensor - raise RuntimeError("Not implemented") + results = [] + for i in range(call_op.get_num_results()): + results.append(triton.language.tensor(call_op.get_result(i), callee_ret_type[i])) + return tuple(results) if hasattr(fn, '__self__') and self.is_triton_tensor(fn.__self__) or \ sys.modules[fn.__module__] is triton.language.core: return fn(*args, _builder=self.builder, **kws) diff --git a/rewrite-test/jit/multi-return.py b/rewrite-test/jit/multi-return.py new file mode 100644 index 000000000..00588bf0b --- /dev/null +++ b/rewrite-test/jit/multi-return.py @@ -0,0 +1,27 @@ +import triton +import triton.language as tl +import triton._C.libtriton.triton as _triton + + +@triton.jit +def foo(a, b): + max, min = maxmin(a, b) + return max, min + +@triton.jit +def maxmin(a, b): + max = tl.maximum(a, b) + min = tl.minimum(a, b) + return max, min + + +mod, ctx = foo.compile_to_ttir(3, 4, grid=(1,)) +assert mod.verify() +mod.dump() + + +pm = _triton.ir.pass_manager(ctx) +pm.add_inliner_pass() +pm.run(mod) +assert mod.verify() +mod.dump()