[LANG] Added support for constexpr (#361)
This commit is contained in:
@@ -12,6 +12,7 @@ In this tutorial, you will write a simple vector addition using Triton and learn
|
||||
# Compute Kernel
|
||||
# --------------------------
|
||||
|
||||
from triton.language.core import constexpr
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
@@ -23,9 +24,9 @@ def add_kernel(
|
||||
y_ptr, # *Pointer* to second input vector
|
||||
output_ptr, # *Pointer* to output vector
|
||||
n_elements, # Size of the vector
|
||||
**meta, # Optional meta-parameters for the kernel
|
||||
BLOCK_SIZE: tl.constexpr, # Number of elements each program should process
|
||||
# NOTE: `constexpr` so it can be used as a shape value
|
||||
):
|
||||
BLOCK_SIZE = meta['BLOCK_SIZE'] # How many inputs each program should process
|
||||
# There are multiple 'program's processing different data. We identify which program
|
||||
# we are here
|
||||
pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0
|
||||
@@ -37,8 +38,8 @@ def add_kernel(
|
||||
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||
# Create a mask to guard memory operations against out-of-bounds accesses
|
||||
mask = offsets < n_elements
|
||||
# Load x and y from DRAM, masking out any extar elements in case the input is not a
|
||||
# multiple of the block size
|
||||
# Load x and y from DRAM, masking out any extra elements in case
|
||||
# the input is not a multiple of the block size
|
||||
x = tl.load(x_ptr + offsets, mask=mask)
|
||||
y = tl.load(y_ptr + offsets, mask=mask)
|
||||
output = x + y
|
||||
|
@@ -65,11 +65,11 @@ import triton.language as tl
|
||||
|
||||
@triton.jit
|
||||
def softmax_kernel(
|
||||
output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols, **meta
|
||||
output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols,
|
||||
BLOCK_SIZE: tl.constexpr
|
||||
):
|
||||
# The rows of the softmax are independent, so we parallelize across those
|
||||
row_idx = tl.program_id(0)
|
||||
BLOCK_SIZE = meta['BLOCK_SIZE']
|
||||
# The stride represents how much we need to increase the pointer to advance 1 row
|
||||
row_start_ptr = input_ptr + row_idx * input_row_stride
|
||||
# The block size is the next power of two greater than n_cols, so we can fit each
|
||||
|
@@ -182,17 +182,13 @@ def matmul_kernel(
|
||||
stride_bk, stride_bn,
|
||||
stride_cm, stride_cn,
|
||||
# Meta-parameters
|
||||
**meta,
|
||||
):
|
||||
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
|
||||
GROUP_SIZE_M: tl.constexpr,
|
||||
ACTIVATION: tl.constexpr,
|
||||
):
|
||||
"""Kernel for computing the matmul C = A x B.
|
||||
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
|
||||
"""
|
||||
# extract meta-parameters
|
||||
BLOCK_SIZE_M = meta['BLOCK_SIZE_M']
|
||||
BLOCK_SIZE_N = meta['BLOCK_SIZE_N']
|
||||
BLOCK_SIZE_K = meta['BLOCK_SIZE_K']
|
||||
GROUP_SIZE_M = 8
|
||||
|
||||
# -----------------------------------------------------------
|
||||
# Map program ids `pid` to the block of C it should compute.
|
||||
# This is done in a grouped ordering to promote L2 data reuse
|
||||
|
Reference in New Issue
Block a user