diff --git a/docs/tutorials/custom-operation.rst b/docs/tutorials/custom-operation.rst index e6abffb12..28c724ede 100644 --- a/docs/tutorials/custom-operation.rst +++ b/docs/tutorials/custom-operation.rst @@ -58,7 +58,7 @@ As you will see, a wrapper for the above Triton function can be created in just """ # create callable kernel for the source-code # options: 4 warps and a -DTILE=1024 - kernel = triton.kernel(src, defines = {'TILE': 1024}; num_warps = [4]) + kernel = triton.kernel(src, defines = {'TILE': 1024}, num_warps = [4]) # Forward pass @staticmethod @@ -88,6 +88,7 @@ As you will see, a wrapper for the above Triton function can be created in just zb = add(x, y) diff = (za - zb).abs().max() print(diff) + print(torch.allclose(za,zb)) Executing the above code will: @@ -97,3 +98,5 @@ Executing the above code will: - Call the resulting custom op In other words, the first program run will generate and cache a bunch of files in $HOME/.triton/cache, but subsequent runs should be just as fast as using a handwritten custom operation. + +A runnable version of this kernel is available `here `_. diff --git a/docs/tutorials/matrix-multiplication.rst b/docs/tutorials/matrix-multiplication.rst index 8cf5c79e5..fbdd3eabb 100644 --- a/docs/tutorials/matrix-multiplication.rst +++ b/docs/tutorials/matrix-multiplication.rst @@ -10,7 +10,7 @@ The purpose of this section is to present a Triton-C implementation of matrix mu Compute Kernel ============== -Matrix multiplications of the form `C = A x B` can be implemented in Triton-C fairly concisely, as shown below: +Matrix multiplications of the form `C = A x B` can be implemented in Triton-C fairly concisely, as shown below: .. code-block:: C @@ -24,7 +24,7 @@ Matrix multiplications of the form `C = A x B` can be implemented in Triton-C fa int rm[TM] = pm * TM + 0 ... TM; //(3) int rn[TN] = pn * TN + 0 ... TN; //(4) int rk[TK] = 0 ... TK; //(5) - // initialize accumulator + // initialize accumulator float c[TM, TN] = 0; //(6) // pointers to operands TYPE* pa[TM, TK] = A + rk[newaxis, :] * 1 + rm[:, newaxis] * lda; //(7) @@ -32,10 +32,10 @@ Matrix multiplications of the form `C = A x B` can be implemented in Triton-C fa // reduction loop for(int k = K; k > 0; k-= TK){ // fetch operands - TYPE a[TM, TK] = *pa; //(9) + TYPE a[TM, TK] = *pa; //(9) TYPE b[TK, TN] = *pb; //(10) // matrix-multiply accumulate - c += dot(a, b); //(11) + c += a @ b; //(11) // increment pointers pa = pa + TK * 1; //(12) pb = pb + TK * ldb; //(13) @@ -85,10 +85,10 @@ The purpose of pre-fetching is to overlap the update of the accumulator `c` with .. code-block:: C // pre-fetch operands - TYPE a[TM, TK] = *pa; //(9) + TYPE a[TM, TK] = *pa; //(9) TYPE b[TK, TN] = *pb; //(10) for(int k = K; k > 0; k-= TK){ - c += dot(a, b); + c += a @ b; pa = pa + TK * 1; pb = pb + TK * ldb; // don't prefetch last iteration @@ -113,7 +113,7 @@ Rematerialization int rcm[TM] = pm * TM + 0 ... TM; int rcn[TN] = pn * TN + 0 ... TN; TYPE* pc[TM, TN] = C + rcn[newaxis, :] + rcm[:, newaxis] * ldc; - *pc = c; + *pc = c; ------------------------------------ @@ -144,7 +144,7 @@ It is common for optimized matrix-multiplication implementations (e.g., BLAS) to TYPE b[SHAPE_B] = (*pb); // reduction loop for(int k = K; k > 0; k-= TK){ - c += dot(USE_A, USE_B); + c += USE_A @ USE_B; pa = pa + TK * STRIDE_AK; pb = pb + TK * STRIDE_BK; a = *pa; @@ -163,16 +163,16 @@ All matrix multiplications variants can then be retrieved using the following co .. code-block:: C // A is not transposed - -DUSE_A=a -DSTRIDE_AK=1-DSTRIDE_AM=lda + -DUSE_A=a -DSTRIDE_AK=1-DSTRIDE_AM=lda -DBROADCAST_AK=newaxis,: -DBROADCAST_AN=:,newaxis -DSHAPE_A=TM,TK // A is transposed - -DUSE_A=^a -DSTRIDE_AK=lda-DSTRIDE_AM=1 + -DUSE_A=^a -DSTRIDE_AK=lda-DSTRIDE_AM=1 -DBROADCAST_AK=:,newaxis -DBROADCAST_AN=newaxis,: -DSHAPE_A=TK,TM // B is not transpose - -DUSE_B=b -DSTRIDE_BK=ldb-DSTRIDE_BN=1 + -DUSE_B=b -DSTRIDE_BK=ldb-DSTRIDE_BN=1 -DBROADCAST_BK=:,newaxis -DBROADCAST_BN=newaxis,: -DSHAPE_B=TK,TN // B is transpose - -DUSE_B=^b -DSTRIDE_BK=1-DSTRIDE_BN=ldb + -DUSE_B=^b -DSTRIDE_BK=1-DSTRIDE_BN=ldb -DBROADCAST_BK=newaxis,: -DBROADCAST_BN=:,newaxis -DSHAPE_B=TN,TK @@ -182,3 +182,5 @@ Auto-tuning can also be handled using pre-processor macros: // Auto-tuning TM and TN in {32, 64, 128}; TK in {8, 16} -DTM=[32, 64, 128] -DTN=[32, 64, 128] -DTK=[8, 16] + +A runnable version of this kernel is available `here `_. diff --git a/docs/tutorials/matrix-transposition.rst b/docs/tutorials/matrix-transposition.rst index 22894e7d8..d51ff6b41 100644 --- a/docs/tutorials/matrix-transposition.rst +++ b/docs/tutorials/matrix-transposition.rst @@ -3,7 +3,7 @@ Matrix Transpositions ********************* -Transpositions are (relatively) hard to efficiently write in CUDA because naive implementations typically suffer from *uncoalesced* memory operations when writing back the transposed matrix to DRAM. +Transpositions are (relatively) hard to efficiently write in CUDA because naive implementations typically suffer from *uncoalesced* memory operations when writing back the transposed matrix to DRAM. Of course, this can be fixed by using shared memory as shown `here `_, but this comes at the cost of simplicity interferes with auto-tuning. @@ -16,7 +16,7 @@ In Triton, however, kernels are single-threaded and the compiler automatically d .. code-block:: C // launched on a grid of (M / TM) x (N / TN) programs of 1 thread each - __global__ void transpose(TYPE * X, TYPE * Y, + __global__ void transpose(TYPE * X, TYPE * Y, int M, int N, int ldx, int ldy) { // extract program ID int pidm = get_program_id(0); //(1) @@ -25,12 +25,12 @@ In Triton, however, kernels are single-threaded and the compiler automatically d int rm[TM] = pidm * TM + 0 ... TM; //(3) int rn[TN] = pidn * TN + 0 ... TN; //(4) // create 2D array of pointers - TYPE* px[TM, TN] = X + rm[:, newaxis] + rn[newaxis, :] * ldx; //(5) - TYPE* py[TN, TM] = Y + rm[newaxis, :] * ldy + rn[:, newaxis]; //(6) + TYPE* px[TM, TN] = X + rm[:, newaxis] * ldx + rn[newaxis, :]; //(5) + TYPE* py[TN, TM] = Y + rm[newaxis, :] + rn[:, newaxis] * ldy; //(6) // write back using the transposition operator '^' *py = ^(*px); //(7) } - + At a high level, this kernel loads a :code:`TM x TN` tile from the input matrix :code:`X`, transposes it and writes the resulting :code:`TN x TM` tile to the output matrix :code:`Y`. Eventually, transposition of the full input matrix is achieved by launching a grid of :code:`(M / TM) x (N / TN)` programs decomposed as follows: - Statements (1) and (2) extract the coordinates the program in the above 2D launch grid. For example, the program producing the output tile `Y[TN:2TN-1, 2TN:3TN-1]` holds the values: @@ -54,7 +54,7 @@ which will be used in statements (5) and (6) to construct tiles of pointers - Statements (5) constructs the following array of pointers `px` using numpy-style broadcasting semantics: :: - + │ X + (pidm*TM + 0) + (pidn*TN + 0)*ldx, ..., ..., X + (pidm*TM + 0) + (pidn*TN + TN - 1)*ldx) │ │ ⋮ ⋮ │ │ ⋮ ⋮ │ @@ -73,6 +73,65 @@ which will be used in statements (5) and (6) to construct tiles of pointers - Statement (7) element-wise dereferences the above array of pointers `*px`, transposes it using the unary transposition operator `^`, and writes it back at the location specified by `py`. +================================== +A Note on Numpy-style Broadcasting +================================== + +The construction statements (5) and (6) are a little subtle. To help understand them, consider the following numpy example. + +First, we create a row vector of numbers 0 to 11, which we reshape into a 4x3 matrix. + +.. code-block:: python + + import numpy as np + + vec = np.linspace(0,11,12) + mat = vec.reshape((4,3)) + +Imagine that we would like to process this in two 2x3 tiles (i.e. tile 0 will consider the top half, and tile 1 will consider the bottom). + +:: + + [[ 0, 1, 2], + [ 3, 4, 5], + [ 6, 7, 8], + [ 9, 10, 11]] + +Given `pidm=0`, `pidn=0`, `TM=2`, `TN=3`, we would like for tile 0 to have the values: + +:: + + [ 0, 1, 2], + [ 3, 4, 5], + +We construct ranges `rm` and `rn` as: +:: + + rm = [0, 1] + rn = [0, 1, 2] + +Using numpy-style broadcasting, we can add these together to create a matrix: + +:: + + rm[:, np.newaxis] + rn[np.newaxis, :] + + rn -> [0, 1, 2] + rm -> [0., [[0, 1, 2], + 1.] [1, 2, 3]] + +The bottom row is incorrect. Notice that `rm` indexes the rows of the matrix; we need to offset it so that each element gives the index +of the start of that row. For instance, to access row 1 column 0, we need to access location 3. To access row 2 column 0, we need +to access location 6. To translate from row N, column 0, we need to multiply N by the number of columns in each row (the leading dimension). +In this case this is 3, so what we really need is: + +:: + + ldx = 3 + px = rm[:, np.newaxis] * ldx + rn[np.newaxis,:] + +`newaxis` is built into Triton, and pointer arrays can be constructed in just the same way (as in this example). + ========================== The __multipleof attribute ========================== @@ -83,19 +142,19 @@ For this reason, Triton provides a __multipleof(N) attributes for variables that .. code-block:: C - __global__ void transpose(TYPE * X, TYPE * Y, int M, int N, - int ldx __multipleof(8), + __global__ void transpose(TYPE * X, TYPE * Y, int M, int N, + int ldx __multipleof(8), int ldy __multipleof(8)) { // ... } - + ========================== Bounds Checking ========================== -You might have noticed that the above code will fail when `M` and `N` are not multiples of `TM` and `TN` respectively. Fortunately, the above kernel can be slightly modified to handle thie situation, as shown below: +You might have noticed that the above code will fail when `M` and `N` are not multiples of `TM` and `TN` respectively. Fortunately, the above kernel can be slightly modified to handle this situation, as shown below: .. code-block:: C @@ -108,6 +167,8 @@ You might have noticed that the above code will fail when `M` and `N` are not mu // conditional write-back using the conditional dereferencing operatior '*?()' *?(checky)py = ^(*?(checkx)px); //(7) } - -Here, statements (7a) creates an array of booleans :code:`checkx[TM, TN]` such that :code:`checkx(i, j) = True` if and only if `px(i, j)` should be dereferenced. Statement (7b) does the same for `py`. Both `px` and `py` are then conditionally dereferenced using Triton-C's conditional dereferencing operator :code:`*?(predicate) pointer`. \ No newline at end of file + +Here, statements (7a) creates an array of booleans :code:`checkx[TM, TN]` such that :code:`checkx(i, j) = True` if and only if `px(i, j)` should be dereferenced. Statement (7b) does the same for `py`. Both `px` and `py` are then conditionally dereferenced using Triton-C's conditional dereferencing operator :code:`*?(predicate) pointer`. + +A runnable version of this kernel is available `here `_. diff --git a/docs/tutorials/putting-it-all-together.rst b/docs/tutorials/putting-it-all-together.rst index 4f760c83c..693a044c9 100644 --- a/docs/tutorials/putting-it-all-together.rst +++ b/docs/tutorials/putting-it-all-together.rst @@ -110,9 +110,9 @@ However, in practice only A, B are provided by the user, and all the other :code 'TYPE' : dtype, 'AT' : transpose_a, 'BT' : transpose_b, - 'TM' : [32, 64, 128] - 'TN' : [32, 64, 128] - 'TK' : [8] + 'TM' : [32, 64, 128], + 'TN' : [32, 64, 128], + 'TK' : [8], # handle A transposition 'USE_A' : '^a' if transpose_a else 'a', 'STRIDE_AK' : 'lda' if transpose_a else '1', @@ -177,4 +177,4 @@ A callable operation can be created using the :code:`apply` method of the :code: dot = _dot.apply -And that's it! In just ~100 lines of pure python, we have written a fully functional matrix multiplication that will not only work with automatic differentiation but also provide performance very close to cuBLAS. And it's all open-source~ \ No newline at end of file +And that's it! In just ~100 lines of pure python, we have written a fully functional matrix multiplication that will not only work with automatic differentiation but also provide performance very close to cuBLAS. And it's all open-source~ diff --git a/python/examples/tutorials/mat_copy.py b/python/examples/tutorials/mat_copy.py new file mode 100644 index 000000000..5eeca842f --- /dev/null +++ b/python/examples/tutorials/mat_copy.py @@ -0,0 +1,68 @@ +import torch +import triton + +class _copy(torch.autograd.Function): + src = """ + __global__ void copy(TYPE * X, TYPE * Y, + int M, int N, int ldx __multipleof(8)) { + // extract program ID + int pidm = get_program_id(0); //(1) + int pidn = get_program_id(1); //(2) + + // create 1D range along the two matrix's axes + int rm[TM] = pidm * TM + 0 ... TM; //(3) + int rn[TN] = pidn * TN + 0 ... TN; //(4) + + // create 2D array of pointers + TYPE* px[TM, TN] = X + rm[:, newaxis] + rn[newaxis, :] * ldx; //(5) + TYPE* py[TM, TN] = Y + rm[:, newaxis] + rn[newaxis, :] * ldx; //(6) + + *py = *px; + } + """ + + kernel = None ### initialize later when we know the sizes + + @staticmethod + def forward(ctx, x): + + M, N = x.shape + + ldx = N; + + dtype = x.dtype + + y = torch.empty((M,N)).cuda() + + defines= { + 'TYPE' : dtype, + 'TM' : [32,64,128], + 'TN' : [32,64,128], + } + grid = lambda opt: [triton.cdiv(M, opt.d('TM')), triton.cdiv(N, opt.d('TN'))] + + if _copy.kernel is None: + _copy.kernel = triton.kernel(_copy.src, defines=defines, num_warps=[4]) + + _copy.kernel(x, y, M, N, ldx, grid=grid) + + return y + +copy = _copy.apply + +# test +torch.manual_seed(0) +x = torch.randn(8,4).cuda() + +print(x) + +ya = x +yb = copy(x) + +print() +print(ya) +print() +print(yb) +print(torch.allclose(ya, yb)) + +print(ya == yb) diff --git a/python/examples/tutorials/mat_mul.py b/python/examples/tutorials/mat_mul.py new file mode 100644 index 000000000..a8230207f --- /dev/null +++ b/python/examples/tutorials/mat_mul.py @@ -0,0 +1,84 @@ +import torch +import triton + +class _dot(torch.autograd.Function): + src = """ + __global__ void dot(TYPE *A, TYPE *B, TYPE *C, int M, int N, int K, + int lda __multipleof(8), int ldb __multipleof(8), int ldc __multipleof(8)) { + int pm = get_program_id(0); + int pn = get_program_id(1); + + // ranges + int rm[TM] = pm * TM + 0 ... TM; + int rn[TN] = pn * TN + 0 ... TN; + int rk[TK] = 0 ... TK; + + // accumulator + float c[TM, TN] = 0; + + //pointers + TYPE* pa[TM, TK] = A + rk[newaxis, :] * 1 + rm[:, newaxis] * lda; + TYPE* pb[TK, TN] = B + rk[:, newaxis] * ldb + rn[newaxis, :] * 1; + + for(int k=K; k>0; k-=TK) { + TYPE a[TM, TK] = *pa; + TYPE b[TK, TN] = *pb; + + c += a @ b; + + pa = pa + TK * 1; + pb = pb + TK * ldb; + } + + TYPE* pc[TM,TN] = C + rn[newaxis, :] + rm[:,newaxis] * ldc; + *pc = c; + } + """ + + @staticmethod + def forward(ctx, a, b): + c = _dot._call(a,b) + return c + + + @staticmethod + def _call(a, b): + M, K = a.shape + K, N = b.shape + + lda = K + ldb = N + ldc = N + + dtype = a.dtype + + c = triton.empty([M,N], dtype=dtype) + + grid = lambda opt: [triton.cdiv(M, opt.d('TM')), triton.cdiv(N, opt.d('TN'))] + + defines= { + 'TYPE' : dtype, + 'TM' : [32,64,128], + 'TN' : [32,64,128], + 'TK' : [8], + } + + _dot.kernel = triton.kernel(_dot.src, defines=defines) + _dot.kernel(a, b, c, M, N, K, lda, ldb, ldc, + grid=grid, num_warps=4, defines=defines) + return c + + +dot = _dot.apply + +torch.manual_seed(0) + +M, N, K = 128, 512, 256 +a = torch.rand((M, K)).cuda() +b = torch.rand((K, N)).cuda() + + +zc = torch.matmul(a,b) +zc_ = dot(a,b) + +print(torch.allclose(zc, zc_)) diff --git a/python/examples/tutorials/mat_transpose.py b/python/examples/tutorials/mat_transpose.py new file mode 100644 index 000000000..39f05c902 --- /dev/null +++ b/python/examples/tutorials/mat_transpose.py @@ -0,0 +1,74 @@ +import torch +import triton + +class _transpose(torch.autograd.Function): + src = """ + __global__ void transpose(TYPE * X, TYPE * Y, + int M, int N, int ldx __multipleof(8), int ldy __multipleof(8)) { + // extract program ID + int pidm = get_program_id(0); //(1) + int pidn = get_program_id(1); //(2) + + // create 1D range along the two matrix's axes + int rm[TM] = pidm * TM + 0 ... TM; //(3) + int rn[TN] = pidn * TN + 0 ... TN; //(4) + + // create 2D array of pointers + TYPE* px[TM, TN] = X + rm[:, newaxis] * ldx + rn[newaxis, :]; //(5) + TYPE* py[TN, TM] = Y + rm[newaxis, :] + rn[:, newaxis] * ldy; //(6) + + // create bounds-checking mask + bool checkx[TM, TN] = (rm[:, newaxis] < M) && (rn[newaxis, :] < N); //(7a) + bool checky[TN, TM] = (rn[:, newaxis] < N) && (rm[newaxis, :] < M); //(7b) + + // conditional write-back using the conditional dereferencing operatior '*?()' + *?(checky)py = ^(*?(checkx)px); //(7) + } + """ + + kernel = None ### initialize later when we know the sizes + + @staticmethod + def forward(ctx, x): + + M, N = x.shape + + ldx = N + ldy = M + + dtype = x.dtype + + y = torch.empty((N,M)).cuda() + + defines= { + 'TYPE' : dtype, + 'TM' : [32,64,128], + 'TN' : [32,64,128], + } + + grid = lambda opt: [triton.cdiv(M, opt.d('TM')), triton.cdiv(N, opt.d('TN'))] + + if _transpose.kernel is None: + _transpose.kernel = triton.kernel(_transpose.src, defines=defines, num_warps=[4]) + + _transpose.kernel(x, y, M, N, ldx, ldy, grid=grid) + + return y + +transpose = _transpose.apply + +# test +torch.manual_seed(0) +x = torch.randn(1024,128).cuda() + +print(x) + +ya = torch.t(x) +yb = transpose(x) +print() +print(ya) +print() +print(yb) +print(torch.allclose(ya, yb)) + +print(ya == yb) diff --git a/python/examples/tutorials/vec_add.py b/python/examples/tutorials/vec_add.py new file mode 100644 index 000000000..cbf63ff08 --- /dev/null +++ b/python/examples/tutorials/vec_add.py @@ -0,0 +1,43 @@ +import torch +import triton + +class _add(torch.autograd.Function): + src = """ +__global__ void add(float* z, float* x, float* y, int N) { + + int pid = get_program_id(0); + + int offset[TILE] = pid * TILE + 0 ... TILE; + float* pz[TILE] = z + offset; + float* px[TILE] = x + offset; + float* py[TILE] = y + offset; + + bool check[TILE] = offset < N; + + *?(check)pz = *?(check)px + *?(check)py; +} + """ + + kernel = triton.kernel(src, defines={'TILE': 1024}, num_warps=[4]) + + @staticmethod + def forward(ctx, x, y): + z = torch.empty_like(x).cuda() + + N = x.numel() + grid = lambda opt: (triton.cdiv(N, opt.d('TILE')),) + + _add.kernel(z,x,y, N, grid=grid) + + return z + +add = _add.apply + +# test +torch.manual_seed(0) +x = torch.rand(98432).cuda() +y = torch.rand(98432).cuda() +za = x + y +zb = add(x, y) + +print(torch.allclose(za,zb))