More progress on WhileOp codegen
This commit is contained in:
18
rewrite-test/jit/while.py
Normal file
18
rewrite-test/jit/while.py
Normal file
@@ -0,0 +1,18 @@
|
||||
import triton
|
||||
import triton.language as tl
|
||||
import torch
|
||||
|
||||
@triton.jit
|
||||
def atomic(lock):
|
||||
while tl.atomic_cas(lock, 0, 1) == 1:
|
||||
pass
|
||||
|
||||
@triton.jit
|
||||
def generic_while(lb, value):
|
||||
c = -1
|
||||
while c <= 0:
|
||||
c += 1
|
||||
|
||||
locks = torch.zeros(32, dtype=torch.int32, device='cuda')
|
||||
mod_atomic, ctx_atomic = atomic.compile_to_ttir(locks, grid=(1,))
|
||||
mod_atomic.dump()
|
Reference in New Issue
Block a user