[LANG] Added support for constexpr (#361)

This commit is contained in:
Philippe Tillet
2021-10-30 00:32:58 -07:00
committed by GitHub
parent 770ea96cca
commit 2acaa4d0dd
16 changed files with 355 additions and 365 deletions

View File

@@ -65,11 +65,11 @@ import triton.language as tl
@triton.jit
def softmax_kernel(
output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols, **meta
output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols,
BLOCK_SIZE: tl.constexpr
):
# The rows of the softmax are independent, so we parallelize across those
row_idx = tl.program_id(0)
BLOCK_SIZE = meta['BLOCK_SIZE']
# The stride represents how much we need to increase the pointer to advance 1 row
row_start_ptr = input_ptr + row_idx * input_row_stride
# The block size is the next power of two greater than n_cols, so we can fit each