[DOCS] Matmul and vecadd working examples
This commit is contained in:
committed by
Philippe Tillet
parent
ce4a4728f5
commit
32819dea51
@@ -58,7 +58,7 @@ As you will see, a wrapper for the above Triton function can be created in just
|
||||
"""
|
||||
# create callable kernel for the source-code
|
||||
# options: 4 warps and a -DTILE=1024
|
||||
kernel = triton.kernel(src, defines = {'TILE': 1024}; num_warps = [4])
|
||||
kernel = triton.kernel(src, defines = {'TILE': 1024}, num_warps = [4])
|
||||
|
||||
# Forward pass
|
||||
@staticmethod
|
||||
@@ -88,6 +88,7 @@ As you will see, a wrapper for the above Triton function can be created in just
|
||||
zb = add(x, y)
|
||||
diff = (za - zb).abs().max()
|
||||
print(diff)
|
||||
print(torch.allclose(za,zb))
|
||||
|
||||
Executing the above code will:
|
||||
|
||||
@@ -97,3 +98,5 @@ Executing the above code will:
|
||||
- Call the resulting custom op
|
||||
|
||||
In other words, the first program run will generate and cache a bunch of files in $HOME/.triton/cache, but subsequent runs should be just as fast as using a handwritten custom operation.
|
||||
|
||||
A runnable version of this kernel is available `here <https://github.com/ptillet/triton/tree/master/python/examples/tutorials/vec_add.py>`_.
|
||||
|
@@ -10,7 +10,7 @@ The purpose of this section is to present a Triton-C implementation of matrix mu
|
||||
Compute Kernel
|
||||
==============
|
||||
|
||||
Matrix multiplications of the form `C = A x B` can be implemented in Triton-C fairly concisely, as shown below:
|
||||
Matrix multiplications of the form `C = A x B` can be implemented in Triton-C fairly concisely, as shown below:
|
||||
|
||||
.. code-block:: C
|
||||
|
||||
@@ -24,7 +24,7 @@ Matrix multiplications of the form `C = A x B` can be implemented in Triton-C fa
|
||||
int rm[TM] = pm * TM + 0 ... TM; //(3)
|
||||
int rn[TN] = pn * TN + 0 ... TN; //(4)
|
||||
int rk[TK] = 0 ... TK; //(5)
|
||||
// initialize accumulator
|
||||
// initialize accumulator
|
||||
float c[TM, TN] = 0; //(6)
|
||||
// pointers to operands
|
||||
TYPE* pa[TM, TK] = A + rk[newaxis, :] * 1 + rm[:, newaxis] * lda; //(7)
|
||||
@@ -32,10 +32,10 @@ Matrix multiplications of the form `C = A x B` can be implemented in Triton-C fa
|
||||
// reduction loop
|
||||
for(int k = K; k > 0; k-= TK){
|
||||
// fetch operands
|
||||
TYPE a[TM, TK] = *pa; //(9)
|
||||
TYPE a[TM, TK] = *pa; //(9)
|
||||
TYPE b[TK, TN] = *pb; //(10)
|
||||
// matrix-multiply accumulate
|
||||
c += dot(a, b); //(11)
|
||||
c += a @ b; //(11)
|
||||
// increment pointers
|
||||
pa = pa + TK * 1; //(12)
|
||||
pb = pb + TK * ldb; //(13)
|
||||
@@ -85,10 +85,10 @@ The purpose of pre-fetching is to overlap the update of the accumulator `c` with
|
||||
.. code-block:: C
|
||||
|
||||
// pre-fetch operands
|
||||
TYPE a[TM, TK] = *pa; //(9)
|
||||
TYPE a[TM, TK] = *pa; //(9)
|
||||
TYPE b[TK, TN] = *pb; //(10)
|
||||
for(int k = K; k > 0; k-= TK){
|
||||
c += dot(a, b);
|
||||
c += a @ b;
|
||||
pa = pa + TK * 1;
|
||||
pb = pb + TK * ldb;
|
||||
// don't prefetch last iteration
|
||||
@@ -113,7 +113,7 @@ Rematerialization
|
||||
int rcm[TM] = pm * TM + 0 ... TM;
|
||||
int rcn[TN] = pn * TN + 0 ... TN;
|
||||
TYPE* pc[TM, TN] = C + rcn[newaxis, :] + rcm[:, newaxis] * ldc;
|
||||
*pc = c;
|
||||
*pc = c;
|
||||
|
||||
|
||||
------------------------------------
|
||||
@@ -144,7 +144,7 @@ It is common for optimized matrix-multiplication implementations (e.g., BLAS) to
|
||||
TYPE b[SHAPE_B] = (*pb);
|
||||
// reduction loop
|
||||
for(int k = K; k > 0; k-= TK){
|
||||
c += dot(USE_A, USE_B);
|
||||
c += USE_A @ USE_B;
|
||||
pa = pa + TK * STRIDE_AK;
|
||||
pb = pb + TK * STRIDE_BK;
|
||||
a = *pa;
|
||||
@@ -163,16 +163,16 @@ All matrix multiplications variants can then be retrieved using the following co
|
||||
.. code-block:: C
|
||||
|
||||
// A is not transposed
|
||||
-DUSE_A=a -DSTRIDE_AK=1-DSTRIDE_AM=lda
|
||||
-DUSE_A=a -DSTRIDE_AK=1-DSTRIDE_AM=lda
|
||||
-DBROADCAST_AK=newaxis,: -DBROADCAST_AN=:,newaxis -DSHAPE_A=TM,TK
|
||||
// A is transposed
|
||||
-DUSE_A=^a -DSTRIDE_AK=lda-DSTRIDE_AM=1
|
||||
-DUSE_A=^a -DSTRIDE_AK=lda-DSTRIDE_AM=1
|
||||
-DBROADCAST_AK=:,newaxis -DBROADCAST_AN=newaxis,: -DSHAPE_A=TK,TM
|
||||
// B is not transpose
|
||||
-DUSE_B=b -DSTRIDE_BK=ldb-DSTRIDE_BN=1
|
||||
-DUSE_B=b -DSTRIDE_BK=ldb-DSTRIDE_BN=1
|
||||
-DBROADCAST_BK=:,newaxis -DBROADCAST_BN=newaxis,: -DSHAPE_B=TK,TN
|
||||
// B is transpose
|
||||
-DUSE_B=^b -DSTRIDE_BK=1-DSTRIDE_BN=ldb
|
||||
-DUSE_B=^b -DSTRIDE_BK=1-DSTRIDE_BN=ldb
|
||||
-DBROADCAST_BK=newaxis,: -DBROADCAST_BN=:,newaxis -DSHAPE_B=TN,TK
|
||||
|
||||
|
||||
@@ -182,3 +182,5 @@ Auto-tuning can also be handled using pre-processor macros:
|
||||
|
||||
// Auto-tuning TM and TN in {32, 64, 128}; TK in {8, 16}
|
||||
-DTM=[32, 64, 128] -DTN=[32, 64, 128] -DTK=[8, 16]
|
||||
|
||||
A runnable version of this kernel is available `here <https://github.com/ptillet/triton/tree/master/python/examples/tutorials/mat_mul.py>`_.
|
||||
|
@@ -3,7 +3,7 @@ Matrix Transpositions
|
||||
*********************
|
||||
|
||||
|
||||
Transpositions are (relatively) hard to efficiently write in CUDA because naive implementations typically suffer from *uncoalesced* memory operations when writing back the transposed matrix to DRAM.
|
||||
Transpositions are (relatively) hard to efficiently write in CUDA because naive implementations typically suffer from *uncoalesced* memory operations when writing back the transposed matrix to DRAM.
|
||||
|
||||
Of course, this can be fixed by using shared memory as shown `here <https://devblogs.nvidia.com/efficient-matrix-transpose-cuda-cc>`_, but this comes at the cost of simplicity interferes with auto-tuning.
|
||||
|
||||
@@ -16,7 +16,7 @@ In Triton, however, kernels are single-threaded and the compiler automatically d
|
||||
.. code-block:: C
|
||||
|
||||
// launched on a grid of (M / TM) x (N / TN) programs of 1 thread each
|
||||
__global__ void transpose(TYPE * X, TYPE * Y,
|
||||
__global__ void transpose(TYPE * X, TYPE * Y,
|
||||
int M, int N, int ldx, int ldy) {
|
||||
// extract program ID
|
||||
int pidm = get_program_id(0); //(1)
|
||||
@@ -30,7 +30,7 @@ In Triton, however, kernels are single-threaded and the compiler automatically d
|
||||
// write back using the transposition operator '^'
|
||||
*py = ^(*px); //(7)
|
||||
}
|
||||
|
||||
|
||||
At a high level, this kernel loads a :code:`TM x TN` tile from the input matrix :code:`X`, transposes it and writes the resulting :code:`TN x TM` tile to the output matrix :code:`Y`. Eventually, transposition of the full input matrix is achieved by launching a grid of :code:`(M / TM) x (N / TN)` programs decomposed as follows:
|
||||
|
||||
- Statements (1) and (2) extract the coordinates the program in the above 2D launch grid. For example, the program producing the output tile `Y[TN:2TN-1, 2TN:3TN-1]` holds the values:
|
||||
@@ -54,7 +54,7 @@ which will be used in statements (5) and (6) to construct tiles of pointers
|
||||
- Statements (5) constructs the following array of pointers `px` using numpy-style broadcasting semantics:
|
||||
|
||||
::
|
||||
|
||||
|
||||
│ X + (pidm*TM + 0) + (pidn*TN + 0)*ldx, ..., ..., X + (pidm*TM + 0) + (pidn*TN + TN - 1)*ldx) │
|
||||
│ ⋮ ⋮ │
|
||||
│ ⋮ ⋮ │
|
||||
@@ -83,19 +83,19 @@ For this reason, Triton provides a __multipleof(N) attributes for variables that
|
||||
|
||||
.. code-block:: C
|
||||
|
||||
__global__ void transpose(TYPE * X, TYPE * Y, int M, int N,
|
||||
int ldx __multipleof(8),
|
||||
__global__ void transpose(TYPE * X, TYPE * Y, int M, int N,
|
||||
int ldx __multipleof(8),
|
||||
int ldy __multipleof(8)) {
|
||||
// ...
|
||||
}
|
||||
|
||||
|
||||
|
||||
==========================
|
||||
Bounds Checking
|
||||
==========================
|
||||
|
||||
|
||||
You might have noticed that the above code will fail when `M` and `N` are not multiples of `TM` and `TN` respectively. Fortunately, the above kernel can be slightly modified to handle thie situation, as shown below:
|
||||
You might have noticed that the above code will fail when `M` and `N` are not multiples of `TM` and `TN` respectively. Fortunately, the above kernel can be slightly modified to handle this situation, as shown below:
|
||||
|
||||
.. code-block:: C
|
||||
|
||||
@@ -108,6 +108,6 @@ You might have noticed that the above code will fail when `M` and `N` are not mu
|
||||
// conditional write-back using the conditional dereferencing operatior '*?()'
|
||||
*?(checky)py = ^(*?(checkx)px); //(7)
|
||||
}
|
||||
|
||||
|
||||
Here, statements (7a) creates an array of booleans :code:`checkx[TM, TN]` such that :code:`checkx(i, j) = True` if and only if `px(i, j)` should be dereferenced. Statement (7b) does the same for `py`. Both `px` and `py` are then conditionally dereferenced using Triton-C's conditional dereferencing operator :code:`*?(predicate) pointer`.
|
||||
|
||||
Here, statements (7a) creates an array of booleans :code:`checkx[TM, TN]` such that :code:`checkx(i, j) = True` if and only if `px(i, j)` should be dereferenced. Statement (7b) does the same for `py`. Both `px` and `py` are then conditionally dereferenced using Triton-C's conditional dereferencing operator :code:`*?(predicate) pointer`.
|
||||
|
@@ -110,9 +110,9 @@ However, in practice only A, B are provided by the user, and all the other :code
|
||||
'TYPE' : dtype,
|
||||
'AT' : transpose_a,
|
||||
'BT' : transpose_b,
|
||||
'TM' : [32, 64, 128]
|
||||
'TN' : [32, 64, 128]
|
||||
'TK' : [8]
|
||||
'TM' : [32, 64, 128],
|
||||
'TN' : [32, 64, 128],
|
||||
'TK' : [8],
|
||||
# handle A transposition
|
||||
'USE_A' : '^a' if transpose_a else 'a',
|
||||
'STRIDE_AK' : 'lda' if transpose_a else '1',
|
||||
@@ -177,4 +177,4 @@ A callable operation can be created using the :code:`apply` method of the :code:
|
||||
dot = _dot.apply
|
||||
|
||||
|
||||
And that's it! In just ~100 lines of pure python, we have written a fully functional matrix multiplication that will not only work with automatic differentiation but also provide performance very close to cuBLAS. And it's all open-source~
|
||||
And that's it! In just ~100 lines of pure python, we have written a fully functional matrix multiplication that will not only work with automatic differentiation but also provide performance very close to cuBLAS. And it's all open-source~
|
||||
|
84
python/examples/tutorials/mat_mul.py
Normal file
84
python/examples/tutorials/mat_mul.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import torch
|
||||
import triton
|
||||
|
||||
class _dot(torch.autograd.Function):
|
||||
src = """
|
||||
__global__ void dot(TYPE *A, TYPE *B, TYPE *C, int M, int N, int K,
|
||||
int lda __multipleof(8), int ldb __multipleof(8), int ldc __multipleof(8)) {
|
||||
int pm = get_program_id(0);
|
||||
int pn = get_program_id(1);
|
||||
|
||||
// ranges
|
||||
int rm[TM] = pm * TM + 0 ... TM;
|
||||
int rn[TN] = pn * TN + 0 ... TN;
|
||||
int rk[TK] = 0 ... TK;
|
||||
|
||||
// accumulator
|
||||
float c[TM, TN] = 0;
|
||||
|
||||
//pointers
|
||||
TYPE* pa[TM, TK] = A + rk[newaxis, :] * 1 + rm[:, newaxis] * lda;
|
||||
TYPE* pb[TK, TN] = B + rk[:, newaxis] * ldb + rn[newaxis, :] * 1;
|
||||
|
||||
for(int k=K; k>0; k-=TK) {
|
||||
TYPE a[TM, TK] = *pa;
|
||||
TYPE b[TK, TN] = *pb;
|
||||
|
||||
c += a @ b;
|
||||
|
||||
pa = pa + TK * 1;
|
||||
pb = pb + TK * ldb;
|
||||
}
|
||||
|
||||
TYPE* pc[TM,TN] = C + rn[newaxis, :] + rm[:,newaxis] * ldc;
|
||||
*pc = c;
|
||||
}
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, a, b):
|
||||
c = _dot._call(a,b)
|
||||
return c
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _call(a, b):
|
||||
M, K = a.shape
|
||||
K, N = b.shape
|
||||
|
||||
lda = K
|
||||
ldb = N
|
||||
ldc = N
|
||||
|
||||
dtype = a.dtype
|
||||
|
||||
c = triton.empty([M,N], dtype=dtype)
|
||||
|
||||
grid = lambda opt: [triton.cdiv(M, opt.d('TM')), triton.cdiv(N, opt.d('TN'))]
|
||||
|
||||
defines= {
|
||||
'TYPE' : dtype,
|
||||
'TM' : [32,64,128],
|
||||
'TN' : [32,64,128],
|
||||
'TK' : [8],
|
||||
}
|
||||
|
||||
_dot.kernel = triton.kernel(_dot.src, defines=defines)
|
||||
_dot.kernel(a, b, c, M, N, K, lda, ldb, ldc,
|
||||
grid=grid, num_warps=4, defines=defines)
|
||||
return c
|
||||
|
||||
|
||||
dot = _dot.apply
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
M, N, K = 128, 512, 256
|
||||
a = torch.rand((M, K)).cuda()
|
||||
b = torch.rand((K, N)).cuda()
|
||||
|
||||
|
||||
zc = torch.matmul(a,b)
|
||||
zc_ = dot(a,b)
|
||||
|
||||
print(torch.allclose(zc, zc_))
|
43
python/examples/tutorials/vec_add.py
Normal file
43
python/examples/tutorials/vec_add.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import torch
|
||||
import triton
|
||||
|
||||
class _add(torch.autograd.Function):
|
||||
src = """
|
||||
__global__ void add(float* z, float* x, float* y, int N) {
|
||||
|
||||
int pid = get_program_id(0);
|
||||
|
||||
int offset[TILE] = pid * TILE + 0 ... TILE;
|
||||
float* pz[TILE] = z + offset;
|
||||
float* px[TILE] = x + offset;
|
||||
float* py[TILE] = y + offset;
|
||||
|
||||
bool check[TILE] = offset < N;
|
||||
|
||||
*?(check)pz = *?(check)px + *?(check)py;
|
||||
}
|
||||
"""
|
||||
|
||||
kernel = triton.kernel(src, defines={'TILE': 1024}, num_warps=[4])
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, x, y):
|
||||
z = torch.empty_like(x).cuda()
|
||||
|
||||
N = x.numel()
|
||||
grid = lambda opt: (triton.cdiv(N, opt.d('TILE')),)
|
||||
|
||||
_add.kernel(z,x,y, N, grid=grid)
|
||||
|
||||
return z
|
||||
|
||||
add = _add.apply
|
||||
|
||||
# test
|
||||
torch.manual_seed(0)
|
||||
x = torch.rand(98432).cuda()
|
||||
y = torch.rand(98432).cuda()
|
||||
za = x + y
|
||||
zb = add(x, y)
|
||||
|
||||
print(torch.allclose(za,zb))
|
Reference in New Issue
Block a user