[DOCS] Matmul and vecadd working examples

This commit is contained in:
jack-willturner
2020-05-04 16:25:17 +01:00
committed by Philippe Tillet
parent ce4a4728f5
commit 32819dea51
6 changed files with 159 additions and 27 deletions

View File

@@ -58,7 +58,7 @@ As you will see, a wrapper for the above Triton function can be created in just
"""
# create callable kernel for the source-code
# options: 4 warps and a -DTILE=1024
kernel = triton.kernel(src, defines = {'TILE': 1024}; num_warps = [4])
kernel = triton.kernel(src, defines = {'TILE': 1024}, num_warps = [4])
# Forward pass
@staticmethod
@@ -88,6 +88,7 @@ As you will see, a wrapper for the above Triton function can be created in just
zb = add(x, y)
diff = (za - zb).abs().max()
print(diff)
print(torch.allclose(za,zb))
Executing the above code will:
@@ -97,3 +98,5 @@ Executing the above code will:
- Call the resulting custom op
In other words, the first program run will generate and cache a bunch of files in $HOME/.triton/cache, but subsequent runs should be just as fast as using a handwritten custom operation.
A runnable version of this kernel is available `here <https://github.com/ptillet/triton/tree/master/python/examples/tutorials/vec_add.py>`_.

View File

@@ -10,7 +10,7 @@ The purpose of this section is to present a Triton-C implementation of matrix mu
Compute Kernel
==============
Matrix multiplications of the form `C = A x B` can be implemented in Triton-C fairly concisely, as shown below:
Matrix multiplications of the form `C = A x B` can be implemented in Triton-C fairly concisely, as shown below:
.. code-block:: C
@@ -24,7 +24,7 @@ Matrix multiplications of the form `C = A x B` can be implemented in Triton-C fa
int rm[TM] = pm * TM + 0 ... TM; //(3)
int rn[TN] = pn * TN + 0 ... TN; //(4)
int rk[TK] = 0 ... TK; //(5)
// initialize accumulator
// initialize accumulator
float c[TM, TN] = 0; //(6)
// pointers to operands
TYPE* pa[TM, TK] = A + rk[newaxis, :] * 1 + rm[:, newaxis] * lda; //(7)
@@ -32,10 +32,10 @@ Matrix multiplications of the form `C = A x B` can be implemented in Triton-C fa
// reduction loop
for(int k = K; k > 0; k-= TK){
// fetch operands
TYPE a[TM, TK] = *pa; //(9)
TYPE a[TM, TK] = *pa; //(9)
TYPE b[TK, TN] = *pb; //(10)
// matrix-multiply accumulate
c += dot(a, b); //(11)
c += a @ b; //(11)
// increment pointers
pa = pa + TK * 1; //(12)
pb = pb + TK * ldb; //(13)
@@ -85,10 +85,10 @@ The purpose of pre-fetching is to overlap the update of the accumulator `c` with
.. code-block:: C
// pre-fetch operands
TYPE a[TM, TK] = *pa; //(9)
TYPE a[TM, TK] = *pa; //(9)
TYPE b[TK, TN] = *pb; //(10)
for(int k = K; k > 0; k-= TK){
c += dot(a, b);
c += a @ b;
pa = pa + TK * 1;
pb = pb + TK * ldb;
// don't prefetch last iteration
@@ -113,7 +113,7 @@ Rematerialization
int rcm[TM] = pm * TM + 0 ... TM;
int rcn[TN] = pn * TN + 0 ... TN;
TYPE* pc[TM, TN] = C + rcn[newaxis, :] + rcm[:, newaxis] * ldc;
*pc = c;
*pc = c;
------------------------------------
@@ -144,7 +144,7 @@ It is common for optimized matrix-multiplication implementations (e.g., BLAS) to
TYPE b[SHAPE_B] = (*pb);
// reduction loop
for(int k = K; k > 0; k-= TK){
c += dot(USE_A, USE_B);
c += USE_A @ USE_B;
pa = pa + TK * STRIDE_AK;
pb = pb + TK * STRIDE_BK;
a = *pa;
@@ -163,16 +163,16 @@ All matrix multiplications variants can then be retrieved using the following co
.. code-block:: C
// A is not transposed
-DUSE_A=a -DSTRIDE_AK=1-DSTRIDE_AM=lda
-DUSE_A=a -DSTRIDE_AK=1-DSTRIDE_AM=lda
-DBROADCAST_AK=newaxis,: -DBROADCAST_AN=:,newaxis -DSHAPE_A=TM,TK
// A is transposed
-DUSE_A=^a -DSTRIDE_AK=lda-DSTRIDE_AM=1
-DUSE_A=^a -DSTRIDE_AK=lda-DSTRIDE_AM=1
-DBROADCAST_AK=:,newaxis -DBROADCAST_AN=newaxis,: -DSHAPE_A=TK,TM
// B is not transpose
-DUSE_B=b -DSTRIDE_BK=ldb-DSTRIDE_BN=1
-DUSE_B=b -DSTRIDE_BK=ldb-DSTRIDE_BN=1
-DBROADCAST_BK=:,newaxis -DBROADCAST_BN=newaxis,: -DSHAPE_B=TK,TN
// B is transpose
-DUSE_B=^b -DSTRIDE_BK=1-DSTRIDE_BN=ldb
-DUSE_B=^b -DSTRIDE_BK=1-DSTRIDE_BN=ldb
-DBROADCAST_BK=newaxis,: -DBROADCAST_BN=:,newaxis -DSHAPE_B=TN,TK
@@ -182,3 +182,5 @@ Auto-tuning can also be handled using pre-processor macros:
// Auto-tuning TM and TN in {32, 64, 128}; TK in {8, 16}
-DTM=[32, 64, 128] -DTN=[32, 64, 128] -DTK=[8, 16]
A runnable version of this kernel is available `here <https://github.com/ptillet/triton/tree/master/python/examples/tutorials/mat_mul.py>`_.

View File

@@ -3,7 +3,7 @@ Matrix Transpositions
*********************
Transpositions are (relatively) hard to efficiently write in CUDA because naive implementations typically suffer from *uncoalesced* memory operations when writing back the transposed matrix to DRAM.
Transpositions are (relatively) hard to efficiently write in CUDA because naive implementations typically suffer from *uncoalesced* memory operations when writing back the transposed matrix to DRAM.
Of course, this can be fixed by using shared memory as shown `here <https://devblogs.nvidia.com/efficient-matrix-transpose-cuda-cc>`_, but this comes at the cost of simplicity interferes with auto-tuning.
@@ -16,7 +16,7 @@ In Triton, however, kernels are single-threaded and the compiler automatically d
.. code-block:: C
// launched on a grid of (M / TM) x (N / TN) programs of 1 thread each
__global__ void transpose(TYPE * X, TYPE * Y,
__global__ void transpose(TYPE * X, TYPE * Y,
int M, int N, int ldx, int ldy) {
// extract program ID
int pidm = get_program_id(0); //(1)
@@ -30,7 +30,7 @@ In Triton, however, kernels are single-threaded and the compiler automatically d
// write back using the transposition operator '^'
*py = ^(*px); //(7)
}
At a high level, this kernel loads a :code:`TM x TN` tile from the input matrix :code:`X`, transposes it and writes the resulting :code:`TN x TM` tile to the output matrix :code:`Y`. Eventually, transposition of the full input matrix is achieved by launching a grid of :code:`(M / TM) x (N / TN)` programs decomposed as follows:
- Statements (1) and (2) extract the coordinates the program in the above 2D launch grid. For example, the program producing the output tile `Y[TN:2TN-1, 2TN:3TN-1]` holds the values:
@@ -54,7 +54,7 @@ which will be used in statements (5) and (6) to construct tiles of pointers
- Statements (5) constructs the following array of pointers `px` using numpy-style broadcasting semantics:
::
│ X + (pidm*TM + 0) + (pidn*TN + 0)*ldx, ..., ..., X + (pidm*TM + 0) + (pidn*TN + TN - 1)*ldx) │
│ ⋮ ⋮ │
│ ⋮ ⋮ │
@@ -83,19 +83,19 @@ For this reason, Triton provides a __multipleof(N) attributes for variables that
.. code-block:: C
__global__ void transpose(TYPE * X, TYPE * Y, int M, int N,
int ldx __multipleof(8),
__global__ void transpose(TYPE * X, TYPE * Y, int M, int N,
int ldx __multipleof(8),
int ldy __multipleof(8)) {
// ...
}
==========================
Bounds Checking
==========================
You might have noticed that the above code will fail when `M` and `N` are not multiples of `TM` and `TN` respectively. Fortunately, the above kernel can be slightly modified to handle thie situation, as shown below:
You might have noticed that the above code will fail when `M` and `N` are not multiples of `TM` and `TN` respectively. Fortunately, the above kernel can be slightly modified to handle this situation, as shown below:
.. code-block:: C
@@ -108,6 +108,6 @@ You might have noticed that the above code will fail when `M` and `N` are not mu
// conditional write-back using the conditional dereferencing operatior '*?()'
*?(checky)py = ^(*?(checkx)px); //(7)
}
Here, statements (7a) creates an array of booleans :code:`checkx[TM, TN]` such that :code:`checkx(i, j) = True` if and only if `px(i, j)` should be dereferenced. Statement (7b) does the same for `py`. Both `px` and `py` are then conditionally dereferenced using Triton-C's conditional dereferencing operator :code:`*?(predicate) pointer`.
Here, statements (7a) creates an array of booleans :code:`checkx[TM, TN]` such that :code:`checkx(i, j) = True` if and only if `px(i, j)` should be dereferenced. Statement (7b) does the same for `py`. Both `px` and `py` are then conditionally dereferenced using Triton-C's conditional dereferencing operator :code:`*?(predicate) pointer`.

View File

@@ -110,9 +110,9 @@ However, in practice only A, B are provided by the user, and all the other :code
'TYPE' : dtype,
'AT' : transpose_a,
'BT' : transpose_b,
'TM' : [32, 64, 128]
'TN' : [32, 64, 128]
'TK' : [8]
'TM' : [32, 64, 128],
'TN' : [32, 64, 128],
'TK' : [8],
# handle A transposition
'USE_A' : '^a' if transpose_a else 'a',
'STRIDE_AK' : 'lda' if transpose_a else '1',
@@ -177,4 +177,4 @@ A callable operation can be created using the :code:`apply` method of the :code:
dot = _dot.apply
And that's it! In just ~100 lines of pure python, we have written a fully functional matrix multiplication that will not only work with automatic differentiation but also provide performance very close to cuBLAS. And it's all open-source~
And that's it! In just ~100 lines of pure python, we have written a fully functional matrix multiplication that will not only work with automatic differentiation but also provide performance very close to cuBLAS. And it's all open-source~