diff --git a/rewrite-test/jit/if-else/vecadd-cond.py b/rewrite-test/jit/if-else/vecadd-cond.py new file mode 100644 index 000000000..fd39098b4 --- /dev/null +++ b/rewrite-test/jit/if-else/vecadd-cond.py @@ -0,0 +1,49 @@ +import triton + +@triton.jit +def if_else(lb, ub, value): + if value > lb: + a = 0.0 + else: + a = 1.0 + c = a + a + +@triton.jit +def only_if(lb, ub, value): + a = -1.0 + if value > lb: + a = 0.0 + c = a + a + +@triton.jit +def only_if_invalid(lb, ub, value): + if value > lb: + a = 0.0 + c = a + a + +@triton.jit +def nested_if(lb, ub, value): + if value > lb: + if value < ub: + a = 2.0 + else: + a = 1.0 + else: + a = 0.0 + c = a + a + + +mod_if_else, ctx_if_else = if_else.compile_to_ttir(2, 4, 3, grid=(1,)) +mod_if_else.dump() + +mod_only_if, ctx_only_if = only_if.compile_to_ttir(2, 4, 3, grid=(1,)) +mod_only_if.dump() + +try: + mod_only_if_invalid, ctx_only_if = only_if_invalid.compile_to_ttir(2, 4, 3, grid=(1,)) + mod_only_if_invalid.dump() +except: + print('value error') + +mod_nested_if, ctx_nested_if = nested_if.compile_to_ttir(2, 4, 3, grid=(1,)) +mod_nested_if.dump() diff --git a/rewrite-test/jit/vecadd-loop.py b/rewrite-test/jit/vecadd-loop.py new file mode 100644 index 000000000..1e02c27e0 --- /dev/null +++ b/rewrite-test/jit/vecadd-loop.py @@ -0,0 +1,52 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def add_kernel( + x_ptr, # *Pointer* to first input vector + y_ptr, # *Pointer* to second input vector + output_ptr, # *Pointer* to output vector + n_elements, # Size of the vector + K, + stride + # BLOCK_SIZE: tl.constexpr, # Number of elements each program should process + # # NOTE: `constexpr` so it can be used as a shape value +): + # There are multiple 'program's processing different data. We identify which program + # we are here + pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0 + # This program will process inputs that are offset from the initial data. + # for instance, if you had a vector of length 256 and block_size of 64, the programs + # would each access the elements [0:64, 64:128, 128:192, 192:256]. + # Note that offsets is a list of pointers + block_start = pid * 256 + offsets = block_start + tl.arange(0, 256) + # Create a mask to guard memory operations against out-of-bounds accesses + mask = offsets < n_elements + + x_ptrs = x_ptr + offsets + y_ptrs = y_ptr + offsets + output = tl.zeros((256,), dtype=tl.float32) + for k in range(0, K, 32): + x = tl.load(x_ptrs, mask=mask, other=0.0) + y = tl.load(y_ptrs, mask=mask, other=0.0) + output += x + y + + x_ptrs += stride + y_ptrs += stride + + # Write x + y back to DRAM + tl.store(output_ptr + offsets, output, mask=mask) + +size = 1024 +x = torch.rand(size, device='cuda') +y = torch.rand(size, device='cuda') +z = torch.empty_like(x) +# add_kernel[(1,)](x, y, z, size, 256) +# print(add_kernel[(1,)].kernel.compile_to_ttir()) +mod, ctx = add_kernel.compile_to_ttir(x, y, z, size, 128, 8, grid=(1,)) +mod.get_context() +mod.dump() +# print(mod)