[LANG] Added support for constexpr (#361)

This commit is contained in:
Philippe Tillet
2021-10-30 00:32:58 -07:00
committed by GitHub
parent 770ea96cca
commit 2acaa4d0dd
16 changed files with 355 additions and 365 deletions

View File

@@ -12,6 +12,7 @@ In this tutorial, you will write a simple vector addition using Triton and learn
# Compute Kernel
# --------------------------
from triton.language.core import constexpr
import torch
import triton
import triton.language as tl
@@ -23,9 +24,9 @@ def add_kernel(
y_ptr, # *Pointer* to second input vector
output_ptr, # *Pointer* to output vector
n_elements, # Size of the vector
**meta, # Optional meta-parameters for the kernel
BLOCK_SIZE: tl.constexpr, # Number of elements each program should process
# NOTE: `constexpr` so it can be used as a shape value
):
BLOCK_SIZE = meta['BLOCK_SIZE'] # How many inputs each program should process
# There are multiple 'program's processing different data. We identify which program
# we are here
pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0
@@ -37,8 +38,8 @@ def add_kernel(
offsets = block_start + tl.arange(0, BLOCK_SIZE)
# Create a mask to guard memory operations against out-of-bounds accesses
mask = offsets < n_elements
# Load x and y from DRAM, masking out any extar elements in case the input is not a
# multiple of the block size
# Load x and y from DRAM, masking out any extra elements in case
# the input is not a multiple of the block size
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
output = x + y

View File

@@ -65,11 +65,11 @@ import triton.language as tl
@triton.jit
def softmax_kernel(
output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols, **meta
output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols,
BLOCK_SIZE: tl.constexpr
):
# The rows of the softmax are independent, so we parallelize across those
row_idx = tl.program_id(0)
BLOCK_SIZE = meta['BLOCK_SIZE']
# The stride represents how much we need to increase the pointer to advance 1 row
row_start_ptr = input_ptr + row_idx * input_row_stride
# The block size is the next power of two greater than n_cols, so we can fit each

View File

@@ -182,17 +182,13 @@ def matmul_kernel(
stride_bk, stride_bn,
stride_cm, stride_cn,
# Meta-parameters
**meta,
):
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
ACTIVATION: tl.constexpr,
):
"""Kernel for computing the matmul C = A x B.
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
"""
# extract meta-parameters
BLOCK_SIZE_M = meta['BLOCK_SIZE_M']
BLOCK_SIZE_N = meta['BLOCK_SIZE_N']
BLOCK_SIZE_K = meta['BLOCK_SIZE_K']
GROUP_SIZE_M = 8
# -----------------------------------------------------------
# Map program ids `pid` to the block of C it should compute.
# This is done in a grouped ordering to promote L2 data reuse