diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 61f67ffd9..33d1a53af 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -833,6 +833,16 @@ def jit(fn): def cdiv(x, y): return (x + y - 1) // y +def next_power_of_2(n): + """Return the smallest power of 2 greater than or equal to n""" + n -= 1 + n |= n >> 1 + n |= n >> 2 + n |= n >> 4 + n |= n >> 8 + n |= n >> 16 + n += 1 + return n ###### diff --git a/python/tutorials/02-fused-softmax.py b/python/tutorials/02-fused-softmax.py index 07b405a92..15efa7c81 100644 --- a/python/tutorials/02-fused-softmax.py +++ b/python/tutorials/02-fused-softmax.py @@ -72,7 +72,6 @@ def softmax_kernel( BLOCK_SIZE = meta['BLOCK_SIZE'] # The stride represents how much we need to increase the pointer to advance 1 row row_start_ptr = input_ptr + row_idx * input_row_stride - # The block size is the next power of two greater than n_cols, so we can fit each # row in a single block col_offsets = tl.arange(0, BLOCK_SIZE) @@ -94,23 +93,10 @@ def softmax_kernel( # %% # We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor. - -def next_power_of_2(n): - """Return the smallest power of 2 greater than or equal to n""" - n -= 1 - n |= n >> 1 - n |= n >> 2 - n |= n >> 4 - n |= n >> 8 - n |= n >> 16 - n += 1 - return n - - def softmax(x): n_rows, n_cols = x.shape # The block size is the smallest power of two greater than the number of columns in `x` - BLOCK_SIZE = next_power_of_2(n_cols) + BLOCK_SIZE = triton.next_power_of_2(n_cols) # Another trick we can use is to ask the compiler to use more threads per row by # increasing the number of warps (`num_warps`) over which each row is distributed. # You will see in the next tutorial how to auto-tune this value in a more natural