Allow JITFunction to return multiple results
This commit is contained in:
@@ -159,14 +159,17 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
def visit_Return(self, node):
|
||||
ret_value = self.visit(node.value)
|
||||
if ret_value is None:
|
||||
return self.builder.ret([])
|
||||
if isinstance(ret_value, list):
|
||||
assert False, "returing a tuple is not supported"
|
||||
self.builder.ret([])
|
||||
return None
|
||||
if isinstance(ret_value, tuple):
|
||||
ret_values = [triton.language.core._to_tensor(v, self.builder) for v in ret_value]
|
||||
ret_types = [v.type for v in ret_values]
|
||||
self.builder.ret([v.handle for v in ret_values])
|
||||
return tuple(ret_types)
|
||||
else:
|
||||
ret = triton.language.core._to_tensor(ret_value, self.builder)
|
||||
ret_type = ret.type
|
||||
self.builder.ret([ret_value.handle])
|
||||
return ret_type
|
||||
return ret.type
|
||||
|
||||
def visit_FunctionDef(self, node):
|
||||
arg_names, kwarg_names = self.visit(node.args)
|
||||
@@ -219,7 +222,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
else:
|
||||
# update return type
|
||||
if isinstance(self.last_ret_type, tuple):
|
||||
self.prototype.ret_types = [ret_tensor.type for ret_tensor in self.last_ret_type]
|
||||
self.prototype.ret_types = list(self.last_ret_type)
|
||||
fn.reset_type(self.prototype.to_ir(self.builder))
|
||||
else:
|
||||
self.prototype.ret_types = [self.last_ret_type]
|
||||
@@ -661,7 +664,10 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
return triton.language.tensor(call_op.get_result(0), callee_ret_type)
|
||||
else:
|
||||
# should return a tuple of tl.tensor
|
||||
raise RuntimeError("Not implemented")
|
||||
results = []
|
||||
for i in range(call_op.get_num_results()):
|
||||
results.append(triton.language.tensor(call_op.get_result(i), callee_ret_type[i]))
|
||||
return tuple(results)
|
||||
if hasattr(fn, '__self__') and self.is_triton_tensor(fn.__self__) or \
|
||||
sys.modules[fn.__module__] is triton.language.core:
|
||||
return fn(*args, _builder=self.builder, **kws)
|
||||
|
27
rewrite-test/jit/multi-return.py
Normal file
27
rewrite-test/jit/multi-return.py
Normal file
@@ -0,0 +1,27 @@
|
||||
import triton
|
||||
import triton.language as tl
|
||||
import triton._C.libtriton.triton as _triton
|
||||
|
||||
|
||||
@triton.jit
|
||||
def foo(a, b):
|
||||
max, min = maxmin(a, b)
|
||||
return max, min
|
||||
|
||||
@triton.jit
|
||||
def maxmin(a, b):
|
||||
max = tl.maximum(a, b)
|
||||
min = tl.minimum(a, b)
|
||||
return max, min
|
||||
|
||||
|
||||
mod, ctx = foo.compile_to_ttir(3, 4, grid=(1,))
|
||||
assert mod.verify()
|
||||
mod.dump()
|
||||
|
||||
|
||||
pm = _triton.ir.pass_manager(ctx)
|
||||
pm.add_inliner_pass()
|
||||
pm.run(mod)
|
||||
assert mod.verify()
|
||||
mod.dump()
|
Reference in New Issue
Block a user