[LANG] Added support for constexpr (#361)
This commit is contained in:
@@ -65,11 +65,11 @@ import triton.language as tl
|
||||
|
||||
@triton.jit
|
||||
def softmax_kernel(
|
||||
output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols, **meta
|
||||
output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols,
|
||||
BLOCK_SIZE: tl.constexpr
|
||||
):
|
||||
# The rows of the softmax are independent, so we parallelize across those
|
||||
row_idx = tl.program_id(0)
|
||||
BLOCK_SIZE = meta['BLOCK_SIZE']
|
||||
# The stride represents how much we need to increase the pointer to advance 1 row
|
||||
row_start_ptr = input_ptr + row_idx * input_row_stride
|
||||
# The block size is the next power of two greater than n_cols, so we can fit each
|
||||
|
Reference in New Issue
Block a user