Allow JITFunction to return multiple results
This commit is contained in:
@@ -159,14 +159,17 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
def visit_Return(self, node):
|
||||
ret_value = self.visit(node.value)
|
||||
if ret_value is None:
|
||||
return self.builder.ret([])
|
||||
if isinstance(ret_value, list):
|
||||
assert False, "returing a tuple is not supported"
|
||||
self.builder.ret([])
|
||||
return None
|
||||
if isinstance(ret_value, tuple):
|
||||
ret_values = [triton.language.core._to_tensor(v, self.builder) for v in ret_value]
|
||||
ret_types = [v.type for v in ret_values]
|
||||
self.builder.ret([v.handle for v in ret_values])
|
||||
return tuple(ret_types)
|
||||
else:
|
||||
ret = triton.language.core._to_tensor(ret_value, self.builder)
|
||||
ret_type = ret.type
|
||||
self.builder.ret([ret_value.handle])
|
||||
return ret_type
|
||||
return ret.type
|
||||
|
||||
def visit_FunctionDef(self, node):
|
||||
arg_names, kwarg_names = self.visit(node.args)
|
||||
@@ -219,7 +222,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
else:
|
||||
# update return type
|
||||
if isinstance(self.last_ret_type, tuple):
|
||||
self.prototype.ret_types = [ret_tensor.type for ret_tensor in self.last_ret_type]
|
||||
self.prototype.ret_types = list(self.last_ret_type)
|
||||
fn.reset_type(self.prototype.to_ir(self.builder))
|
||||
else:
|
||||
self.prototype.ret_types = [self.last_ret_type]
|
||||
@@ -661,7 +664,10 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
return triton.language.tensor(call_op.get_result(0), callee_ret_type)
|
||||
else:
|
||||
# should return a tuple of tl.tensor
|
||||
raise RuntimeError("Not implemented")
|
||||
results = []
|
||||
for i in range(call_op.get_num_results()):
|
||||
results.append(triton.language.tensor(call_op.get_result(i), callee_ret_type[i]))
|
||||
return tuple(results)
|
||||
if hasattr(fn, '__self__') and self.is_triton_tensor(fn.__self__) or \
|
||||
sys.modules[fn.__module__] is triton.language.core:
|
||||
return fn(*args, _builder=self.builder, **kws)
|
||||
|
Reference in New Issue
Block a user