Files
triton/rewrite-test/jit/while.py
2022-04-10 16:28:45 +08:00

39 lines
790 B
Python

import triton
import triton.language as tl
import torch
@triton.jit
def atomic(lock):
while tl.atomic_cas(lock, 0, 1) == 1:
pass
@triton.jit
def generic_while(lb, value):
c = -1
while c <= 0:
c += 1
# locks = torch.zeros(32, dtype=torch.int32, device='cuda')
# mod_atomic, ctx_atomic = atomic.compile_to_ttir(locks, grid=(1,))
# mod_atomic.dump()
# mod_generic_while, ctx_generic_while = generic_while.compile_to_ttir(8, 9, grid=(1,))
# mod_generic_while.dump()
@triton.jit
def nested_cf(X, lb, ub, Z):
a = 0.0
if lb < ub:
for z in range(0, Z):
a += 2.0
else:
while a < 1.2:
a *= 2.0
for _ in range(0, Z, 2):
a *= -3.3
a -= 1.0
mod, _ = nested_cf.compile_to_ttir(3, 4, 5, 6, grid=(1,))
assert mod.verify(), mod.str()
mod.dump()