Allow JITFunction to return multiple results
This commit is contained in:
27
rewrite-test/jit/multi-return.py
Normal file
27
rewrite-test/jit/multi-return.py
Normal file
@@ -0,0 +1,27 @@
|
||||
import triton
|
||||
import triton.language as tl
|
||||
import triton._C.libtriton.triton as _triton
|
||||
|
||||
|
||||
@triton.jit
|
||||
def foo(a, b):
|
||||
max, min = maxmin(a, b)
|
||||
return max, min
|
||||
|
||||
@triton.jit
|
||||
def maxmin(a, b):
|
||||
max = tl.maximum(a, b)
|
||||
min = tl.minimum(a, b)
|
||||
return max, min
|
||||
|
||||
|
||||
mod, ctx = foo.compile_to_ttir(3, 4, grid=(1,))
|
||||
assert mod.verify()
|
||||
mod.dump()
|
||||
|
||||
|
||||
pm = _triton.ir.pass_manager(ctx)
|
||||
pm.add_inliner_pass()
|
||||
pm.run(mod)
|
||||
assert mod.verify()
|
||||
mod.dump()
|
Reference in New Issue
Block a user