[LANG] Added support for constexpr (#361)

This commit is contained in:
Philippe Tillet
2021-10-30 00:32:58 -07:00
committed by GitHub
parent 770ea96cca
commit 2acaa4d0dd
16 changed files with 355 additions and 365 deletions

View File

@@ -12,6 +12,7 @@ In this tutorial, you will write a simple vector addition using Triton and learn
# Compute Kernel
# --------------------------
from triton.language.core import constexpr
import torch
import triton
import triton.language as tl
@@ -23,9 +24,9 @@ def add_kernel(
y_ptr, # *Pointer* to second input vector
output_ptr, # *Pointer* to output vector
n_elements, # Size of the vector
**meta, # Optional meta-parameters for the kernel
BLOCK_SIZE: tl.constexpr, # Number of elements each program should process
# NOTE: `constexpr` so it can be used as a shape value
):
BLOCK_SIZE = meta['BLOCK_SIZE'] # How many inputs each program should process
# There are multiple 'program's processing different data. We identify which program
# we are here
pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0
@@ -37,8 +38,8 @@ def add_kernel(
offsets = block_start + tl.arange(0, BLOCK_SIZE)
# Create a mask to guard memory operations against out-of-bounds accesses
mask = offsets < n_elements
# Load x and y from DRAM, masking out any extar elements in case the input is not a
# multiple of the block size
# Load x and y from DRAM, masking out any extra elements in case
# the input is not a multiple of the block size
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
output = x + y