From 70e28ff38021648d24ba6a1382f15ac544624a3b Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Wed, 11 Aug 2021 18:59:15 -0700 Subject: [PATCH] [DOCS] Minor modifications of the matmul tutorial (#199) Making the code more compact and fixing inconsistencies between text variable names and final python program. --- python/triton/language.py | 11 ++ python/tutorials/03-matrix-multiplication.py | 184 +++++++++---------- 2 files changed, 99 insertions(+), 96 deletions(-) diff --git a/python/triton/language.py b/python/triton/language.py index bc14b4235..6a1a44d60 100644 --- a/python/triton/language.py +++ b/python/triton/language.py @@ -649,6 +649,17 @@ def max_contiguous(input, value, builder=None): # Standard library # ----------------------- +@triton.jit +def cdiv(x, div): + """ + Computes the ceiling division of :code:`x` by :code:`div` + + :param x: the input number + :type input: Block + :param div: the divisor + :param div: Block + """ + return (x + div - 1) // div @triton.jit def minimum(x, y): diff --git a/python/tutorials/03-matrix-multiplication.py b/python/tutorials/03-matrix-multiplication.py index e71fae2d6..80207d8cf 100644 --- a/python/tutorials/03-matrix-multiplication.py +++ b/python/tutorials/03-matrix-multiplication.py @@ -23,7 +23,7 @@ You will specifically learn about: # yourself with Triton, in a way that is easy to customize and extend. # # Roughly speaking, the kernel that we will write will implement the following blocked -# algorithm to multiply a (MxK) by a (KxN) matrix: +# algorithm to multiply a (M, K) by a (K, N) matrix: # # .. code-block:: python # @@ -38,7 +38,7 @@ You will specifically learn about: # acc += dot(a, b) # C[m : m+BLOCK_SIZE_M, n : n+BLOCK_SIZE_N] = acc; # -# where each iteration of the doubly-nested for-loop corresponds to a Triton program instance. +# where each iteration of the doubly-nested for-loop is performed by a dedicated Triton program instance. # %% # Compute Kernel @@ -53,35 +53,31 @@ You will specifically learn about: # ~~~~~~~~~~~~~~~~~~~~ # # For a row-major 2D tensor :code:`X`, the memory location of :code:`X[i, j]` is given b -# y :code:`&X[i, j] = X + i*stride_x_0 + j*stride_x_1`. +# y :code:`&X[i, j] = X + i*stride_xi + j*stride_xj`. # Therefore, blocks of pointers for :code:`A[m : m+BLOCK_SIZE_M, k:k+BLOCK_SIZE_K]` and # :code:`B[k : k+BLOCK_SIZE_K, n : n+BLOCK_SIZE_N]` can be defined in pseudo-code as: # # .. code-block:: python # -# &A[m : m+BLOCK_SIZE_M, k:k+BLOCK_SIZE_K] = A + (m : m+BLOCK_SIZE_M)[:, None]*A.stride(0) + (k : k+BLOCK_SIZE_K)[None, :]*A.stride(1); -# &B[k : k+BLOCK_SIZE_K, n:n+BLOCK_SIZE_N] = B + (k : k+BLOCK_SIZE_K)[:, None]*B.stride(0) + (n : n+BLOCK_SIZE_N)[None, :]*B.stride(1); +# &A[m : m+BLOCK_SIZE_M, k:k+BLOCK_SIZE_K] = a_ptr + (m : m+BLOCK_SIZE_M)[:, None]*A.stride(0) + (k : k+BLOCK_SIZE_K)[None, :]*A.stride(1); +# &B[k : k+BLOCK_SIZE_K, n:n+BLOCK_SIZE_N] = b_ptr + (k : k+BLOCK_SIZE_K)[:, None]*B.stride(0) + (n : n+BLOCK_SIZE_N)[None, :]*B.stride(1); # # Which means that pointers for blocks of A and B can be initialized (i.e., :code:`k=0`) in Triton as: # # .. code-block:: python # -# pid_m = triton.program_id(0) -# pid_n = triton.program_id(1) -# rm = pid_m * BLOCK_SIZE_M + triton.arange(0, BLOCK_SIZE_M) -# rn = pid_n * BLOCK_SIZE_N + triton.arange(0, BLOCK_SIZE_N) -# rk = triton.arange(0, BLOCK_SIZE_K) -# // pointer for A operand -# pa = A + (rm[:, None] * stride_a_0 + rk[None, :] * stride_a_1); -# // pointer for B operand -# pb = B + (rk[:, None] * stride_b_0 + rn[None, :] * stride_b_1); +# offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) +# offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) +# offs_k = tl.arange(0, BLOCK_SIZE_K) +# a_ptrs = a_ptr + (offs_am[:, None]*stride_am + offs_k [None, :]*stride_ak) +# b_ptrs = b_ptr + (offs_k [:, None]*stride_bk + offs_bn[None, :]*stride_bn) # # And then updated in the inner loop as follows: # # .. code-block:: python # -# pa += BLOCK_SIZE_K * stride_a_1; -# pb += BLOCK_SIZE_K * stride_b_0; +# pa += BLOCK_SIZE_K * stride_ak; +# pb += BLOCK_SIZE_K * stride_bk; # # # L2 Cache Optimizations @@ -109,13 +105,25 @@ You will specifically learn about: # # .. code-block:: python # -# pid = triton.program_id(0); -# width = GROUP_M * grid_n; -# group_id = pid // width; -# # we need to handle the case where M % (GROUP_M*BLOCK_SIZE_M) != 0 -# group_size = min(grid_m - group_id * GROUP_M, GROUP_M); -# pid_m = group_id * GROUP_M + (pid % group_size); -# pid_n = (pid % width) // (group_size); +# # program ID +# pid = tl.program_id(axis=0) +# # number of program ids along the M axis +# num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) +# # number of programs ids along the N axis +# num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) +# # number of programs in group +# num_pid_in_group = GROUP_SIZE_M * num_pid_n +# # id of the group this program is in +# group_id = pid // num_pid_in_group +# # row-id of the first program in the group +# first_pid_m = group_id * GROUP_SIZE_M +# # if `num_pid_m` isn't divisible by `GROUP_SIZE_M`, the last group is smaller +# group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) +# # *within groups*, programs are ordered in a column-major order +# # row-id of the program in the *launch grid* +# pid_m = first_pid_m + (pid % group_size_m) +# # col-id of the program in the *launch grid* +# pid_n = (pid % num_pid_in_group) // group_size_m # # For example, in the following matmul where each matrix is 9 blocks by 9 blocks, # we can see that if we compute the output in row-major ordering, we need to load 90 @@ -164,26 +172,19 @@ import triton.language as tl @triton.jit def matmul_kernel( # Pointers to matrices - a_ptr, - b_ptr, - c_ptr, + a_ptr, b_ptr, c_ptr, # Matrix dimensions - M, - N, - K, + M, N, K, # The stride variables represent how much to increase the ptr by when moving by 1 # element in a particular dimension. E.g. stride_am is how much to increase a_ptr # by to get the element one row down (A has M rows) - stride_am, - stride_ak, - stride_bk, - stride_bn, - stride_cm, - stride_cn, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + # Meta-parameters **meta, ): - """Kernel for computing the matmul AB = C - + """Kernel for computing the matmul C = A x B. A has shape (M, K), B has shape (K, N) and C has shape (M, N) """ # extract meta-parameters @@ -191,67 +192,65 @@ def matmul_kernel( BLOCK_SIZE_N = meta['BLOCK_SIZE_N'] BLOCK_SIZE_K = meta['BLOCK_SIZE_K'] GROUP_SIZE_M = 8 + + # ----------------------------------------------------------- + # Map program ids `pid` to the block of C it should compute. + # This is done in a grouped ordering to promote L2 data reuse + # See above `L2 Cache Optimizations` section for details pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m - # the number of blocks is the ceil(M / BLOCK_SIZE_M) since we need an extra block - # Note that this will lead to some quantization in performance where time-taken jumps - # when you need to add a new block - n_blocks_m = (M + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M - n_blocks_n = (N + BLOCK_SIZE_N - 1) // BLOCK_SIZE_N + # ---------------------------------------------------------- + # Create pointers for the first blocks of A and B. + # We will advance this pointer as we move in the K direction + # and accumulate + # a_ptrs is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers + # b_ptrs is a block of [BLOCK_SIZE_K, BLOCK_SIZE_n] pointers + # see above `Pointer Arithmetics` section for details + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None]*stride_am + offs_k [None, :]*stride_ak) + b_ptrs = b_ptr + (offs_k [:, None]*stride_bk + offs_bn[None, :]*stride_bn) - # Map PIDs to the block they should compute. This is done in a grouped ordering - # to promote L2 cache reuse. - n_output_blocks_in_group = GROUP_SIZE_M * n_blocks_n - group_id = pid // n_output_blocks_in_group - first_m_block_in_group = group_id * GROUP_SIZE_M - - # If the number of blocks is not divisible by the group size, the last group is smaller - group_size_m = min(n_blocks_m - first_m_block_in_group, GROUP_SIZE_M) - - # Within a group, we compute in col-major ordering, block_m and block_n are the - # output row and col that this program is computing in terms of blocks - block_m = first_m_block_in_group + (pid % group_size_m) - block_n = (pid % n_output_blocks_in_group) // group_size_m - - # Convert from block indices back to element indices - m_start = block_m * BLOCK_SIZE_M - n_start = block_n * BLOCK_SIZE_N - - # Expand out to all the offsets for each of the elements in this block. - m_offsets_a = (m_start + tl.arange(0, BLOCK_SIZE_M))[:, None] - n_offsets_b = (n_start + tl.arange(0, BLOCK_SIZE_N))[None, :] - k_offsets = tl.arange(0, BLOCK_SIZE_K) - - # Get the pointers for the first block of each. We will advance this pointer - # as we move in the K direction and accumulate. - # a_ptrs should contain BLOCK_SIZE_M * BLOCK_SIZE_K pointers - a_ptrs = a_ptr + (stride_am * m_offsets_a + stride_ak * k_offsets[None, :]) - # b_ptrs should contain BLOCK_SIZE_K * BLOCK_SIZE_N pointers - b_ptrs = b_ptr + (stride_bk * k_offsets[:, None] + stride_bn * n_offsets_b) - # We accumulate internally in fp32, but the output is written out in the dtype - # of the tensor when it is stored + # ----------------------------------------------------------- + # Iterate to compute a block of the C matrix + # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block + # of fp32 values for higher accuracy. + # `accumulator` will be converted back to fp16 after the loop accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k in range(0, K, BLOCK_SIZE_K): - # Note that for simplicity, we don't apply a mask here. This means that if K is - # not a multiple of BLOCK_SIZE_K, this will access out-of-bounds memory and - # accumulate it incorrectly. + # Note that for simplicity, we don't apply a mask here. + # This means that if K is not a multiple of BLOCK_SIZE_K, + # this will access out-of-bounds memory and produce an + # error or (worse!) incorrect results. a = tl.load(a_ptrs) b = tl.load(b_ptrs) # We accumulate along the K dimension accumulator += tl.dot(a, b) - # Advance the ptrs to the next K block a_ptrs += BLOCK_SIZE_K * stride_ak b_ptrs += BLOCK_SIZE_K * stride_bk - # triton can accept arbitrary activation function via metaparameters! - if meta['ACTIVATION']: + # you can fuse arbitrary activation functions here + # while the accumulator is still in FP32 ! + if meta['ACTIVATION']: accumulator = meta['ACTIVATION'](accumulator) + c = accumulator.to(tl.float16) - m_offsets_c = (m_start + tl.arange(0, BLOCK_SIZE_M))[:, None] - n_offsets_c = (n_start + tl.arange(0, BLOCK_SIZE_N))[None, :] - c_ptrs = c_ptr + stride_cm * m_offsets_c + stride_cn * n_offsets_c - mask = (m_offsets_c < M) & (n_offsets_c < N) - tl.store(c_ptrs, accumulator, mask=mask) + # ----------------------------------------------------------- + # Write back the block of the output matrix C + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) # we can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `_matmul` @@ -282,18 +281,11 @@ def matmul(a, b, activation=None): triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), ) matmul_kernel[grid]( - a, - b, - c, - M, - N, - K, - a.stride(0), - a.stride(1), - b.stride(0), - b.stride(1), - c.stride(0), - c.stride(1), + a, b, c, + M, N, K, + a.stride(0), a.stride(1), + b.stride(0), b.stride(1), + c.stride(0), c.stride(1), ACTIVATION=activation, ) return c