[LANG] Added support for constexpr (#361)
This commit is contained in:
@@ -12,6 +12,7 @@ In this tutorial, you will write a simple vector addition using Triton and learn
|
||||
# Compute Kernel
|
||||
# --------------------------
|
||||
|
||||
from triton.language.core import constexpr
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
@@ -23,9 +24,9 @@ def add_kernel(
|
||||
y_ptr, # *Pointer* to second input vector
|
||||
output_ptr, # *Pointer* to output vector
|
||||
n_elements, # Size of the vector
|
||||
**meta, # Optional meta-parameters for the kernel
|
||||
BLOCK_SIZE: tl.constexpr, # Number of elements each program should process
|
||||
# NOTE: `constexpr` so it can be used as a shape value
|
||||
):
|
||||
BLOCK_SIZE = meta['BLOCK_SIZE'] # How many inputs each program should process
|
||||
# There are multiple 'program's processing different data. We identify which program
|
||||
# we are here
|
||||
pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0
|
||||
@@ -37,8 +38,8 @@ def add_kernel(
|
||||
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||
# Create a mask to guard memory operations against out-of-bounds accesses
|
||||
mask = offsets < n_elements
|
||||
# Load x and y from DRAM, masking out any extar elements in case the input is not a
|
||||
# multiple of the block size
|
||||
# Load x and y from DRAM, masking out any extra elements in case
|
||||
# the input is not a multiple of the block size
|
||||
x = tl.load(x_ptr + offsets, mask=mask)
|
||||
y = tl.load(y_ptr + offsets, mask=mask)
|
||||
output = x + y
|
||||
|
Reference in New Issue
Block a user