[PYTHON] Now providing triton.next_power_of_2 (#273)

This commit is contained in:
Philippe Tillet
2021-09-10 11:05:44 -07:00
committed by GitHub
parent 43723ccb95
commit ac10551d55
2 changed files with 11 additions and 15 deletions

View File

@@ -833,6 +833,16 @@ def jit(fn):
def cdiv(x, y):
return (x + y - 1) // y
def next_power_of_2(n):
"""Return the smallest power of 2 greater than or equal to n"""
n -= 1
n |= n >> 1
n |= n >> 2
n |= n >> 4
n |= n >> 8
n |= n >> 16
n += 1
return n
######

View File

@@ -72,7 +72,6 @@ def softmax_kernel(
BLOCK_SIZE = meta['BLOCK_SIZE']
# The stride represents how much we need to increase the pointer to advance 1 row
row_start_ptr = input_ptr + row_idx * input_row_stride
# The block size is the next power of two greater than n_cols, so we can fit each
# row in a single block
col_offsets = tl.arange(0, BLOCK_SIZE)
@@ -94,23 +93,10 @@ def softmax_kernel(
# %%
# We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor.
def next_power_of_2(n):
"""Return the smallest power of 2 greater than or equal to n"""
n -= 1
n |= n >> 1
n |= n >> 2
n |= n >> 4
n |= n >> 8
n |= n >> 16
n += 1
return n
def softmax(x):
n_rows, n_cols = x.shape
# The block size is the smallest power of two greater than the number of columns in `x`
BLOCK_SIZE = next_power_of_2(n_cols)
BLOCK_SIZE = triton.next_power_of_2(n_cols)
# Another trick we can use is to ask the compiler to use more threads per row by
# increasing the number of warps (`num_warps`) over which each row is distributed.
# You will see in the next tutorial how to auto-tune this value in a more natural