Allow JITFunction to return multiple results

This commit is contained in:
Yan Da
2022-04-15 15:38:19 +08:00
parent 1c52bd587d
commit 9e304cf79d
2 changed files with 40 additions and 7 deletions

View File

@@ -159,14 +159,17 @@ class CodeGenerator(ast.NodeVisitor):
def visit_Return(self, node):
ret_value = self.visit(node.value)
if ret_value is None:
return self.builder.ret([])
if isinstance(ret_value, list):
assert False, "returing a tuple is not supported"
self.builder.ret([])
return None
if isinstance(ret_value, tuple):
ret_values = [triton.language.core._to_tensor(v, self.builder) for v in ret_value]
ret_types = [v.type for v in ret_values]
self.builder.ret([v.handle for v in ret_values])
return tuple(ret_types)
else:
ret = triton.language.core._to_tensor(ret_value, self.builder)
ret_type = ret.type
self.builder.ret([ret_value.handle])
return ret_type
return ret.type
def visit_FunctionDef(self, node):
arg_names, kwarg_names = self.visit(node.args)
@@ -219,7 +222,7 @@ class CodeGenerator(ast.NodeVisitor):
else:
# update return type
if isinstance(self.last_ret_type, tuple):
self.prototype.ret_types = [ret_tensor.type for ret_tensor in self.last_ret_type]
self.prototype.ret_types = list(self.last_ret_type)
fn.reset_type(self.prototype.to_ir(self.builder))
else:
self.prototype.ret_types = [self.last_ret_type]
@@ -661,7 +664,10 @@ class CodeGenerator(ast.NodeVisitor):
return triton.language.tensor(call_op.get_result(0), callee_ret_type)
else:
# should return a tuple of tl.tensor
raise RuntimeError("Not implemented")
results = []
for i in range(call_op.get_num_results()):
results.append(triton.language.tensor(call_op.get_result(i), callee_ret_type[i]))
return tuple(results)
if hasattr(fn, '__self__') and self.is_triton_tensor(fn.__self__) or \
sys.modules[fn.__module__] is triton.language.core:
return fn(*args, _builder=self.builder, **kws)

View File

@@ -0,0 +1,27 @@
import triton
import triton.language as tl
import triton._C.libtriton.triton as _triton
@triton.jit
def foo(a, b):
max, min = maxmin(a, b)
return max, min
@triton.jit
def maxmin(a, b):
max = tl.maximum(a, b)
min = tl.minimum(a, b)
return max, min
mod, ctx = foo.compile_to_ttir(3, 4, grid=(1,))
assert mod.verify()
mod.dump()
pm = _triton.ir.pass_manager(ctx)
pm.add_inliner_pass()
pm.run(mod)
assert mod.verify()
mod.dump()