from tarfile import BLOCKSIZE import torch import triton import triton.language as tl @triton.jit def add_kernel( x_ptr, # *Pointer* to first input vector y_ptr, # *Pointer* to second input vector output_ptr, # *Pointer* to output vector n_elements, # Size of the vector BLOCK_SIZE: tl.constexpr, # Number of elements each program should process # NOTE: `constexpr` so it can be used as a shape value ): # There are multiple 'program's processing different data. We identify which program # we are here pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0 # This program will process inputs that are offset from the initial data. # for instance, if you had a vector of length 256 and block_size of 64, the programs # would each access the elements [0:64, 64:128, 128:192, 192:256]. # Note that offsets is a list of pointers block_start = pid * BLOCK_SIZE offsets = block_start + tl.arange(0, BLOCK_SIZE) # Create a mask to guard memory operations against out-of-bounds accesses mask = offsets < n_elements # Load x and y from DRAM, masking out any extra elements in case the input is not a # multiple of the block size x = tl.load(x_ptr + offsets, mask=mask, other=0.0) y = tl.load(y_ptr + offsets, mask=mask, other=0.0) output = x + y # Write x + y back to DRAM tl.store(output_ptr + offsets, output, mask=mask) size = 1024 x = torch.rand(size, device='cuda') y = torch.rand(size, device='cuda') z = torch.empty_like(x) # add_kernel[(1,)](x, y, z, size, 256) # print(add_kernel[(1,)].kernel.compile_to_ttir()) # print(add_kernel.annotations) mod, ctx = add_kernel.compile_to_ttir(x, y, z, size, BLOCK_SIZE=256, grid=(1,)) mod.dump()