examples
This commit is contained in:
49
rewrite-test/jit/if-else/vecadd-cond.py
Normal file
49
rewrite-test/jit/if-else/vecadd-cond.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import triton
|
||||
|
||||
@triton.jit
|
||||
def if_else(lb, ub, value):
|
||||
if value > lb:
|
||||
a = 0.0
|
||||
else:
|
||||
a = 1.0
|
||||
c = a + a
|
||||
|
||||
@triton.jit
|
||||
def only_if(lb, ub, value):
|
||||
a = -1.0
|
||||
if value > lb:
|
||||
a = 0.0
|
||||
c = a + a
|
||||
|
||||
@triton.jit
|
||||
def only_if_invalid(lb, ub, value):
|
||||
if value > lb:
|
||||
a = 0.0
|
||||
c = a + a
|
||||
|
||||
@triton.jit
|
||||
def nested_if(lb, ub, value):
|
||||
if value > lb:
|
||||
if value < ub:
|
||||
a = 2.0
|
||||
else:
|
||||
a = 1.0
|
||||
else:
|
||||
a = 0.0
|
||||
c = a + a
|
||||
|
||||
|
||||
mod_if_else, ctx_if_else = if_else.compile_to_ttir(2, 4, 3, grid=(1,))
|
||||
mod_if_else.dump()
|
||||
|
||||
mod_only_if, ctx_only_if = only_if.compile_to_ttir(2, 4, 3, grid=(1,))
|
||||
mod_only_if.dump()
|
||||
|
||||
try:
|
||||
mod_only_if_invalid, ctx_only_if = only_if_invalid.compile_to_ttir(2, 4, 3, grid=(1,))
|
||||
mod_only_if_invalid.dump()
|
||||
except:
|
||||
print('value error')
|
||||
|
||||
mod_nested_if, ctx_nested_if = nested_if.compile_to_ttir(2, 4, 3, grid=(1,))
|
||||
mod_nested_if.dump()
|
52
rewrite-test/jit/vecadd-loop.py
Normal file
52
rewrite-test/jit/vecadd-loop.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def add_kernel(
|
||||
x_ptr, # *Pointer* to first input vector
|
||||
y_ptr, # *Pointer* to second input vector
|
||||
output_ptr, # *Pointer* to output vector
|
||||
n_elements, # Size of the vector
|
||||
K,
|
||||
stride
|
||||
# BLOCK_SIZE: tl.constexpr, # Number of elements each program should process
|
||||
# # NOTE: `constexpr` so it can be used as a shape value
|
||||
):
|
||||
# There are multiple 'program's processing different data. We identify which program
|
||||
# we are here
|
||||
pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0
|
||||
# This program will process inputs that are offset from the initial data.
|
||||
# for instance, if you had a vector of length 256 and block_size of 64, the programs
|
||||
# would each access the elements [0:64, 64:128, 128:192, 192:256].
|
||||
# Note that offsets is a list of pointers
|
||||
block_start = pid * 256
|
||||
offsets = block_start + tl.arange(0, 256)
|
||||
# Create a mask to guard memory operations against out-of-bounds accesses
|
||||
mask = offsets < n_elements
|
||||
|
||||
x_ptrs = x_ptr + offsets
|
||||
y_ptrs = y_ptr + offsets
|
||||
output = tl.zeros((256,), dtype=tl.float32)
|
||||
for k in range(0, K, 32):
|
||||
x = tl.load(x_ptrs, mask=mask, other=0.0)
|
||||
y = tl.load(y_ptrs, mask=mask, other=0.0)
|
||||
output += x + y
|
||||
|
||||
x_ptrs += stride
|
||||
y_ptrs += stride
|
||||
|
||||
# Write x + y back to DRAM
|
||||
tl.store(output_ptr + offsets, output, mask=mask)
|
||||
|
||||
size = 1024
|
||||
x = torch.rand(size, device='cuda')
|
||||
y = torch.rand(size, device='cuda')
|
||||
z = torch.empty_like(x)
|
||||
# add_kernel[(1,)](x, y, z, size, 256)
|
||||
# print(add_kernel[(1,)].kernel.compile_to_ttir())
|
||||
mod, ctx = add_kernel.compile_to_ttir(x, y, z, size, 128, 8, grid=(1,))
|
||||
mod.get_context()
|
||||
mod.dump()
|
||||
# print(mod)
|
Reference in New Issue
Block a user