Files
triton/rewrite-test/jit/if-else/if-else.py
2022-04-10 15:49:09 +08:00

52 lines
913 B
Python

import triton
@triton.jit
def if_else(lb, ub, value):
if value > lb:
a = 0.0
else:
a = 1.0
c = a + a
@triton.jit
def only_if(lb, ub, value):
a = -1.0
if value > lb:
a = 0.0
c = a + a
@triton.jit
def only_if_invalid(lb, ub, value):
if value > lb:
a = 0.0
c = a + a
@triton.jit
def nested_if(lb, ub, value):
if value > lb:
if value < ub:
a = 2.0
else:
a = 1.0
else:
a = 0.0
c = a + a
mod_if_else, ctx_if_else = if_else.compile_to_ttir(2, 4, 3, grid=(1,))
mod_if_else.dump()
mod_only_if, ctx_only_if = only_if.compile_to_ttir(2, 4, 3, grid=(1,))
mod_only_if.dump()
try:
mod_only_if_invalid, ctx_only_if = only_if_invalid.compile_to_ttir(2, 4, 3, grid=(1,))
mod_only_if_invalid.dump()
except:
print('value error')
mod_nested_if, ctx_nested_if = nested_if.compile_to_ttir(2, 4, 3, grid=(1,))
mod_nested_if.dump()
print(mod_nested_if.str())