Allow JITFunction to return multiple results

This commit is contained in:
Yan Da
2022-04-15 15:38:19 +08:00
parent 1c52bd587d
commit 9e304cf79d
2 changed files with 40 additions and 7 deletions

View File

@@ -0,0 +1,27 @@
import triton
import triton.language as tl
import triton._C.libtriton.triton as _triton
@triton.jit
def foo(a, b):
max, min = maxmin(a, b)
return max, min
@triton.jit
def maxmin(a, b):
max = tl.maximum(a, b)
min = tl.minimum(a, b)
return max, min
mod, ctx = foo.compile_to_ttir(3, 4, grid=(1,))
assert mod.verify()
mod.dump()
pm = _triton.ir.pass_manager(ctx)
pm.add_inliner_pass()
pm.run(mod)
assert mod.verify()
mod.dump()