More progress on WhileOp codegen

This commit is contained in:
Yan Da
2022-04-05 15:55:48 +08:00
parent 76d9249724
commit c7ad928e60
5 changed files with 145 additions and 45 deletions

18
rewrite-test/jit/while.py Normal file
View File

@@ -0,0 +1,18 @@
import triton
import triton.language as tl
import torch
@triton.jit
def atomic(lock):
while tl.atomic_cas(lock, 0, 1) == 1:
pass
@triton.jit
def generic_while(lb, value):
c = -1
while c <= 0:
c += 1
locks = torch.zeros(32, dtype=torch.int32, device='cuda')
mod_atomic, ctx_atomic = atomic.compile_to_ttir(locks, grid=(1,))
mod_atomic.dump()